hamelin_translation 0.4.2

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: UNNEST lowering.
//!
//! Lowers UNNEST commands to simpler primitives using a temp column to preserve
//! the original column unchanged.
//!
//! **Case 1: UNNEST of STRUCT column `x`** (where x: {a, b})
//! ```text
//! UNNEST x
//! ```
//! becomes:
//! ```text
//! LET __unnest_0 = x | LET a = __unnest_0.a, b = __unnest_0.b | DROP __unnest_0
//! ```
//! The original column `x` is preserved unchanged.
//!
//! **Case 2: UNNEST of ARRAY<STRUCT> column `arr`** (where arr: ARRAY<{a, b}>)
//! ```text
//! UNNEST arr
//! ```
//! becomes:
//! ```text
//! LET __unnest_0 = arr | EXPLODE __unnest_0 = __unnest_0 | LET a = __unnest_0.a, b = __unnest_0.b | DROP __unnest_0
//! ```
//! The original column `arr` is preserved as ARRAY<STRUCT>.

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::{command::Command, identifier::SimpleIdentifier},
    builder::{self, column_ref, drop_command, explode_command, field, let_command},
    typed_ast::{
        command::{TypedCommand, TypedCommandKind, TypedUnnestCommand},
        context::StatementTranslationContext,
        pipeline::TypedPipeline,
    },
};
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::Type;

use super::super::unique::UniqueNameGenerator;

/// Lower UNNEST commands in a pipeline.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn lower_unnest(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any UNNEST command exists
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(|cmd| matches!(&cmd.kind, TypedCommandKind::Unnest(_)))
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Shared name generator for temp columns
    let mut name_gen = UniqueNameGenerator::new("__unnest");

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in lower_command(cmd, &mut name_gen, ctx)? {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Lower a single command - transforms UNNEST, passes others through.
fn lower_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let TypedCommandKind::Unnest(unnest_cmd) = &cmd.kind else {
        return Ok(vec![cmd.ast.clone()]);
    };

    lower_unnest_command(unnest_cmd, cmd, name_gen, ctx)
}

/// Lower an UNNEST command to LET/DROP (and EXPLODE for arrays).
fn lower_unnest_command(
    unnest_cmd: &TypedUnnestCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let expr = &unnest_cmd.expression;
    let expr_type = expr.resolved_type.as_ref();

    // Get the struct type (either directly or from array element)
    let (is_array, struct_type) = extract_struct_type(expr_type, ctx, expr.ast.as_ref())?;

    let mut result: Vec<Arc<Command>> = Vec::new();

    // Always use a temp column to preserve the original column unchanged.
    // This matches legacy behavior where UNNEST doesn't modify the source column.
    let temp_name = name_gen.next(&cmd.input_schema);

    // LET __unnest_0 = expr
    let let_cmd = let_command()
        .named_field(temp_name.clone(), expr.ast.as_ref().clone())
        .at(cmd.ast.span)
        .build();
    result.push(Arc::new(let_cmd));

    // If array, add EXPLODE to convert ARRAY<STRUCT> to STRUCT
    if is_array {
        let explode = explode_command()
            .named_field(temp_name.clone(), column_ref(temp_name.as_str()))
            .at(cmd.ast.span)
            .build();
        result.push(Arc::new(explode));
    }

    // Build LET to extract struct fields: LET a = __unnest_0.a, b = __unnest_0.b, ...
    let mut let_builder = let_command().at(cmd.ast.span);
    for (field_name, _field_type) in struct_type.fields.iter() {
        let field_expr = field(column_ref(temp_name.as_str()), field_name.name.as_str());
        let field_id: SimpleIdentifier = field_name.clone().into();
        let_builder = let_builder.named_field(field_id, field_expr);
    }
    result.push(Arc::new(let_builder.build()));

    // Always DROP the temp column
    let drop = drop_command().field(temp_name).at(cmd.ast.span).build();
    result.push(Arc::new(drop));

    Ok(result)
}

/// Extract struct type from expression type.
///
/// Returns (is_array, struct_type) or an error if type is not STRUCT or ARRAY<STRUCT>.
fn extract_struct_type(
    expr_type: &Type,
    ctx: &mut StatementTranslationContext,
    expr_ast: &hamelin_lib::tree::ast::expression::Expression,
) -> Result<(bool, Struct), Arc<TranslationError>> {
    match expr_type {
        Type::Struct(s) => Ok((false, s.clone())),
        Type::Array(arr) => match arr.element_type.as_ref() {
            Type::Struct(s) => Ok((true, s.clone())),
            other => Err(ctx
                .error(format!(
                    "UNNEST requires STRUCT or ARRAY<STRUCT>, found ARRAY<{}>",
                    other
                ))
                .at(expr_ast)
                .emit()),
        },
        other => Err(ctx
            .error(format!(
                "UNNEST requires STRUCT or ARRAY<STRUCT>, found {}",
                other
            ))
            .at(expr_ast)
            .emit()),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        tree::{
            ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
            builder::{
                array, column_ref, drop_command, field, let_command, pipeline, select_command,
                struct_literal, unnest_command,
            },
        },
        types::{array::Array, struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: No UNNEST commands - passes through unchanged
    #[case::no_unnest_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: UNNEST of STRUCT column → LET (temp) + LET (extract) + DROP (temp), original preserved
    #[case::unnest_struct_column(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", "hello"))
                .build())
            .command(unnest_command(column_ref("x")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", "hello"))
                .build())
            .command(let_command()
                .named_field("__unnest_0", column_ref("x"))
                .build())
            .command(let_command()
                .named_field("a", field(column_ref("__unnest_0"), "a"))
                .named_field("b", field(column_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", STRING)
            .with_str("x", Struct::default().with_str("a", INT).with_str("b", STRING).into())
    )]
    // Case 3: UNNEST of ARRAY<STRUCT> column → LET (temp) + EXPLODE (temp) + LET (extract) + DROP (temp)
    // Original array column preserved unchanged
    #[case::unnest_array_struct_column(
        pipeline()
            .command(select_command()
                .named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
                .build())
            .command(unnest_command(column_ref("arr")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
                .build())
            .command(let_command()
                .named_field("__unnest_0", column_ref("arr"))
                .build())
            .command(explode_command().named_field("__unnest_0", column_ref("__unnest_0")).build())
            .command(let_command()
                .named_field("a", field(column_ref("__unnest_0"), "a"))
                .named_field("b", field(column_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", INT)
            .with_str("arr", Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into())
    )]
    // Case 4: UNNEST of complex expression → LET (assign to temp) + LET (extract) + DROP (temp)
    #[case::unnest_complex_expression(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", 2))
                .build())
            .command(unnest_command(struct_literal().field("a", 10).field("b", 20)).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", 2))
                .build())
            .command(let_command()
                .named_field("__unnest_0", struct_literal().field("a", 10).field("b", 20))
                .build())
            .command(let_command()
                .named_field("a", field(column_ref("__unnest_0"), "a"))
                .named_field("b", field(column_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", INT)
            .with_str("x", Struct::default().with_str("a", INT).with_str("b", INT).into())
    )]
    // Case 5: Multiple UNNEST commands - both original columns are preserved
    #[case::multiple_unnest(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1))
                .named_field("y", struct_literal().field("b", 2))
                .build())
            .command(unnest_command(column_ref("x")).build())
            .command(unnest_command(column_ref("y")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1))
                .named_field("y", struct_literal().field("b", 2))
                .build())
            .command(let_command().named_field("__unnest_0", column_ref("x")).build())
            .command(let_command().named_field("a", field(column_ref("__unnest_0"), "a")).build())
            .command(drop_command().field("__unnest_0").build())
            .command(let_command().named_field("__unnest_1", column_ref("y")).build())
            .command(let_command().named_field("b", field(column_ref("__unnest_1"), "b")).build())
            .command(drop_command().field("__unnest_1").build())
            .build(),
        Struct::default()
            .with_str("b", INT)
            .with_str("a", INT)
            .with_str("x", Struct::default().with_str("a", INT).into())
            .with_str("y", Struct::default().with_str("b", INT).into())
    )]
    fn test_lower_unnest(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = input.typed_with().typed();
        let expected_typed = expected.typed_with().typed();

        let mut ctx = StatementTranslationContext::default();
        let result = lower_unnest(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().flatten();
        assert_eq!(result_schema, expected_output_schema);
    }
}