hamelin_translation 0.7.2

Lowering and IR for Hamelin query language
Documentation
//! Pass 1.5: FROM-to-UNION conversion.
//!
//! Converts multi-source FROM commands into UNION commands so downstream
//! passes only need to reason about UNION for schema widening.

use std::sync::Arc;

#[cfg(test)]
use hamelin_lib::type_check_with_options;
use hamelin_lib::{
    err::TranslationError,
    tree::{
        ast::query::Query,
        builder::{self, query},
        typed_ast::{
            clause::TypedFromClause,
            command::{TypedCommandKind, TypedFromCommand},
            context::StatementTranslationContext,
            pipeline::TypedPipeline,
            query::TypedStatement,
        },
    },
};

/// Convert multi-source FROM commands into UNION commands.
pub fn from_to_union(
    statement: Arc<TypedStatement>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
    if !statement_needs_conversion(&statement)? {
        return Ok(statement);
    }

    let new_query = transform_statement(&statement)?;

    Ok(Arc::new(TypedStatement::from_ast_with_context(
        Arc::new(new_query),
        ctx,
    )))
}

fn statement_needs_conversion(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
    statement.iter().try_fold(false, |acc, p| {
        pipeline_needs_conversion(p).map(|pe| pe || acc)
    })
}

fn pipeline_needs_conversion(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
    Ok(pipeline.valid_ref()?.commands.iter().any(
        |c| matches!(&c.kind, TypedCommandKind::From(from_cmd) if from_needs_conversion(from_cmd)),
    ))
}

fn from_needs_conversion(cmd: &TypedFromCommand) -> bool {
    cmd.clauses.len() > 1
}

fn transform_statement(statement: &TypedStatement) -> Result<Query, Arc<TranslationError>> {
    let mut query_builder = query();

    for with_clause in &statement.with_clauses {
        let transformed = transform_pipeline(&with_clause.pipeline)?;
        let valid_name = with_clause.name.clone().valid()?;
        query_builder = query_builder.merge_as_cte(transformed, valid_name);
    }

    let main_query = transform_pipeline(&statement.pipeline)?;
    Ok(query_builder.merge_as_main(main_query))
}

fn transform_pipeline(pipeline: &TypedPipeline) -> Result<Query, Arc<TranslationError>> {
    let query_builder = query();
    let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());

    for cmd in &pipeline.valid_ref()?.commands {
        match &cmd.kind {
            TypedCommandKind::From(from_cmd) if from_needs_conversion(from_cmd) => {
                let mut union_builder = builder::union_command().at(cmd.ast.span.clone());
                for clause in &from_cmd.clauses {
                    match clause {
                        TypedFromClause::Reference(ref_clause) => {
                            let table_name = ref_clause.ast.identifier.clone().valid()?;
                            union_builder = union_builder.table_reference(table_name);
                        }
                        TypedFromClause::Alias(_) => {
                            continue;
                        }
                        TypedFromClause::Error(e) => return Err(e.clone()),
                    }
                }
                pipeline_builder = pipeline_builder.command(union_builder);
            }
            _ => pipeline_builder = pipeline_builder.command(cmd.ast.clone()),
        }
    }

    Ok(query_builder.main(pipeline_builder.build()).build())
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::{EnvironmentProvider, NoOpProvider},
        tree::{
            ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
            builder::QueryBuilderWithMain,
            builder::{pipeline as pipeline_builder, query},
            options::TypeCheckOptions,
            typed_ast::query::TypedStatement,
        },
        types::{struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[derive(Debug)]
    struct MockProvider;

    impl EnvironmentProvider for MockProvider {
        fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
            let events: Identifier = AstSimpleIdentifier::new("events").into();
            let logs: Identifier = AstSimpleIdentifier::new("logs").into();
            if name == &events {
                Ok(Struct::default().with_str("a", INT).with_str("b", STRING))
            } else if name == &logs {
                Ok(Struct::default().with_str("a", INT).with_str("c", INT))
            } else {
                NoOpProvider::default().reflect_columns(name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
            Ok(vec![])
        }
    }

    fn typed_query(builder: QueryBuilderWithMain) -> TypedStatement {
        type_check_with_options(
            builder.build(),
            TypeCheckOptions::builder()
                .registry(Arc::new(FunctionRegistry::default()))
                .provider(Arc::new(MockProvider))
                .build(),
        )
        .output
    }

    #[rstest]
    #[case::single_from_passthrough(
        query().main(pipeline_builder().from(|f| f.table_reference("events"))),
        query().main(pipeline_builder().from(|f| f.table_reference("events"))),
    )]
    #[case::multi_from_to_union(
        query().main(pipeline_builder().from(|f| f.table_reference("events").table_reference("logs"))),
        query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
    )]
    #[case::union_passthrough(
        query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
        query().main(pipeline_builder().union(|u| u.table_reference("events").table_reference("logs"))),
    )]
    fn test_from_to_union(
        #[case] input: QueryBuilderWithMain,
        #[case] expected: QueryBuilderWithMain,
    ) -> Result<(), Arc<TranslationError>> {
        let statement = typed_query(input);
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = from_to_union(Arc::new(statement), &mut ctx)?;

        let expected_query = expected.build();
        assert_eq!(transformed.ast.as_ref(), &expected_query);
        Ok(())
    }
}