datafusion_ffi/udaf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{ROption, RResult, RStr, RString, RVec};
24use accumulator::FFI_Accumulator;
25use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
26use arrow::datatypes::{DataType, Field};
27use arrow::ffi::FFI_ArrowSchema;
28use arrow_schema::FieldRef;
29use datafusion_common::{DataFusionError, Result, ffi_datafusion_err};
30use datafusion_expr::function::AggregateFunctionSimplification;
31use datafusion_expr::type_coercion::functions::fields_with_udf;
32use datafusion_expr::{
33    Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
34};
35use datafusion_functions_aggregate_common::accumulator::{
36    AccumulatorArgs, StateFieldsArgs,
37};
38use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
39use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
40use groups_accumulator::FFI_GroupsAccumulator;
41use prost::{DecodeError, Message};
42
43use crate::arrow_wrappers::WrappedSchema;
44use crate::util::{
45    FFIResult, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
46    vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
47};
48use crate::volatility::FFI_Volatility;
49use crate::{df_result, rresult, rresult_return};
50
51mod accumulator;
52mod accumulator_args;
53mod groups_accumulator;
54
55/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries.
56#[repr(C)]
57#[derive(Debug, StableAbi)]
58pub struct FFI_AggregateUDF {
59    /// FFI equivalent to the `name` of a [`AggregateUDF`]
60    pub name: RString,
61
62    /// FFI equivalent to the `aliases` of a [`AggregateUDF`]
63    pub aliases: RVec<RString>,
64
65    /// FFI equivalent to the `volatility` of a [`AggregateUDF`]
66    pub volatility: FFI_Volatility,
67
68    /// Determines the return field of the underlying [`AggregateUDF`] based on the
69    /// argument fields.
70    pub return_field: unsafe extern "C" fn(
71        udaf: &Self,
72        arg_fields: RVec<WrappedSchema>,
73    ) -> FFIResult<WrappedSchema>,
74
75    /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
76    pub is_nullable: bool,
77
78    /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`]
79    pub groups_accumulator_supported:
80        unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
81
82    /// FFI equivalent to [`AggregateUDF::accumulator`]
83    pub accumulator: unsafe extern "C" fn(
84        udaf: &FFI_AggregateUDF,
85        args: FFI_AccumulatorArgs,
86    ) -> FFIResult<FFI_Accumulator>,
87
88    /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`]
89    pub create_sliding_accumulator: unsafe extern "C" fn(
90        udaf: &FFI_AggregateUDF,
91        args: FFI_AccumulatorArgs,
92    )
93        -> FFIResult<FFI_Accumulator>,
94
95    /// FFI equivalent to [`AggregateUDF::state_fields`]
96    pub state_fields: unsafe extern "C" fn(
97        udaf: &FFI_AggregateUDF,
98        name: &RStr,
99        input_fields: RVec<WrappedSchema>,
100        return_field: WrappedSchema,
101        ordering_fields: RVec<RVec<u8>>,
102        is_distinct: bool,
103    ) -> FFIResult<RVec<RVec<u8>>>,
104
105    /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`]
106    pub create_groups_accumulator:
107        unsafe extern "C" fn(
108            udaf: &FFI_AggregateUDF,
109            args: FFI_AccumulatorArgs,
110        ) -> FFIResult<FFI_GroupsAccumulator>,
111
112    /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`]
113    pub with_beneficial_ordering:
114        unsafe extern "C" fn(
115            udaf: &FFI_AggregateUDF,
116            beneficial_ordering: bool,
117        ) -> FFIResult<ROption<FFI_AggregateUDF>>,
118
119    /// FFI equivalent to [`AggregateUDF::order_sensitivity`]
120    pub order_sensitivity:
121        unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
122
123    /// Performs type coercion. To simply this interface, all UDFs are treated as having
124    /// user defined signatures, which will in turn call coerce_types to be called. This
125    /// call should be transparent to most users as the internal function performs the
126    /// appropriate calls on the underlying [`AggregateUDF`]
127    pub coerce_types: unsafe extern "C" fn(
128        udf: &Self,
129        arg_types: RVec<WrappedSchema>,
130    ) -> FFIResult<RVec<WrappedSchema>>,
131
132    /// Used to create a clone on the provider of the udaf. This should
133    /// only need to be called by the receiver of the udaf.
134    pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
135
136    /// Release the memory of the private data when it is no longer being used.
137    pub release: unsafe extern "C" fn(udaf: &mut Self),
138
139    /// Internal data. This is only to be accessed by the provider of the udaf.
140    /// A [`ForeignAggregateUDF`] should never attempt to access this data.
141    pub private_data: *mut c_void,
142
143    /// Utility to identify when FFI objects are accessed locally through
144    /// the foreign interface. See [`crate::get_library_marker_id`] and
145    /// the crate's `README.md` for more information.
146    pub library_marker_id: extern "C" fn() -> usize,
147}
148
149unsafe impl Send for FFI_AggregateUDF {}
150unsafe impl Sync for FFI_AggregateUDF {}
151
152pub struct AggregateUDFPrivateData {
153    pub udaf: Arc<AggregateUDF>,
154}
155
156impl FFI_AggregateUDF {
157    unsafe fn inner(&self) -> &Arc<AggregateUDF> {
158        unsafe {
159            let private_data = self.private_data as *const AggregateUDFPrivateData;
160            &(*private_data).udaf
161        }
162    }
163}
164
165unsafe extern "C" fn return_field_fn_wrapper(
166    udaf: &FFI_AggregateUDF,
167    arg_fields: RVec<WrappedSchema>,
168) -> FFIResult<WrappedSchema> {
169    unsafe {
170        let udaf = udaf.inner();
171
172        let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
173
174        let return_field = udaf
175            .return_field(&arg_fields)
176            .and_then(|v| {
177                FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
178            })
179            .map(WrappedSchema);
180
181        rresult!(return_field)
182    }
183}
184
185unsafe extern "C" fn accumulator_fn_wrapper(
186    udaf: &FFI_AggregateUDF,
187    args: FFI_AccumulatorArgs,
188) -> FFIResult<FFI_Accumulator> {
189    unsafe {
190        let udaf = udaf.inner();
191
192        let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
193
194        rresult!(
195            udaf.accumulator(accumulator_args.into())
196                .map(FFI_Accumulator::from)
197        )
198    }
199}
200
201unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
202    udaf: &FFI_AggregateUDF,
203    args: FFI_AccumulatorArgs,
204) -> FFIResult<FFI_Accumulator> {
205    unsafe {
206        let udaf = udaf.inner();
207
208        let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
209
210        rresult!(
211            udaf.create_sliding_accumulator(accumulator_args.into())
212                .map(FFI_Accumulator::from)
213        )
214    }
215}
216
217unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
218    udaf: &FFI_AggregateUDF,
219    args: FFI_AccumulatorArgs,
220) -> FFIResult<FFI_GroupsAccumulator> {
221    unsafe {
222        let udaf = udaf.inner();
223
224        let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args));
225
226        rresult!(
227            udaf.create_groups_accumulator(accumulator_args.into())
228                .map(FFI_GroupsAccumulator::from)
229        )
230    }
231}
232
233unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
234    udaf: &FFI_AggregateUDF,
235    args: FFI_AccumulatorArgs,
236) -> bool {
237    unsafe {
238        let udaf = udaf.inner();
239
240        ForeignAccumulatorArgs::try_from(args)
241            .map(|a| udaf.groups_accumulator_supported((&a).into()))
242            .unwrap_or_else(|e| {
243                log::warn!("Unable to parse accumulator args. {e}");
244                false
245            })
246    }
247}
248
249unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
250    udaf: &FFI_AggregateUDF,
251    beneficial_ordering: bool,
252) -> FFIResult<ROption<FFI_AggregateUDF>> {
253    unsafe {
254        let udaf = udaf.inner().as_ref().clone();
255
256        let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
257        let result = rresult_return!(
258            result
259                .map(|func| func.with_beneficial_ordering(beneficial_ordering))
260                .transpose()
261        )
262        .flatten()
263        .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
264
265        RResult::ROk(result.into())
266    }
267}
268
269unsafe extern "C" fn state_fields_fn_wrapper(
270    udaf: &FFI_AggregateUDF,
271    name: &RStr,
272    input_fields: RVec<WrappedSchema>,
273    return_field: WrappedSchema,
274    ordering_fields: RVec<RVec<u8>>,
275    is_distinct: bool,
276) -> FFIResult<RVec<RVec<u8>>> {
277    unsafe {
278        let udaf = udaf.inner();
279
280        let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
281        let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
282
283        let ordering_fields = &rresult_return!(
284            ordering_fields
285                .into_iter()
286                .map(|field_bytes| datafusion_proto_common::Field::decode(
287                    field_bytes.as_ref()
288                ))
289                .collect::<std::result::Result<Vec<_>, DecodeError>>()
290        );
291
292        let ordering_fields =
293            &rresult_return!(parse_proto_fields_to_fields(ordering_fields))
294                .into_iter()
295                .map(Arc::new)
296                .collect::<Vec<_>>();
297
298        let args = StateFieldsArgs {
299            name: name.as_str(),
300            input_fields,
301            return_field,
302            ordering_fields,
303            is_distinct,
304        };
305
306        let state_fields = rresult_return!(udaf.state_fields(args));
307        let state_fields = rresult_return!(
308            state_fields
309                .iter()
310                .map(|f| f.as_ref())
311                .map(datafusion_proto::protobuf::Field::try_from)
312                .map(|v| v.map_err(DataFusionError::from))
313                .collect::<Result<Vec<_>>>()
314        )
315        .into_iter()
316        .map(|field| field.encode_to_vec().into())
317        .collect();
318
319        RResult::ROk(state_fields)
320    }
321}
322
323unsafe extern "C" fn order_sensitivity_fn_wrapper(
324    udaf: &FFI_AggregateUDF,
325) -> FFI_AggregateOrderSensitivity {
326    unsafe { udaf.inner().order_sensitivity().into() }
327}
328
329unsafe extern "C" fn coerce_types_fn_wrapper(
330    udaf: &FFI_AggregateUDF,
331    arg_types: RVec<WrappedSchema>,
332) -> FFIResult<RVec<WrappedSchema>> {
333    unsafe {
334        let udaf = udaf.inner();
335
336        let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
337
338        let arg_fields = arg_types
339            .iter()
340            .map(|dt| Field::new("f", dt.clone(), true))
341            .map(Arc::new)
342            .collect::<Vec<_>>();
343        let return_types = rresult_return!(fields_with_udf(&arg_fields, udaf.as_ref()))
344            .into_iter()
345            .map(|f| f.data_type().to_owned())
346            .collect::<Vec<_>>();
347
348        rresult!(vec_datatype_to_rvec_wrapped(&return_types))
349    }
350}
351
352unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
353    unsafe {
354        debug_assert!(!udaf.private_data.is_null());
355        let private_data =
356            Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
357        drop(private_data);
358        udaf.private_data = std::ptr::null_mut();
359    }
360}
361
362unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
363    unsafe { Arc::clone(udaf.inner()).into() }
364}
365
366impl Clone for FFI_AggregateUDF {
367    fn clone(&self) -> Self {
368        unsafe { (self.clone)(self) }
369    }
370}
371
372impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
373    fn from(udaf: Arc<AggregateUDF>) -> Self {
374        let name = udaf.name().into();
375        let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
376        let is_nullable = udaf.is_nullable();
377        let volatility = udaf.signature().volatility.into();
378
379        let private_data = Box::new(AggregateUDFPrivateData { udaf });
380
381        Self {
382            name,
383            is_nullable,
384            volatility,
385            aliases,
386            return_field: return_field_fn_wrapper,
387            accumulator: accumulator_fn_wrapper,
388            create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
389            create_groups_accumulator: create_groups_accumulator_fn_wrapper,
390            groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
391            with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
392            state_fields: state_fields_fn_wrapper,
393            order_sensitivity: order_sensitivity_fn_wrapper,
394            coerce_types: coerce_types_fn_wrapper,
395            clone: clone_fn_wrapper,
396            release: release_fn_wrapper,
397            private_data: Box::into_raw(private_data) as *mut c_void,
398            library_marker_id: crate::get_library_marker_id,
399        }
400    }
401}
402
403impl Drop for FFI_AggregateUDF {
404    fn drop(&mut self) {
405        unsafe { (self.release)(self) }
406    }
407}
408
409/// This struct is used to access an UDF provided by a foreign
410/// library across a FFI boundary.
411///
412/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has
413/// no knowledge or access to the private data. All interaction with the UDF
414/// must occur through the functions defined in FFI_AggregateUDF.
415#[derive(Debug)]
416pub struct ForeignAggregateUDF {
417    signature: Signature,
418    aliases: Vec<String>,
419    udaf: FFI_AggregateUDF,
420}
421
422unsafe impl Send for ForeignAggregateUDF {}
423unsafe impl Sync for ForeignAggregateUDF {}
424
425impl PartialEq for ForeignAggregateUDF {
426    fn eq(&self, other: &Self) -> bool {
427        // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do.
428        std::ptr::eq(self, other)
429    }
430}
431impl Eq for ForeignAggregateUDF {}
432impl Hash for ForeignAggregateUDF {
433    fn hash<H: Hasher>(&self, state: &mut H) {
434        std::ptr::hash(self, state)
435    }
436}
437
438impl From<&FFI_AggregateUDF> for Arc<dyn AggregateUDFImpl> {
439    fn from(udaf: &FFI_AggregateUDF) -> Self {
440        if (udaf.library_marker_id)() == crate::get_library_marker_id() {
441            return Arc::clone(unsafe { udaf.inner().inner() });
442        }
443
444        let signature = Signature::user_defined((&udaf.volatility).into());
445        let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
446
447        Arc::new(ForeignAggregateUDF {
448            udaf: udaf.clone(),
449            signature,
450            aliases,
451        })
452    }
453}
454
455impl AggregateUDFImpl for ForeignAggregateUDF {
456    fn as_any(&self) -> &dyn std::any::Any {
457        self
458    }
459
460    fn name(&self) -> &str {
461        self.udaf.name.as_str()
462    }
463
464    fn signature(&self) -> &Signature {
465        &self.signature
466    }
467
468    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
469        unimplemented!()
470    }
471
472    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
473        let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
474
475        let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
476
477        let result = df_result!(result);
478
479        result.and_then(|r| {
480            Field::try_from(&r.0)
481                .map(Arc::new)
482                .map_err(DataFusionError::from)
483        })
484    }
485
486    fn is_nullable(&self) -> bool {
487        self.udaf.is_nullable
488    }
489
490    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
491        let args = acc_args.try_into()?;
492        unsafe {
493            df_result!((self.udaf.accumulator)(&self.udaf, args))
494                .map(<Box<dyn Accumulator>>::from)
495        }
496    }
497
498    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
499        unsafe {
500            let name = RStr::from_str(args.name);
501            let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
502            let return_field =
503                WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
504            let ordering_fields = args
505                .ordering_fields
506                .iter()
507                .map(|f| f.as_ref())
508                .map(datafusion_proto::protobuf::Field::try_from)
509                .map(|v| v.map_err(DataFusionError::from))
510                .collect::<Result<Vec<_>>>()?
511                .into_iter()
512                .map(|proto_field| proto_field.encode_to_vec().into())
513                .collect();
514
515            let fields = df_result!((self.udaf.state_fields)(
516                &self.udaf,
517                &name,
518                input_fields,
519                return_field,
520                ordering_fields,
521                args.is_distinct
522            ))?;
523            let fields = fields
524                .into_iter()
525                .map(|field_bytes| {
526                    datafusion_proto_common::Field::decode(field_bytes.as_ref())
527                        .map_err(|e| ffi_datafusion_err!("{e}"))
528                })
529                .collect::<Result<Vec<_>>>()?;
530
531            parse_proto_fields_to_fields(fields.iter())
532                .map(|fields| fields.into_iter().map(Arc::new).collect())
533                .map_err(|e| ffi_datafusion_err!("{e}"))
534        }
535    }
536
537    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
538        let args = match FFI_AccumulatorArgs::try_from(args) {
539            Ok(v) => v,
540            Err(e) => {
541                log::warn!("Attempting to convert accumulator arguments: {e}");
542                return false;
543            }
544        };
545
546        unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
547    }
548
549    fn create_groups_accumulator(
550        &self,
551        args: AccumulatorArgs,
552    ) -> Result<Box<dyn GroupsAccumulator>> {
553        let args = FFI_AccumulatorArgs::try_from(args)?;
554
555        unsafe {
556            df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args))
557                .map(<Box<dyn GroupsAccumulator>>::from)
558        }
559    }
560
561    fn aliases(&self) -> &[String] {
562        &self.aliases
563    }
564
565    fn create_sliding_accumulator(
566        &self,
567        args: AccumulatorArgs,
568    ) -> Result<Box<dyn Accumulator>> {
569        let args = args.try_into()?;
570        unsafe {
571            df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args))
572                .map(<Box<dyn Accumulator>>::from)
573        }
574    }
575
576    fn with_beneficial_ordering(
577        self: Arc<Self>,
578        beneficial_ordering: bool,
579    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
580        unsafe {
581            let result = df_result!((self.udaf.with_beneficial_ordering)(
582                &self.udaf,
583                beneficial_ordering
584            ))?
585            .into_option();
586
587            let result = result.map(|func| <Arc<dyn AggregateUDFImpl>>::from(&func));
588
589            Ok(result)
590        }
591    }
592
593    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
594        unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
595    }
596
597    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
598        None
599    }
600
601    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
602        unsafe {
603            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
604            let result_types =
605                df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
606            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
607        }
608    }
609}
610
611#[repr(C)]
612#[derive(Debug, StableAbi)]
613pub enum FFI_AggregateOrderSensitivity {
614    Insensitive,
615    HardRequirement,
616    SoftRequirement,
617    Beneficial,
618}
619
620impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
621    fn from(value: FFI_AggregateOrderSensitivity) -> Self {
622        match value {
623            FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
624            FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
625            FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
626            FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
627        }
628    }
629}
630
631impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
632    fn from(value: AggregateOrderSensitivity) -> Self {
633        match value {
634            AggregateOrderSensitivity::Insensitive => Self::Insensitive,
635            AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
636            AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
637            AggregateOrderSensitivity::Beneficial => Self::Beneficial,
638        }
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use std::any::Any;
645    use std::collections::HashMap;
646
647    use arrow::datatypes::Schema;
648    use datafusion::common::create_array;
649    use datafusion::functions_aggregate::sum::Sum;
650    use datafusion::physical_expr::PhysicalSortExpr;
651    use datafusion::physical_plan::expressions::col;
652    use datafusion::scalar::ScalarValue;
653
654    use super::*;
655
656    #[derive(Default, Debug, Hash, Eq, PartialEq)]
657    struct SumWithCopiedMetadata {
658        inner: Sum,
659    }
660
661    impl AggregateUDFImpl for SumWithCopiedMetadata {
662        fn as_any(&self) -> &dyn Any {
663            self
664        }
665
666        fn name(&self) -> &str {
667            self.inner.name()
668        }
669
670        fn signature(&self) -> &Signature {
671            self.inner.signature()
672        }
673
674        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
675            unimplemented!()
676        }
677
678        fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
679            // Copy the input field, so any metadata gets returned
680            Ok(Arc::clone(&arg_fields[0]))
681        }
682
683        fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
684            self.inner.accumulator(acc_args)
685        }
686    }
687
688    fn create_test_foreign_udaf(
689        original_udaf: impl AggregateUDFImpl + 'static,
690    ) -> Result<AggregateUDF> {
691        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
692
693        let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
694        local_udaf.library_marker_id = crate::mock_foreign_marker_id;
695
696        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
697        Ok(AggregateUDF::new_from_shared_impl(foreign_udaf))
698    }
699
700    #[test]
701    fn test_round_trip_udaf() -> Result<()> {
702        let original_udaf = Sum::new();
703        let original_name = original_udaf.name().to_owned();
704        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
705
706        // Convert to FFI format
707        let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
708        local_udaf.library_marker_id = crate::mock_foreign_marker_id;
709
710        // Convert back to native format
711        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
712        let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
713
714        assert_eq!(original_name, foreign_udaf.name());
715        Ok(())
716    }
717
718    #[test]
719    fn test_foreign_udaf_aliases() -> Result<()> {
720        let foreign_udaf =
721            create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
722
723        let return_field =
724            foreign_udaf
725                .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
726        let return_type = return_field.data_type();
727        assert_eq!(return_type, &DataType::Float64);
728        Ok(())
729    }
730
731    #[test]
732    fn test_foreign_udaf_accumulator() -> Result<()> {
733        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
734
735        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
736        let acc_args = AccumulatorArgs {
737            return_field: Field::new("f", DataType::Float64, true).into(),
738            schema: &schema,
739            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
740            ignore_nulls: true,
741            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
742            is_reversed: false,
743            name: "round_trip",
744            is_distinct: true,
745            exprs: &[col("a", &schema)?],
746        };
747        let mut accumulator = foreign_udaf.accumulator(acc_args)?;
748        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
749        accumulator.update_batch(&[values])?;
750        let resultant_value = accumulator.evaluate()?;
751        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
752
753        Ok(())
754    }
755
756    #[test]
757    fn test_round_trip_udaf_metadata() -> Result<()> {
758        let original_udaf = SumWithCopiedMetadata::default();
759        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
760
761        // Convert to FFI format
762        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
763
764        // Convert back to native format
765        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
766        let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
767
768        let metadata: HashMap<String, String> =
769            [("a_key".to_string(), "a_value".to_string())]
770                .into_iter()
771                .collect();
772        let input_field = Arc::new(
773            Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
774        );
775        let return_field = foreign_udaf.return_field(&[input_field])?;
776
777        assert_eq!(&metadata, return_field.metadata());
778        Ok(())
779    }
780
781    #[test]
782    fn test_beneficial_ordering() -> Result<()> {
783        let foreign_udaf = create_test_foreign_udaf(
784            datafusion::functions_aggregate::first_last::FirstValue::new(),
785        )?;
786
787        let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
788
789        assert_eq!(
790            foreign_udaf.order_sensitivity(),
791            AggregateOrderSensitivity::Beneficial
792        );
793
794        let a_field = Arc::new(Field::new("a", DataType::Float64, true));
795        let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
796            name: "a",
797            input_fields: &[Field::new("f", DataType::Float64, true).into()],
798            return_field: Field::new("f", DataType::Float64, true).into(),
799            ordering_fields: &[Arc::clone(&a_field)],
800            is_distinct: false,
801        })?;
802
803        assert_eq!(state_fields.len(), 3);
804        assert_eq!(state_fields[1], a_field);
805        Ok(())
806    }
807
808    #[test]
809    fn test_sliding_accumulator() -> Result<()> {
810        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
811
812        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
813        // Note: sum distinct is only support Int64 until now
814        let acc_args = AccumulatorArgs {
815            return_field: Field::new("f", DataType::Float64, true).into(),
816            schema: &schema,
817            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
818            ignore_nulls: true,
819            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
820            is_reversed: false,
821            name: "round_trip",
822            is_distinct: false,
823            exprs: &[col("a", &schema)?],
824        };
825
826        let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
827        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
828        accumulator.update_batch(&[values])?;
829        let resultant_value = accumulator.evaluate()?;
830        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
831
832        Ok(())
833    }
834
835    fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
836        let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
837        let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
838
839        assert_eq!(sensitivity, round_trip_sensitivity);
840    }
841
842    #[test]
843    fn test_round_trip_all_order_sensitivities() {
844        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
845        test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
846        test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
847        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
848    }
849
850    #[test]
851    fn test_ffi_udaf_local_bypass() -> Result<()> {
852        let original_udaf = Sum::new();
853        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
854
855        let mut ffi_udaf = FFI_AggregateUDF::from(original_udaf);
856
857        // Verify local libraries can be downcast to their original
858        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
859        assert!(foreign_udaf.as_any().downcast_ref::<Sum>().is_some());
860
861        // Verify different library markers generate foreign providers
862        ffi_udaf.library_marker_id = crate::mock_foreign_marker_id;
863        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
864        assert!(
865            foreign_udaf
866                .as_any()
867                .downcast_ref::<ForeignAggregateUDF>()
868                .is_some()
869        );
870
871        Ok(())
872    }
873}