Skip to main content

datafusion_ffi/udaf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use accumulator::FFI_Accumulator;
19use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
20use arrow::datatypes::{DataType, Field};
21use arrow::ffi::FFI_ArrowSchema;
22use arrow_schema::FieldRef;
23use datafusion_common::{DataFusionError, Result, ffi_datafusion_err};
24use datafusion_expr::function::AggregateFunctionSimplification;
25use datafusion_expr::type_coercion::functions::fields_with_udf;
26use datafusion_expr::{
27    Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
28};
29use datafusion_functions_aggregate_common::accumulator::{
30    AccumulatorArgs, StateFieldsArgs,
31};
32use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
33use datafusion_proto_common::from_proto::parse_proto_fields_to_fields;
34use groups_accumulator::FFI_GroupsAccumulator;
35use prost::{DecodeError, Message};
36
37use stabby::str::Str as SStr;
38use stabby::string::String as SString;
39use stabby::vec::Vec as SVec;
40use std::ffi::c_void;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43
44use crate::arrow_wrappers::WrappedSchema;
45use crate::util::{
46    FFI_Option, FFI_Result, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
47    vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
48};
49use crate::volatility::FFI_Volatility;
50use crate::{df_result, sresult, sresult_return};
51
52mod accumulator;
53mod accumulator_args;
54mod groups_accumulator;
55
56/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries.
57#[repr(C)]
58#[derive(Debug)]
59pub struct FFI_AggregateUDF {
60    /// FFI equivalent to the `name` of a [`AggregateUDF`]
61    pub name: SString,
62
63    /// FFI equivalent to the `aliases` of a [`AggregateUDF`]
64    pub aliases: SVec<SString>,
65
66    /// FFI equivalent to the `volatility` of a [`AggregateUDF`]
67    pub volatility: FFI_Volatility,
68
69    /// Determines the return field of the underlying [`AggregateUDF`] based on the
70    /// argument fields.
71    pub return_field: unsafe extern "C" fn(
72        udaf: &Self,
73        arg_fields: SVec<WrappedSchema>,
74    ) -> FFI_Result<WrappedSchema>,
75
76    /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
77    pub is_nullable: bool,
78
79    /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`]
80    pub groups_accumulator_supported:
81        unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool,
82
83    /// FFI equivalent to [`AggregateUDF::accumulator`]
84    pub accumulator: unsafe extern "C" fn(
85        udaf: &FFI_AggregateUDF,
86        args: FFI_AccumulatorArgs,
87    ) -> FFI_Result<FFI_Accumulator>,
88
89    /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`]
90    pub create_sliding_accumulator: unsafe extern "C" fn(
91        udaf: &FFI_AggregateUDF,
92        args: FFI_AccumulatorArgs,
93    )
94        -> FFI_Result<FFI_Accumulator>,
95
96    /// FFI equivalent to [`AggregateUDF::state_fields`]
97    pub state_fields: unsafe extern "C" fn(
98        udaf: &FFI_AggregateUDF,
99        name: &SStr,
100        input_fields: SVec<WrappedSchema>,
101        return_field: WrappedSchema,
102        ordering_fields: SVec<SVec<u8>>,
103        is_distinct: bool,
104    ) -> FFI_Result<SVec<SVec<u8>>>,
105
106    /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`]
107    pub create_groups_accumulator:
108        unsafe extern "C" fn(
109            udaf: &FFI_AggregateUDF,
110            args: FFI_AccumulatorArgs,
111        ) -> FFI_Result<FFI_GroupsAccumulator>,
112
113    /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`]
114    pub with_beneficial_ordering:
115        unsafe extern "C" fn(
116            udaf: &FFI_AggregateUDF,
117            beneficial_ordering: bool,
118        ) -> FFI_Result<FFI_Option<FFI_AggregateUDF>>,
119
120    /// FFI equivalent to [`AggregateUDF::order_sensitivity`]
121    pub order_sensitivity:
122        unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity,
123
124    /// Performs type coercion. To simply this interface, all UDFs are treated as having
125    /// user defined signatures, which will in turn call coerce_types to be called. This
126    /// call should be transparent to most users as the internal function performs the
127    /// appropriate calls on the underlying [`AggregateUDF`]
128    pub coerce_types: unsafe extern "C" fn(
129        udf: &Self,
130        arg_types: SVec<WrappedSchema>,
131    ) -> FFI_Result<SVec<WrappedSchema>>,
132
133    /// Used to create a clone on the provider of the udaf. This should
134    /// only need to be called by the receiver of the udaf.
135    pub clone: unsafe extern "C" fn(udaf: &Self) -> Self,
136
137    /// Release the memory of the private data when it is no longer being used.
138    pub release: unsafe extern "C" fn(udaf: &mut Self),
139
140    /// Internal data. This is only to be accessed by the provider of the udaf.
141    /// A [`ForeignAggregateUDF`] should never attempt to access this data.
142    pub private_data: *mut c_void,
143
144    /// Utility to identify when FFI objects are accessed locally through
145    /// the foreign interface. See [`crate::get_library_marker_id`] and
146    /// the crate's `README.md` for more information.
147    pub library_marker_id: extern "C" fn() -> usize,
148}
149
150unsafe impl Send for FFI_AggregateUDF {}
151unsafe impl Sync for FFI_AggregateUDF {}
152
153pub struct AggregateUDFPrivateData {
154    pub udaf: Arc<AggregateUDF>,
155}
156
157impl FFI_AggregateUDF {
158    unsafe fn inner(&self) -> &Arc<AggregateUDF> {
159        unsafe {
160            let private_data = self.private_data as *const AggregateUDFPrivateData;
161            &(*private_data).udaf
162        }
163    }
164}
165
166unsafe extern "C" fn return_field_fn_wrapper(
167    udaf: &FFI_AggregateUDF,
168    arg_fields: SVec<WrappedSchema>,
169) -> FFI_Result<WrappedSchema> {
170    unsafe {
171        let udaf = udaf.inner();
172
173        let arg_fields = sresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
174
175        let return_field = udaf
176            .return_field(&arg_fields)
177            .and_then(|v| {
178                FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
179            })
180            .map(WrappedSchema);
181
182        sresult!(return_field)
183    }
184}
185
186unsafe extern "C" fn accumulator_fn_wrapper(
187    udaf: &FFI_AggregateUDF,
188    args: FFI_AccumulatorArgs,
189) -> FFI_Result<FFI_Accumulator> {
190    unsafe {
191        let udaf = udaf.inner();
192
193        let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
194
195        sresult!(
196            udaf.accumulator(accumulator_args.into())
197                .map(FFI_Accumulator::from)
198        )
199    }
200}
201
202unsafe extern "C" fn create_sliding_accumulator_fn_wrapper(
203    udaf: &FFI_AggregateUDF,
204    args: FFI_AccumulatorArgs,
205) -> FFI_Result<FFI_Accumulator> {
206    unsafe {
207        let udaf = udaf.inner();
208
209        let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
210
211        sresult!(
212            udaf.create_sliding_accumulator(accumulator_args.into())
213                .map(FFI_Accumulator::from)
214        )
215    }
216}
217
218unsafe extern "C" fn create_groups_accumulator_fn_wrapper(
219    udaf: &FFI_AggregateUDF,
220    args: FFI_AccumulatorArgs,
221) -> FFI_Result<FFI_GroupsAccumulator> {
222    unsafe {
223        let udaf = udaf.inner();
224
225        let accumulator_args = &sresult_return!(ForeignAccumulatorArgs::try_from(args));
226
227        sresult!(
228            udaf.create_groups_accumulator(accumulator_args.into())
229                .map(FFI_GroupsAccumulator::from)
230        )
231    }
232}
233
234unsafe extern "C" fn groups_accumulator_supported_fn_wrapper(
235    udaf: &FFI_AggregateUDF,
236    args: FFI_AccumulatorArgs,
237) -> bool {
238    unsafe {
239        let udaf = udaf.inner();
240
241        ForeignAccumulatorArgs::try_from(args)
242            .map(|a| udaf.groups_accumulator_supported((&a).into()))
243            .unwrap_or_else(|e| {
244                log::warn!("Unable to parse accumulator args. {e}");
245                false
246            })
247    }
248}
249
250unsafe extern "C" fn with_beneficial_ordering_fn_wrapper(
251    udaf: &FFI_AggregateUDF,
252    beneficial_ordering: bool,
253) -> FFI_Result<FFI_Option<FFI_AggregateUDF>> {
254    unsafe {
255        let udaf = udaf.inner().as_ref().clone();
256
257        let result = sresult_return!(udaf.with_beneficial_ordering(beneficial_ordering));
258        let result = sresult_return!(
259            result
260                .map(|func| func.with_beneficial_ordering(beneficial_ordering))
261                .transpose()
262        )
263        .flatten()
264        .map(|func| FFI_AggregateUDF::from(Arc::new(func)));
265
266        FFI_Result::Ok(result.into())
267    }
268}
269
270unsafe extern "C" fn state_fields_fn_wrapper(
271    udaf: &FFI_AggregateUDF,
272    name: &SStr,
273    input_fields: SVec<WrappedSchema>,
274    return_field: WrappedSchema,
275    ordering_fields: SVec<SVec<u8>>,
276    is_distinct: bool,
277) -> FFI_Result<SVec<SVec<u8>>> {
278    unsafe {
279        let udaf = udaf.inner();
280
281        let input_fields = &sresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
282        let return_field = sresult_return!(Field::try_from(&return_field.0)).into();
283
284        let ordering_fields = &sresult_return!(
285            ordering_fields
286                .into_iter()
287                .map(|field_bytes| datafusion_proto_common::Field::decode(
288                    field_bytes.as_ref()
289                ))
290                .collect::<std::result::Result<Vec<_>, DecodeError>>()
291        );
292
293        let ordering_fields =
294            &sresult_return!(parse_proto_fields_to_fields(ordering_fields))
295                .into_iter()
296                .map(Arc::new)
297                .collect::<Vec<_>>();
298
299        let args = StateFieldsArgs {
300            name: name.as_str(),
301            input_fields,
302            return_field,
303            ordering_fields,
304            is_distinct,
305        };
306
307        let state_fields = sresult_return!(udaf.state_fields(args));
308        let state_fields = sresult_return!(
309            state_fields
310                .iter()
311                .map(|f| f.as_ref())
312                .map(datafusion_proto::protobuf::Field::try_from)
313                .map(|v| v.map_err(DataFusionError::from))
314                .collect::<Result<Vec<_>>>()
315        )
316        .into_iter()
317        .map(|field| field.encode_to_vec().into_iter().collect())
318        .collect();
319
320        FFI_Result::Ok(state_fields)
321    }
322}
323
324unsafe extern "C" fn order_sensitivity_fn_wrapper(
325    udaf: &FFI_AggregateUDF,
326) -> FFI_AggregateOrderSensitivity {
327    unsafe { udaf.inner().order_sensitivity().into() }
328}
329
330unsafe extern "C" fn coerce_types_fn_wrapper(
331    udaf: &FFI_AggregateUDF,
332    arg_types: SVec<WrappedSchema>,
333) -> FFI_Result<SVec<WrappedSchema>> {
334    unsafe {
335        let udaf = udaf.inner();
336
337        let arg_types = sresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
338
339        let arg_fields = arg_types
340            .iter()
341            .map(|dt| Field::new("f", dt.clone(), true))
342            .map(Arc::new)
343            .collect::<Vec<_>>();
344        let return_types = sresult_return!(fields_with_udf(&arg_fields, udaf.as_ref()))
345            .into_iter()
346            .map(|f| f.data_type().to_owned())
347            .collect::<Vec<_>>();
348
349        sresult!(vec_datatype_to_rvec_wrapped(&return_types))
350    }
351}
352
353unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) {
354    unsafe {
355        debug_assert!(!udaf.private_data.is_null());
356        let private_data =
357            Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData);
358        drop(private_data);
359        udaf.private_data = std::ptr::null_mut();
360    }
361}
362
363unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF {
364    unsafe { Arc::clone(udaf.inner()).into() }
365}
366
367impl Clone for FFI_AggregateUDF {
368    fn clone(&self) -> Self {
369        unsafe { (self.clone)(self) }
370    }
371}
372
373impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
374    fn from(udaf: Arc<AggregateUDF>) -> Self {
375        if let Some(udaf) = udaf.inner().downcast_ref::<ForeignAggregateUDF>() {
376            return udaf.udaf.clone();
377        }
378
379        let name = udaf.name().into();
380        let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect();
381        let is_nullable = udaf.is_nullable();
382        let volatility = udaf.signature().volatility.into();
383
384        let private_data = Box::new(AggregateUDFPrivateData { udaf });
385
386        Self {
387            name,
388            is_nullable,
389            volatility,
390            aliases,
391            return_field: return_field_fn_wrapper,
392            accumulator: accumulator_fn_wrapper,
393            create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
394            create_groups_accumulator: create_groups_accumulator_fn_wrapper,
395            groups_accumulator_supported: groups_accumulator_supported_fn_wrapper,
396            with_beneficial_ordering: with_beneficial_ordering_fn_wrapper,
397            state_fields: state_fields_fn_wrapper,
398            order_sensitivity: order_sensitivity_fn_wrapper,
399            coerce_types: coerce_types_fn_wrapper,
400            clone: clone_fn_wrapper,
401            release: release_fn_wrapper,
402            private_data: Box::into_raw(private_data) as *mut c_void,
403            library_marker_id: crate::get_library_marker_id,
404        }
405    }
406}
407
408impl Drop for FFI_AggregateUDF {
409    fn drop(&mut self) {
410        unsafe { (self.release)(self) }
411    }
412}
413
414/// This struct is used to access an UDF provided by a foreign
415/// library across a FFI boundary.
416///
417/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has
418/// no knowledge or access to the private data. All interaction with the UDF
419/// must occur through the functions defined in FFI_AggregateUDF.
420#[derive(Debug)]
421pub struct ForeignAggregateUDF {
422    signature: Signature,
423    aliases: Vec<String>,
424    udaf: FFI_AggregateUDF,
425}
426
427unsafe impl Send for ForeignAggregateUDF {}
428unsafe impl Sync for ForeignAggregateUDF {}
429
430impl PartialEq for ForeignAggregateUDF {
431    fn eq(&self, other: &Self) -> bool {
432        // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do.
433        std::ptr::eq(self, other)
434    }
435}
436impl Eq for ForeignAggregateUDF {}
437impl Hash for ForeignAggregateUDF {
438    fn hash<H: Hasher>(&self, state: &mut H) {
439        std::ptr::hash(self, state)
440    }
441}
442
443impl From<&FFI_AggregateUDF> for Arc<dyn AggregateUDFImpl> {
444    fn from(udaf: &FFI_AggregateUDF) -> Self {
445        if (udaf.library_marker_id)() == crate::get_library_marker_id() {
446            return Arc::clone(unsafe { udaf.inner().inner() });
447        }
448
449        let signature = Signature::user_defined((&udaf.volatility).into());
450        let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect();
451
452        Arc::new(ForeignAggregateUDF {
453            udaf: udaf.clone(),
454            signature,
455            aliases,
456        })
457    }
458}
459
460impl AggregateUDFImpl for ForeignAggregateUDF {
461    fn name(&self) -> &str {
462        self.udaf.name.as_str()
463    }
464
465    fn signature(&self) -> &Signature {
466        &self.signature
467    }
468
469    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
470        unimplemented!()
471    }
472
473    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
474        let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
475
476        let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
477
478        let result = df_result!(result);
479
480        result.and_then(|r| {
481            Field::try_from(&r.0)
482                .map(Arc::new)
483                .map_err(DataFusionError::from)
484        })
485    }
486
487    fn is_nullable(&self) -> bool {
488        self.udaf.is_nullable
489    }
490
491    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
492        let args = acc_args.try_into()?;
493        unsafe {
494            df_result!((self.udaf.accumulator)(&self.udaf, args))
495                .map(<Box<dyn Accumulator>>::from)
496        }
497    }
498
499    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
500        unsafe {
501            let name = SStr::from(args.name);
502            let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?;
503            let return_field =
504                WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
505            let ordering_fields = args
506                .ordering_fields
507                .iter()
508                .map(|f| f.as_ref())
509                .map(datafusion_proto::protobuf::Field::try_from)
510                .map(|v| v.map_err(DataFusionError::from))
511                .collect::<Result<Vec<_>>>()?
512                .into_iter()
513                .map(|proto_field| proto_field.encode_to_vec().into_iter().collect())
514                .collect();
515
516            let fields = df_result!((self.udaf.state_fields)(
517                &self.udaf,
518                &name,
519                input_fields,
520                return_field,
521                ordering_fields,
522                args.is_distinct
523            ))?;
524            let fields = fields
525                .into_iter()
526                .map(|field_bytes| {
527                    datafusion_proto_common::Field::decode(field_bytes.as_ref())
528                        .map_err(|e| ffi_datafusion_err!("{e}"))
529                })
530                .collect::<Result<Vec<_>>>()?;
531
532            parse_proto_fields_to_fields(fields.iter())
533                .map(|fields| fields.into_iter().map(Arc::new).collect())
534                .map_err(|e| ffi_datafusion_err!("{e}"))
535        }
536    }
537
538    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
539        let args = match FFI_AccumulatorArgs::try_from(args) {
540            Ok(v) => v,
541            Err(e) => {
542                log::warn!("Attempting to convert accumulator arguments: {e}");
543                return false;
544            }
545        };
546
547        unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) }
548    }
549
550    fn create_groups_accumulator(
551        &self,
552        args: AccumulatorArgs,
553    ) -> Result<Box<dyn GroupsAccumulator>> {
554        let args = FFI_AccumulatorArgs::try_from(args)?;
555
556        unsafe {
557            df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args))
558                .map(<Box<dyn GroupsAccumulator>>::from)
559        }
560    }
561
562    fn aliases(&self) -> &[String] {
563        &self.aliases
564    }
565
566    fn create_sliding_accumulator(
567        &self,
568        args: AccumulatorArgs,
569    ) -> Result<Box<dyn Accumulator>> {
570        let args = args.try_into()?;
571        unsafe {
572            df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args))
573                .map(<Box<dyn Accumulator>>::from)
574        }
575    }
576
577    fn with_beneficial_ordering(
578        self: Arc<Self>,
579        beneficial_ordering: bool,
580    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
581        unsafe {
582            let result = df_result!((self.udaf.with_beneficial_ordering)(
583                &self.udaf,
584                beneficial_ordering
585            ))?
586            .into_option();
587
588            let result = result.map(|func| <Arc<dyn AggregateUDFImpl>>::from(&func));
589
590            Ok(result)
591        }
592    }
593
594    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
595        unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() }
596    }
597
598    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
599        None
600    }
601
602    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
603        unsafe {
604            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
605            let result_types =
606                df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?;
607            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
608        }
609    }
610}
611
612#[repr(C)]
613#[derive(Debug)]
614pub enum FFI_AggregateOrderSensitivity {
615    Insensitive,
616    HardRequirement,
617    SoftRequirement,
618    Beneficial,
619}
620
621impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity {
622    fn from(value: FFI_AggregateOrderSensitivity) -> Self {
623        match value {
624            FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive,
625            FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
626            FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
627            FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial,
628        }
629    }
630}
631
632impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
633    fn from(value: AggregateOrderSensitivity) -> Self {
634        match value {
635            AggregateOrderSensitivity::Insensitive => Self::Insensitive,
636            AggregateOrderSensitivity::HardRequirement => Self::HardRequirement,
637            AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement,
638            AggregateOrderSensitivity::Beneficial => Self::Beneficial,
639        }
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use std::collections::HashMap;
646
647    use arrow::datatypes::Schema;
648    use datafusion::common::create_array;
649    use datafusion::functions_aggregate::sum::Sum;
650    use datafusion::physical_expr::PhysicalSortExpr;
651    use datafusion::physical_plan::expressions::col;
652    use datafusion::scalar::ScalarValue;
653
654    use super::*;
655
656    #[derive(Default, Debug, Hash, Eq, PartialEq)]
657    struct SumWithCopiedMetadata {
658        inner: Sum,
659    }
660
661    impl AggregateUDFImpl for SumWithCopiedMetadata {
662        fn name(&self) -> &str {
663            self.inner.name()
664        }
665
666        fn signature(&self) -> &Signature {
667            self.inner.signature()
668        }
669
670        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
671            unimplemented!()
672        }
673
674        fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
675            // Copy the input field, so any metadata gets returned
676            Ok(Arc::clone(&arg_fields[0]))
677        }
678
679        fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
680            self.inner.accumulator(acc_args)
681        }
682    }
683
684    fn create_test_foreign_udaf(
685        original_udaf: impl AggregateUDFImpl + 'static,
686    ) -> Result<AggregateUDF> {
687        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
688
689        let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
690        local_udaf.library_marker_id = crate::mock_foreign_marker_id;
691
692        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
693        Ok(AggregateUDF::new_from_shared_impl(foreign_udaf))
694    }
695
696    #[test]
697    fn test_round_trip_udaf() -> Result<()> {
698        let original_udaf = Sum::new();
699        let original_name = original_udaf.name().to_owned();
700        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
701
702        // Convert to FFI format
703        let mut local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
704        local_udaf.library_marker_id = crate::mock_foreign_marker_id;
705
706        // Convert back to native format
707        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
708        let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
709
710        assert_eq!(original_name, foreign_udaf.name());
711        Ok(())
712    }
713
714    #[test]
715    fn test_foreign_udaf_aliases() -> Result<()> {
716        let foreign_udaf =
717            create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
718
719        let return_field =
720            foreign_udaf
721                .return_field(&[Field::new("a", DataType::Float64, true).into()])?;
722        let return_type = return_field.data_type();
723        assert_eq!(return_type, &DataType::Float64);
724        Ok(())
725    }
726
727    #[test]
728    fn test_foreign_udaf_accumulator() -> Result<()> {
729        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
730
731        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
732        let acc_args = AccumulatorArgs {
733            return_field: Field::new("f", DataType::Float64, true).into(),
734            schema: &schema,
735            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
736            ignore_nulls: true,
737            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
738            is_reversed: false,
739            name: "round_trip",
740            is_distinct: true,
741            exprs: &[col("a", &schema)?],
742        };
743        let mut accumulator = foreign_udaf.accumulator(acc_args)?;
744        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
745        accumulator.update_batch(&[values])?;
746        let resultant_value = accumulator.evaluate()?;
747        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
748
749        Ok(())
750    }
751
752    #[test]
753    fn test_round_trip_udaf_metadata() -> Result<()> {
754        let original_udaf = SumWithCopiedMetadata::default();
755        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
756
757        // Convert to FFI format
758        let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
759
760        // Convert back to native format
761        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&local_udaf).into();
762        let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf);
763
764        let metadata: HashMap<String, String> =
765            [("a_key".to_string(), "a_value".to_string())]
766                .into_iter()
767                .collect();
768        let input_field = Arc::new(
769            Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
770        );
771        let return_field = foreign_udaf.return_field(&[input_field])?;
772
773        assert_eq!(&metadata, return_field.metadata());
774        Ok(())
775    }
776
777    #[test]
778    fn test_beneficial_ordering() -> Result<()> {
779        let foreign_udaf = create_test_foreign_udaf(
780            datafusion::functions_aggregate::first_last::FirstValue::new(),
781        )?;
782
783        let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
784
785        assert_eq!(
786            foreign_udaf.order_sensitivity(),
787            AggregateOrderSensitivity::Beneficial
788        );
789
790        let a_field = Arc::new(Field::new("a", DataType::Float64, true));
791        let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
792            name: "a",
793            input_fields: &[Field::new("f", DataType::Float64, true).into()],
794            return_field: Field::new("f", DataType::Float64, true).into(),
795            ordering_fields: &[Arc::clone(&a_field)],
796            is_distinct: false,
797        })?;
798
799        assert_eq!(state_fields.len(), 3);
800        assert_eq!(state_fields[1], a_field);
801        Ok(())
802    }
803
804    #[test]
805    fn test_sliding_accumulator() -> Result<()> {
806        let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
807
808        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
809        // Note: sum distinct is only support Int64 until now
810        let acc_args = AccumulatorArgs {
811            return_field: Field::new("f", DataType::Float64, true).into(),
812            schema: &schema,
813            expr_fields: &[Field::new("a", DataType::Float64, true).into()],
814            ignore_nulls: true,
815            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
816            is_reversed: false,
817            name: "round_trip",
818            is_distinct: false,
819            exprs: &[col("a", &schema)?],
820        };
821
822        let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
823        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
824        accumulator.update_batch(&[values])?;
825        let resultant_value = accumulator.evaluate()?;
826        assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
827
828        Ok(())
829    }
830
831    fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
832        let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
833        let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
834
835        assert_eq!(sensitivity, round_trip_sensitivity);
836    }
837
838    #[test]
839    fn test_round_trip_all_order_sensitivities() {
840        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
841        test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
842        test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement);
843        test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
844    }
845
846    #[test]
847    fn test_ffi_udaf_local_bypass() -> Result<()> {
848        let original_udaf = Sum::new();
849        let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
850
851        let mut ffi_udaf = FFI_AggregateUDF::from(original_udaf);
852
853        // Verify local libraries can be downcast to their original
854        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
855        assert!(foreign_udaf.is::<Sum>());
856
857        // Verify different library markers generate foreign providers
858        ffi_udaf.library_marker_id = crate::mock_foreign_marker_id;
859        let foreign_udaf: Arc<dyn AggregateUDFImpl> = (&ffi_udaf).into();
860        assert!(foreign_udaf.is::<ForeignAggregateUDF>());
861
862        Ok(())
863    }
864}