use std::collections::HashMap;
use std::sync::Arc;
use datafusion::logical_expr::builder::subquery_alias;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use hamelin_lib::catalog::Column;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::identifier::Identifier;
use hamelin_translation::{IRSideEffect, IRStatement};
use crate::expr::ExprTranslationContext;
use crate::pipeline::translate_pipeline_with_ctes;
pub enum TranslatedStatement {
Query {
plan: LogicalPlan,
output_schema: Vec<Column>,
},
Dml {
source_plan: LogicalPlan,
target_table: Identifier,
distinct_by: Vec<Identifier>,
},
}
impl TranslatedStatement {
pub fn query(self) -> Option<(LogicalPlan, Vec<Column>)> {
if let Self::Query {
plan,
output_schema,
} = self
{
Some((plan, output_schema))
} else {
None
}
}
pub fn dml(self) -> Option<(LogicalPlan, Identifier, Vec<Identifier>)> {
if let Self::Dml {
source_plan,
target_table,
distinct_by,
} = self
{
Some((source_plan, target_table, distinct_by))
} else {
None
}
}
}
pub async fn translate_statement(
statement: &IRStatement,
ctx: &SessionContext,
) -> Result<TranslatedStatement, Arc<TranslationError>> {
let expr_ctx = ExprTranslationContext::default();
let mut cte_plans: HashMap<Identifier, Arc<LogicalPlan>> = HashMap::new();
for with_clause in &statement.with_clauses {
let cte_name = with_clause.name.as_str().to_string();
let cte_ident: Identifier = with_clause.name.clone().into();
let cte_plan =
translate_pipeline_with_ctes(&with_clause.pipeline, ctx, &cte_plans, &expr_ctx).await?;
let aliased_plan = subquery_alias(cte_plan, &cte_name)
.map_err(|e| Arc::new(TranslationError::wrap(with_clause, e)))?;
cte_plans.insert(cte_ident, Arc::new(aliased_plan));
}
let plan =
translate_pipeline_with_ctes(&statement.pipeline, ctx, &cte_plans, &expr_ctx).await?;
if let IRSideEffect::Append { table, distinct_by } = &statement.side_effect {
return Ok(TranslatedStatement::Dml {
source_plan: plan,
target_table: table.clone(),
distinct_by: distinct_by.clone(),
});
}
let output_schema: Vec<Column> = statement.pipeline.output_schema.as_struct().into();
Ok(TranslatedStatement::Query {
plan,
output_schema,
})
}