1use crate::{
19 arrow_wrappers::{WrappedArray, WrappedSchema},
20 df_result, rresult, rresult_return,
21 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22 volatility::FFI_Volatility,
23};
24use abi_stable::{
25 std_types::{RResult, RString, RVec},
26 StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30 array::ArrayRef,
31 error::ArrowError,
32 ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::logical_expr::ReturnFieldArgs;
36use datafusion::{
37 error::DataFusionError,
38 logical_expr::type_coercion::functions::data_types_with_scalar_udf,
39};
40use datafusion::{
41 error::Result,
42 logical_expr::{
43 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
44 },
45};
46use return_type_args::{
47 FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
48};
49use std::hash::{DefaultHasher, Hash, Hasher};
50use std::{ffi::c_void, sync::Arc};
51
52pub mod return_type_args;
53
54#[repr(C)]
56#[derive(Debug, StableAbi)]
57#[allow(non_camel_case_types)]
58pub struct FFI_ScalarUDF {
59 pub name: RString,
61
62 pub aliases: RVec<RString>,
64
65 pub volatility: FFI_Volatility,
67
68 pub return_type: unsafe extern "C" fn(
71 udf: &Self,
72 arg_types: RVec<WrappedSchema>,
73 ) -> RResult<WrappedSchema, RString>,
74
75 pub return_field_from_args: unsafe extern "C" fn(
78 udf: &Self,
79 args: FFI_ReturnFieldArgs,
80 )
81 -> RResult<WrappedSchema, RString>,
82
83 #[allow(clippy::type_complexity)]
86 pub invoke_with_args: unsafe extern "C" fn(
87 udf: &Self,
88 args: RVec<WrappedArray>,
89 arg_fields: RVec<WrappedSchema>,
90 num_rows: usize,
91 return_field: WrappedSchema,
92 ) -> RResult<WrappedArray, RString>,
93
94 pub short_circuits: bool,
96
97 pub coerce_types: unsafe extern "C" fn(
102 udf: &Self,
103 arg_types: RVec<WrappedSchema>,
104 ) -> RResult<RVec<WrappedSchema>, RString>,
105
106 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
109
110 pub release: unsafe extern "C" fn(udf: &mut Self),
112
113 pub private_data: *mut c_void,
116}
117
118unsafe impl Send for FFI_ScalarUDF {}
119unsafe impl Sync for FFI_ScalarUDF {}
120
121pub struct ScalarUDFPrivateData {
122 pub udf: Arc<ScalarUDF>,
123}
124
125unsafe extern "C" fn return_type_fn_wrapper(
126 udf: &FFI_ScalarUDF,
127 arg_types: RVec<WrappedSchema>,
128) -> RResult<WrappedSchema, RString> {
129 let private_data = udf.private_data as *const ScalarUDFPrivateData;
130 let udf = &(*private_data).udf;
131
132 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
133
134 let return_type = udf
135 .return_type(&arg_types)
136 .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
137 .map(WrappedSchema);
138
139 rresult!(return_type)
140}
141
142unsafe extern "C" fn return_field_from_args_fn_wrapper(
143 udf: &FFI_ScalarUDF,
144 args: FFI_ReturnFieldArgs,
145) -> RResult<WrappedSchema, RString> {
146 let private_data = udf.private_data as *const ScalarUDFPrivateData;
147 let udf = &(*private_data).udf;
148
149 let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
150 let args_ref: ForeignReturnFieldArgs = (&args).into();
151
152 let return_type = udf
153 .return_field_from_args((&args_ref).into())
154 .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
155 .map(WrappedSchema);
156
157 rresult!(return_type)
158}
159
160unsafe extern "C" fn coerce_types_fn_wrapper(
161 udf: &FFI_ScalarUDF,
162 arg_types: RVec<WrappedSchema>,
163) -> RResult<RVec<WrappedSchema>, RString> {
164 let private_data = udf.private_data as *const ScalarUDFPrivateData;
165 let udf = &(*private_data).udf;
166
167 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
168
169 let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
170
171 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
172}
173
174unsafe extern "C" fn invoke_with_args_fn_wrapper(
175 udf: &FFI_ScalarUDF,
176 args: RVec<WrappedArray>,
177 arg_fields: RVec<WrappedSchema>,
178 number_rows: usize,
179 return_field: WrappedSchema,
180) -> RResult<WrappedArray, RString> {
181 let private_data = udf.private_data as *const ScalarUDFPrivateData;
182 let udf = &(*private_data).udf;
183
184 let args = args
185 .into_iter()
186 .map(|arr| {
187 from_ffi(arr.array, &arr.schema.0)
188 .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
189 })
190 .collect::<std::result::Result<_, _>>();
191
192 let args = rresult_return!(args);
193 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
194
195 let arg_fields = arg_fields
196 .into_iter()
197 .map(|wrapped_field| {
198 Field::try_from(&wrapped_field.0)
199 .map(Arc::new)
200 .map_err(DataFusionError::from)
201 })
202 .collect::<Result<Vec<FieldRef>>>();
203 let arg_fields = rresult_return!(arg_fields);
204
205 let args = ScalarFunctionArgs {
206 args,
207 arg_fields,
208 number_rows,
209 return_field,
210 };
211
212 let result = rresult_return!(udf
213 .invoke_with_args(args)
214 .and_then(|r| r.to_array(number_rows)));
215
216 let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
217
218 RResult::ROk(WrappedArray {
219 array: result_array,
220 schema: WrappedSchema(result_schema),
221 })
222}
223
224unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
225 let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
226 drop(private_data);
227}
228
229unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
230 let private_data = udf.private_data as *const ScalarUDFPrivateData;
231 let udf_data = &(*private_data);
232
233 Arc::clone(&udf_data.udf).into()
234}
235
236impl Clone for FFI_ScalarUDF {
237 fn clone(&self) -> Self {
238 unsafe { (self.clone)(self) }
239 }
240}
241
242impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
243 fn from(udf: Arc<ScalarUDF>) -> Self {
244 let name = udf.name().into();
245 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
246 let volatility = udf.signature().volatility.into();
247 let short_circuits = udf.short_circuits();
248
249 let private_data = Box::new(ScalarUDFPrivateData { udf });
250
251 Self {
252 name,
253 aliases,
254 volatility,
255 short_circuits,
256 invoke_with_args: invoke_with_args_fn_wrapper,
257 return_type: return_type_fn_wrapper,
258 return_field_from_args: return_field_from_args_fn_wrapper,
259 coerce_types: coerce_types_fn_wrapper,
260 clone: clone_fn_wrapper,
261 release: release_fn_wrapper,
262 private_data: Box::into_raw(private_data) as *mut c_void,
263 }
264 }
265}
266
267impl Drop for FFI_ScalarUDF {
268 fn drop(&mut self) {
269 unsafe { (self.release)(self) }
270 }
271}
272
273#[derive(Debug)]
280pub struct ForeignScalarUDF {
281 name: String,
282 aliases: Vec<String>,
283 udf: FFI_ScalarUDF,
284 signature: Signature,
285}
286
287unsafe impl Send for ForeignScalarUDF {}
288unsafe impl Sync for ForeignScalarUDF {}
289
290impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
291 type Error = DataFusionError;
292
293 fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
294 let name = udf.name.to_owned().into();
295 let signature = Signature::user_defined((&udf.volatility).into());
296
297 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
298
299 Ok(Self {
300 name,
301 udf: udf.clone(),
302 aliases,
303 signature,
304 })
305 }
306}
307
308impl ScalarUDFImpl for ForeignScalarUDF {
309 fn as_any(&self) -> &dyn std::any::Any {
310 self
311 }
312
313 fn name(&self) -> &str {
314 &self.name
315 }
316
317 fn signature(&self) -> &Signature {
318 &self.signature
319 }
320
321 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
322 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
323
324 let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
325
326 let result = df_result!(result);
327
328 result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
329 }
330
331 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
332 let args: FFI_ReturnFieldArgs = args.try_into()?;
333
334 let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
335
336 let result = df_result!(result);
337
338 result.and_then(|r| {
339 Field::try_from(&r.0)
340 .map(Arc::new)
341 .map_err(DataFusionError::from)
342 })
343 }
344
345 fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
346 let ScalarFunctionArgs {
347 args,
348 arg_fields,
349 number_rows,
350 return_field,
351 } = invoke_args;
352
353 let args = args
354 .into_iter()
355 .map(|v| v.to_array(number_rows))
356 .collect::<Result<Vec<_>>>()?
357 .into_iter()
358 .map(|v| {
359 to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
360 array: ffi_array,
361 schema: WrappedSchema(ffi_schema),
362 })
363 })
364 .collect::<std::result::Result<Vec<_>, ArrowError>>()?
365 .into();
366
367 let arg_fields_wrapped = arg_fields
368 .iter()
369 .map(FFI_ArrowSchema::try_from)
370 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
371
372 let arg_fields = arg_fields_wrapped
373 .into_iter()
374 .map(WrappedSchema)
375 .collect::<RVec<_>>();
376
377 let return_field = return_field.as_ref().clone();
378 let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
379
380 let result = unsafe {
381 (self.udf.invoke_with_args)(
382 &self.udf,
383 args,
384 arg_fields,
385 number_rows,
386 return_field,
387 )
388 };
389
390 let result = df_result!(result)?;
391 let result_array: ArrayRef = result.try_into()?;
392
393 Ok(ColumnarValue::Array(result_array))
394 }
395
396 fn aliases(&self) -> &[String] {
397 &self.aliases
398 }
399
400 fn short_circuits(&self) -> bool {
401 self.udf.short_circuits
402 }
403
404 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
405 unsafe {
406 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
407 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
408 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
409 }
410 }
411
412 fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
413 let Some(other) = other.as_any().downcast_ref::<Self>() else {
414 return false;
415 };
416 let Self {
417 name,
418 aliases,
419 udf,
420 signature,
421 } = self;
422 name == &other.name
423 && aliases == &other.aliases
424 && std::ptr::eq(udf, &other.udf)
425 && signature == &other.signature
426 }
427
428 fn hash_value(&self) -> u64 {
429 let Self {
430 name,
431 aliases,
432 udf,
433 signature,
434 } = self;
435 let mut hasher = DefaultHasher::new();
436 std::any::type_name::<Self>().hash(&mut hasher);
437 name.hash(&mut hasher);
438 aliases.hash(&mut hasher);
439 std::ptr::hash(udf, &mut hasher);
440 signature.hash(&mut hasher);
441 hasher.finish()
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_round_trip_scalar_udf() -> Result<()> {
451 let original_udf = datafusion::functions::math::abs::AbsFunc::new();
452 let original_udf = Arc::new(ScalarUDF::from(original_udf));
453
454 let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
455
456 let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
457
458 assert_eq!(original_udf.name(), foreign_udf.name());
459
460 Ok(())
461 }
462}