1use abi_stable::{
19 std_types::{ROption, RResult, RStr, RString, RVec},
20 StableAbi,
21};
22use accumulator::{FFI_Accumulator, ForeignAccumulator};
23use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
24use arrow::datatypes::{DataType, Field};
25use arrow::ffi::FFI_ArrowSchema;
26use arrow_schema::FieldRef;
27use datafusion::{
28 error::DataFusionError,
29 logical_expr::{
30 function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
31 type_coercion::functions::fields_with_aggregate_udf,
32 utils::AggregateOrderSensitivity,
33 Accumulator, GroupsAccumulator,
34 },
35};
36use datafusion::{
37 error::Result,
38 logical_expr::{AggregateUDF, AggregateUDFImpl, Signature},
39};
40use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
41use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator};
42use std::hash::{Hash, Hasher};
43use std::{ffi::c_void, sync::Arc};
44
45use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
46use crate::{
47 arrow_wrappers::WrappedSchema,
48 df_result, rresult, rresult_return,
49 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
50 volatility::FFI_Volatility,
51};
52use prost::{DecodeError, Message};
53
54mod accumulator;
55mod accumulator_args;
56mod groups_accumulator;
57
58#[repr(C)]
60#[derive(Debug, StableAbi)]
61#[allow(non_camel_case_types)]
62pub struct FFI_AggregateUDF {
63 pub name: RString,
65
66 pub aliases: RVec<RString>,
68
69 pub volatility: FFI_Volatility,
71
72 pub return_field: unsafe extern "C" fn(
75 udaf: &Self,
76 arg_fields: RVec<WrappedSchema>,
77 ) -> RResult<WrappedSchema, RString>,
78
79 pub is_nullable: bool,
81
82 pub groups_accumulator_supported:
84 unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
85
86 pub accumulator: unsafe extern "C" fn(
88 udaf: &FFI_AggregateUDF,
89 args: FFI_AccumulatorArgs,
90 ) -> RResult<FFI_Accumulator, RString>,
91
92 pub create_sliding_accumulator:
94 unsafe extern "C" fn(
95 udaf: &FFI_AggregateUDF,
96 args: FFI_AccumulatorArgs,
97 ) -> RResult<FFI_Accumulator, RString>,
98
99 #[allow(clippy::type_complexity)]
101 pub state_fields: unsafe extern "C" fn(
102 udaf: &FFI_AggregateUDF,
103 name: &RStr,
104 input_fields: RVec<WrappedSchema>,
105 return_field: WrappedSchema,
106 ordering_fields: RVec<RVec<u8>>,
107 is_distinct: bool,
108 ) -> RResult<RVec<RVec<u8>>, RString>,
109
110 pub create_groups_accumulator:
112 unsafe extern "C" fn(
113 udaf: &FFI_AggregateUDF,
114 args: FFI_AccumulatorArgs,
115 ) -> RResult<FFI_GroupsAccumulator, RString>,
116
117 pub with_beneficial_ordering:
119 unsafe extern "C" fn(
120 udaf: &FFI_AggregateUDF,
121 beneficial_ordering: bool,
122 ) -> RResult<ROption<FFI_AggregateUDF>, RString>,
123
124 pub order_sensitivity:
126 unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
127
128 pub coerce_types: unsafe extern "C" fn(
133 udf: &Self,
134 arg_types: RVec<WrappedSchema>,
135 ) -> RResult<RVec<WrappedSchema>, RString>,
136
137 pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
140
141 pub release: unsafe extern "C" fn(udaf: &mut Self),
143
144 pub private_data: *mut c_void,
147}
148
149unsafe impl Send for FFI_AggregateUDF {}
150unsafe impl Sync for FFI_AggregateUDF {}
151
152pub struct AggregateUDFPrivateData {
153 pub udaf: Arc<AggregateUDF>,
154}
155
156impl FFI_AggregateUDF {
157 unsafe fn inner(&self) -> &Arc<AggregateUDF> {
158 let private_data = self.private_data as *const AggregateUDFPrivateData;
159 &(*private_data).udaf
160 }
161}
162
163unsafe extern "C" fn return_field_fn_wrapper(
164 udaf: &FFI_AggregateUDF,
165 arg_fields: RVec<WrappedSchema>,
166) -> RResult<WrappedSchema, RString> {
167 let udaf = udaf.inner();
168
169 let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
170
171 let return_field = udaf
172 .return_field(&arg_fields)
173 .and_then(|v| {
174 FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
175 })
176 .map(WrappedSchema);
177
178 rresult!(return_field)
179}
180
181unsafe extern "C" fn accumulator_fn_wrapper(
182 udaf: &FFI_AggregateUDF,
183 args: FFI_AccumulatorArgs,
184) -> RResult<FFI_Accumulator, RString> {
185 let udaf = udaf.inner();
186
187 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
188
189 rresult!(udaf
190 .accumulator(accumulator_args.into())
191 .map(FFI_Accumulator::from))
192}
193
194unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
195 udaf: &FFI_AggregateUDF,
196 args: FFI_AccumulatorArgs,
197) -> RResult<FFI_Accumulator, RString> {
198 let udaf = udaf.inner();
199
200 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
201
202 rresult!(udaf
203 .create_sliding_accumulator(accumulator_args.into())
204 .map(FFI_Accumulator::from))
205}
206
207unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
208 udaf: &FFI_AggregateUDF,
209 args: FFI_AccumulatorArgs,
210) -> RResult<FFI_GroupsAccumulator, RString> {
211 let udaf = udaf.inner();
212
213 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
214
215 rresult!(udaf
216 .create_groups_accumulator(accumulator_args.into())
217 .map(FFI_GroupsAccumulator::from))
218}
219
220unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
221 udaf: &FFI_AggregateUDF,
222 args: FFI_AccumulatorArgs,
223) -> bool {
224 let udaf = udaf.inner();
225
226 ForeignAccumulatorArgs::try_from(args)
227 .map(|a| udaf.groups_accumulator_supported((&a).into()))
228 .unwrap_or_else(|e| {
229 log::warn!("Unable to parse accumulator args. {e}");
230 false
231 })
232}
233
234unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
235 udaf: &FFI_AggregateUDF,
236 beneficial_ordering: bool,
237) -> RResult<ROption<FFI_AggregateUDF>, RString> {
238 let udaf = udaf.inner().as_ref().clone();
239
240 let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
241 let result = rresult_return!(result
242 .map(|func| func.with_beneficial_ordering(beneficial_ordering))
243 .transpose())
244 .flatten()
245 .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
246
247 RResult::ROk(result.into())
248}
249
250unsafe extern "C" fn state_fields_fn_wrapper(
251 udaf: &FFI_AggregateUDF,
252 name: &RStr,
253 input_fields: RVec<WrappedSchema>,
254 return_field: WrappedSchema,
255 ordering_fields: RVec<RVec<u8>>,
256 is_distinct: bool,
257) -> RResult<RVec<RVec<u8>>, RString> {
258 let udaf = udaf.inner();
259
260 let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
261 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
262
263 let ordering_fields = &rresult_return!(ordering_fields
264 .into_iter()
265 .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref()))
266 .collect::<std::result::Result<Vec<_>, DecodeError>>());
267
268 let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
269 .into_iter()
270 .map(Arc::new)
271 .collect::<Vec<_>>();
272
273 let args = StateFieldsArgs {
274 name: name.as_str(),
275 input_fields,
276 return_field,
277 ordering_fields,
278 is_distinct,
279 };
280
281 let state_fields = rresult_return!(udaf.state_fields(args));
282 let state_fields = rresult_return!(state_fields
283 .iter()
284 .map(|f| f.as_ref())
285 .map(datafusion_proto::protobuf::Field::try_from)
286 .map(|v| v.map_err(DataFusionError::from))
287 .collect::<Result<Vec<_>>>())
288 .into_iter()
289 .map(|field| field.encode_to_vec().into())
290 .collect();
291
292 RResult::ROk(state_fields)
293}
294
295unsafe extern "C" fn order_sensitivity_fn_wrapper(
296 udaf: &FFI_AggregateUDF,
297) -> FFI_AggregateOrderSensitivity {
298 udaf.inner().order_sensitivity().into()
299}
300
301unsafe extern "C" fn coerce_types_fn_wrapper(
302 udaf: &FFI_AggregateUDF,
303 arg_types: RVec<WrappedSchema>,
304) -> RResult<RVec<WrappedSchema>, RString> {
305 let udaf = udaf.inner();
306
307 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
308
309 let arg_fields = arg_types
310 .iter()
311 .map(|dt| Field::new("f", dt.clone(), true))
312 .map(Arc::new)
313 .collect::<Vec<_>>();
314 let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf))
315 .into_iter()
316 .map(|f| f.data_type().to_owned())
317 .collect::<Vec<_>>();
318
319 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
320}
321
322unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
323 let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
324 drop(private_data);
325}
326
327unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
328 Arc::clone(udaf.inner()).into()
329}
330
331impl Clone for FFI_AggregateUDF {
332 fn clone(&self) -> Self {
333 unsafe { (self.clone)(self) }
334 }
335}
336
337impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
338 fn from(udaf: Arc<AggregateUDF>) -> Self {
339 let name = udaf.name().into();
340 let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
341 let is_nullable = udaf.is_nullable();
342 let volatility = udaf.signature().volatility.into();
343
344 let private_data = Box::new(AggregateUDFPrivateData { udaf });
345
346 Self {
347 name,
348 is_nullable,
349 volatility,
350 aliases,
351 return_field: return_field_fn_wrapper,
352 accumulator: accumulator_fn_wrapper,
353 create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
354 create_groups_accumulator: create_groups_accumulator_fn_wrapper,
355 groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
356 with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
357 state_fields: state_fields_fn_wrapper,
358 order_sensitivity: order_sensitivity_fn_wrapper,
359 coerce_types: coerce_types_fn_wrapper,
360 clone: clone_fn_wrapper,
361 release: release_fn_wrapper,
362 private_data: Box::into_raw(private_data) as *mut c_void,
363 }
364 }
365}
366
367impl Drop for FFI_AggregateUDF {
368 fn drop(&mut self) {
369 unsafe { (self.release)(self) }
370 }
371}
372
373#[derive(Debug)]
380pub struct ForeignAggregateUDF {
381 signature: Signature,
382 aliases: Vec<String>,
383 udaf: FFI_AggregateUDF,
384}
385
386unsafe impl Send for ForeignAggregateUDF {}
387unsafe impl Sync for ForeignAggregateUDF {}
388
389impl PartialEq for ForeignAggregateUDF {
390 fn eq(&self, other: &Self) -> bool {
391 std::ptr::eq(self, other)
393 }
394}
395impl Eq for ForeignAggregateUDF {}
396impl Hash for ForeignAggregateUDF {
397 fn hash<H: Hasher>(&self, state: &mut H) {
398 std::ptr::hash(self, state)
399 }
400}
401
402impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF {
403 type Error = DataFusionError;
404
405 fn try_from(udaf: &FFI_AggregateUDF) -> Result<Self, Self::Error> {
406 let signature = Signature::user_defined((&udaf.volatility).into());
407 let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
408
409 Ok(Self {
410 udaf: udaf.clone(),
411 signature,
412 aliases,
413 })
414 }
415}
416
417impl AggregateUDFImpl for ForeignAggregateUDF {
418 fn as_any(&self) -> &dyn std::any::Any {
419 self
420 }
421
422 fn name(&self) -> &str {
423 self.udaf.name.as_str()
424 }
425
426 fn signature(&self) -> &Signature {
427 &self.signature
428 }
429
430 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
431 unimplemented!()
432 }
433
434 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
435 let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
436
437 let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
438
439 let result = df_result!(result);
440
441 result.and_then(|r| {
442 Field::try_from(&r.0)
443 .map(Arc::new)
444 .map_err(DataFusionError::from)
445 })
446 }
447
448 fn is_nullable(&self) -> bool {
449 self.udaf.is_nullable
450 }
451
452 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
453 let args = acc_args.try_into()?;
454 unsafe {
455 df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| {
456 Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>
457 })
458 }
459 }
460
461 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
462 unsafe {
463 let name = RStr::from_str(args.name);
464 let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
465 let return_field =
466 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
467 let ordering_fields = args
468 .ordering_fields
469 .iter()
470 .map(|f| f.as_ref())
471 .map(datafusion_proto::protobuf::Field::try_from)
472 .map(|v| v.map_err(DataFusionError::from))
473 .collect::<Result<Vec<_>>>()?
474 .into_iter()
475 .map(|proto_field| proto_field.encode_to_vec().into())
476 .collect();
477
478 let fields = df_result!((self.udaf.state_fields)(
479 &self.udaf,
480 &name,
481 input_fields,
482 return_field,
483 ordering_fields,
484 args.is_distinct
485 ))?;
486 let fields = fields
487 .into_iter()
488 .map(|field_bytes| {
489 datafusion_proto_common::Field::decode(field_bytes.as_ref())
490 .map_err(|e| DataFusionError::Execution(e.to_string()))
491 })
492 .collect::<Result<Vec<_>>>()?;
493
494 parse_proto_fields_to_fields(fields.iter())
495 .map(|fields| fields.into_iter().map(Arc::new).collect())
496 .map_err(|e| DataFusionError::Execution(e.to_string()))
497 }
498 }
499
500 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
501 let args = match FFI_AccumulatorArgs::try_from(args) {
502 Ok(v) => v,
503 Err(e) => {
504 log::warn!("Attempting to convert accumulator arguments: {e}");
505 return false;
506 }
507 };
508
509 unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
510 }
511
512 fn create_groups_accumulator(
513 &self,
514 args: AccumulatorArgs,
515 ) -> Result<Box<dyn GroupsAccumulator>> {
516 let args = FFI_AccumulatorArgs::try_from(args)?;
517
518 unsafe {
519 df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map(
520 |accum| {
521 Box::new(ForeignGroupsAccumulator::from(accum))
522 as Box<dyn GroupsAccumulator>
523 },
524 )
525 }
526 }
527
528 fn aliases(&self) -> &[String] {
529 &self.aliases
530 }
531
532 fn create_sliding_accumulator(
533 &self,
534 args: AccumulatorArgs,
535 ) -> Result<Box<dyn Accumulator>> {
536 let args = args.try_into()?;
537 unsafe {
538 df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map(
539 |accum| Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>,
540 )
541 }
542 }
543
544 fn with_beneficial_ordering(
545 self: Arc<Self>,
546 beneficial_ordering: bool,
547 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
548 unsafe {
549 let result = df_result!((self.udaf.with_beneficial_ordering)(
550 &self.udaf,
551 beneficial_ordering
552 ))?
553 .into_option();
554
555 let result = result
556 .map(|func| ForeignAggregateUDF::try_from(&func))
557 .transpose()?;
558
559 Ok(result.map(|func| Arc::new(func) as Arc<dyn AggregateUDFImpl>))
560 }
561 }
562
563 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
564 unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
565 }
566
567 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
568 None
569 }
570
571 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
572 unsafe {
573 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
574 let result_types =
575 df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
576 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
577 }
578 }
579}
580
581#[repr(C)]
582#[derive(Debug, StableAbi)]
583#[allow(non_camel_case_types)]
584pub enum FFI_AggregateOrderSensitivity {
585 Insensitive,
586 HardRequirement,
587 SoftRequirement,
588 Beneficial,
589}
590
591impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
592 fn from(value: FFI_AggregateOrderSensitivity) -> Self {
593 match value {
594 FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
595 FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
596 FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
597 FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
598 }
599 }
600}
601
602impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
603 fn from(value: AggregateOrderSensitivity) -> Self {
604 match value {
605 AggregateOrderSensitivity::Insensitive => Self::Insensitive,
606 AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
607 AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
608 AggregateOrderSensitivity::Beneficial => Self::Beneficial,
609 }
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use arrow::datatypes::Schema;
616 use datafusion::{
617 common::create_array, functions_aggregate::sum::Sum,
618 physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
619 scalar::ScalarValue,
620 };
621 use std::any::Any;
622 use std::collections::HashMap;
623
624 use super::*;
625
626 #[derive(Default, Debug, Hash, Eq, PartialEq)]
627 struct SumWithCopiedMetadata {
628 inner: Sum,
629 }
630
631 impl AggregateUDFImpl for SumWithCopiedMetadata {
632 fn as_any(&self) -> &dyn Any {
633 self
634 }
635
636 fn name(&self) -> &str {
637 self.inner.name()
638 }
639
640 fn signature(&self) -> &Signature {
641 self.inner.signature()
642 }
643
644 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
645 unimplemented!()
646 }
647
648 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
649 Ok(Arc::clone(&arg_fields[0]))
651 }
652
653 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
654 self.inner.accumulator(acc_args)
655 }
656 }
657
658 fn create_test_foreign_udaf(
659 original_udaf: impl AggregateUDFImpl + 'static,
660 ) -> Result<AggregateUDF> {
661 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
662
663 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
664
665 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
666 Ok(foreign_udaf.into())
667 }
668
669 #[test]
670 fn test_round_trip_udaf() -> Result<()> {
671 let original_udaf = Sum::new();
672 let original_name = original_udaf.name().to_owned();
673 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
674
675 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
677
678 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
680 let foreign_udaf: AggregateUDF = foreign_udaf.into();
681
682 assert_eq!(original_name, foreign_udaf.name());
683 Ok(())
684 }
685
686 #[test]
687 fn test_foreign_udaf_aliases() -> Result<()> {
688 let foreign_udaf =
689 create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
690
691 let return_field =
692 foreign_udaf
693 .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
694 let return_type = return_field.data_type();
695 assert_eq!(return_type, &DataType::Float64);
696 Ok(())
697 }
698
699 #[test]
700 fn test_foreign_udaf_accumulator() -> Result<()> {
701 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
702
703 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
704 let acc_args = AccumulatorArgs {
705 return_field: Field::new("f", DataType::Float64, true).into(),
706 schema: &schema,
707 ignore_nulls: true,
708 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
709 is_reversed: false,
710 name: "round_trip",
711 is_distinct: true,
712 exprs: &[col("a", &schema)?],
713 };
714 let mut accumulator = foreign_udaf.accumulator(acc_args)?;
715 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
716 accumulator.update_batch(&[values])?;
717 let resultant_value = accumulator.evaluate()?;
718 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
719
720 Ok(())
721 }
722
723 #[test]
724 fn test_round_trip_udaf_metadata() -> Result<()> {
725 let original_udaf = SumWithCopiedMetadata::default();
726 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
727
728 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
730
731 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
733 let foreign_udaf: AggregateUDF = foreign_udaf.into();
734
735 let metadata: HashMap<String, String> =
736 [("a_key".to_string(), "a_value".to_string())]
737 .into_iter()
738 .collect();
739 let input_field = Arc::new(
740 Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
741 );
742 let return_field = foreign_udaf.return_field(&[input_field])?;
743
744 assert_eq!(&metadata, return_field.metadata());
745 Ok(())
746 }
747
748 #[test]
749 fn test_beneficial_ordering() -> Result<()> {
750 let foreign_udaf = create_test_foreign_udaf(
751 datafusion::functions_aggregate::first_last::FirstValue::new(),
752 )?;
753
754 let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
755
756 assert_eq!(
757 foreign_udaf.order_sensitivity(),
758 AggregateOrderSensitivity::Beneficial
759 );
760
761 let a_field = Arc::new(Field::new("a", DataType::Float64, true));
762 let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
763 name: "a",
764 input_fields: &[Field::new("f", DataType::Float64, true).into()],
765 return_field: Field::new("f", DataType::Float64, true).into(),
766 ordering_fields: &[Arc::clone(&a_field)],
767 is_distinct: false,
768 })?;
769
770 assert_eq!(state_fields.len(), 3);
771 assert_eq!(state_fields[1], a_field);
772 Ok(())
773 }
774
775 #[test]
776 fn test_sliding_accumulator() -> Result<()> {
777 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
778
779 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
780 let acc_args = AccumulatorArgs {
782 return_field: Field::new("f", DataType::Float64, true).into(),
783 schema: &schema,
784 ignore_nulls: true,
785 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
786 is_reversed: false,
787 name: "round_trip",
788 is_distinct: false,
789 exprs: &[col("a", &schema)?],
790 };
791
792 let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
793 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
794 accumulator.update_batch(&[values])?;
795 let resultant_value = accumulator.evaluate()?;
796 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
797
798 Ok(())
799 }
800
801 fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
802 let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
803 let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
804
805 assert_eq!(sensitivity, round_trip_sensitivity);
806 }
807
808 #[test]
809 fn test_round_trip_all_order_sensitivities() {
810 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
811 test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
812 test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
813 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
814 }
815}