datafusion_ffi/execution/
task_ctx.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::sync::Arc;
20
21use abi_stable::StableAbi;
22use abi_stable::pmr::ROption;
23use abi_stable::std_types::{RHashMap, RString};
24use datafusion_execution::TaskContext;
25use datafusion_execution::config::SessionConfig;
26use datafusion_execution::runtime_env::RuntimeEnv;
27use datafusion_expr::{
28    AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl,
29};
30
31use crate::session::config::FFI_SessionConfig;
32use crate::udaf::FFI_AggregateUDF;
33use crate::udf::FFI_ScalarUDF;
34use crate::udwf::FFI_WindowUDF;
35
36/// A stable struct for sharing [`TaskContext`] across FFI boundaries.
37#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_TaskContext {
40    /// Return the session ID.
41    pub session_id: unsafe extern "C" fn(&Self) -> RString,
42
43    /// Return the task ID.
44    pub task_id: unsafe extern "C" fn(&Self) -> ROption<RString>,
45
46    /// Return the session configuration.
47    pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig,
48
49    /// Returns a hashmap of names to scalar functions.
50    pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_ScalarUDF>,
51
52    /// Returns a hashmap of names to aggregate functions.
53    pub aggregate_functions:
54        unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_AggregateUDF>,
55
56    /// Returns a hashmap of names to window functions.
57    pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_WindowUDF>,
58
59    /// Release the memory of the private data when it is no longer being used.
60    pub release: unsafe extern "C" fn(arg: &mut Self),
61
62    /// Internal data. This is only to be accessed by the provider of the plan.
63    /// The foreign library should never attempt to access this data.
64    pub private_data: *mut c_void,
65
66    /// Utility to identify when FFI objects are accessed locally through
67    /// the foreign interface. See [`crate::get_library_marker_id`] and
68    /// the crate's `README.md` for more information.
69    pub library_marker_id: extern "C" fn() -> usize,
70}
71
72struct TaskContextPrivateData {
73    ctx: Arc<TaskContext>,
74}
75
76impl FFI_TaskContext {
77    unsafe fn inner(&self) -> &Arc<TaskContext> {
78        unsafe {
79            let private_data = self.private_data as *const TaskContextPrivateData;
80            &(*private_data).ctx
81        }
82    }
83}
84
85unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString {
86    unsafe {
87        let ctx = ctx.inner();
88        ctx.session_id().into()
89    }
90}
91
92unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption<RString> {
93    unsafe {
94        let ctx = ctx.inner();
95        ctx.task_id().map(|s| s.as_str().into()).into()
96    }
97}
98
99unsafe extern "C" fn session_config_fn_wrapper(
100    ctx: &FFI_TaskContext,
101) -> FFI_SessionConfig {
102    unsafe {
103        let ctx = ctx.inner();
104        ctx.session_config().into()
105    }
106}
107
108unsafe extern "C" fn scalar_functions_fn_wrapper(
109    ctx: &FFI_TaskContext,
110) -> RHashMap<RString, FFI_ScalarUDF> {
111    unsafe {
112        let ctx = ctx.inner();
113        ctx.scalar_functions()
114            .iter()
115            .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into()))
116            .collect()
117    }
118}
119
120unsafe extern "C" fn aggregate_functions_fn_wrapper(
121    ctx: &FFI_TaskContext,
122) -> RHashMap<RString, FFI_AggregateUDF> {
123    unsafe {
124        let ctx = ctx.inner();
125        ctx.aggregate_functions()
126            .iter()
127            .map(|(name, udaf)| {
128                (
129                    name.to_owned().into(),
130                    FFI_AggregateUDF::from(Arc::clone(udaf)),
131                )
132            })
133            .collect()
134    }
135}
136
137unsafe extern "C" fn window_functions_fn_wrapper(
138    ctx: &FFI_TaskContext,
139) -> RHashMap<RString, FFI_WindowUDF> {
140    unsafe {
141        let ctx = ctx.inner();
142        ctx.window_functions()
143            .iter()
144            .map(|(name, udf)| {
145                (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf)))
146            })
147            .collect()
148    }
149}
150
151unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) {
152    unsafe {
153        let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData);
154        drop(private_data);
155    }
156}
157
158impl Drop for FFI_TaskContext {
159    fn drop(&mut self) {
160        unsafe { (self.release)(self) }
161    }
162}
163
164impl From<Arc<TaskContext>> for FFI_TaskContext {
165    fn from(ctx: Arc<TaskContext>) -> Self {
166        let private_data = Box::new(TaskContextPrivateData { ctx });
167
168        FFI_TaskContext {
169            session_id: session_id_fn_wrapper,
170            task_id: task_id_fn_wrapper,
171            session_config: session_config_fn_wrapper,
172            scalar_functions: scalar_functions_fn_wrapper,
173            aggregate_functions: aggregate_functions_fn_wrapper,
174            window_functions: window_functions_fn_wrapper,
175            release: release_fn_wrapper,
176            private_data: Box::into_raw(private_data) as *mut c_void,
177            library_marker_id: crate::get_library_marker_id,
178        }
179    }
180}
181
182impl From<FFI_TaskContext> for Arc<TaskContext> {
183    fn from(ffi_ctx: FFI_TaskContext) -> Self {
184        unsafe {
185            if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() {
186                return Arc::clone(ffi_ctx.inner());
187            }
188
189            let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into();
190            let session_id = (ffi_ctx.session_id)(&ffi_ctx).into();
191            let session_config = (ffi_ctx.session_config)(&ffi_ctx);
192            let session_config =
193                SessionConfig::try_from(&session_config).unwrap_or_default();
194
195            let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx)
196                .into_iter()
197                .map(|kv_pair| {
198                    let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1);
199
200                    (
201                        kv_pair.0.into_string(),
202                        Arc::new(ScalarUDF::new_from_shared_impl(udf)),
203                    )
204                })
205                .collect();
206            let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx)
207                .into_iter()
208                .map(|kv_pair| {
209                    let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1);
210
211                    (
212                        kv_pair.0.into_string(),
213                        Arc::new(AggregateUDF::new_from_shared_impl(udaf)),
214                    )
215                })
216                .collect();
217            let window_functions = (ffi_ctx.window_functions)(&ffi_ctx)
218                .into_iter()
219                .map(|kv_pair| {
220                    let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1);
221
222                    (
223                        kv_pair.0.into_string(),
224                        Arc::new(WindowUDF::new_from_shared_impl(udwf)),
225                    )
226                })
227                .collect();
228
229            let runtime = Arc::new(RuntimeEnv::default());
230
231            Arc::new(TaskContext::new(
232                task_id,
233                session_id,
234                session_config,
235                scalar_functions,
236                aggregate_functions,
237                window_functions,
238                runtime,
239            ))
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::sync::Arc;
247
248    use datafusion::prelude::SessionContext;
249    use datafusion_common::Result;
250    use datafusion_execution::TaskContext;
251
252    use crate::execution::FFI_TaskContext;
253
254    #[test]
255    fn ffi_task_ctx_round_trip() -> Result<()> {
256        let session_ctx = SessionContext::new();
257        let original = session_ctx.task_ctx();
258        let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original));
259        ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id;
260
261        let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into();
262
263        // TaskContext doesn't implement Eq (nor should it) so check some of the
264        // data is round tripping correctly.
265
266        assert_eq!(
267            original.scalar_functions(),
268            foreign_task_ctx.scalar_functions()
269        );
270        assert_eq!(
271            original.aggregate_functions(),
272            foreign_task_ctx.aggregate_functions()
273        );
274        assert_eq!(
275            original.window_functions(),
276            foreign_task_ctx.window_functions()
277        );
278        assert_eq!(original.task_id(), foreign_task_ctx.task_id());
279        assert_eq!(original.session_id(), foreign_task_ctx.session_id());
280        assert_eq!(
281            format!("{:?}", original.session_config()),
282            format!("{:?}", foreign_task_ctx.session_config())
283        );
284
285        Ok(())
286    }
287}