Skip to main content

daft_ext/
function.rs

1use std::{
2    ffi::{CStr, c_char, c_int, c_void},
3    sync::Arc,
4};
5
6use crate::abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction};
7
8use crate::{error::DaftResult, ffi::trampoline::trampoline};
9
10/// Trait that extension authors implement to define a scalar function.
11pub trait DaftScalarFunction {
12    fn name(&self) -> &CStr;
13    fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema>;
14    fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData>;
15}
16
17/// A shared, type-erased scalar function reference.
18pub type DaftScalarFunctionRef = Arc<dyn DaftScalarFunction>;
19
20/// Convert a [`DaftScalarFunctionRef`] into a [`FFI_ScalarFunction`] vtable.
21///
22/// The `Arc` is moved into the vtable's opaque context and released
23/// when the host calls `fini`.
24pub fn into_ffi(func: DaftScalarFunctionRef) -> FFI_ScalarFunction {
25    let ctx_ptr = Box::into_raw(Box::new(func));
26    FFI_ScalarFunction {
27        ctx: ctx_ptr.cast(),
28        name: ffi_name,
29        get_return_field: ffi_get_return_field,
30        call: ffi_call,
31        fini: ffi_fini,
32    }
33}
34
35/// Returns the function name as a null-terminated UTF-8 string.
36unsafe extern "C" fn ffi_name(ctx: *const c_void) -> *const c_char {
37    unsafe { &*ctx.cast::<DaftScalarFunctionRef>() }
38        .name()
39        .as_ptr()
40}
41
42/// Returns the output field given the input fields.
43#[rustfmt::skip]
44unsafe extern "C" fn ffi_get_return_field(
45    ctx:        *const c_void,
46    args:       *const ArrowSchema,
47    args_count: usize,
48    ret:        *mut ArrowSchema,
49    errmsg:     *mut *mut c_char,
50) -> c_int {
51    unsafe { trampoline(errmsg, "panic in get_return_field", || {
52        let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
53        let schemas = if args_count == 0 {
54            &[]
55        } else {
56            std::slice::from_raw_parts(args, args_count)
57        };
58        let result = ctx.return_field(schemas)?;
59        std::ptr::write(ret, result);
60        Ok(())
61    })}
62}
63
64/// Evaluates the function on Arrow arrays via the C Data Interface.
65#[rustfmt::skip]
66unsafe extern "C" fn ffi_call(
67    ctx:          *const c_void,
68    args:         *const ArrowArray,
69    args_schemas: *const ArrowSchema,
70    args_count:   usize,
71    ret_array:    *mut ArrowArray,
72    ret_schema:   *mut ArrowSchema,
73    errmsg:       *mut *mut c_char,
74) -> c_int {
75    unsafe { trampoline(errmsg, "panic in call", || {
76        let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
77        let mut data = Vec::with_capacity(args_count);
78        for i in 0..args_count {
79            let array = std::ptr::read(args.add(i));
80            let schema = std::ptr::read(args_schemas.add(i));
81            data.push(ArrowData { schema, array });
82        }
83        let result = ctx.call(data)?;
84        std::ptr::write(ret_array, result.array);
85        std::ptr::write(ret_schema, result.schema);
86        Ok(())
87    })}
88}
89
90/// Finalizes the function, freeing all owned resources.
91unsafe extern "C" fn ffi_fini(ctx: *mut c_void) {
92    let _ = std::panic::catch_unwind(|| unsafe {
93        drop(Box::from_raw(ctx.cast::<DaftScalarFunctionRef>()));
94    });
95}
96
97#[cfg(test)]
98mod tests {
99    use arrow_array::{Array, ArrayRef, Int32Array};
100    use arrow_schema::{DataType, Field, Schema};
101    use crate::abi::ffi::strings::free_string;
102
103    use super::*;
104    use crate::error::DaftError;
105
106    fn export_array(array: &dyn Array) -> ArrowData {
107        let (ffi_array, ffi_schema) = arrow::ffi::to_ffi(&array.to_data()).unwrap();
108        ArrowData {
109            array: unsafe { ArrowArray::from_owned(ffi_array) },
110            schema: unsafe { ArrowSchema::from_owned(ffi_schema) },
111        }
112    }
113
114    fn import_array(data: ArrowData) -> ArrayRef {
115        let ffi_array: arrow::ffi::FFI_ArrowArray = unsafe { data.array.into_owned() };
116        let ffi_schema: arrow::ffi::FFI_ArrowSchema = unsafe { data.schema.into_owned() };
117        let arrow_data = unsafe { arrow::ffi::from_ffi(ffi_array, &ffi_schema) }.unwrap();
118        arrow_array::make_array(arrow_data)
119    }
120
121    fn export_schema(schema: &Schema) -> ArrowSchema {
122        let ffi = arrow::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
123        unsafe { ArrowSchema::from_owned(ffi) }
124    }
125
126    fn import_schema(schema: &ArrowSchema) -> Schema {
127        let ffi: &arrow::ffi::FFI_ArrowSchema = unsafe { schema.as_raw() };
128        Schema::try_from(ffi).unwrap()
129    }
130
131    struct IncrementFn;
132
133    impl DaftScalarFunction for IncrementFn {
134        fn name(&self) -> &CStr {
135            c"increment"
136        }
137
138        fn return_field(&self, _args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
139            let field = Field::new("result", DataType::Int32, false);
140            Ok(export_schema(&Schema::new(vec![field])))
141        }
142
143        fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData> {
144            let input_array = import_array(args.into_iter().next().unwrap());
145            let input = input_array
146                .as_any()
147                .downcast_ref::<Int32Array>()
148                .ok_or_else(|| DaftError::TypeError("expected Int32".into()))?;
149            let output: Int32Array = input.iter().map(|v| v.map(|x| x + 1)).collect();
150            Ok(export_array(&output))
151        }
152    }
153
154    #[test]
155    fn vtable_name_roundtrip() {
156        let vtable = into_ffi(Arc::new(IncrementFn));
157
158        let name = unsafe { CStr::from_ptr((vtable.name)(vtable.ctx)) };
159        assert_eq!(name.to_str().unwrap(), "increment");
160
161        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
162    }
163
164    #[test]
165    fn vtable_get_return_field_roundtrip() {
166        let vtable = into_ffi(Arc::new(IncrementFn));
167
168        let field = Field::new("x", DataType::Int32, false);
169        let ffi_schema = export_schema(&Schema::new(vec![field]));
170
171        let mut ret_schema = ArrowSchema::empty();
172        let mut errmsg: *mut c_char = std::ptr::null_mut();
173
174        let rc = unsafe {
175            (vtable.get_return_field)(
176                vtable.ctx,
177                &raw const ffi_schema,
178                1,
179                &raw mut ret_schema,
180                &raw mut errmsg,
181            )
182        };
183
184        assert_eq!(rc, 0, "get_return_field should succeed");
185
186        let schema = import_schema(&ret_schema);
187        assert_eq!(schema.field(0).name(), "result");
188        assert_eq!(*schema.field(0).data_type(), DataType::Int32);
189
190        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
191    }
192
193    #[test]
194    fn vtable_call_roundtrip() {
195        let vtable = into_ffi(Arc::new(IncrementFn));
196
197        let input = Int32Array::from(vec![1, 2, 3]);
198        let data = export_array(&input);
199
200        let mut ret_array = ArrowArray::empty();
201        let mut ret_schema = ArrowSchema::empty();
202        let mut errmsg: *mut c_char = std::ptr::null_mut();
203
204        let rc = unsafe {
205            (vtable.call)(
206                vtable.ctx,
207                &raw const data.array,
208                &raw const data.schema,
209                1,
210                &raw mut ret_array,
211                &raw mut ret_schema,
212                &raw mut errmsg,
213            )
214        };
215
216        assert_eq!(rc, 0, "call should succeed");
217
218        let result_array = import_array(ArrowData {
219            schema: ret_schema,
220            array: ret_array,
221        });
222        let result = result_array.as_any().downcast_ref::<Int32Array>().unwrap();
223        assert_eq!(result.values(), &[2, 3, 4]);
224
225        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
226    }
227
228    #[test]
229    fn vtable_error_propagation() {
230        struct FailingFn;
231        impl DaftScalarFunction for FailingFn {
232            fn name(&self) -> &CStr {
233                c"failing"
234            }
235            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
236                Err(DaftError::TypeError("bad type".into()))
237            }
238            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
239                Err(DaftError::RuntimeError("compute failed".into()))
240            }
241        }
242
243        let vtable = into_ffi(Arc::new(FailingFn));
244
245        let mut ret_schema = ArrowSchema::empty();
246        let mut errmsg: *mut c_char = std::ptr::null_mut();
247
248        let rc = unsafe {
249            (vtable.get_return_field)(
250                vtable.ctx,
251                std::ptr::null(),
252                0,
253                &raw mut ret_schema,
254                &raw mut errmsg,
255            )
256        };
257
258        assert_ne!(rc, 0, "should return non-zero on error");
259        assert!(!errmsg.is_null());
260
261        let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
262        assert!(err_str.contains("bad type"), "error message: {err_str}");
263
264        unsafe { free_string(errmsg) };
265        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
266    }
267
268    #[test]
269    fn vtable_call_error_propagation() {
270        struct CallFailFn;
271        impl DaftScalarFunction for CallFailFn {
272            fn name(&self) -> &CStr {
273                c"call_fail"
274            }
275            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
276                Ok(export_schema(&Schema::new(vec![Field::new(
277                    "x",
278                    DataType::Int32,
279                    false,
280                )])))
281            }
282            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
283                Err(DaftError::RuntimeError("compute failed".into()))
284            }
285        }
286
287        let vtable = into_ffi(Arc::new(CallFailFn));
288
289        let input = Int32Array::from(vec![1]);
290        let data = export_array(&input);
291
292        let mut ret_array = ArrowArray::empty();
293        let mut ret_schema = ArrowSchema::empty();
294        let mut errmsg: *mut c_char = std::ptr::null_mut();
295
296        let rc = unsafe {
297            (vtable.call)(
298                vtable.ctx,
299                &raw const data.array,
300                &raw const data.schema,
301                1,
302                &raw mut ret_array,
303                &raw mut ret_schema,
304                &raw mut errmsg,
305            )
306        };
307
308        assert_ne!(rc, 0, "call should return non-zero on error");
309        assert!(!errmsg.is_null());
310
311        let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
312        assert!(
313            err_str.contains("compute failed"),
314            "error message: {err_str}"
315        );
316
317        unsafe { free_string(errmsg) };
318        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
319    }
320
321    #[test]
322    fn vtable_zero_args() {
323        struct NoArgFn;
324        impl DaftScalarFunction for NoArgFn {
325            fn name(&self) -> &CStr {
326                c"no_args"
327            }
328            fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
329                assert!(args.is_empty());
330                Ok(export_schema(&Schema::new(vec![Field::new(
331                    "result",
332                    DataType::Int32,
333                    false,
334                )])))
335            }
336            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
337                let output = Int32Array::from(vec![42]);
338                Ok(export_array(&output))
339            }
340        }
341
342        let vtable = into_ffi(Arc::new(NoArgFn));
343
344        let mut ret_schema = ArrowSchema::empty();
345        let mut errmsg: *mut c_char = std::ptr::null_mut();
346
347        let rc = unsafe {
348            (vtable.get_return_field)(
349                vtable.ctx,
350                std::ptr::null(),
351                0,
352                &raw mut ret_schema,
353                &raw mut errmsg,
354            )
355        };
356        assert_eq!(rc, 0, "get_return_field with zero args should succeed");
357
358        let schema = import_schema(&ret_schema);
359        assert_eq!(schema.field(0).name(), "result");
360
361        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
362    }
363
364    #[test]
365    fn fini_is_callable() {
366        struct DisposableFn;
367        impl DaftScalarFunction for DisposableFn {
368            fn name(&self) -> &CStr {
369                c"disposable"
370            }
371            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
372                Ok(export_schema(&Schema::new(vec![Field::new(
373                    "x",
374                    DataType::Null,
375                    true,
376                )])))
377            }
378            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
379                let output = Int32Array::from(vec![0]);
380                Ok(export_array(&output))
381            }
382        }
383
384        let vtable = into_ffi(Arc::new(DisposableFn));
385        unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
386    }
387}