1use std::{ffi::c_void, ops::Deref, sync::Arc};
19
20use crate::{
21 arrow_wrappers::{WrappedArray, WrappedSchema},
22 df_result, rresult, rresult_return,
23};
24use abi_stable::{
25 std_types::{ROption, RResult, RString, RVec},
26 StableAbi,
27};
28use arrow::{
29 array::{Array, ArrayRef, BooleanArray},
30 error::ArrowError,
31 ffi::to_ffi,
32};
33use datafusion::{
34 error::{DataFusionError, Result},
35 logical_expr::{EmitTo, GroupsAccumulator},
36};
37
38#[repr(C)]
42#[derive(Debug, StableAbi)]
43#[allow(non_camel_case_types)]
44pub struct FFI_GroupsAccumulator {
45 pub update_batch: unsafe extern "C" fn(
46 accumulator: &mut Self,
47 values: RVec<WrappedArray>,
48 group_indices: RVec<usize>,
49 opt_filter: ROption<WrappedArray>,
50 total_num_groups: usize,
51 ) -> RResult<(), RString>,
52
53 pub evaluate: unsafe extern "C" fn(
55 accumulator: &mut Self,
56 emit_to: FFI_EmitTo,
57 ) -> RResult<WrappedArray, RString>,
58
59 pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
60
61 pub state: unsafe extern "C" fn(
62 accumulator: &mut Self,
63 emit_to: FFI_EmitTo,
64 ) -> RResult<RVec<WrappedArray>, RString>,
65
66 pub merge_batch: unsafe extern "C" fn(
67 accumulator: &mut Self,
68 values: RVec<WrappedArray>,
69 group_indices: RVec<usize>,
70 opt_filter: ROption<WrappedArray>,
71 total_num_groups: usize,
72 ) -> RResult<(), RString>,
73
74 pub convert_to_state: unsafe extern "C" fn(
75 accumulator: &Self,
76 values: RVec<WrappedArray>,
77 opt_filter: ROption<WrappedArray>,
78 )
79 -> RResult<RVec<WrappedArray>, RString>,
80
81 pub supports_convert_to_state: bool,
82
83 pub release: unsafe extern "C" fn(accumulator: &mut Self),
85
86 pub private_data: *mut c_void,
89}
90
91unsafe impl Send for FFI_GroupsAccumulator {}
92unsafe impl Sync for FFI_GroupsAccumulator {}
93
94pub struct GroupsAccumulatorPrivateData {
95 pub accumulator: Box<dyn GroupsAccumulator>,
96}
97
98impl FFI_GroupsAccumulator {
99 #[inline]
100 unsafe fn inner_mut(&mut self) -> &mut Box<dyn GroupsAccumulator> {
101 let private_data = self.private_data as *mut GroupsAccumulatorPrivateData;
102 &mut (*private_data).accumulator
103 }
104
105 #[inline]
106 unsafe fn inner(&self) -> &dyn GroupsAccumulator {
107 let private_data = self.private_data as *const GroupsAccumulatorPrivateData;
108 (*private_data).accumulator.deref()
109 }
110}
111
112fn process_values(values: RVec<WrappedArray>) -> Result<Vec<Arc<dyn Array>>> {
113 values
114 .into_iter()
115 .map(|v| v.try_into().map_err(DataFusionError::from))
116 .collect::<Result<Vec<ArrayRef>>>()
117}
118
119fn process_opt_filter(opt_filter: ROption<WrappedArray>) -> Result<Option<BooleanArray>> {
121 opt_filter
122 .into_option()
123 .map(|filter| {
124 ArrayRef::try_from(filter)
125 .map_err(DataFusionError::from)
126 .map(|arr| BooleanArray::from(arr.into_data()))
127 })
128 .transpose()
129}
130
131unsafe extern "C" fn update_batch_fn_wrapper(
132 accumulator: &mut FFI_GroupsAccumulator,
133 values: RVec<WrappedArray>,
134 group_indices: RVec<usize>,
135 opt_filter: ROption<WrappedArray>,
136 total_num_groups: usize,
137) -> RResult<(), RString> {
138 let accumulator = accumulator.inner_mut();
139 let values = rresult_return!(process_values(values));
140 let group_indices: Vec<usize> = group_indices.into_iter().collect();
141 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
142
143 rresult!(accumulator.update_batch(
144 &values,
145 &group_indices,
146 opt_filter.as_ref(),
147 total_num_groups
148 ))
149}
150
151unsafe extern "C" fn evaluate_fn_wrapper(
152 accumulator: &mut FFI_GroupsAccumulator,
153 emit_to: FFI_EmitTo,
154) -> RResult<WrappedArray, RString> {
155 let accumulator = accumulator.inner_mut();
156
157 let result = rresult_return!(accumulator.evaluate(emit_to.into()));
158
159 rresult!(WrappedArray::try_from(&result))
160}
161
162unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize {
163 let accumulator = accumulator.inner();
164 accumulator.size()
165}
166
167unsafe extern "C" fn state_fn_wrapper(
168 accumulator: &mut FFI_GroupsAccumulator,
169 emit_to: FFI_EmitTo,
170) -> RResult<RVec<WrappedArray>, RString> {
171 let accumulator = accumulator.inner_mut();
172
173 let state = rresult_return!(accumulator.state(emit_to.into()));
174 rresult!(state
175 .into_iter()
176 .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from))
177 .collect::<Result<RVec<_>>>())
178}
179
180unsafe extern "C" fn merge_batch_fn_wrapper(
181 accumulator: &mut FFI_GroupsAccumulator,
182 values: RVec<WrappedArray>,
183 group_indices: RVec<usize>,
184 opt_filter: ROption<WrappedArray>,
185 total_num_groups: usize,
186) -> RResult<(), RString> {
187 let accumulator = accumulator.inner_mut();
188 let values = rresult_return!(process_values(values));
189 let group_indices: Vec<usize> = group_indices.into_iter().collect();
190 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
191
192 rresult!(accumulator.merge_batch(
193 &values,
194 &group_indices,
195 opt_filter.as_ref(),
196 total_num_groups
197 ))
198}
199
200unsafe extern "C" fn convert_to_state_fn_wrapper(
201 accumulator: &FFI_GroupsAccumulator,
202 values: RVec<WrappedArray>,
203 opt_filter: ROption<WrappedArray>,
204) -> RResult<RVec<WrappedArray>, RString> {
205 let accumulator = accumulator.inner();
206 let values = rresult_return!(process_values(values));
207 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
208 let state =
209 rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref()));
210
211 rresult!(state
212 .iter()
213 .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from))
214 .collect::<Result<RVec<_>>>())
215}
216
217unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) {
218 let private_data =
219 Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData);
220 drop(private_data);
221}
222
223impl From<Box<dyn GroupsAccumulator>> for FFI_GroupsAccumulator {
224 fn from(accumulator: Box<dyn GroupsAccumulator>) -> Self {
225 let supports_convert_to_state = accumulator.supports_convert_to_state();
226 let private_data = GroupsAccumulatorPrivateData { accumulator };
227
228 Self {
229 update_batch: update_batch_fn_wrapper,
230 evaluate: evaluate_fn_wrapper,
231 size: size_fn_wrapper,
232 state: state_fn_wrapper,
233 merge_batch: merge_batch_fn_wrapper,
234 convert_to_state: convert_to_state_fn_wrapper,
235 supports_convert_to_state,
236
237 release: release_fn_wrapper,
238 private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
239 }
240 }
241}
242
243impl Drop for FFI_GroupsAccumulator {
244 fn drop(&mut self) {
245 unsafe { (self.release)(self) }
246 }
247}
248
249#[derive(Debug)]
256pub struct ForeignGroupsAccumulator {
257 accumulator: FFI_GroupsAccumulator,
258}
259
260unsafe impl Send for ForeignGroupsAccumulator {}
261unsafe impl Sync for ForeignGroupsAccumulator {}
262
263impl From<FFI_GroupsAccumulator> for ForeignGroupsAccumulator {
264 fn from(accumulator: FFI_GroupsAccumulator) -> Self {
265 Self { accumulator }
266 }
267}
268
269impl GroupsAccumulator for ForeignGroupsAccumulator {
270 fn update_batch(
271 &mut self,
272 values: &[ArrayRef],
273 group_indices: &[usize],
274 opt_filter: Option<&BooleanArray>,
275 total_num_groups: usize,
276 ) -> Result<()> {
277 unsafe {
278 let values = values
279 .iter()
280 .map(WrappedArray::try_from)
281 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
282 let group_indices = group_indices.iter().cloned().collect();
283 let opt_filter = opt_filter
284 .map(|bool_array| to_ffi(&bool_array.to_data()))
285 .transpose()?
286 .map(|(array, schema)| WrappedArray {
287 array,
288 schema: WrappedSchema(schema),
289 })
290 .into();
291
292 df_result!((self.accumulator.update_batch)(
293 &mut self.accumulator,
294 values.into(),
295 group_indices,
296 opt_filter,
297 total_num_groups
298 ))
299 }
300 }
301
302 fn size(&self) -> usize {
303 unsafe { (self.accumulator.size)(&self.accumulator) }
304 }
305
306 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
307 unsafe {
308 let return_array = df_result!((self.accumulator.evaluate)(
309 &mut self.accumulator,
310 emit_to.into()
311 ))?;
312
313 return_array.try_into().map_err(DataFusionError::from)
314 }
315 }
316
317 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
318 unsafe {
319 let returned_arrays = df_result!((self.accumulator.state)(
320 &mut self.accumulator,
321 emit_to.into()
322 ))?;
323
324 returned_arrays
325 .into_iter()
326 .map(|wrapped_array| {
327 wrapped_array.try_into().map_err(DataFusionError::from)
328 })
329 .collect::<Result<Vec<_>>>()
330 }
331 }
332
333 fn merge_batch(
334 &mut self,
335 values: &[ArrayRef],
336 group_indices: &[usize],
337 opt_filter: Option<&BooleanArray>,
338 total_num_groups: usize,
339 ) -> Result<()> {
340 unsafe {
341 let values = values
342 .iter()
343 .map(WrappedArray::try_from)
344 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
345 let group_indices = group_indices.iter().cloned().collect();
346 let opt_filter = opt_filter
347 .map(|bool_array| to_ffi(&bool_array.to_data()))
348 .transpose()?
349 .map(|(array, schema)| WrappedArray {
350 array,
351 schema: WrappedSchema(schema),
352 })
353 .into();
354
355 df_result!((self.accumulator.merge_batch)(
356 &mut self.accumulator,
357 values.into(),
358 group_indices,
359 opt_filter,
360 total_num_groups
361 ))
362 }
363 }
364
365 fn convert_to_state(
366 &self,
367 values: &[ArrayRef],
368 opt_filter: Option<&BooleanArray>,
369 ) -> Result<Vec<ArrayRef>> {
370 unsafe {
371 let values = values
372 .iter()
373 .map(WrappedArray::try_from)
374 .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
375
376 let opt_filter = opt_filter
377 .map(|bool_array| to_ffi(&bool_array.to_data()))
378 .transpose()?
379 .map(|(array, schema)| WrappedArray {
380 array,
381 schema: WrappedSchema(schema),
382 })
383 .into();
384
385 let returned_array = df_result!((self.accumulator.convert_to_state)(
386 &self.accumulator,
387 values,
388 opt_filter
389 ))?;
390
391 returned_array
392 .into_iter()
393 .map(|arr| arr.try_into().map_err(DataFusionError::from))
394 .collect()
395 }
396 }
397
398 fn supports_convert_to_state(&self) -> bool {
399 self.accumulator.supports_convert_to_state
400 }
401}
402
403#[repr(C)]
404#[derive(Debug, StableAbi)]
405#[allow(non_camel_case_types)]
406pub enum FFI_EmitTo {
407 All,
408 First(usize),
409}
410
411impl From<EmitTo> for FFI_EmitTo {
412 fn from(value: EmitTo) -> Self {
413 match value {
414 EmitTo::All => Self::All,
415 EmitTo::First(v) => Self::First(v),
416 }
417 }
418}
419
420impl From<FFI_EmitTo> for EmitTo {
421 fn from(value: FFI_EmitTo) -> Self {
422 match value {
423 FFI_EmitTo::All => Self::All,
424 FFI_EmitTo::First(v) => Self::First(v),
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use arrow::array::{make_array, Array, BooleanArray};
432 use datafusion::{
433 common::create_array,
434 error::Result,
435 logical_expr::{EmitTo, GroupsAccumulator},
436 };
437 use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
438
439 use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator};
440
441 #[test]
442 fn test_foreign_avg_accumulator() -> Result<()> {
443 let boxed_accum: Box<dyn GroupsAccumulator> =
444 Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true));
445 let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
446 let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into();
447
448 let values = create_array!(Boolean, vec![true, true, true, false, true, true]);
450 let opt_filter =
451 create_array!(Boolean, vec![true, true, true, true, false, false]);
452 foreign_accum.update_batch(
453 &[values],
454 &[0, 0, 1, 1, 2, 2],
455 Some(opt_filter.as_ref()),
456 3,
457 )?;
458
459 let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
460 let groups_bool = groups_bool.as_any().downcast_ref::<BooleanArray>().unwrap();
461
462 assert_eq!(
463 groups_bool,
464 create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref()
465 );
466
467 let state = foreign_accum.state(EmitTo::All)?;
468 assert_eq!(state.len(), 1);
469
470 let second_states =
473 vec![make_array(create_array!(Boolean, vec![false]).to_data())];
474
475 let opt_filter = create_array!(Boolean, vec![true]);
476 foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?;
477 let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
478 assert_eq!(groups_bool.len(), 1);
479 assert_eq!(
480 groups_bool.as_ref(),
481 make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
482 );
483
484 let values = create_array!(Boolean, vec![false]);
485 let opt_filter = create_array!(Boolean, vec![true]);
486 let groups_bool =
487 foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?;
488
489 assert_eq!(
490 groups_bool[0].as_ref(),
491 make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
492 );
493
494 Ok(())
495 }
496
497 fn test_emit_to_round_trip(value: EmitTo) -> Result<()> {
498 let ffi_value: FFI_EmitTo = value.into();
499 let round_trip_value: EmitTo = ffi_value.into();
500
501 assert_eq!(value, round_trip_value);
502 Ok(())
503 }
504
505 #[test]
507 fn test_all_emit_to_round_trip() -> Result<()> {
508 test_emit_to_round_trip(EmitTo::All)?;
509 test_emit_to_round_trip(EmitTo::First(10))?;
510
511 Ok(())
512 }
513}