datafusion_ffi/udaf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use abi_stable::{
19    std_types::{ROption, RResult, RStr, RString, RVec},
20    StableAbi,
21};
22use accumulator::{FFI_Accumulator, ForeignAccumulator};
23use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
24use arrow::datatypes::{DataType, Field};
25use arrow::ffi::FFI_ArrowSchema;
26use arrow_schema::FieldRef;
27use datafusion::{
28    error::DataFusionError,
29    logical_expr::{
30        function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
31        type_coercion::functions::fields_with_aggregate_udf,
32        utils::AggregateOrderSensitivity,
33        Accumulator, GroupsAccumulator,
34    },
35};
36use datafusion::{
37    error::Result,
38    logical_expr::{AggregateUDF, AggregateUDFImpl, Signature},
39};
40use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
41use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator};
42use std::hash::{Hash, Hasher};
43use std::{ffi::c_void, sync::Arc};
44
45use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
46use crate::{
47    arrow_wrappers::WrappedSchema,
48    df_result, rresult, rresult_return,
49    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
50    volatility::FFI_Volatility,
51};
52use prost::{DecodeError, Message};
53
54mod accumulator;
55mod accumulator_args;
56mod groups_accumulator;
57
58/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries.
59#[repr(C)]
60#[derive(Debug, StableAbi)]
61#[allow(non_camel_case_types)]
62pub struct FFI_AggregateUDF {
63    /// FFI equivalent to the `name` of a [`AggregateUDF`]
64    pub name: RString,
65
66    /// FFI equivalent to the `aliases` of a [`AggregateUDF`]
67    pub aliases: RVec<RString>,
68
69    /// FFI equivalent to the `volatility` of a [`AggregateUDF`]
70    pub volatility: FFI_Volatility,
71
72    /// Determines the return field of the underlying [`AggregateUDF`] based on the
73    /// argument fields.
74    pub return_field: unsafe extern "C" fn(
75        udaf: &Self,
76        arg_fields: RVec<WrappedSchema>,
77    ) -> RResult<WrappedSchema, RString>,
78
79    /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
80    pub is_nullable: bool,
81
82    /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`]
83    pub groups_accumulator_supported:
84        unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
85
86    /// FFI equivalent to [`AggregateUDF::accumulator`]
87    pub accumulator: unsafe extern "C" fn(
88        udaf: &FFI_AggregateUDF,
89        args: FFI_AccumulatorArgs,
90    ) -> RResult<FFI_Accumulator, RString>,
91
92    /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`]
93    pub create_sliding_accumulator:
94        unsafe extern "C" fn(
95            udaf: &FFI_AggregateUDF,
96            args: FFI_AccumulatorArgs,
97        ) -> RResult<FFI_Accumulator, RString>,
98
99    /// FFI equivalent to [`AggregateUDF::state_fields`]
100    #[allow(clippy::type_complexity)]
101    pub state_fields: unsafe extern "C" fn(
102        udaf: &FFI_AggregateUDF,
103        name: &RStr,
104        input_fields: RVec<WrappedSchema>,
105        return_field: WrappedSchema,
106        ordering_fields: RVec<RVec<u8>>,
107        is_distinct: bool,
108    ) -> RResult<RVec<RVec<u8>>, RString>,
109
110    /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`]
111    pub create_groups_accumulator:
112        unsafe extern "C" fn(
113            udaf: &FFI_AggregateUDF,
114            args: FFI_AccumulatorArgs,
115        ) -> RResult<FFI_GroupsAccumulator, RString>,
116
117    /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`]
118    pub with_beneficial_ordering:
119        unsafe extern "C" fn(
120            udaf: &FFI_AggregateUDF,
121            beneficial_ordering: bool,
122        ) -> RResult<ROption<FFI_AggregateUDF>, RString>,
123
124    /// FFI equivalent to [`AggregateUDF::order_sensitivity`]
125    pub order_sensitivity:
126        unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
127
128    /// Performs type coercion. To simply this interface, all UDFs are treated as having
129    /// user defined signatures, which will in turn call coerce_types to be called. This
130    /// call should be transparent to most users as the internal function performs the
131    /// appropriate calls on the underlying [`AggregateUDF`]
132    pub coerce_types: unsafe extern "C" fn(
133        udf: &Self,
134        arg_types: RVec<WrappedSchema>,
135    ) -> RResult<RVec<WrappedSchema>, RString>,
136
137    /// Used to create a clone on the provider of the udaf. This should
138    /// only need to be called by the receiver of the udaf.
139    pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
140
141    /// Release the memory of the private data when it is no longer being used.
142    pub release: unsafe extern "C" fn(udaf: &mut Self),
143
144    /// Internal data. This is only to be accessed by the provider of the udaf.
145    /// A [`ForeignAggregateUDF`] should never attempt to access this data.
146    pub private_data: *mut c_void,
147}
148
149unsafe impl Send for FFI_AggregateUDF {}
150unsafe impl Sync for FFI_AggregateUDF {}
151
152pub struct AggregateUDFPrivateData {
153    pub udaf: Arc<AggregateUDF>,
154}
155
156impl FFI_AggregateUDF {
157    unsafe fn inner(&self) -> &Arc<AggregateUDF> {
158        let private_data = self.private_data as *const AggregateUDFPrivateData;
159        &(*private_data).udaf
160    }
161}
162
163unsafe extern "C" fn return_field_fn_wrapper(
164    udaf: &FFI_AggregateUDF,
165    arg_fields: RVec<WrappedSchema>,
166) -> RResult<WrappedSchema, RString> {
167    let udaf = udaf.inner();
168
169    let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
170
171    let return_field = udaf
172        .return_field(&arg_fields)
173        .and_then(|v| {
174            FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
175        })
176        .map(WrappedSchema);
177
178    rresult!(return_field)
179}
180
181unsafe extern "C" fn accumulator_fn_wrapper(
182    udaf: &FFI_AggregateUDF,
183    args: FFI_AccumulatorArgs,
184) -> RResult<FFI_Accumulator, RString> {
185    let udaf = udaf.inner();
186
187    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
188
189    rresult!(udaf
190        .accumulator(accumulator_args.into())
191        .map(FFI_Accumulator::from))
192}
193
194unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
195    udaf: &FFI_AggregateUDF,
196    args: FFI_AccumulatorArgs,
197) -> RResult<FFI_Accumulator, RString> {
198    let udaf = udaf.inner();
199
200    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
201
202    rresult!(udaf
203        .create_sliding_accumulator(accumulator_args.into())
204        .map(FFI_Accumulator::from))
205}
206
207unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
208    udaf: &FFI_AggregateUDF,
209    args: FFI_AccumulatorArgs,
210) -> RResult<FFI_GroupsAccumulator, RString> {
211    let udaf = udaf.inner();
212
213    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
214
215    rresult!(udaf
216        .create_groups_accumulator(accumulator_args.into())
217        .map(FFI_GroupsAccumulator::from))
218}
219
220unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
221    udaf: &FFI_AggregateUDF,
222    args: FFI_AccumulatorArgs,
223) -> bool {
224    let udaf = udaf.inner();
225
226    ForeignAccumulatorArgs::try_from(args)
227        .map(|a| udaf.groups_accumulator_supported((&a).into()))
228        .unwrap_or_else(|e| {
229            log::warn!("Unable to parse accumulator args. {e}");
230            false
231        })
232}
233
234unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
235    udaf: &FFI_AggregateUDF,
236    beneficial_ordering: bool,
237) -> RResult<ROption<FFI_AggregateUDF>, RString> {
238    let udaf = udaf.inner().as_ref().clone();
239
240    let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
241    let result = rresult_return!(result
242        .map(|func| func.with_beneficial_ordering(beneficial_ordering))
243        .transpose())
244    .flatten()
245    .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
246
247    RResult::ROk(result.into())
248}
249
250unsafe extern "C" fn state_fields_fn_wrapper(
251    udaf: &FFI_AggregateUDF,
252    name: &RStr,
253    input_fields: RVec<WrappedSchema>,
254    return_field: WrappedSchema,
255    ordering_fields: RVec<RVec<u8>>,
256    is_distinct: bool,
257) -> RResult<RVec<RVec<u8>>, RString> {
258    let udaf = udaf.inner();
259
260    let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
261    let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
262
263    let ordering_fields = &rresult_return!(ordering_fields
264        .into_iter()
265        .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref()))
266        .collect::<std::result::Result<Vec<_>, DecodeError>>());
267
268    let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
269        .into_iter()
270        .map(Arc::new)
271        .collect::<Vec<_>>();
272
273    let args = StateFieldsArgs {
274        name: name.as_str(),
275        input_fields,
276        return_field,
277        ordering_fields,
278        is_distinct,
279    };
280
281    let state_fields = rresult_return!(udaf.state_fields(args));
282    let state_fields = rresult_return!(state_fields
283        .iter()
284        .map(|f| f.as_ref())
285        .map(datafusion_proto::protobuf::Field::try_from)
286        .map(|v| v.map_err(DataFusionError::from))
287        .collect::<Result<Vec<_>>>())
288    .into_iter()
289    .map(|field| field.encode_to_vec().into())
290    .collect();
291
292    RResult::ROk(state_fields)
293}
294
295unsafe extern "C" fn order_sensitivity_fn_wrapper(
296    udaf: &FFI_AggregateUDF,
297) -> FFI_AggregateOrderSensitivity {
298    udaf.inner().order_sensitivity().into()
299}
300
301unsafe extern "C" fn coerce_types_fn_wrapper(
302    udaf: &FFI_AggregateUDF,
303    arg_types: RVec<WrappedSchema>,
304) -> RResult<RVec<WrappedSchema>, RString> {
305    let udaf = udaf.inner();
306
307    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
308
309    let arg_fields = arg_types
310        .iter()
311        .map(|dt| Field::new("f", dt.clone(), true))
312        .map(Arc::new)
313        .collect::<Vec<_>>();
314    let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf))
315        .into_iter()
316        .map(|f| f.data_type().to_owned())
317        .collect::<Vec<_>>();
318
319    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
320}
321
322unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
323    let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
324    drop(private_data);
325}
326
327unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
328    Arc::clone(udaf.inner()).into()
329}
330
331impl Clone for FFI_AggregateUDF {
332    fn clone(&self) -> Self {
333        unsafe { (self.clone)(self) }
334    }
335}
336
337impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
338    fn from(udaf: Arc<AggregateUDF>) -> Self {
339        let name = udaf.name().into();
340        let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
341        let is_nullable = udaf.is_nullable();
342        let volatility = udaf.signature().volatility.into();
343
344        let private_data = Box::new(AggregateUDFPrivateData { udaf });
345
346        Self {
347            name,
348            is_nullable,
349            volatility,
350            aliases,
351            return_field: return_field_fn_wrapper,
352            accumulator: accumulator_fn_wrapper,
353            create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
354            create_groups_accumulator: create_groups_accumulator_fn_wrapper,
355            groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
356            with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
357            state_fields: state_fields_fn_wrapper,
358            order_sensitivity: order_sensitivity_fn_wrapper,
359            coerce_types: coerce_types_fn_wrapper,
360            clone: clone_fn_wrapper,
361            release: release_fn_wrapper,
362            private_data: Box::into_raw(private_data) as *mut c_void,
363        }
364    }
365}
366
367impl Drop for FFI_AggregateUDF {
368    fn drop(&mut self) {
369        unsafe { (self.release)(self) }
370    }
371}
372
373/// This struct is used to access an UDF provided by a foreign
374/// library across a FFI boundary.
375///
376/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has
377/// no knowledge or access to the private data. All interaction with the UDF
378/// must occur through the functions defined in FFI_AggregateUDF.
379#[derive(Debug)]
380pub struct ForeignAggregateUDF {
381    signature: Signature,
382    aliases: Vec<String>,
383    udaf: FFI_AggregateUDF,
384}
385
386unsafe impl Send for ForeignAggregateUDF {}
387unsafe impl Sync for ForeignAggregateUDF {}
388
389impl PartialEq for ForeignAggregateUDF {
390    fn eq(&self, other: &Self) -> bool {
391        // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do.
392        std::ptr::eq(self, other)
393    }
394}
395impl Eq for ForeignAggregateUDF {}
396impl Hash for ForeignAggregateUDF {
397    fn hash<H: Hasher>(&self, state: &mut H) {
398        std::ptr::hash(self, state)
399    }
400}
401
402impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF {
403    type Error = DataFusionError;
404
405    fn try_from(udaf: &FFI_AggregateUDF) -> Result<Self, Self::Error> {
406        let signature = Signature::user_defined((&udaf.volatility).into());
407        let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
408
409        Ok(Self {
410            udaf: udaf.clone(),
411            signature,
412            aliases,
413        })
414    }
415}
416
417impl AggregateUDFImpl for ForeignAggregateUDF {
418    fn as_any(&self) -> &dyn std::any::Any {
419        self
420    }
421
422    fn name(&self) -> &str {
423        self.udaf.name.as_str()
424    }
425
426    fn signature(&self) -> &Signature {
427        &self.signature
428    }
429
430    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
431        unimplemented!()
432    }
433
434    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
435        let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
436
437        let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
438
439        let result = df_result!(result);
440
441        result.and_then(|r| {
442            Field::try_from(&r.0)
443                .map(Arc::new)
444                .map_err(DataFusionError::from)
445        })
446    }
447
448    fn is_nullable(&self) -> bool {
449        self.udaf.is_nullable
450    }
451
452    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
453        let args = acc_args.try_into()?;
454        unsafe {
455            df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| {
456                Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>
457            })
458        }
459    }
460
461    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
462        unsafe {
463            let name = RStr::from_str(args.name);
464            let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
465            let return_field =
466                WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
467            let ordering_fields = args
468                .ordering_fields
469                .iter()
470                .map(|f| f.as_ref())
471                .map(datafusion_proto::protobuf::Field::try_from)
472                .map(|v| v.map_err(DataFusionError::from))
473                .collect::<Result<Vec<_>>>()?
474                .into_iter()
475                .map(|proto_field| proto_field.encode_to_vec().into())
476                .collect();
477
478            let fields = df_result!((self.udaf.state_fields)(
479                &self.udaf,
480                &name,
481                input_fields,
482                return_field,
483                ordering_fields,
484                args.is_distinct
485            ))?;
486            let fields = fields
487                .into_iter()
488                .map(|field_bytes| {
489                    datafusion_proto_common::Field::decode(field_bytes.as_ref())
490                        .map_err(|e| DataFusionError::Execution(e.to_string()))
491                })
492                .collect::<Result<Vec<_>>>()?;
493
494            parse_proto_fields_to_fields(fields.iter())
495                .map(|fields| fields.into_iter().map(Arc::new).collect())
496                .map_err(|e| DataFusionError::Execution(e.to_string()))
497        }
498    }
499
500    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
501        let args = match FFI_AccumulatorArgs::try_from(args) {
502            Ok(v) => v,
503            Err(e) => {
504                log::warn!("Attempting to convert accumulator arguments: {e}");
505                return false;
506            }
507        };
508
509        unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
510    }
511
512    fn create_groups_accumulator(
513        &self,
514        args: AccumulatorArgs,
515    ) -> Result<Box<dyn GroupsAccumulator>> {
516        let args = FFI_AccumulatorArgs::try_from(args)?;
517
518        unsafe {
519            df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map(
520                |accum| {
521                    Box::new(ForeignGroupsAccumulator::from(accum))
522                        as Box<dyn GroupsAccumulator>
523                },
524            )
525        }
526    }
527
528    fn aliases(&self) -> &[String] {
529        &self.aliases
530    }
531
532    fn create_sliding_accumulator(
533        &self,
534        args: AccumulatorArgs,
535    ) -> Result<Box<dyn Accumulator>> {
536        let args = args.try_into()?;
537        unsafe {
538            df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map(
539                |accum| Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>,
540            )
541        }
542    }
543
544    fn with_beneficial_ordering(
545        self: Arc<Self>,
546        beneficial_ordering: bool,
547    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
548        unsafe {
549            let result = df_result!((self.udaf.with_beneficial_ordering)(
550                &self.udaf,
551                beneficial_ordering
552            ))?
553            .into_option();
554
555            let result = result
556                .map(|func| ForeignAggregateUDF::try_from(&func))
557                .transpose()?;
558
559            Ok(result.map(|func| Arc::new(func) as Arc<dyn AggregateUDFImpl>))
560        }
561    }
562
563    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
564        unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
565    }
566
567    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
568        None
569    }
570
571    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
572        unsafe {
573            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
574            let result_types =
575                df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
576            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
577        }
578    }
579}
580
581#[repr(C)]
582#[derive(Debug, StableAbi)]
583#[allow(non_camel_case_types)]
584pub enum FFI_AggregateOrderSensitivity {
585    Insensitive,
586    HardRequirement,
587    SoftRequirement,
588    Beneficial,
589}
590
591impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
592    fn from(value: FFI_AggregateOrderSensitivity) -> Self {
593        match value {
594            FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
595            FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
596            FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
597            FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
598        }
599    }
600}
601
602impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
603    fn from(value: AggregateOrderSensitivity) -> Self {
604        match value {
605            AggregateOrderSensitivity::Insensitive => Self::Insensitive,
606            AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
607            AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
608            AggregateOrderSensitivity::Beneficial => Self::Beneficial,
609        }
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use arrow::datatypes::Schema;
616    use datafusion::{
617        common::create_array, functions_aggregate::sum::Sum,
618        physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
619        scalar::ScalarValue,
620    };
621    use std::any::Any;
622    use std::collections::HashMap;
623
624    use super::*;
625
626    #[derive(Default, Debug, Hash, Eq, PartialEq)]
627    struct SumWithCopiedMetadata {
628        inner: Sum,
629    }
630
631    impl AggregateUDFImpl for SumWithCopiedMetadata {
632        fn as_any(&self) -> &dyn Any {
633            self
634        }
635
636        fn name(&self) -> &str {
637            self.inner.name()
638        }
639
640        fn signature(&self) -> &Signature {
641            self.inner.signature()
642        }
643
644        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
645            unimplemented!()
646        }
647
648        fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
649            // Copy the input field, so any metadata gets returned
650            Ok(Arc::clone(&arg_fields[0]))
651        }
652
653        fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
654            self.inner.accumulator(acc_args)
655        }
656    }
657
658    fn create_test_foreign_udaf(
659        original_udaf: impl AggregateUDFImpl + 'static,
660    ) -> Result<AggregateUDF> {
661        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
662
663        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
664
665        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
666        Ok(foreign_udaf.into())
667    }
668
669    #[test]
670    fn test_round_trip_udaf() -> Result<()> {
671        let original_udaf = Sum::new();
672        let original_name = original_udaf.name().to_owned();
673        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
674
675        // Convert to FFI format
676        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
677
678        // Convert back to native format
679        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
680        let foreign_udaf: AggregateUDF = foreign_udaf.into();
681
682        assert_eq!(original_name, foreign_udaf.name());
683        Ok(())
684    }
685
686    #[test]
687    fn test_foreign_udaf_aliases() -> Result<()> {
688        let foreign_udaf =
689            create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
690
691        let return_field =
692            foreign_udaf
693                .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
694        let return_type = return_field.data_type();
695        assert_eq!(return_type, &DataType::Float64);
696        Ok(())
697    }
698
699    #[test]
700    fn test_foreign_udaf_accumulator() -> Result<()> {
701        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
702
703        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
704        let acc_args = AccumulatorArgs {
705            return_field: Field::new("f", DataType::Float64, true).into(),
706            schema: &schema,
707            ignore_nulls: true,
708            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
709            is_reversed: false,
710            name: "round_trip",
711            is_distinct: true,
712            exprs: &[col("a", &schema)?],
713        };
714        let mut accumulator = foreign_udaf.accumulator(acc_args)?;
715        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
716        accumulator.update_batch(&[values])?;
717        let resultant_value = accumulator.evaluate()?;
718        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
719
720        Ok(())
721    }
722
723    #[test]
724    fn test_round_trip_udaf_metadata() -> Result<()> {
725        let original_udaf = SumWithCopiedMetadata::default();
726        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
727
728        // Convert to FFI format
729        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
730
731        // Convert back to native format
732        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
733        let foreign_udaf: AggregateUDF = foreign_udaf.into();
734
735        let metadata: HashMap<String, String> =
736            [("a_key".to_string(), "a_value".to_string())]
737                .into_iter()
738                .collect();
739        let input_field = Arc::new(
740            Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
741        );
742        let return_field = foreign_udaf.return_field(&[input_field])?;
743
744        assert_eq!(&metadata, return_field.metadata());
745        Ok(())
746    }
747
748    #[test]
749    fn test_beneficial_ordering() -> Result<()> {
750        let foreign_udaf = create_test_foreign_udaf(
751            datafusion::functions_aggregate::first_last::FirstValue::new(),
752        )?;
753
754        let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
755
756        assert_eq!(
757            foreign_udaf.order_sensitivity(),
758            AggregateOrderSensitivity::Beneficial
759        );
760
761        let a_field = Arc::new(Field::new("a", DataType::Float64, true));
762        let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
763            name: "a",
764            input_fields: &[Field::new("f", DataType::Float64, true).into()],
765            return_field: Field::new("f", DataType::Float64, true).into(),
766            ordering_fields: &[Arc::clone(&a_field)],
767            is_distinct: false,
768        })?;
769
770        assert_eq!(state_fields.len(), 3);
771        assert_eq!(state_fields[1], a_field);
772        Ok(())
773    }
774
775    #[test]
776    fn test_sliding_accumulator() -> Result<()> {
777        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
778
779        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
780        // Note: sum distinct is only support Int64 until now
781        let acc_args = AccumulatorArgs {
782            return_field: Field::new("f", DataType::Float64, true).into(),
783            schema: &schema,
784            ignore_nulls: true,
785            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
786            is_reversed: false,
787            name: "round_trip",
788            is_distinct: false,
789            exprs: &[col("a", &schema)?],
790        };
791
792        let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
793        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
794        accumulator.update_batch(&[values])?;
795        let resultant_value = accumulator.evaluate()?;
796        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
797
798        Ok(())
799    }
800
801    fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
802        let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
803        let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
804
805        assert_eq!(sensitivity, round_trip_sensitivity);
806    }
807
808    #[test]
809    fn test_round_trip_all_order_sensitivities() {
810        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
811        test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
812        test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
813        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
814    }
815}