1use std::ffi::c_void;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21
22use abi_stable::StableAbi;
23use abi_stable::std_types::{RResult, RString, RVec};
24use arrow::array::ArrayRef;
25use arrow::datatypes::{DataType, Field};
26use arrow::error::ArrowError;
27use arrow::ffi::{FFI_ArrowSchema, from_ffi, to_ffi};
28use arrow_schema::FieldRef;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::{DataFusionError, Result, internal_err};
31use datafusion_expr::type_coercion::functions::fields_with_udf;
32use datafusion_expr::{
33 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
34 Signature,
35};
36use return_type_args::{
37 FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
38};
39
40use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
41use crate::util::{
42 FFIResult, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped,
43};
44use crate::volatility::FFI_Volatility;
45use crate::{df_result, rresult, rresult_return};
46
47pub mod return_type_args;
48
49#[repr(C)]
51#[derive(Debug, StableAbi)]
52pub struct FFI_ScalarUDF {
53 pub name: RString,
55
56 pub aliases: RVec<RString>,
58
59 pub volatility: FFI_Volatility,
61
62 pub return_field_from_args: unsafe extern "C" fn(
64 udf: &Self,
65 args: FFI_ReturnFieldArgs,
66 ) -> FFIResult<WrappedSchema>,
67
68 pub invoke_with_args: unsafe extern "C" fn(
71 udf: &Self,
72 args: RVec<WrappedArray>,
73 arg_fields: RVec<WrappedSchema>,
74 num_rows: usize,
75 return_field: WrappedSchema,
76 ) -> FFIResult<WrappedArray>,
77
78 pub short_circuits: bool,
80
81 pub coerce_types: unsafe extern "C" fn(
86 udf: &Self,
87 arg_types: RVec<WrappedSchema>,
88 ) -> FFIResult<RVec<WrappedSchema>>,
89
90 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
93
94 pub release: unsafe extern "C" fn(udf: &mut Self),
96
97 pub private_data: *mut c_void,
100
101 pub library_marker_id: extern "C" fn() -> usize,
105}
106
107unsafe impl Send for FFI_ScalarUDF {}
108unsafe impl Sync for FFI_ScalarUDF {}
109
110pub struct ScalarUDFPrivateData {
111 pub udf: Arc<ScalarUDF>,
112}
113
114impl FFI_ScalarUDF {
115 fn inner(&self) -> &Arc<ScalarUDF> {
116 let private_data = self.private_data as *const ScalarUDFPrivateData;
117 unsafe { &(*private_data).udf }
118 }
119}
120
121unsafe extern "C" fn return_field_from_args_fn_wrapper(
122 udf: &FFI_ScalarUDF,
123 args: FFI_ReturnFieldArgs,
124) -> FFIResult<WrappedSchema> {
125 let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
126 let args_ref: ForeignReturnFieldArgs = (&args).into();
127
128 let return_type = udf
129 .inner()
130 .return_field_from_args((&args_ref).into())
131 .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
132 .map(WrappedSchema);
133
134 rresult!(return_type)
135}
136
137unsafe extern "C" fn coerce_types_fn_wrapper(
138 udf: &FFI_ScalarUDF,
139 arg_types: RVec<WrappedSchema>,
140) -> FFIResult<RVec<WrappedSchema>> {
141 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
142
143 let arg_fields = arg_types
144 .iter()
145 .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
146 .collect::<Vec<_>>();
147 let return_types =
148 rresult_return!(fields_with_udf(&arg_fields, udf.inner().as_ref()))
149 .into_iter()
150 .map(|f| f.data_type().to_owned())
151 .collect::<Vec<_>>();
152
153 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
154}
155
156unsafe extern "C" fn invoke_with_args_fn_wrapper(
157 udf: &FFI_ScalarUDF,
158 args: RVec<WrappedArray>,
159 arg_fields: RVec<WrappedSchema>,
160 number_rows: usize,
161 return_field: WrappedSchema,
162) -> FFIResult<WrappedArray> {
163 unsafe {
164 let args = args
165 .into_iter()
166 .map(|arr| {
167 from_ffi(arr.array, &arr.schema.0)
168 .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
169 })
170 .collect::<std::result::Result<_, _>>();
171
172 let args = rresult_return!(args);
173 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
174
175 let arg_fields = arg_fields
176 .into_iter()
177 .map(|wrapped_field| {
178 Field::try_from(&wrapped_field.0)
179 .map(Arc::new)
180 .map_err(DataFusionError::from)
181 })
182 .collect::<Result<Vec<FieldRef>>>();
183 let arg_fields = rresult_return!(arg_fields);
184
185 let args = ScalarFunctionArgs {
186 args,
187 arg_fields,
188 number_rows,
189 return_field,
190 config_options: Arc::new(ConfigOptions::default()),
192 };
193
194 let result = rresult_return!(
195 udf.inner()
196 .invoke_with_args(args)
197 .and_then(|r| r.to_array(number_rows))
198 );
199
200 let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
201
202 RResult::ROk(WrappedArray {
203 array: result_array,
204 schema: WrappedSchema(result_schema),
205 })
206 }
207}
208
209unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
210 unsafe {
211 debug_assert!(!udf.private_data.is_null());
212 let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
213 drop(private_data);
214 udf.private_data = std::ptr::null_mut();
215 }
216}
217
218unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
219 unsafe {
220 let private_data = udf.private_data as *const ScalarUDFPrivateData;
221 let udf_data = &(*private_data);
222
223 Arc::clone(&udf_data.udf).into()
224 }
225}
226
227impl Clone for FFI_ScalarUDF {
228 fn clone(&self) -> Self {
229 unsafe { (self.clone)(self) }
230 }
231}
232
233impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
234 fn from(udf: Arc<ScalarUDF>) -> Self {
235 let name = udf.name().into();
236 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
237 let volatility = udf.signature().volatility.into();
238 let short_circuits = udf.short_circuits();
239
240 let private_data = Box::new(ScalarUDFPrivateData { udf });
241
242 Self {
243 name,
244 aliases,
245 volatility,
246 short_circuits,
247 invoke_with_args: invoke_with_args_fn_wrapper,
248 return_field_from_args: return_field_from_args_fn_wrapper,
249 coerce_types: coerce_types_fn_wrapper,
250 clone: clone_fn_wrapper,
251 release: release_fn_wrapper,
252 private_data: Box::into_raw(private_data) as *mut c_void,
253 library_marker_id: crate::get_library_marker_id,
254 }
255 }
256}
257
258impl Drop for FFI_ScalarUDF {
259 fn drop(&mut self) {
260 unsafe { (self.release)(self) }
261 }
262}
263
264#[derive(Debug)]
271pub struct ForeignScalarUDF {
272 name: String,
273 aliases: Vec<String>,
274 udf: FFI_ScalarUDF,
275 signature: Signature,
276}
277
278unsafe impl Send for ForeignScalarUDF {}
279unsafe impl Sync for ForeignScalarUDF {}
280
281impl PartialEq for ForeignScalarUDF {
282 fn eq(&self, other: &Self) -> bool {
283 let Self {
284 name,
285 aliases,
286 udf,
287 signature,
288 } = self;
289 name == &other.name
290 && aliases == &other.aliases
291 && std::ptr::eq(udf, &other.udf)
292 && signature == &other.signature
293 }
294}
295impl Eq for ForeignScalarUDF {}
296
297impl Hash for ForeignScalarUDF {
298 fn hash<H: Hasher>(&self, state: &mut H) {
299 let Self {
300 name,
301 aliases,
302 udf,
303 signature,
304 } = self;
305 name.hash(state);
306 aliases.hash(state);
307 std::ptr::hash(udf, state);
308 signature.hash(state);
309 }
310}
311
312impl From<&FFI_ScalarUDF> for Arc<dyn ScalarUDFImpl> {
313 fn from(udf: &FFI_ScalarUDF) -> Self {
314 if (udf.library_marker_id)() == crate::get_library_marker_id() {
315 Arc::clone(udf.inner().inner())
316 } else {
317 let name = udf.name.to_owned().into();
318 let signature = Signature::user_defined((&udf.volatility).into());
319
320 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
321
322 Arc::new(ForeignScalarUDF {
323 name,
324 udf: udf.clone(),
325 aliases,
326 signature,
327 })
328 }
329 }
330}
331
332impl ScalarUDFImpl for ForeignScalarUDF {
333 fn as_any(&self) -> &dyn std::any::Any {
334 self
335 }
336
337 fn name(&self) -> &str {
338 &self.name
339 }
340
341 fn signature(&self) -> &Signature {
342 &self.signature
343 }
344
345 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
346 internal_err!("ForeignScalarUDF implements return_field_from_args instead.")
347 }
348
349 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
350 let args: FFI_ReturnFieldArgs = args.try_into()?;
351
352 let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
353
354 let result = df_result!(result);
355
356 result.and_then(|r| {
357 Field::try_from(&r.0)
358 .map(Arc::new)
359 .map_err(DataFusionError::from)
360 })
361 }
362
363 fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
364 let ScalarFunctionArgs {
365 args,
366 arg_fields,
367 number_rows,
368 return_field,
369 config_options: _config_options,
371 } = invoke_args;
372
373 let args = args
374 .into_iter()
375 .map(|v| v.to_array(number_rows))
376 .collect::<Result<Vec<_>>>()?
377 .into_iter()
378 .map(|v| {
379 to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
380 array: ffi_array,
381 schema: WrappedSchema(ffi_schema),
382 })
383 })
384 .collect::<std::result::Result<Vec<_>, ArrowError>>()?
385 .into();
386
387 let arg_fields_wrapped = arg_fields
388 .iter()
389 .map(FFI_ArrowSchema::try_from)
390 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
391
392 let arg_fields = arg_fields_wrapped
393 .into_iter()
394 .map(WrappedSchema)
395 .collect::<RVec<_>>();
396
397 let return_field = return_field.as_ref().clone();
398 let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
399
400 let result = unsafe {
401 (self.udf.invoke_with_args)(
402 &self.udf,
403 args,
404 arg_fields,
405 number_rows,
406 return_field,
407 )
408 };
409
410 let result = df_result!(result)?;
411 let result_array: ArrayRef = result.try_into()?;
412
413 Ok(ColumnarValue::Array(result_array))
414 }
415
416 fn aliases(&self) -> &[String] {
417 &self.aliases
418 }
419
420 fn short_circuits(&self) -> bool {
421 self.udf.short_circuits
422 }
423
424 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
425 unsafe {
426 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
427 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
428 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_round_trip_scalar_udf() -> Result<()> {
439 let original_udf = datafusion::functions::math::abs::AbsFunc::new();
440 let original_udf = Arc::new(ScalarUDF::from(original_udf));
441
442 let mut local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
443 local_udf.library_marker_id = crate::mock_foreign_marker_id;
444
445 let foreign_udf: Arc<dyn ScalarUDFImpl> = (&local_udf).into();
446
447 assert_eq!(original_udf.name(), foreign_udf.name());
448
449 Ok(())
450 }
451
452 #[test]
453 fn test_ffi_udf_local_bypass() -> Result<()> {
454 use datafusion::functions::math::abs::AbsFunc;
455 let original_udf = AbsFunc::new();
456 let original_udf = Arc::new(ScalarUDF::from(original_udf));
457
458 let mut ffi_udf = FFI_ScalarUDF::from(original_udf);
459
460 let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
462 assert!(foreign_udf.as_any().downcast_ref::<AbsFunc>().is_some());
463
464 ffi_udf.library_marker_id = crate::mock_foreign_marker_id;
466 let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
467 assert!(
468 foreign_udf
469 .as_any()
470 .downcast_ref::<ForeignScalarUDF>()
471 .is_some()
472 );
473
474 Ok(())
475 }
476}