Skip to main content

laminar_sql/parser/
interval_rewriter.rs

1//! Interval arithmetic rewriter for BIGINT timestamp columns.
2//!
3//! LaminarDB uses BIGINT millisecond timestamps for event time. DataFusion
4//! cannot natively evaluate `Int64 ± INTERVAL`, so this module rewrites
5//! INTERVAL expressions in arithmetic operations to equivalent millisecond
6//! integer literals before the SQL reaches DataFusion.
7//!
8//! # Example
9//!
10//! ```sql
11//! -- Before rewrite:
12//! SELECT * FROM trades t
13//! INNER JOIN orders o ON t.symbol = o.symbol
14//!   AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND
15//!
16//! -- After rewrite:
17//! SELECT * FROM trades t
18//! INNER JOIN orders o ON t.symbol = o.symbol
19//!   AND o.ts BETWEEN t.ts - 10000 AND t.ts + 10000
20//! ```
21
22use sqlparser::ast::{
23    BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24    SetExpr, Statement, Value,
25};
26
27/// Convert an [`Interval`](sqlparser::ast::Interval) to its equivalent milliseconds value.
28///
29/// Returns `None` if the interval cannot be converted (unsupported unit or
30/// non-numeric value).
31fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32    let value = extract_interval_numeric(&interval.value)?;
33    let unit = interval
34        .leading_field
35        .clone()
36        .unwrap_or(DateTimeField::Second);
37
38    let millis = match unit {
39        DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40        DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41        DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42        DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43        DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44        _ => return None,
45    };
46
47    Some(millis)
48}
49
50/// Extract a numeric value from an interval's value expression.
51fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52    match expr {
53        Expr::Value(vws) => match &vws.value {
54            Value::Number(n, _) => n.parse().ok(),
55            Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56            _ => None,
57        },
58        _ => None,
59    }
60}
61
62/// Create a numeric literal `Expr` from a milliseconds value.
63///
64/// Uses sqlparser's own parser to construct the AST node, ensuring
65/// correct internal representation.
66fn make_number_expr(n: i64) -> Expr {
67    use sqlparser::dialect::GenericDialect;
68    let s = n.to_string();
69    let dialect = GenericDialect {};
70    sqlparser::parser::Parser::new(&dialect)
71        .try_with_sql(&s)
72        .expect("number literal should tokenize")
73        .parse_expr()
74        .expect("number literal should parse as Expr")
75}
76
77// ---------------------------------------------------------------------------
78// Expression rewriter
79// ---------------------------------------------------------------------------
80
81/// Rewrite INTERVAL arithmetic in an expression tree, in place.
82///
83/// Converts patterns like `col ± INTERVAL 'N' UNIT` to `col ± <millis>` so
84/// that DataFusion can evaluate the expression when the column is `Int64`.
85pub fn rewrite_expr_mut(expr: &mut Expr) {
86    if let Expr::BinaryOp { left, op, right } = expr {
87        let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89        if is_add_sub {
90            // Check right side for INTERVAL: col ± INTERVAL → col ± millis
91            let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92                interval_to_millis(interval)
93            } else {
94                None
95            };
96
97            if let Some(ms) = right_ms {
98                **right = make_number_expr(ms);
99                rewrite_expr_mut(left);
100                return;
101            }
102
103            // Check left side: INTERVAL + col → millis + col (only addition)
104            if matches!(*op, BinaryOperator::Plus) {
105                let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
106                    interval_to_millis(interval)
107                } else {
108                    None
109                };
110
111                if let Some(ms) = left_ms {
112                    **left = make_number_expr(ms);
113                    rewrite_expr_mut(right);
114                    return;
115                }
116            }
117        }
118
119        // Default: recurse into both sides
120        rewrite_expr_mut(left);
121        rewrite_expr_mut(right);
122        return;
123    }
124
125    // Handle other expression types that can contain sub-expressions
126    match expr {
127        Expr::Between {
128            expr: e, low, high, ..
129        } => {
130            rewrite_expr_mut(e);
131            rewrite_expr_mut(low);
132            rewrite_expr_mut(high);
133        }
134        Expr::InList { expr: e, list, .. } => {
135            rewrite_expr_mut(e);
136            for item in list {
137                rewrite_expr_mut(item);
138            }
139        }
140        Expr::Nested(inner)
141        | Expr::UnaryOp { expr: inner, .. }
142        | Expr::Cast { expr: inner, .. }
143        | Expr::IsNull(inner)
144        | Expr::IsNotNull(inner)
145        | Expr::IsFalse(inner)
146        | Expr::IsNotFalse(inner)
147        | Expr::IsTrue(inner)
148        | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
149        _ => {}
150    }
151}
152
153// ---------------------------------------------------------------------------
154// Statement / query walker
155// ---------------------------------------------------------------------------
156
157/// Rewrite all INTERVAL arithmetic in a SQL [`Statement`].
158///
159/// Walks the full AST and converts `expr ± INTERVAL 'N' UNIT` patterns
160/// to `expr ± <milliseconds>` for BIGINT timestamp compatibility.
161pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
162    if let Statement::Query(query) = stmt {
163        rewrite_query(query);
164    }
165}
166
167fn rewrite_query(query: &mut Query) {
168    rewrite_set_expr(&mut query.body);
169}
170
171fn rewrite_set_expr(body: &mut SetExpr) {
172    match body {
173        SetExpr::Select(select) => rewrite_select(select),
174        SetExpr::Query(query) => rewrite_query(query),
175        SetExpr::SetOperation { left, right, .. } => {
176            rewrite_set_expr(left);
177            rewrite_set_expr(right);
178        }
179        _ => {}
180    }
181}
182
183fn rewrite_select(select: &mut Select) {
184    // Rewrite SELECT projection expressions
185    for item in &mut select.projection {
186        match item {
187            SelectItem::UnnamedExpr(ref mut expr)
188            | SelectItem::ExprWithAlias { ref mut expr, .. } => {
189                rewrite_expr_mut(expr);
190            }
191            _ => {}
192        }
193    }
194
195    // Rewrite WHERE clause
196    if let Some(ref mut where_expr) = select.selection {
197        rewrite_expr_mut(where_expr);
198    }
199
200    // Rewrite HAVING clause
201    if let Some(ref mut having) = select.having {
202        rewrite_expr_mut(having);
203    }
204
205    // Rewrite JOIN ON conditions
206    for table_with_joins in &mut select.from {
207        for join in &mut table_with_joins.joins {
208            rewrite_join_operator(&mut join.join_operator);
209        }
210    }
211}
212
213fn rewrite_join_operator(jo: &mut JoinOperator) {
214    let (JoinOperator::Join(constraint)
215    | JoinOperator::Inner(constraint)
216    | JoinOperator::Left(constraint)
217    | JoinOperator::LeftOuter(constraint)
218    | JoinOperator::Right(constraint)
219    | JoinOperator::RightOuter(constraint)
220    | JoinOperator::FullOuter(constraint)
221    | JoinOperator::StraightJoin(constraint)
222    | JoinOperator::LeftSemi(constraint)
223    | JoinOperator::RightSemi(constraint)
224    | JoinOperator::LeftAnti(constraint)
225    | JoinOperator::RightAnti(constraint)
226    | JoinOperator::Semi(constraint)
227    | JoinOperator::Anti(constraint)) = jo
228    else {
229        return;
230    };
231    if let JoinConstraint::On(expr) = constraint {
232        rewrite_expr_mut(expr);
233    }
234}
235
236// ---------------------------------------------------------------------------
237// Tests
238// ---------------------------------------------------------------------------
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::parser::dialect::LaminarDialect;
244
245    /// Helper: parse SQL, rewrite intervals, return the rewritten SQL string.
246    fn rewrite(sql: &str) -> String {
247        let dialect = LaminarDialect::default();
248        let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
249        assert!(!stmts.is_empty());
250        rewrite_interval_arithmetic(&mut stmts[0]);
251        stmts[0].to_string()
252    }
253
254    // -- Basic arithmetic --
255
256    #[test]
257    fn test_subtract_interval_seconds() {
258        let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
259        assert!(result.contains("ts - 10000"), "got: {result}");
260        assert!(!result.contains("INTERVAL"), "got: {result}");
261    }
262
263    #[test]
264    fn test_add_interval_seconds() {
265        let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
266        assert!(result.contains("ts + 5000"), "got: {result}");
267    }
268
269    #[test]
270    fn test_interval_minutes() {
271        let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
272        assert!(result.contains("ts - 120000"), "got: {result}");
273    }
274
275    #[test]
276    fn test_interval_hours() {
277        let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
278        assert!(result.contains("ts + 3600000"), "got: {result}");
279    }
280
281    #[test]
282    fn test_interval_days() {
283        let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
284        assert!(result.contains("ts - 86400000"), "got: {result}");
285    }
286
287    #[test]
288    fn test_interval_milliseconds() {
289        let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
290        assert!(result.contains("ts - 100"), "got: {result}");
291    }
292
293    // -- WHERE clause --
294
295    #[test]
296    fn test_where_clause_interval() {
297        let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
298        assert!(result.contains("ts2 - 10000"), "got: {result}");
299    }
300
301    // -- BETWEEN (from issue example) --
302
303    #[test]
304    fn test_between_interval() {
305        let result = rewrite(
306            "SELECT * FROM trades t \
307             INNER JOIN orders o ON t.symbol = o.symbol \
308             AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
309        );
310        assert!(result.contains("t.ts - 10000"), "got: {result}");
311        assert!(result.contains("t.ts + 10000"), "got: {result}");
312        assert!(!result.contains("INTERVAL"), "got: {result}");
313    }
314
315    // -- JOIN ON condition --
316
317    #[test]
318    fn test_join_on_interval() {
319        let result = rewrite(
320            "SELECT * FROM a JOIN b ON a.id = b.id \
321             AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
322        );
323        assert!(result.contains("a.ts - 300000"), "got: {result}");
324        assert!(result.contains("a.ts + 300000"), "got: {result}");
325    }
326
327    // -- Nested expressions --
328
329    #[test]
330    fn test_nested_parens() {
331        let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
332        assert!(result.contains("ts - 1000"), "got: {result}");
333    }
334
335    // -- Left-side INTERVAL (INTERVAL + col) --
336
337    #[test]
338    fn test_interval_on_left_side() {
339        let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
340        assert!(result.contains("10000 + ts"), "got: {result}");
341    }
342
343    // -- No-op cases (should not be modified) --
344
345    #[test]
346    fn test_no_interval_unchanged() {
347        let result = rewrite("SELECT ts - 10000 FROM events");
348        assert!(result.contains("ts - 10000"), "got: {result}");
349    }
350
351    #[test]
352    fn test_interval_default_unit_is_second() {
353        // When no unit is specified, sqlparser defaults to SECOND
354        let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
355        assert!(result.contains("ts - 5000"), "got: {result}");
356    }
357
358    // -- Multiple intervals in same query --
359
360    #[test]
361    fn test_multiple_intervals() {
362        let result = rewrite(
363            "SELECT * FROM events \
364             WHERE ts > start_ts - INTERVAL '10' SECOND \
365             AND ts < end_ts + INTERVAL '30' SECOND",
366        );
367        assert!(result.contains("start_ts - 10000"), "got: {result}");
368        assert!(result.contains("end_ts + 30000"), "got: {result}");
369    }
370
371    // -- HAVING clause --
372
373    #[test]
374    fn test_having_clause_interval() {
375        let result = rewrite(
376            "SELECT symbol, COUNT(*) FROM trades \
377             GROUP BY symbol \
378             HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
379        );
380        // The HAVING expression should remain valid; INTERVAL is on the
381        // right side of `>` which is a comparison, not +/-, so untouched
382        // is correct.
383        assert!(result.contains("HAVING"), "got: {result}");
384    }
385}