1use std::{
2 ffi::{CStr, c_char, c_int, c_void},
3 sync::Arc,
4};
5
6use crate::{
7 abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction},
8 error::DaftResult,
9 ffi::trampoline::trampoline,
10};
11
12pub trait DaftScalarFunction {
14 fn name(&self) -> &CStr;
15 fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema>;
16 fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData>;
17}
18
19pub type DaftScalarFunctionRef = Arc<dyn DaftScalarFunction>;
21
22pub fn into_ffi(func: DaftScalarFunctionRef) -> FFI_ScalarFunction {
27 let ctx_ptr = Box::into_raw(Box::new(func));
28 FFI_ScalarFunction {
29 ctx: ctx_ptr.cast(),
30 name: ffi_name,
31 get_return_field: ffi_get_return_field,
32 call: ffi_call,
33 fini: ffi_fini,
34 }
35}
36
37unsafe extern "C" fn ffi_name(ctx: *const c_void) -> *const c_char {
39 unsafe { &*ctx.cast::<DaftScalarFunctionRef>() }
40 .name()
41 .as_ptr()
42}
43
44#[rustfmt::skip]
46unsafe extern "C" fn ffi_get_return_field(
47 ctx: *const c_void,
48 args: *const ArrowSchema,
49 args_count: usize,
50 ret: *mut ArrowSchema,
51 errmsg: *mut *mut c_char,
52) -> c_int {
53 unsafe { trampoline(errmsg, "panic in get_return_field", || {
54 let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
55 let schemas = if args_count == 0 {
56 &[]
57 } else {
58 std::slice::from_raw_parts(args, args_count)
59 };
60 let result = ctx.return_field(schemas)?;
61 std::ptr::write(ret, result);
62 Ok(())
63 })}
64}
65
66#[rustfmt::skip]
68unsafe extern "C" fn ffi_call(
69 ctx: *const c_void,
70 args: *const ArrowArray,
71 args_schemas: *const ArrowSchema,
72 args_count: usize,
73 ret_array: *mut ArrowArray,
74 ret_schema: *mut ArrowSchema,
75 errmsg: *mut *mut c_char,
76) -> c_int {
77 unsafe { trampoline(errmsg, "panic in call", || {
78 let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
79 let mut data = Vec::with_capacity(args_count);
80 for i in 0..args_count {
81 let array = std::ptr::read(args.add(i));
82 let schema = std::ptr::read(args_schemas.add(i));
83 data.push(ArrowData { schema, array });
84 }
85 let result = ctx.call(data)?;
86 std::ptr::write(ret_array, result.array);
87 std::ptr::write(ret_schema, result.schema);
88 Ok(())
89 })}
90}
91
92unsafe extern "C" fn ffi_fini(ctx: *mut c_void) {
94 let _ = std::panic::catch_unwind(|| unsafe {
95 drop(Box::from_raw(ctx.cast::<DaftScalarFunctionRef>()));
96 });
97}
98
99#[cfg(all(test, feature = "arrow-57"))]
100mod tests {
101 use arrow_schema_57::{DataType, Field, Schema};
102
103 use super::*;
104 use crate::{abi::ffi::strings::free_string, error::DaftError};
105
106 fn export_schema(schema: &Schema) -> ArrowSchema {
110 let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
111 unsafe { ArrowSchema::from_owned(ffi) }
112 }
113
114 fn import_schema(schema: &ArrowSchema) -> Schema {
116 let ffi: &arrow_schema_57::ffi::FFI_ArrowSchema = unsafe { schema.as_raw() };
117 Schema::try_from(ffi).unwrap()
118 }
119
120 fn make_int32(values: &[i32]) -> ArrowData {
124 let schema = {
125 let field = Field::new("", DataType::Int32, false);
126 let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(&field).unwrap();
127 unsafe { ArrowSchema::from_owned(ffi) }
128 };
129
130 let values: Box<[i32]> = values.into();
131 let len = values.len();
132
133 let buffers = Box::new([std::ptr::null::<c_void>(), values.as_ptr().cast::<c_void>()]);
135
136 let private = Box::new((values, std::ptr::null::<c_void>()));
138 let array = ArrowArray {
139 length: len as i64,
140 null_count: 0,
141 offset: 0,
142 n_buffers: 2,
143 n_children: 0,
144 buffers: Box::into_raw(buffers).cast::<*const c_void>(),
145 children: std::ptr::null_mut(),
146 dictionary: std::ptr::null_mut(),
147 release: Some(release_int32),
148 private_data: Box::into_raw(private).cast::<c_void>(),
149 };
150
151 ArrowData { schema, array }
152 }
153
154 unsafe extern "C" fn release_int32(array: *mut ArrowArray) {
155 let a = unsafe { &mut *array };
156 drop(unsafe { Box::from_raw(a.private_data.cast::<(Box<[i32]>, *const c_void)>()) });
158 drop(unsafe { Box::from_raw(a.buffers.cast::<[*const c_void; 2]>()) });
160 a.release = None;
161 }
162
163 fn read_int32(data: &ArrowData) -> &[i32] {
165 unsafe {
166 let bufs = std::slice::from_raw_parts(data.array.buffers.cast_const(), 2);
167 std::slice::from_raw_parts(bufs[1].cast::<i32>(), data.array.length as usize)
168 }
169 }
170
171 struct IncrementFn;
174
175 impl DaftScalarFunction for IncrementFn {
176 fn name(&self) -> &CStr {
177 c"increment"
178 }
179
180 fn return_field(&self, _args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
181 let field = Field::new("result", DataType::Int32, false);
182 Ok(export_schema(&Schema::new(vec![field])))
183 }
184
185 fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData> {
186 let input = args
187 .first()
188 .ok_or_else(|| DaftError::TypeError("expected at least one argument".into()))?;
189 let values = read_int32(input);
190 let output: Vec<i32> = values.iter().map(|x| x + 1).collect();
191 Ok(make_int32(&output))
192 }
193 }
194
195 #[test]
198 fn vtable_name_roundtrip() {
199 let vtable = into_ffi(Arc::new(IncrementFn));
200
201 let name = unsafe { CStr::from_ptr((vtable.name)(vtable.ctx)) };
202 assert_eq!(name.to_str().unwrap(), "increment");
203
204 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
205 }
206
207 #[test]
208 fn vtable_get_return_field_roundtrip() {
209 let vtable = into_ffi(Arc::new(IncrementFn));
210
211 let field = Field::new("x", DataType::Int32, false);
212 let ffi_schema = export_schema(&Schema::new(vec![field]));
213
214 let mut ret_schema = ArrowSchema::empty();
215 let mut errmsg: *mut c_char = std::ptr::null_mut();
216
217 let rc = unsafe {
218 (vtable.get_return_field)(
219 vtable.ctx,
220 &raw const ffi_schema,
221 1,
222 &raw mut ret_schema,
223 &raw mut errmsg,
224 )
225 };
226
227 assert_eq!(rc, 0, "get_return_field should succeed");
228
229 let schema = import_schema(&ret_schema);
230 assert_eq!(schema.field(0).name(), "result");
231 assert_eq!(*schema.field(0).data_type(), DataType::Int32);
232
233 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
234 }
235
236 #[test]
237 fn vtable_call_roundtrip() {
238 let vtable = into_ffi(Arc::new(IncrementFn));
239
240 let data = make_int32(&[1, 2, 3]);
241
242 let mut ret_array = ArrowArray::empty();
243 let mut ret_schema = ArrowSchema::empty();
244 let mut errmsg: *mut c_char = std::ptr::null_mut();
245
246 let rc = unsafe {
247 (vtable.call)(
248 vtable.ctx,
249 &raw const data.array,
250 &raw const data.schema,
251 1,
252 &raw mut ret_array,
253 &raw mut ret_schema,
254 &raw mut errmsg,
255 )
256 };
257
258 assert_eq!(rc, 0, "call should succeed");
259
260 let result = ArrowData {
261 schema: ret_schema,
262 array: ret_array,
263 };
264 assert_eq!(read_int32(&result), &[2, 3, 4]);
265
266 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
267 }
268
269 #[test]
270 fn vtable_error_propagation() {
271 struct FailingFn;
272 impl DaftScalarFunction for FailingFn {
273 fn name(&self) -> &CStr {
274 c"failing"
275 }
276 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
277 Err(DaftError::TypeError("bad type".into()))
278 }
279 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
280 Err(DaftError::RuntimeError("compute failed".into()))
281 }
282 }
283
284 let vtable = into_ffi(Arc::new(FailingFn));
285
286 let mut ret_schema = ArrowSchema::empty();
287 let mut errmsg: *mut c_char = std::ptr::null_mut();
288
289 let rc = unsafe {
290 (vtable.get_return_field)(
291 vtable.ctx,
292 std::ptr::null(),
293 0,
294 &raw mut ret_schema,
295 &raw mut errmsg,
296 )
297 };
298
299 assert_ne!(rc, 0, "should return non-zero on error");
300 assert!(!errmsg.is_null());
301
302 let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
303 assert!(err_str.contains("bad type"), "error message: {err_str}");
304
305 unsafe { free_string(errmsg) };
306 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
307 }
308
309 #[test]
310 fn vtable_call_error_propagation() {
311 struct CallFailFn;
312 impl DaftScalarFunction for CallFailFn {
313 fn name(&self) -> &CStr {
314 c"call_fail"
315 }
316 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
317 Ok(export_schema(&Schema::new(vec![Field::new(
318 "x",
319 DataType::Int32,
320 false,
321 )])))
322 }
323 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
324 Err(DaftError::RuntimeError("compute failed".into()))
325 }
326 }
327
328 let vtable = into_ffi(Arc::new(CallFailFn));
329
330 let data = make_int32(&[1]);
331
332 let mut ret_array = ArrowArray::empty();
333 let mut ret_schema = ArrowSchema::empty();
334 let mut errmsg: *mut c_char = std::ptr::null_mut();
335
336 let rc = unsafe {
337 (vtable.call)(
338 vtable.ctx,
339 &raw const data.array,
340 &raw const data.schema,
341 1,
342 &raw mut ret_array,
343 &raw mut ret_schema,
344 &raw mut errmsg,
345 )
346 };
347
348 assert_ne!(rc, 0, "call should return non-zero on error");
349 assert!(!errmsg.is_null());
350
351 let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
352 assert!(
353 err_str.contains("compute failed"),
354 "error message: {err_str}"
355 );
356
357 unsafe { free_string(errmsg) };
358 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
359 }
360
361 #[test]
362 fn vtable_zero_args() {
363 struct NoArgFn;
364 impl DaftScalarFunction for NoArgFn {
365 fn name(&self) -> &CStr {
366 c"no_args"
367 }
368 fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
369 assert!(args.is_empty());
370 Ok(export_schema(&Schema::new(vec![Field::new(
371 "result",
372 DataType::Int32,
373 false,
374 )])))
375 }
376 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
377 Ok(make_int32(&[42]))
378 }
379 }
380
381 let vtable = into_ffi(Arc::new(NoArgFn));
382
383 let mut ret_schema = ArrowSchema::empty();
384 let mut errmsg: *mut c_char = std::ptr::null_mut();
385
386 let rc = unsafe {
387 (vtable.get_return_field)(
388 vtable.ctx,
389 std::ptr::null(),
390 0,
391 &raw mut ret_schema,
392 &raw mut errmsg,
393 )
394 };
395 assert_eq!(rc, 0, "get_return_field with zero args should succeed");
396
397 let schema = import_schema(&ret_schema);
398 assert_eq!(schema.field(0).name(), "result");
399
400 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
401 }
402
403 #[test]
404 fn fini_is_callable() {
405 struct DisposableFn;
406 impl DaftScalarFunction for DisposableFn {
407 fn name(&self) -> &CStr {
408 c"disposable"
409 }
410 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
411 Ok(export_schema(&Schema::new(vec![Field::new(
412 "x",
413 DataType::Null,
414 true,
415 )])))
416 }
417 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
418 Ok(make_int32(&[0]))
419 }
420 }
421
422 let vtable = into_ffi(Arc::new(DisposableFn));
423 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
424 }
425}