1use crate::{
19 config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20 runtime_env::RuntimeEnv,
21};
22use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, HigherOrderUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28#[derive(Debug)]
52pub struct TaskContext {
53 session_id: String,
55 task_id: Option<String>,
57 session_config: SessionConfig,
59 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
61 higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
63 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
65 window_functions: HashMap<String, Arc<WindowUDF>>,
67 runtime: Arc<RuntimeEnv>,
69}
70
71impl Default for TaskContext {
72 fn default() -> Self {
73 let runtime = Arc::new(RuntimeEnv::default());
74
75 Self {
77 session_id: "DEFAULT".to_string(),
78 task_id: None,
79 session_config: SessionConfig::new(),
80 scalar_functions: HashMap::new(),
81 higher_order_functions: HashMap::new(),
82 aggregate_functions: HashMap::new(),
83 window_functions: HashMap::new(),
84 runtime,
85 }
86 }
87}
88
89impl TaskContext {
90 #[expect(clippy::too_many_arguments)]
96 pub fn new(
97 task_id: Option<String>,
98 session_id: String,
99 session_config: SessionConfig,
100 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
101 higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
102 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
103 window_functions: HashMap<String, Arc<WindowUDF>>,
104 runtime: Arc<RuntimeEnv>,
105 ) -> Self {
106 Self {
107 task_id,
108 session_id,
109 session_config,
110 scalar_functions,
111 higher_order_functions,
112 aggregate_functions,
113 window_functions,
114 runtime,
115 }
116 }
117
118 pub fn session_config(&self) -> &SessionConfig {
120 &self.session_config
121 }
122
123 pub fn session_id(&self) -> String {
125 self.session_id.clone()
126 }
127
128 pub fn task_id(&self) -> Option<String> {
130 self.task_id.clone()
131 }
132
133 pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
135 &self.runtime.memory_pool
136 }
137
138 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
140 Arc::clone(&self.runtime)
141 }
142
143 pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
144 &self.scalar_functions
145 }
146
147 pub fn higher_order_functions(&self) -> &HashMap<String, Arc<HigherOrderUDF>> {
148 &self.higher_order_functions
149 }
150
151 pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
152 &self.aggregate_functions
153 }
154
155 pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
156 &self.window_functions
157 }
158
159 pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
161 self.session_config = session_config;
162 self
163 }
164
165 pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
167 self.runtime = runtime;
168 self
169 }
170}
171
172impl FunctionRegistry for TaskContext {
173 fn udfs(&self) -> HashSet<String> {
174 self.scalar_functions.keys().cloned().collect()
175 }
176
177 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
178 let result = self.scalar_functions.get(name);
179
180 result.cloned().ok_or_else(|| {
181 plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
182 })
183 }
184
185 fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>> {
186 let result = self.higher_order_functions.get(name);
187
188 result.cloned().ok_or_else(|| {
189 plan_datafusion_err!(
190 "There is no higher-order function named \"{name}\" in the TaskContext"
191 )
192 })
193 }
194
195 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
196 let result = self.aggregate_functions.get(name);
197
198 result.cloned().ok_or_else(|| {
199 plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
200 })
201 }
202
203 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
204 let result = self.window_functions.get(name);
205
206 result.cloned().ok_or_else(|| {
207 internal_datafusion_err!(
208 "There is no UDWF named \"{name}\" in the TaskContext"
209 )
210 })
211 }
212 fn register_udaf(
213 &mut self,
214 udaf: Arc<AggregateUDF>,
215 ) -> Result<Option<Arc<AggregateUDF>>> {
216 udaf.aliases().iter().for_each(|alias| {
217 self.aggregate_functions
218 .insert(alias.clone(), Arc::clone(&udaf));
219 });
220 Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
221 }
222 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
223 udwf.aliases().iter().for_each(|alias| {
224 self.window_functions
225 .insert(alias.clone(), Arc::clone(&udwf));
226 });
227 Ok(self.window_functions.insert(udwf.name().into(), udwf))
228 }
229 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
230 udf.aliases().iter().for_each(|alias| {
231 self.scalar_functions
232 .insert(alias.clone(), Arc::clone(&udf));
233 });
234 Ok(self.scalar_functions.insert(udf.name().into(), udf))
235 }
236
237 fn register_higher_order_function(
238 &mut self,
239 function: Arc<HigherOrderUDF>,
240 ) -> Result<Option<Arc<HigherOrderUDF>>> {
241 function.aliases().iter().for_each(|alias| {
242 self.higher_order_functions
243 .insert(alias.clone(), Arc::clone(&function));
244 });
245 Ok(self
246 .higher_order_functions
247 .insert(function.name().into(), function))
248 }
249
250 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
251 vec![]
252 }
253
254 fn higher_order_function_names(&self) -> HashSet<String> {
255 self.higher_order_functions.keys().cloned().collect()
256 }
257
258 fn udafs(&self) -> HashSet<String> {
259 self.aggregate_functions.keys().cloned().collect()
260 }
261
262 fn udwfs(&self) -> HashSet<String> {
263 self.window_functions.keys().cloned().collect()
264 }
265}
266
267pub trait TaskContextProvider {
269 fn task_ctx(&self) -> Arc<TaskContext>;
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use datafusion_common::{
276 config::{ConfigExtension, ConfigOptions, Extensions},
277 extensions_options,
278 };
279
280 extensions_options! {
281 struct TestExtension {
282 value: usize, default = 42
283 option_value: Option<usize>, default = None
284 }
285 }
286
287 impl ConfigExtension for TestExtension {
288 const PREFIX: &'static str = "test";
289 }
290
291 #[test]
292 fn task_context_extensions() -> Result<()> {
293 let runtime = Arc::new(RuntimeEnv::default());
294 let mut extensions = Extensions::new();
295 extensions.insert(TestExtension::default());
296
297 let mut config = ConfigOptions::new().with_extensions(extensions);
298 config.set("test.value", "24")?;
299 config.set("test.option_value", "42")?;
300 let session_config = SessionConfig::from(config);
301
302 let task_context = TaskContext::new(
303 Some("task_id".to_string()),
304 "session_id".to_string(),
305 session_config,
306 HashMap::default(),
307 HashMap::default(),
308 HashMap::default(),
309 HashMap::default(),
310 runtime,
311 );
312
313 let test = task_context
314 .session_config()
315 .options()
316 .extensions
317 .get::<TestExtension>();
318 assert!(test.is_some());
319
320 assert_eq!(test.unwrap().value, 24);
321 assert_eq!(test.unwrap().option_value, Some(42));
322
323 Ok(())
324 }
325
326 #[test]
327 fn task_context_extensions_default() -> Result<()> {
328 let runtime = Arc::new(RuntimeEnv::default());
329 let mut extensions = Extensions::new();
330 extensions.insert(TestExtension::default());
331
332 let config = ConfigOptions::new().with_extensions(extensions);
333 let session_config = SessionConfig::from(config);
334
335 let task_context = TaskContext::new(
336 Some("task_id".to_string()),
337 "session_id".to_string(),
338 session_config,
339 HashMap::default(),
340 HashMap::default(),
341 HashMap::default(),
342 HashMap::default(),
343 runtime,
344 );
345
346 let test = task_context
347 .session_config()
348 .options()
349 .extensions
350 .get::<TestExtension>();
351 assert!(test.is_some());
352
353 assert_eq!(test.unwrap().value, 42);
354 assert_eq!(test.unwrap().option_value, None);
355
356 Ok(())
357 }
358}