Skip to main content

datafusion_execution/
task.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20    runtime_env::RuntimeEnv,
21};
22use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, HigherOrderUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28/// Task Execution Context
29///
30/// A [`TaskContext`] contains the state required during a single query's
31/// execution. Please see the documentation on [`SessionContext`] for more
32/// information.
33///
34/// # Relationship with [`ExecutionProps`]
35///
36/// [`TaskContext`] is intentionally distinct from [`ExecutionProps`].
37/// [`ExecutionProps`] is state used while optimizing a logical
38/// plan and constructing a physical plan.
39///
40/// [`TaskContext`] is the runtime context passed to physical operators when
41/// executing a physical plan. It carries runtime services and session state
42/// needed at that stage, such as [`RuntimeEnv`], memory-pool access, session
43/// configuration, and function lookup.
44///
45/// Keeping these structures separate avoids threading execution/runtime state
46/// through planning APIs, and avoids making execution depend on planner-only
47/// scratch state.
48///
49/// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html
50/// [`ExecutionProps`]: datafusion_expr::execution_props::ExecutionProps
51#[derive(Debug)]
52pub struct TaskContext {
53    /// Session Id
54    session_id: String,
55    /// Optional Task Identify
56    task_id: Option<String>,
57    /// Session configuration
58    session_config: SessionConfig,
59    /// Scalar functions associated with this task context
60    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
61    /// Higher order functions associated with this task context
62    higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
63    /// Aggregate functions associated with this task context
64    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
65    /// Window functions associated with this task context
66    window_functions: HashMap<String, Arc<WindowUDF>>,
67    /// Runtime environment associated with this task context
68    runtime: Arc<RuntimeEnv>,
69}
70
71impl Default for TaskContext {
72    fn default() -> Self {
73        let runtime = Arc::new(RuntimeEnv::default());
74
75        // Create a default task context, mostly useful for testing
76        Self {
77            session_id: "DEFAULT".to_string(),
78            task_id: None,
79            session_config: SessionConfig::new(),
80            scalar_functions: HashMap::new(),
81            higher_order_functions: HashMap::new(),
82            aggregate_functions: HashMap::new(),
83            window_functions: HashMap::new(),
84            runtime,
85        }
86    }
87}
88
89impl TaskContext {
90    /// Create a new [`TaskContext`] instance.
91    ///
92    /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s
93    ///
94    /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx
95    #[expect(clippy::too_many_arguments)]
96    pub fn new(
97        task_id: Option<String>,
98        session_id: String,
99        session_config: SessionConfig,
100        scalar_functions: HashMap<String, Arc<ScalarUDF>>,
101        higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
102        aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
103        window_functions: HashMap<String, Arc<WindowUDF>>,
104        runtime: Arc<RuntimeEnv>,
105    ) -> Self {
106        Self {
107            task_id,
108            session_id,
109            session_config,
110            scalar_functions,
111            higher_order_functions,
112            aggregate_functions,
113            window_functions,
114            runtime,
115        }
116    }
117
118    /// Return the SessionConfig associated with this [TaskContext]
119    pub fn session_config(&self) -> &SessionConfig {
120        &self.session_config
121    }
122
123    /// Return the `session_id` of this [TaskContext]
124    pub fn session_id(&self) -> String {
125        self.session_id.clone()
126    }
127
128    /// Return the `task_id` of this [TaskContext]
129    pub fn task_id(&self) -> Option<String> {
130        self.task_id.clone()
131    }
132
133    /// Return the [`MemoryPool`] associated with this [TaskContext]
134    pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
135        &self.runtime.memory_pool
136    }
137
138    /// Return the [RuntimeEnv] associated with this [TaskContext]
139    pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
140        Arc::clone(&self.runtime)
141    }
142
143    pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
144        &self.scalar_functions
145    }
146
147    pub fn higher_order_functions(&self) -> &HashMap<String, Arc<HigherOrderUDF>> {
148        &self.higher_order_functions
149    }
150
151    pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
152        &self.aggregate_functions
153    }
154
155    pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
156        &self.window_functions
157    }
158
159    /// Update the [`SessionConfig`]
160    pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
161        self.session_config = session_config;
162        self
163    }
164
165    /// Update the [`RuntimeEnv`]
166    pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
167        self.runtime = runtime;
168        self
169    }
170}
171
172impl FunctionRegistry for TaskContext {
173    fn udfs(&self) -> HashSet<String> {
174        self.scalar_functions.keys().cloned().collect()
175    }
176
177    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
178        let result = self.scalar_functions.get(name);
179
180        result.cloned().ok_or_else(|| {
181            plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
182        })
183    }
184
185    fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>> {
186        let result = self.higher_order_functions.get(name);
187
188        result.cloned().ok_or_else(|| {
189            plan_datafusion_err!(
190                "There is no higher-order function named \"{name}\" in the TaskContext"
191            )
192        })
193    }
194
195    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
196        let result = self.aggregate_functions.get(name);
197
198        result.cloned().ok_or_else(|| {
199            plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
200        })
201    }
202
203    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
204        let result = self.window_functions.get(name);
205
206        result.cloned().ok_or_else(|| {
207            internal_datafusion_err!(
208                "There is no UDWF named \"{name}\" in the TaskContext"
209            )
210        })
211    }
212    fn register_udaf(
213        &mut self,
214        udaf: Arc<AggregateUDF>,
215    ) -> Result<Option<Arc<AggregateUDF>>> {
216        udaf.aliases().iter().for_each(|alias| {
217            self.aggregate_functions
218                .insert(alias.clone(), Arc::clone(&udaf));
219        });
220        Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
221    }
222    fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
223        udwf.aliases().iter().for_each(|alias| {
224            self.window_functions
225                .insert(alias.clone(), Arc::clone(&udwf));
226        });
227        Ok(self.window_functions.insert(udwf.name().into(), udwf))
228    }
229    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
230        udf.aliases().iter().for_each(|alias| {
231            self.scalar_functions
232                .insert(alias.clone(), Arc::clone(&udf));
233        });
234        Ok(self.scalar_functions.insert(udf.name().into(), udf))
235    }
236
237    fn register_higher_order_function(
238        &mut self,
239        function: Arc<HigherOrderUDF>,
240    ) -> Result<Option<Arc<HigherOrderUDF>>> {
241        function.aliases().iter().for_each(|alias| {
242            self.higher_order_functions
243                .insert(alias.clone(), Arc::clone(&function));
244        });
245        Ok(self
246            .higher_order_functions
247            .insert(function.name().into(), function))
248    }
249
250    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
251        vec![]
252    }
253
254    fn higher_order_function_names(&self) -> HashSet<String> {
255        self.higher_order_functions.keys().cloned().collect()
256    }
257
258    fn udafs(&self) -> HashSet<String> {
259        self.aggregate_functions.keys().cloned().collect()
260    }
261
262    fn udwfs(&self) -> HashSet<String> {
263        self.window_functions.keys().cloned().collect()
264    }
265}
266
267/// Produce the [`TaskContext`].
268pub trait TaskContextProvider {
269    fn task_ctx(&self) -> Arc<TaskContext>;
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use datafusion_common::{
276        config::{ConfigExtension, ConfigOptions, Extensions},
277        extensions_options,
278    };
279
280    extensions_options! {
281        struct TestExtension {
282            value: usize, default = 42
283            option_value: Option<usize>, default = None
284        }
285    }
286
287    impl ConfigExtension for TestExtension {
288        const PREFIX: &'static str = "test";
289    }
290
291    #[test]
292    fn task_context_extensions() -> Result<()> {
293        let runtime = Arc::new(RuntimeEnv::default());
294        let mut extensions = Extensions::new();
295        extensions.insert(TestExtension::default());
296
297        let mut config = ConfigOptions::new().with_extensions(extensions);
298        config.set("test.value", "24")?;
299        config.set("test.option_value", "42")?;
300        let session_config = SessionConfig::from(config);
301
302        let task_context = TaskContext::new(
303            Some("task_id".to_string()),
304            "session_id".to_string(),
305            session_config,
306            HashMap::default(),
307            HashMap::default(),
308            HashMap::default(),
309            HashMap::default(),
310            runtime,
311        );
312
313        let test = task_context
314            .session_config()
315            .options()
316            .extensions
317            .get::<TestExtension>();
318        assert!(test.is_some());
319
320        assert_eq!(test.unwrap().value, 24);
321        assert_eq!(test.unwrap().option_value, Some(42));
322
323        Ok(())
324    }
325
326    #[test]
327    fn task_context_extensions_default() -> Result<()> {
328        let runtime = Arc::new(RuntimeEnv::default());
329        let mut extensions = Extensions::new();
330        extensions.insert(TestExtension::default());
331
332        let config = ConfigOptions::new().with_extensions(extensions);
333        let session_config = SessionConfig::from(config);
334
335        let task_context = TaskContext::new(
336            Some("task_id".to_string()),
337            "session_id".to_string(),
338            session_config,
339            HashMap::default(),
340            HashMap::default(),
341            HashMap::default(),
342            HashMap::default(),
343            runtime,
344        );
345
346        let test = task_context
347            .session_config()
348            .options()
349            .extensions
350            .get::<TestExtension>();
351        assert!(test.is_some());
352
353        assert_eq!(test.unwrap().value, 42);
354        assert_eq!(test.unwrap().option_value, None);
355
356        Ok(())
357    }
358}