use std::collections::HashMap;
use std::sync::Arc;
use datafusion::execution::SessionStateBuilder;
use crate::planner::SparkFunctionPlanner;
use crate::{
all_default_aggregate_functions, all_default_scalar_functions,
all_default_table_functions, all_default_window_functions,
};
pub trait SessionStateBuilderSpark {
fn with_spark_features(self) -> Self;
}
impl SessionStateBuilderSpark for SessionStateBuilder {
fn with_spark_features(mut self) -> Self {
self.expr_planners()
.get_or_insert_with(Vec::new)
.insert(0, Arc::new(SparkFunctionPlanner));
self.scalar_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_scalar_functions());
self.aggregate_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_aggregate_functions());
self.window_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_window_functions());
self.table_functions()
.get_or_insert_with(HashMap::new)
.extend(
all_default_table_functions()
.into_iter()
.map(|f| (f.name().to_string(), f)),
);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_state_with_spark_features() {
let state = SessionStateBuilder::new().with_spark_features().build();
assert!(
state.scalar_functions().contains_key("sha2"),
"Apache Spark scalar function 'sha2' should be registered"
);
assert!(
state.aggregate_functions().contains_key("try_sum"),
"Apache Spark aggregate function 'try_sum' should be registered"
);
assert!(
!state.expr_planners().is_empty(),
"Apache Spark expr planners should be registered"
);
}
}