1use accumulator::FFI_Accumulator;
19use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
20use arrow::datatypes::{DataType, Field};
21use arrow::ffi::FFI_ArrowSchema;
22use arrow_schema::FieldRef;
23use datafusion_common::{DataFusionError, Result, ffi_datafusion_err};
24use datafusion_expr::function::AggregateFunctionSimplification;
25use datafusion_expr::type_coercion::functions::fields_with_udf;
26use datafusion_expr::{
27 Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
28};
29use datafusion_functions_aggregate_common::accumulator::{
30 AccumulatorArgs, StateFieldsArgs,
31};
32use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
33use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
34use groups_accumulator::FFI_GroupsAccumulator;
35use prost::{DecodeError, Message};
36
37use stabby::str::Str as SStr;
38use stabby::string::String as SString;
39use stabby::vec::Vec as SVec;
40use std::ffi::c_void;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43
44use crate::arrow_wrappers::WrappedSchema;
45use crate::util::{
46 FFI_Option, FFI_Result, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
47 vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
48};
49use crate::volatility::FFI_Volatility;
50use crate::{df_result, sresult, sresult_return};
51
52mod accumulator;
53mod accumulator_args;
54mod groups_accumulator;
55
56#[repr(C)]
58#[derive(Debug)]
59pub struct FFI_AggregateUDF {
60 pub name: SString,
62
63 pub aliases: SVec<SString>,
65
66 pub volatility: FFI_Volatility,
68
69 pub return_field: unsafe extern "C" fn(
72 udaf: &Self,
73 arg_fields: SVec<WrappedSchema>,
74 ) -> FFI_Result<WrappedSchema>,
75
76 pub is_nullable: bool,
78
79 pub groups_accumulator_supported:
81 unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
82
83 pub accumulator: unsafe extern "C" fn(
85 udaf: &FFI_AggregateUDF,
86 args: FFI_AccumulatorArgs,
87 ) -> FFI_Result<FFI_Accumulator>,
88
89 pub create_sliding_accumulator: unsafe extern "C" fn(
91 udaf: &FFI_AggregateUDF,
92 args: FFI_AccumulatorArgs,
93 )
94 -> FFI_Result<FFI_Accumulator>,
95
96 pub state_fields: unsafe extern "C" fn(
98 udaf: &FFI_AggregateUDF,
99 name: &SStr,
100 input_fields: SVec<WrappedSchema>,
101 return_field: WrappedSchema,
102 ordering_fields: SVec<SVec<u8>>,
103 is_distinct: bool,
104 ) -> FFI_Result<SVec<SVec<u8>>>,
105
106 pub create_groups_accumulator:
108 unsafe extern "C" fn(
109 udaf: &FFI_AggregateUDF,
110 args: FFI_AccumulatorArgs,
111 ) -> FFI_Result<FFI_GroupsAccumulator>,
112
113 pub with_beneficial_ordering:
115 unsafe extern "C" fn(
116 udaf: &FFI_AggregateUDF,
117 beneficial_ordering: bool,
118 ) -> FFI_Result<FFI_Option<FFI_AggregateUDF>>,
119
120 pub order_sensitivity:
122 unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
123
124 pub coerce_types: unsafe extern "C" fn(
129 udf: &Self,
130 arg_types: SVec<WrappedSchema>,
131 ) -> FFI_Result<SVec<WrappedSchema>>,
132
133 pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
136
137 pub release: unsafe extern "C" fn(udaf: &mut Self),
139
140 pub private_data: *mut c_void,
143
144 pub library_marker_id: extern "C" fn() -> usize,
148}
149
150unsafe impl Send for FFI_AggregateUDF {}
151unsafe impl Sync for FFI_AggregateUDF {}
152
153pub struct AggregateUDFPrivateData {
154 pub udaf: Arc<AggregateUDF>,
155}
156
157impl FFI_AggregateUDF {
158 unsafe fn inner(&self) -> &Arc<AggregateUDF> {
159 unsafe {
160 let private_data = self.private_data as *const AggregateUDFPrivateData;
161 &(*private_data).udaf
162 }
163 }
164}
165
166unsafe extern "C" fn return_field_fn_wrapper(
167 udaf: &FFI_AggregateUDF,
168 arg_fields: SVec<WrappedSchema>,
169) -> FFI_Result<WrappedSchema> {
170 unsafe {
171 let udaf = udaf.inner();
172
173 let arg_fields = sresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
174
175 let return_field = udaf
176 .return_field(&arg_fields)
177 .and_then(|v| {
178 FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
179 })
180 .map(WrappedSchema);
181
182 sresult!(return_field)
183 }
184}
185
186unsafe extern "C" fn accumulator_fn_wrapper(
187 udaf: &FFI_AggregateUDF,
188 args: FFI_AccumulatorArgs,
189) -> FFI_Result<FFI_Accumulator> {
190 unsafe {
191 let udaf = udaf.inner();
192
193 let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
194
195 sresult!(
196 udaf.accumulator(accumulator_args.into())
197 .map(FFI_Accumulator::from)
198 )
199 }
200}
201
202unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
203 udaf: &FFI_AggregateUDF,
204 args: FFI_AccumulatorArgs,
205) -> FFI_Result<FFI_Accumulator> {
206 unsafe {
207 let udaf = udaf.inner();
208
209 let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
210
211 sresult!(
212 udaf.create_sliding_accumulator(accumulator_args.into())
213 .map(FFI_Accumulator::from)
214 )
215 }
216}
217
218unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
219 udaf: &FFI_AggregateUDF,
220 args: FFI_AccumulatorArgs,
221) -> FFI_Result<FFI_GroupsAccumulator> {
222 unsafe {
223 let udaf = udaf.inner();
224
225 let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
226
227 sresult!(
228 udaf.create_groups_accumulator(accumulator_args.into())
229 .map(FFI_GroupsAccumulator::from)
230 )
231 }
232}
233
234unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
235 udaf: &FFI_AggregateUDF,
236 args: FFI_AccumulatorArgs,
237) -> bool {
238 unsafe {
239 let udaf = udaf.inner();
240
241 ForeignAccumulatorArgs::try_from(args)
242 .map(|a| udaf.groups_accumulator_supported((&a).into()))
243 .unwrap_or_else(|e| {
244 log::warn!("Unable to parse accumulator args. {e}");
245 false
246 })
247 }
248}
249
250unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
251 udaf: &FFI_AggregateUDF,
252 beneficial_ordering: bool,
253) -> FFI_Result<FFI_Option<FFI_AggregateUDF>> {
254 unsafe {
255 let udaf = udaf.inner().as_ref().clone();
256
257 let result = sresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
258 let result = sresult_return!(
259 result
260 .map(|func| func.with_beneficial_ordering(beneficial_ordering))
261 .transpose()
262 )
263 .flatten()
264 .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
265
266 FFI_Result::Ok(result.into())
267 }
268}
269
270unsafe extern "C" fn state_fields_fn_wrapper(
271 udaf: &FFI_AggregateUDF,
272 name: &SStr,
273 input_fields: SVec<WrappedSchema>,
274 return_field: WrappedSchema,
275 ordering_fields: SVec<SVec<u8>>,
276 is_distinct: bool,
277) -> FFI_Result<SVec<SVec<u8>>> {
278 unsafe {
279 let udaf = udaf.inner();
280
281 let input_fields = &sresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
282 let return_field = sresult_return!(Field::try_from(&return_field.0)).into();
283
284 let ordering_fields = &sresult_return!(
285 ordering_fields
286 .into_iter()
287 .map(|field_bytes| datafusion_proto_common::Field::decode(
288 field_bytes.as_ref()
289 ))
290 .collect::<std::result::Result<Vec<_>, DecodeError>>()
291 );
292
293 let ordering_fields =
294 &sresult_return!(parse_proto_fields_to_fields(ordering_fields))
295 .into_iter()
296 .map(Arc::new)
297 .collect::<Vec<_>>();
298
299 let args = StateFieldsArgs {
300 name: name.as_str(),
301 input_fields,
302 return_field,
303 ordering_fields,
304 is_distinct,
305 };
306
307 let state_fields = sresult_return!(udaf.state_fields(args));
308 let state_fields = sresult_return!(
309 state_fields
310 .iter()
311 .map(|f| f.as_ref())
312 .map(datafusion_proto::protobuf::Field::try_from)
313 .map(|v| v.map_err(DataFusionError::from))
314 .collect::<Result<Vec<_>>>()
315 )
316 .into_iter()
317 .map(|field| field.encode_to_vec().into_iter().collect())
318 .collect();
319
320 FFI_Result::Ok(state_fields)
321 }
322}
323
324unsafe extern "C" fn order_sensitivity_fn_wrapper(
325 udaf: &FFI_AggregateUDF,
326) -> FFI_AggregateOrderSensitivity {
327 unsafe { udaf.inner().order_sensitivity().into() }
328}
329
330unsafe extern "C" fn coerce_types_fn_wrapper(
331 udaf: &FFI_AggregateUDF,
332 arg_types: SVec<WrappedSchema>,
333) -> FFI_Result<SVec<WrappedSchema>> {
334 unsafe {
335 let udaf = udaf.inner();
336
337 let arg_types = sresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
338
339 let arg_fields = arg_types
340 .iter()
341 .map(|dt| Field::new("f", dt.clone(), true))
342 .map(Arc::new)
343 .collect::<Vec<_>>();
344 let return_types = sresult_return!(fields_with_udf(&arg_fields, udaf.as_ref()))
345 .into_iter()
346 .map(|f| f.data_type().to_owned())
347 .collect::<Vec<_>>();
348
349 sresult!(vec_datatype_to_rvec_wrapped(&return_types))
350 }
351}
352
353unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
354 unsafe {
355 debug_assert!(!udaf.private_data.is_null());
356 let private_data =
357 Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
358 drop(private_data);
359 udaf.private_data = std::ptr::null_mut();
360 }
361}
362
363unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
364 unsafe { Arc::clone(udaf.inner()).into() }
365}
366
367impl Clone for FFI_AggregateUDF {
368 fn clone(&self) -> Self {
369 unsafe { (self.clone)(self) }
370 }
371}
372
373impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
374 fn from(udaf: Arc<AggregateUDF>) -> Self {
375 if let Some(udaf) = udaf.inner().downcast_ref::<ForeignAggregateUDF>() {
376 return udaf.udaf.clone();
377 }
378
379 let name = udaf.name().into();
380 let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
381 let is_nullable = udaf.is_nullable();
382 let volatility = udaf.signature().volatility.into();
383
384 let private_data = Box::new(AggregateUDFPrivateData { udaf });
385
386 Self {
387 name,
388 is_nullable,
389 volatility,
390 aliases,
391 return_field: return_field_fn_wrapper,
392 accumulator: accumulator_fn_wrapper,
393 create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
394 create_groups_accumulator: create_groups_accumulator_fn_wrapper,
395 groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
396 with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
397 state_fields: state_fields_fn_wrapper,
398 order_sensitivity: order_sensitivity_fn_wrapper,
399 coerce_types: coerce_types_fn_wrapper,
400 clone: clone_fn_wrapper,
401 release: release_fn_wrapper,
402 private_data: Box::into_raw(private_data) as *mut c_void,
403 library_marker_id: crate::get_library_marker_id,
404 }
405 }
406}
407
408impl Drop for FFI_AggregateUDF {
409 fn drop(&mut self) {
410 unsafe { (self.release)(self) }
411 }
412}
413
414#[derive(Debug)]
421pub struct ForeignAggregateUDF {
422 signature: Signature,
423 aliases: Vec<String>,
424 udaf: FFI_AggregateUDF,
425}
426
427unsafe impl Send for ForeignAggregateUDF {}
428unsafe impl Sync for ForeignAggregateUDF {}
429
430impl PartialEq for ForeignAggregateUDF {
431 fn eq(&self, other: &Self) -> bool {
432 std::ptr::eq(self, other)
434 }
435}
436impl Eq for ForeignAggregateUDF {}
437impl Hash for ForeignAggregateUDF {
438 fn hash<H: Hasher>(&self, state: &mut H) {
439 std::ptr::hash(self, state)
440 }
441}
442
443impl From<&FFI_AggregateUDF> for Arc<dyn AggregateUDFImpl> {
444 fn from(udaf: &FFI_AggregateUDF) -> Self {
445 if (udaf.library_marker_id)() == crate::get_library_marker_id() {
446 return Arc::clone(unsafe { udaf.inner().inner() });
447 }
448
449 let signature = Signature::user_defined((&udaf.volatility).into());
450 let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
451
452 Arc::new(ForeignAggregateUDF {
453 udaf: udaf.clone(),
454 signature,
455 aliases,
456 })
457 }
458}
459
460impl AggregateUDFImpl for ForeignAggregateUDF {
461 fn name(&self) -> &str {
462 self.udaf.name.as_str()
463 }
464
465 fn signature(&self) -> &Signature {
466 &self.signature
467 }
468
469 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
470 unimplemented!()
471 }
472
473 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
474 let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
475
476 let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
477
478 let result = df_result!(result);
479
480 result.and_then(|r| {
481 Field::try_from(&r.0)
482 .map(Arc::new)
483 .map_err(DataFusionError::from)
484 })
485 }
486
487 fn is_nullable(&self) -> bool {
488 self.udaf.is_nullable
489 }
490
491 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
492 let args = acc_args.try_into()?;
493 unsafe {
494 df_result!((self.udaf.accumulator)(&self.udaf, args))
495 .map(<Box<dyn Accumulator>>::from)
496 }
497 }
498
499 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
500 unsafe {
501 let name = SStr::from(args.name);
502 let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
503 let return_field =
504 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
505 let ordering_fields = args
506 .ordering_fields
507 .iter()
508 .map(|f| f.as_ref())
509 .map(datafusion_proto::protobuf::Field::try_from)
510 .map(|v| v.map_err(DataFusionError::from))
511 .collect::<Result<Vec<_>>>()?
512 .into_iter()
513 .map(|proto_field| proto_field.encode_to_vec().into_iter().collect())
514 .collect();
515
516 let fields = df_result!((self.udaf.state_fields)(
517 &self.udaf,
518 &name,
519 input_fields,
520 return_field,
521 ordering_fields,
522 args.is_distinct
523 ))?;
524 let fields = fields
525 .into_iter()
526 .map(|field_bytes| {
527 datafusion_proto_common::Field::decode(field_bytes.as_ref())
528 .map_err(|e| ffi_datafusion_err!("{e}"))
529 })
530 .collect::<Result<Vec<_>>>()?;
531
532 parse_proto_fields_to_fields(fields.iter())
533 .map(|fields| fields.into_iter().map(Arc::new).collect())
534 .map_err(|e| ffi_datafusion_err!("{e}"))
535 }
536 }
537
538 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
539 let args = match FFI_AccumulatorArgs::try_from(args) {
540 Ok(v) => v,
541 Err(e) => {
542 log::warn!("Attempting to convert accumulator arguments: {e}");
543 return false;
544 }
545 };
546
547 unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
548 }
549
550 fn create_groups_accumulator(
551 &self,
552 args: AccumulatorArgs,
553 ) -> Result<Box<dyn GroupsAccumulator>> {
554 let args = FFI_AccumulatorArgs::try_from(args)?;
555
556 unsafe {
557 df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args))
558 .map(<Box<dyn GroupsAccumulator>>::from)
559 }
560 }
561
562 fn aliases(&self) -> &[String] {
563 &self.aliases
564 }
565
566 fn create_sliding_accumulator(
567 &self,
568 args: AccumulatorArgs,
569 ) -> Result<Box<dyn Accumulator>> {
570 let args = args.try_into()?;
571 unsafe {
572 df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args))
573 .map(<Box<dyn Accumulator>>::from)
574 }
575 }
576
577 fn with_beneficial_ordering(
578 self: Arc<Self>,
579 beneficial_ordering: bool,
580 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
581 unsafe {
582 let result = df_result!((self.udaf.with_beneficial_ordering)(
583 &self.udaf,
584 beneficial_ordering
585 ))?
586 .into_option();
587
588 let result = result.map(|func| <Arc<dyn AggregateUDFImpl>>::from(&func));
589
590 Ok(result)
591 }
592 }
593
594 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
595 unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
596 }
597
598 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
599 None
600 }
601
602 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
603 unsafe {
604 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
605 let result_types =
606 df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
607 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
608 }
609 }
610}
611
612#[repr(C)]
613#[derive(Debug)]
614pub enum FFI_AggregateOrderSensitivity {
615 Insensitive,
616 HardRequirement,
617 SoftRequirement,
618 Beneficial,
619}
620
621impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
622 fn from(value: FFI_AggregateOrderSensitivity) -> Self {
623 match value {
624 FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
625 FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
626 FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
627 FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
628 }
629 }
630}
631
632impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
633 fn from(value: AggregateOrderSensitivity) -> Self {
634 match value {
635 AggregateOrderSensitivity::Insensitive => Self::Insensitive,
636 AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
637 AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
638 AggregateOrderSensitivity::Beneficial => Self::Beneficial,
639 }
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use std::collections::HashMap;
646
647 use arrow::datatypes::Schema;
648 use datafusion::common::create_array;
649 use datafusion::functions_aggregate::sum::Sum;
650 use datafusion::physical_expr::PhysicalSortExpr;
651 use datafusion::physical_plan::expressions::col;
652 use datafusion::scalar::ScalarValue;
653
654 use super::*;
655
656 #[derive(Default, Debug, Hash, Eq, PartialEq)]
657 struct SumWithCopiedMetadata {
658 inner: Sum,
659 }
660
661 impl AggregateUDFImpl for SumWithCopiedMetadata {
662 fn name(&self) -> &str {
663 self.inner.name()
664 }
665
666 fn signature(&self) -> &Signature {
667 self.inner.signature()
668 }
669
670 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
671 unimplemented!()
672 }
673
674 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
675 Ok(Arc::clone(&arg_fields[0]))
677 }
678
679 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
680 self.inner.accumulator(acc_args)
681 }
682 }
683
684 fn create_test_foreign_udaf(
685 original_udaf: impl AggregateUDFImpl + 'static,
686 ) -> Result<AggregateUDF> {
687 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
688
689 let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
690 local_udaf.library_marker_id = crate::mock_foreign_marker_id;
691
692 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
693 Ok(AggregateUDF::new_from_shared_impl(foreign_udaf))
694 }
695
696 #[test]
697 fn test_round_trip_udaf() -> Result<()> {
698 let original_udaf = Sum::new();
699 let original_name = original_udaf.name().to_owned();
700 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
701
702 let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
704 local_udaf.library_marker_id = crate::mock_foreign_marker_id;
705
706 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
708 let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
709
710 assert_eq!(original_name, foreign_udaf.name());
711 Ok(())
712 }
713
714 #[test]
715 fn test_foreign_udaf_aliases() -> Result<()> {
716 let foreign_udaf =
717 create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
718
719 let return_field =
720 foreign_udaf
721 .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
722 let return_type = return_field.data_type();
723 assert_eq!(return_type, &DataType::Float64);
724 Ok(())
725 }
726
727 #[test]
728 fn test_foreign_udaf_accumulator() -> Result<()> {
729 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
730
731 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
732 let acc_args = AccumulatorArgs {
733 return_field: Field::new("f", DataType::Float64, true).into(),
734 schema: &schema,
735 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
736 ignore_nulls: true,
737 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
738 is_reversed: false,
739 name: "round_trip",
740 is_distinct: true,
741 exprs: &[col("a", &schema)?],
742 };
743 let mut accumulator = foreign_udaf.accumulator(acc_args)?;
744 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
745 accumulator.update_batch(&[values])?;
746 let resultant_value = accumulator.evaluate()?;
747 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
748
749 Ok(())
750 }
751
752 #[test]
753 fn test_round_trip_udaf_metadata() -> Result<()> {
754 let original_udaf = SumWithCopiedMetadata::default();
755 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
756
757 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
759
760 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
762 let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
763
764 let metadata: HashMap<String, String> =
765 [("a_key".to_string(), "a_value".to_string())]
766 .into_iter()
767 .collect();
768 let input_field = Arc::new(
769 Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
770 );
771 let return_field = foreign_udaf.return_field(&[input_field])?;
772
773 assert_eq!(&metadata, return_field.metadata());
774 Ok(())
775 }
776
777 #[test]
778 fn test_beneficial_ordering() -> Result<()> {
779 let foreign_udaf = create_test_foreign_udaf(
780 datafusion::functions_aggregate::first_last::FirstValue::new(),
781 )?;
782
783 let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
784
785 assert_eq!(
786 foreign_udaf.order_sensitivity(),
787 AggregateOrderSensitivity::Beneficial
788 );
789
790 let a_field = Arc::new(Field::new("a", DataType::Float64, true));
791 let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
792 name: "a",
793 input_fields: &[Field::new("f", DataType::Float64, true).into()],
794 return_field: Field::new("f", DataType::Float64, true).into(),
795 ordering_fields: &[Arc::clone(&a_field)],
796 is_distinct: false,
797 })?;
798
799 assert_eq!(state_fields.len(), 3);
800 assert_eq!(state_fields[1], a_field);
801 Ok(())
802 }
803
804 #[test]
805 fn test_sliding_accumulator() -> Result<()> {
806 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
807
808 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
809 let acc_args = AccumulatorArgs {
811 return_field: Field::new("f", DataType::Float64, true).into(),
812 schema: &schema,
813 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
814 ignore_nulls: true,
815 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
816 is_reversed: false,
817 name: "round_trip",
818 is_distinct: false,
819 exprs: &[col("a", &schema)?],
820 };
821
822 let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
823 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
824 accumulator.update_batch(&[values])?;
825 let resultant_value = accumulator.evaluate()?;
826 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
827
828 Ok(())
829 }
830
831 fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
832 let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
833 let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
834
835 assert_eq!(sensitivity, round_trip_sensitivity);
836 }
837
838 #[test]
839 fn test_round_trip_all_order_sensitivities() {
840 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
841 test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
842 test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
843 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
844 }
845
846 #[test]
847 fn test_ffi_udaf_local_bypass() -> Result<()> {
848 let original_udaf = Sum::new();
849 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
850
851 let mut ffi_udaf = FFI_AggregateUDF::from(original_udaf);
852
853 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
855 assert!(foreign_udaf.is::<Sum>());
856
857 ffi_udaf.library_marker_id = crate::mock_foreign_marker_id;
859 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
860 assert!(foreign_udaf.is::<ForeignAggregateUDF>());
861
862 Ok(())
863 }
864}