1use abi_stable::{
19 std_types::{ROption, RResult, RStr, RString, RVec},
20 StableAbi,
21};
22use accumulator::{FFI_Accumulator, ForeignAccumulator};
23use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
24use arrow::datatypes::{DataType, Field};
25use arrow::ffi::FFI_ArrowSchema;
26use arrow_schema::FieldRef;
27use datafusion::{
28 error::DataFusionError,
29 logical_expr::{
30 function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
31 type_coercion::functions::fields_with_aggregate_udf,
32 utils::AggregateOrderSensitivity,
33 Accumulator, GroupsAccumulator,
34 },
35};
36use datafusion::{
37 error::Result,
38 logical_expr::{AggregateUDF, AggregateUDFImpl, Signature},
39};
40use datafusion_common::exec_datafusion_err;
41use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
42use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator};
43use std::hash::{Hash, Hasher};
44use std::{ffi::c_void, sync::Arc};
45
46use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
47use crate::{
48 arrow_wrappers::WrappedSchema,
49 df_result, rresult, rresult_return,
50 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
51 volatility::FFI_Volatility,
52};
53use prost::{DecodeError, Message};
54
55mod accumulator;
56mod accumulator_args;
57mod groups_accumulator;
58
59#[repr(C)]
61#[derive(Debug, StableAbi)]
62#[allow(non_camel_case_types)]
63pub struct FFI_AggregateUDF {
64 pub name: RString,
66
67 pub aliases: RVec<RString>,
69
70 pub volatility: FFI_Volatility,
72
73 pub return_field: unsafe extern "C" fn(
76 udaf: &Self,
77 arg_fields: RVec<WrappedSchema>,
78 ) -> RResult<WrappedSchema, RString>,
79
80 pub is_nullable: bool,
82
83 pub groups_accumulator_supported:
85 unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
86
87 pub accumulator: unsafe extern "C" fn(
89 udaf: &FFI_AggregateUDF,
90 args: FFI_AccumulatorArgs,
91 ) -> RResult<FFI_Accumulator, RString>,
92
93 pub create_sliding_accumulator:
95 unsafe extern "C" fn(
96 udaf: &FFI_AggregateUDF,
97 args: FFI_AccumulatorArgs,
98 ) -> RResult<FFI_Accumulator, RString>,
99
100 #[allow(clippy::type_complexity)]
102 pub state_fields: unsafe extern "C" fn(
103 udaf: &FFI_AggregateUDF,
104 name: &RStr,
105 input_fields: RVec<WrappedSchema>,
106 return_field: WrappedSchema,
107 ordering_fields: RVec<RVec<u8>>,
108 is_distinct: bool,
109 ) -> RResult<RVec<RVec<u8>>, RString>,
110
111 pub create_groups_accumulator:
113 unsafe extern "C" fn(
114 udaf: &FFI_AggregateUDF,
115 args: FFI_AccumulatorArgs,
116 ) -> RResult<FFI_GroupsAccumulator, RString>,
117
118 pub with_beneficial_ordering:
120 unsafe extern "C" fn(
121 udaf: &FFI_AggregateUDF,
122 beneficial_ordering: bool,
123 ) -> RResult<ROption<FFI_AggregateUDF>, RString>,
124
125 pub order_sensitivity:
127 unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
128
129 pub coerce_types: unsafe extern "C" fn(
134 udf: &Self,
135 arg_types: RVec<WrappedSchema>,
136 ) -> RResult<RVec<WrappedSchema>, RString>,
137
138 pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
141
142 pub release: unsafe extern "C" fn(udaf: &mut Self),
144
145 pub private_data: *mut c_void,
148}
149
150unsafe impl Send for FFI_AggregateUDF {}
151unsafe impl Sync for FFI_AggregateUDF {}
152
153pub struct AggregateUDFPrivateData {
154 pub udaf: Arc<AggregateUDF>,
155}
156
157impl FFI_AggregateUDF {
158 unsafe fn inner(&self) -> &Arc<AggregateUDF> {
159 let private_data = self.private_data as *const AggregateUDFPrivateData;
160 &(*private_data).udaf
161 }
162}
163
164unsafe extern "C" fn return_field_fn_wrapper(
165 udaf: &FFI_AggregateUDF,
166 arg_fields: RVec<WrappedSchema>,
167) -> RResult<WrappedSchema, RString> {
168 let udaf = udaf.inner();
169
170 let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
171
172 let return_field = udaf
173 .return_field(&arg_fields)
174 .and_then(|v| {
175 FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
176 })
177 .map(WrappedSchema);
178
179 rresult!(return_field)
180}
181
182unsafe extern "C" fn accumulator_fn_wrapper(
183 udaf: &FFI_AggregateUDF,
184 args: FFI_AccumulatorArgs,
185) -> RResult<FFI_Accumulator, RString> {
186 let udaf = udaf.inner();
187
188 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
189
190 rresult!(udaf
191 .accumulator(accumulator_args.into())
192 .map(FFI_Accumulator::from))
193}
194
195unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
196 udaf: &FFI_AggregateUDF,
197 args: FFI_AccumulatorArgs,
198) -> RResult<FFI_Accumulator, RString> {
199 let udaf = udaf.inner();
200
201 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
202
203 rresult!(udaf
204 .create_sliding_accumulator(accumulator_args.into())
205 .map(FFI_Accumulator::from))
206}
207
208unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
209 udaf: &FFI_AggregateUDF,
210 args: FFI_AccumulatorArgs,
211) -> RResult<FFI_GroupsAccumulator, RString> {
212 let udaf = udaf.inner();
213
214 let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
215
216 rresult!(udaf
217 .create_groups_accumulator(accumulator_args.into())
218 .map(FFI_GroupsAccumulator::from))
219}
220
221unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
222 udaf: &FFI_AggregateUDF,
223 args: FFI_AccumulatorArgs,
224) -> bool {
225 let udaf = udaf.inner();
226
227 ForeignAccumulatorArgs::try_from(args)
228 .map(|a| udaf.groups_accumulator_supported((&a).into()))
229 .unwrap_or_else(|e| {
230 log::warn!("Unable to parse accumulator args. {e}");
231 false
232 })
233}
234
235unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
236 udaf: &FFI_AggregateUDF,
237 beneficial_ordering: bool,
238) -> RResult<ROption<FFI_AggregateUDF>, RString> {
239 let udaf = udaf.inner().as_ref().clone();
240
241 let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
242 let result = rresult_return!(result
243 .map(|func| func.with_beneficial_ordering(beneficial_ordering))
244 .transpose())
245 .flatten()
246 .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
247
248 RResult::ROk(result.into())
249}
250
251unsafe extern "C" fn state_fields_fn_wrapper(
252 udaf: &FFI_AggregateUDF,
253 name: &RStr,
254 input_fields: RVec<WrappedSchema>,
255 return_field: WrappedSchema,
256 ordering_fields: RVec<RVec<u8>>,
257 is_distinct: bool,
258) -> RResult<RVec<RVec<u8>>, RString> {
259 let udaf = udaf.inner();
260
261 let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
262 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
263
264 let ordering_fields = &rresult_return!(ordering_fields
265 .into_iter()
266 .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref()))
267 .collect::<std::result::Result<Vec<_>, DecodeError>>());
268
269 let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
270 .into_iter()
271 .map(Arc::new)
272 .collect::<Vec<_>>();
273
274 let args = StateFieldsArgs {
275 name: name.as_str(),
276 input_fields,
277 return_field,
278 ordering_fields,
279 is_distinct,
280 };
281
282 let state_fields = rresult_return!(udaf.state_fields(args));
283 let state_fields = rresult_return!(state_fields
284 .iter()
285 .map(|f| f.as_ref())
286 .map(datafusion_proto::protobuf::Field::try_from)
287 .map(|v| v.map_err(DataFusionError::from))
288 .collect::<Result<Vec<_>>>())
289 .into_iter()
290 .map(|field| field.encode_to_vec().into())
291 .collect();
292
293 RResult::ROk(state_fields)
294}
295
296unsafe extern "C" fn order_sensitivity_fn_wrapper(
297 udaf: &FFI_AggregateUDF,
298) -> FFI_AggregateOrderSensitivity {
299 udaf.inner().order_sensitivity().into()
300}
301
302unsafe extern "C" fn coerce_types_fn_wrapper(
303 udaf: &FFI_AggregateUDF,
304 arg_types: RVec<WrappedSchema>,
305) -> RResult<RVec<WrappedSchema>, RString> {
306 let udaf = udaf.inner();
307
308 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
309
310 let arg_fields = arg_types
311 .iter()
312 .map(|dt| Field::new("f", dt.clone(), true))
313 .map(Arc::new)
314 .collect::<Vec<_>>();
315 let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf))
316 .into_iter()
317 .map(|f| f.data_type().to_owned())
318 .collect::<Vec<_>>();
319
320 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
321}
322
323unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
324 let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
325 drop(private_data);
326}
327
328unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
329 Arc::clone(udaf.inner()).into()
330}
331
332impl Clone for FFI_AggregateUDF {
333 fn clone(&self) -> Self {
334 unsafe { (self.clone)(self) }
335 }
336}
337
338impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
339 fn from(udaf: Arc<AggregateUDF>) -> Self {
340 let name = udaf.name().into();
341 let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
342 let is_nullable = udaf.is_nullable();
343 let volatility = udaf.signature().volatility.into();
344
345 let private_data = Box::new(AggregateUDFPrivateData { udaf });
346
347 Self {
348 name,
349 is_nullable,
350 volatility,
351 aliases,
352 return_field: return_field_fn_wrapper,
353 accumulator: accumulator_fn_wrapper,
354 create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
355 create_groups_accumulator: create_groups_accumulator_fn_wrapper,
356 groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
357 with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
358 state_fields: state_fields_fn_wrapper,
359 order_sensitivity: order_sensitivity_fn_wrapper,
360 coerce_types: coerce_types_fn_wrapper,
361 clone: clone_fn_wrapper,
362 release: release_fn_wrapper,
363 private_data: Box::into_raw(private_data) as *mut c_void,
364 }
365 }
366}
367
368impl Drop for FFI_AggregateUDF {
369 fn drop(&mut self) {
370 unsafe { (self.release)(self) }
371 }
372}
373
374#[derive(Debug)]
381pub struct ForeignAggregateUDF {
382 signature: Signature,
383 aliases: Vec<String>,
384 udaf: FFI_AggregateUDF,
385}
386
387unsafe impl Send for ForeignAggregateUDF {}
388unsafe impl Sync for ForeignAggregateUDF {}
389
390impl PartialEq for ForeignAggregateUDF {
391 fn eq(&self, other: &Self) -> bool {
392 std::ptr::eq(self, other)
394 }
395}
396impl Eq for ForeignAggregateUDF {}
397impl Hash for ForeignAggregateUDF {
398 fn hash<H: Hasher>(&self, state: &mut H) {
399 std::ptr::hash(self, state)
400 }
401}
402
403impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF {
404 type Error = DataFusionError;
405
406 fn try_from(udaf: &FFI_AggregateUDF) -> Result<Self, Self::Error> {
407 let signature = Signature::user_defined((&udaf.volatility).into());
408 let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
409
410 Ok(Self {
411 udaf: udaf.clone(),
412 signature,
413 aliases,
414 })
415 }
416}
417
418impl AggregateUDFImpl for ForeignAggregateUDF {
419 fn as_any(&self) -> &dyn std::any::Any {
420 self
421 }
422
423 fn name(&self) -> &str {
424 self.udaf.name.as_str()
425 }
426
427 fn signature(&self) -> &Signature {
428 &self.signature
429 }
430
431 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
432 unimplemented!()
433 }
434
435 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
436 let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
437
438 let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
439
440 let result = df_result!(result);
441
442 result.and_then(|r| {
443 Field::try_from(&r.0)
444 .map(Arc::new)
445 .map_err(DataFusionError::from)
446 })
447 }
448
449 fn is_nullable(&self) -> bool {
450 self.udaf.is_nullable
451 }
452
453 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
454 let args = acc_args.try_into()?;
455 unsafe {
456 df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| {
457 Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>
458 })
459 }
460 }
461
462 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
463 unsafe {
464 let name = RStr::from_str(args.name);
465 let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
466 let return_field =
467 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
468 let ordering_fields = args
469 .ordering_fields
470 .iter()
471 .map(|f| f.as_ref())
472 .map(datafusion_proto::protobuf::Field::try_from)
473 .map(|v| v.map_err(DataFusionError::from))
474 .collect::<Result<Vec<_>>>()?
475 .into_iter()
476 .map(|proto_field| proto_field.encode_to_vec().into())
477 .collect();
478
479 let fields = df_result!((self.udaf.state_fields)(
480 &self.udaf,
481 &name,
482 input_fields,
483 return_field,
484 ordering_fields,
485 args.is_distinct
486 ))?;
487 let fields = fields
488 .into_iter()
489 .map(|field_bytes| {
490 datafusion_proto_common::Field::decode(field_bytes.as_ref())
491 .map_err(|e| exec_datafusion_err!("{e}"))
492 })
493 .collect::<Result<Vec<_>>>()?;
494
495 parse_proto_fields_to_fields(fields.iter())
496 .map(|fields| fields.into_iter().map(Arc::new).collect())
497 .map_err(|e| exec_datafusion_err!("{e}"))
498 }
499 }
500
501 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
502 let args = match FFI_AccumulatorArgs::try_from(args) {
503 Ok(v) => v,
504 Err(e) => {
505 log::warn!("Attempting to convert accumulator arguments: {e}");
506 return false;
507 }
508 };
509
510 unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
511 }
512
513 fn create_groups_accumulator(
514 &self,
515 args: AccumulatorArgs,
516 ) -> Result<Box<dyn GroupsAccumulator>> {
517 let args = FFI_AccumulatorArgs::try_from(args)?;
518
519 unsafe {
520 df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map(
521 |accum| {
522 Box::new(ForeignGroupsAccumulator::from(accum))
523 as Box<dyn GroupsAccumulator>
524 },
525 )
526 }
527 }
528
529 fn aliases(&self) -> &[String] {
530 &self.aliases
531 }
532
533 fn create_sliding_accumulator(
534 &self,
535 args: AccumulatorArgs,
536 ) -> Result<Box<dyn Accumulator>> {
537 let args = args.try_into()?;
538 unsafe {
539 df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map(
540 |accum| Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>,
541 )
542 }
543 }
544
545 fn with_beneficial_ordering(
546 self: Arc<Self>,
547 beneficial_ordering: bool,
548 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
549 unsafe {
550 let result = df_result!((self.udaf.with_beneficial_ordering)(
551 &self.udaf,
552 beneficial_ordering
553 ))?
554 .into_option();
555
556 let result = result
557 .map(|func| ForeignAggregateUDF::try_from(&func))
558 .transpose()?;
559
560 Ok(result.map(|func| Arc::new(func) as Arc<dyn AggregateUDFImpl>))
561 }
562 }
563
564 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
565 unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
566 }
567
568 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
569 None
570 }
571
572 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
573 unsafe {
574 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
575 let result_types =
576 df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
577 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
578 }
579 }
580}
581
582#[repr(C)]
583#[derive(Debug, StableAbi)]
584#[allow(non_camel_case_types)]
585pub enum FFI_AggregateOrderSensitivity {
586 Insensitive,
587 HardRequirement,
588 SoftRequirement,
589 Beneficial,
590}
591
592impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
593 fn from(value: FFI_AggregateOrderSensitivity) -> Self {
594 match value {
595 FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
596 FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
597 FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
598 FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
599 }
600 }
601}
602
603impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
604 fn from(value: AggregateOrderSensitivity) -> Self {
605 match value {
606 AggregateOrderSensitivity::Insensitive => Self::Insensitive,
607 AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
608 AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
609 AggregateOrderSensitivity::Beneficial => Self::Beneficial,
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use arrow::datatypes::Schema;
617 use datafusion::{
618 common::create_array, functions_aggregate::sum::Sum,
619 physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
620 scalar::ScalarValue,
621 };
622 use std::any::Any;
623 use std::collections::HashMap;
624
625 use super::*;
626
627 #[derive(Default, Debug, Hash, Eq, PartialEq)]
628 struct SumWithCopiedMetadata {
629 inner: Sum,
630 }
631
632 impl AggregateUDFImpl for SumWithCopiedMetadata {
633 fn as_any(&self) -> &dyn Any {
634 self
635 }
636
637 fn name(&self) -> &str {
638 self.inner.name()
639 }
640
641 fn signature(&self) -> &Signature {
642 self.inner.signature()
643 }
644
645 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
646 unimplemented!()
647 }
648
649 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
650 Ok(Arc::clone(&arg_fields[0]))
652 }
653
654 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
655 self.inner.accumulator(acc_args)
656 }
657 }
658
659 fn create_test_foreign_udaf(
660 original_udaf: impl AggregateUDFImpl + 'static,
661 ) -> Result<AggregateUDF> {
662 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
663
664 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
665
666 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
667 Ok(foreign_udaf.into())
668 }
669
670 #[test]
671 fn test_round_trip_udaf() -> Result<()> {
672 let original_udaf = Sum::new();
673 let original_name = original_udaf.name().to_owned();
674 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
675
676 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
678
679 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
681 let foreign_udaf: AggregateUDF = foreign_udaf.into();
682
683 assert_eq!(original_name, foreign_udaf.name());
684 Ok(())
685 }
686
687 #[test]
688 fn test_foreign_udaf_aliases() -> Result<()> {
689 let foreign_udaf =
690 create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
691
692 let return_field =
693 foreign_udaf
694 .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
695 let return_type = return_field.data_type();
696 assert_eq!(return_type, &DataType::Float64);
697 Ok(())
698 }
699
700 #[test]
701 fn test_foreign_udaf_accumulator() -> Result<()> {
702 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
703
704 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
705 let acc_args = AccumulatorArgs {
706 return_field: Field::new("f", DataType::Float64, true).into(),
707 schema: &schema,
708 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
709 ignore_nulls: true,
710 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
711 is_reversed: false,
712 name: "round_trip",
713 is_distinct: true,
714 exprs: &[col("a", &schema)?],
715 };
716 let mut accumulator = foreign_udaf.accumulator(acc_args)?;
717 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
718 accumulator.update_batch(&[values])?;
719 let resultant_value = accumulator.evaluate()?;
720 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
721
722 Ok(())
723 }
724
725 #[test]
726 fn test_round_trip_udaf_metadata() -> Result<()> {
727 let original_udaf = SumWithCopiedMetadata::default();
728 let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
729
730 let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
732
733 let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
735 let foreign_udaf: AggregateUDF = foreign_udaf.into();
736
737 let metadata: HashMap<String, String> =
738 [("a_key".to_string(), "a_value".to_string())]
739 .into_iter()
740 .collect();
741 let input_field = Arc::new(
742 Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
743 );
744 let return_field = foreign_udaf.return_field(&[input_field])?;
745
746 assert_eq!(&metadata, return_field.metadata());
747 Ok(())
748 }
749
750 #[test]
751 fn test_beneficial_ordering() -> Result<()> {
752 let foreign_udaf = create_test_foreign_udaf(
753 datafusion::functions_aggregate::first_last::FirstValue::new(),
754 )?;
755
756 let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
757
758 assert_eq!(
759 foreign_udaf.order_sensitivity(),
760 AggregateOrderSensitivity::Beneficial
761 );
762
763 let a_field = Arc::new(Field::new("a", DataType::Float64, true));
764 let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
765 name: "a",
766 input_fields: &[Field::new("f", DataType::Float64, true).into()],
767 return_field: Field::new("f", DataType::Float64, true).into(),
768 ordering_fields: &[Arc::clone(&a_field)],
769 is_distinct: false,
770 })?;
771
772 assert_eq!(state_fields.len(), 3);
773 assert_eq!(state_fields[1], a_field);
774 Ok(())
775 }
776
777 #[test]
778 fn test_sliding_accumulator() -> Result<()> {
779 let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
780
781 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
782 let acc_args = AccumulatorArgs {
784 return_field: Field::new("f", DataType::Float64, true).into(),
785 schema: &schema,
786 expr_fields: &[Field::new("a", DataType::Float64, true).into()],
787 ignore_nulls: true,
788 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
789 is_reversed: false,
790 name: "round_trip",
791 is_distinct: false,
792 exprs: &[col("a", &schema)?],
793 };
794
795 let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
796 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
797 accumulator.update_batch(&[values])?;
798 let resultant_value = accumulator.evaluate()?;
799 assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
800
801 Ok(())
802 }
803
804 fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
805 let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
806 let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
807
808 assert_eq!(sensitivity, round_trip_sensitivity);
809 }
810
811 #[test]
812 fn test_round_trip_all_order_sensitivities() {
813 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
814 test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
815 test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
816 test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
817 }
818}