use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use datafusion_common::{
config::{ConfigOptions, Extensions},
DataFusionError, Result,
};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use crate::{
config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry,
runtime_env::RuntimeEnv,
};
pub struct TaskContext {
session_id: String,
task_id: Option<String>,
session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
}
impl TaskContext {
pub fn new(
task_id: Option<String>,
session_id: String,
session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
) -> Self {
Self {
task_id,
session_id,
session_config,
scalar_functions,
aggregate_functions,
runtime,
}
}
#[deprecated(
since = "21.0.0",
note = "Construct SessionConfig and call TaskContext::new() instead"
)]
pub fn try_new(
task_id: String,
session_id: String,
task_props: HashMap<String, String>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
extensions: Extensions,
) -> Result<Self> {
let mut config = ConfigOptions::new().with_extensions(extensions);
for (k, v) in task_props {
config.set(&k, &v)?;
}
let session_config = SessionConfig::from(config);
Ok(Self::new(
Some(task_id),
session_id,
session_config,
scalar_functions,
aggregate_functions,
runtime,
))
}
pub fn session_config(&self) -> &SessionConfig {
&self.session_config
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn task_id(&self) -> Option<String> {
self.task_id.clone()
}
pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
&self.runtime.memory_pool
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.runtime.clone()
}
}
impl FunctionRegistry for TaskContext {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDF named \"{name}\" in the TaskContext"
))
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDAF named \"{name}\" in the TaskContext"
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::{config::ConfigExtension, extensions_options};
extensions_options! {
struct TestExtension {
value: usize, default = 42
}
}
impl ConfigExtension for TestExtension {
const PREFIX: &'static str = "test";
}
#[test]
fn task_context_extensions() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let mut extensions = Extensions::new();
extensions.insert(TestExtension::default());
let mut config = ConfigOptions::new().with_extensions(extensions);
config.set("test.value", "24")?;
let session_config = SessionConfig::from(config);
let task_context = TaskContext::new(
Some("task_id".to_string()),
"session_id".to_string(),
session_config,
HashMap::default(),
HashMap::default(),
runtime,
);
let test = task_context
.session_config()
.options()
.extensions
.get::<TestExtension>();
assert!(test.is_some());
assert_eq!(test.unwrap().value, 24);
Ok(())
}
}