datafusion_ffi/udwf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{ROption, RResult, RString, RVec};
24use arrow::compute::SortOptions;
25use arrow::datatypes::{DataType, Schema, SchemaRef};
26use arrow_schema::{Field, FieldRef};
27use datafusion_common::{Result, ffi_err};
28use datafusion_expr::function::WindowUDFFieldArgs;
29use datafusion_expr::type_coercion::functions::fields_with_udf;
30use datafusion_expr::{
31    LimitEffect, PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl,
32};
33use datafusion_physical_expr::PhysicalExpr;
34use partition_evaluator::FFI_PartitionEvaluator;
35use partition_evaluator_args::{
36    FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
37};
38
39mod partition_evaluator;
40mod partition_evaluator_args;
41mod range;
42
43use crate::arrow_wrappers::WrappedSchema;
44use crate::util::{
45    FFIResult, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
46    vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
47};
48use crate::volatility::FFI_Volatility;
49use crate::{df_result, rresult, rresult_return};
50
51/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries.
52#[repr(C)]
53#[derive(Debug, StableAbi)]
54pub struct FFI_WindowUDF {
55    /// FFI equivalent to the `name` of a [`WindowUDF`]
56    pub name: RString,
57
58    /// FFI equivalent to the `aliases` of a [`WindowUDF`]
59    pub aliases: RVec<RString>,
60
61    /// FFI equivalent to the `volatility` of a [`WindowUDF`]
62    pub volatility: FFI_Volatility,
63
64    pub partition_evaluator: unsafe extern "C" fn(
65        udwf: &Self,
66        args: FFI_PartitionEvaluatorArgs,
67    )
68        -> FFIResult<FFI_PartitionEvaluator>,
69
70    pub field: unsafe extern "C" fn(
71        udwf: &Self,
72        input_types: RVec<WrappedSchema>,
73        display_name: RString,
74    ) -> FFIResult<WrappedSchema>,
75
76    /// Performs type coercion. To simply this interface, all UDFs are treated as having
77    /// user defined signatures, which will in turn call coerce_types to be called. This
78    /// call should be transparent to most users as the internal function performs the
79    /// appropriate calls on the underlying [`WindowUDF`]
80    pub coerce_types: unsafe extern "C" fn(
81        udf: &Self,
82        arg_types: RVec<WrappedSchema>,
83    ) -> FFIResult<RVec<WrappedSchema>>,
84
85    pub sort_options: ROption<FFI_SortOptions>,
86
87    /// Used to create a clone on the provider of the udf. This should
88    /// only need to be called by the receiver of the udf.
89    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
90
91    /// Release the memory of the private data when it is no longer being used.
92    pub release: unsafe extern "C" fn(udf: &mut Self),
93
94    /// Internal data. This is only to be accessed by the provider of the udf.
95    /// A [`ForeignWindowUDF`] should never attempt to access this data.
96    pub private_data: *mut c_void,
97
98    /// Utility to identify when FFI objects are accessed locally through
99    /// the foreign interface. See [`crate::get_library_marker_id`] and
100    /// the crate's `README.md` for more information.
101    pub library_marker_id: extern "C" fn() -> usize,
102}
103
104unsafe impl Send for FFI_WindowUDF {}
105unsafe impl Sync for FFI_WindowUDF {}
106
107pub struct WindowUDFPrivateData {
108    pub udf: Arc<WindowUDF>,
109}
110
111impl FFI_WindowUDF {
112    unsafe fn inner(&self) -> &Arc<WindowUDF> {
113        unsafe {
114            let private_data = self.private_data as *const WindowUDFPrivateData;
115            &(*private_data).udf
116        }
117    }
118}
119
120unsafe extern "C" fn partition_evaluator_fn_wrapper(
121    udwf: &FFI_WindowUDF,
122    args: FFI_PartitionEvaluatorArgs,
123) -> FFIResult<FFI_PartitionEvaluator> {
124    unsafe {
125        let inner = udwf.inner();
126
127        let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
128
129        let evaluator =
130            rresult_return!(inner.partition_evaluator_factory((&args).into()));
131
132        RResult::ROk(evaluator.into())
133    }
134}
135
136unsafe extern "C" fn field_fn_wrapper(
137    udwf: &FFI_WindowUDF,
138    input_fields: RVec<WrappedSchema>,
139    display_name: RString,
140) -> FFIResult<WrappedSchema> {
141    unsafe {
142        let inner = udwf.inner();
143
144        let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
145
146        let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
147            &input_fields,
148            display_name.as_str()
149        )));
150
151        let schema = Arc::new(Schema::new(vec![field]));
152
153        RResult::ROk(WrappedSchema::from(schema))
154    }
155}
156
157unsafe extern "C" fn coerce_types_fn_wrapper(
158    udwf: &FFI_WindowUDF,
159    arg_types: RVec<WrappedSchema>,
160) -> FFIResult<RVec<WrappedSchema>> {
161    unsafe {
162        let inner = udwf.inner();
163
164        let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
165            .into_iter()
166            .map(|dt| Field::new("f", dt, false))
167            .map(Arc::new)
168            .collect::<Vec<_>>();
169
170        let return_fields = rresult_return!(fields_with_udf(&arg_fields, inner.as_ref()));
171        let return_types = return_fields
172            .into_iter()
173            .map(|f| f.data_type().to_owned())
174            .collect::<Vec<_>>();
175
176        rresult!(vec_datatype_to_rvec_wrapped(&return_types))
177    }
178}
179
180unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
181    unsafe {
182        debug_assert!(!udwf.private_data.is_null());
183        let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
184        drop(private_data);
185        udwf.private_data = std::ptr::null_mut();
186    }
187}
188
189unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
190    unsafe {
191        // let private_data = udf.private_data as *const WindowUDFPrivateData;
192        // let udf_data = &(*private_data);
193
194        // let private_data = Box::new(WindowUDFPrivateData {
195        //     udf: Arc::clone(&udf_data.udf),
196        // });
197        let private_data = Box::new(WindowUDFPrivateData {
198            udf: Arc::clone(udwf.inner()),
199        });
200
201        FFI_WindowUDF {
202            name: udwf.name.clone(),
203            aliases: udwf.aliases.clone(),
204            volatility: udwf.volatility.clone(),
205            partition_evaluator: partition_evaluator_fn_wrapper,
206            sort_options: udwf.sort_options.clone(),
207            coerce_types: coerce_types_fn_wrapper,
208            field: field_fn_wrapper,
209            clone: clone_fn_wrapper,
210            release: release_fn_wrapper,
211            private_data: Box::into_raw(private_data) as *mut c_void,
212            library_marker_id: crate::get_library_marker_id,
213        }
214    }
215}
216
217impl Clone for FFI_WindowUDF {
218    fn clone(&self) -> Self {
219        unsafe { (self.clone)(self) }
220    }
221}
222
223impl From<Arc<WindowUDF>> for FFI_WindowUDF {
224    fn from(udf: Arc<WindowUDF>) -> Self {
225        let name = udf.name().into();
226        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
227        let volatility = udf.signature().volatility.into();
228        let sort_options = udf.sort_options().map(|v| (&v).into()).into();
229
230        let private_data = Box::new(WindowUDFPrivateData { udf });
231
232        Self {
233            name,
234            aliases,
235            volatility,
236            partition_evaluator: partition_evaluator_fn_wrapper,
237            sort_options,
238            coerce_types: coerce_types_fn_wrapper,
239            field: field_fn_wrapper,
240            clone: clone_fn_wrapper,
241            release: release_fn_wrapper,
242            private_data: Box::into_raw(private_data) as *mut c_void,
243            library_marker_id: crate::get_library_marker_id,
244        }
245    }
246}
247
248impl Drop for FFI_WindowUDF {
249    fn drop(&mut self) {
250        unsafe { (self.release)(self) }
251    }
252}
253
254/// This struct is used to access an UDF provided by a foreign
255/// library across a FFI boundary.
256///
257/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has
258/// no knowledge or access to the private data. All interaction with the UDF
259/// must occur through the functions defined in FFI_WindowUDF.
260#[derive(Debug)]
261pub struct ForeignWindowUDF {
262    name: String,
263    aliases: Vec<String>,
264    udf: FFI_WindowUDF,
265    signature: Signature,
266}
267
268unsafe impl Send for ForeignWindowUDF {}
269unsafe impl Sync for ForeignWindowUDF {}
270
271impl PartialEq for ForeignWindowUDF {
272    fn eq(&self, other: &Self) -> bool {
273        // FFI_WindowUDF cannot be compared, so identity equality is the best we can do.
274        std::ptr::eq(self, other)
275    }
276}
277impl Eq for ForeignWindowUDF {}
278impl Hash for ForeignWindowUDF {
279    fn hash<H: Hasher>(&self, state: &mut H) {
280        std::ptr::hash(self, state)
281    }
282}
283
284impl From<&FFI_WindowUDF> for Arc<dyn WindowUDFImpl> {
285    fn from(udf: &FFI_WindowUDF) -> Self {
286        if (udf.library_marker_id)() == crate::get_library_marker_id() {
287            Arc::clone(unsafe { udf.inner().inner() })
288        } else {
289            let name = udf.name.to_owned().into();
290            let signature = Signature::user_defined((&udf.volatility).into());
291
292            let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
293
294            Arc::new(ForeignWindowUDF {
295                name,
296                udf: udf.clone(),
297                aliases,
298                signature,
299            })
300        }
301    }
302}
303
304impl WindowUDFImpl for ForeignWindowUDF {
305    fn as_any(&self) -> &dyn std::any::Any {
306        self
307    }
308
309    fn name(&self) -> &str {
310        &self.name
311    }
312
313    fn signature(&self) -> &Signature {
314        &self.signature
315    }
316
317    fn aliases(&self) -> &[String] {
318        &self.aliases
319    }
320
321    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
322        unsafe {
323            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
324            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
325            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
326        }
327    }
328
329    fn partition_evaluator(
330        &self,
331        args: datafusion_expr::function::PartitionEvaluatorArgs,
332    ) -> Result<Box<dyn PartitionEvaluator>> {
333        let evaluator = unsafe {
334            let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
335            (self.udf.partition_evaluator)(&self.udf, args)
336        };
337
338        df_result!(evaluator).map(<Box<dyn PartitionEvaluator>>::from)
339    }
340
341    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
342        unsafe {
343            let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
344            let schema = df_result!((self.udf.field)(
345                &self.udf,
346                input_types,
347                field_args.name().into()
348            ))?;
349            let schema: SchemaRef = schema.into();
350
351            match schema.fields().is_empty() {
352                true => ffi_err!(
353                    "Unable to retrieve field in WindowUDF via FFI - schema has no fields"
354                ),
355                false => Ok(schema.field(0).to_owned().into()),
356            }
357        }
358    }
359
360    fn sort_options(&self) -> Option<SortOptions> {
361        let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
362        options.map(|s| s.into())
363    }
364
365    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
366        LimitEffect::Unknown
367    }
368}
369
370#[repr(C)]
371#[derive(Debug, StableAbi, Clone)]
372pub struct FFI_SortOptions {
373    pub descending: bool,
374    pub nulls_first: bool,
375}
376
377impl From<&SortOptions> for FFI_SortOptions {
378    fn from(value: &SortOptions) -> Self {
379        Self {
380            descending: value.descending,
381            nulls_first: value.nulls_first,
382        }
383    }
384}
385
386impl From<&FFI_SortOptions> for SortOptions {
387    fn from(value: &FFI_SortOptions) -> Self {
388        Self {
389            descending: value.descending,
390            nulls_first: value.nulls_first,
391        }
392    }
393}
394
395#[cfg(test)]
396#[cfg(feature = "integration-tests")]
397mod tests {
398    use std::sync::Arc;
399
400    use arrow::array::{ArrayRef, create_array};
401    use datafusion::functions_window::lead_lag::{WindowShift, lag_udwf};
402    use datafusion::logical_expr::expr::Sort;
403    use datafusion::logical_expr::{ExprFunctionExt, WindowUDF, WindowUDFImpl, col};
404    use datafusion::prelude::SessionContext;
405
406    use crate::tests::create_record_batch;
407    use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
408
409    fn create_test_foreign_udwf(
410        original_udwf: impl WindowUDFImpl + 'static,
411    ) -> datafusion::common::Result<WindowUDF> {
412        let original_udwf = Arc::new(WindowUDF::from(original_udwf));
413
414        let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
415        local_udwf.library_marker_id = crate::mock_foreign_marker_id;
416
417        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
418        Ok(WindowUDF::new_from_shared_impl(foreign_udwf))
419    }
420
421    #[test]
422    fn test_round_trip_udwf() -> datafusion::common::Result<()> {
423        let original_udwf = lag_udwf();
424        let original_name = original_udwf.name().to_owned();
425
426        // Convert to FFI format
427        let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
428        local_udwf.library_marker_id = crate::mock_foreign_marker_id;
429
430        // Convert back to native format
431        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
432        let foreign_udwf = WindowUDF::new_from_shared_impl(foreign_udwf);
433
434        assert_eq!(original_name, foreign_udwf.name());
435        Ok(())
436    }
437
438    #[tokio::test]
439    async fn test_lag_udwf() -> datafusion::common::Result<()> {
440        let udwf = create_test_foreign_udwf(WindowShift::lag())?;
441
442        let ctx = SessionContext::default();
443        let df = ctx.read_batch(create_record_batch(-5, 5))?;
444
445        let df = df.select(vec![
446            col("a"),
447            udwf.call(vec![col("a")])
448                .order_by(vec![Sort::new(col("a"), true, true)])
449                .build()
450                .unwrap()
451                .alias("lag_a"),
452        ])?;
453
454        df.clone().show().await?;
455
456        let result = df.collect().await?;
457        let expected =
458            create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
459                as ArrayRef;
460
461        assert_eq!(result.len(), 1);
462        assert_eq!(result[0].column(1), &expected);
463
464        Ok(())
465    }
466
467    #[test]
468    fn test_ffi_udwf_local_bypass() -> datafusion_common::Result<()> {
469        let original_udwf = Arc::new(WindowUDF::from(WindowShift::lag()));
470
471        let mut ffi_udwf = FFI_WindowUDF::from(original_udwf);
472
473        // Verify local libraries can be downcast to their original
474        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
475        assert!(
476            foreign_udwf
477                .as_any()
478                .downcast_ref::<WindowShift>()
479                .is_some()
480        );
481
482        // Verify different library markers generate foreign providers
483        ffi_udwf.library_marker_id = crate::mock_foreign_marker_id;
484        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
485        assert!(
486            foreign_udwf
487                .as_any()
488                .downcast_ref::<ForeignWindowUDF>()
489                .is_some()
490        );
491
492        Ok(())
493    }
494}