1use crate::abi::FFI_SessionContext;
2
3use crate::function::{DaftScalarFunctionRef, into_ffi};
4
5pub trait DaftSession {
7 fn define_function(&mut self, function: DaftScalarFunctionRef);
8}
9
10pub trait DaftExtension {
12 fn install(session: &mut dyn DaftSession);
13}
14
15pub struct SessionContext<'a> {
20 session: &'a mut FFI_SessionContext,
21}
22
23impl<'a> SessionContext<'a> {
24 pub fn new(session: &'a mut FFI_SessionContext) -> Self {
25 Self { session }
26 }
27}
28
29impl DaftSession for SessionContext<'_> {
30 fn define_function(&mut self, func: DaftScalarFunctionRef) {
31 let vtable = into_ffi(func);
32 let rc = unsafe { (self.session.define_function)(self.session.ctx, vtable) };
33 assert_eq!(rc, 0, "host define_function returned non-zero: {rc}");
34 }
35}
36
37#[cfg(test)]
38mod tests {
39 use std::{
40 ffi::{CStr, c_int, c_void},
41 sync::{Arc, Mutex},
42 };
43
44 use arrow_array::{Array, Int32Array};
45 use arrow_schema::{DataType, Field, Schema};
46 use crate::abi::{ArrowArray, ArrowData, ArrowSchema, FFI_ScalarFunction};
47
48 use super::*;
49 use crate::{error::DaftResult, function::DaftScalarFunction};
50
51 fn export_array(array: &dyn Array) -> ArrowData {
52 let (ffi_array, ffi_schema) = arrow::ffi::to_ffi(&array.to_data()).unwrap();
53 ArrowData {
54 array: unsafe { ArrowArray::from_owned(ffi_array) },
55 schema: unsafe { ArrowSchema::from_owned(ffi_schema) },
56 }
57 }
58
59 fn export_schema(schema: &Schema) -> ArrowSchema {
60 let ffi = arrow::ffi::FFI_ArrowSchema::try_from(schema).unwrap();
61 unsafe { ArrowSchema::from_owned(ffi) }
62 }
63
64 #[test]
65 fn session_context_integration() {
66 static RECORDED: Mutex<Vec<String>> = Mutex::new(Vec::new());
67
68 unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
69 let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
70 .to_str()
71 .unwrap()
72 .to_string();
73 RECORDED.lock().unwrap().push(name);
74 unsafe { (func.fini)(func.ctx.cast_mut()) };
75 0
76 }
77
78 let mut raw_session = FFI_SessionContext {
79 ctx: std::ptr::null_mut(),
80 define_function: mock_define,
81 };
82
83 let mut session = SessionContext::new(&mut raw_session);
84
85 struct AddFn;
86 impl DaftScalarFunction for AddFn {
87 fn name(&self) -> &CStr {
88 c"my_add"
89 }
90 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
91 Ok(export_schema(&Schema::new(vec![Field::new(
92 "sum",
93 DataType::Int32,
94 false,
95 )])))
96 }
97 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
98 let output = Int32Array::from(vec![0]);
99 Ok(export_array(&output))
100 }
101 }
102
103 session.define_function(Arc::new(AddFn));
104
105 let recorded = RECORDED.lock().unwrap();
106 assert!(
107 recorded.contains(&"my_add".to_string()),
108 "expected 'my_add' in recorded functions"
109 );
110 }
111
112 #[test]
113 fn session_context_multiple_functions() {
114 static NAMES: Mutex<Vec<String>> = Mutex::new(Vec::new());
115
116 unsafe extern "C" fn mock_define(_ctx: *mut c_void, func: FFI_ScalarFunction) -> c_int {
117 let name = unsafe { CStr::from_ptr((func.name)(func.ctx)) }
118 .to_str()
119 .unwrap()
120 .to_string();
121 NAMES.lock().unwrap().push(name);
122 unsafe { (func.fini)(func.ctx.cast_mut()) };
123 0
124 }
125
126 let mut raw_session = FFI_SessionContext {
127 ctx: std::ptr::null_mut(),
128 define_function: mock_define,
129 };
130
131 let mut session = SessionContext::new(&mut raw_session);
132
133 struct FnA;
134 impl DaftScalarFunction for FnA {
135 fn name(&self) -> &CStr {
136 c"fn_a"
137 }
138 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
139 Ok(export_schema(&Schema::new(vec![Field::new(
140 "a",
141 DataType::Int32,
142 false,
143 )])))
144 }
145 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
146 let output = Int32Array::from(vec![0]);
147 Ok(export_array(&output))
148 }
149 }
150
151 struct FnB;
152 impl DaftScalarFunction for FnB {
153 fn name(&self) -> &CStr {
154 c"fn_b"
155 }
156 fn return_field(&self, _: &[ArrowSchema]) -> DaftResult<ArrowSchema> {
157 Ok(export_schema(&Schema::new(vec![Field::new(
158 "b",
159 DataType::Utf8,
160 true,
161 )])))
162 }
163 fn call(&self, _: Vec<ArrowData>) -> DaftResult<ArrowData> {
164 let output = Int32Array::from(vec![0]);
165 Ok(export_array(&output))
166 }
167 }
168
169 session.define_function(Arc::new(FnA));
170 session.define_function(Arc::new(FnB));
171
172 let names = NAMES.lock().unwrap();
173 assert!(names.contains(&"fn_a".to_string()));
174 assert!(names.contains(&"fn_b".to_string()));
175 }
176}