use arrow_schema::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{plan_err, Result};
use datafusion_expr::WindowUDF;
use datafusion_expr::{
logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
sqlparser::{dialect::GenericDialect, parser::Parser},
TableReference,
};
use std::{collections::HashMap, sync::Arc};
fn main() {
let sql = "SELECT \
c.id, c.first_name, c.last_name, \
COUNT(*) as num_orders, \
sum(o.price) AS total_price, \
sum(o.price * s.sales_tax) AS state_tax \
FROM customer c \
JOIN state s ON c.state = s.id \
JOIN orders o ON c.id = o.customer_id \
WHERE o.price > 0 \
AND c.last_name LIKE 'G%' \
GROUP BY 1, 2, 3 \
ORDER BY state_tax DESC";
let dialect = GenericDialect {}; let ast = Parser::parse_sql(&dialect, sql).unwrap();
let statement = &ast[0];
let context_provider = MyContextProvider::new()
.with_udaf(sum_udaf())
.with_udaf(count_udaf());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
println!("{plan:?}");
}
struct MyContextProvider {
options: ConfigOptions,
tables: HashMap<String, Arc<dyn TableSource>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
}
impl MyContextProvider {
fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self {
self.udafs.insert(udaf.name().to_lowercase(), udaf);
self
}
fn new() -> Self {
let mut tables = HashMap::new();
tables.insert(
"customer".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
]),
);
tables.insert(
"state".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("sales_tax", DataType::Decimal128(10, 2), false),
]),
);
tables.insert(
"orders".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("customer_id", DataType::Int32, false),
Field::new("item_id", DataType::Int32, false),
Field::new("quantity", DataType::Int32, false),
Field::new("price", DataType::Decimal128(10, 2), false),
]),
);
Self {
tables,
options: Default::default(),
udafs: Default::default(),
}
}
}
fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
Arc::new(LogicalTableSource::new(Arc::new(
Schema::new_with_metadata(fields, HashMap::new()),
)))
}
impl ContextProvider for MyContextProvider {
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
match self.tables.get(name.table()) {
Some(table) => Ok(table.clone()),
_ => plan_err!("Table not found: {}", name.table()),
}
}
fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
None
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.udafs.get(name).cloned()
}
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
None
}
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}
fn options(&self) -> &ConfigOptions {
&self.options
}
fn udf_names(&self) -> Vec<String> {
Vec::new()
}
fn udaf_names(&self) -> Vec<String> {
Vec::new()
}
fn udwf_names(&self) -> Vec<String> {
Vec::new()
}
}