datafusion_ffi/udaf/
accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::ops::Deref;
20use std::ptr::null_mut;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{RResult, RVec};
24use arrow::array::ArrayRef;
25use arrow::error::ArrowError;
26use datafusion_common::error::{DataFusionError, Result};
27use datafusion_common::scalar::ScalarValue;
28use datafusion_expr::Accumulator;
29use prost::Message;
30
31use crate::arrow_wrappers::WrappedArray;
32use crate::util::FFIResult;
33use crate::{df_result, rresult, rresult_return};
34
35/// A stable struct for sharing [`Accumulator`] across FFI boundaries.
36/// For an explanation of each field, see the corresponding function
37/// defined in [`Accumulator`].
38#[repr(C)]
39#[derive(Debug, StableAbi)]
40pub struct FFI_Accumulator {
41    pub update_batch: unsafe extern "C" fn(
42        accumulator: &mut Self,
43        values: RVec<WrappedArray>,
44    ) -> FFIResult<()>,
45
46    // Evaluate and return a ScalarValues as protobuf bytes
47    pub evaluate: unsafe extern "C" fn(accumulator: &mut Self) -> FFIResult<RVec<u8>>,
48
49    pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
50
51    pub state: unsafe extern "C" fn(accumulator: &mut Self) -> FFIResult<RVec<RVec<u8>>>,
52
53    pub merge_batch: unsafe extern "C" fn(
54        accumulator: &mut Self,
55        states: RVec<WrappedArray>,
56    ) -> FFIResult<()>,
57
58    pub retract_batch: unsafe extern "C" fn(
59        accumulator: &mut Self,
60        values: RVec<WrappedArray>,
61    ) -> FFIResult<()>,
62
63    pub supports_retract_batch: bool,
64
65    /// Release the memory of the private data when it is no longer being used.
66    pub release: unsafe extern "C" fn(accumulator: &mut Self),
67
68    /// Internal data. This is only to be accessed by the provider of the accumulator.
69    /// A [`ForeignAccumulator`] should never attempt to access this data.
70    pub private_data: *mut c_void,
71
72    /// Utility to identify when FFI objects are accessed locally through
73    /// the foreign interface. See [`crate::get_library_marker_id`] and
74    /// the crate's `README.md` for more information.
75    pub library_marker_id: extern "C" fn() -> usize,
76}
77
78unsafe impl Send for FFI_Accumulator {}
79unsafe impl Sync for FFI_Accumulator {}
80
81pub struct AccumulatorPrivateData {
82    pub accumulator: Box<dyn Accumulator>,
83}
84
85impl FFI_Accumulator {
86    #[inline]
87    unsafe fn inner_mut(&mut self) -> &mut Box<dyn Accumulator> {
88        unsafe {
89            let private_data = self.private_data as *mut AccumulatorPrivateData;
90            &mut (*private_data).accumulator
91        }
92    }
93
94    #[inline]
95    unsafe fn inner(&self) -> &dyn Accumulator {
96        unsafe {
97            let private_data = self.private_data as *const AccumulatorPrivateData;
98            (*private_data).accumulator.deref()
99        }
100    }
101}
102
103unsafe extern "C" fn update_batch_fn_wrapper(
104    accumulator: &mut FFI_Accumulator,
105    values: RVec<WrappedArray>,
106) -> FFIResult<()> {
107    unsafe {
108        let accumulator = accumulator.inner_mut();
109
110        let values_arrays = values
111            .into_iter()
112            .map(|v| v.try_into().map_err(DataFusionError::from))
113            .collect::<Result<Vec<ArrayRef>>>();
114        let values_arrays = rresult_return!(values_arrays);
115
116        rresult!(accumulator.update_batch(&values_arrays))
117    }
118}
119
120unsafe extern "C" fn evaluate_fn_wrapper(
121    accumulator: &mut FFI_Accumulator,
122) -> FFIResult<RVec<u8>> {
123    unsafe {
124        let accumulator = accumulator.inner_mut();
125
126        let scalar_result = rresult_return!(accumulator.evaluate());
127        let proto_result: datafusion_proto::protobuf::ScalarValue =
128            rresult_return!((&scalar_result).try_into());
129
130        RResult::ROk(proto_result.encode_to_vec().into())
131    }
132}
133
134unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize {
135    unsafe { accumulator.inner().size() }
136}
137
138unsafe extern "C" fn state_fn_wrapper(
139    accumulator: &mut FFI_Accumulator,
140) -> FFIResult<RVec<RVec<u8>>> {
141    unsafe {
142        let accumulator = accumulator.inner_mut();
143
144        let state = rresult_return!(accumulator.state());
145        let state = state
146            .into_iter()
147            .map(|state_val| {
148                datafusion_proto::protobuf::ScalarValue::try_from(&state_val)
149                    .map_err(DataFusionError::from)
150                    .map(|v| RVec::from(v.encode_to_vec()))
151            })
152            .collect::<Result<Vec<_>>>()
153            .map(|state_vec| state_vec.into());
154
155        rresult!(state)
156    }
157}
158
159unsafe extern "C" fn merge_batch_fn_wrapper(
160    accumulator: &mut FFI_Accumulator,
161    states: RVec<WrappedArray>,
162) -> FFIResult<()> {
163    unsafe {
164        let accumulator = accumulator.inner_mut();
165
166        let states = rresult_return!(
167            states
168                .into_iter()
169                .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from))
170                .collect::<Result<Vec<_>>>()
171        );
172
173        rresult!(accumulator.merge_batch(&states))
174    }
175}
176
177unsafe extern "C" fn retract_batch_fn_wrapper(
178    accumulator: &mut FFI_Accumulator,
179    values: RVec<WrappedArray>,
180) -> FFIResult<()> {
181    unsafe {
182        let accumulator = accumulator.inner_mut();
183
184        let values_arrays = values
185            .into_iter()
186            .map(|v| v.try_into().map_err(DataFusionError::from))
187            .collect::<Result<Vec<ArrayRef>>>();
188        let values_arrays = rresult_return!(values_arrays);
189
190        rresult!(accumulator.retract_batch(&values_arrays))
191    }
192}
193
194unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) {
195    unsafe {
196        if !accumulator.private_data.is_null() {
197            let private_data =
198                Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData);
199            drop(private_data);
200            accumulator.private_data = null_mut();
201        }
202    }
203}
204
205impl From<Box<dyn Accumulator>> for FFI_Accumulator {
206    fn from(accumulator: Box<dyn Accumulator>) -> Self {
207        let supports_retract_batch = accumulator.supports_retract_batch();
208        let private_data = AccumulatorPrivateData { accumulator };
209
210        Self {
211            update_batch: update_batch_fn_wrapper,
212            evaluate: evaluate_fn_wrapper,
213            size: size_fn_wrapper,
214            state: state_fn_wrapper,
215            merge_batch: merge_batch_fn_wrapper,
216            retract_batch: retract_batch_fn_wrapper,
217            supports_retract_batch,
218            release: release_fn_wrapper,
219            private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
220            library_marker_id: crate::get_library_marker_id,
221        }
222    }
223}
224
225impl Drop for FFI_Accumulator {
226    fn drop(&mut self) {
227        unsafe { (self.release)(self) }
228    }
229}
230
231/// This struct is used to access an UDF provided by a foreign
232/// library across a FFI boundary.
233///
234/// The ForeignAccumulator is to be used by the caller of the UDF, so it has
235/// no knowledge or access to the private data. All interaction with the UDF
236/// must occur through the functions defined in FFI_Accumulator.
237#[derive(Debug)]
238pub struct ForeignAccumulator {
239    accumulator: FFI_Accumulator,
240}
241
242impl From<FFI_Accumulator> for Box<dyn Accumulator> {
243    fn from(mut accumulator: FFI_Accumulator) -> Self {
244        if (accumulator.library_marker_id)() == crate::get_library_marker_id() {
245            unsafe {
246                let private_data = Box::from_raw(
247                    accumulator.private_data as *mut AccumulatorPrivateData,
248                );
249                // We must set this to null to avoid a double free
250                accumulator.private_data = null_mut();
251                private_data.accumulator
252            }
253        } else {
254            Box::new(ForeignAccumulator { accumulator })
255        }
256    }
257}
258
259impl Accumulator for ForeignAccumulator {
260    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
261        unsafe {
262            let values = values
263                .iter()
264                .map(WrappedArray::try_from)
265                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
266            df_result!((self.accumulator.update_batch)(
267                &mut self.accumulator,
268                values.into()
269            ))
270        }
271    }
272
273    fn evaluate(&mut self) -> Result<ScalarValue> {
274        unsafe {
275            let scalar_bytes =
276                df_result!((self.accumulator.evaluate)(&mut self.accumulator))?;
277
278            let proto_scalar =
279                datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref())
280                    .map_err(|e| DataFusionError::External(Box::new(e)))?;
281
282            ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from)
283        }
284    }
285
286    fn size(&self) -> usize {
287        unsafe { (self.accumulator.size)(&self.accumulator) }
288    }
289
290    fn state(&mut self) -> Result<Vec<ScalarValue>> {
291        unsafe {
292            let state_protos =
293                df_result!((self.accumulator.state)(&mut self.accumulator))?;
294
295            state_protos
296                .into_iter()
297                .map(|proto_bytes| {
298                    datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref())
299                        .map_err(|e| DataFusionError::External(Box::new(e)))
300                        .and_then(|proto_value| {
301                            ScalarValue::try_from(&proto_value)
302                                .map_err(DataFusionError::from)
303                        })
304                })
305                .collect::<Result<Vec<_>>>()
306        }
307    }
308
309    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
310        unsafe {
311            let states = states
312                .iter()
313                .map(WrappedArray::try_from)
314                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
315            df_result!((self.accumulator.merge_batch)(
316                &mut self.accumulator,
317                states.into()
318            ))
319        }
320    }
321
322    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
323        unsafe {
324            let values = values
325                .iter()
326                .map(WrappedArray::try_from)
327                .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
328            df_result!((self.accumulator.retract_batch)(
329                &mut self.accumulator,
330                values.into()
331            ))
332        }
333    }
334
335    fn supports_retract_batch(&self) -> bool {
336        self.accumulator.supports_retract_batch
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use arrow::array::{Array, make_array};
343    use datafusion::common::create_array;
344    use datafusion::error::Result;
345    use datafusion::functions_aggregate::average::AvgAccumulator;
346    use datafusion::logical_expr::Accumulator;
347    use datafusion::scalar::ScalarValue;
348
349    use super::{FFI_Accumulator, ForeignAccumulator};
350
351    #[test]
352    fn test_foreign_avg_accumulator() -> Result<()> {
353        let original_accum = AvgAccumulator::default();
354        let original_size = original_accum.size();
355        let original_supports_retract = original_accum.supports_retract_batch();
356
357        let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
358        let mut ffi_accum: FFI_Accumulator = boxed_accum.into();
359        ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
360        let mut foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
361
362        // Send in an array to average. There are 5 values and it should average to 30.0
363        let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
364        foreign_accum.update_batch(&[values])?;
365
366        let avg = foreign_accum.evaluate()?;
367        assert_eq!(avg, ScalarValue::Float64(Some(30.0)));
368
369        let state = foreign_accum.state()?;
370        assert_eq!(state.len(), 2);
371        assert_eq!(state[0], ScalarValue::UInt64(Some(5)));
372        assert_eq!(state[1], ScalarValue::Float64(Some(150.0)));
373
374        // To verify merging batches works, create a second state to add in
375        // This should cause our average to go down to 25.0
376        let second_states = vec![
377            make_array(create_array!(UInt64, vec![1]).to_data()),
378            make_array(create_array!(Float64, vec![0.0]).to_data()),
379        ];
380
381        foreign_accum.merge_batch(&second_states)?;
382        let avg = foreign_accum.evaluate()?;
383        assert_eq!(avg, ScalarValue::Float64(Some(25.0)));
384
385        // If we remove a batch that is equivalent to the state we added
386        // we should go back to our original value of 30.0
387        let values = create_array!(Float64, vec![0.0]);
388        foreign_accum.retract_batch(&[values])?;
389        let avg = foreign_accum.evaluate()?;
390        assert_eq!(avg, ScalarValue::Float64(Some(30.0)));
391
392        assert_eq!(original_size, foreign_accum.size());
393        assert_eq!(
394            original_supports_retract,
395            foreign_accum.supports_retract_batch()
396        );
397
398        Ok(())
399    }
400
401    #[test]
402    fn test_ffi_accumulator_local_bypass() -> Result<()> {
403        let original_accum = AvgAccumulator::default();
404        let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
405        let original_size = boxed_accum.size();
406
407        let ffi_accum: FFI_Accumulator = boxed_accum.into();
408
409        // Verify local libraries can be downcast to their original
410        let foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
411        unsafe {
412            let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator
413                as *const AvgAccumulator);
414            assert_eq!(original_size, concrete.size());
415        }
416
417        // Verify different library markers generate foreign accumulator
418        let original_accum = AvgAccumulator::default();
419        let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
420        let mut ffi_accum: FFI_Accumulator = boxed_accum.into();
421        ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
422        let foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
423        unsafe {
424            let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator
425                as *const ForeignAccumulator);
426            assert_eq!(original_size, concrete.size());
427        }
428
429        Ok(())
430    }
431}