datafusion_ffi/udf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    arrow_wrappers::{WrappedArray, WrappedSchema},
20    df_result, rresult, rresult_return,
21    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22    volatility::FFI_Volatility,
23};
24use abi_stable::{
25    std_types::{RResult, RString, RVec},
26    StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30    array::ArrayRef,
31    error::ArrowError,
32    ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::config::ConfigOptions;
36use datafusion::logical_expr::ReturnFieldArgs;
37use datafusion::{
38    error::DataFusionError,
39    logical_expr::type_coercion::functions::data_types_with_scalar_udf,
40};
41use datafusion::{
42    error::Result,
43    logical_expr::{
44        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
45    },
46};
47use return_type_args::{
48    FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
49};
50use std::hash::{Hash, Hasher};
51use std::{ffi::c_void, sync::Arc};
52
53pub mod return_type_args;
54
55/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries.
56#[repr(C)]
57#[derive(Debug, StableAbi)]
58#[allow(non_camel_case_types)]
59pub struct FFI_ScalarUDF {
60    /// FFI equivalent to the `name` of a [`ScalarUDF`]
61    pub name: RString,
62
63    /// FFI equivalent to the `aliases` of a [`ScalarUDF`]
64    pub aliases: RVec<RString>,
65
66    /// FFI equivalent to the `volatility` of a [`ScalarUDF`]
67    pub volatility: FFI_Volatility,
68
69    /// Determines the return type of the underlying [`ScalarUDF`] based on the
70    /// argument types.
71    pub return_type: unsafe extern "C" fn(
72        udf: &Self,
73        arg_types: RVec<WrappedSchema>,
74    ) -> RResult<WrappedSchema, RString>,
75
76    /// Determines the return info of the underlying [`ScalarUDF`]. Either this
77    /// or return_type may be implemented on a UDF.
78    pub return_field_from_args: unsafe extern "C" fn(
79        udf: &Self,
80        args: FFI_ReturnFieldArgs,
81    )
82        -> RResult<WrappedSchema, RString>,
83
84    /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray`
85    /// within an AbiStable wrapper.
86    #[allow(clippy::type_complexity)]
87    pub invoke_with_args: unsafe extern "C" fn(
88        udf: &Self,
89        args: RVec<WrappedArray>,
90        arg_fields: RVec<WrappedSchema>,
91        num_rows: usize,
92        return_field: WrappedSchema,
93    ) -> RResult<WrappedArray, RString>,
94
95    /// See [`ScalarUDFImpl`] for details on short_circuits
96    pub short_circuits: bool,
97
98    /// Performs type coercion. To simply this interface, all UDFs are treated as having
99    /// user defined signatures, which will in turn call coerce_types to be called. This
100    /// call should be transparent to most users as the internal function performs the
101    /// appropriate calls on the underlying [`ScalarUDF`]
102    pub coerce_types: unsafe extern "C" fn(
103        udf: &Self,
104        arg_types: RVec<WrappedSchema>,
105    ) -> RResult<RVec<WrappedSchema>, RString>,
106
107    /// Used to create a clone on the provider of the udf. This should
108    /// only need to be called by the receiver of the udf.
109    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
110
111    /// Release the memory of the private data when it is no longer being used.
112    pub release: unsafe extern "C" fn(udf: &mut Self),
113
114    /// Internal data. This is only to be accessed by the provider of the udf.
115    /// A [`ForeignScalarUDF`] should never attempt to access this data.
116    pub private_data: *mut c_void,
117}
118
119unsafe impl Send for FFI_ScalarUDF {}
120unsafe impl Sync for FFI_ScalarUDF {}
121
122pub struct ScalarUDFPrivateData {
123    pub udf: Arc<ScalarUDF>,
124}
125
126unsafe extern "C" fn return_type_fn_wrapper(
127    udf: &FFI_ScalarUDF,
128    arg_types: RVec<WrappedSchema>,
129) -> RResult<WrappedSchema, RString> {
130    let private_data = udf.private_data as *const ScalarUDFPrivateData;
131    let udf = &(*private_data).udf;
132
133    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
134
135    let return_type = udf
136        .return_type(&arg_types)
137        .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
138        .map(WrappedSchema);
139
140    rresult!(return_type)
141}
142
143unsafe extern "C" fn return_field_from_args_fn_wrapper(
144    udf: &FFI_ScalarUDF,
145    args: FFI_ReturnFieldArgs,
146) -> RResult<WrappedSchema, RString> {
147    let private_data = udf.private_data as *const ScalarUDFPrivateData;
148    let udf = &(*private_data).udf;
149
150    let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
151    let args_ref: ForeignReturnFieldArgs = (&args).into();
152
153    let return_type = udf
154        .return_field_from_args((&args_ref).into())
155        .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
156        .map(WrappedSchema);
157
158    rresult!(return_type)
159}
160
161unsafe extern "C" fn coerce_types_fn_wrapper(
162    udf: &FFI_ScalarUDF,
163    arg_types: RVec<WrappedSchema>,
164) -> RResult<RVec<WrappedSchema>, RString> {
165    let private_data = udf.private_data as *const ScalarUDFPrivateData;
166    let udf = &(*private_data).udf;
167
168    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
169
170    let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
171
172    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
173}
174
175unsafe extern "C" fn invoke_with_args_fn_wrapper(
176    udf: &FFI_ScalarUDF,
177    args: RVec<WrappedArray>,
178    arg_fields: RVec<WrappedSchema>,
179    number_rows: usize,
180    return_field: WrappedSchema,
181) -> RResult<WrappedArray, RString> {
182    let private_data = udf.private_data as *const ScalarUDFPrivateData;
183    let udf = &(*private_data).udf;
184
185    let args = args
186        .into_iter()
187        .map(|arr| {
188            from_ffi(arr.array, &arr.schema.0)
189                .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
190        })
191        .collect::<std::result::Result<_, _>>();
192
193    let args = rresult_return!(args);
194    let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
195
196    let arg_fields = arg_fields
197        .into_iter()
198        .map(|wrapped_field| {
199            Field::try_from(&wrapped_field.0)
200                .map(Arc::new)
201                .map_err(DataFusionError::from)
202        })
203        .collect::<Result<Vec<FieldRef>>>();
204    let arg_fields = rresult_return!(arg_fields);
205
206    let args = ScalarFunctionArgs {
207        args,
208        arg_fields,
209        number_rows,
210        return_field,
211        // TODO: pass config options: https://github.com/apache/datafusion/issues/17035
212        config_options: Arc::new(ConfigOptions::default()),
213    };
214
215    let result = rresult_return!(udf
216        .invoke_with_args(args)
217        .and_then(|r| r.to_array(number_rows)));
218
219    let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
220
221    RResult::ROk(WrappedArray {
222        array: result_array,
223        schema: WrappedSchema(result_schema),
224    })
225}
226
227unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
228    let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
229    drop(private_data);
230}
231
232unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
233    let private_data = udf.private_data as *const ScalarUDFPrivateData;
234    let udf_data = &(*private_data);
235
236    Arc::clone(&udf_data.udf).into()
237}
238
239impl Clone for FFI_ScalarUDF {
240    fn clone(&self) -> Self {
241        unsafe { (self.clone)(self) }
242    }
243}
244
245impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
246    fn from(udf: Arc<ScalarUDF>) -> Self {
247        let name = udf.name().into();
248        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
249        let volatility = udf.signature().volatility.into();
250        let short_circuits = udf.short_circuits();
251
252        let private_data = Box::new(ScalarUDFPrivateData { udf });
253
254        Self {
255            name,
256            aliases,
257            volatility,
258            short_circuits,
259            invoke_with_args: invoke_with_args_fn_wrapper,
260            return_type: return_type_fn_wrapper,
261            return_field_from_args: return_field_from_args_fn_wrapper,
262            coerce_types: coerce_types_fn_wrapper,
263            clone: clone_fn_wrapper,
264            release: release_fn_wrapper,
265            private_data: Box::into_raw(private_data) as *mut c_void,
266        }
267    }
268}
269
270impl Drop for FFI_ScalarUDF {
271    fn drop(&mut self) {
272        unsafe { (self.release)(self) }
273    }
274}
275
276/// This struct is used to access an UDF provided by a foreign
277/// library across a FFI boundary.
278///
279/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has
280/// no knowledge or access to the private data. All interaction with the UDF
281/// must occur through the functions defined in FFI_ScalarUDF.
282#[derive(Debug)]
283pub struct ForeignScalarUDF {
284    name: String,
285    aliases: Vec<String>,
286    udf: FFI_ScalarUDF,
287    signature: Signature,
288}
289
290unsafe impl Send for ForeignScalarUDF {}
291unsafe impl Sync for ForeignScalarUDF {}
292
293impl PartialEq for ForeignScalarUDF {
294    fn eq(&self, other: &Self) -> bool {
295        let Self {
296            name,
297            aliases,
298            udf,
299            signature,
300        } = self;
301        name == &other.name
302            && aliases == &other.aliases
303            && std::ptr::eq(udf, &other.udf)
304            && signature == &other.signature
305    }
306}
307impl Eq for ForeignScalarUDF {}
308
309impl Hash for ForeignScalarUDF {
310    fn hash<H: Hasher>(&self, state: &mut H) {
311        let Self {
312            name,
313            aliases,
314            udf,
315            signature,
316        } = self;
317        name.hash(state);
318        aliases.hash(state);
319        std::ptr::hash(udf, state);
320        signature.hash(state);
321    }
322}
323
324impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
325    type Error = DataFusionError;
326
327    fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
328        let name = udf.name.to_owned().into();
329        let signature = Signature::user_defined((&udf.volatility).into());
330
331        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
332
333        Ok(Self {
334            name,
335            udf: udf.clone(),
336            aliases,
337            signature,
338        })
339    }
340}
341
342impl ScalarUDFImpl for ForeignScalarUDF {
343    fn as_any(&self) -> &dyn std::any::Any {
344        self
345    }
346
347    fn name(&self) -> &str {
348        &self.name
349    }
350
351    fn signature(&self) -> &Signature {
352        &self.signature
353    }
354
355    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
356        let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
357
358        let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
359
360        let result = df_result!(result);
361
362        result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
363    }
364
365    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
366        let args: FFI_ReturnFieldArgs = args.try_into()?;
367
368        let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
369
370        let result = df_result!(result);
371
372        result.and_then(|r| {
373            Field::try_from(&r.0)
374                .map(Arc::new)
375                .map_err(DataFusionError::from)
376        })
377    }
378
379    fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
380        let ScalarFunctionArgs {
381            args,
382            arg_fields,
383            number_rows,
384            return_field,
385            // TODO: pass config options: https://github.com/apache/datafusion/issues/17035
386            config_options: _config_options,
387        } = invoke_args;
388
389        let args = args
390            .into_iter()
391            .map(|v| v.to_array(number_rows))
392            .collect::<Result<Vec<_>>>()?
393            .into_iter()
394            .map(|v| {
395                to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
396                    array: ffi_array,
397                    schema: WrappedSchema(ffi_schema),
398                })
399            })
400            .collect::<std::result::Result<Vec<_>, ArrowError>>()?
401            .into();
402
403        let arg_fields_wrapped = arg_fields
404            .iter()
405            .map(FFI_ArrowSchema::try_from)
406            .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
407
408        let arg_fields = arg_fields_wrapped
409            .into_iter()
410            .map(WrappedSchema)
411            .collect::<RVec<_>>();
412
413        let return_field = return_field.as_ref().clone();
414        let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
415
416        let result = unsafe {
417            (self.udf.invoke_with_args)(
418                &self.udf,
419                args,
420                arg_fields,
421                number_rows,
422                return_field,
423            )
424        };
425
426        let result = df_result!(result)?;
427        let result_array: ArrayRef = result.try_into()?;
428
429        Ok(ColumnarValue::Array(result_array))
430    }
431
432    fn aliases(&self) -> &[String] {
433        &self.aliases
434    }
435
436    fn short_circuits(&self) -> bool {
437        self.udf.short_circuits
438    }
439
440    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
441        unsafe {
442            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
443            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
444            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_round_trip_scalar_udf() -> Result<()> {
455        let original_udf = datafusion::functions::math::abs::AbsFunc::new();
456        let original_udf = Arc::new(ScalarUDF::from(original_udf));
457
458        let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
459
460        let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
461
462        assert_eq!(original_udf.name(), foreign_udf.name());
463
464        Ok(())
465    }
466}