mod assignment_tree;
mod freeze;
use assignment_tree::AssignmentTree;
use freeze::FreezeAlgebra;
use std::collections::HashMap;
use std::sync::Arc;
use ordermap::OrderMap;
use hamelin_eval::{eval, Environment};
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::clause::SortOrder;
use hamelin_lib::tree::ast::identifier::{Identifier, SimpleIdentifier};
use hamelin_lib::tree::ast::node::Span;
use hamelin_lib::tree::typed_ast::clause::{Projections, TypedFromClause};
use hamelin_lib::tree::typed_ast::command::{
SideEffect, TypedAggCommand, TypedCommand, TypedCommandKind, TypedExplodeCommand,
TypedFromCommand, TypedJoinCommand, TypedLimitCommand, TypedLookupCommand, TypedSelectCommand,
TypedSortCommand, TypedSortExpression, TypedUnionCommand, TypedWhereCommand,
TypedWindowCommand,
};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::{TypedExpression, TypedExpressionKind};
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::typed_ast::query::TypedStatement;
use hamelin_lib::types::Type;
use crate::window_frame::WindowFrame;
#[derive(Debug, Clone)]
pub enum IRSideEffect {
None,
Append {
table: Identifier,
distinct_by: Vec<SimpleIdentifier>,
},
}
#[derive(Debug, Clone)]
pub struct IRStatement {
pub with_clauses: Vec<IRWithClause>,
pub pipeline: Arc<IRPipeline>,
pub side_effect: IRSideEffect,
}
#[derive(Debug, Clone)]
pub struct IRWithClause {
pub name: SimpleIdentifier,
pub pipeline: Arc<IRPipeline>,
}
#[derive(Debug, Clone)]
pub struct IRPipeline {
pub commands: Vec<IRCommand>,
pub output_schema: Arc<TypeEnvironment>,
}
#[derive(Debug, Clone)]
pub struct IRCommand {
pub kind: IRCommandKind,
pub span: Span,
pub output_schema: Arc<TypeEnvironment>,
}
#[derive(Debug, Clone)]
pub enum IRCommandKind {
From(IRFromCommand),
Where(IRWhereCommand),
Select(IRSelectCommand),
Agg(IRAggCommand),
Window(IRWindowCommand),
Sort(IRSortCommand),
Limit(IRLimitCommand),
Explode(IRExplodeCommand),
Join(IRJoinCommand),
}
#[derive(Debug, Clone)]
pub struct IRFromCommand {
pub inputs: Vec<IRInput>,
}
#[derive(Debug, Clone)]
pub enum IRInput {
Table(Identifier),
With(SimpleIdentifier, Arc<IRPipeline>),
}
#[derive(Debug, Clone)]
pub struct IRWhereCommand {
pub predicate: IRExpression,
}
#[derive(Debug, Clone)]
pub struct IRSelectCommand {
pub assignments: Vec<IRAssignment>,
}
#[derive(Debug, Clone)]
pub struct IRAggCommand {
pub aggregates: Vec<IRAssignment>,
pub group_by: Vec<IRAssignment>,
pub sort_by: Vec<IRSortExpression>,
}
#[derive(Debug, Clone)]
pub struct IRWindowCommand {
pub projections: Vec<IRAssignment>,
pub partition_by: Vec<IRAssignment>,
pub sort_by: Vec<IRSortExpression>,
pub frame: Option<WindowFrame>,
}
#[derive(Debug, Clone)]
pub struct IRSortCommand {
pub sort_by: Vec<IRSortExpression>,
}
#[derive(Debug, Clone)]
pub struct IRSortExpression {
pub expression: IRExpression,
pub order: SortOrder,
}
#[derive(Debug, Clone)]
pub struct IRLimitCommand {
pub count: IRExpression,
}
#[derive(Debug, Clone)]
pub struct IRExplodeCommand {
pub column: SimpleIdentifier,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
}
#[derive(Debug, Clone)]
pub struct IRJoinCommand {
pub join_type: JoinType,
pub right: SimpleIdentifier,
pub condition: IRExpression,
}
#[derive(Debug, Clone)]
pub struct IRAssignment {
pub identifier: SimpleIdentifier,
pub expression: IRExpression,
}
#[derive(Debug, Clone)]
pub struct IRExpression(pub Arc<TypedExpression>);
impl IRExpression {
pub fn new(expr: Arc<TypedExpression>) -> Self {
Self(expr)
}
pub fn inner(&self) -> &TypedExpression {
&self.0
}
pub fn resolved_type(&self) -> &Type {
&self.0.resolved_type
}
pub fn span(&self) -> &Span {
&self.0.ast.span
}
pub fn freeze(&self) -> Self {
let mut alg = FreezeAlgebra;
let result = self.0.cata(&mut alg);
match result {
Ok(value) => IRExpression::new(value.into()),
Err(expr) => IRExpression::new(expr),
}
}
}
impl IRStatement {
pub fn from_typed(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Self, Arc<TranslationError>> {
let mut with_clauses = Vec::new();
let mut cte_map: HashMap<String, Arc<IRPipeline>> = HashMap::new();
for wc in &statement.with_clauses {
let name = wc
.name
.valid_ref()?
.clone()
.try_unwrap_simple()
.map_err(|id| {
ctx.error(format!("CTE name must be simple identifier, got: {}", id))
.emit()
})?;
let pipeline = Arc::new(IRPipeline::from_typed(wc.pipeline.clone(), ctx, &cte_map)?);
cte_map.insert(name.as_str().to_string(), pipeline.clone());
with_clauses.push(IRWithClause { name, pipeline });
}
let pipeline = IRPipeline::from_typed(statement.pipeline.clone(), ctx, &cte_map)?;
let side_effect = match &statement.side_effect {
SideEffect::None => IRSideEffect::None,
SideEffect::Append { table, distinct_by } => {
let table_id = table.ast.identifier.valid_ref()?.clone();
let mut lowered_distinct_by = Vec::new();
for selection in distinct_by {
let id = selection
.ast
.identifier
.valid_ref()?
.clone()
.try_unwrap_simple()
.map_err(|id| {
ctx.error(format!(
"APPEND DISTINCT BY must use simple identifiers, got: {}",
id
))
.emit()
})?;
lowered_distinct_by.push(id);
}
IRSideEffect::Append {
table: table_id,
distinct_by: lowered_distinct_by,
}
}
};
Ok(Self {
with_clauses,
pipeline: Arc::new(pipeline),
side_effect,
})
}
}
impl IRPipeline {
pub fn from_typed(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
let mut commands = Vec::new();
for cmd in &valid.commands {
if matches!(cmd.kind, TypedCommandKind::Append(_)) {
continue;
}
commands.push(IRCommand::from_typed(cmd, ctx, cte_map)?);
}
Ok(Self {
commands,
output_schema: valid.final_schema.clone(),
})
}
}
impl IRCommand {
pub fn from_typed(
cmd: &Arc<TypedCommand>,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
let kind = IRCommandKind::from_typed(&cmd.kind, ctx, cte_map)?;
Ok(Self {
kind,
span: cmd.ast.span,
output_schema: cmd.output_schema.clone(),
})
}
}
impl IRCommandKind {
fn from_typed(
kind: &TypedCommandKind,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
match kind {
TypedCommandKind::From(from_cmd) => {
Ok(IRFromCommand::from_typed(from_cmd, ctx, cte_map)?.into())
}
TypedCommandKind::Where(where_cmd) => Ok(IRWhereCommand::from_typed(where_cmd).into()),
TypedCommandKind::Select(select_cmd) => {
Ok(IRSelectCommand::from_typed(select_cmd)?.into())
}
TypedCommandKind::Agg(agg_cmd) => Ok(IRAggCommand::from_typed(agg_cmd)?.into()),
TypedCommandKind::Window(window_cmd) => {
Ok(IRWindowCommand::from_typed(window_cmd)?.into())
}
TypedCommandKind::Sort(sort_cmd) => Ok(IRSortCommand::from_typed(sort_cmd).into()),
TypedCommandKind::Limit(limit_cmd) => Ok(IRLimitCommand::from_typed(limit_cmd).into()),
TypedCommandKind::Explode(explode_cmd) => {
Ok(IRExplodeCommand::from_typed(explode_cmd, ctx)?.into())
}
TypedCommandKind::Let(_) => Err(ctx
.error("LET command should have been fused into SELECT during normalization")
.emit()),
TypedCommandKind::Drop(_) => Err(ctx
.error("DROP command should have been fused into SELECT during normalization")
.emit()),
TypedCommandKind::Within(_) => Err(ctx
.error("WITHIN command should have been converted to WHERE during normalization")
.emit()),
TypedCommandKind::Parse(_) => Err(ctx
.error("PARSE command should have been lowered to LET + WHERE during normalization")
.emit()),
TypedCommandKind::Unnest(_) => Err(ctx
.error("UNNEST command should have been lowered during normalization")
.emit()),
TypedCommandKind::Join(join_cmd) => {
Ok(IRJoinCommand::from_typed_join(join_cmd, ctx)?.into())
}
TypedCommandKind::Lookup(lookup_cmd) => {
Ok(IRJoinCommand::from_typed_lookup(lookup_cmd, ctx)?.into())
}
TypedCommandKind::Append(_) => Err(ctx
.error("APPEND command should be skipped during pipeline lowering (captured as IRSideEffect)")
.emit()),
TypedCommandKind::Union(union_cmd) => {
Ok(IRFromCommand::from_union(union_cmd, ctx, cte_map)?.into())
}
TypedCommandKind::Match(_) => Err(ctx
.error("MATCH command should have been lowered during normalization")
.emit()),
TypedCommandKind::Nest(_) => Err(ctx
.error("NEST command should have been lowered to SELECT during normalization")
.emit()),
TypedCommandKind::Error(err) => Err(err.clone()),
}
}
}
impl IRFromCommand {
fn from_typed(
from_cmd: &TypedFromCommand,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
let mut inputs = Vec::new();
for clause in &from_cmd.clauses {
inputs.push(IRInput::from_typed(clause, ctx, cte_map)?);
}
Ok(Self { inputs })
}
fn from_union(
union_cmd: &TypedUnionCommand,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
let mut inputs = Vec::new();
for clause in &union_cmd.clauses {
inputs.push(IRInput::from_typed(clause, ctx, cte_map)?);
}
Ok(Self { inputs })
}
}
impl IRInput {
fn from_typed(
clause: &TypedFromClause,
ctx: &mut StatementTranslationContext,
cte_map: &HashMap<String, Arc<IRPipeline>>,
) -> Result<Self, Arc<TranslationError>> {
match clause {
TypedFromClause::Reference(table_ref) => {
let identifier = table_ref.ast.identifier.valid_ref()?.clone();
if let Identifier::Simple(ref simple) = identifier {
if let Some(pipeline) = cte_map.get(simple.as_str()) {
return Ok(IRInput::With(simple.clone(), pipeline.clone()));
}
}
Ok(IRInput::Table(identifier))
}
TypedFromClause::Alias(_) => Err(ctx
.error("FROM aliases should have been converted to CTEs during normalization")
.emit()),
TypedFromClause::Error(err) => Err(err.clone()),
}
}
}
impl IRWhereCommand {
fn from_typed(where_cmd: &TypedWhereCommand) -> Self {
Self {
predicate: IRExpression::new(where_cmd.predicate.clone()),
}
}
}
impl IRSelectCommand {
fn from_typed(select_cmd: &TypedSelectCommand) -> Result<Self, Arc<TranslationError>> {
let assignments = convert_projections(&select_cmd.projections)?;
Ok(Self { assignments })
}
}
impl IRAggCommand {
fn from_typed(agg_cmd: &TypedAggCommand) -> Result<Self, Arc<TranslationError>> {
let aggregates = convert_projections(&agg_cmd.aggregates)?;
let group_by = convert_projections(&agg_cmd.group_by)?;
let sort_by = convert_sort_expressions(&agg_cmd.sort_by);
Ok(Self {
aggregates,
group_by,
sort_by,
})
}
}
impl IRWindowCommand {
fn from_typed(window_cmd: &TypedWindowCommand) -> Result<Self, Arc<TranslationError>> {
let projections = convert_projections(&window_cmd.projections)?;
let partition_by = convert_projections(&window_cmd.group_by)?;
let sort_by = convert_sort_expressions(&window_cmd.sort_by);
let frame = window_cmd
.within
.as_ref()
.map(|within_expr| eval_within_to_frame(within_expr))
.transpose()?;
Ok(Self {
projections,
partition_by,
sort_by,
frame,
})
}
}
fn eval_within_to_frame(
within_expr: &TypedExpression,
) -> Result<WindowFrame, Arc<TranslationError>> {
use hamelin_lib::err::Context;
use hamelin_lib::tree::typed_ast::expression::{TypedErrorExpression, TypedExpressionKind};
if let Some(err_expr) =
within_expr.find(&mut |e| matches!(&e.kind, TypedExpressionKind::Error(_)))
{
if let TypedExpressionKind::Error(TypedErrorExpression { error }) = &err_expr.kind {
return Err(error.clone());
}
}
if let Some(bad_func_expr) = within_expr.find(&mut |e| {
matches!(&e.kind, TypedExpressionKind::Apply(apply) if !apply.function_def.is_deterministic())
}) {
if let TypedExpressionKind::Apply(apply) = &bad_func_expr.kind {
let span = bad_func_expr.ast.span.to_range().unwrap_or(0..=0);
let error = TranslationError::new(Context::new(
span,
&format!(
"WITHIN expression cannot use non-deterministic function '{}' - window frames must be constant",
apply.function_def.name()
),
));
return Err(error.into());
}
}
let empty_env = Environment::default();
match eval(within_expr, &empty_env) {
Ok(value) => {
WindowFrame::from_value(value).map_err(|msg| {
let span = within_expr.ast.span.to_range().unwrap_or(0..=0);
TranslationError::new(Context::new(
span,
&format!("Invalid window frame: {}", msg),
))
.into()
})
}
Err(eval_err) => {
let span = within_expr.ast.span.to_range().unwrap_or(0..=0);
let error = TranslationError::new(Context::new(
span,
&format!(
"WITHIN expression must be constant (cannot reference columns): {}",
eval_err
),
));
Err(error.into())
}
}
}
impl IRSortCommand {
fn from_typed(sort_cmd: &TypedSortCommand) -> Self {
Self {
sort_by: convert_sort_expressions(&sort_cmd.expressions),
}
}
}
impl IRLimitCommand {
fn from_typed(limit_cmd: &TypedLimitCommand) -> Self {
Self {
count: IRExpression::new(limit_cmd.count.clone()),
}
}
}
impl IRExplodeCommand {
fn from_typed(
explode_cmd: &TypedExplodeCommand,
ctx: &mut StatementTranslationContext,
) -> Result<Self, Arc<TranslationError>> {
let column = explode_cmd
.identifier
.valid_ref()?
.clone()
.try_unwrap_simple()
.map_err(|id| {
ctx.error(format!(
"EXPLODE identifier must be simple after normalization, got: {}",
id
))
.emit()
})?;
let is_canonical = matches!(
&explode_cmd.expression.kind,
TypedExpressionKind::ColumnReference(col_ref)
if col_ref.column_name.valid_ref()
.is_ok_and(|name| name.as_str() == column.as_str())
);
if !is_canonical {
return Err(ctx
.error(format!(
"EXPLODE must be in canonical form (EXPLODE {0} = {0}) after normalization, \
but expression is not a column reference to '{0}'",
column
))
.emit());
}
Ok(Self { column })
}
}
impl IRJoinCommand {
fn from_typed_join(
join_cmd: &TypedJoinCommand,
ctx: &mut StatementTranslationContext,
) -> Result<Self, Arc<TranslationError>> {
Self::from_typed_inner(JoinType::Inner, &join_cmd.right, &join_cmd.condition, ctx)
}
fn from_typed_lookup(
lookup_cmd: &TypedLookupCommand,
ctx: &mut StatementTranslationContext,
) -> Result<Self, Arc<TranslationError>> {
Self::from_typed_inner(
JoinType::Left,
&lookup_cmd.right,
&lookup_cmd.condition,
ctx,
)
}
fn from_typed_inner(
join_type: JoinType,
right: &hamelin_lib::tree::typed_ast::clause::TypedTableAlias,
condition: &Option<Arc<TypedExpression>>,
ctx: &mut StatementTranslationContext,
) -> Result<Self, Arc<TranslationError>> {
let right_id = right
.ast
.table
.identifier
.valid_ref()?
.clone()
.try_unwrap_simple()
.map_err(|id| {
ctx.error(format!(
"JOIN right side must be simple identifier after lowering, got: {}",
id
))
.emit()
})?;
let condition = condition.as_ref().ok_or_else(|| {
ctx.error("JOIN condition should be present after lower_joins normalization")
.emit()
})?;
Ok(Self {
join_type,
right: right_id,
condition: IRExpression::new(condition.clone()),
})
}
}
fn convert_projections(
projections: &Projections,
) -> Result<Vec<IRAssignment>, Arc<TranslationError>> {
let mut groups: OrderMap<SimpleIdentifier, AssignmentTree> = OrderMap::new();
for assignment in &projections.assignments {
let identifier = assignment.identifier.valid_ref()?;
let expression = assignment.expression.clone();
match identifier {
Identifier::Simple(simple) => {
groups
.entry(simple.clone())
.or_default()
.insert_leaf(expression);
}
Identifier::Compound(compound) => {
let root = compound.first();
let path = &compound.parts[1..];
groups
.entry(root)
.or_default()
.insert_at_path(path, expression);
}
}
}
Ok(groups
.into_iter()
.map(|(identifier, tree)| IRAssignment {
identifier,
expression: tree.into_ir_expression(),
})
.collect())
}
fn convert_sort_expressions(exprs: &[TypedSortExpression]) -> Vec<IRSortExpression> {
exprs
.iter()
.map(|e| IRSortExpression {
expression: IRExpression::new(e.expression.clone()),
order: e.order.clone(),
})
.collect()
}
impl From<IRFromCommand> for IRCommandKind {
fn from(cmd: IRFromCommand) -> Self {
IRCommandKind::From(cmd)
}
}
impl From<IRWhereCommand> for IRCommandKind {
fn from(cmd: IRWhereCommand) -> Self {
IRCommandKind::Where(cmd)
}
}
impl From<IRSelectCommand> for IRCommandKind {
fn from(cmd: IRSelectCommand) -> Self {
IRCommandKind::Select(cmd)
}
}
impl From<IRAggCommand> for IRCommandKind {
fn from(cmd: IRAggCommand) -> Self {
IRCommandKind::Agg(cmd)
}
}
impl From<IRWindowCommand> for IRCommandKind {
fn from(cmd: IRWindowCommand) -> Self {
IRCommandKind::Window(cmd)
}
}
impl From<IRSortCommand> for IRCommandKind {
fn from(cmd: IRSortCommand) -> Self {
IRCommandKind::Sort(cmd)
}
}
impl From<IRLimitCommand> for IRCommandKind {
fn from(cmd: IRLimitCommand) -> Self {
IRCommandKind::Limit(cmd)
}
}
impl From<IRExplodeCommand> for IRCommandKind {
fn from(cmd: IRExplodeCommand) -> Self {
IRCommandKind::Explode(cmd)
}
}
impl From<IRJoinCommand> for IRCommandKind {
fn from(cmd: IRJoinCommand) -> Self {
IRCommandKind::Join(cmd)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_eval::{eval, value::Value, Environment};
use hamelin_lib::tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{ident, pipeline, select_command},
typed_ast::expression::TypedExpressionKind,
};
use hamelin_lib::types::{struct_type::Struct, INT, STRING};
use pretty_assertions::assert_eq;
use rstest::rstest;
fn get_ir_select(pipeline: Pipeline) -> IRSelectCommand {
let typed = pipeline.typed_with().typed();
let select_cmd = typed.valid_ref().unwrap().commands[0].clone();
if let TypedCommandKind::Select(select) = &select_cmd.kind {
IRSelectCommand::from_typed(select).unwrap()
} else {
panic!("Expected SELECT command");
}
}
fn assignment_info(assignment: &IRAssignment) -> (String, &Type) {
(
assignment.identifier.to_string(),
assignment.expression.resolved_type(),
)
}
#[rstest]
#[case::simple_assignments(
pipeline()
.command(select_command()
.named_field("a", 1)
.named_field("b", "hello")
.build())
.build(),
vec![("a", INT.clone()), ("b", STRING.clone())]
)]
fn test_simple_assignments_no_packing(
#[case] input: Pipeline,
#[case] expected: Vec<(&str, Type)>,
) {
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), expected.len());
for (assignment, (expected_name, expected_type)) in
ir_select.assignments.iter().zip(expected.iter())
{
let (name, resolved) = assignment_info(assignment);
assert_eq!(name, *expected_name);
assert_eq!(resolved, expected_type);
}
}
#[rstest]
#[case::compound_same_root_packs_to_struct(
pipeline()
.command(select_command()
.named_field(ident("x").dot("a"), 1)
.named_field(ident("x").dot("b"), "hello")
.build())
.build(),
"x",
Struct::default().with_str("a", INT).with_str("b", STRING).into()
)]
#[case::single_compound_packs(
pipeline()
.command(select_command()
.named_field(ident("user").dot("id"), 42)
.build())
.build(),
"user",
Struct::default().with_str("id", INT).into()
)]
fn test_compound_identifiers_pack_to_struct(
#[case] input: Pipeline,
#[case] expected_name: &str,
#[case] expected_type: Type,
) {
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), 1);
let (name, resolved) = assignment_info(&ir_select.assignments[0]);
assert_eq!(name, expected_name);
assert_eq!(resolved, &expected_type);
}
#[rstest]
#[case::deep_nesting(
pipeline()
.command(select_command()
.named_field(ident("a").dot("b").dot("c"), 1)
.build())
.build(),
"a",
Struct::default().with_str("b",
Struct::default().with_str("c", INT).into()).into()
)]
fn test_deep_nesting_packs_correctly(
#[case] input: Pipeline,
#[case] expected_name: &str,
#[case] expected_type: Type,
) {
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), 1);
let (name, resolved) = assignment_info(&ir_select.assignments[0]);
assert_eq!(name, expected_name);
assert_eq!(resolved, &expected_type);
}
#[rstest]
#[case::mixed_simple_and_compound(
pipeline()
.command(select_command()
.named_field("simple", 1)
.named_field(ident("nested").dot("field"), 2)
.build())
.build(),
vec![
("simple", INT.clone()),
("nested", Struct::default().with_str("field", INT).into()),
]
)]
#[case::multiple_roots(
pipeline()
.command(select_command()
.named_field(ident("a").dot("x"), 1)
.named_field(ident("b").dot("y"), 2)
.build())
.build(),
vec![
("a", Struct::default().with_str("x", INT).into()),
("b", Struct::default().with_str("y", INT).into()),
]
)]
fn test_mixed_assignments(#[case] input: Pipeline, #[case] expected: Vec<(&str, Type)>) {
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), expected.len());
for (assignment, (expected_name, expected_type)) in
ir_select.assignments.iter().zip(expected.iter())
{
let (name, resolved) = assignment_info(assignment);
assert_eq!(name, *expected_name);
assert_eq!(resolved, expected_type);
}
}
#[test]
fn test_order_preserved() {
let input = pipeline()
.command(
select_command()
.named_field(ident("z").dot("first"), 1)
.named_field(ident("a").dot("second"), 2)
.named_field(ident("z").dot("third"), 3)
.build(),
)
.build();
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), 2);
assert_eq!(ir_select.assignments[0].identifier.to_string(), "z");
assert_eq!(ir_select.assignments[1].identifier.to_string(), "a");
let z_type = ir_select.assignments[0].expression.resolved_type();
let expected_z = Struct::default()
.with_str("first", INT)
.with_str("third", INT);
assert_eq!(z_type, &Type::from(expected_z));
}
#[test]
fn test_expressions_preserved_in_packed_structs() {
let input = pipeline()
.command(
select_command()
.named_field(ident("nested").dot("int_field"), 42)
.named_field(ident("nested").dot("str_field"), "hello")
.build(),
)
.build();
let ir_select = get_ir_select(input);
assert_eq!(ir_select.assignments.len(), 1);
let nested = &ir_select.assignments[0];
assert_eq!(nested.identifier.to_string(), "nested");
let expected_type: Type = Struct::default()
.with_str("int_field", INT)
.with_str("str_field", STRING)
.into();
assert_eq!(nested.expression.resolved_type(), &expected_type);
}
#[test]
fn test_freeze_resolves_now_in_within() {
use crate::lower::lower;
use hamelin_lib::tree::builder::{call, hours, string};
let input = pipeline()
.command(
select_command()
.named_field("timestamp", call("ts").arg(string("2024-01-15T12:00:00Z")))
.build(),
)
.within(hours(-5))
.build();
let typed = input.typed_with().typed();
let query = hamelin_lib::tree::builder::query()
.main(Arc::new(typed))
.build();
let typed_query = query.typed_with().typed();
let ir = lower(Arc::new(typed_query)).expect("lowering should succeed");
let where_cmd = ir
.pipeline
.commands
.iter()
.find(|cmd| matches!(cmd.kind, IRCommandKind::Where(_)))
.expect("should have WHERE command");
let IRCommandKind::Where(where_cmd) = &where_cmd.kind else {
panic!("expected WHERE command");
};
let frozen = where_cmd.predicate.freeze();
fn has_now_apply(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::Apply(apply) => {
if apply.function_def.name() == "now" {
return true;
}
apply.parameter_binding.iter().any(|arg| has_now_apply(arg))
}
TypedExpressionKind::ArrayLiteral(arr) => {
arr.elements.iter().any(|e| has_now_apply(e))
}
TypedExpressionKind::TupleLiteral(tup) => {
tup.elements.iter().any(|e| has_now_apply(e))
}
TypedExpressionKind::StructLiteral(s) => {
s.fields.iter().any(|(_, e)| has_now_apply(e))
}
TypedExpressionKind::VariantIndexAccess(v) => has_now_apply(&v.value),
TypedExpressionKind::FieldLookup(f) => has_now_apply(&f.value),
TypedExpressionKind::Cast(c) => has_now_apply(&c.value),
TypedExpressionKind::TsTrunc(t) => has_now_apply(&t.expression),
TypedExpressionKind::BroadcastApply(b) => {
b.parameter_binding.iter().any(|arg| has_now_apply(arg))
}
TypedExpressionKind::ColumnReference(_)
| TypedExpressionKind::Leaf
| TypedExpressionKind::Lambda(_)
| TypedExpressionKind::Error(_) => false,
}
}
fn has_ts_apply(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::Apply(apply) => {
if apply.function_def.name() == "ts" {
return true;
}
apply.parameter_binding.iter().any(|arg| has_ts_apply(arg))
}
TypedExpressionKind::ArrayLiteral(arr) => {
arr.elements.iter().any(|e| has_ts_apply(e))
}
TypedExpressionKind::TupleLiteral(tup) => {
tup.elements.iter().any(|e| has_ts_apply(e))
}
TypedExpressionKind::StructLiteral(s) => {
s.fields.iter().any(|(_, e)| has_ts_apply(e))
}
TypedExpressionKind::VariantIndexAccess(v) => has_ts_apply(&v.value),
TypedExpressionKind::FieldLookup(f) => has_ts_apply(&f.value),
TypedExpressionKind::Cast(c) => has_ts_apply(&c.value),
TypedExpressionKind::TsTrunc(t) => has_ts_apply(&t.expression),
TypedExpressionKind::BroadcastApply(b) => {
b.parameter_binding.iter().any(|arg| has_ts_apply(arg))
}
TypedExpressionKind::ColumnReference(_)
| TypedExpressionKind::Leaf
| TypedExpressionKind::Lambda(_)
| TypedExpressionKind::Error(_) => false,
}
}
assert!(
has_now_apply(where_cmd.predicate.inner()),
"before freeze should contain now() calls"
);
assert!(
!has_now_apply(frozen.inner()),
"frozen expression should not contain now() calls"
);
assert!(
has_ts_apply(frozen.inner()),
"frozen expression should contain ts() calls for resolved timestamps"
);
fn collect_ts_timestamps(expr: &TypedExpression) -> Vec<chrono::DateTime<chrono::Utc>> {
let mut timestamps = Vec::new();
collect_ts_timestamps_impl(expr, &mut timestamps);
timestamps
}
fn collect_ts_timestamps_impl(
expr: &TypedExpression,
out: &mut Vec<chrono::DateTime<chrono::Utc>>,
) {
match &expr.kind {
TypedExpressionKind::Apply(apply) => {
if apply.function_def.name() == "ts" {
let env = Environment::default();
if let Ok(Value::Timestamp(ts)) = eval(expr, &env) {
out.push(*ts.instant());
}
}
for arg in apply.parameter_binding.iter() {
collect_ts_timestamps_impl(arg, out);
}
}
TypedExpressionKind::ArrayLiteral(arr) => {
for e in &arr.elements {
collect_ts_timestamps_impl(e, out);
}
}
TypedExpressionKind::TupleLiteral(tup) => {
for e in &tup.elements {
collect_ts_timestamps_impl(e, out);
}
}
TypedExpressionKind::StructLiteral(s) => {
for (_, e) in &s.fields {
collect_ts_timestamps_impl(e, out);
}
}
TypedExpressionKind::VariantIndexAccess(v) => {
collect_ts_timestamps_impl(&v.value, out)
}
TypedExpressionKind::FieldLookup(f) => collect_ts_timestamps_impl(&f.value, out),
TypedExpressionKind::Cast(c) => collect_ts_timestamps_impl(&c.value, out),
TypedExpressionKind::TsTrunc(t) => collect_ts_timestamps_impl(&t.expression, out),
TypedExpressionKind::BroadcastApply(b) => {
for arg in b.parameter_binding.iter() {
collect_ts_timestamps_impl(arg, out);
}
}
TypedExpressionKind::ColumnReference(_)
| TypedExpressionKind::Leaf
| TypedExpressionKind::Lambda(_)
| TypedExpressionKind::Error(_) => {}
}
}
let timestamps = collect_ts_timestamps(frozen.inner());
assert_eq!(
timestamps.len(),
2,
"should have exactly 2 timestamp literals"
);
let now = chrono::Utc::now();
let five_hours = chrono::Duration::hours(5);
let one_minute = chrono::Duration::minutes(1);
let (earlier, later) = if timestamps[0] < timestamps[1] {
(timestamps[0], timestamps[1])
} else {
(timestamps[1], timestamps[0])
};
let expected_earlier = now - five_hours;
let expected_later = now;
assert!(
(earlier - expected_earlier).abs() < one_minute,
"earlier timestamp {:?} should be within 1 minute of now - 5h ({:?})",
earlier,
expected_earlier
);
assert!(
(later - expected_later).abs() < one_minute,
"later timestamp {:?} should be within 1 minute of now ({:?})",
later,
expected_later
);
}
}