hamelin_translation 0.8.0

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: UNNEST lowering.
//!
//! Lowers UNNEST and ROWS commands to simpler primitives. When the expression is a
//! field reference (simple or dotted, e.g. `x` or `data.items`), that source column
//! is dropped to avoid duplicating data. Literal / non-field expressions (typical for
//! `ROWS [...]`) have no column to drop, so no `DROP` is emitted.
//!
//! **Case 1: UNNEST of STRUCT column `x`** (where x: {a, b})
//! ```text
//! UNNEST x
//! ```
//! becomes:
//! ```text
//! SET __unnest_0 = x | DROP x | SET a = __unnest_0.a, b = __unnest_0.b | DROP __unnest_0
//! ```
//!
//! **Case 2: UNNEST of ARRAY<STRUCT> column `arr`** (where arr: ARRAY<{a, b}>)
//! ```text
//! UNNEST arr
//! ```
//! becomes:
//! ```text
//! SET __unnest_0 = arr | EXPLODE __unnest_0 = __unnest_0 | DROP arr | SET a = __unnest_0.a, b = __unnest_0.b | DROP __unnest_0
//! ```

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::{command::Command, identifier::SimpleIdentifier},
    builder::{self, drop_command, explode_command, field, field_ref, set_command},
    typed_ast::{
        command::{TypedCommand, TypedCommandKind, TypedUnnestCommand},
        context::StatementTranslationContext,
        pipeline::TypedPipeline,
    },
};
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::Type;

use crate::unique::UniqueNameGenerator;

/// Lower UNNEST commands in a pipeline.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn lower_unnest(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any UNNEST command exists
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(|cmd| matches!(&cmd.kind, TypedCommandKind::Unnest(_)))
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Shared name generator for temp columns
    let mut name_gen = UniqueNameGenerator::new("__unnest");

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in lower_command(cmd, &mut name_gen, ctx)? {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Lower a single command - transforms UNNEST, passes others through.
fn lower_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let TypedCommandKind::Unnest(unnest_cmd) = &cmd.kind else {
        return Ok(vec![cmd.ast.clone()]);
    };

    lower_unnest_command(unnest_cmd, cmd, name_gen, ctx)
}

/// Lower an UNNEST command to SET/DROP (and EXPLODE for arrays).
fn lower_unnest_command(
    unnest_cmd: &TypedUnnestCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let expr = &unnest_cmd.expression;
    let expr_type = expr.resolved_type.as_ref();

    // Get the struct type (either directly or from array element)
    let (is_array, struct_type) = extract_struct_type(expr_type, ctx, expr.ast.as_ref())?;

    let mut result: Vec<Arc<Command>> = Vec::new();

    // Use a temp column for the intermediate steps.
    let temp_name = name_gen.next(&cmd.input_schema);

    // SET __unnest_0 = expr
    let set_cmd = set_command()
        .named_field(temp_name.clone(), expr.ast.as_ref().clone())
        .at(cmd.ast.span)
        .build();
    result.push(Arc::new(set_cmd));

    // If array, add EXPLODE to convert ARRAY<STRUCT> to STRUCT
    if is_array {
        let explode = explode_command()
            .named_field(temp_name.clone(), field_ref(temp_name.as_str()))
            .at(cmd.ast.span)
            .build();
        result.push(Arc::new(explode));
    }

    // Drop the source column BEFORE extracting struct fields when the expression names one.
    // Ordering matters: if the struct has a field with the same name as the source column,
    // dropping after would remove the newly-created field instead of the original.
    if let Some(drop_id) = expr.unnest_source_drop_identifier() {
        let drop_source = drop_command().field(drop_id).at(cmd.ast.span).build();
        result.push(Arc::new(drop_source));
    }

    // Build SET to extract struct fields: SET a = __unnest_0.a, b = __unnest_0.b, ...
    let mut set_builder = set_command().at(cmd.ast.span);
    for (field_name, _field_type) in struct_type.iter() {
        let field_expr = field(field_ref(temp_name.as_str()), field_name.name());
        let field_id: SimpleIdentifier = field_name.clone().into();
        set_builder = set_builder.named_field(field_id, field_expr);
    }
    result.push(Arc::new(set_builder.build()));

    // Always DROP the temp column
    let drop = drop_command().field(temp_name).at(cmd.ast.span).build();
    result.push(Arc::new(drop));

    Ok(result)
}

/// Extract struct type from expression type.
///
/// Returns (is_array, struct_type) or an error if type is not STRUCT or ARRAY<STRUCT>.
fn extract_struct_type(
    expr_type: &Type,
    ctx: &mut StatementTranslationContext,
    expr_ast: &hamelin_lib::tree::ast::expression::Expression,
) -> Result<(bool, Struct), Arc<TranslationError>> {
    match expr_type {
        Type::Struct(s) => Ok((false, s.clone())),
        Type::Array(arr) => match arr.element_type.as_ref() {
            Type::Struct(s) => Ok((true, s.clone())),
            other => Err(ctx
                .error(format!(
                    "UNNEST requires STRUCT or ARRAY<STRUCT>, found ARRAY<{}>",
                    other
                ))
                .at(expr_ast)
                .emit()),
        },
        other => Err(ctx
            .error(format!(
                "UNNEST requires STRUCT or ARRAY<STRUCT>, found {}",
                other
            ))
            .at(expr_ast)
            .emit()),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::type_check;
    use hamelin_lib::{
        tree::{
            ast::{identifier::CompoundIdentifier, pipeline::Pipeline},
            builder::{
                array, drop_command, field, field_ref, pipeline, select_command, set_command,
                struct_literal, unnest_command,
            },
        },
        types::{struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: No UNNEST commands - passes through unchanged
    #[case::no_unnest_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: UNNEST of STRUCT column → SET (temp) + DROP (source) + SET (extract) + DROP (temp)
    #[case::unnest_struct_column(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", "hello"))
                .build())
            .command(unnest_command(field_ref("x")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", "hello"))
                .build())
            .command(set_command()
                .named_field("__unnest_0", field_ref("x"))
                .build())
            .command(drop_command().field("x").build())
            .command(set_command()
                .named_field("a", field(field_ref("__unnest_0"), "a"))
                .named_field("b", field(field_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", STRING)
    )]
    // Case 3: UNNEST of ARRAY<STRUCT> column → SET (temp) + EXPLODE (temp) + DROP (source) + SET (extract) + DROP (temp)
    #[case::unnest_array_struct_column(
        pipeline()
            .command(select_command()
                .named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
                .build())
            .command(unnest_command(field_ref("arr")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
                .build())
            .command(set_command()
                .named_field("__unnest_0", field_ref("arr"))
                .build())
            .command(explode_command().named_field("__unnest_0", field_ref("__unnest_0")).build())
            .command(drop_command().field("arr").build())
            .command(set_command()
                .named_field("a", field(field_ref("__unnest_0"), "a"))
                .named_field("b", field(field_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", INT)
    )]
    // Case 4: UNNEST of complex expression → no source column to drop (not a field reference)
    #[case::unnest_complex_expression(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", 2))
                .build())
            .command(unnest_command(struct_literal().field("a", 10).field("b", 20)).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1).field("b", 2))
                .build())
            .command(set_command()
                .named_field("__unnest_0", struct_literal().field("a", 10).field("b", 20))
                .build())
            .command(set_command()
                .named_field("a", field(field_ref("__unnest_0"), "a"))
                .named_field("b", field(field_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", INT)
            .with_str("x", Struct::default().with_str("a", INT).with_str("b", INT).into())
    )]
    // Case 5: Multiple UNNEST commands - both source columns are dropped
    #[case::multiple_unnest(
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1))
                .named_field("y", struct_literal().field("b", 2))
                .build())
            .command(unnest_command(field_ref("x")).build())
            .command(unnest_command(field_ref("y")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("x", struct_literal().field("a", 1))
                .named_field("y", struct_literal().field("b", 2))
                .build())
            .command(set_command().named_field("__unnest_0", field_ref("x")).build())
            .command(drop_command().field("x").build())
            .command(set_command().named_field("a", field(field_ref("__unnest_0"), "a")).build())
            .command(drop_command().field("__unnest_0").build())
            .command(set_command().named_field("__unnest_1", field_ref("y")).build())
            .command(drop_command().field("y").build())
            .command(set_command().named_field("b", field(field_ref("__unnest_1"), "b")).build())
            .command(drop_command().field("__unnest_1").build())
            .build(),
        Struct::default()
            .with_str("b", INT)
            .with_str("a", INT)
    )]
    // Case 6: UNNEST compound field reference → DROP uses full path (e.g. data.items)
    #[case::unnest_compound_field_reference(
        pipeline()
            .command(select_command()
                .named_field("data", struct_literal()
                    .field("items", array().element(struct_literal().field("a", 1).field("b", 2))))
                .build())
            .command(unnest_command(field(field_ref("data"), "items")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("data", struct_literal()
                    .field("items", array().element(struct_literal().field("a", 1).field("b", 2))))
                .build())
            .command(set_command()
                .named_field("__unnest_0", field(field_ref("data"), "items"))
                .build())
            .command(explode_command().named_field("__unnest_0", field_ref("__unnest_0")).build())
            .command(drop_command().field(CompoundIdentifier::from_two("data", "items")).build())
            .command(set_command()
                .named_field("a", field(field_ref("__unnest_0"), "a"))
                .named_field("b", field(field_ref("__unnest_0"), "b"))
                .build())
            .command(drop_command().field("__unnest_0").build())
            .build(),
        Struct::default()
            .with_str("a", INT)
            .with_str("b", INT)
    )]
    fn test_lower_unnest(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = type_check(input).output;
        let expected_typed = type_check(expected).output;

        let mut ctx = StatementTranslationContext::default();
        let result = lower_unnest(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().as_struct().clone();
        assert_eq!(result_schema, expected_output_schema);
    }
}