1use crate::{
2 abi::FFI_SessionContext,
3 function::{DaftScalarFunctionRef, into_ffi},
4};
5
6pub trait DaftSession {
8 fn define_function(&mut self, function: DaftScalarFunctionRef);
9}
10
11pub trait DaftExtension {
13 fn install(session: &mut dyn DaftSession);
14}
15
16pub struct SessionContext<'a> {
21 session: &'a mut FFI_SessionContext,
22}
23
24impl<'a> SessionContext<'a> {
25 pub fn new(session: &'a mut FFI_SessionContext) -> Self {
26 Self { session }
27 }
28}
29
30impl DaftSession for SessionContext<'_> {
31 fn define_function(&mut self, func: DaftScalarFunctionRef) {
32 let vtable = into_ffi(func);
33 let rc = unsafe { (self.session.define_function)(self.session.ctx, vtable) };
34 assert_eq!(rc, 0, "host define_function returned non-zero: {rc}");
35 }
36}
37
38#[cfg(all(test, feature = "arrow-57"))]
39mod tests {
40 use std::{
41 ffi::{CStr, c_int, c_void},
42 sync::{Arc, Mutex},
43 };
44
45 use arrow_schema_57::{DataType, Field, Schema};
46
47 use super::*;
48 use crate::{
49 abi::{ArrowData, ArrowSchema, FFI_ScalarFunction},
50 error::DaftResult,
51 function::DaftScalarFunction,
52 };
53
54 fn export_schema(schema: &Schema) -> ArrowSchema {
55 let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
56 unsafe { ArrowSchema::from_owned(ffi) }
57 }
58
59 fn stub_int32() -> ArrowData {
61 let schema = {
62 let field = Field::new("", DataType::Int32, false);
63 let ffi = arrow_schema_57::ffi::FFI_ArrowSchema::try_from(&field).unwrap();
64 unsafe { ArrowSchema::from_owned(ffi) }
65 };
66 ArrowData {
67 schema,
68 array: crate::abi::ArrowArray::empty(),
69 }
70 }
71
72 #[test]
73 fn session_context_integration() {
74 static RECORDED: Mutex<Vec<String>> = Mutex::new(Vec::new());
75
76 unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
77 let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
78 .to_str()
79 .unwrap()
80 .to_string();
81 RECORDED.lock().unwrap().push(name);
82 unsafe { (func.fini)(func.ctx.cast_mut()) };
83 0
84 }
85
86 let mut raw_session = FFI_SessionContext {
87 ctx: std::ptr::null_mut(),
88 define_function: mock_define,
89 };
90
91 let mut session = SessionContext::new(&mut raw_session);
92
93 struct AddFn;
94 impl DaftScalarFunction for AddFn {
95 fn name(&self) -> &CStr {
96 c"my_add"
97 }
98 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
99 Ok(export_schema(&Schema::new(vec![Field::new(
100 "sum",
101 DataType::Int32,
102 false,
103 )])))
104 }
105 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
106 Ok(stub_int32())
107 }
108 }
109
110 session.define_function(Arc::new(AddFn));
111
112 let recorded = RECORDED.lock().unwrap();
113 assert!(
114 recorded.contains(&"my_add".to_string()),
115 "expected 'my_add' in recorded functions"
116 );
117 }
118
119 #[test]
120 fn session_context_multiple_functions() {
121 static NAMES: Mutex<Vec<String>> = Mutex::new(Vec::new());
122
123 unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
124 let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
125 .to_str()
126 .unwrap()
127 .to_string();
128 NAMES.lock().unwrap().push(name);
129 unsafe { (func.fini)(func.ctx.cast_mut()) };
130 0
131 }
132
133 let mut raw_session = FFI_SessionContext {
134 ctx: std::ptr::null_mut(),
135 define_function: mock_define,
136 };
137
138 let mut session = SessionContext::new(&mut raw_session);
139
140 struct FnA;
141 impl DaftScalarFunction for FnA {
142 fn name(&self) -> &CStr {
143 c"fn_a"
144 }
145 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
146 Ok(export_schema(&Schema::new(vec![Field::new(
147 "a",
148 DataType::Int32,
149 false,
150 )])))
151 }
152 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
153 Ok(stub_int32())
154 }
155 }
156
157 struct FnB;
158 impl DaftScalarFunction for FnB {
159 fn name(&self) -> &CStr {
160 c"fn_b"
161 }
162 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
163 Ok(export_schema(&Schema::new(vec![Field::new(
164 "b",
165 DataType::Utf8,
166 true,
167 )])))
168 }
169 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
170 Ok(stub_int32())
171 }
172 }
173
174 session.define_function(Arc::new(FnA));
175 session.define_function(Arc::new(FnB));
176
177 let names = NAMES.lock().unwrap();
178 assert!(names.contains(&"fn_a".to_string()));
179 assert!(names.contains(&"fn_b".to_string()));
180 }
181}