Skip to main content

daft_ext/
function.rs

1use std::{
2    ffi::{CStr, c_char, c_int, c_void},
3    sync::Arc,
4};
5
6use crate::{
7    abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction},
8    error::DaftResult,
9    ffi::trampoline::trampoline,
10};
11
12/// Trait that extension authors implement to define a scalar function.
13pub trait DaftScalarFunction {
14    fn name(&self) -> &CStr;
15    fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema>;
16    fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData>;
17}
18
19/// A shared, type-erased scalar function reference.
20pub type DaftScalarFunctionRef = Arc<dyn DaftScalarFunction>;
21
22/// Convert a [`DaftScalarFunctionRef`] into a [`FFI_ScalarFunction`] vtable.
23///
24/// The `Arc` is moved into the vtable's opaque context and released
25/// when the host calls `fini`.
26pub fn into_ffi(func: DaftScalarFunctionRef) -> FFI_ScalarFunction {
27    let ctx_ptr = Box::into_raw(Box::new(func));
28    FFI_ScalarFunction {
29        ctx: ctx_ptr.cast(),
30        name: ffi_name,
31        get_return_field: ffi_get_return_field,
32        call: ffi_call,
33        fini: ffi_fini,
34    }
35}
36
37/// Returns the function name as a null-terminated UTF-8 string.
38unsafe extern "C" fn ffi_name(ctx: *const c_void) -> *const c_char {
39    unsafe { &*ctx.cast::<DaftScalarFunctionRef>() }
40        .name()
41        .as_ptr()
42}
43
44/// Returns the output field given the input fields.
45#[rustfmt::skip]
46unsafe extern "C" fn ffi_get_return_field(
47    ctx:        *const c_void,
48    args:       *const ArrowSchema,
49    args_count: usize,
50    ret:        *mut ArrowSchema,
51    errmsg:     *mut *mut c_char,
52) -> c_int {
53    unsafe { trampoline(errmsg, "panic in get_return_field", || {
54        let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
55        let schemas = if args_count == 0 {
56            &[]
57        } else {
58            std::slice::from_raw_parts(args, args_count)
59        };
60        let result = ctx.return_field(schemas)?;
61        std::ptr::write(ret, result);
62        Ok(())
63    })}
64}
65
66/// Evaluates the function on Arrow arrays via the C Data Interface.
67#[rustfmt::skip]
68unsafe extern "C" fn ffi_call(
69    ctx:          *const c_void,
70    args:         *const ArrowArray,
71    args_schemas: *const ArrowSchema,
72    args_count:   usize,
73    ret_array:    *mut ArrowArray,
74    ret_schema:   *mut ArrowSchema,
75    errmsg:       *mut *mut c_char,
76) -> c_int {
77    unsafe { trampoline(errmsg, "panic in call", || {
78        let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
79        let mut data = Vec::with_capacity(args_count);
80        for i in 0..args_count {
81            let array = std::ptr::read(args.add(i));
82            let schema = std::ptr::read(args_schemas.add(i));
83            data.push(ArrowData { schema, array });
84        }
85        let result = ctx.call(data)?;
86        std::ptr::write(ret_array, result.array);
87        std::ptr::write(ret_schema, result.schema);
88        Ok(())
89    })}
90}
91
92/// Finalizes the function, freeing all owned resources.
93unsafe extern "C" fn ffi_fini(ctx: *mut c_void) {
94    let _ = std::panic::catch_unwind(|| unsafe {
95        drop(Box::from_raw(ctx.cast::<DaftScalarFunctionRef>()));
96    });
97}
98
99#[cfg(all(test, feature = "arrow-57"))]
100mod tests {
101    use arrow_schema_57::{DataType, Field, Schema};
102
103    use super::*;
104    use crate::{abi::ffi::strings::free_string, error::DaftError};
105
106    // ── Raw-level test helpers (no arrow-array dependency) ──────────
107
108    /// Create an [`ArrowSchema`] from an `arrow_schema_57::Schema`.
109    fn export_schema(schema: &Schema) -> ArrowSchema {
110        let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
111        unsafe { ArrowSchema::from_owned(ffi) }
112    }
113
114    /// Read an [`ArrowSchema`] back into an `arrow_schema_57::Schema`.
115    fn import_schema(schema: &ArrowSchema) -> Schema {
116        let ffi: &arrow_schema_57::ffi::FFI_ArrowSchema = unsafe { schema.as_raw() };
117        Schema::try_from(ffi).unwrap()
118    }
119
120    /// Build an [`ArrowData`] containing an Int32 column from raw values.
121    ///
122    /// Allocates buffers on the heap; the release callback frees them.
123    fn make_int32(values: &[i32]) -> ArrowData {
124        let schema = {
125            let field = Field::new("", DataType::Int32, false);
126            let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(&field).unwrap();
127            unsafe { ArrowSchema::from_owned(ffi) }
128        };
129
130        let values: Box<[i32]> = values.into();
131        let len = values.len();
132
133        // Arrow Int32 layout: [validity (null), values]
134        let buffers = Box::new([std::ptr::null::<c_void>(), values.as_ptr().cast::<c_void>()]);
135
136        // Pack both allocations into private_data so `release` can free them.
137        let private = Box::new((values, std::ptr::null::<c_void>()));
138        let array = ArrowArray {
139            length: len as i64,
140            null_count: 0,
141            offset: 0,
142            n_buffers: 2,
143            n_children: 0,
144            buffers: Box::into_raw(buffers).cast::<*const c_void>(),
145            children: std::ptr::null_mut(),
146            dictionary: std::ptr::null_mut(),
147            release: Some(release_int32),
148            private_data: Box::into_raw(private).cast::<c_void>(),
149        };
150
151        ArrowData { schema, array }
152    }
153
154    unsafe extern "C" fn release_int32(array: *mut ArrowArray) {
155        let a = unsafe { &mut *array };
156        // Free the (values, _) tuple.
157        drop(unsafe { Box::from_raw(a.private_data.cast::<(Box<[i32]>, *const c_void)>()) });
158        // Free the buffers pointer array.
159        drop(unsafe { Box::from_raw(a.buffers.cast::<[*const c_void; 2]>()) });
160        a.release = None;
161    }
162
163    /// Read the i32 values out of an [`ArrowData`] (non-null, zero-offset).
164    fn read_int32(data: &ArrowData) -> &[i32] {
165        unsafe {
166            let bufs = std::slice::from_raw_parts(data.array.buffers.cast_const(), 2);
167            std::slice::from_raw_parts(bufs[1].cast::<i32>(), data.array.length as usize)
168        }
169    }
170
171    // ── Test function impls ─────────────────────────────────────────
172
173    struct IncrementFn;
174
175    impl DaftScalarFunction for IncrementFn {
176        fn name(&self) -> &CStr {
177            c"increment"
178        }
179
180        fn return_field(&self, _args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
181            let field = Field::new("result", DataType::Int32, false);
182            Ok(export_schema(&Schema::new(vec![field])))
183        }
184
185        fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData> {
186            let input = args
187                .first()
188                .ok_or_else(|| DaftError::TypeError("expected at least one argument".into()))?;
189            let values = read_int32(input);
190            let output: Vec<i32> = values.iter().map(|x| x + 1).collect();
191            Ok(make_int32(&output))
192        }
193    }
194
195    // ── Tests ───────────────────────────────────────────────────────
196
197    #[test]
198    fn vtable_name_roundtrip() {
199        let vtable = into_ffi(Arc::new(IncrementFn));
200
201        let name = unsafe { CStr::from_ptr((vtable.name)(vtable.ctx)) };
202        assert_eq!(name.to_str().unwrap(), "increment");
203
204        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
205    }
206
207    #[test]
208    fn vtable_get_return_field_roundtrip() {
209        let vtable = into_ffi(Arc::new(IncrementFn));
210
211        let field = Field::new("x", DataType::Int32, false);
212        let ffi_schema = export_schema(&Schema::new(vec![field]));
213
214        let mut ret_schema = ArrowSchema::empty();
215        let mut errmsg: *mut c_char = std::ptr::null_mut();
216
217        let rc = unsafe {
218            (vtable.get_return_field)(
219                vtable.ctx,
220                &raw const ffi_schema,
221                1,
222                &raw mut ret_schema,
223                &raw mut errmsg,
224            )
225        };
226
227        assert_eq!(rc, 0, "get_return_field should succeed");
228
229        let schema = import_schema(&ret_schema);
230        assert_eq!(schema.field(0).name(), "result");
231        assert_eq!(*schema.field(0).data_type(), DataType::Int32);
232
233        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
234    }
235
236    #[test]
237    fn vtable_call_roundtrip() {
238        let vtable = into_ffi(Arc::new(IncrementFn));
239
240        let data = make_int32(&[1, 2, 3]);
241
242        let mut ret_array = ArrowArray::empty();
243        let mut ret_schema = ArrowSchema::empty();
244        let mut errmsg: *mut c_char = std::ptr::null_mut();
245
246        let rc = unsafe {
247            (vtable.call)(
248                vtable.ctx,
249                &raw const data.array,
250                &raw const data.schema,
251                1,
252                &raw mut ret_array,
253                &raw mut ret_schema,
254                &raw mut errmsg,
255            )
256        };
257
258        assert_eq!(rc, 0, "call should succeed");
259
260        let result = ArrowData {
261            schema: ret_schema,
262            array: ret_array,
263        };
264        assert_eq!(read_int32(&result), &[2, 3, 4]);
265
266        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
267    }
268
269    #[test]
270    fn vtable_error_propagation() {
271        struct FailingFn;
272        impl DaftScalarFunction for FailingFn {
273            fn name(&self) -> &CStr {
274                c"failing"
275            }
276            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
277                Err(DaftError::TypeError("bad type".into()))
278            }
279            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
280                Err(DaftError::RuntimeError("compute failed".into()))
281            }
282        }
283
284        let vtable = into_ffi(Arc::new(FailingFn));
285
286        let mut ret_schema = ArrowSchema::empty();
287        let mut errmsg: *mut c_char = std::ptr::null_mut();
288
289        let rc = unsafe {
290            (vtable.get_return_field)(
291                vtable.ctx,
292                std::ptr::null(),
293                0,
294                &raw mut ret_schema,
295                &raw mut errmsg,
296            )
297        };
298
299        assert_ne!(rc, 0, "should return non-zero on error");
300        assert!(!errmsg.is_null());
301
302        let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
303        assert!(err_str.contains("bad type"), "error message: {err_str}");
304
305        unsafe { free_string(errmsg) };
306        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
307    }
308
309    #[test]
310    fn vtable_call_error_propagation() {
311        struct CallFailFn;
312        impl DaftScalarFunction for CallFailFn {
313            fn name(&self) -> &CStr {
314                c"call_fail"
315            }
316            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
317                Ok(export_schema(&Schema::new(vec![Field::new(
318                    "x",
319                    DataType::Int32,
320                    false,
321                )])))
322            }
323            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
324                Err(DaftError::RuntimeError("compute failed".into()))
325            }
326        }
327
328        let vtable = into_ffi(Arc::new(CallFailFn));
329
330        let data = make_int32(&[1]);
331
332        let mut ret_array = ArrowArray::empty();
333        let mut ret_schema = ArrowSchema::empty();
334        let mut errmsg: *mut c_char = std::ptr::null_mut();
335
336        let rc = unsafe {
337            (vtable.call)(
338                vtable.ctx,
339                &raw const data.array,
340                &raw const data.schema,
341                1,
342                &raw mut ret_array,
343                &raw mut ret_schema,
344                &raw mut errmsg,
345            )
346        };
347
348        assert_ne!(rc, 0, "call should return non-zero on error");
349        assert!(!errmsg.is_null());
350
351        let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
352        assert!(
353            err_str.contains("compute failed"),
354            "error message: {err_str}"
355        );
356
357        unsafe { free_string(errmsg) };
358        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
359    }
360
361    #[test]
362    fn vtable_zero_args() {
363        struct NoArgFn;
364        impl DaftScalarFunction for NoArgFn {
365            fn name(&self) -> &CStr {
366                c"no_args"
367            }
368            fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
369                assert!(args.is_empty());
370                Ok(export_schema(&Schema::new(vec![Field::new(
371                    "result",
372                    DataType::Int32,
373                    false,
374                )])))
375            }
376            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
377                Ok(make_int32(&[42]))
378            }
379        }
380
381        let vtable = into_ffi(Arc::new(NoArgFn));
382
383        let mut ret_schema = ArrowSchema::empty();
384        let mut errmsg: *mut c_char = std::ptr::null_mut();
385
386        let rc = unsafe {
387            (vtable.get_return_field)(
388                vtable.ctx,
389                std::ptr::null(),
390                0,
391                &raw mut ret_schema,
392                &raw mut errmsg,
393            )
394        };
395        assert_eq!(rc, 0, "get_return_field with zero args should succeed");
396
397        let schema = import_schema(&ret_schema);
398        assert_eq!(schema.field(0).name(), "result");
399
400        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
401    }
402
403    #[test]
404    fn fini_is_callable() {
405        struct DisposableFn;
406        impl DaftScalarFunction for DisposableFn {
407            fn name(&self) -> &CStr {
408                c"disposable"
409            }
410            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
411                Ok(export_schema(&Schema::new(vec![Field::new(
412                    "x",
413                    DataType::Null,
414                    true,
415                )])))
416            }
417            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
418                Ok(make_int32(&[0]))
419            }
420        }
421
422        let vtable = into_ffi(Arc::new(DisposableFn));
423        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
424    }
425}