datafusion_ffi/udf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    arrow_wrappers::{WrappedArray, WrappedSchema},
20    df_result, rresult, rresult_return,
21    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22    volatility::FFI_Volatility,
23};
24use abi_stable::{
25    std_types::{RResult, RString, RVec},
26    StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30    array::ArrayRef,
31    error::ArrowError,
32    ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::logical_expr::ReturnFieldArgs;
36use datafusion::{
37    error::DataFusionError,
38    logical_expr::type_coercion::functions::data_types_with_scalar_udf,
39};
40use datafusion::{
41    error::Result,
42    logical_expr::{
43        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
44    },
45};
46use return_type_args::{
47    FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
48};
49use std::hash::{DefaultHasher, Hash, Hasher};
50use std::{ffi::c_void, sync::Arc};
51
52pub mod return_type_args;
53
54/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries.
55#[repr(C)]
56#[derive(Debug, StableAbi)]
57#[allow(non_camel_case_types)]
58pub struct FFI_ScalarUDF {
59    /// FFI equivalent to the `name` of a [`ScalarUDF`]
60    pub name: RString,
61
62    /// FFI equivalent to the `aliases` of a [`ScalarUDF`]
63    pub aliases: RVec<RString>,
64
65    /// FFI equivalent to the `volatility` of a [`ScalarUDF`]
66    pub volatility: FFI_Volatility,
67
68    /// Determines the return type of the underlying [`ScalarUDF`] based on the
69    /// argument types.
70    pub return_type: unsafe extern "C" fn(
71        udf: &Self,
72        arg_types: RVec<WrappedSchema>,
73    ) -> RResult<WrappedSchema, RString>,
74
75    /// Determines the return info of the underlying [`ScalarUDF`]. Either this
76    /// or return_type may be implemented on a UDF.
77    pub return_field_from_args: unsafe extern "C" fn(
78        udf: &Self,
79        args: FFI_ReturnFieldArgs,
80    )
81        -> RResult<WrappedSchema, RString>,
82
83    /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray`
84    /// within an AbiStable wrapper.
85    #[allow(clippy::type_complexity)]
86    pub invoke_with_args: unsafe extern "C" fn(
87        udf: &Self,
88        args: RVec<WrappedArray>,
89        arg_fields: RVec<WrappedSchema>,
90        num_rows: usize,
91        return_field: WrappedSchema,
92    ) -> RResult<WrappedArray, RString>,
93
94    /// See [`ScalarUDFImpl`] for details on short_circuits
95    pub short_circuits: bool,
96
97    /// Performs type coersion. To simply this interface, all UDFs are treated as having
98    /// user defined signatures, which will in turn call coerce_types to be called. This
99    /// call should be transparent to most users as the internal function performs the
100    /// appropriate calls on the underlying [`ScalarUDF`]
101    pub coerce_types: unsafe extern "C" fn(
102        udf: &Self,
103        arg_types: RVec<WrappedSchema>,
104    ) -> RResult<RVec<WrappedSchema>, RString>,
105
106    /// Used to create a clone on the provider of the udf. This should
107    /// only need to be called by the receiver of the udf.
108    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
109
110    /// Release the memory of the private data when it is no longer being used.
111    pub release: unsafe extern "C" fn(udf: &mut Self),
112
113    /// Internal data. This is only to be accessed by the provider of the udf.
114    /// A [`ForeignScalarUDF`] should never attempt to access this data.
115    pub private_data: *mut c_void,
116}
117
118unsafe impl Send for FFI_ScalarUDF {}
119unsafe impl Sync for FFI_ScalarUDF {}
120
121pub struct ScalarUDFPrivateData {
122    pub udf: Arc<ScalarUDF>,
123}
124
125unsafe extern "C" fn return_type_fn_wrapper(
126    udf: &FFI_ScalarUDF,
127    arg_types: RVec<WrappedSchema>,
128) -> RResult<WrappedSchema, RString> {
129    let private_data = udf.private_data as *const ScalarUDFPrivateData;
130    let udf = &(*private_data).udf;
131
132    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
133
134    let return_type = udf
135        .return_type(&arg_types)
136        .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
137        .map(WrappedSchema);
138
139    rresult!(return_type)
140}
141
142unsafe extern "C" fn return_field_from_args_fn_wrapper(
143    udf: &FFI_ScalarUDF,
144    args: FFI_ReturnFieldArgs,
145) -> RResult<WrappedSchema, RString> {
146    let private_data = udf.private_data as *const ScalarUDFPrivateData;
147    let udf = &(*private_data).udf;
148
149    let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
150    let args_ref: ForeignReturnFieldArgs = (&args).into();
151
152    let return_type = udf
153        .return_field_from_args((&args_ref).into())
154        .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
155        .map(WrappedSchema);
156
157    rresult!(return_type)
158}
159
160unsafe extern "C" fn coerce_types_fn_wrapper(
161    udf: &FFI_ScalarUDF,
162    arg_types: RVec<WrappedSchema>,
163) -> RResult<RVec<WrappedSchema>, RString> {
164    let private_data = udf.private_data as *const ScalarUDFPrivateData;
165    let udf = &(*private_data).udf;
166
167    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
168
169    let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
170
171    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
172}
173
174unsafe extern "C" fn invoke_with_args_fn_wrapper(
175    udf: &FFI_ScalarUDF,
176    args: RVec<WrappedArray>,
177    arg_fields: RVec<WrappedSchema>,
178    number_rows: usize,
179    return_field: WrappedSchema,
180) -> RResult<WrappedArray, RString> {
181    let private_data = udf.private_data as *const ScalarUDFPrivateData;
182    let udf = &(*private_data).udf;
183
184    let args = args
185        .into_iter()
186        .map(|arr| {
187            from_ffi(arr.array, &arr.schema.0)
188                .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
189        })
190        .collect::<std::result::Result<_, _>>();
191
192    let args = rresult_return!(args);
193    let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
194
195    let arg_fields = arg_fields
196        .into_iter()
197        .map(|wrapped_field| {
198            Field::try_from(&wrapped_field.0)
199                .map(Arc::new)
200                .map_err(DataFusionError::from)
201        })
202        .collect::<Result<Vec<FieldRef>>>();
203    let arg_fields = rresult_return!(arg_fields);
204
205    let args = ScalarFunctionArgs {
206        args,
207        arg_fields,
208        number_rows,
209        return_field,
210    };
211
212    let result = rresult_return!(udf
213        .invoke_with_args(args)
214        .and_then(|r| r.to_array(number_rows)));
215
216    let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
217
218    RResult::ROk(WrappedArray {
219        array: result_array,
220        schema: WrappedSchema(result_schema),
221    })
222}
223
224unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
225    let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
226    drop(private_data);
227}
228
229unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
230    let private_data = udf.private_data as *const ScalarUDFPrivateData;
231    let udf_data = &(*private_data);
232
233    Arc::clone(&udf_data.udf).into()
234}
235
236impl Clone for FFI_ScalarUDF {
237    fn clone(&self) -> Self {
238        unsafe { (self.clone)(self) }
239    }
240}
241
242impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
243    fn from(udf: Arc<ScalarUDF>) -> Self {
244        let name = udf.name().into();
245        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
246        let volatility = udf.signature().volatility.into();
247        let short_circuits = udf.short_circuits();
248
249        let private_data = Box::new(ScalarUDFPrivateData { udf });
250
251        Self {
252            name,
253            aliases,
254            volatility,
255            short_circuits,
256            invoke_with_args: invoke_with_args_fn_wrapper,
257            return_type: return_type_fn_wrapper,
258            return_field_from_args: return_field_from_args_fn_wrapper,
259            coerce_types: coerce_types_fn_wrapper,
260            clone: clone_fn_wrapper,
261            release: release_fn_wrapper,
262            private_data: Box::into_raw(private_data) as *mut c_void,
263        }
264    }
265}
266
267impl Drop for FFI_ScalarUDF {
268    fn drop(&mut self) {
269        unsafe { (self.release)(self) }
270    }
271}
272
273/// This struct is used to access an UDF provided by a foreign
274/// library across a FFI boundary.
275///
276/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has
277/// no knowledge or access to the private data. All interaction with the UDF
278/// must occur through the functions defined in FFI_ScalarUDF.
279#[derive(Debug)]
280pub struct ForeignScalarUDF {
281    name: String,
282    aliases: Vec<String>,
283    udf: FFI_ScalarUDF,
284    signature: Signature,
285}
286
287unsafe impl Send for ForeignScalarUDF {}
288unsafe impl Sync for ForeignScalarUDF {}
289
290impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
291    type Error = DataFusionError;
292
293    fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
294        let name = udf.name.to_owned().into();
295        let signature = Signature::user_defined((&udf.volatility).into());
296
297        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
298
299        Ok(Self {
300            name,
301            udf: udf.clone(),
302            aliases,
303            signature,
304        })
305    }
306}
307
308impl ScalarUDFImpl for ForeignScalarUDF {
309    fn as_any(&self) -> &dyn std::any::Any {
310        self
311    }
312
313    fn name(&self) -> &str {
314        &self.name
315    }
316
317    fn signature(&self) -> &Signature {
318        &self.signature
319    }
320
321    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
322        let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
323
324        let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
325
326        let result = df_result!(result);
327
328        result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
329    }
330
331    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
332        let args: FFI_ReturnFieldArgs = args.try_into()?;
333
334        let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
335
336        let result = df_result!(result);
337
338        result.and_then(|r| {
339            Field::try_from(&r.0)
340                .map(Arc::new)
341                .map_err(DataFusionError::from)
342        })
343    }
344
345    fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
346        let ScalarFunctionArgs {
347            args,
348            arg_fields,
349            number_rows,
350            return_field,
351        } = invoke_args;
352
353        let args = args
354            .into_iter()
355            .map(|v| v.to_array(number_rows))
356            .collect::<Result<Vec<_>>>()?
357            .into_iter()
358            .map(|v| {
359                to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
360                    array: ffi_array,
361                    schema: WrappedSchema(ffi_schema),
362                })
363            })
364            .collect::<std::result::Result<Vec<_>, ArrowError>>()?
365            .into();
366
367        let arg_fields_wrapped = arg_fields
368            .iter()
369            .map(FFI_ArrowSchema::try_from)
370            .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
371
372        let arg_fields = arg_fields_wrapped
373            .into_iter()
374            .map(WrappedSchema)
375            .collect::<RVec<_>>();
376
377        let return_field = return_field.as_ref().clone();
378        let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
379
380        let result = unsafe {
381            (self.udf.invoke_with_args)(
382                &self.udf,
383                args,
384                arg_fields,
385                number_rows,
386                return_field,
387            )
388        };
389
390        let result = df_result!(result)?;
391        let result_array: ArrayRef = result.try_into()?;
392
393        Ok(ColumnarValue::Array(result_array))
394    }
395
396    fn aliases(&self) -> &[String] {
397        &self.aliases
398    }
399
400    fn short_circuits(&self) -> bool {
401        self.udf.short_circuits
402    }
403
404    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
405        unsafe {
406            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
407            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
408            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
409        }
410    }
411
412    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
413        let Some(other) = other.as_any().downcast_ref::<Self>() else {
414            return false;
415        };
416        let Self {
417            name,
418            aliases,
419            udf,
420            signature,
421        } = self;
422        name == &other.name
423            && aliases == &other.aliases
424            && std::ptr::eq(udf, &other.udf)
425            && signature == &other.signature
426    }
427
428    fn hash_value(&self) -> u64 {
429        let Self {
430            name,
431            aliases,
432            udf,
433            signature,
434        } = self;
435        let mut hasher = DefaultHasher::new();
436        std::any::type_name::<Self>().hash(&mut hasher);
437        name.hash(&mut hasher);
438        aliases.hash(&mut hasher);
439        std::ptr::hash(udf, &mut hasher);
440        signature.hash(&mut hasher);
441        hasher.finish()
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_round_trip_scalar_udf() -> Result<()> {
451        let original_udf = datafusion::functions::math::abs::AbsFunc::new();
452        let original_udf = Arc::new(ScalarUDF::from(original_udf));
453
454        let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
455
456        let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
457
458        assert_eq!(original_udf.name(), foreign_udf.name());
459
460        Ok(())
461    }
462}