hamelin_translation 0.9.7

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: NEST lowering.
//!
//! Lowers NEST commands to SELECT with a struct literal:
//!
//! ```text
//! FROM events | NEST user
//! ```
//! where events has fields `{a, b, c}` becomes:
//! ```text
//! FROM events | SELECT user = {a: a, b: b, c: c}
//! ```
//!
//! For compound identifiers:
//! ```text
//! FROM events | NEST user.address
//! ```
//! becomes:
//! ```text
//! FROM events | SELECT user = {address: {a: a, b: b, c: c}}
//! ```

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::command::Command,
    builder::{self, field_ref, select_command},
    typed_ast::{
        command::{TypedCommand, TypedCommandKind, TypedNestCommand},
        context::StatementTranslationContext,
        pipeline::TypedPipeline,
    },
};

/// Lower NEST commands to SELECT with struct literals.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn lower_nest(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any NEST command exists
    let has_nest = pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(|cmd| matches!(&cmd.kind, TypedCommandKind::Nest(_)));

    if !has_nest {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        pipe_builder = pipe_builder.command(lower_command(cmd, ctx)?);
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Lower a single command - transforms NEST, passes others through.
fn lower_command(
    cmd: &Arc<TypedCommand>,
    ctx: &mut StatementTranslationContext,
) -> Result<Command, Arc<TranslationError>> {
    let TypedCommandKind::Nest(nest_cmd) = &cmd.kind else {
        return Ok(cmd.ast.as_ref().clone());
    };

    lower_nest_command(nest_cmd, cmd, ctx)
}

/// Lower a NEST command to SELECT with compound identifier assignments.
///
/// Instead of `SELECT user_info = {a: a, b: b}`, outputs:
/// `SELECT user_info.a = a, user_info.b = b`
///
/// This avoids struct literals which get unpacked during type-checking anyway.
fn lower_nest_command(
    nest_cmd: &TypedNestCommand,
    cmd: &TypedCommand,
    _ctx: &mut StatementTranslationContext,
) -> Result<Command, Arc<TranslationError>> {
    use hamelin_lib::tree::ast::identifier::SimpleIdentifier;

    // Get the target identifier (should always be valid after type-checking)
    let identifier = nest_cmd.identifier.valid_ref()?;

    // Build SELECT with compound identifiers: prefix.field = field for each field
    let mut select = select_command().at(cmd.ast.span);
    for (field_name, _field_type) in nest_cmd.nested_type.iter() {
        // Build compound identifier: prefix + field using the Add trait
        let field_id: SimpleIdentifier = field_name.clone().into();
        let compound_id = identifier.clone() + field_id.into();
        select = select.named_field(compound_id, field_ref(field_name.clone()));
    }

    Ok(select.build())
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::type_check;
    use hamelin_lib::{
        tree::{
            ast::pipeline::Pipeline,
            builder::{field_ref, ident, nest_command, pipeline, select_command},
        },
        types::{struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: No NEST commands - passes through unchanged
    #[case::no_nest_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: Simple NEST -> SELECT with compound identifiers (prefix.field = field)
    #[case::simple_nest(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", "hello").build())
            .command(nest_command("user").build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", "hello").build())
            .command(select_command()
                .named_field(ident("user").dot("a"), field_ref("a"))
                .named_field(ident("user").dot("b"), field_ref("b"))
                .build())
            .build(),
        Struct::default().with_str("user", Struct::default().with_str("a", INT).with_str("b", STRING).into())
    )]
    // Case 3: Compound NEST -> SELECT with deeper compound identifiers
    #[case::compound_nest(
        pipeline()
            .command(select_command().named_field("x", 42).build())
            .command(nest_command(ident("user").dot("address")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("x", 42).build())
            .command(select_command()
                .named_field(ident("user").dot("address").dot("x"), field_ref("x"))
                .build())
            .build(),
        Struct::default().with_str("user",
            Struct::default().with_str("address",
                Struct::default().with_str("x", INT).into()).into())
    )]
    fn test_lower_nest(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = type_check(input).output;
        let expected_typed = type_check(expected).output;

        let mut ctx = StatementTranslationContext::default();
        let result = lower_nest(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().as_struct().clone();
        assert_eq!(result_schema, expected_output_schema);
    }
}