datafusion_ffi/udaf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use abi_stable::{
19    std_types::{ROption, RResult, RStr, RString, RVec},
20    StableAbi,
21};
22use accumulator::{FFI_Accumulator, ForeignAccumulator};
23use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
24use arrow::datatypes::{DataType, Field};
25use arrow::ffi::FFI_ArrowSchema;
26use arrow_schema::FieldRef;
27use datafusion::{
28    error::DataFusionError,
29    logical_expr::{
30        function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
31        type_coercion::functions::fields_with_aggregate_udf,
32        utils::AggregateOrderSensitivity,
33        Accumulator, GroupsAccumulator,
34    },
35};
36use datafusion::{
37    error::Result,
38    logical_expr::{AggregateUDF, AggregateUDFImpl, Signature},
39};
40use datafusion_common::exec_datafusion_err;
41use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
42use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator};
43use std::hash::{Hash, Hasher};
44use std::{ffi::c_void, sync::Arc};
45
46use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
47use crate::{
48    arrow_wrappers::WrappedSchema,
49    df_result, rresult, rresult_return,
50    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
51    volatility::FFI_Volatility,
52};
53use prost::{DecodeError, Message};
54
55mod accumulator;
56mod accumulator_args;
57mod groups_accumulator;
58
59/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries.
60#[repr(C)]
61#[derive(Debug, StableAbi)]
62#[allow(non_camel_case_types)]
63pub struct FFI_AggregateUDF {
64    /// FFI equivalent to the `name` of a [`AggregateUDF`]
65    pub name: RString,
66
67    /// FFI equivalent to the `aliases` of a [`AggregateUDF`]
68    pub aliases: RVec<RString>,
69
70    /// FFI equivalent to the `volatility` of a [`AggregateUDF`]
71    pub volatility: FFI_Volatility,
72
73    /// Determines the return field of the underlying [`AggregateUDF`] based on the
74    /// argument fields.
75    pub return_field: unsafe extern "C" fn(
76        udaf: &Self,
77        arg_fields: RVec<WrappedSchema>,
78    ) -> RResult<WrappedSchema, RString>,
79
80    /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
81    pub is_nullable: bool,
82
83    /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`]
84    pub groups_accumulator_supported:
85        unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
86
87    /// FFI equivalent to [`AggregateUDF::accumulator`]
88    pub accumulator: unsafe extern "C" fn(
89        udaf: &FFI_AggregateUDF,
90        args: FFI_AccumulatorArgs,
91    ) -> RResult<FFI_Accumulator, RString>,
92
93    /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`]
94    pub create_sliding_accumulator:
95        unsafe extern "C" fn(
96            udaf: &FFI_AggregateUDF,
97            args: FFI_AccumulatorArgs,
98        ) -> RResult<FFI_Accumulator, RString>,
99
100    /// FFI equivalent to [`AggregateUDF::state_fields`]
101    #[allow(clippy::type_complexity)]
102    pub state_fields: unsafe extern "C" fn(
103        udaf: &FFI_AggregateUDF,
104        name: &RStr,
105        input_fields: RVec<WrappedSchema>,
106        return_field: WrappedSchema,
107        ordering_fields: RVec<RVec<u8>>,
108        is_distinct: bool,
109    ) -> RResult<RVec<RVec<u8>>, RString>,
110
111    /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`]
112    pub create_groups_accumulator:
113        unsafe extern "C" fn(
114            udaf: &FFI_AggregateUDF,
115            args: FFI_AccumulatorArgs,
116        ) -> RResult<FFI_GroupsAccumulator, RString>,
117
118    /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`]
119    pub with_beneficial_ordering:
120        unsafe extern "C" fn(
121            udaf: &FFI_AggregateUDF,
122            beneficial_ordering: bool,
123        ) -> RResult<ROption<FFI_AggregateUDF>, RString>,
124
125    /// FFI equivalent to [`AggregateUDF::order_sensitivity`]
126    pub order_sensitivity:
127        unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
128
129    /// Performs type coercion. To simply this interface, all UDFs are treated as having
130    /// user defined signatures, which will in turn call coerce_types to be called. This
131    /// call should be transparent to most users as the internal function performs the
132    /// appropriate calls on the underlying [`AggregateUDF`]
133    pub coerce_types: unsafe extern "C" fn(
134        udf: &Self,
135        arg_types: RVec<WrappedSchema>,
136    ) -> RResult<RVec<WrappedSchema>, RString>,
137
138    /// Used to create a clone on the provider of the udaf. This should
139    /// only need to be called by the receiver of the udaf.
140    pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
141
142    /// Release the memory of the private data when it is no longer being used.
143    pub release: unsafe extern "C" fn(udaf: &mut Self),
144
145    /// Internal data. This is only to be accessed by the provider of the udaf.
146    /// A [`ForeignAggregateUDF`] should never attempt to access this data.
147    pub private_data: *mut c_void,
148}
149
150unsafe impl Send for FFI_AggregateUDF {}
151unsafe impl Sync for FFI_AggregateUDF {}
152
153pub struct AggregateUDFPrivateData {
154    pub udaf: Arc<AggregateUDF>,
155}
156
157impl FFI_AggregateUDF {
158    unsafe fn inner(&self) -> &Arc<AggregateUDF> {
159        let private_data = self.private_data as *const AggregateUDFPrivateData;
160        &(*private_data).udaf
161    }
162}
163
164unsafe extern "C" fn return_field_fn_wrapper(
165    udaf: &FFI_AggregateUDF,
166    arg_fields: RVec<WrappedSchema>,
167) -> RResult<WrappedSchema, RString> {
168    let udaf = udaf.inner();
169
170    let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
171
172    let return_field = udaf
173        .return_field(&arg_fields)
174        .and_then(|v| {
175            FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
176        })
177        .map(WrappedSchema);
178
179    rresult!(return_field)
180}
181
182unsafe extern "C" fn accumulator_fn_wrapper(
183    udaf: &FFI_AggregateUDF,
184    args: FFI_AccumulatorArgs,
185) -> RResult<FFI_Accumulator, RString> {
186    let udaf = udaf.inner();
187
188    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
189
190    rresult!(udaf
191        .accumulator(accumulator_args.into())
192        .map(FFI_Accumulator::from))
193}
194
195unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
196    udaf: &FFI_AggregateUDF,
197    args: FFI_AccumulatorArgs,
198) -> RResult<FFI_Accumulator, RString> {
199    let udaf = udaf.inner();
200
201    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
202
203    rresult!(udaf
204        .create_sliding_accumulator(accumulator_args.into())
205        .map(FFI_Accumulator::from))
206}
207
208unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
209    udaf: &FFI_AggregateUDF,
210    args: FFI_AccumulatorArgs,
211) -> RResult<FFI_GroupsAccumulator, RString> {
212    let udaf = udaf.inner();
213
214    let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
215
216    rresult!(udaf
217        .create_groups_accumulator(accumulator_args.into())
218        .map(FFI_GroupsAccumulator::from))
219}
220
221unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
222    udaf: &FFI_AggregateUDF,
223    args: FFI_AccumulatorArgs,
224) -> bool {
225    let udaf = udaf.inner();
226
227    ForeignAccumulatorArgs::try_from(args)
228        .map(|a| udaf.groups_accumulator_supported((&a).into()))
229        .unwrap_or_else(|e| {
230            log::warn!("Unable to parse accumulator args. {e}");
231            false
232        })
233}
234
235unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
236    udaf: &FFI_AggregateUDF,
237    beneficial_ordering: bool,
238) -> RResult<ROption<FFI_AggregateUDF>, RString> {
239    let udaf = udaf.inner().as_ref().clone();
240
241    let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
242    let result = rresult_return!(result
243        .map(|func| func.with_beneficial_ordering(beneficial_ordering))
244        .transpose())
245    .flatten()
246    .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
247
248    RResult::ROk(result.into())
249}
250
251unsafe extern "C" fn state_fields_fn_wrapper(
252    udaf: &FFI_AggregateUDF,
253    name: &RStr,
254    input_fields: RVec<WrappedSchema>,
255    return_field: WrappedSchema,
256    ordering_fields: RVec<RVec<u8>>,
257    is_distinct: bool,
258) -> RResult<RVec<RVec<u8>>, RString> {
259    let udaf = udaf.inner();
260
261    let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
262    let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
263
264    let ordering_fields = &rresult_return!(ordering_fields
265        .into_iter()
266        .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref()))
267        .collect::<std::result::Result<Vec<_>, DecodeError>>());
268
269    let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
270        .into_iter()
271        .map(Arc::new)
272        .collect::<Vec<_>>();
273
274    let args = StateFieldsArgs {
275        name: name.as_str(),
276        input_fields,
277        return_field,
278        ordering_fields,
279        is_distinct,
280    };
281
282    let state_fields = rresult_return!(udaf.state_fields(args));
283    let state_fields = rresult_return!(state_fields
284        .iter()
285        .map(|f| f.as_ref())
286        .map(datafusion_proto::protobuf::Field::try_from)
287        .map(|v| v.map_err(DataFusionError::from))
288        .collect::<Result<Vec<_>>>())
289    .into_iter()
290    .map(|field| field.encode_to_vec().into())
291    .collect();
292
293    RResult::ROk(state_fields)
294}
295
296unsafe extern "C" fn order_sensitivity_fn_wrapper(
297    udaf: &FFI_AggregateUDF,
298) -> FFI_AggregateOrderSensitivity {
299    udaf.inner().order_sensitivity().into()
300}
301
302unsafe extern "C" fn coerce_types_fn_wrapper(
303    udaf: &FFI_AggregateUDF,
304    arg_types: RVec<WrappedSchema>,
305) -> RResult<RVec<WrappedSchema>, RString> {
306    let udaf = udaf.inner();
307
308    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
309
310    let arg_fields = arg_types
311        .iter()
312        .map(|dt| Field::new("f", dt.clone(), true))
313        .map(Arc::new)
314        .collect::<Vec<_>>();
315    let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf))
316        .into_iter()
317        .map(|f| f.data_type().to_owned())
318        .collect::<Vec<_>>();
319
320    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
321}
322
323unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
324    let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
325    drop(private_data);
326}
327
328unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
329    Arc::clone(udaf.inner()).into()
330}
331
332impl Clone for FFI_AggregateUDF {
333    fn clone(&self) -> Self {
334        unsafe { (self.clone)(self) }
335    }
336}
337
338impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
339    fn from(udaf: Arc<AggregateUDF>) -> Self {
340        let name = udaf.name().into();
341        let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
342        let is_nullable = udaf.is_nullable();
343        let volatility = udaf.signature().volatility.into();
344
345        let private_data = Box::new(AggregateUDFPrivateData { udaf });
346
347        Self {
348            name,
349            is_nullable,
350            volatility,
351            aliases,
352            return_field: return_field_fn_wrapper,
353            accumulator: accumulator_fn_wrapper,
354            create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
355            create_groups_accumulator: create_groups_accumulator_fn_wrapper,
356            groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
357            with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
358            state_fields: state_fields_fn_wrapper,
359            order_sensitivity: order_sensitivity_fn_wrapper,
360            coerce_types: coerce_types_fn_wrapper,
361            clone: clone_fn_wrapper,
362            release: release_fn_wrapper,
363            private_data: Box::into_raw(private_data) as *mut c_void,
364        }
365    }
366}
367
368impl Drop for FFI_AggregateUDF {
369    fn drop(&mut self) {
370        unsafe { (self.release)(self) }
371    }
372}
373
374/// This struct is used to access an UDF provided by a foreign
375/// library across a FFI boundary.
376///
377/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has
378/// no knowledge or access to the private data. All interaction with the UDF
379/// must occur through the functions defined in FFI_AggregateUDF.
380#[derive(Debug)]
381pub struct ForeignAggregateUDF {
382    signature: Signature,
383    aliases: Vec<String>,
384    udaf: FFI_AggregateUDF,
385}
386
387unsafe impl Send for ForeignAggregateUDF {}
388unsafe impl Sync for ForeignAggregateUDF {}
389
390impl PartialEq for ForeignAggregateUDF {
391    fn eq(&self, other: &Self) -> bool {
392        // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do.
393        std::ptr::eq(self, other)
394    }
395}
396impl Eq for ForeignAggregateUDF {}
397impl Hash for ForeignAggregateUDF {
398    fn hash<H: Hasher>(&self, state: &mut H) {
399        std::ptr::hash(self, state)
400    }
401}
402
403impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF {
404    type Error = DataFusionError;
405
406    fn try_from(udaf: &FFI_AggregateUDF) -> Result<Self, Self::Error> {
407        let signature = Signature::user_defined((&udaf.volatility).into());
408        let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
409
410        Ok(Self {
411            udaf: udaf.clone(),
412            signature,
413            aliases,
414        })
415    }
416}
417
418impl AggregateUDFImpl for ForeignAggregateUDF {
419    fn as_any(&self) -> &dyn std::any::Any {
420        self
421    }
422
423    fn name(&self) -> &str {
424        self.udaf.name.as_str()
425    }
426
427    fn signature(&self) -> &Signature {
428        &self.signature
429    }
430
431    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
432        unimplemented!()
433    }
434
435    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
436        let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
437
438        let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
439
440        let result = df_result!(result);
441
442        result.and_then(|r| {
443            Field::try_from(&r.0)
444                .map(Arc::new)
445                .map_err(DataFusionError::from)
446        })
447    }
448
449    fn is_nullable(&self) -> bool {
450        self.udaf.is_nullable
451    }
452
453    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
454        let args = acc_args.try_into()?;
455        unsafe {
456            df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| {
457                Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>
458            })
459        }
460    }
461
462    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
463        unsafe {
464            let name = RStr::from_str(args.name);
465            let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
466            let return_field =
467                WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
468            let ordering_fields = args
469                .ordering_fields
470                .iter()
471                .map(|f| f.as_ref())
472                .map(datafusion_proto::protobuf::Field::try_from)
473                .map(|v| v.map_err(DataFusionError::from))
474                .collect::<Result<Vec<_>>>()?
475                .into_iter()
476                .map(|proto_field| proto_field.encode_to_vec().into())
477                .collect();
478
479            let fields = df_result!((self.udaf.state_fields)(
480                &self.udaf,
481                &name,
482                input_fields,
483                return_field,
484                ordering_fields,
485                args.is_distinct
486            ))?;
487            let fields = fields
488                .into_iter()
489                .map(|field_bytes| {
490                    datafusion_proto_common::Field::decode(field_bytes.as_ref())
491                        .map_err(|e| exec_datafusion_err!("{e}"))
492                })
493                .collect::<Result<Vec<_>>>()?;
494
495            parse_proto_fields_to_fields(fields.iter())
496                .map(|fields| fields.into_iter().map(Arc::new).collect())
497                .map_err(|e| exec_datafusion_err!("{e}"))
498        }
499    }
500
501    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
502        let args = match FFI_AccumulatorArgs::try_from(args) {
503            Ok(v) => v,
504            Err(e) => {
505                log::warn!("Attempting to convert accumulator arguments: {e}");
506                return false;
507            }
508        };
509
510        unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
511    }
512
513    fn create_groups_accumulator(
514        &self,
515        args: AccumulatorArgs,
516    ) -> Result<Box<dyn GroupsAccumulator>> {
517        let args = FFI_AccumulatorArgs::try_from(args)?;
518
519        unsafe {
520            df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map(
521                |accum| {
522                    Box::new(ForeignGroupsAccumulator::from(accum))
523                        as Box<dyn GroupsAccumulator>
524                },
525            )
526        }
527    }
528
529    fn aliases(&self) -> &[String] {
530        &self.aliases
531    }
532
533    fn create_sliding_accumulator(
534        &self,
535        args: AccumulatorArgs,
536    ) -> Result<Box<dyn Accumulator>> {
537        let args = args.try_into()?;
538        unsafe {
539            df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map(
540                |accum| Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>,
541            )
542        }
543    }
544
545    fn with_beneficial_ordering(
546        self: Arc<Self>,
547        beneficial_ordering: bool,
548    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
549        unsafe {
550            let result = df_result!((self.udaf.with_beneficial_ordering)(
551                &self.udaf,
552                beneficial_ordering
553            ))?
554            .into_option();
555
556            let result = result
557                .map(|func| ForeignAggregateUDF::try_from(&func))
558                .transpose()?;
559
560            Ok(result.map(|func| Arc::new(func) as Arc<dyn AggregateUDFImpl>))
561        }
562    }
563
564    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
565        unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
566    }
567
568    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
569        None
570    }
571
572    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
573        unsafe {
574            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
575            let result_types =
576                df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
577            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
578        }
579    }
580}
581
582#[repr(C)]
583#[derive(Debug, StableAbi)]
584#[allow(non_camel_case_types)]
585pub enum FFI_AggregateOrderSensitivity {
586    Insensitive,
587    HardRequirement,
588    SoftRequirement,
589    Beneficial,
590}
591
592impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
593    fn from(value: FFI_AggregateOrderSensitivity) -> Self {
594        match value {
595            FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
596            FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
597            FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
598            FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
599        }
600    }
601}
602
603impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
604    fn from(value: AggregateOrderSensitivity) -> Self {
605        match value {
606            AggregateOrderSensitivity::Insensitive => Self::Insensitive,
607            AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
608            AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
609            AggregateOrderSensitivity::Beneficial => Self::Beneficial,
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use arrow::datatypes::Schema;
617    use datafusion::{
618        common::create_array, functions_aggregate::sum::Sum,
619        physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
620        scalar::ScalarValue,
621    };
622    use std::any::Any;
623    use std::collections::HashMap;
624
625    use super::*;
626
627    #[derive(Default, Debug, Hash, Eq, PartialEq)]
628    struct SumWithCopiedMetadata {
629        inner: Sum,
630    }
631
632    impl AggregateUDFImpl for SumWithCopiedMetadata {
633        fn as_any(&self) -> &dyn Any {
634            self
635        }
636
637        fn name(&self) -> &str {
638            self.inner.name()
639        }
640
641        fn signature(&self) -> &Signature {
642            self.inner.signature()
643        }
644
645        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
646            unimplemented!()
647        }
648
649        fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
650            // Copy the input field, so any metadata gets returned
651            Ok(Arc::clone(&arg_fields[0]))
652        }
653
654        fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
655            self.inner.accumulator(acc_args)
656        }
657    }
658
659    fn create_test_foreign_udaf(
660        original_udaf: impl AggregateUDFImpl + 'static,
661    ) -> Result<AggregateUDF> {
662        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
663
664        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
665
666        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
667        Ok(foreign_udaf.into())
668    }
669
670    #[test]
671    fn test_round_trip_udaf() -> Result<()> {
672        let original_udaf = Sum::new();
673        let original_name = original_udaf.name().to_owned();
674        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
675
676        // Convert to FFI format
677        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
678
679        // Convert back to native format
680        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
681        let foreign_udaf: AggregateUDF = foreign_udaf.into();
682
683        assert_eq!(original_name, foreign_udaf.name());
684        Ok(())
685    }
686
687    #[test]
688    fn test_foreign_udaf_aliases() -> Result<()> {
689        let foreign_udaf =
690            create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
691
692        let return_field =
693            foreign_udaf
694                .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
695        let return_type = return_field.data_type();
696        assert_eq!(return_type, &DataType::Float64);
697        Ok(())
698    }
699
700    #[test]
701    fn test_foreign_udaf_accumulator() -> Result<()> {
702        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
703
704        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
705        let acc_args = AccumulatorArgs {
706            return_field: Field::new("f", DataType::Float64, true).into(),
707            schema: &schema,
708            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
709            ignore_nulls: true,
710            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
711            is_reversed: false,
712            name: "round_trip",
713            is_distinct: true,
714            exprs: &[col("a", &schema)?],
715        };
716        let mut accumulator = foreign_udaf.accumulator(acc_args)?;
717        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
718        accumulator.update_batch(&[values])?;
719        let resultant_value = accumulator.evaluate()?;
720        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
721
722        Ok(())
723    }
724
725    #[test]
726    fn test_round_trip_udaf_metadata() -> Result<()> {
727        let original_udaf = SumWithCopiedMetadata::default();
728        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
729
730        // Convert to FFI format
731        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
732
733        // Convert back to native format
734        let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
735        let foreign_udaf: AggregateUDF = foreign_udaf.into();
736
737        let metadata: HashMap<String, String> =
738            [("a_key".to_string(), "a_value".to_string())]
739                .into_iter()
740                .collect();
741        let input_field = Arc::new(
742            Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
743        );
744        let return_field = foreign_udaf.return_field(&[input_field])?;
745
746        assert_eq!(&metadata, return_field.metadata());
747        Ok(())
748    }
749
750    #[test]
751    fn test_beneficial_ordering() -> Result<()> {
752        let foreign_udaf = create_test_foreign_udaf(
753            datafusion::functions_aggregate::first_last::FirstValue::new(),
754        )?;
755
756        let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
757
758        assert_eq!(
759            foreign_udaf.order_sensitivity(),
760            AggregateOrderSensitivity::Beneficial
761        );
762
763        let a_field = Arc::new(Field::new("a", DataType::Float64, true));
764        let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
765            name: "a",
766            input_fields: &[Field::new("f", DataType::Float64, true).into()],
767            return_field: Field::new("f", DataType::Float64, true).into(),
768            ordering_fields: &[Arc::clone(&a_field)],
769            is_distinct: false,
770        })?;
771
772        assert_eq!(state_fields.len(), 3);
773        assert_eq!(state_fields[1], a_field);
774        Ok(())
775    }
776
777    #[test]
778    fn test_sliding_accumulator() -> Result<()> {
779        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
780
781        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
782        // Note: sum distinct is only support Int64 until now
783        let acc_args = AccumulatorArgs {
784            return_field: Field::new("f", DataType::Float64, true).into(),
785            schema: &schema,
786            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
787            ignore_nulls: true,
788            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
789            is_reversed: false,
790            name: "round_trip",
791            is_distinct: false,
792            exprs: &[col("a", &schema)?],
793        };
794
795        let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
796        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
797        accumulator.update_batch(&[values])?;
798        let resultant_value = accumulator.evaluate()?;
799        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
800
801        Ok(())
802    }
803
804    fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
805        let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
806        let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
807
808        assert_eq!(sensitivity, round_trip_sensitivity);
809    }
810
811    #[test]
812    fn test_round_trip_all_order_sensitivities() {
813        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
814        test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
815        test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
816        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
817    }
818}