Skip to main content

datafusion_ffi/execution/
task_ctx.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::c_void;
20use std::sync::Arc;
21
22use datafusion_execution::TaskContext;
23use datafusion_execution::config::SessionConfig;
24use datafusion_execution::runtime_env::RuntimeEnv;
25use datafusion_expr::{
26    AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl,
27};
28
29use stabby::string::String as SString;
30use stabby::vec::Vec as SVec;
31
32use crate::session::config::FFI_SessionConfig;
33use crate::udaf::FFI_AggregateUDF;
34use crate::udf::FFI_ScalarUDF;
35use crate::udwf::FFI_WindowUDF;
36use crate::util::FFI_Option;
37
38/// A stable struct for sharing [`TaskContext`] across FFI boundaries.
39#[repr(C)]
40#[derive(Debug)]
41pub struct FFI_TaskContext {
42    /// Return the session ID.
43    pub session_id: unsafe extern "C" fn(&Self) -> SString,
44
45    /// Return the task ID.
46    pub task_id: unsafe extern "C" fn(&Self) -> FFI_Option<SString>,
47
48    /// Return the session configuration.
49    pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig,
50
51    /// Returns a vec of name-function pairs for scalar functions.
52    pub scalar_functions: unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_ScalarUDF)>,
53
54    /// Returns a vec of name-function pairs for aggregate functions.
55    pub aggregate_functions:
56        unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_AggregateUDF)>,
57
58    /// Returns a vec of name-function pairs for window functions.
59    pub window_functions: unsafe extern "C" fn(&Self) -> SVec<(SString, FFI_WindowUDF)>,
60
61    /// Release the memory of the private data when it is no longer being used.
62    pub release: unsafe extern "C" fn(arg: &mut Self),
63
64    /// Internal data. This is only to be accessed by the provider of the plan.
65    /// The foreign library should never attempt to access this data.
66    pub private_data: *mut c_void,
67
68    /// Utility to identify when FFI objects are accessed locally through
69    /// the foreign interface. See [`crate::get_library_marker_id`] and
70    /// the crate's `README.md` for more information.
71    pub library_marker_id: extern "C" fn() -> usize,
72}
73
74struct TaskContextPrivateData {
75    ctx: Arc<TaskContext>,
76}
77
78impl FFI_TaskContext {
79    unsafe fn inner(&self) -> &Arc<TaskContext> {
80        unsafe {
81            let private_data = self.private_data as *const TaskContextPrivateData;
82            &(*private_data).ctx
83        }
84    }
85}
86
87unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> SString {
88    unsafe {
89        let ctx = ctx.inner();
90        ctx.session_id().into()
91    }
92}
93
94unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> FFI_Option<SString> {
95    unsafe {
96        let ctx = ctx.inner();
97        ctx.task_id().map(|s| s.as_str().into()).into()
98    }
99}
100
101unsafe extern "C" fn session_config_fn_wrapper(
102    ctx: &FFI_TaskContext,
103) -> FFI_SessionConfig {
104    unsafe {
105        let ctx = ctx.inner();
106        ctx.session_config().into()
107    }
108}
109
110unsafe extern "C" fn scalar_functions_fn_wrapper(
111    ctx: &FFI_TaskContext,
112) -> SVec<(SString, FFI_ScalarUDF)> {
113    unsafe {
114        let ctx = ctx.inner();
115        ctx.scalar_functions()
116            .iter()
117            .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into()))
118            .collect()
119    }
120}
121
122unsafe extern "C" fn aggregate_functions_fn_wrapper(
123    ctx: &FFI_TaskContext,
124) -> SVec<(SString, FFI_AggregateUDF)> {
125    unsafe {
126        let ctx = ctx.inner();
127        ctx.aggregate_functions()
128            .iter()
129            .map(|(name, udaf)| {
130                (
131                    name.to_owned().into(),
132                    FFI_AggregateUDF::from(Arc::clone(udaf)),
133                )
134            })
135            .collect()
136    }
137}
138
139unsafe extern "C" fn window_functions_fn_wrapper(
140    ctx: &FFI_TaskContext,
141) -> SVec<(SString, FFI_WindowUDF)> {
142    unsafe {
143        let ctx = ctx.inner();
144        ctx.window_functions()
145            .iter()
146            .map(|(name, udf)| {
147                (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf)))
148            })
149            .collect()
150    }
151}
152
153unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) {
154    unsafe {
155        let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData);
156        drop(private_data);
157    }
158}
159
160impl Drop for FFI_TaskContext {
161    fn drop(&mut self) {
162        unsafe { (self.release)(self) }
163    }
164}
165
166impl From<Arc<TaskContext>> for FFI_TaskContext {
167    fn from(ctx: Arc<TaskContext>) -> Self {
168        let private_data = Box::new(TaskContextPrivateData { ctx });
169
170        FFI_TaskContext {
171            session_id: session_id_fn_wrapper,
172            task_id: task_id_fn_wrapper,
173            session_config: session_config_fn_wrapper,
174            scalar_functions: scalar_functions_fn_wrapper,
175            aggregate_functions: aggregate_functions_fn_wrapper,
176            window_functions: window_functions_fn_wrapper,
177            release: release_fn_wrapper,
178            private_data: Box::into_raw(private_data) as *mut c_void,
179            library_marker_id: crate::get_library_marker_id,
180        }
181    }
182}
183
184impl From<FFI_TaskContext> for Arc<TaskContext> {
185    fn from(ffi_ctx: FFI_TaskContext) -> Self {
186        unsafe {
187            if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() {
188                return Arc::clone(ffi_ctx.inner());
189            }
190
191            let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into();
192            let session_id = (ffi_ctx.session_id)(&ffi_ctx).into();
193            let session_config = (ffi_ctx.session_config)(&ffi_ctx);
194            let session_config =
195                SessionConfig::try_from(&session_config).unwrap_or_default();
196
197            let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx)
198                .into_iter()
199                .map(|kv_pair| {
200                    let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1);
201
202                    (
203                        kv_pair.0.to_string(),
204                        Arc::new(ScalarUDF::new_from_shared_impl(udf)),
205                    )
206                })
207                .collect();
208            let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx)
209                .into_iter()
210                .map(|kv_pair| {
211                    let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1);
212
213                    (
214                        kv_pair.0.to_string(),
215                        Arc::new(AggregateUDF::new_from_shared_impl(udaf)),
216                    )
217                })
218                .collect();
219            let window_functions = (ffi_ctx.window_functions)(&ffi_ctx)
220                .into_iter()
221                .map(|kv_pair| {
222                    let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1);
223
224                    (
225                        kv_pair.0.to_string(),
226                        Arc::new(WindowUDF::new_from_shared_impl(udwf)),
227                    )
228                })
229                .collect();
230
231            let runtime = Arc::new(RuntimeEnv::default());
232
233            Arc::new(TaskContext::new(
234                task_id,
235                session_id,
236                session_config,
237                scalar_functions,
238                HashMap::new(),
239                aggregate_functions,
240                window_functions,
241                runtime,
242            ))
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use std::sync::Arc;
250
251    use datafusion::prelude::SessionContext;
252    use datafusion_common::Result;
253    use datafusion_execution::TaskContext;
254
255    use crate::execution::FFI_TaskContext;
256
257    #[test]
258    fn ffi_task_ctx_round_trip() -> Result<()> {
259        let session_ctx = SessionContext::new();
260        let original = session_ctx.task_ctx();
261        let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original));
262        ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id;
263
264        let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into();
265
266        // TaskContext doesn't implement Eq (nor should it) so check some of the
267        // data is round tripping correctly.
268
269        assert_eq!(
270            original.scalar_functions(),
271            foreign_task_ctx.scalar_functions()
272        );
273        assert_eq!(
274            original.aggregate_functions(),
275            foreign_task_ctx.aggregate_functions()
276        );
277        assert_eq!(
278            original.window_functions(),
279            foreign_task_ctx.window_functions()
280        );
281        assert_eq!(original.task_id(), foreign_task_ctx.task_id());
282        assert_eq!(original.session_id(), foreign_task_ctx.session_id());
283        assert_eq!(
284            format!("{:?}", original.session_config()),
285            format!("{:?}", foreign_task_ctx.session_config())
286        );
287
288        Ok(())
289    }
290}