datafusion_ffi/udwf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use abi_stable::{
19    std_types::{ROption, RResult, RString, RVec},
20    StableAbi,
21};
22use arrow::datatypes::Schema;
23use arrow::{
24    compute::SortOptions,
25    datatypes::{DataType, SchemaRef},
26};
27use arrow_schema::{Field, FieldRef};
28use datafusion::{
29    error::DataFusionError,
30    logical_expr::{
31        function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf,
32        PartitionEvaluator,
33    },
34};
35use datafusion::{
36    error::Result,
37    logical_expr::{Signature, WindowUDF, WindowUDFImpl},
38};
39use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator};
40use partition_evaluator_args::{
41    FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
42};
43use std::hash::{Hash, Hasher};
44use std::{ffi::c_void, sync::Arc};
45
46mod partition_evaluator;
47mod partition_evaluator_args;
48mod range;
49
50use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
51use crate::{
52    arrow_wrappers::WrappedSchema,
53    df_result, rresult, rresult_return,
54    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
55    volatility::FFI_Volatility,
56};
57
58/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries.
59#[repr(C)]
60#[derive(Debug, StableAbi)]
61#[allow(non_camel_case_types)]
62pub struct FFI_WindowUDF {
63    /// FFI equivalent to the `name` of a [`WindowUDF`]
64    pub name: RString,
65
66    /// FFI equivalent to the `aliases` of a [`WindowUDF`]
67    pub aliases: RVec<RString>,
68
69    /// FFI equivalent to the `volatility` of a [`WindowUDF`]
70    pub volatility: FFI_Volatility,
71
72    pub partition_evaluator:
73        unsafe extern "C" fn(
74            udwf: &Self,
75            args: FFI_PartitionEvaluatorArgs,
76        ) -> RResult<FFI_PartitionEvaluator, RString>,
77
78    pub field: unsafe extern "C" fn(
79        udwf: &Self,
80        input_types: RVec<WrappedSchema>,
81        display_name: RString,
82    ) -> RResult<WrappedSchema, RString>,
83
84    /// Performs type coercion. To simply this interface, all UDFs are treated as having
85    /// user defined signatures, which will in turn call coerce_types to be called. This
86    /// call should be transparent to most users as the internal function performs the
87    /// appropriate calls on the underlying [`WindowUDF`]
88    pub coerce_types: unsafe extern "C" fn(
89        udf: &Self,
90        arg_types: RVec<WrappedSchema>,
91    ) -> RResult<RVec<WrappedSchema>, RString>,
92
93    pub sort_options: ROption<FFI_SortOptions>,
94
95    /// Used to create a clone on the provider of the udf. This should
96    /// only need to be called by the receiver of the udf.
97    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
98
99    /// Release the memory of the private data when it is no longer being used.
100    pub release: unsafe extern "C" fn(udf: &mut Self),
101
102    /// Internal data. This is only to be accessed by the provider of the udf.
103    /// A [`ForeignWindowUDF`] should never attempt to access this data.
104    pub private_data: *mut c_void,
105}
106
107unsafe impl Send for FFI_WindowUDF {}
108unsafe impl Sync for FFI_WindowUDF {}
109
110pub struct WindowUDFPrivateData {
111    pub udf: Arc<WindowUDF>,
112}
113
114impl FFI_WindowUDF {
115    unsafe fn inner(&self) -> &Arc<WindowUDF> {
116        let private_data = self.private_data as *const WindowUDFPrivateData;
117        &(*private_data).udf
118    }
119}
120
121unsafe extern "C" fn partition_evaluator_fn_wrapper(
122    udwf: &FFI_WindowUDF,
123    args: FFI_PartitionEvaluatorArgs,
124) -> RResult<FFI_PartitionEvaluator, RString> {
125    let inner = udwf.inner();
126
127    let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
128
129    let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into()));
130
131    RResult::ROk(evaluator.into())
132}
133
134unsafe extern "C" fn field_fn_wrapper(
135    udwf: &FFI_WindowUDF,
136    input_fields: RVec<WrappedSchema>,
137    display_name: RString,
138) -> RResult<WrappedSchema, RString> {
139    let inner = udwf.inner();
140
141    let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
142
143    let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
144        &input_fields,
145        display_name.as_str()
146    )));
147
148    let schema = Arc::new(Schema::new(vec![field]));
149
150    RResult::ROk(WrappedSchema::from(schema))
151}
152
153unsafe extern "C" fn coerce_types_fn_wrapper(
154    udwf: &FFI_WindowUDF,
155    arg_types: RVec<WrappedSchema>,
156) -> RResult<RVec<WrappedSchema>, RString> {
157    let inner = udwf.inner();
158
159    let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
160        .into_iter()
161        .map(|dt| Field::new("f", dt, false))
162        .map(Arc::new)
163        .collect::<Vec<_>>();
164
165    let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner));
166    let return_types = return_fields
167        .into_iter()
168        .map(|f| f.data_type().to_owned())
169        .collect::<Vec<_>>();
170
171    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
172}
173
174unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
175    let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
176    drop(private_data);
177}
178
179unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
180    // let private_data = udf.private_data as *const WindowUDFPrivateData;
181    // let udf_data = &(*private_data);
182
183    // let private_data = Box::new(WindowUDFPrivateData {
184    //     udf: Arc::clone(&udf_data.udf),
185    // });
186    let private_data = Box::new(WindowUDFPrivateData {
187        udf: Arc::clone(udwf.inner()),
188    });
189
190    FFI_WindowUDF {
191        name: udwf.name.clone(),
192        aliases: udwf.aliases.clone(),
193        volatility: udwf.volatility.clone(),
194        partition_evaluator: partition_evaluator_fn_wrapper,
195        sort_options: udwf.sort_options.clone(),
196        coerce_types: coerce_types_fn_wrapper,
197        field: field_fn_wrapper,
198        clone: clone_fn_wrapper,
199        release: release_fn_wrapper,
200        private_data: Box::into_raw(private_data) as *mut c_void,
201    }
202}
203
204impl Clone for FFI_WindowUDF {
205    fn clone(&self) -> Self {
206        unsafe { (self.clone)(self) }
207    }
208}
209
210impl From<Arc<WindowUDF>> for FFI_WindowUDF {
211    fn from(udf: Arc<WindowUDF>) -> Self {
212        let name = udf.name().into();
213        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
214        let volatility = udf.signature().volatility.into();
215        let sort_options = udf.sort_options().map(|v| (&v).into()).into();
216
217        let private_data = Box::new(WindowUDFPrivateData { udf });
218
219        Self {
220            name,
221            aliases,
222            volatility,
223            partition_evaluator: partition_evaluator_fn_wrapper,
224            sort_options,
225            coerce_types: coerce_types_fn_wrapper,
226            field: field_fn_wrapper,
227            clone: clone_fn_wrapper,
228            release: release_fn_wrapper,
229            private_data: Box::into_raw(private_data) as *mut c_void,
230        }
231    }
232}
233
234impl Drop for FFI_WindowUDF {
235    fn drop(&mut self) {
236        unsafe { (self.release)(self) }
237    }
238}
239
240/// This struct is used to access an UDF provided by a foreign
241/// library across a FFI boundary.
242///
243/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has
244/// no knowledge or access to the private data. All interaction with the UDF
245/// must occur through the functions defined in FFI_WindowUDF.
246#[derive(Debug)]
247pub struct ForeignWindowUDF {
248    name: String,
249    aliases: Vec<String>,
250    udf: FFI_WindowUDF,
251    signature: Signature,
252}
253
254unsafe impl Send for ForeignWindowUDF {}
255unsafe impl Sync for ForeignWindowUDF {}
256
257impl PartialEq for ForeignWindowUDF {
258    fn eq(&self, other: &Self) -> bool {
259        // FFI_WindowUDF cannot be compared, so identity equality is the best we can do.
260        std::ptr::eq(self, other)
261    }
262}
263impl Eq for ForeignWindowUDF {}
264impl Hash for ForeignWindowUDF {
265    fn hash<H: Hasher>(&self, state: &mut H) {
266        std::ptr::hash(self, state)
267    }
268}
269
270impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF {
271    type Error = DataFusionError;
272
273    fn try_from(udf: &FFI_WindowUDF) -> Result<Self, Self::Error> {
274        let name = udf.name.to_owned().into();
275        let signature = Signature::user_defined((&udf.volatility).into());
276
277        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
278
279        Ok(Self {
280            name,
281            udf: udf.clone(),
282            aliases,
283            signature,
284        })
285    }
286}
287
288impl WindowUDFImpl for ForeignWindowUDF {
289    fn as_any(&self) -> &dyn std::any::Any {
290        self
291    }
292
293    fn name(&self) -> &str {
294        &self.name
295    }
296
297    fn signature(&self) -> &Signature {
298        &self.signature
299    }
300
301    fn aliases(&self) -> &[String] {
302        &self.aliases
303    }
304
305    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
306        unsafe {
307            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
308            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
309            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
310        }
311    }
312
313    fn partition_evaluator(
314        &self,
315        args: datafusion::logical_expr::function::PartitionEvaluatorArgs,
316    ) -> Result<Box<dyn PartitionEvaluator>> {
317        let evaluator = unsafe {
318            let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
319            (self.udf.partition_evaluator)(&self.udf, args)
320        };
321
322        df_result!(evaluator).map(|evaluator| {
323            Box::new(ForeignPartitionEvaluator::from(evaluator))
324                as Box<dyn PartitionEvaluator>
325        })
326    }
327
328    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
329        unsafe {
330            let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
331            let schema = df_result!((self.udf.field)(
332                &self.udf,
333                input_types,
334                field_args.name().into()
335            ))?;
336            let schema: SchemaRef = schema.into();
337
338            match schema.fields().is_empty() {
339                true => Err(DataFusionError::Execution(
340                    "Unable to retrieve field in WindowUDF via FFI".to_string(),
341                )),
342                false => Ok(schema.field(0).to_owned().into()),
343            }
344        }
345    }
346
347    fn sort_options(&self) -> Option<SortOptions> {
348        let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
349        options.map(|s| s.into())
350    }
351}
352
353#[repr(C)]
354#[derive(Debug, StableAbi, Clone)]
355#[allow(non_camel_case_types)]
356pub struct FFI_SortOptions {
357    pub descending: bool,
358    pub nulls_first: bool,
359}
360
361impl From<&SortOptions> for FFI_SortOptions {
362    fn from(value: &SortOptions) -> Self {
363        Self {
364            descending: value.descending,
365            nulls_first: value.nulls_first,
366        }
367    }
368}
369
370impl From<&FFI_SortOptions> for SortOptions {
371    fn from(value: &FFI_SortOptions) -> Self {
372        Self {
373            descending: value.descending,
374            nulls_first: value.nulls_first,
375        }
376    }
377}
378
379#[cfg(test)]
380#[cfg(feature = "integration-tests")]
381mod tests {
382    use crate::tests::create_record_batch;
383    use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
384    use arrow::array::{create_array, ArrayRef};
385    use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift};
386    use datafusion::logical_expr::expr::Sort;
387    use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl};
388    use datafusion::prelude::SessionContext;
389    use std::sync::Arc;
390
391    fn create_test_foreign_udwf(
392        original_udwf: impl WindowUDFImpl + 'static,
393    ) -> datafusion::common::Result<WindowUDF> {
394        let original_udwf = Arc::new(WindowUDF::from(original_udwf));
395
396        let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
397
398        let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
399        Ok(foreign_udwf.into())
400    }
401
402    #[test]
403    fn test_round_trip_udwf() -> datafusion::common::Result<()> {
404        let original_udwf = lag_udwf();
405        let original_name = original_udwf.name().to_owned();
406
407        // Convert to FFI format
408        let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
409
410        // Convert back to native format
411        let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
412        let foreign_udwf: WindowUDF = foreign_udwf.into();
413
414        assert_eq!(original_name, foreign_udwf.name());
415        Ok(())
416    }
417
418    #[tokio::test]
419    async fn test_lag_udwf() -> datafusion::common::Result<()> {
420        let udwf = create_test_foreign_udwf(WindowShift::lag())?;
421
422        let ctx = SessionContext::default();
423        let df = ctx.read_batch(create_record_batch(-5, 5))?;
424
425        let df = df.select(vec![
426            col("a"),
427            udwf.call(vec![col("a")])
428                .order_by(vec![Sort::new(col("a"), true, true)])
429                .build()
430                .unwrap()
431                .alias("lag_a"),
432        ])?;
433
434        df.clone().show().await?;
435
436        let result = df.collect().await?;
437        let expected =
438            create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
439                as ArrayRef;
440
441        assert_eq!(result.len(), 1);
442        assert_eq!(result[0].column(1), &expected);
443
444        Ok(())
445    }
446}