laminar_sql/parser/
interval_rewriter.rs1use sqlparser::ast::{
23 BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24 SetExpr, Statement, Value,
25};
26
27fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32 let value = extract_interval_numeric(&interval.value)?;
33 let unit = interval
34 .leading_field
35 .clone()
36 .unwrap_or(DateTimeField::Second);
37
38 let millis = match unit {
39 DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40 DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41 DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42 DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43 DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44 _ => return None,
45 };
46
47 Some(millis)
48}
49
50fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52 match expr {
53 Expr::Value(vws) => match &vws.value {
54 Value::Number(n, _) => n.parse().ok(),
55 Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56 _ => None,
57 },
58 _ => None,
59 }
60}
61
62fn make_number_expr(n: i64) -> Option<Expr> {
67 use sqlparser::dialect::GenericDialect;
68 let s = n.to_string();
69 let dialect = GenericDialect {};
70 sqlparser::parser::Parser::new(&dialect)
71 .try_with_sql(&s)
72 .ok()?
73 .parse_expr()
74 .ok()
75}
76
77pub fn rewrite_expr_mut(expr: &mut Expr) {
86 if let Expr::BinaryOp { left, op, right } = expr {
87 let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89 if is_add_sub {
90 let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92 interval_to_millis(interval)
93 } else {
94 None
95 };
96
97 if let Some(ms) = right_ms {
98 if let Some(num_expr) = make_number_expr(ms) {
99 **right = num_expr;
100 rewrite_expr_mut(left);
101 return;
102 }
103 }
104
105 if matches!(*op, BinaryOperator::Plus) {
107 let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
108 interval_to_millis(interval)
109 } else {
110 None
111 };
112
113 if let Some(ms) = left_ms {
114 if let Some(num_expr) = make_number_expr(ms) {
115 **left = num_expr;
116 rewrite_expr_mut(right);
117 return;
118 }
119 }
120 }
121 }
122
123 rewrite_expr_mut(left);
125 rewrite_expr_mut(right);
126 return;
127 }
128
129 match expr {
131 Expr::Between {
132 expr: e, low, high, ..
133 } => {
134 rewrite_expr_mut(e);
135 rewrite_expr_mut(low);
136 rewrite_expr_mut(high);
137 }
138 Expr::InList { expr: e, list, .. } => {
139 rewrite_expr_mut(e);
140 for item in list {
141 rewrite_expr_mut(item);
142 }
143 }
144 Expr::Nested(inner)
145 | Expr::UnaryOp { expr: inner, .. }
146 | Expr::Cast { expr: inner, .. }
147 | Expr::IsNull(inner)
148 | Expr::IsNotNull(inner)
149 | Expr::IsFalse(inner)
150 | Expr::IsNotFalse(inner)
151 | Expr::IsTrue(inner)
152 | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
153 _ => {}
154 }
155}
156
157pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
166 if let Statement::Query(query) = stmt {
167 rewrite_query(query);
168 }
169}
170
171fn rewrite_query(query: &mut Query) {
172 rewrite_set_expr(&mut query.body);
173}
174
175fn rewrite_set_expr(body: &mut SetExpr) {
176 match body {
177 SetExpr::Select(select) => rewrite_select(select),
178 SetExpr::Query(query) => rewrite_query(query),
179 SetExpr::SetOperation { left, right, .. } => {
180 rewrite_set_expr(left);
181 rewrite_set_expr(right);
182 }
183 _ => {}
184 }
185}
186
187fn rewrite_select(select: &mut Select) {
188 for item in &mut select.projection {
190 match item {
191 SelectItem::UnnamedExpr(ref mut expr)
192 | SelectItem::ExprWithAlias { ref mut expr, .. } => {
193 rewrite_expr_mut(expr);
194 }
195 _ => {}
196 }
197 }
198
199 if let Some(ref mut where_expr) = select.selection {
201 rewrite_expr_mut(where_expr);
202 }
203
204 if let Some(ref mut having) = select.having {
206 rewrite_expr_mut(having);
207 }
208
209 for table_with_joins in &mut select.from {
211 for join in &mut table_with_joins.joins {
212 rewrite_join_operator(&mut join.join_operator);
213 }
214 }
215}
216
217fn rewrite_join_operator(jo: &mut JoinOperator) {
218 let (JoinOperator::Join(constraint)
219 | JoinOperator::Inner(constraint)
220 | JoinOperator::Left(constraint)
221 | JoinOperator::LeftOuter(constraint)
222 | JoinOperator::Right(constraint)
223 | JoinOperator::RightOuter(constraint)
224 | JoinOperator::FullOuter(constraint)
225 | JoinOperator::StraightJoin(constraint)
226 | JoinOperator::LeftSemi(constraint)
227 | JoinOperator::RightSemi(constraint)
228 | JoinOperator::LeftAnti(constraint)
229 | JoinOperator::RightAnti(constraint)
230 | JoinOperator::Semi(constraint)
231 | JoinOperator::Anti(constraint)) = jo
232 else {
233 return;
234 };
235 if let JoinConstraint::On(expr) = constraint {
236 rewrite_expr_mut(expr);
237 }
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::parser::dialect::LaminarDialect;
248
249 fn rewrite(sql: &str) -> String {
251 let dialect = LaminarDialect::default();
252 let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
253 assert!(!stmts.is_empty());
254 rewrite_interval_arithmetic(&mut stmts[0]);
255 stmts[0].to_string()
256 }
257
258 #[test]
261 fn test_subtract_interval_seconds() {
262 let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
263 assert!(result.contains("ts - 10000"), "got: {result}");
264 assert!(!result.contains("INTERVAL"), "got: {result}");
265 }
266
267 #[test]
268 fn test_add_interval_seconds() {
269 let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
270 assert!(result.contains("ts + 5000"), "got: {result}");
271 }
272
273 #[test]
274 fn test_interval_minutes() {
275 let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
276 assert!(result.contains("ts - 120000"), "got: {result}");
277 }
278
279 #[test]
280 fn test_interval_hours() {
281 let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
282 assert!(result.contains("ts + 3600000"), "got: {result}");
283 }
284
285 #[test]
286 fn test_interval_days() {
287 let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
288 assert!(result.contains("ts - 86400000"), "got: {result}");
289 }
290
291 #[test]
292 fn test_interval_milliseconds() {
293 let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
294 assert!(result.contains("ts - 100"), "got: {result}");
295 }
296
297 #[test]
300 fn test_where_clause_interval() {
301 let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
302 assert!(result.contains("ts2 - 10000"), "got: {result}");
303 }
304
305 #[test]
308 fn test_between_interval() {
309 let result = rewrite(
310 "SELECT * FROM trades t \
311 INNER JOIN orders o ON t.symbol = o.symbol \
312 AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
313 );
314 assert!(result.contains("t.ts - 10000"), "got: {result}");
315 assert!(result.contains("t.ts + 10000"), "got: {result}");
316 assert!(!result.contains("INTERVAL"), "got: {result}");
317 }
318
319 #[test]
322 fn test_join_on_interval() {
323 let result = rewrite(
324 "SELECT * FROM a JOIN b ON a.id = b.id \
325 AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
326 );
327 assert!(result.contains("a.ts - 300000"), "got: {result}");
328 assert!(result.contains("a.ts + 300000"), "got: {result}");
329 }
330
331 #[test]
334 fn test_nested_parens() {
335 let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
336 assert!(result.contains("ts - 1000"), "got: {result}");
337 }
338
339 #[test]
342 fn test_interval_on_left_side() {
343 let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
344 assert!(result.contains("10000 + ts"), "got: {result}");
345 }
346
347 #[test]
350 fn test_no_interval_unchanged() {
351 let result = rewrite("SELECT ts - 10000 FROM events");
352 assert!(result.contains("ts - 10000"), "got: {result}");
353 }
354
355 #[test]
356 fn test_interval_default_unit_is_second() {
357 let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
359 assert!(result.contains("ts - 5000"), "got: {result}");
360 }
361
362 #[test]
365 fn test_multiple_intervals() {
366 let result = rewrite(
367 "SELECT * FROM events \
368 WHERE ts > start_ts - INTERVAL '10' SECOND \
369 AND ts < end_ts + INTERVAL '30' SECOND",
370 );
371 assert!(result.contains("start_ts - 10000"), "got: {result}");
372 assert!(result.contains("end_ts + 30000"), "got: {result}");
373 }
374
375 #[test]
378 fn test_having_clause_interval() {
379 let result = rewrite(
380 "SELECT symbol, COUNT(*) FROM trades \
381 GROUP BY symbol \
382 HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
383 );
384 assert!(result.contains("HAVING"), "got: {result}");
388 }
389}