datafusion_execution/
task.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20    runtime_env::RuntimeEnv,
21};
22use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28/// Task Execution Context
29///
30/// A [`TaskContext`] contains the state required during a single query's
31/// execution. Please see the documentation on [`SessionContext`] for more
32/// information.
33///
34/// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html
35#[derive(Debug)]
36pub struct TaskContext {
37    /// Session Id
38    session_id: String,
39    /// Optional Task Identify
40    task_id: Option<String>,
41    /// Session configuration
42    session_config: SessionConfig,
43    /// Scalar functions associated with this task context
44    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
45    /// Aggregate functions associated with this task context
46    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
47    /// Window functions associated with this task context
48    window_functions: HashMap<String, Arc<WindowUDF>>,
49    /// Runtime environment associated with this task context
50    runtime: Arc<RuntimeEnv>,
51}
52
53impl Default for TaskContext {
54    fn default() -> Self {
55        let runtime = Arc::new(RuntimeEnv::default());
56
57        // Create a default task context, mostly useful for testing
58        Self {
59            session_id: "DEFAULT".to_string(),
60            task_id: None,
61            session_config: SessionConfig::new(),
62            scalar_functions: HashMap::new(),
63            aggregate_functions: HashMap::new(),
64            window_functions: HashMap::new(),
65            runtime,
66        }
67    }
68}
69
70impl TaskContext {
71    /// Create a new [`TaskContext`] instance.
72    ///
73    /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s
74    ///
75    /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx
76    pub fn new(
77        task_id: Option<String>,
78        session_id: String,
79        session_config: SessionConfig,
80        scalar_functions: HashMap<String, Arc<ScalarUDF>>,
81        aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
82        window_functions: HashMap<String, Arc<WindowUDF>>,
83        runtime: Arc<RuntimeEnv>,
84    ) -> Self {
85        Self {
86            task_id,
87            session_id,
88            session_config,
89            scalar_functions,
90            aggregate_functions,
91            window_functions,
92            runtime,
93        }
94    }
95
96    /// Return the SessionConfig associated with this [TaskContext]
97    pub fn session_config(&self) -> &SessionConfig {
98        &self.session_config
99    }
100
101    /// Return the `session_id` of this [TaskContext]
102    pub fn session_id(&self) -> String {
103        self.session_id.clone()
104    }
105
106    /// Return the `task_id` of this [TaskContext]
107    pub fn task_id(&self) -> Option<String> {
108        self.task_id.clone()
109    }
110
111    /// Return the [`MemoryPool`] associated with this [TaskContext]
112    pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
113        &self.runtime.memory_pool
114    }
115
116    /// Return the [RuntimeEnv] associated with this [TaskContext]
117    pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
118        Arc::clone(&self.runtime)
119    }
120
121    pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
122        &self.scalar_functions
123    }
124
125    pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
126        &self.aggregate_functions
127    }
128
129    pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
130        &self.window_functions
131    }
132
133    /// Update the [`SessionConfig`]
134    pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
135        self.session_config = session_config;
136        self
137    }
138
139    /// Update the [`RuntimeEnv`]
140    pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
141        self.runtime = runtime;
142        self
143    }
144}
145
146impl FunctionRegistry for TaskContext {
147    fn udfs(&self) -> HashSet<String> {
148        self.scalar_functions.keys().cloned().collect()
149    }
150
151    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
152        let result = self.scalar_functions.get(name);
153
154        result.cloned().ok_or_else(|| {
155            plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
156        })
157    }
158
159    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
160        let result = self.aggregate_functions.get(name);
161
162        result.cloned().ok_or_else(|| {
163            plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
164        })
165    }
166
167    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
168        let result = self.window_functions.get(name);
169
170        result.cloned().ok_or_else(|| {
171            DataFusionError::Internal(format!(
172                "There is no UDWF named \"{name}\" in the TaskContext"
173            ))
174        })
175    }
176    fn register_udaf(
177        &mut self,
178        udaf: Arc<AggregateUDF>,
179    ) -> Result<Option<Arc<AggregateUDF>>> {
180        udaf.aliases().iter().for_each(|alias| {
181            self.aggregate_functions
182                .insert(alias.clone(), Arc::clone(&udaf));
183        });
184        Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
185    }
186    fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
187        udwf.aliases().iter().for_each(|alias| {
188            self.window_functions
189                .insert(alias.clone(), Arc::clone(&udwf));
190        });
191        Ok(self.window_functions.insert(udwf.name().into(), udwf))
192    }
193    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
194        udf.aliases().iter().for_each(|alias| {
195            self.scalar_functions
196                .insert(alias.clone(), Arc::clone(&udf));
197        });
198        Ok(self.scalar_functions.insert(udf.name().into(), udf))
199    }
200
201    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
202        vec![]
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use datafusion_common::{
210        config::{ConfigExtension, ConfigOptions, Extensions},
211        extensions_options,
212    };
213
214    extensions_options! {
215        struct TestExtension {
216            value: usize, default = 42
217            option_value: Option<usize>, default = None
218        }
219    }
220
221    impl ConfigExtension for TestExtension {
222        const PREFIX: &'static str = "test";
223    }
224
225    #[test]
226    fn task_context_extensions() -> Result<()> {
227        let runtime = Arc::new(RuntimeEnv::default());
228        let mut extensions = Extensions::new();
229        extensions.insert(TestExtension::default());
230
231        let mut config = ConfigOptions::new().with_extensions(extensions);
232        config.set("test.value", "24")?;
233        config.set("test.option_value", "42")?;
234        let session_config = SessionConfig::from(config);
235
236        let task_context = TaskContext::new(
237            Some("task_id".to_string()),
238            "session_id".to_string(),
239            session_config,
240            HashMap::default(),
241            HashMap::default(),
242            HashMap::default(),
243            runtime,
244        );
245
246        let test = task_context
247            .session_config()
248            .options()
249            .extensions
250            .get::<TestExtension>();
251        assert!(test.is_some());
252
253        assert_eq!(test.unwrap().value, 24);
254        assert_eq!(test.unwrap().option_value, Some(42));
255
256        Ok(())
257    }
258
259    #[test]
260    fn task_context_extensions_default() -> Result<()> {
261        let runtime = Arc::new(RuntimeEnv::default());
262        let mut extensions = Extensions::new();
263        extensions.insert(TestExtension::default());
264
265        let config = ConfigOptions::new().with_extensions(extensions);
266        let session_config = SessionConfig::from(config);
267
268        let task_context = TaskContext::new(
269            Some("task_id".to_string()),
270            "session_id".to_string(),
271            session_config,
272            HashMap::default(),
273            HashMap::default(),
274            HashMap::default(),
275            runtime,
276        );
277
278        let test = task_context
279            .session_config()
280            .options()
281            .extensions
282            .get::<TestExtension>();
283        assert!(test.is_some());
284
285        assert_eq!(test.unwrap().value, 42);
286        assert_eq!(test.unwrap().option_value, None);
287
288        Ok(())
289    }
290}