use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::query::Query,
builder::{self, field_ref, from_command, query, select_command, struct_literal},
typed_ast::{
clause::TypedFromClause, command::TypedCommandKind,
context::StatementTranslationContext, pipeline::TypedPipeline, query::TypedStatement,
},
},
};
use crate::unique::UniqueNameGenerator;
use hamelin_lib::tree::builder::pipeline as pipeline_builder;
pub fn nest_from_aliases(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_has_aliases(&statement)? {
return Ok(statement);
}
let mut name_gen = UniqueNameGenerator::new("__alias");
let new_query = transform_statement(&statement, &mut name_gen)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_has_aliases(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement
.iter()
.try_fold(false, |acc, p| pipeline_has_aliases(p).map(|pa| pa || acc))
}
fn pipeline_has_aliases(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let res = pipeline
.valid_ref()?
.commands
.iter()
.flat_map(|c| match &c.kind {
TypedCommandKind::From(typed_from_command) => typed_from_command.clauses.iter(),
_ => [].iter(),
})
.any(|c| matches!(c, TypedFromClause::Alias(_)));
Ok(res)
}
fn transform_statement(
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for sd in &statement.scalar_defs {
let name = sd.name.valid_ref()?.clone();
query_builder = query_builder.def_expression(name, sd.expression.ast.clone());
}
for pd in &statement.pipeline_defs {
let transformed = transform_pipeline(&pd.pipeline, statement, name_gen)?;
let valid_name = pd.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, statement, name_gen)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
let mut pipeline_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in &pipeline.valid_ref()?.commands {
match &cmd.kind {
TypedCommandKind::From(from_cmd)
if from_cmd
.clauses
.iter()
.any(|c| matches!(c, TypedFromClause::Alias(_))) =>
{
let mut from_builder = from_command().at(cmd.ast.span.clone());
for clause in &from_cmd.clauses {
match clause {
TypedFromClause::Alias(alias_clause) => {
let alias = alias_clause.alias.valid_ref()?;
let cte_name = name_gen.next(statement);
let table_name = alias_clause
.ast
.table
.identifier
.valid_ref()
.map(|id| id.clone())?;
let table_env = alias_clause.resolved.environment();
let table_schema = table_env.as_struct();
let mut struct_builder = struct_literal();
for (field_name, _) in table_schema.iter() {
struct_builder = struct_builder
.field(field_name.name(), field_ref(field_name.name()));
}
let mut select_builder = select_command();
for (field_name, _) in table_schema.iter() {
if field_name.name() == alias.as_str() {
continue;
}
select_builder = select_builder.field(field_name.to_string());
}
select_builder =
select_builder.named_field(alias.clone(), struct_builder);
let cte_pipeline = builder::pipeline()
.from(|f| f.table_reference(table_name))
.command(select_builder)
.build();
query_builder =
query_builder.def_pipeline(cte_name.clone(), cte_pipeline);
from_builder = from_builder.table_reference(cte_name);
}
TypedFromClause::Reference(ref_clause) => {
from_builder = from_builder
.table_reference(ref_clause.ast.identifier.clone().valid()?);
}
TypedFromClause::Error(e) => return Err(e.clone()),
}
}
pipeline_builder = pipeline_builder.command(from_builder);
}
_ => {
pipeline_builder = pipeline_builder.command(cmd.ast.clone());
}
}
}
Ok(query_builder.main(pipeline_builder.build()).build())
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
tree::{
ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
builder::{eq, field, field_ref, query, QueryBuilderWithMain},
},
type_check_with_provider,
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
if name == &events {
Ok(Struct::default().with_str("a", INT).with_str("b", INT))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilderWithMain) -> TypedStatement {
type_check_with_provider(builder.build(), Arc::new(MockProvider)).output
}
#[test]
fn test_no_aliases_passthrough() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.where_cmd(eq(field_ref("a"), 10)),
);
let statement = typed_query(q);
assert!(!statement_has_aliases(&statement)?);
Ok(())
}
#[test]
fn test_single_alias_generates_cte() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_alias("x", "events"))
.where_cmd(eq(field(field_ref("x"), "a"), 10)),
);
let statement = typed_query(q);
assert!(statement_has_aliases(&statement)?);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = nest_from_aliases(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.pipeline_defs.len(), 1);
let cte_name = transformed.pipeline_defs[0].name.valid_ref().unwrap();
assert_eq!(cte_name.to_string(), "__alias_0");
Ok(())
}
#[test]
fn test_multiple_aliases_generate_multiple_ctes() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder().from(|f| f.table_alias("x", "events").table_alias("y", "events")),
);
let statement = typed_query(q);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = nest_from_aliases(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.pipeline_defs.len(), 2);
let cte_name_0 = transformed.pipeline_defs[0].name.valid_ref().unwrap();
let cte_name_1 = transformed.pipeline_defs[1].name.valid_ref().unwrap();
assert_eq!(cte_name_0.to_string(), "__alias_0");
assert_eq!(cte_name_1.to_string(), "__alias_1");
Ok(())
}
}