Skip to main content

datafusion_ffi/udwf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use arrow::compute::SortOptions;
23use arrow::datatypes::{DataType, Schema, SchemaRef};
24use arrow_schema::{Field, FieldRef};
25use datafusion_common::{Result, ffi_err};
26use datafusion_expr::function::WindowUDFFieldArgs;
27use datafusion_expr::type_coercion::functions::fields_with_udf;
28use datafusion_expr::{
29    LimitEffect, PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl,
30};
31use datafusion_physical_expr::PhysicalExpr;
32use partition_evaluator::FFI_PartitionEvaluator;
33use partition_evaluator_args::{
34    FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
35};
36
37use stabby::string::String as SString;
38use stabby::vec::Vec as SVec;
39
40mod partition_evaluator;
41mod partition_evaluator_args;
42mod range;
43
44use crate::arrow_wrappers::WrappedSchema;
45use crate::util::{
46    FFI_Option, FFI_Result, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
47    vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
48};
49use crate::volatility::FFI_Volatility;
50use crate::{df_result, sresult, sresult_return};
51
52/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries.
53#[repr(C)]
54#[derive(Debug)]
55pub struct FFI_WindowUDF {
56    /// FFI equivalent to the `name` of a [`WindowUDF`]
57    pub name: SString,
58
59    /// FFI equivalent to the `aliases` of a [`WindowUDF`]
60    pub aliases: SVec<SString>,
61
62    /// FFI equivalent to the `volatility` of a [`WindowUDF`]
63    pub volatility: FFI_Volatility,
64
65    pub partition_evaluator: unsafe extern "C" fn(
66        udwf: &Self,
67        args: FFI_PartitionEvaluatorArgs,
68    )
69        -> FFI_Result<FFI_PartitionEvaluator>,
70
71    pub field: unsafe extern "C" fn(
72        udwf: &Self,
73        input_types: SVec<WrappedSchema>,
74        display_name: SString,
75    ) -> FFI_Result<WrappedSchema>,
76
77    /// Performs type coercion. To simply this interface, all UDFs are treated as having
78    /// user defined signatures, which will in turn call coerce_types to be called. This
79    /// call should be transparent to most users as the internal function performs the
80    /// appropriate calls on the underlying [`WindowUDF`]
81    pub coerce_types: unsafe extern "C" fn(
82        udf: &Self,
83        arg_types: SVec<WrappedSchema>,
84    ) -> FFI_Result<SVec<WrappedSchema>>,
85
86    pub sort_options: FFI_Option<FFI_SortOptions>,
87
88    /// Used to create a clone on the provider of the udf. This should
89    /// only need to be called by the receiver of the udf.
90    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
91
92    /// Release the memory of the private data when it is no longer being used.
93    pub release: unsafe extern "C" fn(udf: &mut Self),
94
95    /// Internal data. This is only to be accessed by the provider of the udf.
96    /// A [`ForeignWindowUDF`] should never attempt to access this data.
97    pub private_data: *mut c_void,
98
99    /// Utility to identify when FFI objects are accessed locally through
100    /// the foreign interface. See [`crate::get_library_marker_id`] and
101    /// the crate's `README.md` for more information.
102    pub library_marker_id: extern "C" fn() -> usize,
103}
104
105unsafe impl Send for FFI_WindowUDF {}
106unsafe impl Sync for FFI_WindowUDF {}
107
108pub struct WindowUDFPrivateData {
109    pub udf: Arc<WindowUDF>,
110}
111
112impl FFI_WindowUDF {
113    unsafe fn inner(&self) -> &Arc<WindowUDF> {
114        unsafe {
115            let private_data = self.private_data as *const WindowUDFPrivateData;
116            &(*private_data).udf
117        }
118    }
119}
120
121unsafe extern "C" fn partition_evaluator_fn_wrapper(
122    udwf: &FFI_WindowUDF,
123    args: FFI_PartitionEvaluatorArgs,
124) -> FFI_Result<FFI_PartitionEvaluator> {
125    unsafe {
126        let inner = udwf.inner();
127
128        let args = sresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
129
130        let evaluator =
131            sresult_return!(inner.partition_evaluator_factory((&args).into()));
132
133        FFI_Result::Ok(evaluator.into())
134    }
135}
136
137unsafe extern "C" fn field_fn_wrapper(
138    udwf: &FFI_WindowUDF,
139    input_fields: SVec<WrappedSchema>,
140    display_name: SString,
141) -> FFI_Result<WrappedSchema> {
142    unsafe {
143        let inner = udwf.inner();
144
145        let input_fields = sresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
146
147        let field = sresult_return!(inner.field(WindowUDFFieldArgs::new(
148            &input_fields,
149            display_name.as_str()
150        )));
151
152        let schema = Arc::new(Schema::new(vec![field]));
153
154        FFI_Result::Ok(WrappedSchema::from(schema))
155    }
156}
157
158unsafe extern "C" fn coerce_types_fn_wrapper(
159    udwf: &FFI_WindowUDF,
160    arg_types: SVec<WrappedSchema>,
161) -> FFI_Result<SVec<WrappedSchema>> {
162    unsafe {
163        let inner = udwf.inner();
164
165        let arg_fields = sresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
166            .into_iter()
167            .map(|dt| Field::new("f", dt, false))
168            .map(Arc::new)
169            .collect::<Vec<_>>();
170
171        let return_fields = sresult_return!(fields_with_udf(&arg_fields, inner.as_ref()));
172        let return_types = return_fields
173            .into_iter()
174            .map(|f| f.data_type().to_owned())
175            .collect::<Vec<_>>();
176
177        sresult!(vec_datatype_to_rvec_wrapped(&return_types))
178    }
179}
180
181unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
182    unsafe {
183        debug_assert!(!udwf.private_data.is_null());
184        let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
185        drop(private_data);
186        udwf.private_data = std::ptr::null_mut();
187    }
188}
189
190unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
191    unsafe {
192        // let private_data = udf.private_data as *const WindowUDFPrivateData;
193        // let udf_data = &(*private_data);
194
195        // let private_data = Box::new(WindowUDFPrivateData {
196        //     udf: Arc::clone(&udf_data.udf),
197        // });
198        let private_data = Box::new(WindowUDFPrivateData {
199            udf: Arc::clone(udwf.inner()),
200        });
201
202        FFI_WindowUDF {
203            name: udwf.name.clone(),
204            aliases: udwf.aliases.clone(),
205            volatility: udwf.volatility.clone(),
206            partition_evaluator: partition_evaluator_fn_wrapper,
207            sort_options: udwf.sort_options.clone(),
208            coerce_types: coerce_types_fn_wrapper,
209            field: field_fn_wrapper,
210            clone: clone_fn_wrapper,
211            release: release_fn_wrapper,
212            private_data: Box::into_raw(private_data) as *mut c_void,
213            library_marker_id: crate::get_library_marker_id,
214        }
215    }
216}
217
218impl Clone for FFI_WindowUDF {
219    fn clone(&self) -> Self {
220        unsafe { (self.clone)(self) }
221    }
222}
223
224impl From<Arc<WindowUDF>> for FFI_WindowUDF {
225    fn from(udf: Arc<WindowUDF>) -> Self {
226        if let Some(udwf) = udf.inner().downcast_ref::<ForeignWindowUDF>() {
227            return udwf.udf.clone();
228        }
229
230        let name = udf.name().into();
231        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
232        let volatility = udf.signature().volatility.into();
233        let sort_options = udf.sort_options().map(|v| (&v).into()).into();
234
235        let private_data = Box::new(WindowUDFPrivateData { udf });
236
237        Self {
238            name,
239            aliases,
240            volatility,
241            partition_evaluator: partition_evaluator_fn_wrapper,
242            sort_options,
243            coerce_types: coerce_types_fn_wrapper,
244            field: field_fn_wrapper,
245            clone: clone_fn_wrapper,
246            release: release_fn_wrapper,
247            private_data: Box::into_raw(private_data) as *mut c_void,
248            library_marker_id: crate::get_library_marker_id,
249        }
250    }
251}
252
253impl Drop for FFI_WindowUDF {
254    fn drop(&mut self) {
255        unsafe { (self.release)(self) }
256    }
257}
258
259/// This struct is used to access an UDF provided by a foreign
260/// library across a FFI boundary.
261///
262/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has
263/// no knowledge or access to the private data. All interaction with the UDF
264/// must occur through the functions defined in FFI_WindowUDF.
265#[derive(Debug)]
266pub struct ForeignWindowUDF {
267    name: String,
268    aliases: Vec<String>,
269    udf: FFI_WindowUDF,
270    signature: Signature,
271}
272
273unsafe impl Send for ForeignWindowUDF {}
274unsafe impl Sync for ForeignWindowUDF {}
275
276impl PartialEq for ForeignWindowUDF {
277    fn eq(&self, other: &Self) -> bool {
278        // FFI_WindowUDF cannot be compared, so identity equality is the best we can do.
279        std::ptr::eq(self, other)
280    }
281}
282impl Eq for ForeignWindowUDF {}
283impl Hash for ForeignWindowUDF {
284    fn hash<H: Hasher>(&self, state: &mut H) {
285        std::ptr::hash(self, state)
286    }
287}
288
289impl From<&FFI_WindowUDF> for Arc<dyn WindowUDFImpl> {
290    fn from(udf: &FFI_WindowUDF) -> Self {
291        if (udf.library_marker_id)() == crate::get_library_marker_id() {
292            Arc::clone(unsafe { udf.inner().inner() })
293        } else {
294            let name = udf.name.to_owned().into();
295            let signature = Signature::user_defined((&udf.volatility).into());
296
297            let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
298
299            Arc::new(ForeignWindowUDF {
300                name,
301                udf: udf.clone(),
302                aliases,
303                signature,
304            })
305        }
306    }
307}
308
309impl WindowUDFImpl for ForeignWindowUDF {
310    fn name(&self) -> &str {
311        &self.name
312    }
313
314    fn signature(&self) -> &Signature {
315        &self.signature
316    }
317
318    fn aliases(&self) -> &[String] {
319        &self.aliases
320    }
321
322    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
323        unsafe {
324            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
325            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
326            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
327        }
328    }
329
330    fn partition_evaluator(
331        &self,
332        args: datafusion_expr::function::PartitionEvaluatorArgs,
333    ) -> Result<Box<dyn PartitionEvaluator>> {
334        let evaluator = unsafe {
335            let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
336            (self.udf.partition_evaluator)(&self.udf, args)
337        };
338
339        df_result!(evaluator).map(<Box<dyn PartitionEvaluator>>::from)
340    }
341
342    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
343        unsafe {
344            let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
345            let schema = df_result!((self.udf.field)(
346                &self.udf,
347                input_types,
348                field_args.name().into()
349            ))?;
350            let schema: SchemaRef = schema.into();
351
352            match schema.fields().is_empty() {
353                true => ffi_err!(
354                    "Unable to retrieve field in WindowUDF via FFI - schema has no fields"
355                ),
356                false => Ok(schema.field(0).to_owned().into()),
357            }
358        }
359    }
360
361    fn sort_options(&self) -> Option<SortOptions> {
362        let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref();
363        options.map(|s| s.into())
364    }
365
366    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
367        LimitEffect::Unknown
368    }
369}
370
371#[repr(C)]
372#[derive(Debug, Clone)]
373pub struct FFI_SortOptions {
374    pub descending: bool,
375    pub nulls_first: bool,
376}
377
378impl From<&SortOptions> for FFI_SortOptions {
379    fn from(value: &SortOptions) -> Self {
380        Self {
381            descending: value.descending,
382            nulls_first: value.nulls_first,
383        }
384    }
385}
386
387impl From<&FFI_SortOptions> for SortOptions {
388    fn from(value: &FFI_SortOptions) -> Self {
389        Self {
390            descending: value.descending,
391            nulls_first: value.nulls_first,
392        }
393    }
394}
395
396#[cfg(test)]
397#[cfg(feature = "integration-tests")]
398mod tests {
399    use std::sync::Arc;
400
401    use arrow::array::{ArrayRef, create_array};
402    use datafusion::functions_window::lead_lag::{WindowShift, lag_udwf};
403    use datafusion::logical_expr::expr::Sort;
404    use datafusion::logical_expr::{ExprFunctionExt, WindowUDF, WindowUDFImpl, col};
405    use datafusion::prelude::SessionContext;
406
407    use crate::tests::create_record_batch;
408    use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
409
410    fn create_test_foreign_udwf(
411        original_udwf: impl WindowUDFImpl + 'static,
412    ) -> datafusion::common::Result<WindowUDF> {
413        let original_udwf = Arc::new(WindowUDF::from(original_udwf));
414
415        let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
416        local_udwf.library_marker_id = crate::mock_foreign_marker_id;
417
418        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
419        Ok(WindowUDF::new_from_shared_impl(foreign_udwf))
420    }
421
422    #[test]
423    fn test_round_trip_udwf() -> datafusion::common::Result<()> {
424        let original_udwf = lag_udwf();
425        let original_name = original_udwf.name().to_owned();
426
427        // Convert to FFI format
428        let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
429        local_udwf.library_marker_id = crate::mock_foreign_marker_id;
430
431        // Convert back to native format
432        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
433        let foreign_udwf = WindowUDF::new_from_shared_impl(foreign_udwf);
434
435        assert_eq!(original_name, foreign_udwf.name());
436        Ok(())
437    }
438
439    #[tokio::test]
440    async fn test_lag_udwf() -> datafusion::common::Result<()> {
441        let udwf = create_test_foreign_udwf(WindowShift::lag())?;
442
443        let ctx = SessionContext::default();
444        let df = ctx.read_batch(create_record_batch(-5, 5))?;
445
446        let df = df.select(vec![
447            col("a"),
448            udwf.call(vec![col("a")])
449                .order_by(vec![Sort::new(col("a"), true, true)])
450                .build()
451                .unwrap()
452                .alias("lag_a"),
453        ])?;
454
455        df.clone().show().await?;
456
457        let result = df.collect().await?;
458        let expected =
459            create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
460                as ArrayRef;
461
462        assert_eq!(result.len(), 1);
463        assert_eq!(result[0].column(1), &expected);
464
465        Ok(())
466    }
467
468    #[test]
469    fn test_ffi_udwf_local_bypass() -> datafusion_common::Result<()> {
470        let original_udwf = Arc::new(WindowUDF::from(WindowShift::lag()));
471
472        let mut ffi_udwf = FFI_WindowUDF::from(original_udwf);
473
474        // Verify local libraries can be downcast to their original
475        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
476        assert!(foreign_udwf.is::<WindowShift>());
477
478        // Verify different library markers generate foreign providers
479        ffi_udwf.library_marker_id = crate::mock_foreign_marker_id;
480        let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
481        assert!(foreign_udwf.is::<ForeignWindowUDF>());
482
483        Ok(())
484    }
485}