hamelin_translation 0.4.3

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: Extract nested aggregates from WINDOW expressions.
//!
//! Transforms WINDOW commands where aggregate functions are buried in expressions
//! into WINDOW commands with only top-level aggregates, followed by a SELECT
//! that computes the final expression.
//!
//! Example:
//! ```text
//! WINDOW crazy_number = sum(left) + sum(right) + 3 SORT BY order
//! ```
//! becomes:
//! ```text
//! WINDOW __agg_0 = sum(left), __agg_1 = sum(right) SORT BY order
//! | SELECT crazy_number = __agg_0 + __agg_1 + 3, ...passthrough...
//! ```
//!
//! This ensures the IR's WINDOW command only contains top-level aggregate functions,
//! making backend translation straightforward.

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::func::def::SpecialPosition;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_lib::tree::typed_ast::command::{TypedCommand, TypedCommandKind, TypedWindowCommand};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::{
    MapExpressionAlgebra, TypedApply, TypedExpression, TypedExpressionKind,
};
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::{
    ast::command::Command,
    builder::{self, column_ref, drop_command, let_command, window_command, ExpressionBuilder},
};

use super::super::unique::UniqueNameGenerator;

/// Extract nested aggregates from WINDOW expressions.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn extract_window_aggregates(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    // Check if any WINDOW command has nested aggregates
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(window_has_nested_aggregates)
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Shared name generator
    let mut name_gen = UniqueNameGenerator::new("__window_agg");

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in transform_command(cmd, &mut name_gen) {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

/// Check if a command is a WINDOW with nested aggregates in projections.
fn window_has_nested_aggregates(cmd: &Arc<TypedCommand>) -> bool {
    let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
        return false;
    };
    window_cmd
        .projections
        .assignments
        .iter()
        .any(|a| has_nested_aggregate(&a.expression))
}

/// Check if an expression has an aggregate that is not at the top level.
/// Returns true if the expression contains aggregates but is NOT itself an aggregate.
fn has_nested_aggregate(expr: &Arc<TypedExpression>) -> bool {
    // If this IS an aggregate at top level, no nested aggregates
    if is_aggregate_call(expr) {
        return false;
    }
    // Otherwise, check if any descendant contains an aggregate using find()
    expr.find(&mut is_aggregate_call).is_some()
}

/// Check if an expression is an aggregate function call.
fn is_aggregate_call(expr: &TypedExpression) -> bool {
    matches!(&expr.kind, TypedExpressionKind::Apply(apply)
        if apply.function_def.special_position() == Some(SpecialPosition::Agg))
}

/// Transform a single command.
fn transform_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
        return vec![cmd.ast.clone()];
    };

    if !window_has_nested_aggregates(cmd) {
        return vec![cmd.ast.clone()];
    }

    transform_window(window_cmd, cmd, name_gen)
}

/// Algebra for extracting aggregate expressions from WINDOW projections.
///
/// Replaces aggregate calls with column references to synthetic variables,
/// collecting the aggregates for extraction into the WINDOW command.
struct ExtractAggregatesAlgebra<'a> {
    name_gen: &'a mut UniqueNameGenerator,
    schema: &'a TypeEnvironment,
    extractions: Vec<(SimpleIdentifier, Expression)>,
    synth_ids: Vec<SimpleIdentifier>,
}

impl MapExpressionAlgebra for ExtractAggregatesAlgebra<'_> {
    fn apply(
        &mut self,
        node: &TypedApply,
        expr: &TypedExpression,
        children: hamelin_lib::func::def::ParameterBinding<Arc<Expression>>,
    ) -> Arc<Expression> {
        // Check if this is an aggregate call
        if node.function_def.special_position() == Some(SpecialPosition::Agg) {
            // Extract this aggregate
            let synth_id = self.name_gen.next(self.schema);
            self.extractions
                .push((synth_id.clone(), expr.ast.as_ref().clone()));
            self.synth_ids.push(synth_id.clone());

            // Return a column reference to the synthetic variable
            Arc::new(column_ref(synth_id).build())
        } else {
            // Not an aggregate - rebuild with transformed children
            node.replace_children_ast(expr, children)
        }
    }
}

/// Transform a WINDOW command, extracting nested aggregates.
fn transform_window(
    window_cmd: &TypedWindowCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    let mut extracted_aggs: Vec<(SimpleIdentifier, Expression)> = Vec::new();
    let mut let_assignments: Vec<(SimpleIdentifier, Expression)> = Vec::new();
    let mut synth_ids_to_drop: Vec<SimpleIdentifier> = Vec::new();

    // Process each projection
    for assignment in &window_cmd.projections.assignments {
        let Ok(id) = assignment.identifier.valid_ref() else {
            continue;
        };
        let simple_id = match id {
            hamelin_lib::tree::ast::identifier::Identifier::Simple(s) => s.clone(),
            hamelin_lib::tree::ast::identifier::Identifier::Compound(c) => {
                // Should have been normalized by normalize_window already
                SimpleIdentifier::new(c.to_string())
            }
        };

        if has_nested_aggregate(&assignment.expression) {
            // Extract aggregates using cata and build replacement expression
            let mut alg = ExtractAggregatesAlgebra {
                name_gen,
                schema: &cmd.input_schema,
                extractions: Vec::new(),
                synth_ids: Vec::new(),
            };
            let new_ast = assignment.expression.cata(&mut alg);

            extracted_aggs.extend(alg.extractions);
            synth_ids_to_drop.extend(alg.synth_ids);
            let_assignments.push((simple_id, new_ast.as_ref().clone()));
        } else {
            // Top-level aggregate or no aggregate - keep in WINDOW
            extracted_aggs.push((simple_id, assignment.expression.ast.as_ref().clone()));
        }
    }

    // Build the WINDOW command with extracted aggregates
    let mut window_builder = window_command().at(cmd.ast.span);
    for (id, expr) in &extracted_aggs {
        window_builder = window_builder.named_field(id.clone(), expr.clone());
    }

    // Preserve group_by
    for assignment in &window_cmd.group_by.assignments {
        if let Ok(id) = assignment.identifier.valid_ref() {
            window_builder =
                window_builder.group_by(id.clone(), assignment.expression.ast.as_ref().clone());
        }
    }

    // Preserve sort_by
    for sort_expr in &window_cmd.sort_by {
        window_builder = window_builder.sort_expr(sort_expr.ast.as_ref().clone());
    }

    // Preserve within
    if let Some(within) = &window_cmd.within {
        window_builder = window_builder.within(within.ast.clone());
    }

    let mut result = vec![Arc::new(window_builder.build())];

    // If we extracted any aggregates into expressions, add LET + DROP commands
    if !let_assignments.is_empty() {
        // LET to compute the final expressions from the synthetic aggregates
        let mut let_builder = let_command();
        for (id, expr) in &let_assignments {
            let_builder = let_builder.named_field(id.clone(), expr.clone());
        }
        result.push(Arc::new(let_builder.build()));

        // DROP to remove the synthetic aggregate columns
        let mut drop_builder = drop_command();
        for synth_id in &synth_ids_to_drop {
            drop_builder = drop_builder.field(synth_id.clone());
        }
        result.push(Arc::new(drop_builder.build()));
    }

    result
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        tree::{
            ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
            builder::{
                add, call, column_ref, drop_command, let_command, pipeline, select_command,
                sort_command, window_command,
            },
        },
        types::{struct_type::Struct, INT},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    // Case 1: No WINDOW commands - passes through unchanged
    #[case::no_window_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: WINDOW with top-level aggregate - passes through unchanged
    #[case::window_toplevel_agg_unchanged(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("order", 1).build())
            .command(window_command()
                .named_field("total", call("sum").arg(column_ref("value")))
                .sort(sort_command().by(column_ref("order")))
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("order", 1).build())
            .command(window_command()
                .named_field("total", call("sum").arg(column_ref("value")))
                .sort(sort_command().by(column_ref("order")))
                .build())
            .build(),
        Struct::default()
            .with_str("total", INT)
            .with_str("value", INT)
            .with_str("order", INT)
    )]
    // Case 3: WINDOW with nested aggregates in binary expression
    // Input: WINDOW crazy_number = sum(left) + sum(right) + 3
    // Output: WINDOW __window_agg_0 = sum(left), __window_agg_1 = sum(right)
    //       | LET crazy_number = __window_agg_0 + __window_agg_1 + 3
    //       | DROP __window_agg_0, __window_agg_1
    #[case::window_nested_agg_binary(
        pipeline()
            .command(select_command()
                .named_field("left", 1)
                .named_field("right", 2)
                .named_field("order", 1)
                .build())
            .command(window_command()
                .named_field("crazy_number",
                    add(
                        add(
                            call("sum").arg(column_ref("left")),
                            call("sum").arg(column_ref("right"))
                        ),
                        3
                    )
                )
                .sort(sort_command().by(column_ref("order")))
                .build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("left", 1)
                .named_field("right", 2)
                .named_field("order", 1)
                .build())
            .command(window_command()
                .named_field("__window_agg_0", call("sum").arg(column_ref("left")))
                .named_field("__window_agg_1", call("sum").arg(column_ref("right")))
                .sort(sort_command().by(column_ref("order")))
                .build())
            .command(let_command()
                .named_field("crazy_number",
                    add(
                        add(
                            column_ref("__window_agg_0"),
                            column_ref("__window_agg_1")
                        ),
                        3
                    )
                )
                .build())
            .command(drop_command()
                .field("__window_agg_0")
                .field("__window_agg_1")
                .build())
            .build(),
        Struct::default()
            .with_str("crazy_number", INT)
            .with_str("left", INT)
            .with_str("right", INT)
            .with_str("order", INT)
    )]
    fn test_extract_window_aggregates(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = input.typed_with().typed();
        let expected_typed = expected.typed_with().typed();

        let mut ctx = StatementTranslationContext::default();
        let result = extract_window_aggregates(Arc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().flatten();
        assert_eq!(result_schema, expected_output_schema);
    }
}