datafusion_ffi/udaf/
groups_accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::ops::Deref;
20use std::ptr::null_mut;
21use std::sync::Arc;
22
23use abi_stable::StableAbi;
24use abi_stable::std_types::{ROption, RVec};
25use arrow::array::{Array, ArrayRef, BooleanArray};
26use arrow::error::ArrowError;
27use arrow::ffi::to_ffi;
28use datafusion_common::error::{DataFusionError, Result};
29use datafusion_expr::{EmitTo, GroupsAccumulator};
30
31use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
32use crate::util::FFIResult;
33use crate::{df_result, rresult, rresult_return};
34
35/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries.
36/// For an explanation of each field, see the corresponding function
37/// defined in [`GroupsAccumulator`].
38#[repr(C)]
39#[derive(Debug, StableAbi)]
40pub struct FFI_GroupsAccumulator {
41    pub update_batch: unsafe extern "C" fn(
42        accumulator: &mut Self,
43        values: RVec<WrappedArray>,
44        group_indices: RVec<usize>,
45        opt_filter: ROption<WrappedArray>,
46        total_num_groups: usize,
47    ) -> FFIResult<()>,
48
49    // Evaluate and return a ScalarValues as protobuf bytes
50    pub evaluate: unsafe extern "C" fn(
51        accumulator: &mut Self,
52        emit_to: FFI_EmitTo,
53    ) -> FFIResult<WrappedArray>,
54
55    pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
56
57    pub state: unsafe extern "C" fn(
58        accumulator: &mut Self,
59        emit_to: FFI_EmitTo,
60    ) -> FFIResult<RVec<WrappedArray>>,
61
62    pub merge_batch: unsafe extern "C" fn(
63        accumulator: &mut Self,
64        values: RVec<WrappedArray>,
65        group_indices: RVec<usize>,
66        opt_filter: ROption<WrappedArray>,
67        total_num_groups: usize,
68    ) -> FFIResult<()>,
69
70    pub convert_to_state: unsafe extern "C" fn(
71        accumulator: &Self,
72        values: RVec<WrappedArray>,
73        opt_filter: ROption<WrappedArray>,
74    ) -> FFIResult<RVec<WrappedArray>>,
75
76    pub supports_convert_to_state: bool,
77
78    /// Release the memory of the private data when it is no longer being used.
79    pub release: unsafe extern "C" fn(accumulator: &mut Self),
80
81    /// Internal data. This is only to be accessed by the provider of the accumulator.
82    /// A [`ForeignGroupsAccumulator`] should never attempt to access this data.
83    pub private_data: *mut c_void,
84
85    /// Utility to identify when FFI objects are accessed locally through
86    /// the foreign interface. See [`crate::get_library_marker_id`] and
87    /// the crate's `README.md` for more information.
88    pub library_marker_id: extern "C" fn() -> usize,
89}
90
91pub struct GroupsAccumulatorPrivateData {
92    pub accumulator: Box<dyn GroupsAccumulator>,
93}
94
95impl FFI_GroupsAccumulator {
96    #[inline]
97    unsafe fn inner_mut(&mut self) -> &mut Box<dyn GroupsAccumulator> {
98        unsafe {
99            let private_data = self.private_data as *mut GroupsAccumulatorPrivateData;
100            &mut (*private_data).accumulator
101        }
102    }
103
104    #[inline]
105    unsafe fn inner(&self) -> &dyn GroupsAccumulator {
106        unsafe {
107            let private_data = self.private_data as *const GroupsAccumulatorPrivateData;
108            (*private_data).accumulator.deref()
109        }
110    }
111}
112
113fn process_values(values: RVec<WrappedArray>) -> Result<Vec<Arc<dyn Array>>> {
114    values
115        .into_iter()
116        .map(|v| v.try_into().map_err(DataFusionError::from))
117        .collect::<Result<Vec<ArrayRef>>>()
118}
119
120/// Convert C-typed opt_filter into the internal type.
121fn process_opt_filter(opt_filter: ROption<WrappedArray>) -> Result<Option<BooleanArray>> {
122    opt_filter
123        .into_option()
124        .map(|filter| {
125            ArrayRef::try_from(filter)
126                .map_err(DataFusionError::from)
127                .map(|arr| BooleanArray::from(arr.into_data()))
128        })
129        .transpose()
130}
131
132unsafe extern "C" fn update_batch_fn_wrapper(
133    accumulator: &mut FFI_GroupsAccumulator,
134    values: RVec<WrappedArray>,
135    group_indices: RVec<usize>,
136    opt_filter: ROption<WrappedArray>,
137    total_num_groups: usize,
138) -> FFIResult<()> {
139    unsafe {
140        let accumulator = accumulator.inner_mut();
141        let values = rresult_return!(process_values(values));
142        let group_indices: Vec<usize> = group_indices.into_iter().collect();
143        let opt_filter = rresult_return!(process_opt_filter(opt_filter));
144
145        rresult!(accumulator.update_batch(
146            &values,
147            &group_indices,
148            opt_filter.as_ref(),
149            total_num_groups
150        ))
151    }
152}
153
154unsafe extern "C" fn evaluate_fn_wrapper(
155    accumulator: &mut FFI_GroupsAccumulator,
156    emit_to: FFI_EmitTo,
157) -> FFIResult<WrappedArray> {
158    unsafe {
159        let accumulator = accumulator.inner_mut();
160
161        let result = rresult_return!(accumulator.evaluate(emit_to.into()));
162
163        rresult!(WrappedArray::try_from(&result))
164    }
165}
166
167unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize {
168    unsafe {
169        let accumulator = accumulator.inner();
170        accumulator.size()
171    }
172}
173
174unsafe extern "C" fn state_fn_wrapper(
175    accumulator: &mut FFI_GroupsAccumulator,
176    emit_to: FFI_EmitTo,
177) -> FFIResult<RVec<WrappedArray>> {
178    unsafe {
179        let accumulator = accumulator.inner_mut();
180
181        let state = rresult_return!(accumulator.state(emit_to.into()));
182        rresult!(
183            state
184                .into_iter()
185                .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from))
186                .collect::<Result<RVec<_>>>()
187        )
188    }
189}
190
191unsafe extern "C" fn merge_batch_fn_wrapper(
192    accumulator: &mut FFI_GroupsAccumulator,
193    values: RVec<WrappedArray>,
194    group_indices: RVec<usize>,
195    opt_filter: ROption<WrappedArray>,
196    total_num_groups: usize,
197) -> FFIResult<()> {
198    unsafe {
199        let accumulator = accumulator.inner_mut();
200        let values = rresult_return!(process_values(values));
201        let group_indices: Vec<usize> = group_indices.into_iter().collect();
202        let opt_filter = rresult_return!(process_opt_filter(opt_filter));
203
204        rresult!(accumulator.merge_batch(
205            &values,
206            &group_indices,
207            opt_filter.as_ref(),
208            total_num_groups
209        ))
210    }
211}
212
213unsafe extern "C" fn convert_to_state_fn_wrapper(
214    accumulator: &FFI_GroupsAccumulator,
215    values: RVec<WrappedArray>,
216    opt_filter: ROption<WrappedArray>,
217) -> FFIResult<RVec<WrappedArray>> {
218    unsafe {
219        let accumulator = accumulator.inner();
220        let values = rresult_return!(process_values(values));
221        let opt_filter = rresult_return!(process_opt_filter(opt_filter));
222        let state =
223            rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref()));
224
225        rresult!(
226            state
227                .iter()
228                .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from))
229                .collect::<Result<RVec<_>>>()
230        )
231    }
232}
233
234unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) {
235    unsafe {
236        if !accumulator.private_data.is_null() {
237            let private_data = Box::from_raw(
238                accumulator.private_data as *mut GroupsAccumulatorPrivateData,
239            );
240            drop(private_data);
241            accumulator.private_data = null_mut();
242        }
243    }
244}
245
246impl From<Box<dyn GroupsAccumulator>> for FFI_GroupsAccumulator {
247    fn from(accumulator: Box<dyn GroupsAccumulator>) -> Self {
248        let supports_convert_to_state = accumulator.supports_convert_to_state();
249        let private_data = GroupsAccumulatorPrivateData { accumulator };
250
251        Self {
252            update_batch: update_batch_fn_wrapper,
253            evaluate: evaluate_fn_wrapper,
254            size: size_fn_wrapper,
255            state: state_fn_wrapper,
256            merge_batch: merge_batch_fn_wrapper,
257            convert_to_state: convert_to_state_fn_wrapper,
258            supports_convert_to_state,
259
260            release: release_fn_wrapper,
261            private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
262            library_marker_id: crate::get_library_marker_id,
263        }
264    }
265}
266
267impl Drop for FFI_GroupsAccumulator {
268    fn drop(&mut self) {
269        unsafe { (self.release)(self) }
270    }
271}
272
273/// This struct is used to access an UDF provided by a foreign
274/// library across a FFI boundary.
275///
276/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has
277/// no knowledge or access to the private data. All interaction with the UDF
278/// must occur through the functions defined in FFI_GroupsAccumulator.
279#[derive(Debug)]
280pub struct ForeignGroupsAccumulator {
281    accumulator: FFI_GroupsAccumulator,
282}
283
284unsafe impl Send for ForeignGroupsAccumulator {}
285unsafe impl Sync for ForeignGroupsAccumulator {}
286
287impl From<FFI_GroupsAccumulator> for Box<dyn GroupsAccumulator> {
288    fn from(mut accumulator: FFI_GroupsAccumulator) -> Self {
289        if (accumulator.library_marker_id)() == crate::get_library_marker_id() {
290            unsafe {
291                let private_data = Box::from_raw(
292                    accumulator.private_data as *mut GroupsAccumulatorPrivateData,
293                );
294                // We must set this to null to avoid a double free
295                accumulator.private_data = null_mut();
296                private_data.accumulator
297            }
298        } else {
299            Box::new(ForeignGroupsAccumulator { accumulator })
300        }
301    }
302}
303
304impl GroupsAccumulator for ForeignGroupsAccumulator {
305    fn update_batch(
306        &mut self,
307        values: &[ArrayRef],
308        group_indices: &[usize],
309        opt_filter: Option<&BooleanArray>,
310        total_num_groups: usize,
311    ) -> Result<()> {
312        unsafe {
313            let values = values
314                .iter()
315                .map(WrappedArray::try_from)
316                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
317            let group_indices = group_indices.iter().cloned().collect();
318            let opt_filter = opt_filter
319                .map(|bool_array| to_ffi(&bool_array.to_data()))
320                .transpose()?
321                .map(|(array, schema)| WrappedArray {
322                    array,
323                    schema: WrappedSchema(schema),
324                })
325                .into();
326
327            df_result!((self.accumulator.update_batch)(
328                &mut self.accumulator,
329                values.into(),
330                group_indices,
331                opt_filter,
332                total_num_groups
333            ))
334        }
335    }
336
337    fn size(&self) -> usize {
338        unsafe { (self.accumulator.size)(&self.accumulator) }
339    }
340
341    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
342        unsafe {
343            let return_array = df_result!((self.accumulator.evaluate)(
344                &mut self.accumulator,
345                emit_to.into()
346            ))?;
347
348            return_array.try_into().map_err(DataFusionError::from)
349        }
350    }
351
352    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
353        unsafe {
354            let returned_arrays = df_result!((self.accumulator.state)(
355                &mut self.accumulator,
356                emit_to.into()
357            ))?;
358
359            returned_arrays
360                .into_iter()
361                .map(|wrapped_array| {
362                    wrapped_array.try_into().map_err(DataFusionError::from)
363                })
364                .collect::<Result<Vec<_>>>()
365        }
366    }
367
368    fn merge_batch(
369        &mut self,
370        values: &[ArrayRef],
371        group_indices: &[usize],
372        opt_filter: Option<&BooleanArray>,
373        total_num_groups: usize,
374    ) -> Result<()> {
375        unsafe {
376            let values = values
377                .iter()
378                .map(WrappedArray::try_from)
379                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
380            let group_indices = group_indices.iter().cloned().collect();
381            let opt_filter = opt_filter
382                .map(|bool_array| to_ffi(&bool_array.to_data()))
383                .transpose()?
384                .map(|(array, schema)| WrappedArray {
385                    array,
386                    schema: WrappedSchema(schema),
387                })
388                .into();
389
390            df_result!((self.accumulator.merge_batch)(
391                &mut self.accumulator,
392                values.into(),
393                group_indices,
394                opt_filter,
395                total_num_groups
396            ))
397        }
398    }
399
400    fn convert_to_state(
401        &self,
402        values: &[ArrayRef],
403        opt_filter: Option<&BooleanArray>,
404    ) -> Result<Vec<ArrayRef>> {
405        unsafe {
406            let values = values
407                .iter()
408                .map(WrappedArray::try_from)
409                .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
410
411            let opt_filter = opt_filter
412                .map(|bool_array| to_ffi(&bool_array.to_data()))
413                .transpose()?
414                .map(|(array, schema)| WrappedArray {
415                    array,
416                    schema: WrappedSchema(schema),
417                })
418                .into();
419
420            let returned_array = df_result!((self.accumulator.convert_to_state)(
421                &self.accumulator,
422                values,
423                opt_filter
424            ))?;
425
426            returned_array
427                .into_iter()
428                .map(|arr| arr.try_into().map_err(DataFusionError::from))
429                .collect()
430        }
431    }
432
433    fn supports_convert_to_state(&self) -> bool {
434        self.accumulator.supports_convert_to_state
435    }
436}
437
438#[repr(C)]
439#[derive(Debug, StableAbi)]
440pub enum FFI_EmitTo {
441    All,
442    First(usize),
443}
444
445impl From<EmitTo> for FFI_EmitTo {
446    fn from(value: EmitTo) -> Self {
447        match value {
448            EmitTo::All => Self::All,
449            EmitTo::First(v) => Self::First(v),
450        }
451    }
452}
453
454impl From<FFI_EmitTo> for EmitTo {
455    fn from(value: FFI_EmitTo) -> Self {
456        match value {
457            FFI_EmitTo::All => Self::All,
458            FFI_EmitTo::First(v) => Self::First(v),
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use arrow::array::{Array, BooleanArray, make_array};
466    use datafusion::common::create_array;
467    use datafusion::error::Result;
468    use datafusion::functions_aggregate::stddev::StddevGroupsAccumulator;
469    use datafusion::logical_expr::{EmitTo, GroupsAccumulator};
470    use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
471    use datafusion_functions_aggregate_common::stats::StatsType;
472
473    use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator};
474
475    #[test]
476    fn test_foreign_avg_accumulator() -> Result<()> {
477        let boxed_accum: Box<dyn GroupsAccumulator> =
478            Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true));
479        let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
480        ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
481        let mut foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
482
483        // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4.
484        let values = create_array!(Boolean, vec![true, true, true, false, true, true]);
485        let opt_filter =
486            create_array!(Boolean, vec![true, true, true, true, false, false]);
487        foreign_accum.update_batch(
488            &[values],
489            &[0, 0, 1, 1, 2, 2],
490            Some(opt_filter.as_ref()),
491            3,
492        )?;
493
494        let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
495        let groups_bool = groups_bool.as_any().downcast_ref::<BooleanArray>().unwrap();
496
497        assert_eq!(
498            groups_bool,
499            create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref()
500        );
501
502        let state = foreign_accum.state(EmitTo::All)?;
503        assert_eq!(state.len(), 1);
504
505        // To verify merging batches works, create a second state to add in
506        // This should cause our average to go down to 25.0
507        let second_states =
508            vec![make_array(create_array!(Boolean, vec![false]).to_data())];
509
510        let opt_filter = create_array!(Boolean, vec![true]);
511        foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?;
512        let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
513        assert_eq!(groups_bool.len(), 1);
514        assert_eq!(
515            groups_bool.as_ref(),
516            make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
517        );
518
519        let values = create_array!(Boolean, vec![false]);
520        let opt_filter = create_array!(Boolean, vec![true]);
521        let groups_bool =
522            foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?;
523
524        assert_eq!(
525            groups_bool[0].as_ref(),
526            make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
527        );
528
529        Ok(())
530    }
531
532    fn test_emit_to_round_trip(value: EmitTo) -> Result<()> {
533        let ffi_value: FFI_EmitTo = value.into();
534        let round_trip_value: EmitTo = ffi_value.into();
535
536        assert_eq!(value, round_trip_value);
537        Ok(())
538    }
539
540    /// This test ensures all enum values are properly translated
541    #[test]
542    fn test_all_emit_to_round_trip() -> Result<()> {
543        test_emit_to_round_trip(EmitTo::All)?;
544        test_emit_to_round_trip(EmitTo::First(10))?;
545
546        Ok(())
547    }
548
549    #[test]
550    fn test_ffi_groups_accumulator_local_bypass_inner() -> Result<()> {
551        let original_accum = StddevGroupsAccumulator::new(StatsType::Population);
552        let boxed_accum: Box<dyn GroupsAccumulator> = Box::new(original_accum);
553        let original_size = boxed_accum.size();
554
555        let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
556
557        // Verify local libraries can be downcast to their original
558        let foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
559        unsafe {
560            let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator
561                as *const StddevGroupsAccumulator);
562            assert_eq!(original_size, concrete.size());
563        }
564
565        // Verify different library markers generate foreign accumulator
566        let original_accum = StddevGroupsAccumulator::new(StatsType::Population);
567        let boxed_accum: Box<dyn GroupsAccumulator> = Box::new(original_accum);
568        let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
569        ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
570        let foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
571        unsafe {
572            let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator
573                as *const ForeignGroupsAccumulator);
574            assert_eq!(original_size, concrete.size());
575        }
576
577        Ok(())
578    }
579}