datafusion_ffi/udf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{RResult, RString, RVec};
24use arrow::array::ArrayRef;
25use arrow::datatypes::{DataType, Field};
26use arrow::error::ArrowError;
27use arrow::ffi::{FFI_ArrowSchema, from_ffi, to_ffi};
28use arrow_schema::FieldRef;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::{DataFusionError, Result, internal_err};
31use datafusion_expr::type_coercion::functions::fields_with_udf;
32use datafusion_expr::{
33    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
34    Signature,
35};
36use return_type_args::{
37    FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
38};
39
40use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
41use crate::util::{
42    FFIResult, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped,
43};
44use crate::volatility::FFI_Volatility;
45use crate::{df_result, rresult, rresult_return};
46
47pub mod return_type_args;
48
49/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries.
50#[repr(C)]
51#[derive(Debug, StableAbi)]
52pub struct FFI_ScalarUDF {
53    /// FFI equivalent to the `name` of a [`ScalarUDF`]
54    pub name: RString,
55
56    /// FFI equivalent to the `aliases` of a [`ScalarUDF`]
57    pub aliases: RVec<RString>,
58
59    /// FFI equivalent to the `volatility` of a [`ScalarUDF`]
60    pub volatility: FFI_Volatility,
61
62    /// Determines the return info of the underlying [`ScalarUDF`].
63    pub return_field_from_args: unsafe extern "C" fn(
64        udf: &Self,
65        args: FFI_ReturnFieldArgs,
66    ) -> FFIResult<WrappedSchema>,
67
68    /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray`
69    /// within an AbiStable wrapper.
70    pub invoke_with_args: unsafe extern "C" fn(
71        udf: &Self,
72        args: RVec<WrappedArray>,
73        arg_fields: RVec<WrappedSchema>,
74        num_rows: usize,
75        return_field: WrappedSchema,
76    ) -> FFIResult<WrappedArray>,
77
78    /// See [`ScalarUDFImpl`] for details on short_circuits
79    pub short_circuits: bool,
80
81    /// Performs type coercion. To simply this interface, all UDFs are treated as having
82    /// user defined signatures, which will in turn call coerce_types to be called. This
83    /// call should be transparent to most users as the internal function performs the
84    /// appropriate calls on the underlying [`ScalarUDF`]
85    pub coerce_types: unsafe extern "C" fn(
86        udf: &Self,
87        arg_types: RVec<WrappedSchema>,
88    ) -> FFIResult<RVec<WrappedSchema>>,
89
90    /// Used to create a clone on the provider of the udf. This should
91    /// only need to be called by the receiver of the udf.
92    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
93
94    /// Release the memory of the private data when it is no longer being used.
95    pub release: unsafe extern "C" fn(udf: &mut Self),
96
97    /// Internal data. This is only to be accessed by the provider of the udf.
98    /// A [`ForeignScalarUDF`] should never attempt to access this data.
99    pub private_data: *mut c_void,
100
101    /// Utility to identify when FFI objects are accessed locally through
102    /// the foreign interface. See [`crate::get_library_marker_id`] and
103    /// the crate's `README.md` for more information.
104    pub library_marker_id: extern "C" fn() -> usize,
105}
106
107unsafe impl Send for FFI_ScalarUDF {}
108unsafe impl Sync for FFI_ScalarUDF {}
109
110pub struct ScalarUDFPrivateData {
111    pub udf: Arc<ScalarUDF>,
112}
113
114impl FFI_ScalarUDF {
115    fn inner(&self) -> &Arc<ScalarUDF> {
116        let private_data = self.private_data as *const ScalarUDFPrivateData;
117        unsafe { &(*private_data).udf }
118    }
119}
120
121unsafe extern "C" fn return_field_from_args_fn_wrapper(
122    udf: &FFI_ScalarUDF,
123    args: FFI_ReturnFieldArgs,
124) -> FFIResult<WrappedSchema> {
125    let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
126    let args_ref: ForeignReturnFieldArgs = (&args).into();
127
128    let return_type = udf
129        .inner()
130        .return_field_from_args((&args_ref).into())
131        .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
132        .map(WrappedSchema);
133
134    rresult!(return_type)
135}
136
137unsafe extern "C" fn coerce_types_fn_wrapper(
138    udf: &FFI_ScalarUDF,
139    arg_types: RVec<WrappedSchema>,
140) -> FFIResult<RVec<WrappedSchema>> {
141    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
142
143    let arg_fields = arg_types
144        .iter()
145        .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
146        .collect::<Vec<_>>();
147    let return_types =
148        rresult_return!(fields_with_udf(&arg_fields, udf.inner().as_ref()))
149            .into_iter()
150            .map(|f| f.data_type().to_owned())
151            .collect::<Vec<_>>();
152
153    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
154}
155
156unsafe extern "C" fn invoke_with_args_fn_wrapper(
157    udf: &FFI_ScalarUDF,
158    args: RVec<WrappedArray>,
159    arg_fields: RVec<WrappedSchema>,
160    number_rows: usize,
161    return_field: WrappedSchema,
162) -> FFIResult<WrappedArray> {
163    unsafe {
164        let args = args
165            .into_iter()
166            .map(|arr| {
167                from_ffi(arr.array, &arr.schema.0)
168                    .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
169            })
170            .collect::<std::result::Result<_, _>>();
171
172        let args = rresult_return!(args);
173        let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
174
175        let arg_fields = arg_fields
176            .into_iter()
177            .map(|wrapped_field| {
178                Field::try_from(&wrapped_field.0)
179                    .map(Arc::new)
180                    .map_err(DataFusionError::from)
181            })
182            .collect::<Result<Vec<FieldRef>>>();
183        let arg_fields = rresult_return!(arg_fields);
184
185        let args = ScalarFunctionArgs {
186            args,
187            arg_fields,
188            number_rows,
189            return_field,
190            // TODO: pass config options: https://github.com/apache/datafusion/issues/17035
191            config_options: Arc::new(ConfigOptions::default()),
192        };
193
194        let result = rresult_return!(
195            udf.inner()
196                .invoke_with_args(args)
197                .and_then(|r| r.to_array(number_rows))
198        );
199
200        let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
201
202        RResult::ROk(WrappedArray {
203            array: result_array,
204            schema: WrappedSchema(result_schema),
205        })
206    }
207}
208
209unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
210    unsafe {
211        debug_assert!(!udf.private_data.is_null());
212        let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
213        drop(private_data);
214        udf.private_data = std::ptr::null_mut();
215    }
216}
217
218unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
219    unsafe {
220        let private_data = udf.private_data as *const ScalarUDFPrivateData;
221        let udf_data = &(*private_data);
222
223        Arc::clone(&udf_data.udf).into()
224    }
225}
226
227impl Clone for FFI_ScalarUDF {
228    fn clone(&self) -> Self {
229        unsafe { (self.clone)(self) }
230    }
231}
232
233impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
234    fn from(udf: Arc<ScalarUDF>) -> Self {
235        let name = udf.name().into();
236        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
237        let volatility = udf.signature().volatility.into();
238        let short_circuits = udf.short_circuits();
239
240        let private_data = Box::new(ScalarUDFPrivateData { udf });
241
242        Self {
243            name,
244            aliases,
245            volatility,
246            short_circuits,
247            invoke_with_args: invoke_with_args_fn_wrapper,
248            return_field_from_args: return_field_from_args_fn_wrapper,
249            coerce_types: coerce_types_fn_wrapper,
250            clone: clone_fn_wrapper,
251            release: release_fn_wrapper,
252            private_data: Box::into_raw(private_data) as *mut c_void,
253            library_marker_id: crate::get_library_marker_id,
254        }
255    }
256}
257
258impl Drop for FFI_ScalarUDF {
259    fn drop(&mut self) {
260        unsafe { (self.release)(self) }
261    }
262}
263
264/// This struct is used to access an UDF provided by a foreign
265/// library across a FFI boundary.
266///
267/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has
268/// no knowledge or access to the private data. All interaction with the UDF
269/// must occur through the functions defined in FFI_ScalarUDF.
270#[derive(Debug)]
271pub struct ForeignScalarUDF {
272    name: String,
273    aliases: Vec<String>,
274    udf: FFI_ScalarUDF,
275    signature: Signature,
276}
277
278unsafe impl Send for ForeignScalarUDF {}
279unsafe impl Sync for ForeignScalarUDF {}
280
281impl PartialEq for ForeignScalarUDF {
282    fn eq(&self, other: &Self) -> bool {
283        let Self {
284            name,
285            aliases,
286            udf,
287            signature,
288        } = self;
289        name == &other.name
290            && aliases == &other.aliases
291            && std::ptr::eq(udf, &other.udf)
292            && signature == &other.signature
293    }
294}
295impl Eq for ForeignScalarUDF {}
296
297impl Hash for ForeignScalarUDF {
298    fn hash<H: Hasher>(&self, state: &mut H) {
299        let Self {
300            name,
301            aliases,
302            udf,
303            signature,
304        } = self;
305        name.hash(state);
306        aliases.hash(state);
307        std::ptr::hash(udf, state);
308        signature.hash(state);
309    }
310}
311
312impl From<&FFI_ScalarUDF> for Arc<dyn ScalarUDFImpl> {
313    fn from(udf: &FFI_ScalarUDF) -> Self {
314        if (udf.library_marker_id)() == crate::get_library_marker_id() {
315            Arc::clone(udf.inner().inner())
316        } else {
317            let name = udf.name.to_owned().into();
318            let signature = Signature::user_defined((&udf.volatility).into());
319
320            let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
321
322            Arc::new(ForeignScalarUDF {
323                name,
324                udf: udf.clone(),
325                aliases,
326                signature,
327            })
328        }
329    }
330}
331
332impl ScalarUDFImpl for ForeignScalarUDF {
333    fn as_any(&self) -> &dyn std::any::Any {
334        self
335    }
336
337    fn name(&self) -> &str {
338        &self.name
339    }
340
341    fn signature(&self) -> &Signature {
342        &self.signature
343    }
344
345    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
346        internal_err!("ForeignScalarUDF implements return_field_from_args instead.")
347    }
348
349    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
350        let args: FFI_ReturnFieldArgs = args.try_into()?;
351
352        let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
353
354        let result = df_result!(result);
355
356        result.and_then(|r| {
357            Field::try_from(&r.0)
358                .map(Arc::new)
359                .map_err(DataFusionError::from)
360        })
361    }
362
363    fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
364        let ScalarFunctionArgs {
365            args,
366            arg_fields,
367            number_rows,
368            return_field,
369            // TODO: pass config options: https://github.com/apache/datafusion/issues/17035
370            config_options: _config_options,
371        } = invoke_args;
372
373        let args = args
374            .into_iter()
375            .map(|v| v.to_array(number_rows))
376            .collect::<Result<Vec<_>>>()?
377            .into_iter()
378            .map(|v| {
379                to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
380                    array: ffi_array,
381                    schema: WrappedSchema(ffi_schema),
382                })
383            })
384            .collect::<std::result::Result<Vec<_>, ArrowError>>()?
385            .into();
386
387        let arg_fields_wrapped = arg_fields
388            .iter()
389            .map(FFI_ArrowSchema::try_from)
390            .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
391
392        let arg_fields = arg_fields_wrapped
393            .into_iter()
394            .map(WrappedSchema)
395            .collect::<RVec<_>>();
396
397        let return_field = return_field.as_ref().clone();
398        let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
399
400        let result = unsafe {
401            (self.udf.invoke_with_args)(
402                &self.udf,
403                args,
404                arg_fields,
405                number_rows,
406                return_field,
407            )
408        };
409
410        let result = df_result!(result)?;
411        let result_array: ArrayRef = result.try_into()?;
412
413        Ok(ColumnarValue::Array(result_array))
414    }
415
416    fn aliases(&self) -> &[String] {
417        &self.aliases
418    }
419
420    fn short_circuits(&self) -> bool {
421        self.udf.short_circuits
422    }
423
424    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
425        unsafe {
426            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
427            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
428            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_round_trip_scalar_udf() -> Result<()> {
439        let original_udf = datafusion::functions::math::abs::AbsFunc::new();
440        let original_udf = Arc::new(ScalarUDF::from(original_udf));
441
442        let mut local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
443        local_udf.library_marker_id = crate::mock_foreign_marker_id;
444
445        let foreign_udf: Arc<dyn ScalarUDFImpl> = (&local_udf).into();
446
447        assert_eq!(original_udf.name(), foreign_udf.name());
448
449        Ok(())
450    }
451
452    #[test]
453    fn test_ffi_udf_local_bypass() -> Result<()> {
454        use datafusion::functions::math::abs::AbsFunc;
455        let original_udf = AbsFunc::new();
456        let original_udf = Arc::new(ScalarUDF::from(original_udf));
457
458        let mut ffi_udf = FFI_ScalarUDF::from(original_udf);
459
460        // Verify local libraries can be downcast to their original
461        let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
462        assert!(foreign_udf.as_any().downcast_ref::<AbsFunc>().is_some());
463
464        // Verify different library markers generate foreign providers
465        ffi_udf.library_marker_id = crate::mock_foreign_marker_id;
466        let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
467        assert!(
468            foreign_udf
469                .as_any()
470                .downcast_ref::<ForeignScalarUDF>()
471                .is_some()
472        );
473
474        Ok(())
475    }
476}