use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::{command::Command, identifier::Identifier},
builder::{self, agg_command},
typed_ast::{
clause::Projections,
command::{TypedAggCommand, TypedCommand, TypedCommandKind},
context::StatementTranslationContext,
pipeline::TypedPipeline,
},
};
use super::super::compound_lowering::{lower_compound_assignments, UniqueNameGenerator};
pub fn normalize_agg(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline.valid_ref()?.commands.iter().any(agg_has_compounds) {
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__normalize_agg");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in normalize_command(cmd, &mut name_gen) {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn agg_has_compounds(cmd: &Arc<TypedCommand>) -> bool {
let TypedCommandKind::Agg(agg_cmd) = &cmd.kind else {
return false;
};
has_compound_identifiers(&agg_cmd.aggregates) || has_compound_identifiers(&agg_cmd.group_by)
}
fn has_compound_identifiers(projections: &Projections) -> bool {
projections.assignments.iter().any(|a| {
a.identifier
.valid_ref()
.map(|id| matches!(id, Identifier::Compound(_)))
.unwrap_or(false)
})
}
fn normalize_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
let TypedCommandKind::Agg(agg_cmd) = &cmd.kind else {
return vec![cmd.ast.clone()];
};
transform_agg(agg_cmd, cmd, name_gen)
}
fn transform_agg(
agg_cmd: &TypedAggCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
let (agg_assignments, agg_restores) =
lower_compound_assignments(&agg_cmd.aggregates, name_gen, &cmd.input_schema);
let (group_by_assignments, group_by_restores) =
lower_compound_assignments(&agg_cmd.group_by, name_gen, &cmd.input_schema);
let mut builder = agg_command().at(cmd.ast.span);
for (id, expr) in agg_assignments {
builder = builder.named_aggregate(id, expr);
}
for (id, expr) in group_by_assignments {
builder = builder.named_group(id, expr);
}
for sort_expr in &agg_cmd.sort_by {
builder = builder.sort_expr(sort_expr.ast.as_ref().clone());
}
let mut result = vec![Arc::new(builder.build())];
result.extend(agg_restores.into_iter().map(Arc::new));
result.extend(group_by_restores.into_iter().map(Arc::new));
result
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::{
ast::{
identifier::CompoundIdentifier, pipeline::Pipeline, IntoTyped, TypeCheckExecutor,
},
builder::{
agg_command, call, column_ref, drop_command, let_command, pipeline, select_command,
sort_command,
},
},
types::{struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_agg_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::agg_simple_ids_unchanged(
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate("total", call("sum").arg(column_ref("value")))
.named_group("category", column_ref("category"))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate("total", call("sum").arg(column_ref("value")))
.named_group("category", column_ref("category"))
.build())
.build(),
Struct::default().with_str("category", INT).with_str("total", INT)
)]
#[case::agg_compound_aggregate(
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
call("sum").arg(column_ref("value"))
)
.named_group("category", column_ref("category"))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate("__normalize_agg_0", call("sum").arg(column_ref("value")))
.named_group("category", column_ref("category"))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
column_ref("__normalize_agg_0")
)
.build())
.command(drop_command().field("__normalize_agg_0").build())
.build(),
Struct::default()
.with_str("stats", Struct::default().with_str("total", INT).into())
.with_str("category", INT)
)]
#[case::agg_compound_with_sort_by(
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
call("sum").arg(column_ref("value"))
)
.named_group("category", column_ref("category"))
.sort(sort_command().by(column_ref("value")))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(agg_command()
.named_aggregate("__normalize_agg_0", call("sum").arg(column_ref("value")))
.named_group("category", column_ref("category"))
.sort(sort_command().by(column_ref("value")))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
column_ref("__normalize_agg_0")
)
.build())
.command(drop_command().field("__normalize_agg_0").build())
.build(),
Struct::default()
.with_str("stats", Struct::default().with_str("total", INT).into())
.with_str("category", INT)
)]
#[case::agg_compound_group_by(
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(agg_command()
.named_aggregate("total", call("sum").arg(column_ref("value")))
.named_group(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
column_ref("cat")
)
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(agg_command()
.named_aggregate("total", call("sum").arg(column_ref("value")))
.named_group("__normalize_agg_0", column_ref("cat"))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
column_ref("__normalize_agg_0")
)
.build())
.command(drop_command().field("__normalize_agg_0").build())
.build(),
Struct::default()
.with_str("group", Struct::default().with_str("key", INT).into())
.with_str("total", INT)
)]
#[case::agg_compound_both(
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(agg_command()
.named_aggregate(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
call("sum").arg(column_ref("value"))
)
.named_group(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
column_ref("cat")
)
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(agg_command()
.named_aggregate("__normalize_agg_0", call("sum").arg(column_ref("value")))
.named_group("__normalize_agg_1", column_ref("cat"))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "total".into(), vec![]),
column_ref("__normalize_agg_0")
)
.build())
.command(drop_command().field("__normalize_agg_0").build())
.command(let_command()
.named_field(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
column_ref("__normalize_agg_1")
)
.build())
.command(drop_command().field("__normalize_agg_1").build())
.build(),
Struct::default()
.with_str("group", Struct::default().with_str("key", INT).into())
.with_str("stats", Struct::default().with_str("total", INT).into())
)]
fn test_normalize_agg(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = normalize_agg(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}