datafusion_ffi/udwf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use abi_stable::{
19    std_types::{ROption, RResult, RString, RVec},
20    StableAbi,
21};
22use arrow::datatypes::Schema;
23use arrow::{
24    compute::SortOptions,
25    datatypes::{DataType, SchemaRef},
26};
27use arrow_schema::{Field, FieldRef};
28use datafusion::logical_expr::LimitEffect;
29use datafusion::physical_expr::PhysicalExpr;
30use datafusion::{
31    error::DataFusionError,
32    logical_expr::{
33        function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf,
34        PartitionEvaluator,
35    },
36};
37use datafusion::{
38    error::Result,
39    logical_expr::{Signature, WindowUDF, WindowUDFImpl},
40};
41use datafusion_common::exec_err;
42use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator};
43use partition_evaluator_args::{
44    FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
45};
46use std::hash::{Hash, Hasher};
47use std::{ffi::c_void, sync::Arc};
48
49mod partition_evaluator;
50mod partition_evaluator_args;
51mod range;
52
53use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
54use crate::{
55    arrow_wrappers::WrappedSchema,
56    df_result, rresult, rresult_return,
57    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
58    volatility::FFI_Volatility,
59};
60
61/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries.
62#[repr(C)]
63#[derive(Debug, StableAbi)]
64#[allow(non_camel_case_types)]
65pub struct FFI_WindowUDF {
66    /// FFI equivalent to the `name` of a [`WindowUDF`]
67    pub name: RString,
68
69    /// FFI equivalent to the `aliases` of a [`WindowUDF`]
70    pub aliases: RVec<RString>,
71
72    /// FFI equivalent to the `volatility` of a [`WindowUDF`]
73    pub volatility: FFI_Volatility,
74
75    pub partition_evaluator:
76        unsafe extern "C" fn(
77            udwf: &Self,
78            args: FFI_PartitionEvaluatorArgs,
79        ) -> RResult<FFI_PartitionEvaluator, RString>,
80
81    pub field: unsafe extern "C" fn(
82        udwf: &Self,
83        input_types: RVec<WrappedSchema>,
84        display_name: RString,
85    ) -> RResult<WrappedSchema, RString>,
86
87    /// Performs type coercion. To simply this interface, all UDFs are treated as having
88    /// user defined signatures, which will in turn call coerce_types to be called. This
89    /// call should be transparent to most users as the internal function performs the
90    /// appropriate calls on the underlying [`WindowUDF`]
91    pub coerce_types: unsafe extern "C" fn(
92        udf: &Self,
93        arg_types: RVec<WrappedSchema>,
94    ) -> RResult<RVec<WrappedSchema>, RString>,
95
96    pub sort_options: ROption<FFI_SortOptions>,
97
98    /// Used to create a clone on the provider of the udf. This should
99    /// only need to be called by the receiver of the udf.
100    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
101
102    /// Release the memory of the private data when it is no longer being used.
103    pub release: unsafe extern "C" fn(udf: &mut Self),
104
105    /// Internal data. This is only to be accessed by the provider of the udf.
106    /// A [`ForeignWindowUDF`] should never attempt to access this data.
107    pub private_data: *mut c_void,
108}
109
110unsafe impl Send for FFI_WindowUDF {}
111unsafe impl Sync for FFI_WindowUDF {}
112
113pub struct WindowUDFPrivateData {
114    pub udf: Arc<WindowUDF>,
115}
116
117impl FFI_WindowUDF {
118    unsafe fn inner(&self) -> &Arc<WindowUDF> {
119        let private_data = self.private_data as *const WindowUDFPrivateData;
120        &(*private_data).udf
121    }
122}
123
124unsafe extern "C" fn partition_evaluator_fn_wrapper(
125    udwf: &FFI_WindowUDF,
126    args: FFI_PartitionEvaluatorArgs,
127) -> RResult<FFI_PartitionEvaluator, RString> {
128    let inner = udwf.inner();
129
130    let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
131
132    let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into()));
133
134    RResult::ROk(evaluator.into())
135}
136
137unsafe extern "C" fn field_fn_wrapper(
138    udwf: &FFI_WindowUDF,
139    input_fields: RVec<WrappedSchema>,
140    display_name: RString,
141) -> RResult<WrappedSchema, RString> {
142    let inner = udwf.inner();
143
144    let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
145
146    let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
147        &input_fields,
148        display_name.as_str()
149    )));
150
151    let schema = Arc::new(Schema::new(vec![field]));
152
153    RResult::ROk(WrappedSchema::from(schema))
154}
155
156unsafe extern "C" fn coerce_types_fn_wrapper(
157    udwf: &FFI_WindowUDF,
158    arg_types: RVec<WrappedSchema>,
159) -> RResult<RVec<WrappedSchema>, RString> {
160    let inner = udwf.inner();
161
162    let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
163        .into_iter()
164        .map(|dt| Field::new("f", dt, false))
165        .map(Arc::new)
166        .collect::<Vec<_>>();
167
168    let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner));
169    let return_types = return_fields
170        .into_iter()
171        .map(|f| f.data_type().to_owned())
172        .collect::<Vec<_>>();
173
174    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
175}
176
177unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
178    let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
179    drop(private_data);
180}
181
182unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
183    // let private_data = udf.private_data as *const WindowUDFPrivateData;
184    // let udf_data = &(*private_data);
185
186    // let private_data = Box::new(WindowUDFPrivateData {
187    //     udf: Arc::clone(&udf_data.udf),
188    // });
189    let private_data = Box::new(WindowUDFPrivateData {
190        udf: Arc::clone(udwf.inner()),
191    });
192
193    FFI_WindowUDF {
194        name: udwf.name.clone(),
195        aliases: udwf.aliases.clone(),
196        volatility: udwf.volatility.clone(),
197        partition_evaluator: partition_evaluator_fn_wrapper,
198        sort_options: udwf.sort_options.clone(),
199        coerce_types: coerce_types_fn_wrapper,
200        field: field_fn_wrapper,
201        clone: clone_fn_wrapper,
202        release: release_fn_wrapper,
203        private_data: Box::into_raw(private_data) as *mut c_void,
204    }
205}
206
207impl Clone for FFI_WindowUDF {
208    fn clone(&self) -> Self {
209        unsafe { (self.clone)(self) }
210    }
211}
212
213impl From<Arc<WindowUDF>> for FFI_WindowUDF {
214    fn from(udf: Arc<WindowUDF>) -> Self {
215        let name = udf.name().into();
216        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
217        let volatility = udf.signature().volatility.into();
218        let sort_options = udf.sort_options().map(|v| (&v).into()).into();
219
220        let private_data = Box::new(WindowUDFPrivateData { udf });
221
222        Self {
223            name,
224            aliases,
225            volatility,
226            partition_evaluator: partition_evaluator_fn_wrapper,
227            sort_options,
228            coerce_types: coerce_types_fn_wrapper,
229            field: field_fn_wrapper,
230            clone: clone_fn_wrapper,
231            release: release_fn_wrapper,
232            private_data: Box::into_raw(private_data) as *mut c_void,
233        }
234    }
235}
236
237impl Drop for FFI_WindowUDF {
238    fn drop(&mut self) {
239        unsafe { (self.release)(self) }
240    }
241}
242
243/// This struct is used to access an UDF provided by a foreign
244/// library across a FFI boundary.
245///
246/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has
247/// no knowledge or access to the private data. All interaction with the UDF
248/// must occur through the functions defined in FFI_WindowUDF.
249#[derive(Debug)]
250pub struct ForeignWindowUDF {
251    name: String,
252    aliases: Vec<String>,
253    udf: FFI_WindowUDF,
254    signature: Signature,
255}
256
257unsafe impl Send for ForeignWindowUDF {}
258unsafe impl Sync for ForeignWindowUDF {}
259
260impl PartialEq for ForeignWindowUDF {
261    fn eq(&self, other: &Self) -> bool {
262        // FFI_WindowUDF cannot be compared, so identity equality is the best we can do.
263        std::ptr::eq(self, other)
264    }
265}
266impl Eq for ForeignWindowUDF {}
267impl Hash for ForeignWindowUDF {
268    fn hash<H: Hasher>(&self, state: &mut H) {
269        std::ptr::hash(self, state)
270    }
271}
272
273impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF {
274    type Error = DataFusionError;
275
276    fn try_from(udf: &FFI_WindowUDF) -> Result<Self, Self::Error> {
277        let name = udf.name.to_owned().into();
278        let signature = Signature::user_defined((&udf.volatility).into());
279
280        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
281
282        Ok(Self {
283            name,
284            udf: udf.clone(),
285            aliases,
286            signature,
287        })
288    }
289}
290
291impl WindowUDFImpl for ForeignWindowUDF {
292    fn as_any(&self) -> &dyn std::any::Any {
293        self
294    }
295
296    fn name(&self) -> &str {
297        &self.name
298    }
299
300    fn signature(&self) -> &Signature {
301        &self.signature
302    }
303
304    fn aliases(&self) -> &[String] {
305        &self.aliases
306    }
307
308    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
309        unsafe {
310            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
311            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
312            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
313        }
314    }
315
316    fn partition_evaluator(
317        &self,
318        args: datafusion::logical_expr::function::PartitionEvaluatorArgs,
319    ) -> Result<Box<dyn PartitionEvaluator>> {
320        let evaluator = unsafe {
321            let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
322            (self.udf.partition_evaluator)(&self.udf, args)
323        };
324
325        df_result!(evaluator).map(|evaluator| {
326            Box::new(ForeignPartitionEvaluator::from(evaluator))
327                as Box<dyn PartitionEvaluator>
328        })
329    }
330
331    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
332        unsafe {
333            let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
334            let schema = df_result!((self.udf.field)(
335                &self.udf,
336                input_types,
337                field_args.name().into()
338            ))?;
339            let schema: SchemaRef = schema.into();
340
341            match schema.fields().is_empty() {
342                true => exec_err!(
343                    "Unable to retrieve field in WindowUDF via FFI - schema has no fields"
344                ),
345                false => Ok(schema.field(0).to_owned().into()),
346            }
347        }
348    }
349
350    fn sort_options(&self) -> Option<SortOptions> {
351        let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
352        options.map(|s| s.into())
353    }
354
355    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
356        LimitEffect::Unknown
357    }
358}
359
360#[repr(C)]
361#[derive(Debug, StableAbi, Clone)]
362#[allow(non_camel_case_types)]
363pub struct FFI_SortOptions {
364    pub descending: bool,
365    pub nulls_first: bool,
366}
367
368impl From<&SortOptions> for FFI_SortOptions {
369    fn from(value: &SortOptions) -> Self {
370        Self {
371            descending: value.descending,
372            nulls_first: value.nulls_first,
373        }
374    }
375}
376
377impl From<&FFI_SortOptions> for SortOptions {
378    fn from(value: &FFI_SortOptions) -> Self {
379        Self {
380            descending: value.descending,
381            nulls_first: value.nulls_first,
382        }
383    }
384}
385
386#[cfg(test)]
387#[cfg(feature = "integration-tests")]
388mod tests {
389    use crate::tests::create_record_batch;
390    use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
391    use arrow::array::{create_array, ArrayRef};
392    use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift};
393    use datafusion::logical_expr::expr::Sort;
394    use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl};
395    use datafusion::prelude::SessionContext;
396    use std::sync::Arc;
397
398    fn create_test_foreign_udwf(
399        original_udwf: impl WindowUDFImpl + 'static,
400    ) -> datafusion::common::Result<WindowUDF> {
401        let original_udwf = Arc::new(WindowUDF::from(original_udwf));
402
403        let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
404
405        let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
406        Ok(foreign_udwf.into())
407    }
408
409    #[test]
410    fn test_round_trip_udwf() -> datafusion::common::Result<()> {
411        let original_udwf = lag_udwf();
412        let original_name = original_udwf.name().to_owned();
413
414        // Convert to FFI format
415        let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
416
417        // Convert back to native format
418        let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
419        let foreign_udwf: WindowUDF = foreign_udwf.into();
420
421        assert_eq!(original_name, foreign_udwf.name());
422        Ok(())
423    }
424
425    #[tokio::test]
426    async fn test_lag_udwf() -> datafusion::common::Result<()> {
427        let udwf = create_test_foreign_udwf(WindowShift::lag())?;
428
429        let ctx = SessionContext::default();
430        let df = ctx.read_batch(create_record_batch(-5, 5))?;
431
432        let df = df.select(vec![
433            col("a"),
434            udwf.call(vec![col("a")])
435                .order_by(vec![Sort::new(col("a"), true, true)])
436                .build()
437                .unwrap()
438                .alias("lag_a"),
439        ])?;
440
441        df.clone().show().await?;
442
443        let result = df.collect().await?;
444        let expected =
445            create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
446                as ArrayRef;
447
448        assert_eq!(result.len(), 1);
449        assert_eq!(result[0].column(1), &expected);
450
451        Ok(())
452    }
453}