use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::query::Query,
builder::{self, query},
typed_ast::{
clause::TypedFromClause,
command::{TypedCommandKind, TypedFromCommand},
context::StatementTranslationContext,
pipeline::TypedPipeline,
query::TypedStatement,
},
},
};
pub fn from_to_union(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_needs_conversion(&statement)? {
return Ok(statement);
}
let new_query = transform_statement(&statement)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_needs_conversion(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement.iter().try_fold(false, |acc, p| {
pipeline_needs_conversion(p).map(|pe| pe || acc)
})
}
fn pipeline_needs_conversion(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
Ok(pipeline.valid_ref()?.commands.iter().any(
|c| matches!(&c.kind, TypedCommandKind::From(from_cmd) if from_needs_conversion(from_cmd)),
))
}
fn from_needs_conversion(cmd: &TypedFromCommand) -> bool {
cmd.clauses.len() > 1
}
fn transform_statement(statement: &TypedStatement) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for with_clause in &statement.with_clauses {
let transformed = transform_pipeline(&with_clause.pipeline)?;
let valid_name = with_clause.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(pipeline: &TypedPipeline) -> Result<Query, Arc<TranslationError>> {
let query_builder = query();
let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());
for cmd in &pipeline.valid_ref()?.commands {
match &cmd.kind {
TypedCommandKind::From(from_cmd) if from_needs_conversion(from_cmd) => {
let mut union_builder = builder::union_command().at(cmd.ast.span.clone());
for clause in &from_cmd.clauses {
match clause {
TypedFromClause::Reference(ref_clause) => {
let table_name = ref_clause.ast.identifier.clone().valid()?;
union_builder = union_builder.table_reference(table_name);
}
TypedFromClause::Alias(_) => {
continue;
}
TypedFromClause::Error(e) => return Err(e.clone()),
}
}
pipeline_builder = pipeline_builder.command(union_builder);
}
_ => pipeline_builder = pipeline_builder.command(cmd.ast.clone()),
}
}
Ok(query_builder.main(pipeline_builder.build()).build())
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::{EnvironmentProvider, NoOpProvider},
sql::{
expression::identifier::Identifier as SqlIdentifier,
query::TableReference as SqlTableReference,
},
tree::{
ast::{IntoTyped, TypeCheckExecutor},
builder::{pipeline as pipeline_builder, query},
builder::{HasMain, QueryBuilder},
typed_ast::query::TypedStatement,
},
types::{struct_type::Struct, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, table: SqlTableReference) -> anyhow::Result<Struct> {
let mut fields = Struct::default();
let events: SqlIdentifier = "events".parse().unwrap();
let logs: SqlIdentifier = "logs".parse().unwrap();
if table.name == events {
fields.fields.insert("a".parse().unwrap(), INT);
fields.fields.insert("b".parse().unwrap(), STRING);
Ok(fields)
} else if table.name == logs {
fields.fields.insert("a".parse().unwrap(), INT);
fields.fields.insert("c".parse().unwrap(), INT);
Ok(fields)
} else {
NoOpProvider::default().reflect_columns(table)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
builder
.build()
.typed_with()
.with_registry(Arc::new(FunctionRegistry::default()))
.with_provider(Arc::new(MockProvider))
.typed()
}
#[rstest]
#[case::single_from_passthrough(
query().main(pipeline_builder().from(|f| f.table_reference("events"))),
query().main(pipeline_builder().from(|f| f.table_reference("events"))),
)]
#[case::multi_from_to_union(
query().main(pipeline_builder().from(|f| f.table_reference("events").table_reference("logs"))),
query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
)]
#[case::union_passthrough(
query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
)]
fn test_from_to_union(
#[case] input: QueryBuilder<HasMain>,
#[case] expected: QueryBuilder<HasMain>,
) -> Result<(), Arc<TranslationError>> {
let statement = typed_query(input);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = from_to_union(Arc::new(statement), &mut ctx)?;
let expected_query = expected.build();
assert_eq!(transformed.ast.as_ref(), &expected_query);
Ok(())
}
}