Skip to main content

daft_ext/
session.rs

1use crate::abi::FFI_SessionContext;
2
3use crate::function::{DaftScalarFunctionRef, into_ffi};
4
5/// Trait for installing an extension within a session.
6pub trait DaftSession {
7    fn define_function(&mut self, function: DaftScalarFunctionRef);
8}
9
10/// Trait implemented by extension crates to install themselves.
11pub trait DaftExtension {
12    fn install(session: &mut dyn DaftSession);
13}
14
15/// A [`DaftSession`] backed by a [`FFI_SessionContext`] from the host.
16///
17/// This bridges the safe `DaftSession` trait to the C ABI session
18/// context provided by the Daft host.
19pub struct SessionContext<'a> {
20    session: &'a mut FFI_SessionContext,
21}
22
23impl<'a> SessionContext<'a> {
24    pub fn new(session: &'a mut FFI_SessionContext) -> Self {
25        Self { session }
26    }
27}
28
29impl DaftSession for SessionContext<'_> {
30    fn define_function(&mut self, func: DaftScalarFunctionRef) {
31        let vtable = into_ffi(func);
32        let rc = unsafe { (self.session.define_function)(self.session.ctx, vtable) };
33        assert_eq!(rc, 0, "host define_function returned non-zero: {rc}");
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use std::{
40        ffi::{CStr, c_int, c_void},
41        sync::{Arc, Mutex},
42    };
43
44    use arrow_array::{Array, Int32Array};
45    use arrow_schema::{DataType, Field, Schema};
46    use crate::abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction};
47
48    use super::*;
49    use crate::{error::DaftResult, function::DaftScalarFunction};
50
51    fn export_array(array: &dyn Array) -> ArrowData {
52        let (ffi_array, ffi_schema) = arrow::ffi::to_ffi(&array.to_data()).unwrap();
53        ArrowData {
54            array: unsafe { ArrowArray::from_owned(ffi_array) },
55            schema: unsafe { ArrowSchema::from_owned(ffi_schema) },
56        }
57    }
58
59    fn export_schema(schema: &Schema) -> ArrowSchema {
60        let ffi = arrow::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
61        unsafe { ArrowSchema::from_owned(ffi) }
62    }
63
64    #[test]
65    fn session_context_integration() {
66        static RECORDED: Mutex<Vec<String>> = Mutex::new(Vec::new());
67
68        unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
69            let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
70                .to_str()
71                .unwrap()
72                .to_string();
73            RECORDED.lock().unwrap().push(name);
74            unsafe { (func.fini)(func.ctx.cast_mut()) };
75            0
76        }
77
78        let mut raw_session = FFI_SessionContext {
79            ctx: std::ptr::null_mut(),
80            define_function: mock_define,
81        };
82
83        let mut session = SessionContext::new(&mut raw_session);
84
85        struct AddFn;
86        impl DaftScalarFunction for AddFn {
87            fn name(&self) -> &CStr {
88                c"my_add"
89            }
90            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
91                Ok(export_schema(&Schema::new(vec![Field::new(
92                    "sum",
93                    DataType::Int32,
94                    false,
95                )])))
96            }
97            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
98                let output = Int32Array::from(vec![0]);
99                Ok(export_array(&output))
100            }
101        }
102
103        session.define_function(Arc::new(AddFn));
104
105        let recorded = RECORDED.lock().unwrap();
106        assert!(
107            recorded.contains(&"my_add".to_string()),
108            "expected 'my_add' in recorded functions"
109        );
110    }
111
112    #[test]
113    fn session_context_multiple_functions() {
114        static NAMES: Mutex<Vec<String>> = Mutex::new(Vec::new());
115
116        unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
117            let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
118                .to_str()
119                .unwrap()
120                .to_string();
121            NAMES.lock().unwrap().push(name);
122            unsafe { (func.fini)(func.ctx.cast_mut()) };
123            0
124        }
125
126        let mut raw_session = FFI_SessionContext {
127            ctx: std::ptr::null_mut(),
128            define_function: mock_define,
129        };
130
131        let mut session = SessionContext::new(&mut raw_session);
132
133        struct FnA;
134        impl DaftScalarFunction for FnA {
135            fn name(&self) -> &CStr {
136                c"fn_a"
137            }
138            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
139                Ok(export_schema(&Schema::new(vec![Field::new(
140                    "a",
141                    DataType::Int32,
142                    false,
143                )])))
144            }
145            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
146                let output = Int32Array::from(vec![0]);
147                Ok(export_array(&output))
148            }
149        }
150
151        struct FnB;
152        impl DaftScalarFunction for FnB {
153            fn name(&self) -> &CStr {
154                c"fn_b"
155            }
156            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
157                Ok(export_schema(&Schema::new(vec![Field::new(
158                    "b",
159                    DataType::Utf8,
160                    true,
161                )])))
162            }
163            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
164                let output = Int32Array::from(vec![0]);
165                Ok(export_array(&output))
166            }
167        }
168
169        session.define_function(Arc::new(FnA));
170        session.define_function(Arc::new(FnB));
171
172        let names = NAMES.lock().unwrap();
173        assert!(names.contains(&"fn_a".to_string()));
174        assert!(names.contains(&"fn_b".to_string()));
175    }
176}