1use std::{ffi::c_void, sync::Arc};
19
20use abi_stable::{
21 std_types::{RResult, RString, RVec},
22 StableAbi,
23};
24use arrow::datatypes::DataType;
25use arrow::{
26 array::ArrayRef,
27 error::ArrowError,
28 ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
29};
30use datafusion::{
31 error::DataFusionError,
32 logical_expr::type_coercion::functions::data_types_with_scalar_udf,
33};
34use datafusion::{
35 error::Result,
36 logical_expr::{
37 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
38 },
39};
40
41use crate::{
42 arrow_wrappers::{WrappedArray, WrappedSchema},
43 df_result, rresult, rresult_return,
44 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
45 volatility::FFI_Volatility,
46};
47
48#[repr(C)]
50#[derive(Debug, StableAbi)]
51#[allow(non_camel_case_types)]
52pub struct FFI_ScalarUDF {
53 pub name: RString,
55
56 pub aliases: RVec<RString>,
58
59 pub volatility: FFI_Volatility,
61
62 pub return_type: unsafe extern "C" fn(
65 udf: &Self,
66 arg_types: RVec<WrappedSchema>,
67 ) -> RResult<WrappedSchema, RString>,
68
69 pub invoke_with_args: unsafe extern "C" fn(
72 udf: &Self,
73 args: RVec<WrappedArray>,
74 num_rows: usize,
75 return_type: WrappedSchema,
76 ) -> RResult<WrappedArray, RString>,
77
78 pub short_circuits: bool,
80
81 pub coerce_types: unsafe extern "C" fn(
86 udf: &Self,
87 arg_types: RVec<WrappedSchema>,
88 ) -> RResult<RVec<WrappedSchema>, RString>,
89
90 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
93
94 pub release: unsafe extern "C" fn(udf: &mut Self),
96
97 pub private_data: *mut c_void,
100}
101
102unsafe impl Send for FFI_ScalarUDF {}
103unsafe impl Sync for FFI_ScalarUDF {}
104
105pub struct ScalarUDFPrivateData {
106 pub udf: Arc<ScalarUDF>,
107}
108
109unsafe extern "C" fn return_type_fn_wrapper(
110 udf: &FFI_ScalarUDF,
111 arg_types: RVec<WrappedSchema>,
112) -> RResult<WrappedSchema, RString> {
113 let private_data = udf.private_data as *const ScalarUDFPrivateData;
114 let udf = &(*private_data).udf;
115
116 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
117
118 let return_type = udf
119 .return_type(&arg_types)
120 .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
121 .map(WrappedSchema);
122
123 rresult!(return_type)
124}
125
126unsafe extern "C" fn coerce_types_fn_wrapper(
127 udf: &FFI_ScalarUDF,
128 arg_types: RVec<WrappedSchema>,
129) -> RResult<RVec<WrappedSchema>, RString> {
130 let private_data = udf.private_data as *const ScalarUDFPrivateData;
131 let udf = &(*private_data).udf;
132
133 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
134
135 let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
136
137 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
138}
139
140unsafe extern "C" fn invoke_with_args_fn_wrapper(
141 udf: &FFI_ScalarUDF,
142 args: RVec<WrappedArray>,
143 number_rows: usize,
144 return_type: WrappedSchema,
145) -> RResult<WrappedArray, RString> {
146 let private_data = udf.private_data as *const ScalarUDFPrivateData;
147 let udf = &(*private_data).udf;
148
149 let args = args
150 .into_iter()
151 .map(|arr| {
152 from_ffi(arr.array, &arr.schema.0)
153 .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
154 })
155 .collect::<std::result::Result<_, _>>();
156
157 let args = rresult_return!(args);
158 let return_type = rresult_return!(DataType::try_from(&return_type.0));
159
160 let args = ScalarFunctionArgs {
161 args,
162 number_rows,
163 return_type: &return_type,
164 };
165
166 let result = rresult_return!(udf
167 .invoke_with_args(args)
168 .and_then(|r| r.to_array(number_rows)));
169
170 let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
171
172 RResult::ROk(WrappedArray {
173 array: result_array,
174 schema: WrappedSchema(result_schema),
175 })
176}
177
178unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
179 let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
180 drop(private_data);
181}
182
183unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
184 let private_data = udf.private_data as *const ScalarUDFPrivateData;
185 let udf_data = &(*private_data);
186
187 Arc::clone(&udf_data.udf).into()
188}
189
190impl Clone for FFI_ScalarUDF {
191 fn clone(&self) -> Self {
192 unsafe { (self.clone)(self) }
193 }
194}
195
196impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
197 fn from(udf: Arc<ScalarUDF>) -> Self {
198 let name = udf.name().into();
199 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
200 let volatility = udf.signature().volatility.into();
201 let short_circuits = udf.short_circuits();
202
203 let private_data = Box::new(ScalarUDFPrivateData { udf });
204
205 Self {
206 name,
207 aliases,
208 volatility,
209 short_circuits,
210 invoke_with_args: invoke_with_args_fn_wrapper,
211 return_type: return_type_fn_wrapper,
212 coerce_types: coerce_types_fn_wrapper,
213 clone: clone_fn_wrapper,
214 release: release_fn_wrapper,
215 private_data: Box::into_raw(private_data) as *mut c_void,
216 }
217 }
218}
219
220impl Drop for FFI_ScalarUDF {
221 fn drop(&mut self) {
222 unsafe { (self.release)(self) }
223 }
224}
225
226#[derive(Debug)]
233pub struct ForeignScalarUDF {
234 name: String,
235 aliases: Vec<String>,
236 udf: FFI_ScalarUDF,
237 signature: Signature,
238}
239
240unsafe impl Send for ForeignScalarUDF {}
241unsafe impl Sync for ForeignScalarUDF {}
242
243impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
244 type Error = DataFusionError;
245
246 fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
247 let name = udf.name.to_owned().into();
248 let signature = Signature::user_defined((&udf.volatility).into());
249
250 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
251
252 Ok(Self {
253 name,
254 udf: udf.clone(),
255 aliases,
256 signature,
257 })
258 }
259}
260
261impl ScalarUDFImpl for ForeignScalarUDF {
262 fn as_any(&self) -> &dyn std::any::Any {
263 self
264 }
265
266 fn name(&self) -> &str {
267 &self.name
268 }
269
270 fn signature(&self) -> &Signature {
271 &self.signature
272 }
273
274 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
275 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
276
277 let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
278
279 let result = df_result!(result);
280
281 result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
282 }
283
284 fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
285 let ScalarFunctionArgs {
286 args,
287 number_rows,
288 return_type,
289 } = invoke_args;
290
291 let args = args
292 .into_iter()
293 .map(|v| v.to_array(number_rows))
294 .collect::<Result<Vec<_>>>()?
295 .into_iter()
296 .map(|v| {
297 to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
298 array: ffi_array,
299 schema: WrappedSchema(ffi_schema),
300 })
301 })
302 .collect::<std::result::Result<Vec<_>, ArrowError>>()?
303 .into();
304
305 let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?);
306
307 let result = unsafe {
308 (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type)
309 };
310
311 let result = df_result!(result)?;
312 let result_array: ArrayRef = result.try_into()?;
313
314 Ok(ColumnarValue::Array(result_array))
315 }
316
317 fn aliases(&self) -> &[String] {
318 &self.aliases
319 }
320
321 fn short_circuits(&self) -> bool {
322 self.udf.short_circuits
323 }
324
325 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
326 unsafe {
327 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
328 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
329 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_round_trip_scalar_udf() -> Result<()> {
340 let original_udf = datafusion::functions::math::abs::AbsFunc::new();
341 let original_udf = Arc::new(ScalarUDF::from(original_udf));
342
343 let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
344
345 let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
346
347 assert!(original_udf.name() == foreign_udf.name());
348
349 Ok(())
350 }
351}