use crate::eval::environment::Environment;
use crate::eval::timestamp::{next_truncation_boundary, truncate_timestamp};
use crate::reverse_eval::domain::Constraint;
use crate::reverse_eval::reverse::reverse_eval;
use crate::value::{TimestampValue, Value};
use chrono::{TimeZone, Utc};
use hamelin_lib::tree::ast::expression::TruncUnit;
use hamelin_lib::tree::builder::{at, call, field_ref, ExpressionBuilder};
use hamelin_lib::tree::options::ExpressionTypeCheckOptions;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::type_check_expression;
use hamelin_lib::types::{INT, TIMESTAMP};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::day(
TruncUnit::Day,
Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 16, 0, 0, 0).unwrap()
)]
#[case::hour(
TruncUnit::Hour,
Utc.with_ymd_and_hms(2024, 1, 15, 14, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 15, 15, 0, 0).unwrap()
)]
#[case::week(
TruncUnit::Week,
Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 22, 0, 0, 0).unwrap()
)]
#[case::month(
TruncUnit::Month,
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 2, 1, 0, 0, 0).unwrap()
)]
#[case::year(
TruncUnit::Year,
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2025, 1, 1, 0, 0, 0).unwrap()
)]
#[case::quarter(
TruncUnit::Quarter,
Utc.with_ymd_and_hms(2024, 4, 1, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 7, 1, 0, 0, 0).unwrap()
)]
fn test_reverse_ts_trunc_equals(
#[case] unit: TruncUnit,
#[case] truncated_ts: chrono::DateTime<Utc>,
#[case] next_boundary: chrono::DateTime<Utc>,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("ts", TIMESTAMP);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
at(field_ref("ts"), unit).build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let output_constraint = Constraint::Equals(TimestampValue::utc(truncated_ts).into());
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Range {
min: Some(TimestampValue::utc(truncated_ts).into()),
max: Some(TimestampValue::utc(next_boundary).into()),
})
);
}
#[rstest]
#[case::multi_day_range(
Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 17, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 18, 0, 0, 0).unwrap()
)]
#[case::single_day_boundary(
Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 16, 0, 0, 0).unwrap(),
Utc.with_ymd_and_hms(2024, 1, 17, 0, 0, 0).unwrap()
)]
fn test_reverse_ts_trunc_day_range(
#[case] min_output: chrono::DateTime<Utc>,
#[case] max_output: chrono::DateTime<Utc>,
#[case] expected_max_input: chrono::DateTime<Utc>,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("ts", TIMESTAMP);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
at(field_ref("ts"), TruncUnit::Day).build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let output_constraint = Constraint::Range {
min: Some(TimestampValue::utc(min_output).into()),
max: Some(TimestampValue::utc(max_output).into()),
};
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Range {
min: Some(TimestampValue::utc(min_output).into()),
max: Some(TimestampValue::utc(expected_max_input).into()),
})
);
}
#[test]
fn test_truncate_and_next_boundary_consistency() {
let original_ts = Utc.with_ymd_and_hms(2024, 1, 15, 14, 35, 42).unwrap();
for unit in [
TruncUnit::Second,
TruncUnit::Minute,
TruncUnit::Hour,
TruncUnit::Day,
TruncUnit::Week,
TruncUnit::Month,
TruncUnit::Quarter,
TruncUnit::Year,
] {
let truncated =
truncate_timestamp(&TimestampValue::utc(original_ts).into(), &unit, 1).unwrap();
let next = next_truncation_boundary(&truncated, &unit, 1).unwrap();
assert!(
original_ts >= *truncated.instant(),
"Original timestamp should be >= truncated for unit {:?}",
unit
);
assert!(
original_ts < *next.instant(),
"Original timestamp should be < next boundary for unit {:?}",
unit
);
}
}
#[rstest]
#[case::seconds(1703505600, "2023-12-25T12:00:00+00:00")]
#[case::zero_epoch(0, "1970-01-01T00:00:00+00:00")]
#[case::negative(-86400, "1969-12-31T00:00:00+00:00")]
fn test_reverse_from_unixtime_seconds_equals(
#[case] expected_seconds: i64,
#[case] timestamp_str: &str,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("epoch_col", INT);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("from_unixtime_seconds")
.arg(field_ref("epoch_col"))
.build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let timestamp: chrono::DateTime<Utc> = timestamp_str.parse().unwrap();
let output_constraint = Constraint::Equals(TimestampValue::utc(timestamp).into());
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Equals(Value::Int(expected_seconds)))
);
}
#[test]
fn test_reverse_from_unixtime_seconds_range() {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("epoch_col", INT);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("from_unixtime_seconds")
.arg(field_ref("epoch_col"))
.build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let min_ts: chrono::DateTime<Utc> = "2024-01-01T00:00:00+00:00".parse().unwrap();
let max_ts: chrono::DateTime<Utc> = "2024-01-02T00:00:00+00:00".parse().unwrap();
let output_constraint = Constraint::Range {
min: Some(TimestampValue::utc(min_ts).into()),
max: Some(TimestampValue::utc(max_ts).into()),
};
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Range {
min: Some(Value::Int(1704067200)), max: Some(Value::Int(1704153600)), })
);
}
#[rstest]
#[case::millis(1703505600123i64, "2023-12-25T12:00:00.123+00:00")]
#[case::zero_epoch(0i64, "1970-01-01T00:00:00+00:00")]
fn test_reverse_from_unixtime_millis_equals(
#[case] expected_millis: i64,
#[case] timestamp_str: &str,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("epoch_col", INT);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("from_unixtime_millis")
.arg(field_ref("epoch_col"))
.build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let timestamp: chrono::DateTime<Utc> = timestamp_str.parse().unwrap();
let output_constraint = Constraint::Equals(TimestampValue::utc(timestamp).into());
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Equals(Value::Int(expected_millis)))
);
}
#[rstest]
#[case::micros(1703505600123456i64, "2023-12-25T12:00:00.123456+00:00")]
#[case::zero_epoch(0i64, "1970-01-01T00:00:00+00:00")]
fn test_reverse_from_unixtime_micros_equals(
#[case] expected_micros: i64,
#[case] timestamp_str: &str,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("epoch_col", INT);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("from_unixtime_micros")
.arg(field_ref("epoch_col"))
.build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let timestamp: chrono::DateTime<Utc> = timestamp_str.parse().unwrap();
let output_constraint = Constraint::Equals(TimestampValue::utc(timestamp).into());
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Equals(Value::Int(expected_micros)))
);
}
#[rstest]
#[case::nanos(1703505600123456789i64, "2023-12-25T12:00:00.123456789+00:00")]
#[case::zero_epoch(0i64, "1970-01-01T00:00:00+00:00")]
fn test_reverse_from_unixtime_nanos_equals(
#[case] expected_nanos: i64,
#[case] timestamp_str: &str,
) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("epoch_col", INT);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("from_unixtime_nanos")
.arg(field_ref("epoch_col"))
.build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let timestamp: chrono::DateTime<Utc> = timestamp_str.parse().unwrap();
let output_constraint = Constraint::Equals(TimestampValue::utc(timestamp).into());
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(result, Some(Constraint::Equals(Value::Int(expected_nanos))));
}
#[rstest]
#[case::seconds(1703505600.0, "2023-12-25T12:00:00+00:00")]
#[case::with_fractional(1703505600.123456, "2023-12-25T12:00:00.123456+00:00")]
#[case::zero_epoch(0.0, "1970-01-01T00:00:00+00:00")]
fn test_reverse_to_unixtime_equals(#[case] expected_seconds: f64, #[case] timestamp_str: &str) {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("timestamp_col", TIMESTAMP);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("to_unixtime").arg(field_ref("timestamp_col")).build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let timestamp: chrono::DateTime<Utc> = timestamp_str.parse().unwrap();
let output_constraint = Constraint::Equals(Value::Double(expected_seconds));
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
assert_eq!(
result,
Some(Constraint::Equals(TimestampValue::utc(timestamp).into()))
);
}
#[test]
fn test_reverse_to_unixtime_range() {
let env = Environment::new();
let mut trans_env = TypeEnvironment::default();
trans_env.bind_str("timestamp_col", TIMESTAMP);
let bindings = Arc::new(trans_env);
let typed_expr = type_check_expression(
call("to_unixtime").arg(field_ref("timestamp_col")).build(),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output;
let min_seconds = 1704067200.0; let max_seconds = 1704153600.0;
let output_constraint = Constraint::Range {
min: Some(Value::Double(min_seconds)),
max: Some(Value::Double(max_seconds)),
};
let result = reverse_eval(&typed_expr, output_constraint, &env).unwrap();
let min_ts: chrono::DateTime<Utc> = "2024-01-01T00:00:00+00:00".parse().unwrap();
let max_ts: chrono::DateTime<Utc> = "2024-01-02T00:00:00+00:00".parse().unwrap();
assert_eq!(
result,
Some(Constraint::Range {
min: Some(TimestampValue::utc(min_ts).into()),
max: Some(TimestampValue::utc(max_ts).into()),
})
);
}