1use std::ffi::c_void;
19use std::ops::Deref;
20use std::ptr::null_mut;
21use std::sync::Arc;
22
23use abi_stable::StableAbi;
24use abi_stable::std_types::{ROption, RVec};
25use arrow::array::{Array, ArrayRef, BooleanArray};
26use arrow::error::ArrowError;
27use arrow::ffi::to_ffi;
28use datafusion_common::error::{DataFusionError, Result};
29use datafusion_expr::{EmitTo, GroupsAccumulator};
30
31use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
32use crate::util::FFIResult;
33use crate::{df_result, rresult, rresult_return};
34
35#[repr(C)]
39#[derive(Debug, StableAbi)]
40pub struct FFI_GroupsAccumulator {
41 pub update_batch: unsafe extern "C" fn(
42 accumulator: &mut Self,
43 values: RVec<WrappedArray>,
44 group_indices: RVec<usize>,
45 opt_filter: ROption<WrappedArray>,
46 total_num_groups: usize,
47 ) -> FFIResult<()>,
48
49 pub evaluate: unsafe extern "C" fn(
51 accumulator: &mut Self,
52 emit_to: FFI_EmitTo,
53 ) -> FFIResult<WrappedArray>,
54
55 pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
56
57 pub state: unsafe extern "C" fn(
58 accumulator: &mut Self,
59 emit_to: FFI_EmitTo,
60 ) -> FFIResult<RVec<WrappedArray>>,
61
62 pub merge_batch: unsafe extern "C" fn(
63 accumulator: &mut Self,
64 values: RVec<WrappedArray>,
65 group_indices: RVec<usize>,
66 opt_filter: ROption<WrappedArray>,
67 total_num_groups: usize,
68 ) -> FFIResult<()>,
69
70 pub convert_to_state: unsafe extern "C" fn(
71 accumulator: &Self,
72 values: RVec<WrappedArray>,
73 opt_filter: ROption<WrappedArray>,
74 ) -> FFIResult<RVec<WrappedArray>>,
75
76 pub supports_convert_to_state: bool,
77
78 pub release: unsafe extern "C" fn(accumulator: &mut Self),
80
81 pub private_data: *mut c_void,
84
85 pub library_marker_id: extern "C" fn() -> usize,
89}
90
91pub struct GroupsAccumulatorPrivateData {
92 pub accumulator: Box<dyn GroupsAccumulator>,
93}
94
95impl FFI_GroupsAccumulator {
96 #[inline]
97 unsafe fn inner_mut(&mut self) -> &mut Box<dyn GroupsAccumulator> {
98 unsafe {
99 let private_data = self.private_data as *mut GroupsAccumulatorPrivateData;
100 &mut (*private_data).accumulator
101 }
102 }
103
104 #[inline]
105 unsafe fn inner(&self) -> &dyn GroupsAccumulator {
106 unsafe {
107 let private_data = self.private_data as *const GroupsAccumulatorPrivateData;
108 (*private_data).accumulator.deref()
109 }
110 }
111}
112
113fn process_values(values: RVec<WrappedArray>) -> Result<Vec<Arc<dyn Array>>> {
114 values
115 .into_iter()
116 .map(|v| v.try_into().map_err(DataFusionError::from))
117 .collect::<Result<Vec<ArrayRef>>>()
118}
119
120fn process_opt_filter(opt_filter: ROption<WrappedArray>) -> Result<Option<BooleanArray>> {
122 opt_filter
123 .into_option()
124 .map(|filter| {
125 ArrayRef::try_from(filter)
126 .map_err(DataFusionError::from)
127 .map(|arr| BooleanArray::from(arr.into_data()))
128 })
129 .transpose()
130}
131
132unsafe extern "C" fn update_batch_fn_wrapper(
133 accumulator: &mut FFI_GroupsAccumulator,
134 values: RVec<WrappedArray>,
135 group_indices: RVec<usize>,
136 opt_filter: ROption<WrappedArray>,
137 total_num_groups: usize,
138) -> FFIResult<()> {
139 unsafe {
140 let accumulator = accumulator.inner_mut();
141 let values = rresult_return!(process_values(values));
142 let group_indices: Vec<usize> = group_indices.into_iter().collect();
143 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
144
145 rresult!(accumulator.update_batch(
146 &values,
147 &group_indices,
148 opt_filter.as_ref(),
149 total_num_groups
150 ))
151 }
152}
153
154unsafe extern "C" fn evaluate_fn_wrapper(
155 accumulator: &mut FFI_GroupsAccumulator,
156 emit_to: FFI_EmitTo,
157) -> FFIResult<WrappedArray> {
158 unsafe {
159 let accumulator = accumulator.inner_mut();
160
161 let result = rresult_return!(accumulator.evaluate(emit_to.into()));
162
163 rresult!(WrappedArray::try_from(&result))
164 }
165}
166
167unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize {
168 unsafe {
169 let accumulator = accumulator.inner();
170 accumulator.size()
171 }
172}
173
174unsafe extern "C" fn state_fn_wrapper(
175 accumulator: &mut FFI_GroupsAccumulator,
176 emit_to: FFI_EmitTo,
177) -> FFIResult<RVec<WrappedArray>> {
178 unsafe {
179 let accumulator = accumulator.inner_mut();
180
181 let state = rresult_return!(accumulator.state(emit_to.into()));
182 rresult!(
183 state
184 .into_iter()
185 .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from))
186 .collect::<Result<RVec<_>>>()
187 )
188 }
189}
190
191unsafe extern "C" fn merge_batch_fn_wrapper(
192 accumulator: &mut FFI_GroupsAccumulator,
193 values: RVec<WrappedArray>,
194 group_indices: RVec<usize>,
195 opt_filter: ROption<WrappedArray>,
196 total_num_groups: usize,
197) -> FFIResult<()> {
198 unsafe {
199 let accumulator = accumulator.inner_mut();
200 let values = rresult_return!(process_values(values));
201 let group_indices: Vec<usize> = group_indices.into_iter().collect();
202 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
203
204 rresult!(accumulator.merge_batch(
205 &values,
206 &group_indices,
207 opt_filter.as_ref(),
208 total_num_groups
209 ))
210 }
211}
212
213unsafe extern "C" fn convert_to_state_fn_wrapper(
214 accumulator: &FFI_GroupsAccumulator,
215 values: RVec<WrappedArray>,
216 opt_filter: ROption<WrappedArray>,
217) -> FFIResult<RVec<WrappedArray>> {
218 unsafe {
219 let accumulator = accumulator.inner();
220 let values = rresult_return!(process_values(values));
221 let opt_filter = rresult_return!(process_opt_filter(opt_filter));
222 let state =
223 rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref()));
224
225 rresult!(
226 state
227 .iter()
228 .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from))
229 .collect::<Result<RVec<_>>>()
230 )
231 }
232}
233
234unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) {
235 unsafe {
236 if !accumulator.private_data.is_null() {
237 let private_data = Box::from_raw(
238 accumulator.private_data as *mut GroupsAccumulatorPrivateData,
239 );
240 drop(private_data);
241 accumulator.private_data = null_mut();
242 }
243 }
244}
245
246impl From<Box<dyn GroupsAccumulator>> for FFI_GroupsAccumulator {
247 fn from(accumulator: Box<dyn GroupsAccumulator>) -> Self {
248 let supports_convert_to_state = accumulator.supports_convert_to_state();
249 let private_data = GroupsAccumulatorPrivateData { accumulator };
250
251 Self {
252 update_batch: update_batch_fn_wrapper,
253 evaluate: evaluate_fn_wrapper,
254 size: size_fn_wrapper,
255 state: state_fn_wrapper,
256 merge_batch: merge_batch_fn_wrapper,
257 convert_to_state: convert_to_state_fn_wrapper,
258 supports_convert_to_state,
259
260 release: release_fn_wrapper,
261 private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
262 library_marker_id: crate::get_library_marker_id,
263 }
264 }
265}
266
267impl Drop for FFI_GroupsAccumulator {
268 fn drop(&mut self) {
269 unsafe { (self.release)(self) }
270 }
271}
272
273#[derive(Debug)]
280pub struct ForeignGroupsAccumulator {
281 accumulator: FFI_GroupsAccumulator,
282}
283
284unsafe impl Send for ForeignGroupsAccumulator {}
285unsafe impl Sync for ForeignGroupsAccumulator {}
286
287impl From<FFI_GroupsAccumulator> for Box<dyn GroupsAccumulator> {
288 fn from(mut accumulator: FFI_GroupsAccumulator) -> Self {
289 if (accumulator.library_marker_id)() == crate::get_library_marker_id() {
290 unsafe {
291 let private_data = Box::from_raw(
292 accumulator.private_data as *mut GroupsAccumulatorPrivateData,
293 );
294 accumulator.private_data = null_mut();
296 private_data.accumulator
297 }
298 } else {
299 Box::new(ForeignGroupsAccumulator { accumulator })
300 }
301 }
302}
303
304impl GroupsAccumulator for ForeignGroupsAccumulator {
305 fn update_batch(
306 &mut self,
307 values: &[ArrayRef],
308 group_indices: &[usize],
309 opt_filter: Option<&BooleanArray>,
310 total_num_groups: usize,
311 ) -> Result<()> {
312 unsafe {
313 let values = values
314 .iter()
315 .map(WrappedArray::try_from)
316 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
317 let group_indices = group_indices.iter().cloned().collect();
318 let opt_filter = opt_filter
319 .map(|bool_array| to_ffi(&bool_array.to_data()))
320 .transpose()?
321 .map(|(array, schema)| WrappedArray {
322 array,
323 schema: WrappedSchema(schema),
324 })
325 .into();
326
327 df_result!((self.accumulator.update_batch)(
328 &mut self.accumulator,
329 values.into(),
330 group_indices,
331 opt_filter,
332 total_num_groups
333 ))
334 }
335 }
336
337 fn size(&self) -> usize {
338 unsafe { (self.accumulator.size)(&self.accumulator) }
339 }
340
341 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
342 unsafe {
343 let return_array = df_result!((self.accumulator.evaluate)(
344 &mut self.accumulator,
345 emit_to.into()
346 ))?;
347
348 return_array.try_into().map_err(DataFusionError::from)
349 }
350 }
351
352 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
353 unsafe {
354 let returned_arrays = df_result!((self.accumulator.state)(
355 &mut self.accumulator,
356 emit_to.into()
357 ))?;
358
359 returned_arrays
360 .into_iter()
361 .map(|wrapped_array| {
362 wrapped_array.try_into().map_err(DataFusionError::from)
363 })
364 .collect::<Result<Vec<_>>>()
365 }
366 }
367
368 fn merge_batch(
369 &mut self,
370 values: &[ArrayRef],
371 group_indices: &[usize],
372 opt_filter: Option<&BooleanArray>,
373 total_num_groups: usize,
374 ) -> Result<()> {
375 unsafe {
376 let values = values
377 .iter()
378 .map(WrappedArray::try_from)
379 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
380 let group_indices = group_indices.iter().cloned().collect();
381 let opt_filter = opt_filter
382 .map(|bool_array| to_ffi(&bool_array.to_data()))
383 .transpose()?
384 .map(|(array, schema)| WrappedArray {
385 array,
386 schema: WrappedSchema(schema),
387 })
388 .into();
389
390 df_result!((self.accumulator.merge_batch)(
391 &mut self.accumulator,
392 values.into(),
393 group_indices,
394 opt_filter,
395 total_num_groups
396 ))
397 }
398 }
399
400 fn convert_to_state(
401 &self,
402 values: &[ArrayRef],
403 opt_filter: Option<&BooleanArray>,
404 ) -> Result<Vec<ArrayRef>> {
405 unsafe {
406 let values = values
407 .iter()
408 .map(WrappedArray::try_from)
409 .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
410
411 let opt_filter = opt_filter
412 .map(|bool_array| to_ffi(&bool_array.to_data()))
413 .transpose()?
414 .map(|(array, schema)| WrappedArray {
415 array,
416 schema: WrappedSchema(schema),
417 })
418 .into();
419
420 let returned_array = df_result!((self.accumulator.convert_to_state)(
421 &self.accumulator,
422 values,
423 opt_filter
424 ))?;
425
426 returned_array
427 .into_iter()
428 .map(|arr| arr.try_into().map_err(DataFusionError::from))
429 .collect()
430 }
431 }
432
433 fn supports_convert_to_state(&self) -> bool {
434 self.accumulator.supports_convert_to_state
435 }
436}
437
438#[repr(C)]
439#[derive(Debug, StableAbi)]
440pub enum FFI_EmitTo {
441 All,
442 First(usize),
443}
444
445impl From<EmitTo> for FFI_EmitTo {
446 fn from(value: EmitTo) -> Self {
447 match value {
448 EmitTo::All => Self::All,
449 EmitTo::First(v) => Self::First(v),
450 }
451 }
452}
453
454impl From<FFI_EmitTo> for EmitTo {
455 fn from(value: FFI_EmitTo) -> Self {
456 match value {
457 FFI_EmitTo::All => Self::All,
458 FFI_EmitTo::First(v) => Self::First(v),
459 }
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use arrow::array::{Array, BooleanArray, make_array};
466 use datafusion::common::create_array;
467 use datafusion::error::Result;
468 use datafusion::functions_aggregate::stddev::StddevGroupsAccumulator;
469 use datafusion::logical_expr::{EmitTo, GroupsAccumulator};
470 use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
471 use datafusion_functions_aggregate_common::stats::StatsType;
472
473 use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator};
474
475 #[test]
476 fn test_foreign_avg_accumulator() -> Result<()> {
477 let boxed_accum: Box<dyn GroupsAccumulator> =
478 Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true));
479 let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
480 ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
481 let mut foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
482
483 let values = create_array!(Boolean, vec![true, true, true, false, true, true]);
485 let opt_filter =
486 create_array!(Boolean, vec![true, true, true, true, false, false]);
487 foreign_accum.update_batch(
488 &[values],
489 &[0, 0, 1, 1, 2, 2],
490 Some(opt_filter.as_ref()),
491 3,
492 )?;
493
494 let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
495 let groups_bool = groups_bool.as_any().downcast_ref::<BooleanArray>().unwrap();
496
497 assert_eq!(
498 groups_bool,
499 create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref()
500 );
501
502 let state = foreign_accum.state(EmitTo::All)?;
503 assert_eq!(state.len(), 1);
504
505 let second_states =
508 vec![make_array(create_array!(Boolean, vec![false]).to_data())];
509
510 let opt_filter = create_array!(Boolean, vec![true]);
511 foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?;
512 let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
513 assert_eq!(groups_bool.len(), 1);
514 assert_eq!(
515 groups_bool.as_ref(),
516 make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
517 );
518
519 let values = create_array!(Boolean, vec![false]);
520 let opt_filter = create_array!(Boolean, vec![true]);
521 let groups_bool =
522 foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?;
523
524 assert_eq!(
525 groups_bool[0].as_ref(),
526 make_array(create_array!(Boolean, vec![false]).to_data()).as_ref()
527 );
528
529 Ok(())
530 }
531
532 fn test_emit_to_round_trip(value: EmitTo) -> Result<()> {
533 let ffi_value: FFI_EmitTo = value.into();
534 let round_trip_value: EmitTo = ffi_value.into();
535
536 assert_eq!(value, round_trip_value);
537 Ok(())
538 }
539
540 #[test]
542 fn test_all_emit_to_round_trip() -> Result<()> {
543 test_emit_to_round_trip(EmitTo::All)?;
544 test_emit_to_round_trip(EmitTo::First(10))?;
545
546 Ok(())
547 }
548
549 #[test]
550 fn test_ffi_groups_accumulator_local_bypass_inner() -> Result<()> {
551 let original_accum = StddevGroupsAccumulator::new(StatsType::Population);
552 let boxed_accum: Box<dyn GroupsAccumulator> = Box::new(original_accum);
553 let original_size = boxed_accum.size();
554
555 let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
556
557 let foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
559 unsafe {
560 let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator
561 as *const StddevGroupsAccumulator);
562 assert_eq!(original_size, concrete.size());
563 }
564
565 let original_accum = StddevGroupsAccumulator::new(StatsType::Population);
567 let boxed_accum: Box<dyn GroupsAccumulator> = Box::new(original_accum);
568 let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into();
569 ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
570 let foreign_accum: Box<dyn GroupsAccumulator> = ffi_accum.into();
571 unsafe {
572 let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator
573 as *const ForeignGroupsAccumulator);
574 assert_eq!(original_size, concrete.size());
575 }
576
577 Ok(())
578 }
579}