1use crate::{
19 arrow_wrappers::{WrappedArray, WrappedSchema},
20 df_result, rresult, rresult_return,
21 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22 volatility::FFI_Volatility,
23};
24use abi_stable::{
25 std_types::{RResult, RString, RVec},
26 StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30 array::ArrayRef,
31 error::ArrowError,
32 ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::config::ConfigOptions;
36use datafusion::logical_expr::ReturnFieldArgs;
37use datafusion::{
38 error::DataFusionError,
39 logical_expr::type_coercion::functions::data_types_with_scalar_udf,
40};
41use datafusion::{
42 error::Result,
43 logical_expr::{
44 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
45 },
46};
47use return_type_args::{
48 FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
49};
50use std::hash::{Hash, Hasher};
51use std::{ffi::c_void, sync::Arc};
52
53pub mod return_type_args;
54
55#[repr(C)]
57#[derive(Debug, StableAbi)]
58#[allow(non_camel_case_types)]
59pub struct FFI_ScalarUDF {
60 pub name: RString,
62
63 pub aliases: RVec<RString>,
65
66 pub volatility: FFI_Volatility,
68
69 pub return_type: unsafe extern "C" fn(
72 udf: &Self,
73 arg_types: RVec<WrappedSchema>,
74 ) -> RResult<WrappedSchema, RString>,
75
76 pub return_field_from_args: unsafe extern "C" fn(
79 udf: &Self,
80 args: FFI_ReturnFieldArgs,
81 )
82 -> RResult<WrappedSchema, RString>,
83
84 #[allow(clippy::type_complexity)]
87 pub invoke_with_args: unsafe extern "C" fn(
88 udf: &Self,
89 args: RVec<WrappedArray>,
90 arg_fields: RVec<WrappedSchema>,
91 num_rows: usize,
92 return_field: WrappedSchema,
93 ) -> RResult<WrappedArray, RString>,
94
95 pub short_circuits: bool,
97
98 pub coerce_types: unsafe extern "C" fn(
103 udf: &Self,
104 arg_types: RVec<WrappedSchema>,
105 ) -> RResult<RVec<WrappedSchema>, RString>,
106
107 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
110
111 pub release: unsafe extern "C" fn(udf: &mut Self),
113
114 pub private_data: *mut c_void,
117}
118
119unsafe impl Send for FFI_ScalarUDF {}
120unsafe impl Sync for FFI_ScalarUDF {}
121
122pub struct ScalarUDFPrivateData {
123 pub udf: Arc<ScalarUDF>,
124}
125
126unsafe extern "C" fn return_type_fn_wrapper(
127 udf: &FFI_ScalarUDF,
128 arg_types: RVec<WrappedSchema>,
129) -> RResult<WrappedSchema, RString> {
130 let private_data = udf.private_data as *const ScalarUDFPrivateData;
131 let udf = &(*private_data).udf;
132
133 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
134
135 let return_type = udf
136 .return_type(&arg_types)
137 .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
138 .map(WrappedSchema);
139
140 rresult!(return_type)
141}
142
143unsafe extern "C" fn return_field_from_args_fn_wrapper(
144 udf: &FFI_ScalarUDF,
145 args: FFI_ReturnFieldArgs,
146) -> RResult<WrappedSchema, RString> {
147 let private_data = udf.private_data as *const ScalarUDFPrivateData;
148 let udf = &(*private_data).udf;
149
150 let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
151 let args_ref: ForeignReturnFieldArgs = (&args).into();
152
153 let return_type = udf
154 .return_field_from_args((&args_ref).into())
155 .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
156 .map(WrappedSchema);
157
158 rresult!(return_type)
159}
160
161unsafe extern "C" fn coerce_types_fn_wrapper(
162 udf: &FFI_ScalarUDF,
163 arg_types: RVec<WrappedSchema>,
164) -> RResult<RVec<WrappedSchema>, RString> {
165 let private_data = udf.private_data as *const ScalarUDFPrivateData;
166 let udf = &(*private_data).udf;
167
168 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
169
170 let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
171
172 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
173}
174
175unsafe extern "C" fn invoke_with_args_fn_wrapper(
176 udf: &FFI_ScalarUDF,
177 args: RVec<WrappedArray>,
178 arg_fields: RVec<WrappedSchema>,
179 number_rows: usize,
180 return_field: WrappedSchema,
181) -> RResult<WrappedArray, RString> {
182 let private_data = udf.private_data as *const ScalarUDFPrivateData;
183 let udf = &(*private_data).udf;
184
185 let args = args
186 .into_iter()
187 .map(|arr| {
188 from_ffi(arr.array, &arr.schema.0)
189 .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
190 })
191 .collect::<std::result::Result<_, _>>();
192
193 let args = rresult_return!(args);
194 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
195
196 let arg_fields = arg_fields
197 .into_iter()
198 .map(|wrapped_field| {
199 Field::try_from(&wrapped_field.0)
200 .map(Arc::new)
201 .map_err(DataFusionError::from)
202 })
203 .collect::<Result<Vec<FieldRef>>>();
204 let arg_fields = rresult_return!(arg_fields);
205
206 let args = ScalarFunctionArgs {
207 args,
208 arg_fields,
209 number_rows,
210 return_field,
211 config_options: Arc::new(ConfigOptions::default()),
213 };
214
215 let result = rresult_return!(udf
216 .invoke_with_args(args)
217 .and_then(|r| r.to_array(number_rows)));
218
219 let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
220
221 RResult::ROk(WrappedArray {
222 array: result_array,
223 schema: WrappedSchema(result_schema),
224 })
225}
226
227unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
228 let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
229 drop(private_data);
230}
231
232unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
233 let private_data = udf.private_data as *const ScalarUDFPrivateData;
234 let udf_data = &(*private_data);
235
236 Arc::clone(&udf_data.udf).into()
237}
238
239impl Clone for FFI_ScalarUDF {
240 fn clone(&self) -> Self {
241 unsafe { (self.clone)(self) }
242 }
243}
244
245impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
246 fn from(udf: Arc<ScalarUDF>) -> Self {
247 let name = udf.name().into();
248 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
249 let volatility = udf.signature().volatility.into();
250 let short_circuits = udf.short_circuits();
251
252 let private_data = Box::new(ScalarUDFPrivateData { udf });
253
254 Self {
255 name,
256 aliases,
257 volatility,
258 short_circuits,
259 invoke_with_args: invoke_with_args_fn_wrapper,
260 return_type: return_type_fn_wrapper,
261 return_field_from_args: return_field_from_args_fn_wrapper,
262 coerce_types: coerce_types_fn_wrapper,
263 clone: clone_fn_wrapper,
264 release: release_fn_wrapper,
265 private_data: Box::into_raw(private_data) as *mut c_void,
266 }
267 }
268}
269
270impl Drop for FFI_ScalarUDF {
271 fn drop(&mut self) {
272 unsafe { (self.release)(self) }
273 }
274}
275
276#[derive(Debug)]
283pub struct ForeignScalarUDF {
284 name: String,
285 aliases: Vec<String>,
286 udf: FFI_ScalarUDF,
287 signature: Signature,
288}
289
290unsafe impl Send for ForeignScalarUDF {}
291unsafe impl Sync for ForeignScalarUDF {}
292
293impl PartialEq for ForeignScalarUDF {
294 fn eq(&self, other: &Self) -> bool {
295 let Self {
296 name,
297 aliases,
298 udf,
299 signature,
300 } = self;
301 name == &other.name
302 && aliases == &other.aliases
303 && std::ptr::eq(udf, &other.udf)
304 && signature == &other.signature
305 }
306}
307impl Eq for ForeignScalarUDF {}
308
309impl Hash for ForeignScalarUDF {
310 fn hash<H: Hasher>(&self, state: &mut H) {
311 let Self {
312 name,
313 aliases,
314 udf,
315 signature,
316 } = self;
317 name.hash(state);
318 aliases.hash(state);
319 std::ptr::hash(udf, state);
320 signature.hash(state);
321 }
322}
323
324impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
325 type Error = DataFusionError;
326
327 fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
328 let name = udf.name.to_owned().into();
329 let signature = Signature::user_defined((&udf.volatility).into());
330
331 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
332
333 Ok(Self {
334 name,
335 udf: udf.clone(),
336 aliases,
337 signature,
338 })
339 }
340}
341
342impl ScalarUDFImpl for ForeignScalarUDF {
343 fn as_any(&self) -> &dyn std::any::Any {
344 self
345 }
346
347 fn name(&self) -> &str {
348 &self.name
349 }
350
351 fn signature(&self) -> &Signature {
352 &self.signature
353 }
354
355 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
356 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
357
358 let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
359
360 let result = df_result!(result);
361
362 result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
363 }
364
365 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
366 let args: FFI_ReturnFieldArgs = args.try_into()?;
367
368 let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
369
370 let result = df_result!(result);
371
372 result.and_then(|r| {
373 Field::try_from(&r.0)
374 .map(Arc::new)
375 .map_err(DataFusionError::from)
376 })
377 }
378
379 fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
380 let ScalarFunctionArgs {
381 args,
382 arg_fields,
383 number_rows,
384 return_field,
385 config_options: _config_options,
387 } = invoke_args;
388
389 let args = args
390 .into_iter()
391 .map(|v| v.to_array(number_rows))
392 .collect::<Result<Vec<_>>>()?
393 .into_iter()
394 .map(|v| {
395 to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
396 array: ffi_array,
397 schema: WrappedSchema(ffi_schema),
398 })
399 })
400 .collect::<std::result::Result<Vec<_>, ArrowError>>()?
401 .into();
402
403 let arg_fields_wrapped = arg_fields
404 .iter()
405 .map(FFI_ArrowSchema::try_from)
406 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
407
408 let arg_fields = arg_fields_wrapped
409 .into_iter()
410 .map(WrappedSchema)
411 .collect::<RVec<_>>();
412
413 let return_field = return_field.as_ref().clone();
414 let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
415
416 let result = unsafe {
417 (self.udf.invoke_with_args)(
418 &self.udf,
419 args,
420 arg_fields,
421 number_rows,
422 return_field,
423 )
424 };
425
426 let result = df_result!(result)?;
427 let result_array: ArrayRef = result.try_into()?;
428
429 Ok(ColumnarValue::Array(result_array))
430 }
431
432 fn aliases(&self) -> &[String] {
433 &self.aliases
434 }
435
436 fn short_circuits(&self) -> bool {
437 self.udf.short_circuits
438 }
439
440 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
441 unsafe {
442 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
443 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
444 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_round_trip_scalar_udf() -> Result<()> {
455 let original_udf = datafusion::functions::math::abs::AbsFunc::new();
456 let original_udf = Arc::new(ScalarUDF::from(original_udf));
457
458 let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
459
460 let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
461
462 assert_eq!(original_udf.name(), foreign_udf.name());
463
464 Ok(())
465 }
466}