use std::collections::HashMap;
use std::sync::Arc;
use datafusion::common::TableReference;
use datafusion::datasource::DefaultTableSource;
use datafusion::logical_expr::builder::subquery_alias;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::logical_expr::{ident, Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion::prelude::SessionContext;
use datafusion_functions::core::expr_fn as core_fn;
use hamelin_executor::executor::ExecutorError;
use hamelin_lib::catalog::Column;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::identifier::{Identifier, SimpleIdentifier};
use hamelin_translation::{IRSideEffect, IRStatement};
use crate::expr::ExprTranslationContext;
use crate::pipeline::translate_pipeline_with_ctes;
pub enum TranslatedStatement {
Query {
plan: LogicalPlan,
output_schema: Vec<Column>,
},
Dml {
source_plan: LogicalPlan,
target_table: Identifier,
distinct_by: Vec<Identifier>,
},
}
impl TranslatedStatement {
pub fn query(self) -> Option<(LogicalPlan, Vec<Column>)> {
if let Self::Query {
plan,
output_schema,
} = self
{
Some((plan, output_schema))
} else {
None
}
}
pub fn dml(self) -> Option<(LogicalPlan, Identifier, Vec<Identifier>)> {
if let Self::Dml {
source_plan,
target_table,
distinct_by,
} = self
{
Some((source_plan, target_table, distinct_by))
} else {
None
}
}
pub async fn into_dml_plan(self, ctx: &SessionContext) -> Result<LogicalPlan, ExecutorError> {
let (source_plan, target_table, distinct_by) = self.dml().ok_or_else(|| {
ExecutorError::QueryError(anyhow::anyhow!("expected DML statement, got query").into())
})?;
let segments = target_table.segments();
let table_name = segments
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(".");
let table_ref = match segments {
[name] => TableReference::bare(name.as_str()),
[schema, name] => TableReference::partial(schema.as_str(), name.as_str()),
[catalog, schema, name] => {
TableReference::full(catalog.as_str(), schema.as_str(), name.as_str())
}
_ => {
return Err(ExecutorError::QueryError(
anyhow::anyhow!("Invalid DML target table identifier").into(),
))
}
};
let table_provider = ctx.table_provider(table_ref).await.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Target table '{}' not found: {}", table_name, e).into(),
)
})?;
let effective_source = if distinct_by.is_empty() {
source_plan
} else {
let target_scan = LogicalPlanBuilder::scan(
table_name.clone(),
Arc::new(DefaultTableSource::new(table_provider.clone())),
None,
)
.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Failed to scan target table for DISTINCT BY: {}", e).into(),
)
})?
.build()
.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Failed to build target scan: {}", e).into(),
)
})?;
let build_field_expr = |segments: &[SimpleIdentifier]| -> Expr {
let mut expr = ident(segments[0].as_str());
for seg in &segments[1..] {
expr = core_fn::get_field(expr, seg.as_str());
}
expr
};
let left_keys: Vec<Expr> = distinct_by
.iter()
.map(|id| build_field_expr(id.segments()))
.collect();
let right_keys: Vec<Expr> = distinct_by
.iter()
.map(|id| build_field_expr(id.segments()))
.collect();
LogicalPlanBuilder::from(source_plan)
.join_with_expr_keys(
target_scan,
datafusion::logical_expr::JoinType::LeftAnti,
(left_keys, right_keys),
None,
)
.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Failed to build anti-join for DISTINCT BY: {}", e).into(),
)
})?
.build()
.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Failed to finalize anti-join plan: {}", e).into(),
)
})?
};
LogicalPlanBuilder::insert_into(
effective_source,
&table_name,
Arc::new(DefaultTableSource::new(table_provider)),
InsertOp::Append,
)
.map_err(|e| {
ExecutorError::QueryError(anyhow::anyhow!("Failed to build INSERT plan: {}", e).into())
})?
.build()
.map_err(|e| {
ExecutorError::QueryError(
anyhow::anyhow!("Failed to finalize INSERT plan: {}", e).into(),
)
})
}
}
pub async fn translate_statement(
statement: &IRStatement,
ctx: &SessionContext,
) -> Result<TranslatedStatement, Arc<TranslationError>> {
let expr_ctx = ExprTranslationContext::default();
let mut cte_plans: HashMap<Identifier, Arc<LogicalPlan>> = HashMap::new();
for with_clause in &statement.with_clauses {
let cte_name = with_clause.name.as_str().to_string();
let cte_ident: Identifier = with_clause.name.clone().into();
let cte_plan =
translate_pipeline_with_ctes(&with_clause.pipeline, ctx, &cte_plans, &expr_ctx).await?;
let aliased_plan = subquery_alias(cte_plan, &cte_name)
.map_err(|e| Arc::new(TranslationError::wrap(with_clause, e)))?;
cte_plans.insert(cte_ident, Arc::new(aliased_plan));
}
let plan =
translate_pipeline_with_ctes(&statement.pipeline, ctx, &cte_plans, &expr_ctx).await?;
if let IRSideEffect::Append { table, distinct_by } = &statement.side_effect {
return Ok(TranslatedStatement::Dml {
source_plan: plan,
target_table: table.clone(),
distinct_by: distinct_by.clone(),
});
}
let output_schema: Vec<Column> = statement.pipeline.output_schema.as_struct().into();
Ok(TranslatedStatement::Query {
plan,
output_schema,
})
}