laminar_sql/parser/
interval_rewriter.rs1use sqlparser::ast::{
23 BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24 SetExpr, Statement, Value,
25};
26
27fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32 let value = extract_interval_numeric(&interval.value)?;
33 let unit = interval
34 .leading_field
35 .clone()
36 .unwrap_or(DateTimeField::Second);
37
38 let millis = match unit {
39 DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40 DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41 DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42 DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43 DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44 _ => return None,
45 };
46
47 Some(millis)
48}
49
50fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52 match expr {
53 Expr::Value(vws) => match &vws.value {
54 Value::Number(n, _) => n.parse().ok(),
55 Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56 _ => None,
57 },
58 _ => None,
59 }
60}
61
62fn make_number_expr(n: i64) -> Expr {
67 use sqlparser::dialect::GenericDialect;
68 let s = n.to_string();
69 let dialect = GenericDialect {};
70 sqlparser::parser::Parser::new(&dialect)
71 .try_with_sql(&s)
72 .expect("number literal should tokenize")
73 .parse_expr()
74 .expect("number literal should parse as Expr")
75}
76
77pub fn rewrite_expr_mut(expr: &mut Expr) {
86 if let Expr::BinaryOp { left, op, right } = expr {
87 let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89 if is_add_sub {
90 let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92 interval_to_millis(interval)
93 } else {
94 None
95 };
96
97 if let Some(ms) = right_ms {
98 **right = make_number_expr(ms);
99 rewrite_expr_mut(left);
100 return;
101 }
102
103 if matches!(*op, BinaryOperator::Plus) {
105 let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
106 interval_to_millis(interval)
107 } else {
108 None
109 };
110
111 if let Some(ms) = left_ms {
112 **left = make_number_expr(ms);
113 rewrite_expr_mut(right);
114 return;
115 }
116 }
117 }
118
119 rewrite_expr_mut(left);
121 rewrite_expr_mut(right);
122 return;
123 }
124
125 match expr {
127 Expr::Between {
128 expr: e, low, high, ..
129 } => {
130 rewrite_expr_mut(e);
131 rewrite_expr_mut(low);
132 rewrite_expr_mut(high);
133 }
134 Expr::InList { expr: e, list, .. } => {
135 rewrite_expr_mut(e);
136 for item in list {
137 rewrite_expr_mut(item);
138 }
139 }
140 Expr::Nested(inner)
141 | Expr::UnaryOp { expr: inner, .. }
142 | Expr::Cast { expr: inner, .. }
143 | Expr::IsNull(inner)
144 | Expr::IsNotNull(inner)
145 | Expr::IsFalse(inner)
146 | Expr::IsNotFalse(inner)
147 | Expr::IsTrue(inner)
148 | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
149 _ => {}
150 }
151}
152
153pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
162 if let Statement::Query(query) = stmt {
163 rewrite_query(query);
164 }
165}
166
167fn rewrite_query(query: &mut Query) {
168 rewrite_set_expr(&mut query.body);
169}
170
171fn rewrite_set_expr(body: &mut SetExpr) {
172 match body {
173 SetExpr::Select(select) => rewrite_select(select),
174 SetExpr::Query(query) => rewrite_query(query),
175 SetExpr::SetOperation { left, right, .. } => {
176 rewrite_set_expr(left);
177 rewrite_set_expr(right);
178 }
179 _ => {}
180 }
181}
182
183fn rewrite_select(select: &mut Select) {
184 for item in &mut select.projection {
186 match item {
187 SelectItem::UnnamedExpr(ref mut expr)
188 | SelectItem::ExprWithAlias { ref mut expr, .. } => {
189 rewrite_expr_mut(expr);
190 }
191 _ => {}
192 }
193 }
194
195 if let Some(ref mut where_expr) = select.selection {
197 rewrite_expr_mut(where_expr);
198 }
199
200 if let Some(ref mut having) = select.having {
202 rewrite_expr_mut(having);
203 }
204
205 for table_with_joins in &mut select.from {
207 for join in &mut table_with_joins.joins {
208 rewrite_join_operator(&mut join.join_operator);
209 }
210 }
211}
212
213fn rewrite_join_operator(jo: &mut JoinOperator) {
214 let (JoinOperator::Join(constraint)
215 | JoinOperator::Inner(constraint)
216 | JoinOperator::Left(constraint)
217 | JoinOperator::LeftOuter(constraint)
218 | JoinOperator::Right(constraint)
219 | JoinOperator::RightOuter(constraint)
220 | JoinOperator::FullOuter(constraint)
221 | JoinOperator::StraightJoin(constraint)
222 | JoinOperator::LeftSemi(constraint)
223 | JoinOperator::RightSemi(constraint)
224 | JoinOperator::LeftAnti(constraint)
225 | JoinOperator::RightAnti(constraint)
226 | JoinOperator::Semi(constraint)
227 | JoinOperator::Anti(constraint)) = jo
228 else {
229 return;
230 };
231 if let JoinConstraint::On(expr) = constraint {
232 rewrite_expr_mut(expr);
233 }
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::parser::dialect::LaminarDialect;
244
245 fn rewrite(sql: &str) -> String {
247 let dialect = LaminarDialect::default();
248 let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
249 assert!(!stmts.is_empty());
250 rewrite_interval_arithmetic(&mut stmts[0]);
251 stmts[0].to_string()
252 }
253
254 #[test]
257 fn test_subtract_interval_seconds() {
258 let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
259 assert!(result.contains("ts - 10000"), "got: {result}");
260 assert!(!result.contains("INTERVAL"), "got: {result}");
261 }
262
263 #[test]
264 fn test_add_interval_seconds() {
265 let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
266 assert!(result.contains("ts + 5000"), "got: {result}");
267 }
268
269 #[test]
270 fn test_interval_minutes() {
271 let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
272 assert!(result.contains("ts - 120000"), "got: {result}");
273 }
274
275 #[test]
276 fn test_interval_hours() {
277 let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
278 assert!(result.contains("ts + 3600000"), "got: {result}");
279 }
280
281 #[test]
282 fn test_interval_days() {
283 let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
284 assert!(result.contains("ts - 86400000"), "got: {result}");
285 }
286
287 #[test]
288 fn test_interval_milliseconds() {
289 let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
290 assert!(result.contains("ts - 100"), "got: {result}");
291 }
292
293 #[test]
296 fn test_where_clause_interval() {
297 let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
298 assert!(result.contains("ts2 - 10000"), "got: {result}");
299 }
300
301 #[test]
304 fn test_between_interval() {
305 let result = rewrite(
306 "SELECT * FROM trades t \
307 INNER JOIN orders o ON t.symbol = o.symbol \
308 AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
309 );
310 assert!(result.contains("t.ts - 10000"), "got: {result}");
311 assert!(result.contains("t.ts + 10000"), "got: {result}");
312 assert!(!result.contains("INTERVAL"), "got: {result}");
313 }
314
315 #[test]
318 fn test_join_on_interval() {
319 let result = rewrite(
320 "SELECT * FROM a JOIN b ON a.id = b.id \
321 AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
322 );
323 assert!(result.contains("a.ts - 300000"), "got: {result}");
324 assert!(result.contains("a.ts + 300000"), "got: {result}");
325 }
326
327 #[test]
330 fn test_nested_parens() {
331 let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
332 assert!(result.contains("ts - 1000"), "got: {result}");
333 }
334
335 #[test]
338 fn test_interval_on_left_side() {
339 let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
340 assert!(result.contains("10000 + ts"), "got: {result}");
341 }
342
343 #[test]
346 fn test_no_interval_unchanged() {
347 let result = rewrite("SELECT ts - 10000 FROM events");
348 assert!(result.contains("ts - 10000"), "got: {result}");
349 }
350
351 #[test]
352 fn test_interval_default_unit_is_second() {
353 let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
355 assert!(result.contains("ts - 5000"), "got: {result}");
356 }
357
358 #[test]
361 fn test_multiple_intervals() {
362 let result = rewrite(
363 "SELECT * FROM events \
364 WHERE ts > start_ts - INTERVAL '10' SECOND \
365 AND ts < end_ts + INTERVAL '30' SECOND",
366 );
367 assert!(result.contains("start_ts - 10000"), "got: {result}");
368 assert!(result.contains("end_ts + 30000"), "got: {result}");
369 }
370
371 #[test]
374 fn test_having_clause_interval() {
375 let result = rewrite(
376 "SELECT symbol, COUNT(*) FROM trades \
377 GROUP BY symbol \
378 HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
379 );
380 assert!(result.contains("HAVING"), "got: {result}");
384 }
385}