hamelin_translation 0.3.10

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: Align pipeline schema to APPEND target table schema.
//!
//! For pipelines with APPEND, inserts a SELECT command before APPEND that:
//! 1. Widens missing columns with typed NULLs
//! 2. Narrows extra columns by excluding them
//! 3. Reorders columns to match target table schema
//!
//! Example:
//! ```text
//! LET a = 1, b = 2 | APPEND my_table
//! -- my_table schema: {b: Int, c: Int, a: Int}
//! -- pipeline schema: {a: Int, b: Int}
//! ```
//! becomes:
//! ```text
//! LET a = 1, b = 2 | SELECT b, c = CAST(NULL AS Int), a | APPEND my_table
//! ```
//!
//! This pass runs only on the main pipeline (not CTEs) after all other normalizations.

use std::rc::Rc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{self, select_command};
use hamelin_lib::tree::typed_ast::clause::ResolvedEnvironment;
use hamelin_lib::tree::typed_ast::command::{TypedAppendCommand, TypedCommandKind};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::types::struct_type::Struct;

use super::super::expand_struct::build_widening_expression;

/// Align the pipeline's schema to match the APPEND target table schema.
///
/// This pass inserts a SELECT before APPEND to ensure the pipeline output
/// matches the target table schema exactly (same columns, same order).
///
/// Contract: `Rc<TypedPipeline> -> Result<Rc<TypedPipeline>, Rc<TranslationError>>`
pub fn align_append_schema(
    pipeline: Rc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedPipeline>, Rc<TranslationError>> {
    // Find APPEND command in pipeline
    let Some((append_cmd, append_idx, input_schema)) = find_append_command(&pipeline)? else {
        return Ok(pipeline);
    };

    // Get the target table schema
    let target_schema = match &append_cmd.table.resolved {
        ResolvedEnvironment::Resolved(env) => env.flatten(),
        ResolvedEnvironment::Error(_) => {
            // Table not resolved, skip alignment
            return Ok(pipeline);
        }
    };

    let source_schema = input_schema.flatten();

    // Check if alignment is needed
    if source_schema == target_schema {
        return Ok(pipeline);
    }

    // Build the new pipeline with SELECT inserted before APPEND
    let new_pipeline =
        build_aligned_pipeline(&pipeline, &source_schema, &target_schema, append_idx)?;

    // Re-typecheck
    Ok(Rc::new(TypedPipeline::from_ast_with_context(
        Rc::new(new_pipeline),
        ctx,
    )))
}

/// Find the APPEND command in the pipeline.
///
/// Returns the command, its index, and the input schema to the APPEND.
fn find_append_command(
    pipeline: &TypedPipeline,
) -> Result<Option<(&TypedAppendCommand, usize, &Rc<TypeEnvironment>)>, Rc<TranslationError>> {
    let valid = pipeline.valid_ref()?;

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
            return Ok(Some((append_cmd, idx, &cmd.input_schema)));
        }
    }

    Ok(None)
}

/// Build a new pipeline with SELECT inserted before APPEND.
fn build_aligned_pipeline(
    pipeline: &TypedPipeline,
    source_schema: &Struct,
    target_schema: &Struct,
    append_idx: usize,
) -> Result<Pipeline, Rc<TranslationError>> {
    let valid = pipeline.valid_ref()?;
    let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if idx == append_idx {
            // Insert alignment SELECT before APPEND
            let select = build_alignment_select(source_schema, target_schema);
            pipeline_builder = pipeline_builder.command(select);
        }
        pipeline_builder = pipeline_builder.command(cmd.ast.clone());
    }

    Ok(pipeline_builder.build())
}

/// Build a SELECT command that aligns source schema to target schema.
///
/// For each field in target schema (in order):
/// - If field exists in source: reference it
/// - If field is missing: insert typed NULL
/// - If field is nested struct with different shape: recursively widen
fn build_alignment_select(
    source_schema: &Struct,
    target_schema: &Struct,
) -> builder::SelectCommandBuilder {
    let mut select_builder = select_command();

    for (field_name, field_type) in target_schema.fields.iter() {
        let source_field_type = source_schema.fields.get(field_name);
        let expr =
            build_widening_expression(field_name.name.as_str(), source_field_type, field_type);
        select_builder = select_builder.named_field(field_name.name.as_str(), expr);
    }

    select_builder
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        provider::EnvironmentProvider,
        sql::{
            expression::identifier::Identifier as SqlIdentifier,
            query::TableReference as SqlTableReference,
        },
        tree::{
            ast::{IntoTyped, TypeCheckExecutor},
            builder::{append_command, let_command, pipeline},
        },
        types::{struct_type::Struct, INT},
    };
    use std::sync::Arc;

    #[derive(Debug)]
    struct MockProvider;

    impl EnvironmentProvider for MockProvider {
        fn reflect_columns(&self, table: SqlTableReference) -> anyhow::Result<Struct> {
            let mut fields = Struct::default();
            let my_table: SqlIdentifier = "my_table".parse().unwrap();

            if table.name == my_table {
                // Target table has columns in different order and extra column
                fields.fields.insert("b".parse().unwrap(), INT);
                fields.fields.insert("c".parse().unwrap(), INT);
                fields.fields.insert("a".parse().unwrap(), INT);
                Ok(fields)
            } else {
                anyhow::bail!("Table not found: {}", table.name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
            Ok(vec![])
        }
    }

    #[test]
    fn test_align_append_schema_inserts_select() -> Result<(), Rc<TranslationError>> {
        // Pipeline: LET a = 1, b = 2 | APPEND my_table
        // Target schema: {b, c, a}
        // Should insert SELECT b, c = NULL, a before APPEND
        let p = pipeline()
            .command(let_command().named_field("a", 1).named_field("b", 2))
            .command(append_command("my_table"))
            .build();

        let typed = p.typed_with().with_provider(Arc::new(MockProvider)).typed();

        let provider = Arc::new(MockProvider);
        let registry = Arc::new(hamelin_lib::func::registry::FunctionRegistry::default());
        let mut ctx = StatementTranslationContext::new(registry, provider);

        let transformed = align_append_schema(Rc::new(typed), &mut ctx)?;

        // Should have 3 commands now: LET, SELECT, APPEND
        let valid = transformed.valid_ref()?;
        assert_eq!(valid.commands.len(), 3);

        // Second command should be SELECT
        assert!(matches!(
            &valid.commands[1].kind,
            TypedCommandKind::Select(_)
        ));

        // Third command should be APPEND
        assert!(matches!(
            &valid.commands[2].kind,
            TypedCommandKind::Append(_)
        ));

        Ok(())
    }

    #[test]
    fn test_no_append_no_change() -> Result<(), Rc<TranslationError>> {
        // Pipeline without APPEND should pass through unchanged
        let p = pipeline()
            .command(let_command().named_field("a", 1))
            .build();

        let typed = p.typed_with().with_provider(Arc::new(MockProvider)).typed();

        let provider = Arc::new(MockProvider);
        let registry = Arc::new(hamelin_lib::func::registry::FunctionRegistry::default());
        let mut ctx = StatementTranslationContext::new(registry, provider);

        let transformed = align_append_schema(Rc::new(typed.clone()), &mut ctx)?;

        // Should be unchanged (same AST)
        assert_eq!(transformed.ast, typed.ast);

        Ok(())
    }
}