use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::{sync::Arc, vec};
use arrow::datatypes::*;
use datafusion_common::config::ConfigOptions;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference};
use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner};
use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
use datafusion_functions_nested::expr_fn::make_array;
use datafusion_sql::planner::ContextProvider;
struct MockCsvType {}
impl GetExt for MockCsvType {
fn get_ext(&self) -> String {
"csv".to_string()
}
}
impl FileType for MockCsvType {
fn as_any(&self) -> &dyn Any {
self
}
}
impl Display for MockCsvType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.get_ext())
}
}
#[derive(Default)]
pub(crate) struct MockSessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
type_planner: Option<Arc<dyn TypePlanner>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
pub config_options: ConfigOptions,
}
impl MockSessionState {
pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(expr_planner);
self
}
pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
self.type_planner = Some(type_planner);
self
}
pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
self.scalar_functions
.insert(scalar_function.name().to_string(), scalar_function);
self
}
pub fn with_aggregate_function(
mut self,
aggregate_function: Arc<AggregateUDF>,
) -> Self {
self.aggregate_functions.insert(
aggregate_function.name().to_string().to_lowercase(),
aggregate_function,
);
self
}
pub fn with_window_function(mut self, window_function: Arc<WindowUDF>) -> Self {
self.window_functions
.insert(window_function.name().to_string(), window_function);
self
}
}
pub(crate) struct MockContextProvider {
pub(crate) state: MockSessionState,
}
impl ContextProvider for MockContextProvider {
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let schema = match name.table() {
"test" => Ok(Schema::new(vec![
Field::new("t_date32", DataType::Date32, false),
Field::new("t_date64", DataType::Date64, false),
])),
"j1" => Ok(Schema::new(vec![
Field::new("j1_id", DataType::Int32, false),
Field::new("j1_string", DataType::Utf8, false),
])),
"j2" => Ok(Schema::new(vec![
Field::new("j2_id", DataType::Int32, false),
Field::new("j2_string", DataType::Utf8, false),
])),
"j3" => Ok(Schema::new(vec![
Field::new("j3_id", DataType::Int32, false),
Field::new("j3_string", DataType::Utf8, false),
])),
"test_decimal" => Ok(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("price", DataType::Decimal128(10, 2), false),
])),
"person" => Ok(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("age", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Float64, false),
Field::new(
"birth_date",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new("😀", DataType::Int32, false),
])),
"person_quoted_cols" => Ok(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("First Name", DataType::Utf8, false),
Field::new("Last Name", DataType::Utf8, false),
Field::new("Age", DataType::Int32, false),
Field::new("State", DataType::Utf8, false),
Field::new("Salary", DataType::Float64, false),
Field::new(
"Birth Date",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new("😀", DataType::Int32, false),
])),
"orders" => Ok(Schema::new(vec![
Field::new("order_id", DataType::UInt32, false),
Field::new("customer_id", DataType::UInt32, false),
Field::new("o_item_id", DataType::Utf8, false),
Field::new("qty", DataType::Int32, false),
Field::new("price", DataType::Float64, false),
Field::new("delivered", DataType::Boolean, false),
])),
"array" => Ok(Schema::new(vec![
Field::new(
"left",
DataType::List(Arc::new(Field::new_list_field(
DataType::Int64,
true,
))),
false,
),
Field::new(
"right",
DataType::List(Arc::new(Field::new_list_field(
DataType::Int64,
true,
))),
false,
),
])),
"lineitem" => Ok(Schema::new(vec![
Field::new("l_item_id", DataType::UInt32, false),
Field::new("l_description", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
])),
"aggregate_test_100" => Ok(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c4", DataType::Int16, false),
Field::new("c5", DataType::Int32, false),
Field::new("c6", DataType::Int64, false),
Field::new("c7", DataType::UInt8, false),
Field::new("c8", DataType::UInt16, false),
Field::new("c9", DataType::UInt32, false),
Field::new("c10", DataType::UInt64, false),
Field::new("c11", DataType::Float32, false),
Field::new("c12", DataType::Float64, false),
Field::new("c13", DataType::Utf8, false),
])),
"UPPERCASE_test" => Ok(Schema::new(vec![
Field::new("Id", DataType::UInt32, false),
Field::new("lower", DataType::UInt32, false),
])),
"unnest_table" => Ok(Schema::new(vec![
Field::new(
"array_col",
DataType::List(Arc::new(Field::new_list_field(
DataType::Int64,
true,
))),
false,
),
Field::new(
"struct_col",
DataType::Struct(Fields::from(vec![
Field::new("field1", DataType::Int64, true),
Field::new("field2", DataType::Utf8, true),
])),
false,
),
])),
_ => plan_err!("No table named: {} found", name.table()),
};
match schema {
Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
Err(e) => Err(e),
}
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions.get(name).cloned()
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions.get(name).cloned()
}
fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
unimplemented!()
}
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions.get(name).cloned()
}
fn options(&self) -> &ConfigOptions {
&self.state.config_options
}
fn get_file_type(&self, _ext: &str) -> Result<Arc<dyn FileType>> {
Ok(Arc::new(MockCsvType {}))
}
fn create_cte_work_table(
&self,
_name: &str,
schema: SchemaRef,
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(EmptyTable::new(schema)))
}
fn udf_names(&self) -> Vec<String> {
self.state.scalar_functions.keys().cloned().collect()
}
fn udaf_names(&self) -> Vec<String> {
self.state.aggregate_functions.keys().cloned().collect()
}
fn udwf_names(&self) -> Vec<String> {
Vec::new()
}
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
}
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
if let Some(type_planner) = &self.state.type_planner {
Some(Arc::clone(type_planner))
} else {
None
}
}
}
struct EmptyTable {
table_schema: SchemaRef,
}
impl EmptyTable {
fn new(table_schema: SchemaRef) -> Self {
Self { table_schema }
}
}
impl TableSource for EmptyTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.table_schema)
}
}
#[derive(Debug)]
pub struct CustomTypePlanner {}
impl TypePlanner for CustomTypePlanner {
fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result<Option<DataType>> {
match sql_type {
sqlparser::ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}
#[derive(Debug)]
pub struct CustomExprPlanner {}
impl ExprPlanner for CustomExprPlanner {
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
_schema: &DFSchema,
) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Planned(make_array(exprs)))
}
}