datafusion_ffi/udwf/
partition_evaluator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::ops::Range;
20
21use abi_stable::StableAbi;
22use abi_stable::std_types::{RResult, RVec};
23use arrow::array::ArrayRef;
24use arrow::error::ArrowError;
25use datafusion_common::scalar::ScalarValue;
26use datafusion_common::{DataFusionError, Result};
27use datafusion_expr::PartitionEvaluator;
28use datafusion_expr::window_state::WindowAggState;
29use prost::Message;
30
31use super::range::FFI_Range;
32use crate::arrow_wrappers::WrappedArray;
33use crate::util::FFIResult;
34use crate::{df_result, rresult, rresult_return};
35
36/// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries.
37/// For an explanation of each field, see the corresponding function
38/// defined in [`PartitionEvaluator`].
39#[repr(C)]
40#[derive(Debug, StableAbi)]
41pub struct FFI_PartitionEvaluator {
42    pub evaluate_all: unsafe extern "C" fn(
43        evaluator: &mut Self,
44        values: RVec<WrappedArray>,
45        num_rows: usize,
46    ) -> FFIResult<WrappedArray>,
47
48    pub evaluate: unsafe extern "C" fn(
49        evaluator: &mut Self,
50        values: RVec<WrappedArray>,
51        range: FFI_Range,
52    ) -> FFIResult<RVec<u8>>,
53
54    pub evaluate_all_with_rank: unsafe extern "C" fn(
55        evaluator: &Self,
56        num_rows: usize,
57        ranks_in_partition: RVec<FFI_Range>,
58    ) -> FFIResult<WrappedArray>,
59
60    pub get_range: unsafe extern "C" fn(
61        evaluator: &Self,
62        idx: usize,
63        n_rows: usize,
64    ) -> FFIResult<FFI_Range>,
65
66    pub is_causal: bool,
67
68    pub supports_bounded_execution: bool,
69    pub uses_window_frame: bool,
70    pub include_rank: bool,
71
72    /// Release the memory of the private data when it is no longer being used.
73    pub release: unsafe extern "C" fn(evaluator: &mut Self),
74
75    /// Internal data. This is only to be accessed by the provider of the evaluator.
76    /// A [`ForeignPartitionEvaluator`] should never attempt to access this data.
77    pub private_data: *mut c_void,
78
79    /// Utility to identify when FFI objects are accessed locally through
80    /// the foreign interface. See [`crate::get_library_marker_id`] and
81    /// the crate's `README.md` for more information.
82    pub library_marker_id: extern "C" fn() -> usize,
83}
84
85unsafe impl Send for FFI_PartitionEvaluator {}
86unsafe impl Sync for FFI_PartitionEvaluator {}
87
88pub struct PartitionEvaluatorPrivateData {
89    pub evaluator: Box<dyn PartitionEvaluator>,
90}
91
92impl FFI_PartitionEvaluator {
93    unsafe fn inner_mut(&mut self) -> &mut Box<dyn PartitionEvaluator + 'static> {
94        unsafe {
95            let private_data = self.private_data as *mut PartitionEvaluatorPrivateData;
96            &mut (*private_data).evaluator
97        }
98    }
99
100    unsafe fn inner(&self) -> &(dyn PartitionEvaluator + 'static) {
101        unsafe {
102            let private_data = self.private_data as *mut PartitionEvaluatorPrivateData;
103            (*private_data).evaluator.as_ref()
104        }
105    }
106}
107
108unsafe extern "C" fn evaluate_all_fn_wrapper(
109    evaluator: &mut FFI_PartitionEvaluator,
110    values: RVec<WrappedArray>,
111    num_rows: usize,
112) -> FFIResult<WrappedArray> {
113    unsafe {
114        let inner = evaluator.inner_mut();
115
116        let values_arrays = values
117            .into_iter()
118            .map(|v| v.try_into().map_err(DataFusionError::from))
119            .collect::<Result<Vec<ArrayRef>>>();
120        let values_arrays = rresult_return!(values_arrays);
121
122        let return_array =
123            inner
124                .evaluate_all(&values_arrays, num_rows)
125                .and_then(|array| {
126                    WrappedArray::try_from(&array).map_err(DataFusionError::from)
127                });
128
129        rresult!(return_array)
130    }
131}
132
133unsafe extern "C" fn evaluate_fn_wrapper(
134    evaluator: &mut FFI_PartitionEvaluator,
135    values: RVec<WrappedArray>,
136    range: FFI_Range,
137) -> FFIResult<RVec<u8>> {
138    unsafe {
139        let inner = evaluator.inner_mut();
140
141        let values_arrays = values
142            .into_iter()
143            .map(|v| v.try_into().map_err(DataFusionError::from))
144            .collect::<Result<Vec<ArrayRef>>>();
145        let values_arrays = rresult_return!(values_arrays);
146
147        // let return_array = (inner.evaluate(&values_arrays, &range.into()));
148        // .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from));
149        let scalar_result =
150            rresult_return!(inner.evaluate(&values_arrays, &range.into()));
151        let proto_result: datafusion_proto::protobuf::ScalarValue =
152            rresult_return!((&scalar_result).try_into());
153
154        RResult::ROk(proto_result.encode_to_vec().into())
155    }
156}
157
158unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper(
159    evaluator: &FFI_PartitionEvaluator,
160    num_rows: usize,
161    ranks_in_partition: RVec<FFI_Range>,
162) -> FFIResult<WrappedArray> {
163    unsafe {
164        let inner = evaluator.inner();
165
166        let ranks_in_partition = ranks_in_partition
167            .into_iter()
168            .map(Range::from)
169            .collect::<Vec<_>>();
170
171        let return_array = inner
172            .evaluate_all_with_rank(num_rows, &ranks_in_partition)
173            .and_then(|array| {
174                WrappedArray::try_from(&array).map_err(DataFusionError::from)
175            });
176
177        rresult!(return_array)
178    }
179}
180
181unsafe extern "C" fn get_range_fn_wrapper(
182    evaluator: &FFI_PartitionEvaluator,
183    idx: usize,
184    n_rows: usize,
185) -> FFIResult<FFI_Range> {
186    unsafe {
187        let inner = evaluator.inner();
188        let range = inner.get_range(idx, n_rows).map(FFI_Range::from);
189
190        rresult!(range)
191    }
192}
193
194unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) {
195    unsafe {
196        if !evaluator.private_data.is_null() {
197            let private_data = Box::from_raw(
198                evaluator.private_data as *mut PartitionEvaluatorPrivateData,
199            );
200            drop(private_data);
201            evaluator.private_data = std::ptr::null_mut();
202        }
203    }
204}
205
206impl From<Box<dyn PartitionEvaluator>> for FFI_PartitionEvaluator {
207    fn from(evaluator: Box<dyn PartitionEvaluator>) -> Self {
208        let is_causal = evaluator.is_causal();
209        let supports_bounded_execution = evaluator.supports_bounded_execution();
210        let include_rank = evaluator.include_rank();
211        let uses_window_frame = evaluator.uses_window_frame();
212
213        let private_data = PartitionEvaluatorPrivateData { evaluator };
214
215        Self {
216            evaluate: evaluate_fn_wrapper,
217            evaluate_all: evaluate_all_fn_wrapper,
218            evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper,
219            get_range: get_range_fn_wrapper,
220            is_causal,
221            supports_bounded_execution,
222            include_rank,
223            uses_window_frame,
224            release: release_fn_wrapper,
225            private_data: Box::into_raw(Box::new(private_data)) as *mut c_void,
226            library_marker_id: crate::get_library_marker_id,
227        }
228    }
229}
230
231impl Drop for FFI_PartitionEvaluator {
232    fn drop(&mut self) {
233        unsafe { (self.release)(self) }
234    }
235}
236
237/// This struct is used to access an UDF provided by a foreign
238/// library across a FFI boundary.
239///
240/// The ForeignPartitionEvaluator is to be used by the caller of the UDF, so it has
241/// no knowledge or access to the private data. All interaction with the UDF
242/// must occur through the functions defined in FFI_PartitionEvaluator.
243#[derive(Debug)]
244pub struct ForeignPartitionEvaluator {
245    evaluator: FFI_PartitionEvaluator,
246}
247
248impl From<FFI_PartitionEvaluator> for Box<dyn PartitionEvaluator> {
249    fn from(mut evaluator: FFI_PartitionEvaluator) -> Self {
250        if (evaluator.library_marker_id)() == crate::get_library_marker_id() {
251            unsafe {
252                let private_data = Box::from_raw(
253                    evaluator.private_data as *mut PartitionEvaluatorPrivateData,
254                );
255                // We must set this to null to avoid a double free
256                evaluator.private_data = std::ptr::null_mut();
257                private_data.evaluator
258            }
259        } else {
260            Box::new(ForeignPartitionEvaluator { evaluator })
261        }
262    }
263}
264
265impl PartitionEvaluator for ForeignPartitionEvaluator {
266    fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
267        // Exposing `memoize` increases the surface are of the FFI work
268        // so for now we dot support it.
269        Ok(())
270    }
271
272    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
273        let range = unsafe { (self.evaluator.get_range)(&self.evaluator, idx, n_rows) };
274        df_result!(range).map(Range::from)
275    }
276
277    /// Get whether evaluator needs future data for its result (if so returns `false`) or not
278    fn is_causal(&self) -> bool {
279        self.evaluator.is_causal
280    }
281
282    fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
283        let result = unsafe {
284            let values = values
285                .iter()
286                .map(WrappedArray::try_from)
287                .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
288            (self.evaluator.evaluate_all)(&mut self.evaluator, values, num_rows)
289        };
290
291        let array = df_result!(result)?;
292
293        Ok(array.try_into()?)
294    }
295
296    fn evaluate(
297        &mut self,
298        values: &[ArrayRef],
299        range: &Range<usize>,
300    ) -> Result<ScalarValue> {
301        unsafe {
302            let values = values
303                .iter()
304                .map(WrappedArray::try_from)
305                .collect::<std::result::Result<RVec<_>, ArrowError>>()?;
306
307            let scalar_bytes = df_result!((self.evaluator.evaluate)(
308                &mut self.evaluator,
309                values,
310                range.to_owned().into()
311            ))?;
312
313            let proto_scalar =
314                datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref())
315                    .map_err(|e| DataFusionError::External(Box::new(e)))?;
316
317            ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from)
318        }
319    }
320
321    fn evaluate_all_with_rank(
322        &self,
323        num_rows: usize,
324        ranks_in_partition: &[Range<usize>],
325    ) -> Result<ArrayRef> {
326        let result = unsafe {
327            let ranks_in_partition = ranks_in_partition
328                .iter()
329                .map(|rank| FFI_Range::from(rank.to_owned()))
330                .collect();
331            (self.evaluator.evaluate_all_with_rank)(
332                &self.evaluator,
333                num_rows,
334                ranks_in_partition,
335            )
336        };
337
338        let array = df_result!(result)?;
339
340        Ok(array.try_into()?)
341    }
342
343    fn supports_bounded_execution(&self) -> bool {
344        self.evaluator.supports_bounded_execution
345    }
346
347    fn uses_window_frame(&self) -> bool {
348        self.evaluator.uses_window_frame
349    }
350
351    fn include_rank(&self) -> bool {
352        self.evaluator.include_rank
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use arrow::array::ArrayRef;
359    use datafusion::logical_expr::PartitionEvaluator;
360
361    use crate::udwf::partition_evaluator::{
362        FFI_PartitionEvaluator, ForeignPartitionEvaluator,
363    };
364
365    #[derive(Debug)]
366    struct TestPartitionEvaluator {}
367
368    impl PartitionEvaluator for TestPartitionEvaluator {
369        fn evaluate_all(
370            &mut self,
371            values: &[ArrayRef],
372            _num_rows: usize,
373        ) -> datafusion_common::Result<ArrayRef> {
374            Ok(values[0].to_owned())
375        }
376    }
377
378    #[test]
379    fn test_ffi_partition_evaluator_local_bypass_inner() -> datafusion_common::Result<()>
380    {
381        let original_accum = TestPartitionEvaluator {};
382        let boxed_accum: Box<dyn PartitionEvaluator> = Box::new(original_accum);
383
384        let ffi_accum: FFI_PartitionEvaluator = boxed_accum.into();
385
386        // Verify local libraries can be downcast to their original
387        let foreign_accum: Box<dyn PartitionEvaluator> = ffi_accum.into();
388        unsafe {
389            let concrete = &*(foreign_accum.as_ref() as *const dyn PartitionEvaluator
390                as *const TestPartitionEvaluator);
391            assert!(!concrete.uses_window_frame());
392        }
393
394        // Verify different library markers generate foreign accumulator
395        let original_accum = TestPartitionEvaluator {};
396        let boxed_accum: Box<dyn PartitionEvaluator> = Box::new(original_accum);
397        let mut ffi_accum: FFI_PartitionEvaluator = boxed_accum.into();
398        ffi_accum.library_marker_id = crate::mock_foreign_marker_id;
399        let foreign_accum: Box<dyn PartitionEvaluator> = ffi_accum.into();
400        unsafe {
401            let concrete = &*(foreign_accum.as_ref() as *const dyn PartitionEvaluator
402                as *const ForeignPartitionEvaluator);
403            assert!(!concrete.uses_window_frame());
404        }
405
406        Ok(())
407    }
408}