datafusion_ffi/udaf/
accumulator.rs1use std::ffi::c_void;
19use std::ops::Deref;
20use std::ptr::null_mut;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{RResult, RVec};
24use arrow::array::ArrayRef;
25use arrow::error::ArrowError;
26use datafusion_common::error::{DataFusionError, Result};
27use datafusion_common::scalar::ScalarValue;
28use datafusion_expr::Accumulator;
29use prost::Message;
30
31use crate::arrow_wrappers::WrappedArray;
32use crate::util::FFIResult;
33use crate::{df_result, rresult, rresult_return};
34
35#[repr(C)]
39#[derive(Debug, StableAbi)]
40pub struct FFI_Accumulator {
41 pub update_batch: unsafe extern "C" fn(
42 accumulator: &mut Self,
43 values: RVec<WrappedArray>,
44 ) -> FFIResult<()>,
45
46 pub evaluate: unsafe extern "C" fn(accumulator: &mut Self) -> FFIResult<RVec<u8>>,
48
49 pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
50
51 pub state: unsafe extern "C" fn(accumulator: &mut Self) -> FFIResult<RVec<RVec<u8>>>,
52
53 pub merge_batch: unsafe extern "C" fn(
54 accumulator: &mut Self,
55 states: RVec<WrappedArray>,
56 ) -> FFIResult<()>,
57
58 pub retract_batch: unsafe extern "C" fn(
59 accumulator: &mut Self,
60 values: RVec<WrappedArray>,
61 ) -> FFIResult<()>,
62
63 pub supports_retract_batch: bool,
64
65 pub release: unsafe extern "C" fn(accumulator: &mut Self),
67
68 pub private_data: *mut c_void,
71
72 pub library_marker_id: extern "C" fn() -> usize,
76}
77
78unsafe impl Send for FFI_Accumulator {}
79unsafe impl Sync for FFI_Accumulator {}
80
81pub struct AccumulatorPrivateData {
82 pub accumulator: Box<dyn Accumulator>,
83}
84
85impl FFI_Accumulator {
86 #[inline]
87 unsafe fn inner_mut(&mut self) -> &mut Box<dyn Accumulator> {
88 unsafe {
89 let private_data = self.private_data as *mut AccumulatorPrivateData;
90 &mut (*private_data).accumulator
91 }
92 }
93
94 #[inline]
95 unsafe fn inner(&self) -> &dyn Accumulator {
96 unsafe {
97 let private_data = self.private_data as *const AccumulatorPrivateData;
98 (*private_data).accumulator.deref()
99 }
100 }
101}
102
103unsafe extern "C" fn update_batch_fn_wrapper(
104 accumulator: &mut FFI_Accumulator,
105 values: RVec<WrappedArray>,
106) -> FFIResult<()> {
107 unsafe {
108 let accumulator = accumulator.inner_mut();
109
110 let values_arrays = values
111 .into_iter()
112 .map(|v| v.try_into().map_err(DataFusionError::from))
113 .collect::<Result<Vec<ArrayRef>>>();
114 let values_arrays = rresult_return!(values_arrays);
115
116 rresult!(accumulator.update_batch(&values_arrays))
117 }
118}
119
120unsafe extern "C" fn evaluate_fn_wrapper(
121 accumulator: &mut FFI_Accumulator,
122) -> FFIResult<RVec<u8>> {
123 unsafe {
124 let accumulator = accumulator.inner_mut();
125
126 let scalar_result = rresult_return!(accumulator.evaluate());
127 let proto_result: datafusion_proto::protobuf::ScalarValue =
128 rresult_return!((&scalar_result).try_into());
129
130 RResult::ROk(proto_result.encode_to_vec().into())
131 }
132}
133
134unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize {
135 unsafe { accumulator.inner().size() }
136}
137
138unsafe extern "C" fn state_fn_wrapper(
139 accumulator: &mut FFI_Accumulator,
140) -> FFIResult<RVec<RVec<u8>>> {
141 unsafe {
142 let accumulator = accumulator.inner_mut();
143
144 let state = rresult_return!(accumulator.state());
145 let state = state
146 .into_iter()
147 .map(|state_val| {
148 datafusion_proto::protobuf::ScalarValue::try_from(&state_val)
149 .map_err(DataFusionError::from)
150 .map(|v| RVec::from(v.encode_to_vec()))
151 })
152 .collect::<Result<Vec<_>>>()
153 .map(|state_vec| state_vec.into());
154
155 rresult!(state)
156 }
157}
158
159unsafe extern "C" fn merge_batch_fn_wrapper(
160 accumulator: &mut FFI_Accumulator,
161 states: RVec<WrappedArray>,
162) -> FFIResult<()> {
163 unsafe {
164 let accumulator = accumulator.inner_mut();
165
166 let states = rresult_return!(
167 states
168 .into_iter()
169 .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from))
170 .collect::<Result<Vec<_>>>()
171 );
172
173 rresult!(accumulator.merge_batch(&states))
174 }
175}
176
177unsafe extern "C" fn retract_batch_fn_wrapper(
178 accumulator: &mut FFI_Accumulator,
179 values: RVec<WrappedArray>,
180) -> FFIResult<()> {
181 unsafe {
182 let accumulator = accumulator.inner_mut();
183
184 let values_arrays = values
185 .into_iter()
186 .map(|v| v.try_into().map_err(DataFusionError::from))
187 .collect::<Result<Vec<ArrayRef>>>();
188 let values_arrays = rresult_return!(values_arrays);
189
190 rresult!(accumulator.retract_batch(&values_arrays))
191 }
192}
193
194unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) {
195 unsafe {
196 if !accumulator.private_data.is_null() {
197 let private_data =
198 Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData);
199 drop(private_data);
200 accumulator.private_data = null_mut();
201 }
202 }
203}
204
205impl From<Box<dyn Accumulator>> for FFI_Accumulator {
206 fn from(accumulator: Box<dyn Accumulator>) -> Self {
207 let supports_retract_batch = accumulator.supports_retract_batch();
208 let private_data = AccumulatorPrivateData { accumulator };
209
210 Self {
211 update_batch: update_batch_fn_wrapper,
212 evaluate: evaluate_fn_wrapper,
213 size: size_fn_wrapper,
214 state: state_fn_wrapper,
215 merge_batch: merge_batch_fn_wrapper,
216 retract_batch: retract_batch_fn_wrapper,
217 supports_retract_batch,
218 release: release_fn_wrapper,
219 private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
220 library_marker_id: crate::get_library_marker_id,
221 }
222 }
223}
224
225impl Drop for FFI_Accumulator {
226 fn drop(&mut self) {
227 unsafe { (self.release)(self) }
228 }
229}
230
231#[derive(Debug)]
238pub struct ForeignAccumulator {
239 accumulator: FFI_Accumulator,
240}
241
242impl From<FFI_Accumulator> for Box<dyn Accumulator> {
243 fn from(mut accumulator: FFI_Accumulator) -> Self {
244 if (accumulator.library_marker_id)() == crate::get_library_marker_id() {
245 unsafe {
246 let private_data = Box::from_raw(
247 accumulator.private_data as *mut AccumulatorPrivateData,
248 );
249 accumulator.private_data = null_mut();
251 private_data.accumulator
252 }
253 } else {
254 Box::new(ForeignAccumulator { accumulator })
255 }
256 }
257}
258
259impl Accumulator for ForeignAccumulator {
260 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
261 unsafe {
262 let values = values
263 .iter()
264 .map(WrappedArray::try_from)
265 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
266 df_result!((self.accumulator.update_batch)(
267 &mut self.accumulator,
268 values.into()
269 ))
270 }
271 }
272
273 fn evaluate(&mut self) -> Result<ScalarValue> {
274 unsafe {
275 let scalar_bytes =
276 df_result!((self.accumulator.evaluate)(&mut self.accumulator))?;
277
278 let proto_scalar =
279 datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref())
280 .map_err(|e| DataFusionError::External(Box::new(e)))?;
281
282 ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from)
283 }
284 }
285
286 fn size(&self) -> usize {
287 unsafe { (self.accumulator.size)(&self.accumulator) }
288 }
289
290 fn state(&mut self) -> Result<Vec<ScalarValue>> {
291 unsafe {
292 let state_protos =
293 df_result!((self.accumulator.state)(&mut self.accumulator))?;
294
295 state_protos
296 .into_iter()
297 .map(|proto_bytes| {
298 datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref())
299 .map_err(|e| DataFusionError::External(Box::new(e)))
300 .and_then(|proto_value| {
301 ScalarValue::try_from(&proto_value)
302 .map_err(DataFusionError::from)
303 })
304 })
305 .collect::<Result<Vec<_>>>()
306 }
307 }
308
309 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
310 unsafe {
311 let states = states
312 .iter()
313 .map(WrappedArray::try_from)
314 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
315 df_result!((self.accumulator.merge_batch)(
316 &mut self.accumulator,
317 states.into()
318 ))
319 }
320 }
321
322 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
323 unsafe {
324 let values = values
325 .iter()
326 .map(WrappedArray::try_from)
327 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
328 df_result!((self.accumulator.retract_batch)(
329 &mut self.accumulator,
330 values.into()
331 ))
332 }
333 }
334
335 fn supports_retract_batch(&self) -> bool {
336 self.accumulator.supports_retract_batch
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use arrow::array::{Array, make_array};
343 use datafusion::common::create_array;
344 use datafusion::error::Result;
345 use datafusion::functions_aggregate::average::AvgAccumulator;
346 use datafusion::logical_expr::Accumulator;
347 use datafusion::scalar::ScalarValue;
348
349 use super::{FFI_Accumulator, ForeignAccumulator};
350
351 #[test]
352 fn test_foreign_avg_accumulator() -> Result<()> {
353 let original_accum = AvgAccumulator::default();
354 let original_size = original_accum.size();
355 let original_supports_retract = original_accum.supports_retract_batch();
356
357 let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
358 let mut ffi_accum: FFI_Accumulator = boxed_accum.into();
359 ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
360 let mut foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
361
362 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
364 foreign_accum.update_batch(&[values])?;
365
366 let avg = foreign_accum.evaluate()?;
367 assert_eq!(avg, ScalarValue::Float64(Some(30.0)));
368
369 let state = foreign_accum.state()?;
370 assert_eq!(state.len(), 2);
371 assert_eq!(state[0], ScalarValue::UInt64(Some(5)));
372 assert_eq!(state[1], ScalarValue::Float64(Some(150.0)));
373
374 let second_states = vec![
377 make_array(create_array!(UInt64, vec![1]).to_data()),
378 make_array(create_array!(Float64, vec![0.0]).to_data()),
379 ];
380
381 foreign_accum.merge_batch(&second_states)?;
382 let avg = foreign_accum.evaluate()?;
383 assert_eq!(avg, ScalarValue::Float64(Some(25.0)));
384
385 let values = create_array!(Float64, vec![0.0]);
388 foreign_accum.retract_batch(&[values])?;
389 let avg = foreign_accum.evaluate()?;
390 assert_eq!(avg, ScalarValue::Float64(Some(30.0)));
391
392 assert_eq!(original_size, foreign_accum.size());
393 assert_eq!(
394 original_supports_retract,
395 foreign_accum.supports_retract_batch()
396 );
397
398 Ok(())
399 }
400
401 #[test]
402 fn test_ffi_accumulator_local_bypass() -> Result<()> {
403 let original_accum = AvgAccumulator::default();
404 let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
405 let original_size = boxed_accum.size();
406
407 let ffi_accum: FFI_Accumulator = boxed_accum.into();
408
409 let foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
411 unsafe {
412 let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator
413 as *const AvgAccumulator);
414 assert_eq!(original_size, concrete.size());
415 }
416
417 let original_accum = AvgAccumulator::default();
419 let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
420 let mut ffi_accum: FFI_Accumulator = boxed_accum.into();
421 ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
422 let foreign_accum: Box<dyn Accumulator> = ffi_accum.into();
423 unsafe {
424 let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator
425 as *const ForeignAccumulator);
426 assert_eq!(original_size, concrete.size());
427 }
428
429 Ok(())
430 }
431}