use std::sync::Arc;
use chrono::Duration;
use hamelin_eval::value::{RangeValue, TimestampValue, Value};
use hamelin_eval::{eval, Environment};
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::expression::{Expression, ExpressionKind};
use hamelin_lib::tree::ast::ops::{BinaryOp, UnaryPostfixOp, UnaryPrefixOp};
use hamelin_lib::tree::{
ast::command::Command,
builder::{
self, add, and, call, cast, field, gte, lt, lte, string, where_command, BoolLiteralBuilder,
ExpressionBuilder, IntoExpressionBuilder,
},
typed_ast::{
command::{TypedCommand, TypedCommandKind, TypedWithinCommand},
context::StatementTranslationContext,
pipeline::TypedPipeline,
},
};
use hamelin_lib::types::{range::Range, Type, TIMESTAMP};
pub fn normalize_within(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(|cmd| matches!(&cmd.kind, TypedCommandKind::Within(_)))
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for normalized in normalize_command(cmd, ctx) {
pipe_builder = pipe_builder.command(normalized);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn normalize_command(
cmd: &Arc<TypedCommand>,
ctx: &mut StatementTranslationContext,
) -> Vec<Arc<Command>> {
let TypedCommandKind::Within(within_cmd) = &cmd.kind else {
return vec![cmd.ast.clone()];
};
transform_within(within_cmd, ctx, cmd)
}
fn transform_within(
within_cmd: &TypedWithinCommand,
ctx: &mut StatementTranslationContext,
cmd: &TypedCommand,
) -> Vec<Arc<Command>> {
let env = Environment::new();
let eval_result = eval(&within_cmd.duration, &env);
let condition: Box<dyn ExpressionBuilder> = match eval_result {
Ok(Value::Interval(duration)) => {
build_interval_condition(ctx, &within_cmd.duration.ast, duration < Duration::zero())
}
Ok(Value::CalendarInterval(months)) => {
build_interval_condition(ctx, &within_cmd.duration.ast, months < 0)
}
Ok(Value::Range(range)) => {
let end_inclusive =
matches!(&*within_cmd.duration.resolved_type, Type::RangeInclusive(_));
let has_interval_bounds = matches!(
(&range.lower, &range.upper),
(Some(Value::Interval(_)), _)
| (_, Some(Value::Interval(_)))
| (Some(Value::CalendarInterval(_)), _)
| (_, Some(Value::CalendarInterval(_)))
);
if has_interval_bounds {
build_interval_range_condition(ctx, &within_cmd.duration.ast, end_inclusive)
} else {
build_timestamp_range_condition(ctx, &range, end_inclusive)
}
}
Ok(Value::Timestamp(ts)) => {
let now = chrono::Utc::now();
build_timestamp_condition(ctx, &ts, *ts.instant() > now)
}
Ok(Value::Null) => return vec![],
_ => {
let end_inclusive =
matches!(&*within_cmd.duration.resolved_type, Type::RangeInclusive(_));
let timestamp_range_type: Type = if end_inclusive {
Type::RangeInclusive(Range::new(TIMESTAMP))
} else {
Range::new(TIMESTAMP).into()
};
let range_expr = cast(
within_cmd.duration.ast.as_ref().clone(),
timestamp_range_type,
)
.build();
build_dynamic_condition(ctx, range_expr, end_inclusive)
}
};
vec![Arc::new(where_command(condition).at(cmd.ast.span).build())]
}
fn build_timestamp_range_condition(
ctx: &StatementTranslationContext,
range: &RangeValue,
end_inclusive: bool,
) -> Box<dyn ExpressionBuilder> {
let timestamp_expr = ctx.timestamp_field.clone().into_expression_builder();
let lower_cmp: Box<dyn ExpressionBuilder> =
match range.lower.as_ref().and_then(extract_timestamp) {
Some(ts) => Box::new(gte(timestamp_expr.build(), timestamp_literal(&ts))),
None => Box::new(gte(timestamp_expr.build(), call("now"))),
};
let upper_cmp: Box<dyn ExpressionBuilder> =
match range.upper.as_ref().and_then(extract_timestamp) {
Some(ts) => {
let bound = timestamp_literal(&ts);
if end_inclusive {
Box::new(lte(timestamp_expr.build(), bound))
} else {
Box::new(lt(timestamp_expr.build(), bound))
}
}
None => {
if end_inclusive {
Box::new(lte(timestamp_expr.build(), call("now")))
} else {
Box::new(lt(timestamp_expr.build(), call("now")))
}
}
};
Box::new(and(lower_cmp, upper_cmp))
}
fn build_interval_condition(
ctx: &StatementTranslationContext,
original_interval_expr: &Expression,
is_negative: bool,
) -> Box<dyn ExpressionBuilder> {
let timestamp_expr = ctx.timestamp_field.clone().into_expression_builder();
let now_expr = call("now");
if is_negative {
Box::new(and(
gte(
timestamp_expr.build(),
add(now_expr.build(), original_interval_expr.clone()),
),
lt(timestamp_expr.build(), call("now")),
))
} else {
Box::new(and(
gte(timestamp_expr.build(), now_expr),
lt(
timestamp_expr.build(),
add(call("now"), original_interval_expr.clone()),
),
))
}
}
fn build_interval_range_condition(
ctx: &StatementTranslationContext,
original_range_expr: &Expression,
end_inclusive: bool,
) -> Box<dyn ExpressionBuilder> {
let (lower_expr, upper_expr) = match &original_range_expr.kind {
ExpressionKind::BinaryOperator(bin_op)
if bin_op.operator == BinaryOp::Range
|| bin_op.operator == BinaryOp::RangeInclusive =>
{
(Some(bin_op.left.as_ref()), Some(bin_op.right.as_ref()))
}
ExpressionKind::UnaryPostfixOperator(u) if u.operator == UnaryPostfixOp::Range => {
(Some(u.operand.as_ref()), None)
}
ExpressionKind::UnaryPrefixOperator(u)
if u.operator == UnaryPrefixOp::Range
|| u.operator == UnaryPrefixOp::RangeInclusive =>
{
(None, Some(u.operand.as_ref()))
}
_ => {
let timestamp_range_type: Type = if end_inclusive {
Type::RangeInclusive(Range::new(TIMESTAMP))
} else {
Range::new(TIMESTAMP).into()
};
let range_expr = cast(original_range_expr.clone(), timestamp_range_type).build();
return build_dynamic_condition(ctx, range_expr, end_inclusive);
}
};
let timestamp_expr = ctx.timestamp_field.clone().into_expression_builder();
let lower_expr = match lower_expr {
Some(e) => add(call("now"), e.clone()).build(),
None => call("now").build(),
};
let upper_expr = match upper_expr {
Some(e) => add(call("now"), e.clone()).build(),
None => call("now").build(),
};
let lower_cmp = gte(timestamp_expr.build(), lower_expr);
let upper_cmp = if end_inclusive {
lte(timestamp_expr.build(), upper_expr)
} else {
lt(timestamp_expr.build(), upper_expr)
};
Box::new(and(lower_cmp, upper_cmp))
}
fn build_timestamp_condition(
ctx: &StatementTranslationContext,
ts: &TimestampValue,
is_future: bool,
) -> Box<dyn ExpressionBuilder> {
let timestamp_expr = ctx.timestamp_field.clone().into_expression_builder();
if is_future {
Box::new(and(
gte(timestamp_expr.build(), call("now")),
lt(timestamp_expr.build(), timestamp_literal(ts)),
))
} else {
Box::new(and(
gte(timestamp_expr.build(), timestamp_literal(ts)),
lt(timestamp_expr.build(), call("now")),
))
}
}
fn build_dynamic_condition(
ctx: &StatementTranslationContext,
range_expr: Expression,
end_inclusive: bool,
) -> Box<dyn ExpressionBuilder> {
let timestamp_expr = ctx.timestamp_field.clone().into_expression_builder();
let lower_cmp = call("coalesce")
.arg(gte(
timestamp_expr.build(),
field(range_expr.clone(), "begin"),
))
.arg(BoolLiteralBuilder::new(true));
let end_field = field(range_expr, "end");
let upper_cmp = if end_inclusive {
call("coalesce")
.arg(lte(timestamp_expr.build(), end_field))
.arg(BoolLiteralBuilder::new(true))
} else {
call("coalesce")
.arg(lt(timestamp_expr.build(), end_field))
.arg(BoolLiteralBuilder::new(true))
};
Box::new(and(lower_cmp, upper_cmp))
}
fn extract_timestamp(v: &Value) -> Option<TimestampValue> {
match v {
Value::Timestamp(ts) => Some(ts.clone()),
_ => None,
}
}
fn timestamp_literal(ts: &TimestampValue) -> impl IntoExpressionBuilder {
call("ts").arg(string(ts.instant().to_rfc3339()))
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::type_check;
use hamelin_lib::{
tree::ast::expression::IntervalUnit,
tree::{
ast::pipeline::Pipeline,
builder::{
add, and, call, field_ref, gte, lt, lte, null, pipeline, select_command, string,
where_command, IntervalLiteralBuilder,
},
},
types::{struct_type::Struct, INT, TIMESTAMP},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
fn interval_hours(value: i64) -> IntervalLiteralBuilder {
IntervalLiteralBuilder::new(value, IntervalUnit::Hour)
}
#[rstest]
#[case::no_within_passthrough(
pipeline()
.command(select_command().named_field("a", 1).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).build())
.build(),
Struct::default().with_str("a", INT)
)]
#[case::constant_range_to_literal_where(
pipeline()
.command(select_command()
.named_field("timestamp", call("ts").arg(string("2024-01-01T00:00:00Z")))
.build())
.within(call("ts").arg(string("2024-01-01T00:00:00Z"))
..call("ts").arg(string("2024-01-02T00:00:00Z")))
.build(),
pipeline()
.command(select_command()
.named_field("timestamp", call("ts").arg(string("2024-01-01T00:00:00Z")))
.build())
.command(where_command(and(
gte(
field_ref("timestamp"),
call("ts").arg(string("2024-01-01T00:00:00+00:00")),
),
lt(
field_ref("timestamp"),
call("ts").arg(string("2024-01-02T00:00:00+00:00")),
),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::column_range_to_dynamic_where(
pipeline()
.command(select_command()
.named_field("timestamp", call("ts").arg(string("2024-01-01T00:00:00Z")))
.named_field(
"time_range",
call("ts").arg(string("2024-01-01T00:00:00Z"))
..call("ts").arg(string("2024-01-02T00:00:00Z")),
)
.build())
.within(field_ref("time_range"))
.build(),
{
let timestamp_range: Type = Range::new(TIMESTAMP).into();
let range_expr = cast(field_ref("time_range"), timestamp_range).build();
pipeline()
.command(select_command()
.named_field("timestamp", call("ts").arg(string("2024-01-01T00:00:00Z")))
.named_field(
"time_range",
call("ts").arg(string("2024-01-01T00:00:00Z"))
..call("ts").arg(string("2024-01-02T00:00:00Z")),
)
.build())
.command(where_command(and(
call("coalesce")
.arg(gte(
field_ref("timestamp"),
field(range_expr.clone(), "begin"),
))
.arg(BoolLiteralBuilder::new(true)),
call("coalesce")
.arg(lt(field_ref("timestamp"), field(range_expr, "end")))
.arg(BoolLiteralBuilder::new(true)),
)))
.build()
},
Struct::default()
.with_str("timestamp", TIMESTAMP)
.with_str("time_range", Range::new(TIMESTAMP).into())
)]
#[case::negative_interval_preserves_now(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(interval_hours(-5))
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.command(where_command(and(
gte(
field_ref("timestamp"),
add(call("now"), interval_hours(-5)),
),
lt(field_ref("timestamp"), call("now")),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::positive_interval_preserves_now(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(interval_hours(5))
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.command(where_command(and(
gte(field_ref("timestamp"), call("now")),
lt(
field_ref("timestamp"),
add(call("now"), interval_hours(5)),
),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::mixed_interval_range_preserves_now(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(interval_hours(-5)..interval_hours(2))
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.command(where_command(and(
gte(
field_ref("timestamp"),
add(call("now"), interval_hours(-5)),
),
lt(
field_ref("timestamp"),
add(call("now"), interval_hours(2)),
),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::negative_interval_range_preserves_now(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(interval_hours(-5)..interval_hours(-2))
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.command(where_command(and(
gte(
field_ref("timestamp"),
add(call("now"), interval_hours(-5)),
),
lt(
field_ref("timestamp"),
add(call("now"), interval_hours(-2)),
),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::inclusive_interval_range_preserves_now(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(interval_hours(-5)..=interval_hours(2))
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.command(where_command(and(
gte(
field_ref("timestamp"),
add(call("now"), interval_hours(-5)),
),
lte(
field_ref("timestamp"),
add(call("now"), interval_hours(2)),
),
)))
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
#[case::null_time_range_drops_within(
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.within(null())
.build(),
pipeline()
.command(select_command().named_field("timestamp", call("now")).build())
.build(),
Struct::default().with_str("timestamp", TIMESTAMP)
)]
fn test_normalize_within(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) -> Result<(), Arc<TranslationError>> {
let input_typed = type_check(input).output;
let expected_typed = type_check(expected).output;
let mut ctx = StatementTranslationContext::default();
let result = normalize_within(Arc::new(input_typed), &mut ctx)?;
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().as_struct().clone();
assert_eq!(result_schema, expected_output_schema);
Ok(())
}
}