Skip to main content

laminar_sql/parser/
interval_rewriter.rs

1//! Interval arithmetic rewriter for BIGINT timestamp columns.
2//!
3//! LaminarDB uses BIGINT millisecond timestamps for event time. DataFusion
4//! cannot natively evaluate `Int64 ± INTERVAL`, so this module rewrites
5//! INTERVAL expressions in arithmetic operations to equivalent millisecond
6//! integer literals before the SQL reaches DataFusion.
7//!
8//! # Example
9//!
10//! ```sql
11//! -- Before rewrite:
12//! SELECT * FROM trades t
13//! INNER JOIN orders o ON t.symbol = o.symbol
14//!   AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND
15//!
16//! -- After rewrite:
17//! SELECT * FROM trades t
18//! INNER JOIN orders o ON t.symbol = o.symbol
19//!   AND o.ts BETWEEN t.ts - 10000 AND t.ts + 10000
20//! ```
21
22use sqlparser::ast::{
23    BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24    SetExpr, Statement, Value,
25};
26
27/// Convert an [`Interval`](sqlparser::ast::Interval) to its equivalent milliseconds value.
28///
29/// Returns `None` if the interval cannot be converted (unsupported unit or
30/// non-numeric value).
31fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32    let value = extract_interval_numeric(&interval.value)?;
33    let unit = interval
34        .leading_field
35        .clone()
36        .unwrap_or(DateTimeField::Second);
37
38    let millis = match unit {
39        DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40        DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41        DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42        DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43        DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44        _ => return None,
45    };
46
47    Some(millis)
48}
49
50/// Extract a numeric value from an interval's value expression.
51fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52    match expr {
53        Expr::Value(vws) => match &vws.value {
54            Value::Number(n, _) => n.parse().ok(),
55            Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56            _ => None,
57        },
58        _ => None,
59    }
60}
61
62/// Create a numeric literal `Expr` from a milliseconds value.
63///
64/// Uses sqlparser's own parser to construct the AST node, ensuring
65/// correct internal representation.
66fn make_number_expr(n: i64) -> Expr {
67    use sqlparser::dialect::GenericDialect;
68    let s = n.to_string();
69    let dialect = GenericDialect {};
70    sqlparser::parser::Parser::new(&dialect)
71        .try_with_sql(&s)
72        .expect("number literal should tokenize")
73        .parse_expr()
74        .expect("number literal should parse as Expr")
75}
76
77// ---------------------------------------------------------------------------
78// Expression rewriter
79// ---------------------------------------------------------------------------
80
81/// Rewrite INTERVAL arithmetic in an expression tree, in place.
82///
83/// Converts patterns like `col ± INTERVAL 'N' UNIT` to `col ± <millis>` so
84/// that DataFusion can evaluate the expression when the column is `Int64`.
85pub fn rewrite_expr_mut(expr: &mut Expr) {
86    if let Expr::BinaryOp { left, op, right } = expr {
87        let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89        if is_add_sub {
90            // Check right side for INTERVAL: col ± INTERVAL → col ± millis
91            let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92                interval_to_millis(interval)
93            } else {
94                None
95            };
96
97            if let Some(ms) = right_ms {
98                **right = make_number_expr(ms);
99                rewrite_expr_mut(left);
100                return;
101            }
102
103            // Check left side: INTERVAL + col → millis + col (only addition)
104            if matches!(*op, BinaryOperator::Plus) {
105                let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
106                    interval_to_millis(interval)
107                } else {
108                    None
109                };
110
111                if let Some(ms) = left_ms {
112                    **left = make_number_expr(ms);
113                    rewrite_expr_mut(right);
114                    return;
115                }
116            }
117        }
118
119        // Default: recurse into both sides
120        rewrite_expr_mut(left);
121        rewrite_expr_mut(right);
122        return;
123    }
124
125    // Handle other expression types that can contain sub-expressions
126    match expr {
127        Expr::Between {
128            expr: e, low, high, ..
129        } => {
130            rewrite_expr_mut(e);
131            rewrite_expr_mut(low);
132            rewrite_expr_mut(high);
133        }
134        Expr::InList { expr: e, list, .. } => {
135            rewrite_expr_mut(e);
136            for item in list {
137                rewrite_expr_mut(item);
138            }
139        }
140        Expr::Nested(inner)
141        | Expr::UnaryOp { expr: inner, .. }
142        | Expr::Cast { expr: inner, .. }
143        | Expr::IsNull(inner)
144        | Expr::IsNotNull(inner)
145        | Expr::IsFalse(inner)
146        | Expr::IsNotFalse(inner)
147        | Expr::IsTrue(inner)
148        | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
149        _ => {}
150    }
151}
152
153// ---------------------------------------------------------------------------
154// Statement / query walker
155// ---------------------------------------------------------------------------
156
157/// Rewrite all INTERVAL arithmetic in a SQL [`Statement`].
158///
159/// Walks the full AST and converts `expr ± INTERVAL 'N' UNIT` patterns
160/// to `expr ± <milliseconds>` for BIGINT timestamp compatibility.
161pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
162    if let Statement::Query(query) = stmt {
163        rewrite_query(query);
164    }
165}
166
167fn rewrite_query(query: &mut Query) {
168    rewrite_set_expr(&mut query.body);
169}
170
171fn rewrite_set_expr(body: &mut SetExpr) {
172    match body {
173        SetExpr::Select(select) => rewrite_select(select),
174        SetExpr::Query(query) => rewrite_query(query),
175        SetExpr::SetOperation { left, right, .. } => {
176            rewrite_set_expr(left);
177            rewrite_set_expr(right);
178        }
179        _ => {}
180    }
181}
182
183fn rewrite_select(select: &mut Select) {
184    // Rewrite SELECT projection expressions
185    for item in &mut select.projection {
186        match item {
187            SelectItem::UnnamedExpr(ref mut expr)
188            | SelectItem::ExprWithAlias { ref mut expr, .. } => {
189                rewrite_expr_mut(expr);
190            }
191            _ => {}
192        }
193    }
194
195    // Rewrite WHERE clause
196    if let Some(ref mut where_expr) = select.selection {
197        rewrite_expr_mut(where_expr);
198    }
199
200    // Rewrite HAVING clause
201    if let Some(ref mut having) = select.having {
202        rewrite_expr_mut(having);
203    }
204
205    // Rewrite JOIN ON conditions
206    for table_with_joins in &mut select.from {
207        for join in &mut table_with_joins.joins {
208            rewrite_join_operator(&mut join.join_operator);
209        }
210    }
211}
212
213fn rewrite_join_operator(jo: &mut JoinOperator) {
214    let (JoinOperator::Join(constraint)
215    | JoinOperator::Inner(constraint)
216    | JoinOperator::LeftOuter(constraint)
217    | JoinOperator::RightOuter(constraint)
218    | JoinOperator::FullOuter(constraint)
219    | JoinOperator::LeftSemi(constraint)
220    | JoinOperator::RightSemi(constraint)
221    | JoinOperator::LeftAnti(constraint)
222    | JoinOperator::RightAnti(constraint)) = jo
223    else {
224        return;
225    };
226    if let JoinConstraint::On(expr) = constraint {
227        rewrite_expr_mut(expr);
228    }
229}
230
231// ---------------------------------------------------------------------------
232// Tests
233// ---------------------------------------------------------------------------
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::parser::dialect::LaminarDialect;
239
240    /// Helper: parse SQL, rewrite intervals, return the rewritten SQL string.
241    fn rewrite(sql: &str) -> String {
242        let dialect = LaminarDialect::default();
243        let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
244        assert!(!stmts.is_empty());
245        rewrite_interval_arithmetic(&mut stmts[0]);
246        stmts[0].to_string()
247    }
248
249    // -- Basic arithmetic --
250
251    #[test]
252    fn test_subtract_interval_seconds() {
253        let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
254        assert!(result.contains("ts - 10000"), "got: {result}");
255        assert!(!result.contains("INTERVAL"), "got: {result}");
256    }
257
258    #[test]
259    fn test_add_interval_seconds() {
260        let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
261        assert!(result.contains("ts + 5000"), "got: {result}");
262    }
263
264    #[test]
265    fn test_interval_minutes() {
266        let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
267        assert!(result.contains("ts - 120000"), "got: {result}");
268    }
269
270    #[test]
271    fn test_interval_hours() {
272        let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
273        assert!(result.contains("ts + 3600000"), "got: {result}");
274    }
275
276    #[test]
277    fn test_interval_days() {
278        let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
279        assert!(result.contains("ts - 86400000"), "got: {result}");
280    }
281
282    #[test]
283    fn test_interval_milliseconds() {
284        let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
285        assert!(result.contains("ts - 100"), "got: {result}");
286    }
287
288    // -- WHERE clause --
289
290    #[test]
291    fn test_where_clause_interval() {
292        let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
293        assert!(result.contains("ts2 - 10000"), "got: {result}");
294    }
295
296    // -- BETWEEN (from issue example) --
297
298    #[test]
299    fn test_between_interval() {
300        let result = rewrite(
301            "SELECT * FROM trades t \
302             INNER JOIN orders o ON t.symbol = o.symbol \
303             AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
304        );
305        assert!(result.contains("t.ts - 10000"), "got: {result}");
306        assert!(result.contains("t.ts + 10000"), "got: {result}");
307        assert!(!result.contains("INTERVAL"), "got: {result}");
308    }
309
310    // -- JOIN ON condition --
311
312    #[test]
313    fn test_join_on_interval() {
314        let result = rewrite(
315            "SELECT * FROM a JOIN b ON a.id = b.id \
316             AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
317        );
318        assert!(result.contains("a.ts - 300000"), "got: {result}");
319        assert!(result.contains("a.ts + 300000"), "got: {result}");
320    }
321
322    // -- Nested expressions --
323
324    #[test]
325    fn test_nested_parens() {
326        let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
327        assert!(result.contains("ts - 1000"), "got: {result}");
328    }
329
330    // -- Left-side INTERVAL (INTERVAL + col) --
331
332    #[test]
333    fn test_interval_on_left_side() {
334        let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
335        assert!(result.contains("10000 + ts"), "got: {result}");
336    }
337
338    // -- No-op cases (should not be modified) --
339
340    #[test]
341    fn test_no_interval_unchanged() {
342        let result = rewrite("SELECT ts - 10000 FROM events");
343        assert!(result.contains("ts - 10000"), "got: {result}");
344    }
345
346    #[test]
347    fn test_interval_default_unit_is_second() {
348        // When no unit is specified, sqlparser defaults to SECOND
349        let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
350        assert!(result.contains("ts - 5000"), "got: {result}");
351    }
352
353    // -- Multiple intervals in same query --
354
355    #[test]
356    fn test_multiple_intervals() {
357        let result = rewrite(
358            "SELECT * FROM events \
359             WHERE ts > start_ts - INTERVAL '10' SECOND \
360             AND ts < end_ts + INTERVAL '30' SECOND",
361        );
362        assert!(result.contains("start_ts - 10000"), "got: {result}");
363        assert!(result.contains("end_ts + 30000"), "got: {result}");
364    }
365
366    // -- HAVING clause --
367
368    #[test]
369    fn test_having_clause_interval() {
370        let result = rewrite(
371            "SELECT symbol, COUNT(*) FROM trades \
372             GROUP BY symbol \
373             HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
374        );
375        // The HAVING expression should remain valid; INTERVAL is on the
376        // right side of `>` which is a comparison, not +/-, so untouched
377        // is correct.
378        assert!(result.contains("HAVING"), "got: {result}");
379    }
380}