1use std::{
2 ffi::{CStr, c_char, c_int, c_void},
3 sync::Arc,
4};
5
6use crate::abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction};
7
8use crate::{error::DaftResult, ffi::trampoline::trampoline};
9
10pub trait DaftScalarFunction {
12 fn name(&self) -> &CStr;
13 fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema>;
14 fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData>;
15}
16
17pub type DaftScalarFunctionRef = Arc<dyn DaftScalarFunction>;
19
20pub fn into_ffi(func: DaftScalarFunctionRef) -> FFI_ScalarFunction {
25 let ctx_ptr = Box::into_raw(Box::new(func));
26 FFI_ScalarFunction {
27 ctx: ctx_ptr.cast(),
28 name: ffi_name,
29 get_return_field: ffi_get_return_field,
30 call: ffi_call,
31 fini: ffi_fini,
32 }
33}
34
35unsafe extern "C" fn ffi_name(ctx: *const c_void) -> *const c_char {
37 unsafe { &*ctx.cast::<DaftScalarFunctionRef>() }
38 .name()
39 .as_ptr()
40}
41
42#[rustfmt::skip]
44unsafe extern "C" fn ffi_get_return_field(
45 ctx: *const c_void,
46 args: *const ArrowSchema,
47 args_count: usize,
48 ret: *mut ArrowSchema,
49 errmsg: *mut *mut c_char,
50) -> c_int {
51 unsafe { trampoline(errmsg, "panic in get_return_field", || {
52 let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
53 let schemas = if args_count == 0 {
54 &[]
55 } else {
56 std::slice::from_raw_parts(args, args_count)
57 };
58 let result = ctx.return_field(schemas)?;
59 std::ptr::write(ret, result);
60 Ok(())
61 })}
62}
63
64#[rustfmt::skip]
66unsafe extern "C" fn ffi_call(
67 ctx: *const c_void,
68 args: *const ArrowArray,
69 args_schemas: *const ArrowSchema,
70 args_count: usize,
71 ret_array: *mut ArrowArray,
72 ret_schema: *mut ArrowSchema,
73 errmsg: *mut *mut c_char,
74) -> c_int {
75 unsafe { trampoline(errmsg, "panic in call", || {
76 let ctx = &*ctx.cast::<DaftScalarFunctionRef>();
77 let mut data = Vec::with_capacity(args_count);
78 for i in 0..args_count {
79 let array = std::ptr::read(args.add(i));
80 let schema = std::ptr::read(args_schemas.add(i));
81 data.push(ArrowData { schema, array });
82 }
83 let result = ctx.call(data)?;
84 std::ptr::write(ret_array, result.array);
85 std::ptr::write(ret_schema, result.schema);
86 Ok(())
87 })}
88}
89
90unsafe extern "C" fn ffi_fini(ctx: *mut c_void) {
92 let _ = std::panic::catch_unwind(|| unsafe {
93 drop(Box::from_raw(ctx.cast::<DaftScalarFunctionRef>()));
94 });
95}
96
97#[cfg(test)]
98mod tests {
99 use arrow_array::{Array, ArrayRef, Int32Array};
100 use arrow_schema::{DataType, Field, Schema};
101 use crate::abi::ffi::strings::free_string;
102
103 use super::*;
104 use crate::error::DaftError;
105
106 fn export_array(array: &dyn Array) -> ArrowData {
107 let (ffi_array, ffi_schema) = arrow::ffi::to_ffi(&array.to_data()).unwrap();
108 ArrowData {
109 array: unsafe { ArrowArray::from_owned(ffi_array) },
110 schema: unsafe { ArrowSchema::from_owned(ffi_schema) },
111 }
112 }
113
114 fn import_array(data: ArrowData) -> ArrayRef {
115 let ffi_array: arrow::ffi::FFI_ArrowArray = unsafe { data.array.into_owned() };
116 let ffi_schema: arrow::ffi::FFI_ArrowSchema = unsafe { data.schema.into_owned() };
117 let arrow_data = unsafe { arrow::ffi::from_ffi(ffi_array, &ffi_schema) }.unwrap();
118 arrow_array::make_array(arrow_data)
119 }
120
121 fn export_schema(schema: &Schema) -> ArrowSchema {
122 let ffi = arrow::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
123 unsafe { ArrowSchema::from_owned(ffi) }
124 }
125
126 fn import_schema(schema: &ArrowSchema) -> Schema {
127 let ffi: &arrow::ffi::FFI_ArrowSchema = unsafe { schema.as_raw() };
128 Schema::try_from(ffi).unwrap()
129 }
130
131 struct IncrementFn;
132
133 impl DaftScalarFunction for IncrementFn {
134 fn name(&self) -> &CStr {
135 c"increment"
136 }
137
138 fn return_field(&self, _args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
139 let field = Field::new("result", DataType::Int32, false);
140 Ok(export_schema(&Schema::new(vec![field])))
141 }
142
143 fn call(&self, args: Vec<ArrowData>) -> DaftResult<ArrowData> {
144 let input_array = import_array(args.into_iter().next().unwrap());
145 let input = input_array
146 .as_any()
147 .downcast_ref::<Int32Array>()
148 .ok_or_else(|| DaftError::TypeError("expected Int32".into()))?;
149 let output: Int32Array = input.iter().map(|v| v.map(|x| x + 1)).collect();
150 Ok(export_array(&output))
151 }
152 }
153
154 #[test]
155 fn vtable_name_roundtrip() {
156 let vtable = into_ffi(Arc::new(IncrementFn));
157
158 let name = unsafe { CStr::from_ptr((vtable.name)(vtable.ctx)) };
159 assert_eq!(name.to_str().unwrap(), "increment");
160
161 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
162 }
163
164 #[test]
165 fn vtable_get_return_field_roundtrip() {
166 let vtable = into_ffi(Arc::new(IncrementFn));
167
168 let field = Field::new("x", DataType::Int32, false);
169 let ffi_schema = export_schema(&Schema::new(vec![field]));
170
171 let mut ret_schema = ArrowSchema::empty();
172 let mut errmsg: *mut c_char = std::ptr::null_mut();
173
174 let rc = unsafe {
175 (vtable.get_return_field)(
176 vtable.ctx,
177 &raw const ffi_schema,
178 1,
179 &raw mut ret_schema,
180 &raw mut errmsg,
181 )
182 };
183
184 assert_eq!(rc, 0, "get_return_field should succeed");
185
186 let schema = import_schema(&ret_schema);
187 assert_eq!(schema.field(0).name(), "result");
188 assert_eq!(*schema.field(0).data_type(), DataType::Int32);
189
190 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
191 }
192
193 #[test]
194 fn vtable_call_roundtrip() {
195 let vtable = into_ffi(Arc::new(IncrementFn));
196
197 let input = Int32Array::from(vec![1, 2, 3]);
198 let data = export_array(&input);
199
200 let mut ret_array = ArrowArray::empty();
201 let mut ret_schema = ArrowSchema::empty();
202 let mut errmsg: *mut c_char = std::ptr::null_mut();
203
204 let rc = unsafe {
205 (vtable.call)(
206 vtable.ctx,
207 &raw const data.array,
208 &raw const data.schema,
209 1,
210 &raw mut ret_array,
211 &raw mut ret_schema,
212 &raw mut errmsg,
213 )
214 };
215
216 assert_eq!(rc, 0, "call should succeed");
217
218 let result_array = import_array(ArrowData {
219 schema: ret_schema,
220 array: ret_array,
221 });
222 let result = result_array.as_any().downcast_ref::<Int32Array>().unwrap();
223 assert_eq!(result.values(), &[2, 3, 4]);
224
225 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
226 }
227
228 #[test]
229 fn vtable_error_propagation() {
230 struct FailingFn;
231 impl DaftScalarFunction for FailingFn {
232 fn name(&self) -> &CStr {
233 c"failing"
234 }
235 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
236 Err(DaftError::TypeError("bad type".into()))
237 }
238 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
239 Err(DaftError::RuntimeError("compute failed".into()))
240 }
241 }
242
243 let vtable = into_ffi(Arc::new(FailingFn));
244
245 let mut ret_schema = ArrowSchema::empty();
246 let mut errmsg: *mut c_char = std::ptr::null_mut();
247
248 let rc = unsafe {
249 (vtable.get_return_field)(
250 vtable.ctx,
251 std::ptr::null(),
252 0,
253 &raw mut ret_schema,
254 &raw mut errmsg,
255 )
256 };
257
258 assert_ne!(rc, 0, "should return non-zero on error");
259 assert!(!errmsg.is_null());
260
261 let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
262 assert!(err_str.contains("bad type"), "error message: {err_str}");
263
264 unsafe { free_string(errmsg) };
265 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
266 }
267
268 #[test]
269 fn vtable_call_error_propagation() {
270 struct CallFailFn;
271 impl DaftScalarFunction for CallFailFn {
272 fn name(&self) -> &CStr {
273 c"call_fail"
274 }
275 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
276 Ok(export_schema(&Schema::new(vec![Field::new(
277 "x",
278 DataType::Int32,
279 false,
280 )])))
281 }
282 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
283 Err(DaftError::RuntimeError("compute failed".into()))
284 }
285 }
286
287 let vtable = into_ffi(Arc::new(CallFailFn));
288
289 let input = Int32Array::from(vec![1]);
290 let data = export_array(&input);
291
292 let mut ret_array = ArrowArray::empty();
293 let mut ret_schema = ArrowSchema::empty();
294 let mut errmsg: *mut c_char = std::ptr::null_mut();
295
296 let rc = unsafe {
297 (vtable.call)(
298 vtable.ctx,
299 &raw const data.array,
300 &raw const data.schema,
301 1,
302 &raw mut ret_array,
303 &raw mut ret_schema,
304 &raw mut errmsg,
305 )
306 };
307
308 assert_ne!(rc, 0, "call should return non-zero on error");
309 assert!(!errmsg.is_null());
310
311 let err_str = unsafe { CStr::from_ptr(errmsg) }.to_str().unwrap();
312 assert!(
313 err_str.contains("compute failed"),
314 "error message: {err_str}"
315 );
316
317 unsafe { free_string(errmsg) };
318 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
319 }
320
321 #[test]
322 fn vtable_zero_args() {
323 struct NoArgFn;
324 impl DaftScalarFunction for NoArgFn {
325 fn name(&self) -> &CStr {
326 c"no_args"
327 }
328 fn return_field(&self, args: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
329 assert!(args.is_empty());
330 Ok(export_schema(&Schema::new(vec![Field::new(
331 "result",
332 DataType::Int32,
333 false,
334 )])))
335 }
336 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
337 let output = Int32Array::from(vec![42]);
338 Ok(export_array(&output))
339 }
340 }
341
342 let vtable = into_ffi(Arc::new(NoArgFn));
343
344 let mut ret_schema = ArrowSchema::empty();
345 let mut errmsg: *mut c_char = std::ptr::null_mut();
346
347 let rc = unsafe {
348 (vtable.get_return_field)(
349 vtable.ctx,
350 std::ptr::null(),
351 0,
352 &raw mut ret_schema,
353 &raw mut errmsg,
354 )
355 };
356 assert_eq!(rc, 0, "get_return_field with zero args should succeed");
357
358 let schema = import_schema(&ret_schema);
359 assert_eq!(schema.field(0).name(), "result");
360
361 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
362 }
363
364 #[test]
365 fn fini_is_callable() {
366 struct DisposableFn;
367 impl DaftScalarFunction for DisposableFn {
368 fn name(&self) -> &CStr {
369 c"disposable"
370 }
371 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
372 Ok(export_schema(&Schema::new(vec![Field::new(
373 "x",
374 DataType::Null,
375 true,
376 )])))
377 }
378 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
379 let output = Int32Array::from(vec![0]);
380 Ok(export_array(&output))
381 }
382 }
383
384 let vtable = into_ffi(Arc::new(DisposableFn));
385 unsafe { (vtable.fini)(vtable.ctx.cast_mut()) };
386 }
387}