datafusion_ffi/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{ffi::c_void, sync::Arc};
19
20use abi_stable::{
21    std_types::{RResult, RString, RVec},
22    StableAbi,
23};
24use arrow::datatypes::DataType;
25use arrow::{
26    array::ArrayRef,
27    error::ArrowError,
28    ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
29};
30use datafusion::{
31    error::DataFusionError,
32    logical_expr::type_coercion::functions::data_types_with_scalar_udf,
33};
34use datafusion::{
35    error::Result,
36    logical_expr::{
37        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
38    },
39};
40
41use crate::{
42    arrow_wrappers::{WrappedArray, WrappedSchema},
43    df_result, rresult, rresult_return,
44    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
45    volatility::FFI_Volatility,
46};
47
48/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries.
49#[repr(C)]
50#[derive(Debug, StableAbi)]
51#[allow(non_camel_case_types)]
52pub struct FFI_ScalarUDF {
53    /// FFI equivalent to the `name` of a [`ScalarUDF`]
54    pub name: RString,
55
56    /// FFI equivalent to the `aliases` of a [`ScalarUDF`]
57    pub aliases: RVec<RString>,
58
59    /// FFI equivalent to the `volatility` of a [`ScalarUDF`]
60    pub volatility: FFI_Volatility,
61
62    /// Determines the return type of the underlying [`ScalarUDF`] based on the
63    /// argument types.
64    pub return_type: unsafe extern "C" fn(
65        udf: &Self,
66        arg_types: RVec<WrappedSchema>,
67    ) -> RResult<WrappedSchema, RString>,
68
69    /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray`
70    /// within an AbiStable wrapper.
71    pub invoke_with_args: unsafe extern "C" fn(
72        udf: &Self,
73        args: RVec<WrappedArray>,
74        num_rows: usize,
75        return_type: WrappedSchema,
76    ) -> RResult<WrappedArray, RString>,
77
78    /// See [`ScalarUDFImpl`] for details on short_circuits
79    pub short_circuits: bool,
80
81    /// Performs type coersion. To simply this interface, all UDFs are treated as having
82    /// user defined signatures, which will in turn call coerce_types to be called. This
83    /// call should be transparent to most users as the internal function performs the
84    /// appropriate calls on the underlying [`ScalarUDF`]
85    pub coerce_types: unsafe extern "C" fn(
86        udf: &Self,
87        arg_types: RVec<WrappedSchema>,
88    ) -> RResult<RVec<WrappedSchema>, RString>,
89
90    /// Used to create a clone on the provider of the udf. This should
91    /// only need to be called by the receiver of the udf.
92    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
93
94    /// Release the memory of the private data when it is no longer being used.
95    pub release: unsafe extern "C" fn(udf: &mut Self),
96
97    /// Internal data. This is only to be accessed by the provider of the udf.
98    /// A [`ForeignScalarUDF`] should never attempt to access this data.
99    pub private_data: *mut c_void,
100}
101
102unsafe impl Send for FFI_ScalarUDF {}
103unsafe impl Sync for FFI_ScalarUDF {}
104
105pub struct ScalarUDFPrivateData {
106    pub udf: Arc<ScalarUDF>,
107}
108
109unsafe extern "C" fn return_type_fn_wrapper(
110    udf: &FFI_ScalarUDF,
111    arg_types: RVec<WrappedSchema>,
112) -> RResult<WrappedSchema, RString> {
113    let private_data = udf.private_data as *const ScalarUDFPrivateData;
114    let udf = &(*private_data).udf;
115
116    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
117
118    let return_type = udf
119        .return_type(&arg_types)
120        .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
121        .map(WrappedSchema);
122
123    rresult!(return_type)
124}
125
126unsafe extern "C" fn coerce_types_fn_wrapper(
127    udf: &FFI_ScalarUDF,
128    arg_types: RVec<WrappedSchema>,
129) -> RResult<RVec<WrappedSchema>, RString> {
130    let private_data = udf.private_data as *const ScalarUDFPrivateData;
131    let udf = &(*private_data).udf;
132
133    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
134
135    let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
136
137    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
138}
139
140unsafe extern "C" fn invoke_with_args_fn_wrapper(
141    udf: &FFI_ScalarUDF,
142    args: RVec<WrappedArray>,
143    number_rows: usize,
144    return_type: WrappedSchema,
145) -> RResult<WrappedArray, RString> {
146    let private_data = udf.private_data as *const ScalarUDFPrivateData;
147    let udf = &(*private_data).udf;
148
149    let args = args
150        .into_iter()
151        .map(|arr| {
152            from_ffi(arr.array, &arr.schema.0)
153                .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
154        })
155        .collect::<std::result::Result<_, _>>();
156
157    let args = rresult_return!(args);
158    let return_type = rresult_return!(DataType::try_from(&return_type.0));
159
160    let args = ScalarFunctionArgs {
161        args,
162        number_rows,
163        return_type: &return_type,
164    };
165
166    let result = rresult_return!(udf
167        .invoke_with_args(args)
168        .and_then(|r| r.to_array(number_rows)));
169
170    let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
171
172    RResult::ROk(WrappedArray {
173        array: result_array,
174        schema: WrappedSchema(result_schema),
175    })
176}
177
178unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
179    let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
180    drop(private_data);
181}
182
183unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
184    let private_data = udf.private_data as *const ScalarUDFPrivateData;
185    let udf_data = &(*private_data);
186
187    Arc::clone(&udf_data.udf).into()
188}
189
190impl Clone for FFI_ScalarUDF {
191    fn clone(&self) -> Self {
192        unsafe { (self.clone)(self) }
193    }
194}
195
196impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
197    fn from(udf: Arc<ScalarUDF>) -> Self {
198        let name = udf.name().into();
199        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
200        let volatility = udf.signature().volatility.into();
201        let short_circuits = udf.short_circuits();
202
203        let private_data = Box::new(ScalarUDFPrivateData { udf });
204
205        Self {
206            name,
207            aliases,
208            volatility,
209            short_circuits,
210            invoke_with_args: invoke_with_args_fn_wrapper,
211            return_type: return_type_fn_wrapper,
212            coerce_types: coerce_types_fn_wrapper,
213            clone: clone_fn_wrapper,
214            release: release_fn_wrapper,
215            private_data: Box::into_raw(private_data) as *mut c_void,
216        }
217    }
218}
219
220impl Drop for FFI_ScalarUDF {
221    fn drop(&mut self) {
222        unsafe { (self.release)(self) }
223    }
224}
225
226/// This struct is used to access an UDF provided by a foreign
227/// library across a FFI boundary.
228///
229/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has
230/// no knowledge or access to the private data. All interaction with the UDF
231/// must occur through the functions defined in FFI_ScalarUDF.
232#[derive(Debug)]
233pub struct ForeignScalarUDF {
234    name: String,
235    aliases: Vec<String>,
236    udf: FFI_ScalarUDF,
237    signature: Signature,
238}
239
240unsafe impl Send for ForeignScalarUDF {}
241unsafe impl Sync for ForeignScalarUDF {}
242
243impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
244    type Error = DataFusionError;
245
246    fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
247        let name = udf.name.to_owned().into();
248        let signature = Signature::user_defined((&udf.volatility).into());
249
250        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
251
252        Ok(Self {
253            name,
254            udf: udf.clone(),
255            aliases,
256            signature,
257        })
258    }
259}
260
261impl ScalarUDFImpl for ForeignScalarUDF {
262    fn as_any(&self) -> &dyn std::any::Any {
263        self
264    }
265
266    fn name(&self) -> &str {
267        &self.name
268    }
269
270    fn signature(&self) -> &Signature {
271        &self.signature
272    }
273
274    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
275        let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
276
277        let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
278
279        let result = df_result!(result);
280
281        result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
282    }
283
284    fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
285        let ScalarFunctionArgs {
286            args,
287            number_rows,
288            return_type,
289        } = invoke_args;
290
291        let args = args
292            .into_iter()
293            .map(|v| v.to_array(number_rows))
294            .collect::<Result<Vec<_>>>()?
295            .into_iter()
296            .map(|v| {
297                to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
298                    array: ffi_array,
299                    schema: WrappedSchema(ffi_schema),
300                })
301            })
302            .collect::<std::result::Result<Vec<_>, ArrowError>>()?
303            .into();
304
305        let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?);
306
307        let result = unsafe {
308            (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type)
309        };
310
311        let result = df_result!(result)?;
312        let result_array: ArrayRef = result.try_into()?;
313
314        Ok(ColumnarValue::Array(result_array))
315    }
316
317    fn aliases(&self) -> &[String] {
318        &self.aliases
319    }
320
321    fn short_circuits(&self) -> bool {
322        self.udf.short_circuits
323    }
324
325    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
326        unsafe {
327            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
328            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
329            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
330        }
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_round_trip_scalar_udf() -> Result<()> {
340        let original_udf = datafusion::functions::math::abs::AbsFunc::new();
341        let original_udf = Arc::new(ScalarUDF::from(original_udf));
342
343        let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
344
345        let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
346
347        assert!(original_udf.name() == foreign_udf.name());
348
349        Ok(())
350    }
351}