hamelin_translation 0.7.1

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: EXPLODE normalization.
//!
//! Normalizes EXPLODE commands to the canonical `EXPLODE col = col` form required by IR.
//!
//! **Case 1: Compound identifier**
//! ```text
//! EXPLODE items.expanded = array_field
//! ```
//! becomes:
//! ```text
//! EXPLODE __explode_0 = array_field | LET items.expanded = __explode_0 | DROP __explode_0
//! ```
//!
//! **Case 2: Simple identifier with different expression**
//! ```text
//! EXPLODE x = some_array_expr
//! ```
//! becomes:
//! ```text
//! LET x = some_array_expr | EXPLODE x = x
//! ```
//!
//! **Case 3: Already canonical** (`EXPLODE col = col`)
//! ```text
//! EXPLODE x = x
//! ```
//! Passes through unchanged.

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::{command::Command, identifier::Identifier, identifier::SimpleIdentifier},
    builder::{self, drop_command, explode_command, field_ref, let_command},
    typed_ast::{
        command::{TypedCommand, TypedCommandKind, TypedExplodeCommand},
        context::StatementTranslationContext,
        expression::TypedExpressionKind,
        pipeline::TypedPipeline,
    },
};

use crate::unique::UniqueNameGenerator;

/// Normalize EXPLODE commands to canonical `EXPLODE col = col` form.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn normalize_explode(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any EXPLODE command needs normalization
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(explode_needs_normalization)
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Shared name generator for all EXPLODE commands in this pipeline
    let mut name_gen = UniqueNameGenerator::new("__explode");

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in normalize_command(cmd, &mut name_gen)? {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Check if an EXPLODE command needs normalization.
///
/// Returns true if the EXPLODE is NOT in canonical `col = col` form.
fn explode_needs_normalization(cmd: &Arc<TypedCommand>) -> bool {
    let TypedCommandKind::Explode(explode_cmd) = &cmd.kind else {
        return false;
    };
    !is_canonical_explode(explode_cmd)
}

/// Check if EXPLODE is in canonical `col = col` form.
/// For multi-item EXPLODE (from lowering passes), all items must be canonical.
fn is_canonical_explode(explode_cmd: &TypedExplodeCommand) -> bool {
    explode_cmd.items.iter().all(|item| {
        // Get the identifier (must be simple)
        let Ok(Identifier::Simple(simple_id)) = item.assignment.identifier.valid_ref() else {
            return false;
        };

        // Check if expression is a column reference to the same name
        let TypedExpressionKind::FieldReference(col_ref) = &item.assignment.expression.kind else {
            return false;
        };

        let Ok(col_name) = col_ref.field_name.valid_ref() else {
            return false;
        };

        // Canonical if identifier == expression column name
        simple_id.as_str() == col_name.as_str()
    })
}

/// Normalize a single command - transforms EXPLODE to canonical form, passes others through.
fn normalize_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let TypedCommandKind::Explode(explode_cmd) = &cmd.kind else {
        return Ok(vec![cmd.ast.clone()]);
    };

    // Already canonical - pass through
    if is_canonical_explode(explode_cmd) {
        return Ok(vec![cmd.ast.clone()]);
    }

    transform_explode(explode_cmd, cmd, name_gen)
}

/// Info about how to normalize a single explode item.
enum ItemNormalization {
    /// Already canonical - just use `col = col`
    Canonical(SimpleIdentifier),
    /// Simple identifier, non-canonical expr - needs `LET col = expr` before, then `col = col`
    SimpleNonCanonical { col_name: SimpleIdentifier },
    /// Compound identifier - needs temp name, then restore after
    /// (the post-LET and post-DROP commands are built inline, we just need temp_name for the EXPLODE)
    Compound { temp_name: SimpleIdentifier },
}

/// Transform an EXPLODE command to canonical form.
///
/// For multi-item EXPLODE, produces:
/// 1. LET commands for non-canonical expressions
/// 2. Single multi-item EXPLODE with all canonical `col = col` forms
/// 3. LET commands to restore compound identifiers
/// 4. DROP commands to clean up temp names
fn transform_explode(
    explode_cmd: &TypedExplodeCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
    let mut pre_lets: Vec<Arc<Command>> = Vec::new();
    let mut post_lets: Vec<Arc<Command>> = Vec::new();
    let mut post_drops: Vec<Arc<Command>> = Vec::new();
    let mut normalizations: Vec<ItemNormalization> = Vec::new();

    // Analyze each item and determine what normalization is needed
    for item in &explode_cmd.items {
        let identifier = item.assignment.identifier.valid_ref()?;

        // Check if already canonical
        let is_canonical = if let Identifier::Simple(simple_id) = identifier {
            if let TypedExpressionKind::FieldReference(col_ref) = &item.assignment.expression.kind {
                if let Ok(col_name) = col_ref.field_name.valid_ref() {
                    simple_id.as_str() == col_name.as_str()
                } else {
                    false
                }
            } else {
                false
            }
        } else {
            false
        };

        if is_canonical {
            let Identifier::Simple(simple_id) = identifier else {
                unreachable!()
            };
            normalizations.push(ItemNormalization::Canonical(simple_id.clone()));
        } else {
            match identifier {
                Identifier::Simple(simple_id) => {
                    // Simple identifier but not canonical - need LET before
                    let col_name = simple_id.clone();
                    pre_lets.push(Arc::new(
                        let_command()
                            .named_field(
                                col_name.clone(),
                                item.assignment.expression.ast.as_ref().clone(),
                            )
                            .at(cmd.ast.span)
                            .build(),
                    ));
                    normalizations.push(ItemNormalization::SimpleNonCanonical { col_name });
                }
                Identifier::Compound(compound) => {
                    // Compound identifier - need temp name
                    let temp_name: SimpleIdentifier = name_gen.next(&cmd.input_schema);

                    // LET __temp = expr (before EXPLODE)
                    pre_lets.push(Arc::new(
                        let_command()
                            .named_field(
                                temp_name.clone(),
                                item.assignment.expression.ast.as_ref().clone(),
                            )
                            .at(cmd.ast.span)
                            .build(),
                    ));

                    // LET x.y = __temp (after EXPLODE)
                    let original: Identifier = compound.clone().into();
                    post_lets.push(Arc::new(
                        let_command()
                            .named_field(original.clone(), field_ref(temp_name.as_str()))
                            .at(cmd.ast.span)
                            .build(),
                    ));

                    // DROP __temp (cleanup)
                    post_drops.push(Arc::new(
                        drop_command()
                            .field(temp_name.clone())
                            .at(cmd.ast.span)
                            .build(),
                    ));

                    normalizations.push(ItemNormalization::Compound { temp_name });
                }
            }
        }
    }

    // Collect all column names for the canonical EXPLODE
    let col_names: Vec<SimpleIdentifier> = normalizations
        .iter()
        .map(|norm| match norm {
            ItemNormalization::Canonical(name) => name.clone(),
            ItemNormalization::SimpleNonCanonical { col_name } => col_name.clone(),
            ItemNormalization::Compound { temp_name } => temp_name.clone(),
        })
        .collect();

    // Build the canonical EXPLODE command with all items
    // We need at least one item (guaranteed by caller)
    let first_col = &col_names[0];
    let explode_builder =
        explode_command().named_field(first_col.clone(), field_ref(first_col.as_str()));

    // Add remaining items (if any)
    let explode_builder = col_names[1..].iter().fold(explode_builder, |builder, col| {
        builder.named_field(col.clone(), field_ref(col.as_str()))
    });

    let explode = explode_builder.at(cmd.ast.span).build();

    // Assemble the result: pre_lets + explode + post_lets + post_drops
    let mut result = pre_lets;
    result.push(Arc::new(explode));
    result.extend(post_lets);
    result.extend(post_drops);

    Ok(result)
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::type_check;
    use hamelin_lib::{
        tree::{
            ast::pipeline::Pipeline,
            builder::{
                array, drop_command, explode_command, field_ref, let_command, pipeline,
                select_command,
            },
        },
        types::{array::Array, struct_type::Struct, INT},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: Already canonical (EXPLODE arr = arr) - passes through unchanged
    #[case::canonical_unchanged(
        pipeline()
            .command(select_command().named_field("arr", array().element(1).element(2)).build())
            .command(explode_command().named_field("arr", field_ref("arr")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("arr", array().element(1).element(2)).build())
            .command(explode_command().named_field("arr", field_ref("arr")).build())
            .build(),
        Struct::default().with_str("arr", INT)
    )]
    // Case 2: Simple id, different expr (EXPLODE x = arr) -> LET x = arr | EXPLODE x = x
    #[case::simple_id_different_expr(
        pipeline()
            .command(select_command().named_field("arr", array().element(1).element(2)).build())
            .command(explode_command().named_field("x", field_ref("arr")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("arr", array().element(1).element(2)).build())
            .command(let_command().named_field("x", field_ref("arr")).build())
            .command(explode_command().named_field("x", field_ref("x")).build())
            .build(),
        Struct::default().with_str("x", INT).with_str("arr", Array::new(INT).into())
    )]
    // Case 3: No EXPLODE commands - passes through unchanged
    #[case::no_explode_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 4: Multiple non-canonical EXPLODEs
    #[case::multiple_explodes(
        pipeline()
            .command(select_command()
                .named_field("arr1", array().element(1))
                .named_field("arr2", array().element(2))
                .build())
            .command(explode_command().named_field("x", field_ref("arr1")).build())
            .command(explode_command().named_field("y", field_ref("arr2")).build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("arr1", array().element(1))
                .named_field("arr2", array().element(2))
                .build())
            .command(let_command().named_field("x", field_ref("arr1")).build())
            .command(explode_command().named_field("x", field_ref("x")).build())
            .command(let_command().named_field("y", field_ref("arr2")).build())
            .command(explode_command().named_field("y", field_ref("y")).build())
            .build(),
        Struct::default()
            .with_str("y", INT)
            .with_str("x", INT)
            .with_str("arr1", Array::new(INT).into())
            .with_str("arr2", Array::new(INT).into())
    )]
    // Case 5: Compound identifier (EXPLODE result.item = arr) -> LET + EXPLODE + LET + DROP
    #[case::compound_id(
        pipeline()
            .command(select_command().named_field("arr", array().element(1)).build())
            .command(explode_command()
                .named_field(
                    hamelin_lib::tree::ast::identifier::CompoundIdentifier::new("result".into(), "item".into(), vec![]),
                    field_ref("arr")
                )
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("arr", array().element(1)).build())
            .command(let_command().named_field("__explode_0", field_ref("arr")).build())
            .command(explode_command().named_field("__explode_0", field_ref("__explode_0")).build())
            .command(let_command()
                .named_field(
                    hamelin_lib::tree::ast::identifier::CompoundIdentifier::new("result".into(), "item".into(), vec![]),
                    field_ref("__explode_0")
                )
                .build())
            .command(drop_command().field("__explode_0").build())
            .build(),
        Struct::default()
            .with_str("result", Struct::default().with_str("item", INT).into())
            .with_str("arr", Array::new(INT).into())
    )]
    fn test_normalize_explode(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = type_check(input).output;
        let expected_typed = type_check(expected).output;

        let mut ctx = StatementTranslationContext::default();
        let result = normalize_explode(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().as_struct().clone();
        assert_eq!(result_schema, expected_output_schema);
    }
}