datafusion_ffi/execution/
task_ctx.rs1use std::ffi::c_void;
19use std::sync::Arc;
20
21use abi_stable::StableAbi;
22use abi_stable::pmr::ROption;
23use abi_stable::std_types::{RHashMap, RString};
24use datafusion_execution::TaskContext;
25use datafusion_execution::config::SessionConfig;
26use datafusion_execution::runtime_env::RuntimeEnv;
27use datafusion_expr::{
28 AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl,
29};
30
31use crate::session::config::FFI_SessionConfig;
32use crate::udaf::FFI_AggregateUDF;
33use crate::udf::FFI_ScalarUDF;
34use crate::udwf::FFI_WindowUDF;
35
36#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_TaskContext {
40 pub session_id: unsafe extern "C" fn(&Self) -> RString,
42
43 pub task_id: unsafe extern "C" fn(&Self) -> ROption<RString>,
45
46 pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig,
48
49 pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_ScalarUDF>,
51
52 pub aggregate_functions:
54 unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_AggregateUDF>,
55
56 pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_WindowUDF>,
58
59 pub release: unsafe extern "C" fn(arg: &mut Self),
61
62 pub private_data: *mut c_void,
65
66 pub library_marker_id: extern "C" fn() -> usize,
70}
71
72struct TaskContextPrivateData {
73 ctx: Arc<TaskContext>,
74}
75
76impl FFI_TaskContext {
77 unsafe fn inner(&self) -> &Arc<TaskContext> {
78 unsafe {
79 let private_data = self.private_data as *const TaskContextPrivateData;
80 &(*private_data).ctx
81 }
82 }
83}
84
85unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString {
86 unsafe {
87 let ctx = ctx.inner();
88 ctx.session_id().into()
89 }
90}
91
92unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption<RString> {
93 unsafe {
94 let ctx = ctx.inner();
95 ctx.task_id().map(|s| s.as_str().into()).into()
96 }
97}
98
99unsafe extern "C" fn session_config_fn_wrapper(
100 ctx: &FFI_TaskContext,
101) -> FFI_SessionConfig {
102 unsafe {
103 let ctx = ctx.inner();
104 ctx.session_config().into()
105 }
106}
107
108unsafe extern "C" fn scalar_functions_fn_wrapper(
109 ctx: &FFI_TaskContext,
110) -> RHashMap<RString, FFI_ScalarUDF> {
111 unsafe {
112 let ctx = ctx.inner();
113 ctx.scalar_functions()
114 .iter()
115 .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into()))
116 .collect()
117 }
118}
119
120unsafe extern "C" fn aggregate_functions_fn_wrapper(
121 ctx: &FFI_TaskContext,
122) -> RHashMap<RString, FFI_AggregateUDF> {
123 unsafe {
124 let ctx = ctx.inner();
125 ctx.aggregate_functions()
126 .iter()
127 .map(|(name, udaf)| {
128 (
129 name.to_owned().into(),
130 FFI_AggregateUDF::from(Arc::clone(udaf)),
131 )
132 })
133 .collect()
134 }
135}
136
137unsafe extern "C" fn window_functions_fn_wrapper(
138 ctx: &FFI_TaskContext,
139) -> RHashMap<RString, FFI_WindowUDF> {
140 unsafe {
141 let ctx = ctx.inner();
142 ctx.window_functions()
143 .iter()
144 .map(|(name, udf)| {
145 (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf)))
146 })
147 .collect()
148 }
149}
150
151unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) {
152 unsafe {
153 let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData);
154 drop(private_data);
155 }
156}
157
158impl Drop for FFI_TaskContext {
159 fn drop(&mut self) {
160 unsafe { (self.release)(self) }
161 }
162}
163
164impl From<Arc<TaskContext>> for FFI_TaskContext {
165 fn from(ctx: Arc<TaskContext>) -> Self {
166 let private_data = Box::new(TaskContextPrivateData { ctx });
167
168 FFI_TaskContext {
169 session_id: session_id_fn_wrapper,
170 task_id: task_id_fn_wrapper,
171 session_config: session_config_fn_wrapper,
172 scalar_functions: scalar_functions_fn_wrapper,
173 aggregate_functions: aggregate_functions_fn_wrapper,
174 window_functions: window_functions_fn_wrapper,
175 release: release_fn_wrapper,
176 private_data: Box::into_raw(private_data) as *mut c_void,
177 library_marker_id: crate::get_library_marker_id,
178 }
179 }
180}
181
182impl From<FFI_TaskContext> for Arc<TaskContext> {
183 fn from(ffi_ctx: FFI_TaskContext) -> Self {
184 unsafe {
185 if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() {
186 return Arc::clone(ffi_ctx.inner());
187 }
188
189 let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into();
190 let session_id = (ffi_ctx.session_id)(&ffi_ctx).into();
191 let session_config = (ffi_ctx.session_config)(&ffi_ctx);
192 let session_config =
193 SessionConfig::try_from(&session_config).unwrap_or_default();
194
195 let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx)
196 .into_iter()
197 .map(|kv_pair| {
198 let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1);
199
200 (
201 kv_pair.0.into_string(),
202 Arc::new(ScalarUDF::new_from_shared_impl(udf)),
203 )
204 })
205 .collect();
206 let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx)
207 .into_iter()
208 .map(|kv_pair| {
209 let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1);
210
211 (
212 kv_pair.0.into_string(),
213 Arc::new(AggregateUDF::new_from_shared_impl(udaf)),
214 )
215 })
216 .collect();
217 let window_functions = (ffi_ctx.window_functions)(&ffi_ctx)
218 .into_iter()
219 .map(|kv_pair| {
220 let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1);
221
222 (
223 kv_pair.0.into_string(),
224 Arc::new(WindowUDF::new_from_shared_impl(udwf)),
225 )
226 })
227 .collect();
228
229 let runtime = Arc::new(RuntimeEnv::default());
230
231 Arc::new(TaskContext::new(
232 task_id,
233 session_id,
234 session_config,
235 scalar_functions,
236 aggregate_functions,
237 window_functions,
238 runtime,
239 ))
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::sync::Arc;
247
248 use datafusion::prelude::SessionContext;
249 use datafusion_common::Result;
250 use datafusion_execution::TaskContext;
251
252 use crate::execution::FFI_TaskContext;
253
254 #[test]
255 fn ffi_task_ctx_round_trip() -> Result<()> {
256 let session_ctx = SessionContext::new();
257 let original = session_ctx.task_ctx();
258 let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original));
259 ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id;
260
261 let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into();
262
263 assert_eq!(
267 original.scalar_functions(),
268 foreign_task_ctx.scalar_functions()
269 );
270 assert_eq!(
271 original.aggregate_functions(),
272 foreign_task_ctx.aggregate_functions()
273 );
274 assert_eq!(
275 original.window_functions(),
276 foreign_task_ctx.window_functions()
277 );
278 assert_eq!(original.task_id(), foreign_task_ctx.task_id());
279 assert_eq!(original.session_id(), foreign_task_ctx.session_id());
280 assert_eq!(
281 format!("{:?}", original.session_config()),
282 format!("{:?}", foreign_task_ctx.session_config())
283 );
284
285 Ok(())
286 }
287}