1use crate::{
19 config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
20 runtime_env::RuntimeEnv,
21};
22use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err};
23use datafusion_expr::planner::ExprPlanner;
24use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
25use std::collections::HashSet;
26use std::{collections::HashMap, sync::Arc};
27
28#[derive(Debug)]
36pub struct TaskContext {
37 session_id: String,
39 task_id: Option<String>,
41 session_config: SessionConfig,
43 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
45 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
47 window_functions: HashMap<String, Arc<WindowUDF>>,
49 runtime: Arc<RuntimeEnv>,
51}
52
53impl Default for TaskContext {
54 fn default() -> Self {
55 let runtime = Arc::new(RuntimeEnv::default());
56
57 Self {
59 session_id: "DEFAULT".to_string(),
60 task_id: None,
61 session_config: SessionConfig::new(),
62 scalar_functions: HashMap::new(),
63 aggregate_functions: HashMap::new(),
64 window_functions: HashMap::new(),
65 runtime,
66 }
67 }
68}
69
70impl TaskContext {
71 pub fn new(
77 task_id: Option<String>,
78 session_id: String,
79 session_config: SessionConfig,
80 scalar_functions: HashMap<String, Arc<ScalarUDF>>,
81 aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
82 window_functions: HashMap<String, Arc<WindowUDF>>,
83 runtime: Arc<RuntimeEnv>,
84 ) -> Self {
85 Self {
86 task_id,
87 session_id,
88 session_config,
89 scalar_functions,
90 aggregate_functions,
91 window_functions,
92 runtime,
93 }
94 }
95
96 pub fn session_config(&self) -> &SessionConfig {
98 &self.session_config
99 }
100
101 pub fn session_id(&self) -> String {
103 self.session_id.clone()
104 }
105
106 pub fn task_id(&self) -> Option<String> {
108 self.task_id.clone()
109 }
110
111 pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
113 &self.runtime.memory_pool
114 }
115
116 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
118 Arc::clone(&self.runtime)
119 }
120
121 pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
122 &self.scalar_functions
123 }
124
125 pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
126 &self.aggregate_functions
127 }
128
129 pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
130 &self.window_functions
131 }
132
133 pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
135 self.session_config = session_config;
136 self
137 }
138
139 pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
141 self.runtime = runtime;
142 self
143 }
144}
145
146impl FunctionRegistry for TaskContext {
147 fn udfs(&self) -> HashSet<String> {
148 self.scalar_functions.keys().cloned().collect()
149 }
150
151 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
152 let result = self.scalar_functions.get(name);
153
154 result.cloned().ok_or_else(|| {
155 plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
156 })
157 }
158
159 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
160 let result = self.aggregate_functions.get(name);
161
162 result.cloned().ok_or_else(|| {
163 plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
164 })
165 }
166
167 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
168 let result = self.window_functions.get(name);
169
170 result.cloned().ok_or_else(|| {
171 internal_datafusion_err!(
172 "There is no UDWF named \"{name}\" in the TaskContext"
173 )
174 })
175 }
176 fn register_udaf(
177 &mut self,
178 udaf: Arc<AggregateUDF>,
179 ) -> Result<Option<Arc<AggregateUDF>>> {
180 udaf.aliases().iter().for_each(|alias| {
181 self.aggregate_functions
182 .insert(alias.clone(), Arc::clone(&udaf));
183 });
184 Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
185 }
186 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
187 udwf.aliases().iter().for_each(|alias| {
188 self.window_functions
189 .insert(alias.clone(), Arc::clone(&udwf));
190 });
191 Ok(self.window_functions.insert(udwf.name().into(), udwf))
192 }
193 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
194 udf.aliases().iter().for_each(|alias| {
195 self.scalar_functions
196 .insert(alias.clone(), Arc::clone(&udf));
197 });
198 Ok(self.scalar_functions.insert(udf.name().into(), udf))
199 }
200
201 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
202 vec![]
203 }
204
205 fn udafs(&self) -> HashSet<String> {
206 self.aggregate_functions.keys().cloned().collect()
207 }
208
209 fn udwfs(&self) -> HashSet<String> {
210 self.window_functions.keys().cloned().collect()
211 }
212}
213
214pub trait TaskContextProvider {
216 fn task_ctx(&self) -> Arc<TaskContext>;
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use datafusion_common::{
223 config::{ConfigExtension, ConfigOptions, Extensions},
224 extensions_options,
225 };
226
227 extensions_options! {
228 struct TestExtension {
229 value: usize, default = 42
230 option_value: Option<usize>, default = None
231 }
232 }
233
234 impl ConfigExtension for TestExtension {
235 const PREFIX: &'static str = "test";
236 }
237
238 #[test]
239 fn task_context_extensions() -> Result<()> {
240 let runtime = Arc::new(RuntimeEnv::default());
241 let mut extensions = Extensions::new();
242 extensions.insert(TestExtension::default());
243
244 let mut config = ConfigOptions::new().with_extensions(extensions);
245 config.set("test.value", "24")?;
246 config.set("test.option_value", "42")?;
247 let session_config = SessionConfig::from(config);
248
249 let task_context = TaskContext::new(
250 Some("task_id".to_string()),
251 "session_id".to_string(),
252 session_config,
253 HashMap::default(),
254 HashMap::default(),
255 HashMap::default(),
256 runtime,
257 );
258
259 let test = task_context
260 .session_config()
261 .options()
262 .extensions
263 .get::<TestExtension>();
264 assert!(test.is_some());
265
266 assert_eq!(test.unwrap().value, 24);
267 assert_eq!(test.unwrap().option_value, Some(42));
268
269 Ok(())
270 }
271
272 #[test]
273 fn task_context_extensions_default() -> Result<()> {
274 let runtime = Arc::new(RuntimeEnv::default());
275 let mut extensions = Extensions::new();
276 extensions.insert(TestExtension::default());
277
278 let config = ConfigOptions::new().with_extensions(extensions);
279 let session_config = SessionConfig::from(config);
280
281 let task_context = TaskContext::new(
282 Some("task_id".to_string()),
283 "session_id".to_string(),
284 session_config,
285 HashMap::default(),
286 HashMap::default(),
287 HashMap::default(),
288 runtime,
289 );
290
291 let test = task_context
292 .session_config()
293 .options()
294 .extensions
295 .get::<TestExtension>();
296 assert!(test.is_some());
297
298 assert_eq!(test.unwrap().value, 42);
299 assert_eq!(test.unwrap().option_value, None);
300
301 Ok(())
302 }
303}