hamelin_translation 0.9.7

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: Extract nested aggregates from AGG assignments (scalar wrappers).
//!
//! Transforms AGG commands where aggregate functions appear inside larger expressions
//! into AGG with only top-level aggregates, followed by SET + DROP (same pattern as
//! [`super::extract_window_aggregates`]).
//!
//! Example:
//! ```text
//! AGG avg_duration = from_millis(avg(to_millis(duration)) as int)
//! ```
//! becomes:
//! ```text
//! AGG __agg_extract_0 = avg(to_millis(duration))
//! | SET avg_duration = from_millis(__agg_extract_0 as int)
//! | DROP __agg_extract_0
//! ```

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::identifier::{Identifier, SimpleIdentifier};
use hamelin_lib::tree::typed_ast::command::{TypedAggCommand, TypedCommand, TypedCommandKind};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::{
    ast::command::Command,
    builder::{self, agg_command, drop_command, set_command},
};

use crate::normalize::special_function_extraction::{
    has_nested_special_function, ExtractSpecialFunctionsAlgebra,
};
use crate::unique::UniqueNameGenerator;

/// Extract nested aggregates from AGG assignments.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn extract_agg_aggregates(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(agg_has_nested_aggregates)
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    let mut name_gen = UniqueNameGenerator::new("__agg_extract");

    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in transform_command(cmd, &mut name_gen) {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_ast),
        ctx,
    )))
}

fn agg_has_nested_aggregates(cmd: &Arc<TypedCommand>) -> bool {
    let TypedCommandKind::Agg(agg_cmd) = &cmd.kind else {
        return false;
    };
    agg_cmd
        .aggregates
        .assignments
        .iter()
        .any(|a| has_nested_special_function(&a.expression))
}

fn transform_command(
    cmd: &Arc<TypedCommand>,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    let TypedCommandKind::Agg(agg_cmd) = &cmd.kind else {
        return vec![cmd.ast.clone()];
    };

    if !agg_has_nested_aggregates(cmd) {
        return vec![cmd.ast.clone()];
    }

    transform_agg(agg_cmd, cmd, name_gen)
}

fn transform_agg(
    agg_cmd: &TypedAggCommand,
    cmd: &TypedCommand,
    name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
    let mut extracted_aggs: Vec<(SimpleIdentifier, Expression)> = Vec::new();
    let mut set_assignments: Vec<(SimpleIdentifier, Expression)> = Vec::new();
    let mut synth_ids_to_drop: Vec<SimpleIdentifier> = Vec::new();

    for assignment in &agg_cmd.aggregates.assignments {
        let Ok(id) = assignment.identifier.valid_ref() else {
            continue;
        };
        let simple_id = match id {
            Identifier::Simple(s) => s.clone(),
            Identifier::Compound(c) => SimpleIdentifier::new(c.to_string()),
        };

        if has_nested_special_function(&assignment.expression) {
            let mut alg = ExtractSpecialFunctionsAlgebra {
                name_gen,
                schema: &cmd.input_schema,
                extractions: Vec::new(),
                synth_ids: Vec::new(),
            };
            let new_ast = assignment.expression.cata(&mut alg);

            extracted_aggs.extend(alg.extractions);
            synth_ids_to_drop.extend(alg.synth_ids);
            set_assignments.push((simple_id, new_ast.as_ref().clone()));
        } else {
            extracted_aggs.push((simple_id, assignment.expression.ast.as_ref().clone()));
        }
    }

    let mut agg_b = agg_command().at(cmd.ast.span);
    for (id, expr) in &extracted_aggs {
        agg_b = agg_b.named_aggregate(id.clone(), expr.clone());
    }

    for assignment in &agg_cmd.group_by.assignments {
        if let Ok(id) = assignment.identifier.valid_ref() {
            agg_b = agg_b.named_group(id.clone(), assignment.expression.ast.as_ref().clone());
        }
    }

    for sort_expr in &agg_cmd.sort_by {
        agg_b = agg_b.sort_expr(sort_expr.ast.as_ref().clone());
    }

    let mut result = vec![Arc::new(agg_b.build())];

    if !set_assignments.is_empty() {
        let mut set_builder = set_command();
        for (id, expr) in &set_assignments {
            set_builder = set_builder.named_field(id.clone(), expr.clone());
        }
        result.push(Arc::new(set_builder.build()));

        let mut drop_builder = drop_command();
        for synth_id in &synth_ids_to_drop {
            drop_builder = drop_builder.field(synth_id.clone());
        }
        result.push(Arc::new(drop_builder.build()));
    }

    result
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::type_check;
    use hamelin_lib::{
        tree::{
            ast::pipeline::Pipeline,
            builder::{
                add, agg_command, call, drop_command, field_ref, pipeline, select_command,
                set_command,
            },
        },
        types::{struct_type::Struct, INT},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[rstest]
    #[case::no_agg_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    #[case::agg_top_level_unchanged(
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(agg_command()
                .named_aggregate("total", call("sum").arg(field_ref("value")))
                .named_group("category", field_ref("category"))
                .build())
            .build(),
        pipeline()
            .command(select_command().named_field("value", 10).named_field("category", 1).build())
            .command(agg_command()
                .named_aggregate("total", call("sum").arg(field_ref("value")))
                .named_group("category", field_ref("category"))
                .build())
            .build(),
        Struct::default()
            .with_str("category", INT)
            .with_str("total", INT)
    )]
    #[case::agg_nested_agg_binary(
        pipeline()
            .command(select_command()
                .named_field("left", 1)
                .named_field("right", 2)
                .build())
            .command(agg_command()
                .named_aggregate("both",
                    add(
                        call("sum").arg(field_ref("left")),
                        call("sum").arg(field_ref("right"))
                    )
                )
                .build())
            .build(),
        pipeline()
            .command(select_command()
                .named_field("left", 1)
                .named_field("right", 2)
                .build())
            .command(agg_command()
                .named_aggregate("__agg_extract_0", call("sum").arg(field_ref("left")))
                .named_aggregate("__agg_extract_1", call("sum").arg(field_ref("right")))
                .build())
            .command(set_command()
                .named_field("both",
                    add(field_ref("__agg_extract_0"), field_ref("__agg_extract_1"))
                )
                .build())
            .command(drop_command()
                .field("__agg_extract_0")
                .field("__agg_extract_1")
                .build())
            .build(),
        Struct::default().with_str("both", INT)
    )]
    fn test_extract_agg_aggregates(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = type_check(input).output;
        let expected_typed = type_check(expected).output;

        let mut ctx = StatementTranslationContext::default();
        let result = extract_agg_aggregates(Arc::new(input_typed), &mut ctx).unwrap();

        assert_eq!(result.ast, expected_typed.ast);

        let result_schema = result.environment().as_struct().clone();
        assert_eq!(result_schema, expected_output_schema);
    }
}