use std::sync::Arc;
use hamelin_lib::err::{TranslationError, TranslationErrors};
use hamelin_lib::sql::expression::apply::BinaryOperatorApply;
use hamelin_lib::sql::expression::identifier::SimpleIdentifier as SQLSimpleIdentifier;
use hamelin_lib::sql::expression::literal::ColumnReference;
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::query::dml::{
Insert, Merge, MergeAction, MergeInsert, MergeWhenNotMatched, DML,
};
use hamelin_lib::sql::query::{SQLQuery, TableReference};
use hamelin_lib::sql::statement::Statement;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::ast::identifier::Identifier;
use hamelin_translation::{IRPipeline, IRSideEffect, IRStatement};
use crate::command::CommandTranslator;
use crate::context::TranslationContext;
use crate::TranslationRegistry;
pub type StatementTranslationResult = Result<Statement, TranslationErrors>;
pub fn translate_statement<T: CommandTranslator>(
ir: &IRStatement,
registry: &TranslationRegistry,
translator: &T,
) -> StatementTranslationResult {
let mut ctx = TranslationContext::new(registry);
translate_statement_with_context(ir, &mut ctx, translator)
}
pub fn translate_statement_with_context<T: CommandTranslator>(
ir: &IRStatement,
ctx: &mut TranslationContext,
translator: &T,
) -> StatementTranslationResult {
for with_clause in &ir.with_clauses {
let cte_query = translate_pipeline(ctx, translator, &with_clause.pipeline)?;
ctx.add_cte(with_clause.name.clone().into(), cte_query);
}
let query = translate_pipeline(ctx, translator, &ir.pipeline)?;
let query_with_ctes = wrap_with_ctes(query, ctx);
match &ir.side_effect {
IRSideEffect::None => Ok(query_with_ctes.into()),
IRSideEffect::Append { table, distinct_by } => translate_append_to_dml(
table,
distinct_by,
query_with_ctes,
&ir.pipeline.output_schema,
&ir.pipeline,
),
}
}
fn translate_pipeline<T: CommandTranslator>(
ctx: &mut TranslationContext,
translator: &T,
pipeline: &Arc<IRPipeline>,
) -> Result<SQLQuery, TranslationErrors> {
if pipeline.commands.is_empty() {
return Err(TranslationError::msg(pipeline.as_ref(), "Pipeline has no commands").single());
}
let mut query = None;
for cmd in &pipeline.commands {
query = Some(translator.translate_command(ctx, cmd, query)?);
}
query.ok_or_else(|| {
TranslationError::msg(pipeline.as_ref(), "Pipeline produced no query").single()
})
}
fn wrap_with_ctes(mut query: SQLQuery, ctx: &TranslationContext) -> SQLQuery {
for (name, cte_query) in ctx.ctes.iter() {
query = query.with_cte(name.clone(), cte_query.clone());
}
query
}
fn translate_append_to_dml(
append_table: &Identifier,
distinct_by: &[Identifier],
source_query: SQLQuery,
output_schema: &Arc<TypeEnvironment>,
pipeline: &IRPipeline,
) -> StatementTranslationResult {
let sql_ident = append_table.clone().into();
let table = TableReference::new(sql_ident);
let schema = output_schema.as_struct().clone();
if distinct_by.is_empty() {
let insert = Insert {
table,
schema,
query: source_query,
};
let dml: DML = insert.into();
Ok(dml.into())
} else {
let target_alias = SQLSimpleIdentifier::new("target");
let source_alias = SQLSimpleIdentifier::new("source");
let conditions: Vec<BinaryOperatorApply> = distinct_by
.iter()
.map(|col| {
let col_ref = ColumnReference::new(col.clone().into());
let source_ref = col_ref.prefixed_with(source_alias.name.as_str());
let target_ref = col_ref.prefixed_with(target_alias.name.as_str());
Ok(BinaryOperatorApply::new(
Operator::Eq,
source_ref.into(),
target_ref.into(),
))
})
.collect::<Result<Vec<_>, TranslationErrors>>()?;
let search_condition = conditions
.into_iter()
.reduce(|a, b| BinaryOperatorApply::new(Operator::And, a.into(), b.into()))
.ok_or_else(|| {
TranslationError::msg(pipeline, "DISTINCT BY requires at least one column").single()
})?;
let insert_columns: Vec<ColumnReference> = schema
.keys()
.map(|col| {
let sql_col: SQLSimpleIdentifier = col.clone().into();
ColumnReference::new(sql_col.into())
})
.collect();
let merge = Merge {
table_alias: target_alias,
query_alias: source_alias,
table,
search_condition: search_condition.into(),
when_clauses: vec![
MergeWhenNotMatched::new(MergeAction::Insert(MergeInsert::new(insert_columns)))
.into(),
],
query: source_query,
};
let dml: DML = merge.into();
Ok(dml.into())
}
}