use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::func::def::SpecialPosition;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_lib::tree::typed_ast::command::{TypedCommand, TypedCommandKind, TypedWindowCommand};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::{
MapExpressionAlgebra, TypedApply, TypedExpression, TypedExpressionKind,
};
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::{
ast::command::Command,
builder::{self, column_ref, drop_command, let_command, window_command, ExpressionBuilder},
};
use super::super::unique::UniqueNameGenerator;
pub fn extract_window_aggregates(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(window_has_nested_aggregates)
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__window_agg");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in transform_command(cmd, &mut name_gen) {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn window_has_nested_aggregates(cmd: &Arc<TypedCommand>) -> bool {
let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
return false;
};
window_cmd
.projections
.assignments
.iter()
.any(|a| has_nested_aggregate(&a.expression))
}
fn has_nested_aggregate(expr: &Arc<TypedExpression>) -> bool {
if is_aggregate_call(expr) {
return false;
}
expr.find(&mut is_aggregate_call).is_some()
}
fn is_aggregate_call(expr: &TypedExpression) -> bool {
matches!(&expr.kind, TypedExpressionKind::Apply(apply)
if apply.function_def.special_position() == Some(SpecialPosition::Agg))
}
fn transform_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
return vec![cmd.ast.clone()];
};
if !window_has_nested_aggregates(cmd) {
return vec![cmd.ast.clone()];
}
transform_window(window_cmd, cmd, name_gen)
}
struct ExtractAggregatesAlgebra<'a> {
name_gen: &'a mut UniqueNameGenerator,
schema: &'a TypeEnvironment,
extractions: Vec<(SimpleIdentifier, Expression)>,
synth_ids: Vec<SimpleIdentifier>,
}
impl MapExpressionAlgebra for ExtractAggregatesAlgebra<'_> {
fn apply(
&mut self,
node: &TypedApply,
expr: &TypedExpression,
children: hamelin_lib::func::def::ParameterBinding<Arc<Expression>>,
) -> Arc<Expression> {
if node.function_def.special_position() == Some(SpecialPosition::Agg) {
let synth_id = self.name_gen.next(self.schema);
self.extractions
.push((synth_id.clone(), expr.ast.as_ref().clone()));
self.synth_ids.push(synth_id.clone());
Arc::new(column_ref(synth_id).build())
} else {
node.replace_children_ast(expr, children)
}
}
}
fn transform_window(
window_cmd: &TypedWindowCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
let mut extracted_aggs: Vec<(SimpleIdentifier, Expression)> = Vec::new();
let mut let_assignments: Vec<(SimpleIdentifier, Expression)> = Vec::new();
let mut synth_ids_to_drop: Vec<SimpleIdentifier> = Vec::new();
for assignment in &window_cmd.projections.assignments {
let Ok(id) = assignment.identifier.valid_ref() else {
continue;
};
let simple_id = match id {
hamelin_lib::tree::ast::identifier::Identifier::Simple(s) => s.clone(),
hamelin_lib::tree::ast::identifier::Identifier::Compound(c) => {
SimpleIdentifier::new(c.to_string())
}
};
if has_nested_aggregate(&assignment.expression) {
let mut alg = ExtractAggregatesAlgebra {
name_gen,
schema: &cmd.input_schema,
extractions: Vec::new(),
synth_ids: Vec::new(),
};
let new_ast = assignment.expression.cata(&mut alg);
extracted_aggs.extend(alg.extractions);
synth_ids_to_drop.extend(alg.synth_ids);
let_assignments.push((simple_id, new_ast.as_ref().clone()));
} else {
extracted_aggs.push((simple_id, assignment.expression.ast.as_ref().clone()));
}
}
let mut window_builder = window_command().at(cmd.ast.span);
for (id, expr) in &extracted_aggs {
window_builder = window_builder.named_field(id.clone(), expr.clone());
}
for assignment in &window_cmd.group_by.assignments {
if let Ok(id) = assignment.identifier.valid_ref() {
window_builder =
window_builder.group_by(id.clone(), assignment.expression.ast.as_ref().clone());
}
}
for sort_expr in &window_cmd.sort_by {
window_builder = window_builder.sort_expr(sort_expr.ast.as_ref().clone());
}
if let Some(within) = &window_cmd.within {
window_builder = window_builder.within(within.ast.clone());
}
let mut result = vec![Arc::new(window_builder.build())];
if !let_assignments.is_empty() {
let mut let_builder = let_command();
for (id, expr) in &let_assignments {
let_builder = let_builder.named_field(id.clone(), expr.clone());
}
result.push(Arc::new(let_builder.build()));
let mut drop_builder = drop_command();
for synth_id in &synth_ids_to_drop {
drop_builder = drop_builder.field(synth_id.clone());
}
result.push(Arc::new(drop_builder.build()));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{
add, call, column_ref, drop_command, let_command, pipeline, select_command,
sort_command, window_command,
},
},
types::{struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_window_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::window_toplevel_agg_unchanged(
pipeline()
.command(select_command().named_field("value", 10).named_field("order", 1).build())
.command(window_command()
.named_field("total", call("sum").arg(column_ref("value")))
.sort(sort_command().by(column_ref("order")))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("order", 1).build())
.command(window_command()
.named_field("total", call("sum").arg(column_ref("value")))
.sort(sort_command().by(column_ref("order")))
.build())
.build(),
Struct::default()
.with_str("total", INT)
.with_str("value", INT)
.with_str("order", INT)
)]
#[case::window_nested_agg_binary(
pipeline()
.command(select_command()
.named_field("left", 1)
.named_field("right", 2)
.named_field("order", 1)
.build())
.command(window_command()
.named_field("crazy_number",
add(
add(
call("sum").arg(column_ref("left")),
call("sum").arg(column_ref("right"))
),
3
)
)
.sort(sort_command().by(column_ref("order")))
.build())
.build(),
pipeline()
.command(select_command()
.named_field("left", 1)
.named_field("right", 2)
.named_field("order", 1)
.build())
.command(window_command()
.named_field("__window_agg_0", call("sum").arg(column_ref("left")))
.named_field("__window_agg_1", call("sum").arg(column_ref("right")))
.sort(sort_command().by(column_ref("order")))
.build())
.command(let_command()
.named_field("crazy_number",
add(
add(
column_ref("__window_agg_0"),
column_ref("__window_agg_1")
),
3
)
)
.build())
.command(drop_command()
.field("__window_agg_0")
.field("__window_agg_1")
.build())
.build(),
Struct::default()
.with_str("crazy_number", INT)
.with_str("left", INT)
.with_str("right", INT)
.with_str("order", INT)
)]
fn test_extract_window_aggregates(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = extract_window_aggregates(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}