use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::{
command::Command,
expression::{ExpressionKind, FieldReference},
identifier::{Identifier, SimpleIdentifier},
},
builder::{self, drop_command, field_ref, let_command, window_command},
typed_ast::{
clause::Projections,
command::{TypedCommand, TypedCommandKind, TypedWindowCommand},
context::StatementTranslationContext,
pipeline::TypedPipeline,
},
};
use super::super::compound_lowering::{lower_compound_assignments, UniqueNameGenerator};
pub fn normalize_window(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(window_needs_normalization)
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__normalize_window");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in normalize_command(cmd, &mut name_gen)? {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn window_needs_normalization(cmd: &Arc<TypedCommand>) -> bool {
let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
return false;
};
has_compound_identifiers(&window_cmd.projections)
|| has_compound_identifiers(&window_cmd.group_by)
|| has_non_identity_group_by(&window_cmd.group_by)
}
fn has_compound_identifiers(projections: &Projections) -> bool {
projections.assignments.iter().any(|a| {
a.identifier
.valid_ref()
.map(|id| matches!(id, Identifier::Compound(_)))
.unwrap_or(false)
})
}
fn has_non_identity_group_by(projections: &Projections) -> bool {
projections.assignments.iter().any(|a| {
let Ok(Identifier::Simple(simple)) = a.identifier.valid_ref() else {
return false;
};
!matches!(
&a.expression.ast.kind,
ExpressionKind::FieldReference(FieldReference { field_name })
if field_name.valid_ref().map(|n| n == simple).unwrap_or(false)
)
})
}
fn normalize_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let TypedCommandKind::Window(window_cmd) = &cmd.kind else {
return Ok(vec![cmd.ast.clone()]);
};
if !window_needs_normalization(cmd) {
return Ok(vec![cmd.ast.clone()]);
}
transform_window(window_cmd, cmd, name_gen)
}
fn transform_window(
window_cmd: &TypedWindowCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let (window_assignments, window_restores) =
lower_compound_assignments(&window_cmd.projections, name_gen, &cmd.input_schema);
let mut let_befores: Vec<Command> = Vec::new();
let mut group_by_names: Vec<SimpleIdentifier> = Vec::new();
let mut group_by_restores: Vec<Command> = Vec::new();
for assignment in &window_cmd.group_by.assignments {
match assignment.identifier.valid_ref()? {
Identifier::Compound(compound) => {
let flat_name = name_gen.next(&cmd.input_schema);
let_befores.push(
let_command()
.named_field(
flat_name.clone(),
assignment.expression.ast.as_ref().clone(),
)
.build(),
);
group_by_names.push(flat_name.clone());
group_by_restores.push(
let_command()
.named_field(
Into::<Identifier>::into(compound.clone()),
field_ref(flat_name.as_str()),
)
.build(),
);
group_by_restores.push(drop_command().field(flat_name).build());
}
Identifier::Simple(simple) => {
let is_identity = matches!(
&assignment.expression.ast.kind,
ExpressionKind::FieldReference(FieldReference { field_name })
if field_name.valid_ref().map(|n| n == simple).unwrap_or(false)
);
if is_identity {
group_by_names.push(simple.clone());
} else {
let_befores.push(
let_command()
.named_field(simple.clone(), assignment.expression.ast.as_ref().clone())
.build(),
);
group_by_names.push(simple.clone());
}
}
}
}
let mut builder = window_command().at(cmd.ast.span);
for (id, expr) in window_assignments {
builder = builder.named_field(id, expr);
}
for name in &group_by_names {
builder = builder.group_by(name.clone(), field_ref(name.as_str()));
}
for sort_expr in &window_cmd.sort_by {
builder = builder.sort_expr(sort_expr.ast.as_ref().clone());
}
if let Some(within) = &window_cmd.within {
builder = builder.within(within.ast.clone());
}
let mut result: Vec<Arc<Command>> = let_befores.into_iter().map(Arc::new).collect();
result.push(Arc::new(builder.build()));
result.extend(window_restores.into_iter().map(Arc::new));
result.extend(group_by_restores.into_iter().map(Arc::new));
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::type_check;
use hamelin_lib::{
tree::ast::expression::IntervalUnit,
tree::{
ast::{identifier::CompoundIdentifier, pipeline::Pipeline},
builder::{
call, drop_command, field_ref, let_command, pipeline, select_command, sort_command,
window_command, IntervalLiteralBuilder,
},
},
types::{struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_window_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::window_simple_ids_unchanged(
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(window_command()
.named_field("running", call("sum").arg(field_ref("value")))
.group_by("category", field_ref("category"))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(window_command()
.named_field("running", call("sum").arg(field_ref("value")))
.group_by("category", field_ref("category"))
.build())
.build(),
// Schema order: projections first, then partition_by, then parent fields
// WINDOW binds partition_by fields directly, so category comes before value
Struct::default()
.with_str("running", INT)
.with_str("category", INT)
.with_str("value", INT)
)]
#[case::window_compound_projection(
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(window_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
call("sum").arg(field_ref("value"))
)
.group_by("category", field_ref("category"))
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("category", 1).build())
.command(window_command()
.named_field("__normalize_window_0", call("sum").arg(field_ref("value")))
.group_by("category", field_ref("category"))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
field_ref("__normalize_window_0")
)
.build())
.command(drop_command().field("__normalize_window_0").build())
.build(),
// Schema order after LET and DROP:
// LET prepends stats, WINDOW had {__normalize_window_0, category, value}, DROP removes temp
// Result: {stats.running, category, value}
Struct::default()
.with_str("stats", Struct::default().with_str("running", INT).into())
.with_str("category", INT)
.with_str("value", INT)
)]
#[case::window_compound_group_by(
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(window_command()
.named_field("running", call("sum").arg(field_ref("value")))
.group_by(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
field_ref("cat")
)
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
// LET-before: project compound partition_by into flat column
.command(let_command()
.named_field("__normalize_window_0", field_ref("cat"))
.build())
// WINDOW with simple column ref in group_by
.command(window_command()
.named_field("running", call("sum").arg(field_ref("value")))
.group_by("__normalize_window_0", field_ref("__normalize_window_0"))
.build())
// LET-after: restore compound path
.command(let_command()
.named_field(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
field_ref("__normalize_window_0")
)
.build())
.command(drop_command().field("__normalize_window_0").build())
.build(),
// Schema: LET-before adds __normalize_window_0, WINDOW adds running,
// LET-after prepends group.key, DROP removes temp
Struct::default()
.with_str(
"group",
Struct::default()
.with_str("key", INT)
.into(),
)
.with_str("running", INT)
.with_str("value", INT)
.with_str("cat", INT)
)]
#[case::window_compound_both(
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
.command(window_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
call("sum").arg(field_ref("value"))
)
.group_by(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
field_ref("cat")
)
.build())
.build(),
pipeline()
.command(select_command().named_field("value", 10).named_field("cat", 1).build())
// LET-before: project compound group_by into flat column
.command(let_command()
.named_field("__normalize_window_1", field_ref("cat"))
.build())
// WINDOW: projection temp + simple column ref group_by
.command(window_command()
.named_field("__normalize_window_0", call("sum").arg(field_ref("value")))
.group_by("__normalize_window_1", field_ref("__normalize_window_1"))
.build())
// Projection restore
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
field_ref("__normalize_window_0")
)
.build())
.command(drop_command().field("__normalize_window_0").build())
// Group_by restore
.command(let_command()
.named_field(
CompoundIdentifier::new("group".into(), "key".into(), vec![]),
field_ref("__normalize_window_1")
)
.build())
.command(drop_command().field("__normalize_window_1").build())
.build(),
Struct::default()
.with_str(
"group",
Struct::default()
.with_str("key", INT)
.into(),
)
.with_str("stats", Struct::default().with_str("running", INT).into())
.with_str("value", INT)
.with_str("cat", INT)
)]
#[case::window_compound_within(
pipeline()
.command(select_command()
.named_field("value", 10)
.named_field("category", 1)
.named_field("timestamp", 5)
.build())
.command(window_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
call("sum").arg(field_ref("value"))
)
.group_by("category", field_ref("category"))
.sort(sort_command().by(field_ref("timestamp")))
.within(IntervalLiteralBuilder::new(-5, IntervalUnit::Hour))
.build())
.build(),
pipeline()
.command(select_command()
.named_field("value", 10)
.named_field("category", 1)
.named_field("timestamp", 5)
.build())
.command(window_command()
.named_field("__normalize_window_0", call("sum").arg(field_ref("value")))
.group_by("category", field_ref("category"))
.sort(sort_command().by(field_ref("timestamp")))
.within(IntervalLiteralBuilder::new(-5, IntervalUnit::Hour))
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("stats".into(), "running".into(), vec![]),
field_ref("__normalize_window_0")
)
.build())
.command(drop_command().field("__normalize_window_0").build())
.build(),
// Schema order: stats prepended by LET, then WINDOW's {category}, then parent's {value, timestamp}
Struct::default()
.with_str("stats", Struct::default().with_str("running", INT).into())
.with_str("category", INT)
.with_str("value", INT)
.with_str("timestamp", INT)
)]
fn test_normalize_window(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = type_check(input).output;
let expected_typed = type_check(expected).output;
let mut ctx = StatementTranslationContext::default();
let result = normalize_window(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().as_struct().clone();
assert_eq!(result_schema, expected_output_schema);
}
}