datafusion_ffi/udaf/
groups_accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{ffi::c_void, ops::Deref, sync::Arc};
19
20use crate::{
21    arrow_wrappers::{WrappedArray, WrappedSchema},
22    df_result, rresult, rresult_return,
23};
24use abi_stable::{
25    std_types::{ROption, RResult, RString, RVec},
26    StableAbi,
27};
28use arrow::{
29    array::{Array, ArrayRef, BooleanArray},
30    error::ArrowError,
31    ffi::to_ffi,
32};
33use datafusion::{
34    error::{DataFusionError, Result},
35    logical_expr::{EmitTo, GroupsAccumulator},
36};
37
38/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries.
39/// For an explanation of each field, see the corresponding function
40/// defined in [`GroupsAccumulator`].
41#[repr(C)]
42#[derive(Debug, StableAbi)]
43#[allow(non_camel_case_types)]
44pub struct FFI_GroupsAccumulator {
45    pub update_batch: unsafe extern "C" fn(
46        accumulator: &mut Self,
47        values: RVec<WrappedArray>,
48        group_indices: RVec<usize>,
49        opt_filter: ROption<WrappedArray>,
50        total_num_groups: usize,
51    ) -> RResult<(), RString>,
52
53    // Evaluate and return a ScalarValues as protobuf bytes
54    pub evaluate: unsafe extern "C" fn(
55        accumulator: &mut Self,
56        emit_to: FFI_EmitTo,
57    ) -> RResult<WrappedArray, RString>,
58
59    pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
60
61    pub state: unsafe extern "C" fn(
62        accumulator: &mut Self,
63        emit_to: FFI_EmitTo,
64    ) -> RResult<RVec<WrappedArray>, RString>,
65
66    pub merge_batch: unsafe extern "C" fn(
67        accumulator: &mut Self,
68        values: RVec<WrappedArray>,
69        group_indices: RVec<usize>,
70        opt_filter: ROption<WrappedArray>,
71        total_num_groups: usize,
72    ) -> RResult<(), RString>,
73
74    pub convert_to_state: unsafe extern "C" fn(
75        accumulator: &Self,
76        values: RVec<WrappedArray>,
77        opt_filter: ROption<WrappedArray>,
78    )
79        -> RResult<RVec<WrappedArray>, RString>,
80
81    pub supports_convert_to_state: bool,
82
83    /// Release the memory of the private data when it is no longer being used.
84    pub release: unsafe extern "C" fn(accumulator: &mut Self),
85
86    /// Internal data. This is only to be accessed by the provider of the accumulator.
87    /// A [`ForeignGroupsAccumulator`] should never attempt to access this data.
88    pub private_data: *mut c_void,
89}
90
91unsafe impl Send for FFI_GroupsAccumulator {}
92unsafe impl Sync for FFI_GroupsAccumulator {}
93
94pub struct GroupsAccumulatorPrivateData {
95    pub accumulator: Box<dyn GroupsAccumulator>,
96}
97
98impl FFI_GroupsAccumulator {
99    #[inline]
100    unsafe fn inner_mut(&mut self) -> &mut Box<dyn GroupsAccumulator> {
101        let private_data = self.private_data as *mut GroupsAccumulatorPrivateData;
102        &mut (*private_data).accumulator
103    }
104
105    #[inline]
106    unsafe fn inner(&self) -> &dyn GroupsAccumulator {
107        let private_data = self.private_data as *const GroupsAccumulatorPrivateData;
108        (*private_data).accumulator.deref()
109    }
110}
111
112fn process_values(values: RVec<WrappedArray>) -> Result<Vec<Arc<dyn Array>>> {
113    values
114        .into_iter()
115        .map(|v| v.try_into().map_err(DataFusionError::from))
116        .collect::<Result<Vec<ArrayRef>>>()
117}
118
119/// Convert C-typed opt_filter into the internal type.
120fn process_opt_filter(opt_filter: ROption<WrappedArray>) -> Result<Option<BooleanArray>> {
121    opt_filter
122        .into_option()
123        .map(|filter| {
124            ArrayRef::try_from(filter)
125                .map_err(DataFusionError::from)
126                .map(|arr| BooleanArray::from(arr.into_data()))
127        })
128        .transpose()
129}
130
131unsafe extern "C" fn update_batch_fn_wrapper(
132    accumulator: &mut FFI_GroupsAccumulator,
133    values: RVec<WrappedArray>,
134    group_indices: RVec<usize>,
135    opt_filter: ROption<WrappedArray>,
136    total_num_groups: usize,
137) -> RResult<(), RString> {
138    let accumulator = accumulator.inner_mut();
139    let values = rresult_return!(process_values(values));
140    let group_indices: Vec<usize> = group_indices.into_iter().collect();
141    let opt_filter = rresult_return!(process_opt_filter(opt_filter));
142
143    rresult!(accumulator.update_batch(
144        &values,
145        &group_indices,
146        opt_filter.as_ref(),
147        total_num_groups
148    ))
149}
150
151unsafe extern "C" fn evaluate_fn_wrapper(
152    accumulator: &mut FFI_GroupsAccumulator,
153    emit_to: FFI_EmitTo,
154) -> RResult<WrappedArray, RString> {
155    let accumulator = accumulator.inner_mut();
156
157    let result = rresult_return!(accumulator.evaluate(emit_to.into()));
158
159    rresult!(WrappedArray::try_from(&result))
160}
161
162unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize {
163    let accumulator = accumulator.inner();
164    accumulator.size()
165}
166
167unsafe extern "C" fn state_fn_wrapper(
168    accumulator: &mut FFI_GroupsAccumulator,
169    emit_to: FFI_EmitTo,
170) -> RResult<RVec<WrappedArray>, RString> {
171    let accumulator = accumulator.inner_mut();
172
173    let state = rresult_return!(accumulator.state(emit_to.into()));
174    rresult!(state
175        .into_iter()
176        .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from))
177        .collect::<Result<RVec<_>>>())
178}
179
180unsafe extern "C" fn merge_batch_fn_wrapper(
181    accumulator: &mut FFI_GroupsAccumulator,
182    values: RVec<WrappedArray>,
183    group_indices: RVec<usize>,
184    opt_filter: ROption<WrappedArray>,
185    total_num_groups: usize,
186) -> RResult<(), RString> {
187    let accumulator = accumulator.inner_mut();
188    let values = rresult_return!(process_values(values));
189    let group_indices: Vec<usize> = group_indices.into_iter().collect();
190    let opt_filter = rresult_return!(process_opt_filter(opt_filter));
191
192    rresult!(accumulator.merge_batch(
193        &values,
194        &group_indices,
195        opt_filter.as_ref(),
196        total_num_groups
197    ))
198}
199
200unsafe extern "C" fn convert_to_state_fn_wrapper(
201    accumulator: &FFI_GroupsAccumulator,
202    values: RVec<WrappedArray>,
203    opt_filter: ROption<WrappedArray>,
204) -> RResult<RVec<WrappedArray>, RString> {
205    let accumulator = accumulator.inner();
206    let values = rresult_return!(process_values(values));
207    let opt_filter = rresult_return!(process_opt_filter(opt_filter));
208    let state =
209        rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref()));
210
211    rresult!(state
212        .iter()
213        .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from))
214        .collect::<Result<RVec<_>>>())
215}
216
217unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) {
218    let private_data =
219        Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData);
220    drop(private_data);
221}
222
223impl From<Box<dyn GroupsAccumulator>> for FFI_GroupsAccumulator {
224    fn from(accumulator: Box<dyn GroupsAccumulator>) -> Self {
225        let supports_convert_to_state = accumulator.supports_convert_to_state();
226        let private_data = GroupsAccumulatorPrivateData { accumulator };
227
228        Self {
229            update_batch: update_batch_fn_wrapper,
230            evaluate: evaluate_fn_wrapper,
231            size: size_fn_wrapper,
232            state: state_fn_wrapper,
233            merge_batch: merge_batch_fn_wrapper,
234            convert_to_state: convert_to_state_fn_wrapper,
235            supports_convert_to_state,
236
237            release: release_fn_wrapper,
238            private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
239        }
240    }
241}
242
243impl Drop for FFI_GroupsAccumulator {
244    fn drop(&mut self) {
245        unsafe { (self.release)(self) }
246    }
247}
248
249/// This struct is used to access an UDF provided by a foreign
250/// library across a FFI boundary.
251///
252/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has
253/// no knowledge or access to the private data. All interaction with the UDF
254/// must occur through the functions defined in FFI_GroupsAccumulator.
255#[derive(Debug)]
256pub struct ForeignGroupsAccumulator {
257    accumulator: FFI_GroupsAccumulator,
258}
259
260unsafe impl Send for ForeignGroupsAccumulator {}
261unsafe impl Sync for ForeignGroupsAccumulator {}
262
263impl From<FFI_GroupsAccumulator> for ForeignGroupsAccumulator {
264    fn from(accumulator: FFI_GroupsAccumulator) -> Self {
265        Self { accumulator }
266    }
267}
268
269impl GroupsAccumulator for ForeignGroupsAccumulator {
270    fn update_batch(
271        &mut self,
272        values: &[ArrayRef],
273        group_indices: &[usize],
274        opt_filter: Option<&BooleanArray>,
275        total_num_groups: usize,
276    ) -> Result<()> {
277        unsafe {
278            let values = values
279                .iter()
280                .map(WrappedArray::try_from)
281                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
282            let group_indices = group_indices.iter().cloned().collect();
283            let opt_filter = opt_filter
284                .map(|bool_array| to_ffi(&bool_array.to_data()))
285                .transpose()?
286                .map(|(array, schema)| WrappedArray {
287                    array,
288                    schema: WrappedSchema(schema),
289                })
290                .into();
291
292            df_result!((self.accumulator.update_batch)(
293                &mut self.accumulator,
294                values.into(),
295                group_indices,
296                opt_filter,
297                total_num_groups
298            ))
299        }
300    }
301
302    fn size(&self) -> usize {
303        unsafe { (self.accumulator.size)(&self.accumulator) }
304    }
305
306    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
307        unsafe {
308            let return_array = df_result!((self.accumulator.evaluate)(
309                &mut self.accumulator,
310                emit_to.into()
311            ))?;
312
313            return_array.try_into().map_err(DataFusionError::from)
314        }
315    }
316
317    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
318        unsafe {
319            let returned_arrays = df_result!((self.accumulator.state)(
320                &mut self.accumulator,
321                emit_to.into()
322            ))?;
323
324            returned_arrays
325                .into_iter()
326                .map(|wrapped_array| {
327                    wrapped_array.try_into().map_err(DataFusionError::from)
328                })
329                .collect::<Result<Vec<_>>>()
330        }
331    }
332
333    fn merge_batch(
334        &mut self,
335        values: &[ArrayRef],
336        group_indices: &[usize],
337        opt_filter: Option<&BooleanArray>,
338        total_num_groups: usize,
339    ) -> Result<()> {
340        unsafe {
341            let values = values
342                .iter()
343                .map(WrappedArray::try_from)
344                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
345            let group_indices = group_indices.iter().cloned().collect();
346            let opt_filter = opt_filter
347                .map(|bool_array| to_ffi(&bool_array.to_data()))
348                .transpose()?
349                .map(|(array, schema)| WrappedArray {
350                    array,
351                    schema: WrappedSchema(schema),
352                })
353                .into();
354
355            df_result!((self.accumulator.merge_batch)(
356                &mut self.accumulator,
357                values.into(),
358                group_indices,
359                opt_filter,
360                total_num_groups
361            ))
362        }
363    }
364
365    fn convert_to_state(
366        &self,
367        values: &[ArrayRef],
368        opt_filter: Option<&BooleanArray>,
369    ) -> Result<Vec<ArrayRef>> {
370        unsafe {
371            let values = values
372                .iter()
373                .map(WrappedArray::try_from)
374                .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
375
376            let opt_filter = opt_filter
377                .map(|bool_array| to_ffi(&bool_array.to_data()))
378                .transpose()?
379                .map(|(array, schema)| WrappedArray {
380                    array,
381                    schema: WrappedSchema(schema),
382                })
383                .into();
384
385            let returned_array = df_result!((self.accumulator.convert_to_state)(
386                &self.accumulator,
387                values,
388                opt_filter
389            ))?;
390
391            returned_array
392                .into_iter()
393                .map(|arr| arr.try_into().map_err(DataFusionError::from))
394                .collect()
395        }
396    }
397
398    fn supports_convert_to_state(&self) -> bool {
399        self.accumulator.supports_convert_to_state
400    }
401}
402
403#[repr(C)]
404#[derive(Debug, StableAbi)]
405#[allow(non_camel_case_types)]
406pub enum FFI_EmitTo {
407    All,
408    First(usize),
409}
410
411impl From<EmitTo> for FFI_EmitTo {
412    fn from(value: EmitTo) -> Self {
413        match value {
414            EmitTo::All => Self::All,
415            EmitTo::First(v) => Self::First(v),
416        }
417    }
418}
419
420impl From<FFI_EmitTo> for EmitTo {
421    fn from(value: FFI_EmitTo) -> Self {
422        match value {
423            FFI_EmitTo::All => Self::All,
424            FFI_EmitTo::First(v) => Self::First(v),
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use arrow::array::{make_array, Array, BooleanArray};
432    use datafusion::{
433        common::create_array,
434        error::Result,
435        logical_expr::{EmitTo, GroupsAccumulator},
436    };
437    use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
438
439    use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator};
440
441    #[test]
442    fn test_foreign_avg_accumulator() -> Result<()> {
443        let boxed_accum: Box<dyn GroupsAccumulator> =
444            Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true));
445        let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
446        let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into();
447
448        // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4.
449        let values = create_array!(Boolean, vec![true, true, true, false, true, true]);
450        let opt_filter =
451            create_array!(Boolean, vec![true, true, true, true, false, false]);
452        foreign_accum.update_batch(
453            &[values],
454            &[0, 0, 1, 1, 2, 2],
455            Some(opt_filter.as_ref()),
456            3,
457        )?;
458
459        let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
460        let groups_bool = groups_bool.as_any().downcast_ref::<BooleanArray>().unwrap();
461
462        assert_eq!(
463            groups_bool,
464            create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref()
465        );
466
467        let state = foreign_accum.state(EmitTo::All)?;
468        assert_eq!(state.len(), 1);
469
470        // To verify merging batches works, create a second state to add in
471        // This should cause our average to go down to 25.0
472        let second_states =
473            vec![make_array(create_array!(Boolean, vec![false]).to_data())];
474
475        let opt_filter = create_array!(Boolean, vec![true]);
476        foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?;
477        let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
478        assert_eq!(groups_bool.len(), 1);
479        assert_eq!(
480            groups_bool.as_ref(),
481            make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
482        );
483
484        let values = create_array!(Boolean, vec![false]);
485        let opt_filter = create_array!(Boolean, vec![true]);
486        let groups_bool =
487            foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?;
488
489        assert_eq!(
490            groups_bool[0].as_ref(),
491            make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
492        );
493
494        Ok(())
495    }
496
497    fn test_emit_to_round_trip(value: EmitTo) -> Result<()> {
498        let ffi_value: FFI_EmitTo = value.into();
499        let round_trip_value: EmitTo = ffi_value.into();
500
501        assert_eq!(value, round_trip_value);
502        Ok(())
503    }
504
505    /// This test ensures all enum values are properly translated
506    #[test]
507    fn test_all_emit_to_round_trip() -> Result<()> {
508        test_emit_to_round_trip(EmitTo::All)?;
509        test_emit_to_round_trip(EmitTo::First(10))?;
510
511        Ok(())
512    }
513}