1use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{ROption, RResult, RStr, RString, RVec};
24use accumulator::FFI_Accumulator;
25use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
26use arrow::datatypes::{DataType, Field};
27use arrow::ffi::FFI_ArrowSchema;
28use arrow_schema::FieldRef;
29use datafusion_common::{DataFusionError, Result, ffi_datafusion_err};
30use datafusion_expr::function::AggregateFunctionSimplification;
31use datafusion_expr::type_coercion::functions::fields_with_udf;
32use datafusion_expr::{
33 Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
34};
35use datafusion_functions_aggregate_common::accumulator::{
36 AccumulatorArgs, StateFieldsArgs,
37};
38use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
39use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
40use groups_accumulator::FFI_GroupsAccumulator;
41use prost::{DecodeError, Message};
42
43use crate::arrow_wrappers::WrappedSchema;
44use crate::util::{
45 FFIResult, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
46 vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
47};
48use crate::volatility::FFI_Volatility;
49use crate::{df_result, rresult, rresult_return};
50
51mod accumulator;
52mod accumulator_args;
53mod groups_accumulator;
54
55#[repr(C)]
57#[derive(Debug, StableAbi)]
58pub struct FFI_AggregateUDF {
59 pub name: RString,
61
62 pub aliases: RVec<RString>,
64
65 pub volatility: FFI_Volatility,
67
68 pub return_field: unsafe extern "C" fn(
71 udaf: &Self,
72 arg_fields: RVec<WrappedSchema>,
73 ) -> FFIResult<WrappedSchema>,
74
75 pub is_nullable: bool,
77
78 pub groups_accumulator_supported:
80 unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
81
82 pub accumulator: unsafe extern "C" fn(
84 udaf: &FFI_AggregateUDF,
85 args: FFI_AccumulatorArgs,
86 ) -> FFIResult<FFI_Accumulator>,
87
88 pub create_sliding_accumulator: unsafe extern "C" fn(
90 udaf: &FFI_AggregateUDF,
91 args: FFI_AccumulatorArgs,
92 )
93 -> FFIResult<FFI_Accumulator>,
94
95 pub state_fields: unsafe extern "C" fn(
97 udaf: &FFI_AggregateUDF,
98 name: &RStr,
99 input_fields: RVec<WrappedSchema>,
100 return_field: WrappedSchema,
101 ordering_fields: RVec<RVec<u8>>,
102 is_distinct: bool,
103 ) -> FFIResult<RVec<RVec<u8>>>,
104
105 pub create_groups_accumulator:
107 unsafe extern "C" fn(
108 udaf: &FFI_AggregateUDF,
109 args: FFI_AccumulatorArgs,
110 ) -> FFIResult<FFI_GroupsAccumulator>,
111
112 pub with_beneficial_ordering:
114 unsafe extern "C" fn(
115 udaf: &FFI_AggregateUDF,
116 beneficial_ordering: bool,
117 ) -> FFIResult<ROption<FFI_AggregateUDF>>,
118
119 pub order_sensitivity:
121 unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
122
123 pub coerce_types: unsafe extern "C" fn(
128 udf: &Self,
129 arg_types: RVec<WrappedSchema>,
130 ) -> FFIResult<RVec<WrappedSchema>>,
131
132 pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
135
136 pub release: unsafe extern "C" fn(udaf: &mut Self),
138
139 pub private_data: *mut c_void,
142
143 pub library_marker_id: extern "C" fn() -> usize,
147}
148
149unsafe impl Send for FFI_AggregateUDF {}
150unsafe impl Sync for FFI_AggregateUDF {}
151
152pub struct AggregateUDFPrivateData {
153 pub udaf: Arc<AggregateUDF>,
154}
155
156impl FFI_AggregateUDF {
157 unsafe fn inner(&self) -> &Arc<AggregateUDF> {
158 unsafe {
159 let private_data = self.private_data as *const AggregateUDFPrivateData;
160 &(*private_data).udaf
161 }
162 }
163}
164
165unsafe extern "C" fn return_field_fn_wrapper(
166 udaf: &FFI_AggregateUDF,
167 arg_fields: RVec<WrappedSchema>,
168) -> FFIResult<WrappedSchema> {
169 unsafe {
170 let udaf = udaf.inner();
171
172 let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
173
174 let return_field = udaf
175 .return_field(&arg_fields)
176 .and_then(|v| {
177 FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
178 })
179 .map(WrappedSchema);
180
181 rresult!(return_field)
182 }
183}
184
185unsafe extern "C" fn accumulator_fn_wrapper(
186 udaf: &FFI_AggregateUDF,
187 args: FFI_AccumulatorArgs,
188) -> FFIResult<FFI_Accumulator> {
189 unsafe {
190 let udaf = udaf.inner();
191
192 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
193
194 rresult!(
195 udaf.accumulator(accumulator_args.into())
196 .map(FFI_Accumulator::from)
197 )
198 }
199}
200
201unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
202 udaf: &FFI_AggregateUDF,
203 args: FFI_AccumulatorArgs,
204) -> FFIResult<FFI_Accumulator> {
205 unsafe {
206 let udaf = udaf.inner();
207
208 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
209
210 rresult!(
211 udaf.create_sliding_accumulator(accumulator_args.into())
212 .map(FFI_Accumulator::from)
213 )
214 }
215}
216
217unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
218 udaf: &FFI_AggregateUDF,
219 args: FFI_AccumulatorArgs,
220) -> FFIResult<FFI_GroupsAccumulator> {
221 unsafe {
222 let udaf = udaf.inner();
223
224 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
225
226 rresult!(
227 udaf.create_groups_accumulator(accumulator_args.into())
228 .map(FFI_GroupsAccumulator::from)
229 )
230 }
231}
232
233unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
234 udaf: &FFI_AggregateUDF,
235 args: FFI_AccumulatorArgs,
236) -> bool {
237 unsafe {
238 let udaf = udaf.inner();
239
240 ForeignAccumulatorArgs::try_from(args)
241 .map(|a| udaf.groups_accumulator_supported((&a).into()))
242 .unwrap_or_else(|e| {
243 log::warn!("Unable to parse accumulator args. {e}");
244 false
245 })
246 }
247}
248
249unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
250 udaf: &FFI_AggregateUDF,
251 beneficial_ordering: bool,
252) -> FFIResult<ROption<FFI_AggregateUDF>> {
253 unsafe {
254 let udaf = udaf.inner().as_ref().clone();
255
256 let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
257 let result = rresult_return!(
258 result
259 .map(|func| func.with_beneficial_ordering(beneficial_ordering))
260 .transpose()
261 )
262 .flatten()
263 .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
264
265 RResult::ROk(result.into())
266 }
267}
268
269unsafe extern "C" fn state_fields_fn_wrapper(
270 udaf: &FFI_AggregateUDF,
271 name: &RStr,
272 input_fields: RVec<WrappedSchema>,
273 return_field: WrappedSchema,
274 ordering_fields: RVec<RVec<u8>>,
275 is_distinct: bool,
276) -> FFIResult<RVec<RVec<u8>>> {
277 unsafe {
278 let udaf = udaf.inner();
279
280 let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
281 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
282
283 let ordering_fields = &rresult_return!(
284 ordering_fields
285 .into_iter()
286 .map(|field_bytes| datafusion_proto_common::Field::decode(
287 field_bytes.as_ref()
288 ))
289 .collect::<std::result::Result<Vec<_>, DecodeError>>()
290 );
291
292 let ordering_fields =
293 &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
294 .into_iter()
295 .map(Arc::new)
296 .collect::<Vec<_>>();
297
298 let args = StateFieldsArgs {
299 name: name.as_str(),
300 input_fields,
301 return_field,
302 ordering_fields,
303 is_distinct,
304 };
305
306 let state_fields = rresult_return!(udaf.state_fields(args));
307 let state_fields = rresult_return!(
308 state_fields
309 .iter()
310 .map(|f| f.as_ref())
311 .map(datafusion_proto::protobuf::Field::try_from)
312 .map(|v| v.map_err(DataFusionError::from))
313 .collect::<Result<Vec<_>>>()
314 )
315 .into_iter()
316 .map(|field| field.encode_to_vec().into())
317 .collect();
318
319 RResult::ROk(state_fields)
320 }
321}
322
323unsafe extern "C" fn order_sensitivity_fn_wrapper(
324 udaf: &FFI_AggregateUDF,
325) -> FFI_AggregateOrderSensitivity {
326 unsafe { udaf.inner().order_sensitivity().into() }
327}
328
329unsafe extern "C" fn coerce_types_fn_wrapper(
330 udaf: &FFI_AggregateUDF,
331 arg_types: RVec<WrappedSchema>,
332) -> FFIResult<RVec<WrappedSchema>> {
333 unsafe {
334 let udaf = udaf.inner();
335
336 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
337
338 let arg_fields = arg_types
339 .iter()
340 .map(|dt| Field::new("f", dt.clone(), true))
341 .map(Arc::new)
342 .collect::<Vec<_>>();
343 let return_types = rresult_return!(fields_with_udf(&arg_fields, udaf.as_ref()))
344 .into_iter()
345 .map(|f| f.data_type().to_owned())
346 .collect::<Vec<_>>();
347
348 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
349 }
350}
351
352unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
353 unsafe {
354 debug_assert!(!udaf.private_data.is_null());
355 let private_data =
356 Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
357 drop(private_data);
358 udaf.private_data = std::ptr::null_mut();
359 }
360}
361
362unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
363 unsafe { Arc::clone(udaf.inner()).into() }
364}
365
366impl Clone for FFI_AggregateUDF {
367 fn clone(&self) -> Self {
368 unsafe { (self.clone)(self) }
369 }
370}
371
372impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
373 fn from(udaf: Arc<AggregateUDF>) -> Self {
374 let name = udaf.name().into();
375 let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
376 let is_nullable = udaf.is_nullable();
377 let volatility = udaf.signature().volatility.into();
378
379 let private_data = Box::new(AggregateUDFPrivateData { udaf });
380
381 Self {
382 name,
383 is_nullable,
384 volatility,
385 aliases,
386 return_field: return_field_fn_wrapper,
387 accumulator: accumulator_fn_wrapper,
388 create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
389 create_groups_accumulator: create_groups_accumulator_fn_wrapper,
390 groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
391 with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
392 state_fields: state_fields_fn_wrapper,
393 order_sensitivity: order_sensitivity_fn_wrapper,
394 coerce_types: coerce_types_fn_wrapper,
395 clone: clone_fn_wrapper,
396 release: release_fn_wrapper,
397 private_data: Box::into_raw(private_data) as *mut c_void,
398 library_marker_id: crate::get_library_marker_id,
399 }
400 }
401}
402
403impl Drop for FFI_AggregateUDF {
404 fn drop(&mut self) {
405 unsafe { (self.release)(self) }
406 }
407}
408
409#[derive(Debug)]
416pub struct ForeignAggregateUDF {
417 signature: Signature,
418 aliases: Vec<String>,
419 udaf: FFI_AggregateUDF,
420}
421
422unsafe impl Send for ForeignAggregateUDF {}
423unsafe impl Sync for ForeignAggregateUDF {}
424
425impl PartialEq for ForeignAggregateUDF {
426 fn eq(&self, other: &Self) -> bool {
427 std::ptr::eq(self, other)
429 }
430}
431impl Eq for ForeignAggregateUDF {}
432impl Hash for ForeignAggregateUDF {
433 fn hash<H: Hasher>(&self, state: &mut H) {
434 std::ptr::hash(self, state)
435 }
436}
437
438impl From<&FFI_AggregateUDF> for Arc<dyn AggregateUDFImpl> {
439 fn from(udaf: &FFI_AggregateUDF) -> Self {
440 if (udaf.library_marker_id)() == crate::get_library_marker_id() {
441 return Arc::clone(unsafe { udaf.inner().inner() });
442 }
443
444 let signature = Signature::user_defined((&udaf.volatility).into());
445 let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
446
447 Arc::new(ForeignAggregateUDF {
448 udaf: udaf.clone(),
449 signature,
450 aliases,
451 })
452 }
453}
454
455impl AggregateUDFImpl for ForeignAggregateUDF {
456 fn as_any(&self) -> &dyn std::any::Any {
457 self
458 }
459
460 fn name(&self) -> &str {
461 self.udaf.name.as_str()
462 }
463
464 fn signature(&self) -> &Signature {
465 &self.signature
466 }
467
468 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
469 unimplemented!()
470 }
471
472 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
473 let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
474
475 let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
476
477 let result = df_result!(result);
478
479 result.and_then(|r| {
480 Field::try_from(&r.0)
481 .map(Arc::new)
482 .map_err(DataFusionError::from)
483 })
484 }
485
486 fn is_nullable(&self) -> bool {
487 self.udaf.is_nullable
488 }
489
490 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
491 let args = acc_args.try_into()?;
492 unsafe {
493 df_result!((self.udaf.accumulator)(&self.udaf, args))
494 .map(<Box<dyn Accumulator>>::from)
495 }
496 }
497
498 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
499 unsafe {
500 let name = RStr::from_str(args.name);
501 let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
502 let return_field =
503 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
504 let ordering_fields = args
505 .ordering_fields
506 .iter()
507 .map(|f| f.as_ref())
508 .map(datafusion_proto::protobuf::Field::try_from)
509 .map(|v| v.map_err(DataFusionError::from))
510 .collect::<Result<Vec<_>>>()?
511 .into_iter()
512 .map(|proto_field| proto_field.encode_to_vec().into())
513 .collect();
514
515 let fields = df_result!((self.udaf.state_fields)(
516 &self.udaf,
517 &name,
518 input_fields,
519 return_field,
520 ordering_fields,
521 args.is_distinct
522 ))?;
523 let fields = fields
524 .into_iter()
525 .map(|field_bytes| {
526 datafusion_proto_common::Field::decode(field_bytes.as_ref())
527 .map_err(|e| ffi_datafusion_err!("{e}"))
528 })
529 .collect::<Result<Vec<_>>>()?;
530
531 parse_proto_fields_to_fields(fields.iter())
532 .map(|fields| fields.into_iter().map(Arc::new).collect())
533 .map_err(|e| ffi_datafusion_err!("{e}"))
534 }
535 }
536
537 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
538 let args = match FFI_AccumulatorArgs::try_from(args) {
539 Ok(v) => v,
540 Err(e) => {
541 log::warn!("Attempting to convert accumulator arguments: {e}");
542 return false;
543 }
544 };
545
546 unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
547 }
548
549 fn create_groups_accumulator(
550 &self,
551 args: AccumulatorArgs,
552 ) -> Result<Box<dyn GroupsAccumulator>> {
553 let args = FFI_AccumulatorArgs::try_from(args)?;
554
555 unsafe {
556 df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args))
557 .map(<Box<dyn GroupsAccumulator>>::from)
558 }
559 }
560
561 fn aliases(&self) -> &[String] {
562 &self.aliases
563 }
564
565 fn create_sliding_accumulator(
566 &self,
567 args: AccumulatorArgs,
568 ) -> Result<Box<dyn Accumulator>> {
569 let args = args.try_into()?;
570 unsafe {
571 df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args))
572 .map(<Box<dyn Accumulator>>::from)
573 }
574 }
575
576 fn with_beneficial_ordering(
577 self: Arc<Self>,
578 beneficial_ordering: bool,
579 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
580 unsafe {
581 let result = df_result!((self.udaf.with_beneficial_ordering)(
582 &self.udaf,
583 beneficial_ordering
584 ))?
585 .into_option();
586
587 let result = result.map(|func| <Arc<dyn AggregateUDFImpl>>::from(&func));
588
589 Ok(result)
590 }
591 }
592
593 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
594 unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
595 }
596
597 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
598 None
599 }
600
601 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
602 unsafe {
603 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
604 let result_types =
605 df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
606 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
607 }
608 }
609}
610
611#[repr(C)]
612#[derive(Debug, StableAbi)]
613pub enum FFI_AggregateOrderSensitivity {
614 Insensitive,
615 HardRequirement,
616 SoftRequirement,
617 Beneficial,
618}
619
620impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
621 fn from(value: FFI_AggregateOrderSensitivity) -> Self {
622 match value {
623 FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
624 FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
625 FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
626 FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
627 }
628 }
629}
630
631impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
632 fn from(value: AggregateOrderSensitivity) -> Self {
633 match value {
634 AggregateOrderSensitivity::Insensitive => Self::Insensitive,
635 AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
636 AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
637 AggregateOrderSensitivity::Beneficial => Self::Beneficial,
638 }
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use std::any::Any;
645 use std::collections::HashMap;
646
647 use arrow::datatypes::Schema;
648 use datafusion::common::create_array;
649 use datafusion::functions_aggregate::sum::Sum;
650 use datafusion::physical_expr::PhysicalSortExpr;
651 use datafusion::physical_plan::expressions::col;
652 use datafusion::scalar::ScalarValue;
653
654 use super::*;
655
656 #[derive(Default, Debug, Hash, Eq, PartialEq)]
657 struct SumWithCopiedMetadata {
658 inner: Sum,
659 }
660
661 impl AggregateUDFImpl for SumWithCopiedMetadata {
662 fn as_any(&self) -> &dyn Any {
663 self
664 }
665
666 fn name(&self) -> &str {
667 self.inner.name()
668 }
669
670 fn signature(&self) -> &Signature {
671 self.inner.signature()
672 }
673
674 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
675 unimplemented!()
676 }
677
678 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
679 Ok(Arc::clone(&arg_fields[0]))
681 }
682
683 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
684 self.inner.accumulator(acc_args)
685 }
686 }
687
688 fn create_test_foreign_udaf(
689 original_udaf: impl AggregateUDFImpl + 'static,
690 ) -> Result<AggregateUDF> {
691 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
692
693 let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
694 local_udaf.library_marker_id = crate::mock_foreign_marker_id;
695
696 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
697 Ok(AggregateUDF::new_from_shared_impl(foreign_udaf))
698 }
699
700 #[test]
701 fn test_round_trip_udaf() -> Result<()> {
702 let original_udaf = Sum::new();
703 let original_name = original_udaf.name().to_owned();
704 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
705
706 let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
708 local_udaf.library_marker_id = crate::mock_foreign_marker_id;
709
710 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
712 let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
713
714 assert_eq!(original_name, foreign_udaf.name());
715 Ok(())
716 }
717
718 #[test]
719 fn test_foreign_udaf_aliases() -> Result<()> {
720 let foreign_udaf =
721 create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
722
723 let return_field =
724 foreign_udaf
725 .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
726 let return_type = return_field.data_type();
727 assert_eq!(return_type, &DataType::Float64);
728 Ok(())
729 }
730
731 #[test]
732 fn test_foreign_udaf_accumulator() -> Result<()> {
733 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
734
735 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
736 let acc_args = AccumulatorArgs {
737 return_field: Field::new("f", DataType::Float64, true).into(),
738 schema: &schema,
739 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
740 ignore_nulls: true,
741 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
742 is_reversed: false,
743 name: "round_trip",
744 is_distinct: true,
745 exprs: &[col("a", &schema)?],
746 };
747 let mut accumulator = foreign_udaf.accumulator(acc_args)?;
748 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
749 accumulator.update_batch(&[values])?;
750 let resultant_value = accumulator.evaluate()?;
751 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
752
753 Ok(())
754 }
755
756 #[test]
757 fn test_round_trip_udaf_metadata() -> Result<()> {
758 let original_udaf = SumWithCopiedMetadata::default();
759 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
760
761 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
763
764 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
766 let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
767
768 let metadata: HashMap<String, String> =
769 [("a_key".to_string(), "a_value".to_string())]
770 .into_iter()
771 .collect();
772 let input_field = Arc::new(
773 Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
774 );
775 let return_field = foreign_udaf.return_field(&[input_field])?;
776
777 assert_eq!(&metadata, return_field.metadata());
778 Ok(())
779 }
780
781 #[test]
782 fn test_beneficial_ordering() -> Result<()> {
783 let foreign_udaf = create_test_foreign_udaf(
784 datafusion::functions_aggregate::first_last::FirstValue::new(),
785 )?;
786
787 let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
788
789 assert_eq!(
790 foreign_udaf.order_sensitivity(),
791 AggregateOrderSensitivity::Beneficial
792 );
793
794 let a_field = Arc::new(Field::new("a", DataType::Float64, true));
795 let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
796 name: "a",
797 input_fields: &[Field::new("f", DataType::Float64, true).into()],
798 return_field: Field::new("f", DataType::Float64, true).into(),
799 ordering_fields: &[Arc::clone(&a_field)],
800 is_distinct: false,
801 })?;
802
803 assert_eq!(state_fields.len(), 3);
804 assert_eq!(state_fields[1], a_field);
805 Ok(())
806 }
807
808 #[test]
809 fn test_sliding_accumulator() -> Result<()> {
810 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
811
812 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
813 let acc_args = AccumulatorArgs {
815 return_field: Field::new("f", DataType::Float64, true).into(),
816 schema: &schema,
817 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
818 ignore_nulls: true,
819 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
820 is_reversed: false,
821 name: "round_trip",
822 is_distinct: false,
823 exprs: &[col("a", &schema)?],
824 };
825
826 let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
827 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
828 accumulator.update_batch(&[values])?;
829 let resultant_value = accumulator.evaluate()?;
830 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
831
832 Ok(())
833 }
834
835 fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
836 let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
837 let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
838
839 assert_eq!(sensitivity, round_trip_sensitivity);
840 }
841
842 #[test]
843 fn test_round_trip_all_order_sensitivities() {
844 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
845 test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
846 test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
847 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
848 }
849
850 #[test]
851 fn test_ffi_udaf_local_bypass() -> Result<()> {
852 let original_udaf = Sum::new();
853 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
854
855 let mut ffi_udaf = FFI_AggregateUDF::from(original_udaf);
856
857 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
859 assert!(foreign_udaf.as_any().downcast_ref::<Sum>().is_some());
860
861 ffi_udaf.library_marker_id = crate::mock_foreign_marker_id;
863 let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
864 assert!(
865 foreign_udaf
866 .as_any()
867 .downcast_ref::<ForeignAggregateUDF>()
868 .is_some()
869 );
870
871 Ok(())
872 }
873}