laminar_sql/parser/
interval_rewriter.rs1use sqlparser::ast::{
23 BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24 SetExpr, Statement, Value,
25};
26
27fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32 let value = extract_interval_numeric(&interval.value)?;
33 let unit = interval
34 .leading_field
35 .clone()
36 .unwrap_or(DateTimeField::Second);
37
38 let millis = match unit {
39 DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40 DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41 DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42 DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43 DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44 _ => return None,
45 };
46
47 Some(millis)
48}
49
50fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52 match expr {
53 Expr::Value(vws) => match &vws.value {
54 Value::Number(n, _) => n.parse().ok(),
55 Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56 _ => None,
57 },
58 _ => None,
59 }
60}
61
62fn make_number_expr(n: i64) -> Expr {
67 use sqlparser::dialect::GenericDialect;
68 let s = n.to_string();
69 let dialect = GenericDialect {};
70 sqlparser::parser::Parser::new(&dialect)
71 .try_with_sql(&s)
72 .expect("number literal should tokenize")
73 .parse_expr()
74 .expect("number literal should parse as Expr")
75}
76
77pub fn rewrite_expr_mut(expr: &mut Expr) {
86 if let Expr::BinaryOp { left, op, right } = expr {
87 let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89 if is_add_sub {
90 let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92 interval_to_millis(interval)
93 } else {
94 None
95 };
96
97 if let Some(ms) = right_ms {
98 **right = make_number_expr(ms);
99 rewrite_expr_mut(left);
100 return;
101 }
102
103 if matches!(*op, BinaryOperator::Plus) {
105 let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
106 interval_to_millis(interval)
107 } else {
108 None
109 };
110
111 if let Some(ms) = left_ms {
112 **left = make_number_expr(ms);
113 rewrite_expr_mut(right);
114 return;
115 }
116 }
117 }
118
119 rewrite_expr_mut(left);
121 rewrite_expr_mut(right);
122 return;
123 }
124
125 match expr {
127 Expr::Between {
128 expr: e, low, high, ..
129 } => {
130 rewrite_expr_mut(e);
131 rewrite_expr_mut(low);
132 rewrite_expr_mut(high);
133 }
134 Expr::InList { expr: e, list, .. } => {
135 rewrite_expr_mut(e);
136 for item in list {
137 rewrite_expr_mut(item);
138 }
139 }
140 Expr::Nested(inner)
141 | Expr::UnaryOp { expr: inner, .. }
142 | Expr::Cast { expr: inner, .. }
143 | Expr::IsNull(inner)
144 | Expr::IsNotNull(inner)
145 | Expr::IsFalse(inner)
146 | Expr::IsNotFalse(inner)
147 | Expr::IsTrue(inner)
148 | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
149 _ => {}
150 }
151}
152
153pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
162 if let Statement::Query(query) = stmt {
163 rewrite_query(query);
164 }
165}
166
167fn rewrite_query(query: &mut Query) {
168 rewrite_set_expr(&mut query.body);
169}
170
171fn rewrite_set_expr(body: &mut SetExpr) {
172 match body {
173 SetExpr::Select(select) => rewrite_select(select),
174 SetExpr::Query(query) => rewrite_query(query),
175 SetExpr::SetOperation { left, right, .. } => {
176 rewrite_set_expr(left);
177 rewrite_set_expr(right);
178 }
179 _ => {}
180 }
181}
182
183fn rewrite_select(select: &mut Select) {
184 for item in &mut select.projection {
186 match item {
187 SelectItem::UnnamedExpr(ref mut expr)
188 | SelectItem::ExprWithAlias { ref mut expr, .. } => {
189 rewrite_expr_mut(expr);
190 }
191 _ => {}
192 }
193 }
194
195 if let Some(ref mut where_expr) = select.selection {
197 rewrite_expr_mut(where_expr);
198 }
199
200 if let Some(ref mut having) = select.having {
202 rewrite_expr_mut(having);
203 }
204
205 for table_with_joins in &mut select.from {
207 for join in &mut table_with_joins.joins {
208 rewrite_join_operator(&mut join.join_operator);
209 }
210 }
211}
212
213fn rewrite_join_operator(jo: &mut JoinOperator) {
214 let (JoinOperator::Join(constraint)
215 | JoinOperator::Inner(constraint)
216 | JoinOperator::LeftOuter(constraint)
217 | JoinOperator::RightOuter(constraint)
218 | JoinOperator::FullOuter(constraint)
219 | JoinOperator::LeftSemi(constraint)
220 | JoinOperator::RightSemi(constraint)
221 | JoinOperator::LeftAnti(constraint)
222 | JoinOperator::RightAnti(constraint)) = jo
223 else {
224 return;
225 };
226 if let JoinConstraint::On(expr) = constraint {
227 rewrite_expr_mut(expr);
228 }
229}
230
231#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::parser::dialect::LaminarDialect;
239
240 fn rewrite(sql: &str) -> String {
242 let dialect = LaminarDialect::default();
243 let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
244 assert!(!stmts.is_empty());
245 rewrite_interval_arithmetic(&mut stmts[0]);
246 stmts[0].to_string()
247 }
248
249 #[test]
252 fn test_subtract_interval_seconds() {
253 let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
254 assert!(result.contains("ts - 10000"), "got: {result}");
255 assert!(!result.contains("INTERVAL"), "got: {result}");
256 }
257
258 #[test]
259 fn test_add_interval_seconds() {
260 let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
261 assert!(result.contains("ts + 5000"), "got: {result}");
262 }
263
264 #[test]
265 fn test_interval_minutes() {
266 let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
267 assert!(result.contains("ts - 120000"), "got: {result}");
268 }
269
270 #[test]
271 fn test_interval_hours() {
272 let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
273 assert!(result.contains("ts + 3600000"), "got: {result}");
274 }
275
276 #[test]
277 fn test_interval_days() {
278 let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
279 assert!(result.contains("ts - 86400000"), "got: {result}");
280 }
281
282 #[test]
283 fn test_interval_milliseconds() {
284 let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
285 assert!(result.contains("ts - 100"), "got: {result}");
286 }
287
288 #[test]
291 fn test_where_clause_interval() {
292 let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
293 assert!(result.contains("ts2 - 10000"), "got: {result}");
294 }
295
296 #[test]
299 fn test_between_interval() {
300 let result = rewrite(
301 "SELECT * FROM trades t \
302 INNER JOIN orders o ON t.symbol = o.symbol \
303 AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
304 );
305 assert!(result.contains("t.ts - 10000"), "got: {result}");
306 assert!(result.contains("t.ts + 10000"), "got: {result}");
307 assert!(!result.contains("INTERVAL"), "got: {result}");
308 }
309
310 #[test]
313 fn test_join_on_interval() {
314 let result = rewrite(
315 "SELECT * FROM a JOIN b ON a.id = b.id \
316 AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
317 );
318 assert!(result.contains("a.ts - 300000"), "got: {result}");
319 assert!(result.contains("a.ts + 300000"), "got: {result}");
320 }
321
322 #[test]
325 fn test_nested_parens() {
326 let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
327 assert!(result.contains("ts - 1000"), "got: {result}");
328 }
329
330 #[test]
333 fn test_interval_on_left_side() {
334 let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
335 assert!(result.contains("10000 + ts"), "got: {result}");
336 }
337
338 #[test]
341 fn test_no_interval_unchanged() {
342 let result = rewrite("SELECT ts - 10000 FROM events");
343 assert!(result.contains("ts - 10000"), "got: {result}");
344 }
345
346 #[test]
347 fn test_interval_default_unit_is_second() {
348 let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
350 assert!(result.contains("ts - 5000"), "got: {result}");
351 }
352
353 #[test]
356 fn test_multiple_intervals() {
357 let result = rewrite(
358 "SELECT * FROM events \
359 WHERE ts > start_ts - INTERVAL '10' SECOND \
360 AND ts < end_ts + INTERVAL '30' SECOND",
361 );
362 assert!(result.contains("start_ts - 10000"), "got: {result}");
363 assert!(result.contains("end_ts + 30000"), "got: {result}");
364 }
365
366 #[test]
369 fn test_having_clause_interval() {
370 let result = rewrite(
371 "SELECT symbol, COUNT(*) FROM trades \
372 GROUP BY symbol \
373 HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
374 );
375 assert!(result.contains("HAVING"), "got: {result}");
379 }
380}