datafusion_execution/
task.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20    runtime_env::RuntimeEnv,
21};
22use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28/// Task Execution Context
29///
30/// A [`TaskContext`] contains the state required during a single query's
31/// execution. Please see the documentation on [`SessionContext`] for more
32/// information.
33///
34/// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html
35#[derive(Debug)]
36pub struct TaskContext {
37    /// Session Id
38    session_id: String,
39    /// Optional Task Identify
40    task_id: Option<String>,
41    /// Session configuration
42    session_config: SessionConfig,
43    /// Scalar functions associated with this task context
44    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
45    /// Aggregate functions associated with this task context
46    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
47    /// Window functions associated with this task context
48    window_functions: HashMap<String, Arc<WindowUDF>>,
49    /// Runtime environment associated with this task context
50    runtime: Arc<RuntimeEnv>,
51}
52
53impl Default for TaskContext {
54    fn default() -> Self {
55        let runtime = Arc::new(RuntimeEnv::default());
56
57        // Create a default task context, mostly useful for testing
58        Self {
59            session_id: "DEFAULT".to_string(),
60            task_id: None,
61            session_config: SessionConfig::new(),
62            scalar_functions: HashMap::new(),
63            aggregate_functions: HashMap::new(),
64            window_functions: HashMap::new(),
65            runtime,
66        }
67    }
68}
69
70impl TaskContext {
71    /// Create a new [`TaskContext`] instance.
72    ///
73    /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s
74    ///
75    /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx
76    pub fn new(
77        task_id: Option<String>,
78        session_id: String,
79        session_config: SessionConfig,
80        scalar_functions: HashMap<String, Arc<ScalarUDF>>,
81        aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
82        window_functions: HashMap<String, Arc<WindowUDF>>,
83        runtime: Arc<RuntimeEnv>,
84    ) -> Self {
85        Self {
86            task_id,
87            session_id,
88            session_config,
89            scalar_functions,
90            aggregate_functions,
91            window_functions,
92            runtime,
93        }
94    }
95
96    /// Return the SessionConfig associated with this [TaskContext]
97    pub fn session_config(&self) -> &SessionConfig {
98        &self.session_config
99    }
100
101    /// Return the `session_id` of this [TaskContext]
102    pub fn session_id(&self) -> String {
103        self.session_id.clone()
104    }
105
106    /// Return the `task_id` of this [TaskContext]
107    pub fn task_id(&self) -> Option<String> {
108        self.task_id.clone()
109    }
110
111    /// Return the [`MemoryPool`] associated with this [TaskContext]
112    pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
113        &self.runtime.memory_pool
114    }
115
116    /// Return the [RuntimeEnv] associated with this [TaskContext]
117    pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
118        Arc::clone(&self.runtime)
119    }
120
121    pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
122        &self.scalar_functions
123    }
124
125    pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
126        &self.aggregate_functions
127    }
128
129    pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
130        &self.window_functions
131    }
132
133    /// Update the [`SessionConfig`]
134    pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
135        self.session_config = session_config;
136        self
137    }
138
139    /// Update the [`RuntimeEnv`]
140    pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
141        self.runtime = runtime;
142        self
143    }
144}
145
146impl FunctionRegistry for TaskContext {
147    fn udfs(&self) -> HashSet<String> {
148        self.scalar_functions.keys().cloned().collect()
149    }
150
151    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
152        let result = self.scalar_functions.get(name);
153
154        result.cloned().ok_or_else(|| {
155            plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
156        })
157    }
158
159    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
160        let result = self.aggregate_functions.get(name);
161
162        result.cloned().ok_or_else(|| {
163            plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
164        })
165    }
166
167    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
168        let result = self.window_functions.get(name);
169
170        result.cloned().ok_or_else(|| {
171            DataFusionError::Internal(format!(
172                "There is no UDWF named \"{name}\" in the TaskContext"
173            ))
174        })
175    }
176    fn register_udaf(
177        &mut self,
178        udaf: Arc<AggregateUDF>,
179    ) -> Result<Option<Arc<AggregateUDF>>> {
180        udaf.aliases().iter().for_each(|alias| {
181            self.aggregate_functions
182                .insert(alias.clone(), Arc::clone(&udaf));
183        });
184        Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
185    }
186    fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
187        udwf.aliases().iter().for_each(|alias| {
188            self.window_functions
189                .insert(alias.clone(), Arc::clone(&udwf));
190        });
191        Ok(self.window_functions.insert(udwf.name().into(), udwf))
192    }
193    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
194        udf.aliases().iter().for_each(|alias| {
195            self.scalar_functions
196                .insert(alias.clone(), Arc::clone(&udf));
197        });
198        Ok(self.scalar_functions.insert(udf.name().into(), udf))
199    }
200
201    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
202        vec![]
203    }
204
205    fn udafs(&self) -> HashSet<String> {
206        self.aggregate_functions.keys().cloned().collect()
207    }
208
209    fn udwfs(&self) -> HashSet<String> {
210        self.window_functions.keys().cloned().collect()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use datafusion_common::{
218        config::{ConfigExtension, ConfigOptions, Extensions},
219        extensions_options,
220    };
221
222    extensions_options! {
223        struct TestExtension {
224            value: usize, default = 42
225            option_value: Option<usize>, default = None
226        }
227    }
228
229    impl ConfigExtension for TestExtension {
230        const PREFIX: &'static str = "test";
231    }
232
233    #[test]
234    fn task_context_extensions() -> Result<()> {
235        let runtime = Arc::new(RuntimeEnv::default());
236        let mut extensions = Extensions::new();
237        extensions.insert(TestExtension::default());
238
239        let mut config = ConfigOptions::new().with_extensions(extensions);
240        config.set("test.value", "24")?;
241        config.set("test.option_value", "42")?;
242        let session_config = SessionConfig::from(config);
243
244        let task_context = TaskContext::new(
245            Some("task_id".to_string()),
246            "session_id".to_string(),
247            session_config,
248            HashMap::default(),
249            HashMap::default(),
250            HashMap::default(),
251            runtime,
252        );
253
254        let test = task_context
255            .session_config()
256            .options()
257            .extensions
258            .get::<TestExtension>();
259        assert!(test.is_some());
260
261        assert_eq!(test.unwrap().value, 24);
262        assert_eq!(test.unwrap().option_value, Some(42));
263
264        Ok(())
265    }
266
267    #[test]
268    fn task_context_extensions_default() -> Result<()> {
269        let runtime = Arc::new(RuntimeEnv::default());
270        let mut extensions = Extensions::new();
271        extensions.insert(TestExtension::default());
272
273        let config = ConfigOptions::new().with_extensions(extensions);
274        let session_config = SessionConfig::from(config);
275
276        let task_context = TaskContext::new(
277            Some("task_id".to_string()),
278            "session_id".to_string(),
279            session_config,
280            HashMap::default(),
281            HashMap::default(),
282            HashMap::default(),
283            runtime,
284        );
285
286        let test = task_context
287            .session_config()
288            .options()
289            .extensions
290            .get::<TestExtension>();
291        assert!(test.is_some());
292
293        assert_eq!(test.unwrap().value, 42);
294        assert_eq!(test.unwrap().option_value, None);
295
296        Ok(())
297    }
298}