hamelin_translation 0.9.7

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: Align pipeline schema to APPEND target table schema.
//!
//! For pipelines with APPEND, inserts a SELECT command before APPEND that:
//! 1. Widens missing columns with typed NULLs
//! 2. Narrows extra columns by excluding them
//! 3. Reorders columns to match target table schema
//!
//! Example:
//! ```text
//! SET a = 1, b = 2 | APPEND my_table
//! -- my_table schema: {b: Int, c: Int, a: Int}
//! -- pipeline schema: {a: Int, b: Int}
//! ```
//! becomes:
//! ```text
//! SET a = 1, b = 2 | SELECT b, c = CAST(NULL AS Int), a | APPEND my_table
//! ```
//!
//! This pass runs only on the main pipeline (not CTEs) after all other normalizations.

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{self, select_command};
use hamelin_lib::tree::typed_ast::clause::ResolvedEnvironment;
use hamelin_lib::tree::typed_ast::command::{TypedAppendCommand, TypedCommandKind};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::types::struct_type::Struct;

use super::super::expand_struct::build_widening_expression;

/// Align the pipeline's schema to match the APPEND target table schema.
///
/// This pass inserts a SELECT before APPEND to ensure the pipeline output
/// matches the target table schema exactly (same columns, same order).
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, Arc<TranslationError>>`
pub fn align_append_schema(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Find APPEND command in pipeline
    let Some((append_cmd, append_idx, input_schema)) = find_append_command(&pipeline)? else {
        return Ok(pipeline);
    };

    // Get the target table schema
    let target_schema = match &append_cmd.table.resolved {
        ResolvedEnvironment::Resolved(env) => env.as_struct(),
        ResolvedEnvironment::Error(_) => {
            // Table not resolved, skip alignment
            return Ok(pipeline);
        }
    };

    let source_schema = input_schema.as_struct();

    // Check if alignment is needed
    if source_schema == target_schema {
        return Ok(pipeline);
    }

    // Build the new pipeline with SELECT inserted before APPEND
    let new_pipeline =
        build_aligned_pipeline(&pipeline, &source_schema, &target_schema, append_idx)?;

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_pipeline),
        ctx,
    )))
}

/// Find the APPEND command in the pipeline.
///
/// Returns the command, its index, and the input schema to the APPEND.
fn find_append_command(
    pipeline: &TypedPipeline,
) -> Result<Option<(&TypedAppendCommand, usize, &Arc<TypeEnvironment>)>, Arc<TranslationError>> {
    let valid = pipeline.valid_ref()?;

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
            return Ok(Some((append_cmd, idx, &cmd.input_schema)));
        }
    }

    Ok(None)
}

/// Build a new pipeline with SELECT inserted before APPEND.
fn build_aligned_pipeline(
    pipeline: &TypedPipeline,
    source_schema: &Struct,
    target_schema: &Struct,
    append_idx: usize,
) -> Result<Pipeline, Arc<TranslationError>> {
    let valid = pipeline.valid_ref()?;
    let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if idx == append_idx {
            // Insert alignment SELECT before APPEND
            let select = build_alignment_select(source_schema, target_schema);
            pipeline_builder = pipeline_builder.command(select);
        }
        pipeline_builder = pipeline_builder.command(cmd.ast.clone());
    }

    Ok(pipeline_builder.build())
}

/// Build a SELECT command that aligns source schema to target schema.
///
/// For each field in target schema (in order):
/// - If field exists in source: reference it
/// - If field is missing: insert typed NULL
/// - If field is nested struct with different shape: recursively widen
fn build_alignment_select(
    source_schema: &Struct,
    target_schema: &Struct,
) -> builder::SelectCommandBuilder {
    let mut select_builder = select_command();

    for (field_name, field_type) in target_schema.iter() {
        let source_field_type = source_schema.lookup(field_name);
        let expr = build_widening_expression(field_name.name(), source_field_type, field_type);
        select_builder = select_builder.named_field(field_name.name(), expr);
    }

    select_builder
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::EnvironmentProvider,
        tree::{
            ast::{
                identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
                pipeline::Pipeline,
            },
            builder::{
                append_command, array, cast, field_ref, null, pipeline, select_command,
                set_command, struct_literal,
            },
            options::TypeCheckOptions,
        },
        type_check_with_options,
        types::{array::Array, struct_type::Struct, Type, INT},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    /// Provider that supports multiple table schemas for testing different widening scenarios.
    #[derive(Debug)]
    struct TestProvider;

    impl EnvironmentProvider for TestProvider {
        fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
            let my_table: Identifier = AstSimpleIdentifier::new("my_table").into();
            let nested_table: Identifier = AstSimpleIdentifier::new("nested_table").into();
            let array_table: Identifier = AstSimpleIdentifier::new("array_table").into();

            if name == &my_table {
                // my_table: {b, c, a} - tests reordering + missing field
                Ok(Struct::default()
                    .with_str("b", INT)
                    .with_str("c", INT)
                    .with_str("a", INT))
            } else if name == &nested_table {
                // nested_table: {nested: {a, b}} - tests nested struct widening
                let nested: Type = Struct::default()
                    .with_str("a", INT)
                    .with_str("b", INT)
                    .into();
                Ok(Struct::default().with_str("nested", nested))
            } else if name == &array_table {
                // array_table: {items: Array<{a, b}>} - tests array-of-structs widening
                let elem: Type = Struct::default()
                    .with_str("a", INT)
                    .with_str("b", INT)
                    .into();
                Ok(Struct::default().with_str("items", Array::new(elem).into()))
            } else {
                anyhow::bail!("Table not found: {}", name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
            Ok(vec![])
        }
    }

    fn run_test(input: Pipeline, expected: Pipeline) -> Result<(), Arc<TranslationError>> {
        let provider = Arc::new(TestProvider);
        let registry = Arc::new(FunctionRegistry::default());
        let tc_opts = || {
            TypeCheckOptions::builder()
                .registry(registry.clone())
                .provider(provider.clone())
                .build()
        };
        let input_typed = type_check_with_options(input, tc_opts()).output;
        let expected_typed = type_check_with_options(expected, tc_opts()).output;

        let mut ctx = StatementTranslationContext::new(registry, provider);

        let result = align_append_schema(Arc::new(input_typed), &mut ctx)?;

        assert_eq!(result.ast, expected_typed.ast);

        Ok(())
    }

    #[rstest]
    #[case::no_append_passthrough(
        pipeline().command(set_command().named_field("a", 1)).build(),
        pipeline().command(set_command().named_field("a", 1)).build()
    )]
    #[case::reorder_and_missing_field(
        pipeline()
            .command(set_command().named_field("a", 1).named_field("b", 2))
            .command(append_command("my_table"))
            .build(),
        pipeline()
            .command(set_command().named_field("a", 1).named_field("b", 2))
            .command(select_command()
                .named_field("b", field_ref("b"))
                .named_field("c", cast(null(), INT))
                .named_field("a", field_ref("a")))
            .command(append_command("my_table"))
            .build()
    )]
    #[case::nested_struct_widening(
        pipeline()
            .command(set_command().named_field("nested", struct_literal().field("a", 1)))
            .command(append_command("nested_table"))
            .build(),
        {
            // The normalizer now generates a cast to the target struct type
            // CastKind::StructExpansion handles adding null fields
            let target_struct: Type = Struct::default()
                .with_str("a", INT)
                .with_str("b", INT)
                .into();
            pipeline()
                .command(set_command().named_field("nested", struct_literal().field("a", 1)))
                .command(select_command()
                    .named_field("nested", cast(field_ref("nested"), target_struct)))
                .command(append_command("nested_table"))
                .build()
        }
    )]
    #[case::array_of_structs_widening(
        pipeline()
            .command(set_command().named_field("items", array().element(struct_literal().field("a", 1))))
            .command(append_command("array_table"))
            .build(),
        {
            // The normalizer now generates a cast to the target array type
            // CastKind::ArrayElementCast(StructExpansion) handles widening array elements
            let target_elem: Type = Struct::default()
                .with_str("a", INT)
                .with_str("b", INT)
                .into();
            let target_array: Type = Array::new(target_elem).into();
            pipeline()
                .command(set_command().named_field("items", array().element(struct_literal().field("a", 1))))
                .command(select_command()
                    .named_field("items", cast(field_ref("items"), target_array)))
                .command(append_command("array_table"))
                .build()
        }
    )]
    fn test_align_append_schema(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
    ) -> Result<(), Arc<TranslationError>> {
        run_test(input, expected)
    }
}