1use crate::{
19 config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20 runtime_env::RuntimeEnv,
21};
22use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28#[derive(Debug)]
36pub struct TaskContext {
37 session_id: String,
39 task_id: Option<String>,
41 session_config: SessionConfig,
43 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
45 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
47 window_functions: HashMap<String, Arc<WindowUDF>>,
49 runtime: Arc<RuntimeEnv>,
51}
52
53impl Default for TaskContext {
54 fn default() -> Self {
55 let runtime = Arc::new(RuntimeEnv::default());
56
57 Self {
59 session_id: "DEFAULT".to_string(),
60 task_id: None,
61 session_config: SessionConfig::new(),
62 scalar_functions: HashMap::new(),
63 aggregate_functions: HashMap::new(),
64 window_functions: HashMap::new(),
65 runtime,
66 }
67 }
68}
69
70impl TaskContext {
71 pub fn new(
77 task_id: Option<String>,
78 session_id: String,
79 session_config: SessionConfig,
80 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
81 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
82 window_functions: HashMap<String, Arc<WindowUDF>>,
83 runtime: Arc<RuntimeEnv>,
84 ) -> Self {
85 Self {
86 task_id,
87 session_id,
88 session_config,
89 scalar_functions,
90 aggregate_functions,
91 window_functions,
92 runtime,
93 }
94 }
95
96 pub fn session_config(&self) -> &SessionConfig {
98 &self.session_config
99 }
100
101 pub fn session_id(&self) -> String {
103 self.session_id.clone()
104 }
105
106 pub fn task_id(&self) -> Option<String> {
108 self.task_id.clone()
109 }
110
111 pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
113 &self.runtime.memory_pool
114 }
115
116 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
118 Arc::clone(&self.runtime)
119 }
120
121 pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
122 &self.scalar_functions
123 }
124
125 pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
126 &self.aggregate_functions
127 }
128
129 pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
130 &self.window_functions
131 }
132
133 pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
135 self.session_config = session_config;
136 self
137 }
138
139 pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
141 self.runtime = runtime;
142 self
143 }
144}
145
146impl FunctionRegistry for TaskContext {
147 fn udfs(&self) -> HashSet<String> {
148 self.scalar_functions.keys().cloned().collect()
149 }
150
151 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
152 let result = self.scalar_functions.get(name);
153
154 result.cloned().ok_or_else(|| {
155 plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
156 })
157 }
158
159 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
160 let result = self.aggregate_functions.get(name);
161
162 result.cloned().ok_or_else(|| {
163 plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
164 })
165 }
166
167 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
168 let result = self.window_functions.get(name);
169
170 result.cloned().ok_or_else(|| {
171 DataFusionError::Internal(format!(
172 "There is no UDWF named \"{name}\" in the TaskContext"
173 ))
174 })
175 }
176 fn register_udaf(
177 &mut self,
178 udaf: Arc<AggregateUDF>,
179 ) -> Result<Option<Arc<AggregateUDF>>> {
180 udaf.aliases().iter().for_each(|alias| {
181 self.aggregate_functions
182 .insert(alias.clone(), Arc::clone(&udaf));
183 });
184 Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
185 }
186 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
187 udwf.aliases().iter().for_each(|alias| {
188 self.window_functions
189 .insert(alias.clone(), Arc::clone(&udwf));
190 });
191 Ok(self.window_functions.insert(udwf.name().into(), udwf))
192 }
193 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
194 udf.aliases().iter().for_each(|alias| {
195 self.scalar_functions
196 .insert(alias.clone(), Arc::clone(&udf));
197 });
198 Ok(self.scalar_functions.insert(udf.name().into(), udf))
199 }
200
201 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
202 vec![]
203 }
204
205 fn udafs(&self) -> HashSet<String> {
206 self.aggregate_functions.keys().cloned().collect()
207 }
208
209 fn udwfs(&self) -> HashSet<String> {
210 self.window_functions.keys().cloned().collect()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use datafusion_common::{
218 config::{ConfigExtension, ConfigOptions, Extensions},
219 extensions_options,
220 };
221
222 extensions_options! {
223 struct TestExtension {
224 value: usize, default = 42
225 option_value: Option<usize>, default = None
226 }
227 }
228
229 impl ConfigExtension for TestExtension {
230 const PREFIX: &'static str = "test";
231 }
232
233 #[test]
234 fn task_context_extensions() -> Result<()> {
235 let runtime = Arc::new(RuntimeEnv::default());
236 let mut extensions = Extensions::new();
237 extensions.insert(TestExtension::default());
238
239 let mut config = ConfigOptions::new().with_extensions(extensions);
240 config.set("test.value", "24")?;
241 config.set("test.option_value", "42")?;
242 let session_config = SessionConfig::from(config);
243
244 let task_context = TaskContext::new(
245 Some("task_id".to_string()),
246 "session_id".to_string(),
247 session_config,
248 HashMap::default(),
249 HashMap::default(),
250 HashMap::default(),
251 runtime,
252 );
253
254 let test = task_context
255 .session_config()
256 .options()
257 .extensions
258 .get::<TestExtension>();
259 assert!(test.is_some());
260
261 assert_eq!(test.unwrap().value, 24);
262 assert_eq!(test.unwrap().option_value, Some(42));
263
264 Ok(())
265 }
266
267 #[test]
268 fn task_context_extensions_default() -> Result<()> {
269 let runtime = Arc::new(RuntimeEnv::default());
270 let mut extensions = Extensions::new();
271 extensions.insert(TestExtension::default());
272
273 let config = ConfigOptions::new().with_extensions(extensions);
274 let session_config = SessionConfig::from(config);
275
276 let task_context = TaskContext::new(
277 Some("task_id".to_string()),
278 "session_id".to_string(),
279 session_config,
280 HashMap::default(),
281 HashMap::default(),
282 HashMap::default(),
283 runtime,
284 );
285
286 let test = task_context
287 .session_config()
288 .options()
289 .extensions
290 .get::<TestExtension>();
291 assert!(test.is_some());
292
293 assert_eq!(test.unwrap().value, 42);
294 assert_eq!(test.unwrap().option_value, None);
295
296 Ok(())
297 }
298}