1use crate::{
19 config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20 runtime_env::RuntimeEnv,
21};
22use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28#[derive(Debug)]
36pub struct TaskContext {
37 session_id: String,
39 task_id: Option<String>,
41 session_config: SessionConfig,
43 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
45 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
47 window_functions: HashMap<String, Arc<WindowUDF>>,
49 runtime: Arc<RuntimeEnv>,
51}
52
53impl Default for TaskContext {
54 fn default() -> Self {
55 let runtime = Arc::new(RuntimeEnv::default());
56
57 Self {
59 session_id: "DEFAULT".to_string(),
60 task_id: None,
61 session_config: SessionConfig::new(),
62 scalar_functions: HashMap::new(),
63 aggregate_functions: HashMap::new(),
64 window_functions: HashMap::new(),
65 runtime,
66 }
67 }
68}
69
70impl TaskContext {
71 pub fn new(
77 task_id: Option<String>,
78 session_id: String,
79 session_config: SessionConfig,
80 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
81 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
82 window_functions: HashMap<String, Arc<WindowUDF>>,
83 runtime: Arc<RuntimeEnv>,
84 ) -> Self {
85 Self {
86 task_id,
87 session_id,
88 session_config,
89 scalar_functions,
90 aggregate_functions,
91 window_functions,
92 runtime,
93 }
94 }
95
96 pub fn session_config(&self) -> &SessionConfig {
98 &self.session_config
99 }
100
101 pub fn session_id(&self) -> String {
103 self.session_id.clone()
104 }
105
106 pub fn task_id(&self) -> Option<String> {
108 self.task_id.clone()
109 }
110
111 pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
113 &self.runtime.memory_pool
114 }
115
116 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
118 Arc::clone(&self.runtime)
119 }
120
121 pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
122 &self.scalar_functions
123 }
124
125 pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
126 &self.aggregate_functions
127 }
128
129 pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
130 &self.window_functions
131 }
132
133 pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
135 self.session_config = session_config;
136 self
137 }
138
139 pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
141 self.runtime = runtime;
142 self
143 }
144}
145
146impl FunctionRegistry for TaskContext {
147 fn udfs(&self) -> HashSet<String> {
148 self.scalar_functions.keys().cloned().collect()
149 }
150
151 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
152 let result = self.scalar_functions.get(name);
153
154 result.cloned().ok_or_else(|| {
155 plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
156 })
157 }
158
159 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
160 let result = self.aggregate_functions.get(name);
161
162 result.cloned().ok_or_else(|| {
163 plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
164 })
165 }
166
167 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
168 let result = self.window_functions.get(name);
169
170 result.cloned().ok_or_else(|| {
171 DataFusionError::Internal(format!(
172 "There is no UDWF named \"{name}\" in the TaskContext"
173 ))
174 })
175 }
176 fn register_udaf(
177 &mut self,
178 udaf: Arc<AggregateUDF>,
179 ) -> Result<Option<Arc<AggregateUDF>>> {
180 udaf.aliases().iter().for_each(|alias| {
181 self.aggregate_functions
182 .insert(alias.clone(), Arc::clone(&udaf));
183 });
184 Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
185 }
186 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
187 udwf.aliases().iter().for_each(|alias| {
188 self.window_functions
189 .insert(alias.clone(), Arc::clone(&udwf));
190 });
191 Ok(self.window_functions.insert(udwf.name().into(), udwf))
192 }
193 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
194 udf.aliases().iter().for_each(|alias| {
195 self.scalar_functions
196 .insert(alias.clone(), Arc::clone(&udf));
197 });
198 Ok(self.scalar_functions.insert(udf.name().into(), udf))
199 }
200
201 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
202 vec![]
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use datafusion_common::{
210 config::{ConfigExtension, ConfigOptions, Extensions},
211 extensions_options,
212 };
213
214 extensions_options! {
215 struct TestExtension {
216 value: usize, default = 42
217 option_value: Option<usize>, default = None
218 }
219 }
220
221 impl ConfigExtension for TestExtension {
222 const PREFIX: &'static str = "test";
223 }
224
225 #[test]
226 fn task_context_extensions() -> Result<()> {
227 let runtime = Arc::new(RuntimeEnv::default());
228 let mut extensions = Extensions::new();
229 extensions.insert(TestExtension::default());
230
231 let mut config = ConfigOptions::new().with_extensions(extensions);
232 config.set("test.value", "24")?;
233 config.set("test.option_value", "42")?;
234 let session_config = SessionConfig::from(config);
235
236 let task_context = TaskContext::new(
237 Some("task_id".to_string()),
238 "session_id".to_string(),
239 session_config,
240 HashMap::default(),
241 HashMap::default(),
242 HashMap::default(),
243 runtime,
244 );
245
246 let test = task_context
247 .session_config()
248 .options()
249 .extensions
250 .get::<TestExtension>();
251 assert!(test.is_some());
252
253 assert_eq!(test.unwrap().value, 24);
254 assert_eq!(test.unwrap().option_value, Some(42));
255
256 Ok(())
257 }
258
259 #[test]
260 fn task_context_extensions_default() -> Result<()> {
261 let runtime = Arc::new(RuntimeEnv::default());
262 let mut extensions = Extensions::new();
263 extensions.insert(TestExtension::default());
264
265 let config = ConfigOptions::new().with_extensions(extensions);
266 let session_config = SessionConfig::from(config);
267
268 let task_context = TaskContext::new(
269 Some("task_id".to_string()),
270 "session_id".to_string(),
271 session_config,
272 HashMap::default(),
273 HashMap::default(),
274 HashMap::default(),
275 runtime,
276 );
277
278 let test = task_context
279 .session_config()
280 .options()
281 .extensions
282 .get::<TestExtension>();
283 assert!(test.is_some());
284
285 assert_eq!(test.unwrap().value, 42);
286 assert_eq!(test.unwrap().option_value, None);
287
288 Ok(())
289 }
290}