Skip to main content

daft_ext/
session.rs

1use crate::{
2    abi::FFI_SessionContext,
3    function::{DaftScalarFunctionRef, into_ffi},
4};
5
6/// Trait for installing an extension within a session.
7pub trait DaftSession {
8    fn define_function(&mut self, function: DaftScalarFunctionRef);
9}
10
11/// Trait implemented by extension crates to install themselves.
12pub trait DaftExtension {
13    fn install(session: &mut dyn DaftSession);
14}
15
16/// A [`DaftSession`] backed by a [`FFI_SessionContext`] from the host.
17///
18/// This bridges the safe `DaftSession` trait to the C ABI session
19/// context provided by the Daft host.
20pub struct SessionContext<'a> {
21    session: &'a mut FFI_SessionContext,
22}
23
24impl<'a> SessionContext<'a> {
25    pub fn new(session: &'a mut FFI_SessionContext) -> Self {
26        Self { session }
27    }
28}
29
30impl DaftSession for SessionContext<'_> {
31    fn define_function(&mut self, func: DaftScalarFunctionRef) {
32        let vtable = into_ffi(func);
33        let rc = unsafe { (self.session.define_function)(self.session.ctx, vtable) };
34        assert_eq!(rc, 0, "host define_function returned non-zero: {rc}");
35    }
36}
37
38#[cfg(all(test, feature = "arrow-57"))]
39mod tests {
40    use std::{
41        ffi::{CStr, c_int, c_void},
42        sync::{Arc, Mutex},
43    };
44
45    use arrow_schema_57::{DataType, Field, Schema};
46
47    use super::*;
48    use crate::{
49        abi::{ArrowData, ArrowSchema, FFI_ScalarFunction},
50        error::DaftResult,
51        function::DaftScalarFunction,
52    };
53
54    fn export_schema(schema: &Schema) -> ArrowSchema {
55        let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
56        unsafe { ArrowSchema::from_owned(ffi) }
57    }
58
59    /// Minimal stub [`ArrowData`] for tests that never inspect the array.
60    fn stub_int32() -> ArrowData {
61        let schema = {
62            let field = Field::new("", DataType::Int32, false);
63            let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(&field).unwrap();
64            unsafe { ArrowSchema::from_owned(ffi) }
65        };
66        ArrowData {
67            schema,
68            array: crate::abi::ArrowArray::empty(),
69        }
70    }
71
72    #[test]
73    fn session_context_integration() {
74        static RECORDED: Mutex<Vec<String>> = Mutex::new(Vec::new());
75
76        unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
77            let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
78                .to_str()
79                .unwrap()
80                .to_string();
81            RECORDED.lock().unwrap().push(name);
82            unsafe { (func.fini)(func.ctx.cast_mut()) };
83            0
84        }
85
86        let mut raw_session = FFI_SessionContext {
87            ctx: std::ptr::null_mut(),
88            define_function: mock_define,
89        };
90
91        let mut session = SessionContext::new(&mut raw_session);
92
93        struct AddFn;
94        impl DaftScalarFunction for AddFn {
95            fn name(&self) -> &CStr {
96                c"my_add"
97            }
98            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
99                Ok(export_schema(&Schema::new(vec![Field::new(
100                    "sum",
101                    DataType::Int32,
102                    false,
103                )])))
104            }
105            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
106                Ok(stub_int32())
107            }
108        }
109
110        session.define_function(Arc::new(AddFn));
111
112        let recorded = RECORDED.lock().unwrap();
113        assert!(
114            recorded.contains(&"my_add".to_string()),
115            "expected 'my_add' in recorded functions"
116        );
117    }
118
119    #[test]
120    fn session_context_multiple_functions() {
121        static NAMES: Mutex<Vec<String>> = Mutex::new(Vec::new());
122
123        unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
124            let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
125                .to_str()
126                .unwrap()
127                .to_string();
128            NAMES.lock().unwrap().push(name);
129            unsafe { (func.fini)(func.ctx.cast_mut()) };
130            0
131        }
132
133        let mut raw_session = FFI_SessionContext {
134            ctx: std::ptr::null_mut(),
135            define_function: mock_define,
136        };
137
138        let mut session = SessionContext::new(&mut raw_session);
139
140        struct FnA;
141        impl DaftScalarFunction for FnA {
142            fn name(&self) -> &CStr {
143                c"fn_a"
144            }
145            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
146                Ok(export_schema(&Schema::new(vec![Field::new(
147                    "a",
148                    DataType::Int32,
149                    false,
150                )])))
151            }
152            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
153                Ok(stub_int32())
154            }
155        }
156
157        struct FnB;
158        impl DaftScalarFunction for FnB {
159            fn name(&self) -> &CStr {
160                c"fn_b"
161            }
162            fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
163                Ok(export_schema(&Schema::new(vec![Field::new(
164                    "b",
165                    DataType::Utf8,
166                    true,
167                )])))
168            }
169            fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
170                Ok(stub_int32())
171            }
172        }
173
174        session.define_function(Arc::new(FnA));
175        session.define_function(Arc::new(FnB));
176
177        let names = NAMES.lock().unwrap();
178        assert!(names.contains(&"fn_a".to_string()));
179        assert!(names.contains(&"fn_b".to_string()));
180    }
181}