hamelin_translation 0.4.2

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: WINDOW compound identifier normalization.
//!
//! Lowers compound identifiers in WINDOW commands to flat identifiers + LET/DROP restoration.
//!
//! Example:
//! ```text
//! FROM events | WINDOW stats.running = sum(value) BY category
//! ```
//! becomes:
//! ```text
//! FROM events | WINDOW __flat_0 = sum(value) BY category
//!             | LET stats.running = __flat_0
//!             | DROP __flat_0
//! ```

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::{command::Command, identifier::Identifier},
    builder::{self, window_command},
    typed_ast::{
        clause::Projections,
        command::{TypedCommand, TypedCommandKind, TypedWindowCommand},
        context::StatementTranslationContext,
        pipeline::TypedPipeline,
    },
};

use super::super::compound_lowering::{lower_compound_assignments, UniqueNameGenerator};

/// Normalize WINDOW commands with compound identifiers in a pipeline.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn normalize_window(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any WINDOW command has compound identifiers
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(window_has_compounds)
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Shared name generator for all WINDOW commands in this pipeline
    let mut name_gen = UniqueNameGenerator::new("__normalize_window");

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in normalize_command(cmd, &mut name_gen) {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Check if a command is a WINDOW with compound identifiers.
fn window_has_compounds(cmd: &Arc<TypedCommand>) -> bool {
    let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
        return false;
    };
    has_compound_identifiers(&window_cmd.projections)
        || has_compound_identifiers(&window_cmd.group_by)
}

/// Check if any assignment in projections has a compound identifier.
fn has_compound_identifiers(projections: &Projections) -> bool {
    projections.assignments.iter().any(|a| {
        a.identifier
            .valid_ref()
            .map(|id| matches!(id, Identifier::Compound(_)))
            .unwrap_or(false)
    })
}

/// Normalize a single command - transforms WINDOW with compounds, passes others through.
fn normalize_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
        return vec![cmd.ast.clone()];
    };

    transform_window(window_cmd, cmd, name_gen)
}

/// Transform a WINDOW command, lowering compound identifiers.
fn transform_window(
    window_cmd: &TypedWindowCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    // Lower compound assignments in projections and group_by
    let (window_assignments, window_restores) =
        lower_compound_assignments(&window_cmd.projections, name_gen, &cmd.input_schema);
    let (group_by_assignments, group_by_restores) =
        lower_compound_assignments(&window_cmd.group_by, name_gen, &cmd.input_schema);

    // Build the WINDOW command with lowered assignments
    let mut builder = window_command().at(cmd.ast.span);
    for (id, expr) in window_assignments {
        builder = builder.named_field(id, expr);
    }
    for (id, expr) in group_by_assignments {
        builder = builder.group_by(id, expr);
    }
    for sort_expr in &window_cmd.sort_by {
        builder = builder.sort_expr(sort_expr.ast.as_ref().clone());
    }
    if let Some(within) = &window_cmd.within {
        builder = builder.within(within.ast.clone());
    }

    // Combine: WINDOW command + restore commands
    let mut result = vec![Arc::new(builder.build())];
    result.extend(window_restores.into_iter().map(Arc::new));
    result.extend(group_by_restores.into_iter().map(Arc::new));
    result
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        tree::ast::expression::IntervalUnit,
        tree::{
            ast::{
                identifier::CompoundIdentifier, pipeline::Pipeline, IntoTyped, TypeCheckExecutor,
            },
            builder::{
                call, column_ref, drop_command, let_command, pipeline, select_command,
                sort_command, window_command, IntervalLiteralBuilder,
            },
        },
        types::{struct_type::Struct, INT},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: No WINDOW commands - passes through unchanged
    #[case::no_window_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: WINDOW with simple identifiers - passes through unchanged
    #[case::window_simple_ids_unchanged(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(window_command()
                .named_field("running", call("sum").arg(column_ref("value")))
                .group_by("category", column_ref("category"))
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(window_command()
                .named_field("running", call("sum").arg(column_ref("value")))
                .group_by("category", column_ref("category"))
                .build())
            .build(),
        // Schema order: projections first, then partition_by, then parent fields
        // WINDOW binds partition_by fields directly, so category comes before value
        Struct::default()
            .with_str("running", INT)
            .with_str("category", INT)
            .with_str("value", INT)
    )]
    // Case 3: WINDOW with compound identifier in projection → WINDOW + LET + DROP
    #[case::window_compound_projection(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(window_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    call("sum").arg(column_ref("value"))
                )
                .group_by("category", column_ref("category"))
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(window_command()
                .named_field("__normalize_window_0", call("sum").arg(column_ref("value")))
                .group_by("category", column_ref("category"))
                .build())
            .command(let_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    column_ref("__normalize_window_0")
                )
                .build())
            .command(drop_command().field("__normalize_window_0").build())
            .build(),
        // Schema order after LET and DROP:
        // LET prepends stats, WINDOW had {__normalize_window_0, category, value}, DROP removes temp
        // Result: {stats.running, category, value}
        Struct::default()
            .with_str("stats", Struct::default().with_str("running", INT).into())
            .with_str("category", INT)
            .with_str("value", INT)
    )]
    // Case 4: WINDOW with compound identifier in group_by → WINDOW + LET + DROP
    // The group_by compound identifier is lowered to a temp, then LET restores the compound path.
    // The LET correctly infers the type from the temp column reference.
    #[case::window_compound_group_by(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("cat", 1).build())
            .command(window_command()
                .named_field("running", call("sum").arg(column_ref("value")))
                .group_by(
                    CompoundIdentifier::new("group".into(), "key".into(), vec![]),
                    column_ref("cat")
                )
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("cat", 1).build())
            .command(window_command()
                .named_field("running", call("sum").arg(column_ref("value")))
                .group_by("__normalize_window_0", column_ref("cat"))
                .build())
            .command(let_command()
                .named_field(
                    CompoundIdentifier::new("group".into(), "key".into(), vec![]),
                    column_ref("__normalize_window_0")
                )
                .build())
            .command(drop_command().field("__normalize_window_0").build())
            .build(),
        // Schema order after LET and DROP:
        // LET prepends group, WINDOW had {running, __normalize_window_0, value, cat}, DROP removes temp
        // Result: {group.key, running, value, cat}
        Struct::default()
            .with_str(
                "group",
                Struct::default()
                    .with_str("key", INT)
                    .into(),
            )
            .with_str("running", INT)
            .with_str("value", INT)
            .with_str("cat", INT)
    )]
    // Case 5: WINDOW with compound identifiers in both projection and group_by
    // Both compound identifiers are lowered to temps, then LET restores the compound paths.
    #[case::window_compound_both(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("cat", 1).build())
            .command(window_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    call("sum").arg(column_ref("value"))
                )
                .group_by(
                    CompoundIdentifier::new("group".into(), "key".into(), vec![]),
                    column_ref("cat")
                )
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("cat", 1).build())
            .command(window_command()
                .named_field("__normalize_window_0", call("sum").arg(column_ref("value")))
                .group_by("__normalize_window_1", column_ref("cat"))
                .build())
            .command(let_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    column_ref("__normalize_window_0")
                )
                .build())
            .command(drop_command().field("__normalize_window_0").build())
            .command(let_command()
                .named_field(
                    CompoundIdentifier::new("group".into(), "key".into(), vec![]),
                    column_ref("__normalize_window_1")
                )
                .build())
            .command(drop_command().field("__normalize_window_1").build())
            .build(),
        Struct::default()
            .with_str(
                "group",
                Struct::default()
                    .with_str("key", INT)
                    .into(),
            )
            .with_str("stats", Struct::default().with_str("running", INT).into())
            .with_str("value", INT)
            .with_str("cat", INT)
    )]
    // Case 6: WINDOW with compound projection and WITHIN preserved
    #[case::window_compound_within(
        pipeline()
            .command(select_command()
                .named_field("value", 10)
                .named_field("category", 1)
                .named_field("timestamp", 5)
                .build())
            .command(window_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    call("sum").arg(column_ref("value"))
                )
                .group_by("category", column_ref("category"))
                .sort(sort_command().by(column_ref("timestamp")))
                .within(IntervalLiteralBuilder::new(-5, IntervalUnit::Hour))
                .build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("value", 10)
                .named_field("category", 1)
                .named_field("timestamp", 5)
                .build())
            .command(window_command()
                .named_field("__normalize_window_0", call("sum").arg(column_ref("value")))
                .group_by("category", column_ref("category"))
                .sort(sort_command().by(column_ref("timestamp")))
                .within(IntervalLiteralBuilder::new(-5, IntervalUnit::Hour))
                .build())
            .command(let_command()
                .named_field(
                    CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
                    column_ref("__normalize_window_0")
                )
                .build())
            .command(drop_command().field("__normalize_window_0").build())
            .build(),
        // Schema order: stats prepended by LET, then WINDOW's {category}, then parent's {value, timestamp}
        Struct::default()
            .with_str("stats", Struct::default().with_str("running", INT).into())
            .with_str("category", INT)
            .with_str("value", INT)
            .with_str("timestamp", INT)
    )]
    fn test_normalize_window(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = input.typed_with().typed();
        let expected_typed = expected.typed_with().typed();

        let mut ctx = StatementTranslationContext::default();
        let result = normalize_window(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().flatten();
        assert_eq!(result_schema, expected_output_schema);
    }
}