datafusion_ffi/execution/
task_ctx.rs1use std::collections::HashMap;
19use std::ffi::c_void;
20use std::sync::Arc;
21
22use datafusion_execution::TaskContext;
23use datafusion_execution::config::SessionConfig;
24use datafusion_execution::runtime_env::RuntimeEnv;
25use datafusion_expr::{
26 AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl,
27};
28
29use stabby::string::String as SString;
30use stabby::vec::Vec as SVec;
31
32use crate::session::config::FFI_SessionConfig;
33use crate::udaf::FFI_AggregateUDF;
34use crate::udf::FFI_ScalarUDF;
35use crate::udwf::FFI_WindowUDF;
36use crate::util::FFI_Option;
37
38#[repr(C)]
40#[derive(Debug)]
41pub struct FFI_TaskContext {
42 pub session_id: unsafe extern "C" fn(&Self) -> SString,
44
45 pub task_id: unsafe extern "C" fn(&Self) -> FFI_Option<SString>,
47
48 pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig,
50
51 pub scalar_functions: unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_ScalarUDF)>,
53
54 pub aggregate_functions:
56 unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_AggregateUDF)>,
57
58 pub window_functions: unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_WindowUDF)>,
60
61 pub release: unsafe extern "C" fn(arg: &mut Self),
63
64 pub private_data: *mut c_void,
67
68 pub library_marker_id: extern "C" fn() -> usize,
72}
73
74struct TaskContextPrivateData {
75 ctx: Arc<TaskContext>,
76}
77
78impl FFI_TaskContext {
79 unsafe fn inner(&self) -> &Arc<TaskContext> {
80 unsafe {
81 let private_data = self.private_data as *const TaskContextPrivateData;
82 &(*private_data).ctx
83 }
84 }
85}
86
87unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> SString {
88 unsafe {
89 let ctx = ctx.inner();
90 ctx.session_id().into()
91 }
92}
93
94unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> FFI_Option<SString> {
95 unsafe {
96 let ctx = ctx.inner();
97 ctx.task_id().map(|s| s.as_str().into()).into()
98 }
99}
100
101unsafe extern "C" fn session_config_fn_wrapper(
102 ctx: &FFI_TaskContext,
103) -> FFI_SessionConfig {
104 unsafe {
105 let ctx = ctx.inner();
106 ctx.session_config().into()
107 }
108}
109
110unsafe extern "C" fn scalar_functions_fn_wrapper(
111 ctx: &FFI_TaskContext,
112) -> SVec<(SString, FFI_ScalarUDF)> {
113 unsafe {
114 let ctx = ctx.inner();
115 ctx.scalar_functions()
116 .iter()
117 .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into()))
118 .collect()
119 }
120}
121
122unsafe extern "C" fn aggregate_functions_fn_wrapper(
123 ctx: &FFI_TaskContext,
124) -> SVec<(SString, FFI_AggregateUDF)> {
125 unsafe {
126 let ctx = ctx.inner();
127 ctx.aggregate_functions()
128 .iter()
129 .map(|(name, udaf)| {
130 (
131 name.to_owned().into(),
132 FFI_AggregateUDF::from(Arc::clone(udaf)),
133 )
134 })
135 .collect()
136 }
137}
138
139unsafe extern "C" fn window_functions_fn_wrapper(
140 ctx: &FFI_TaskContext,
141) -> SVec<(SString, FFI_WindowUDF)> {
142 unsafe {
143 let ctx = ctx.inner();
144 ctx.window_functions()
145 .iter()
146 .map(|(name, udf)| {
147 (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf)))
148 })
149 .collect()
150 }
151}
152
153unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) {
154 unsafe {
155 let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData);
156 drop(private_data);
157 }
158}
159
160impl Drop for FFI_TaskContext {
161 fn drop(&mut self) {
162 unsafe { (self.release)(self) }
163 }
164}
165
166impl From<Arc<TaskContext>> for FFI_TaskContext {
167 fn from(ctx: Arc<TaskContext>) -> Self {
168 let private_data = Box::new(TaskContextPrivateData { ctx });
169
170 FFI_TaskContext {
171 session_id: session_id_fn_wrapper,
172 task_id: task_id_fn_wrapper,
173 session_config: session_config_fn_wrapper,
174 scalar_functions: scalar_functions_fn_wrapper,
175 aggregate_functions: aggregate_functions_fn_wrapper,
176 window_functions: window_functions_fn_wrapper,
177 release: release_fn_wrapper,
178 private_data: Box::into_raw(private_data) as *mut c_void,
179 library_marker_id: crate::get_library_marker_id,
180 }
181 }
182}
183
184impl From<FFI_TaskContext> for Arc<TaskContext> {
185 fn from(ffi_ctx: FFI_TaskContext) -> Self {
186 unsafe {
187 if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() {
188 return Arc::clone(ffi_ctx.inner());
189 }
190
191 let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into();
192 let session_id = (ffi_ctx.session_id)(&ffi_ctx).into();
193 let session_config = (ffi_ctx.session_config)(&ffi_ctx);
194 let session_config =
195 SessionConfig::try_from(&session_config).unwrap_or_default();
196
197 let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx)
198 .into_iter()
199 .map(|kv_pair| {
200 let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1);
201
202 (
203 kv_pair.0.to_string(),
204 Arc::new(ScalarUDF::new_from_shared_impl(udf)),
205 )
206 })
207 .collect();
208 let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx)
209 .into_iter()
210 .map(|kv_pair| {
211 let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1);
212
213 (
214 kv_pair.0.to_string(),
215 Arc::new(AggregateUDF::new_from_shared_impl(udaf)),
216 )
217 })
218 .collect();
219 let window_functions = (ffi_ctx.window_functions)(&ffi_ctx)
220 .into_iter()
221 .map(|kv_pair| {
222 let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1);
223
224 (
225 kv_pair.0.to_string(),
226 Arc::new(WindowUDF::new_from_shared_impl(udwf)),
227 )
228 })
229 .collect();
230
231 let runtime = Arc::new(RuntimeEnv::default());
232
233 Arc::new(TaskContext::new(
234 task_id,
235 session_id,
236 session_config,
237 scalar_functions,
238 HashMap::new(),
239 aggregate_functions,
240 window_functions,
241 runtime,
242 ))
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use std::sync::Arc;
250
251 use datafusion::prelude::SessionContext;
252 use datafusion_common::Result;
253 use datafusion_execution::TaskContext;
254
255 use crate::execution::FFI_TaskContext;
256
257 #[test]
258 fn ffi_task_ctx_round_trip() -> Result<()> {
259 let session_ctx = SessionContext::new();
260 let original = session_ctx.task_ctx();
261 let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original));
262 ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id;
263
264 let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into();
265
266 assert_eq!(
270 original.scalar_functions(),
271 foreign_task_ctx.scalar_functions()
272 );
273 assert_eq!(
274 original.aggregate_functions(),
275 foreign_task_ctx.aggregate_functions()
276 );
277 assert_eq!(
278 original.window_functions(),
279 foreign_task_ctx.window_functions()
280 );
281 assert_eq!(original.task_id(), foreign_task_ctx.task_id());
282 assert_eq!(original.session_id(), foreign_task_ctx.session_id());
283 assert_eq!(
284 format!("{:?}", original.session_config()),
285 format!("{:?}", foreign_task_ctx.session_config())
286 );
287
288 Ok(())
289 }
290}