hamelin_translation 0.3.10

Lowering and IR for Hamelin query language
Documentation
//! Pass 2: UNION schema expansion pass.
//!
//! For UNION commands with differing schemas, generates CTEs that widen each
//! source to the merged output schema.
//!
//! Example:
//! ```text
//! UNION events, logs
//! -- events: {timestamp: Timestamp, event_type: String}
//! -- logs: {timestamp: Timestamp, message: String}
//! -- merged output: {timestamp: Timestamp, event_type: String?, message: String?}
//! ```
//! becomes:
//! ```text
//! WITH __union_0 = (FROM events | SELECT timestamp, event_type, message = CAST(NULL AS String))
//! WITH __union_1 = (FROM logs | SELECT timestamp, event_type = CAST(NULL AS String), message)
//! UNION __union_0, __union_1
//! ```
//!
//! This pass operates at the statement level, taking a TypedStatement and
//! returning a TypedStatement. It uses builders to construct AST, then
//! runs type-checking to derive types.

use std::rc::Rc;

use hamelin_lib::tree::builder::pipeline as pipeline_builder;
use hamelin_lib::{
    err::TranslationError,
    tree::{
        ast::{identifier::Identifier, pipeline::Pipeline, query::Query},
        builder::{self, query, select_command},
        typed_ast::{
            clause::TypedFromClause,
            command::{TypedCommandKind, TypedUnionCommand},
            context::StatementTranslationContext,
            environment::TypeEnvironment,
            pipeline::TypedPipeline,
            query::TypedStatement,
        },
    },
    types::struct_type::Struct,
};

use super::super::{expand_struct::build_widening_expression, unique::UniqueNameGenerator};

/// Expand UNION clauses with differing schemas into CTEs.
///
/// This is Pass 2 of normalization. It transforms UNION commands where
/// sources have different schemas into CTEs that widen each source to the
/// merged output schema.
pub fn expand_union_schemas(
    statement: Rc<TypedStatement>,
    ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedStatement>, Rc<TranslationError>> {
    // Check if any pipeline has UNION commands that need expansion
    if !statement_needs_expansion(&statement)? {
        return Ok(statement);
    }

    let mut name_gen = UniqueNameGenerator::new("__union");
    let new_query = transform_statement(&statement, &mut name_gen)?;

    Ok(Rc::new(TypedStatement::from_ast_with_context(
        Rc::new(new_query),
        ctx,
    )))
}

/// Check if the statement has any UNION commands that need schema expansion.
fn statement_needs_expansion(statement: &TypedStatement) -> Result<bool, Rc<TranslationError>> {
    statement.iter().try_fold(false, |acc, p| {
        pipeline_needs_expansion(p).map(|pe| pe || acc)
    })
}

/// Check if a pipeline has any UNION commands that need schema expansion.
fn pipeline_needs_expansion(pipeline: &TypedPipeline) -> Result<bool, Rc<TranslationError>> {
    let res = pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(|c| match &c.kind {
            TypedCommandKind::Union(union_cmd) => {
                union_needs_expansion(union_cmd, &c.output_schema)
            }
            _ => false,
        });

    Ok(res)
}

/// Check if a UNION command needs schema widening.
///
/// Returns true if multiple clauses have differing schemas that need to be
/// unified. Single-input UNION does not need widening.
fn union_needs_expansion(cmd: &TypedUnionCommand, output_schema: &TypeEnvironment) -> bool {
    if cmd.clauses.len() <= 1 {
        return false;
    }

    let output_struct = output_schema.flatten();
    cmd.clauses.iter().any(|clause| {
        let clause_struct = clause.environment().flatten();
        clause_struct != output_struct
    })
}

/// Transform a full statement, processing all pipelines and returning a new Query.
fn transform_statement(
    statement: &TypedStatement,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Rc<TranslationError>> {
    let mut query_builder = query();

    // Process existing WITH clauses - each may generate additional CTEs
    for with_clause in &statement.with_clauses {
        let transformed = transform_pipeline(&with_clause.pipeline, name_gen)?;
        let valid_name = with_clause.name.clone().valid()?;

        query_builder = query_builder.merge_as_cte(transformed, valid_name);
    }

    // Process main pipeline
    let main_query = transform_pipeline(&statement.pipeline, name_gen)?;
    Ok(query_builder.merge_as_main(main_query))
}

/// Transform a pipeline, generating CTEs for UNION commands that need schema widening.
///
/// Returns a Query containing any generated CTEs and the transformed pipeline as main.
fn transform_pipeline(
    pipeline: &TypedPipeline,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Rc<TranslationError>> {
    let mut query_builder = query();
    let mut pipeline_builder = pipeline_builder().at(pipeline.ast.span.clone());

    for cmd in &pipeline.valid_ref()?.commands {
        match &cmd.kind {
            TypedCommandKind::Union(union_cmd)
                if union_needs_expansion(union_cmd, &cmd.output_schema) =>
            {
                let output_struct = cmd.output_schema.flatten();
                let mut union_builder = builder::union_command().at(cmd.ast.span.clone());

                for clause in &union_cmd.clauses {
                    match clause {
                        TypedFromClause::Reference(ref_clause) => {
                            let table_name = ref_clause.ast.identifier.clone().valid()?;
                            let clause_struct = clause.environment().flatten();

                            let cte_name = name_gen.next();
                            let cte_pipeline = build_widening_pipeline(
                                table_name.clone(),
                                &clause_struct,
                                &output_struct,
                            );

                            query_builder = query_builder.with(cte_name.clone(), cte_pipeline);
                            union_builder = union_builder.table_reference(cte_name);
                        }
                        TypedFromClause::Alias(_) => {
                            continue;
                        }
                        TypedFromClause::Error(e) => return Err(e.clone()),
                    }
                }

                pipeline_builder = pipeline_builder.command(union_builder);
            }
            _ => pipeline_builder = pipeline_builder.command(cmd.ast.clone()),
        }
    }

    Ok(query_builder.main(pipeline_builder.build()).build())
}

/// Build a pipeline that widens a single UNION clause to match the target schema.
///
/// Pipeline: FROM <table> | SELECT <field1>, <field2>, <field3> = CAST(NULL AS <type>), ...
///
/// For nested struct fields that differ between source and target, this builds
/// struct literals that recursively widen the nested fields.
fn build_widening_pipeline(
    table_name: Identifier,
    source_struct: &Struct,
    target_struct: &Struct,
) -> Pipeline {
    let mut select_builder = select_command();

    // For each field in the target schema (in order)
    for (field_name, field_type) in target_struct.fields.iter() {
        let source_field_type = source_struct.fields.get(field_name);
        let widened_expr =
            build_widening_expression(field_name.name.as_str(), source_field_type, field_type);
        select_builder = select_builder.named_field(field_name.name.as_str(), widened_expr);
    }

    builder::pipeline()
        .from(|f| f.table_reference(table_name))
        .command(select_builder)
        .build()
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::EnvironmentProvider,
        sql::{
            expression::identifier::Identifier as SqlIdentifier,
            query::TableReference as SqlTableReference,
        },
        tree::{
            ast::{IntoTyped, TypeCheckExecutor},
            builder::{
                cast, column_ref, field, query, select_command, struct_literal, HasMain,
                NullLiteralBuilder, QueryBuilder,
            },
        },
        types::{struct_type::Struct, Type, INT},
    };
    use std::sync::Arc;

    // Mock provider for tests
    #[derive(Debug)]
    struct MockProvider;

    impl EnvironmentProvider for MockProvider {
        fn reflect_columns(&self, table: SqlTableReference) -> anyhow::Result<Struct> {
            let mut fields = Struct::default();
            let events: SqlIdentifier = "events".parse().unwrap();
            let logs: SqlIdentifier = "logs".parse().unwrap();

            if table.name == events {
                fields.fields.insert("a".parse().unwrap(), INT);
                fields.fields.insert("b".parse().unwrap(), INT);
                Ok(fields)
            } else if table.name == logs {
                fields.fields.insert("a".parse().unwrap(), INT);
                fields.fields.insert("c".parse().unwrap(), INT);
                Ok(fields)
            } else {
                anyhow::bail!("Table not found: {}", table.name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
            Ok(vec![])
        }
    }

    fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
        builder
            .build()
            .typed_with()
            .with_provider(Arc::new(MockProvider))
            .typed()
    }

    #[test]
    fn test_single_table_no_expansion() -> Result<(), Rc<TranslationError>> {
        // UNION events - single input, no expansion needed
        let q = query().main(pipeline_builder().union(|u| u.table_reference("events")));

        let statement = typed_query(q);

        // Should not need expansion
        assert!(!statement_needs_expansion(&statement)?);
        Ok(())
    }

    #[test]
    fn test_identical_schemas_no_expansion() -> Result<(), Rc<TranslationError>> {
        // UNION events, events - same schema, no expansion needed
        let q = query().main(
            pipeline_builder().union(|u| u.table_reference("events").table_reference("events")),
        );

        let statement = typed_query(q);

        // Should not need expansion (both have same schema)
        assert!(!statement_needs_expansion(&statement)?);
        Ok(())
    }

    #[test]
    fn test_different_schemas_needs_expansion() -> Result<(), Rc<TranslationError>> {
        // UNION events, logs - different schemas, needs expansion
        // events: {a, b}, logs: {a, c}
        let q = query().main(
            pipeline_builder().union(|u| u.table_reference("events").table_reference("logs")),
        );

        let statement = typed_query(q);

        // Should need expansion
        assert!(statement_needs_expansion(&statement)?);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = expand_union_schemas(Rc::new(statement), &mut ctx)?;

        // Should now have 2 CTEs (one for each table)
        assert_eq!(transformed.with_clauses.len(), 2);

        // CTE names should be __union_0 and __union_1
        let cte_name_0 = transformed.with_clauses[0].name.valid_ref().unwrap();
        let cte_name_1 = transformed.with_clauses[1].name.valid_ref().unwrap();
        assert_eq!(cte_name_0.to_string(), "__union_0");
        assert_eq!(cte_name_1.to_string(), "__union_1");
        Ok(())
    }

    #[test]
    fn test_nested_struct_schema_widening() -> Result<(), Rc<TranslationError>> {
        #[derive(Debug)]
        struct NestedProvider;

        impl EnvironmentProvider for NestedProvider {
            fn reflect_columns(&self, table: SqlTableReference) -> anyhow::Result<Struct> {
                let mut fields = Struct::default();
                let events: SqlIdentifier = "events".parse().unwrap();
                let logs: SqlIdentifier = "logs".parse().unwrap();

                let nested_events: Type = Struct::default().with_str("a", INT).into();
                let nested_logs: Type = Struct::default()
                    .with_str("a", INT)
                    .with_str("b", INT)
                    .into();

                if table.name == events {
                    fields
                        .fields
                        .insert("nested".parse().unwrap(), nested_events);
                    Ok(fields)
                } else if table.name == logs {
                    fields.fields.insert("nested".parse().unwrap(), nested_logs);
                    Ok(fields)
                } else {
                    anyhow::bail!("Table not found: {}", table.name)
                }
            }

            fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
                Ok(vec![])
            }
        }

        let q = query().main(
            pipeline_builder().union(|u| u.table_reference("events").table_reference("logs")),
        );

        let statement = q
            .build()
            .typed_with()
            .with_provider(Arc::new(NestedProvider))
            .typed();

        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(NestedProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = expand_union_schemas(Rc::new(statement), &mut ctx)?;

        assert_eq!(transformed.with_clauses.len(), 2);

        let expected_events = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .command(
                    select_command()
                        .named_field(
                            "nested",
                            struct_literal()
                                .field("a", field(column_ref("nested"), "a"))
                                .field("b", cast(NullLiteralBuilder::new(), INT)),
                        )
                        .build(),
                ),
        );

        let expected_typed = expected_events
            .build()
            .typed_with()
            .with_provider(Arc::new(NestedProvider))
            .typed();

        assert_eq!(
            transformed.with_clauses[0].pipeline.ast,
            expected_typed.pipeline.ast
        );
        Ok(())
    }
}