1use datafusion::error::DataFusionError;
2use datafusion::sql::sqlparser::ast::{
3 self, BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArgumentList, Ident,
4 ObjectNamePart, VisitorMut,
5};
6use std::fmt::Display;
7use std::ops::ControlFlow;
8use std::str::FromStr;
9
10#[derive(Default)]
11pub struct SQLiteIntervalVisitor {}
12
13#[derive(Default, Debug)]
14struct IntervalParts {
15 years: i64,
16 months: i64,
17 days: i64,
18 hours: i64,
19 minutes: i64,
20 seconds: i64,
21 nanos: u32,
22}
23
24enum SQLiteIntervalType {
25 Date,
26 Datetime,
27}
28
29impl Display for SQLiteIntervalType {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 SQLiteIntervalType::Date => write!(f, "date"),
33 SQLiteIntervalType::Datetime => write!(f, "datetime"),
34 }
35 }
36}
37
38type IntervalSetter = fn(IntervalParts, i64) -> IntervalParts;
39
40impl IntervalParts {
41 fn new() -> Self {
42 Self::default()
43 }
44
45 fn intraday(&self) -> bool {
46 self.hours > 0 || self.minutes > 0 || self.seconds > 0 || self.nanos > 0
47 }
48
49 fn negate(mut self) -> Self {
50 self.years = -self.years;
51 self.months = -self.months;
52 self.days = -self.days;
53 self.hours = -self.hours;
54 self.minutes = -self.minutes;
55 self.seconds = -self.seconds;
56 self
57 }
58
59 fn with_years(mut self, years: i64) -> Self {
60 self.years = years;
61 self
62 }
63
64 fn with_months(mut self, months: i64) -> Self {
65 self.months = months;
66 self
67 }
68
69 fn with_days(mut self, days: i64) -> Self {
70 self.days = days;
71 self
72 }
73
74 fn with_hours(mut self, hours: i64) -> Self {
75 self.hours = hours;
76 self
77 }
78
79 fn with_minutes(mut self, minutes: i64) -> Self {
80 self.minutes = minutes;
81 self
82 }
83
84 fn with_seconds(mut self, seconds: i64) -> Self {
85 self.seconds = seconds;
86 self
87 }
88
89 fn with_nanos(mut self, nanos: u32) -> Self {
90 self.nanos = nanos;
91 self
92 }
93}
94
95impl VisitorMut for SQLiteIntervalVisitor {
96 type Break = ();
97
98 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
99 if let Expr::BinaryOp { op, left, right } = expr {
107 if *op != BinaryOperator::Plus && *op != BinaryOperator::Minus {
108 return ControlFlow::Continue(());
109 }
110
111 let (target, interval) = SQLiteIntervalVisitor::normalize_interval_expr(left, right);
112
113 if let Expr::Interval(_) = interval.as_ref() {
114 if let Ok(interval_parts) = SQLiteIntervalVisitor::parse_interval(interval) {
117 let interval_parts = if *op == BinaryOperator::Minus {
119 interval_parts.negate()
120 } else {
121 interval_parts
122 };
123
124 *expr =
125 SQLiteIntervalVisitor::create_datetime_function(target, &interval_parts);
126 }
127 }
128 }
129 ControlFlow::Continue(())
130 }
131}
132
133impl SQLiteIntervalVisitor {
134 fn normalize_interval_expr<'a>(
136 left: &'a mut Box<Expr>,
137 right: &'a mut Box<Expr>,
138 ) -> (&'a mut Box<Expr>, &'a mut Box<Expr>) {
139 if let Expr::Interval { .. } = left.as_ref() {
140 (right, left)
141 } else {
142 (left, right)
143 }
144 }
145
146 fn parse_interval(interval: &Expr) -> Result<IntervalParts, DataFusionError> {
147 if let Expr::Interval(interval_expr) = interval {
148 if let Expr::Value(ast::ValueWithSpan {
149 value: ast::Value::SingleQuotedString(value),
150 span: _,
151 }) = interval_expr.value.as_ref()
152 {
153 return SQLiteIntervalVisitor::parse_interval_string(value);
154 }
155 }
156 Err(DataFusionError::Plan(
157 "Invalid interval expression".to_string(),
158 ))
159 }
160
161 fn parse_interval_string(value: &str) -> Result<IntervalParts, DataFusionError> {
162 let mut parts = IntervalParts::new();
163 let mut remaining = value;
164
165 let components: [(_, IntervalSetter); 5] = [
166 ("YEARS", IntervalParts::with_years),
167 ("MONS", IntervalParts::with_months),
168 ("DAYS", IntervalParts::with_days),
169 ("HOURS", IntervalParts::with_hours),
170 ("MINS", IntervalParts::with_minutes),
171 ];
172
173 for (unit, setter) in &components {
174 if let Some((value, rest)) = remaining.split_once(unit) {
175 let parsed_value: i64 = SQLiteIntervalVisitor::parse_value(value.trim())?;
176 parts = setter(parts, parsed_value);
177 remaining = rest;
178 }
179 }
180
181 if let Some((secs, _)) = remaining.split_once("SECS") {
183 let (seconds, nanos) = SQLiteIntervalVisitor::parse_seconds_and_nanos(secs.trim())?;
184 parts = parts.with_seconds(seconds).with_nanos(nanos);
185 }
186
187 Ok(parts)
188 }
189
190 fn parse_seconds_and_nanos(value: &str) -> Result<(i64, u32), DataFusionError> {
191 let parts: Vec<&str> = value.split('.').collect();
192 let seconds = SQLiteIntervalVisitor::parse_value(parts[0])?;
193 let nanos = if parts.len() > 1 {
194 let nanos_str = format!("{:0<9}", parts[1]);
195 nanos_str[..9].parse().map_err(|_| {
196 DataFusionError::Plan(format!("Failed to parse nanoseconds: {}", parts[1]))
197 })?
198 } else {
199 0
200 };
201 Ok((seconds, nanos))
202 }
203
204 fn parse_value<T: FromStr>(value: &str) -> Result<T, DataFusionError> {
205 value
206 .parse()
207 .map_err(|_| DataFusionError::Plan(format!("Failed to parse interval value: {value}")))
208 }
209
210 fn create_datetime_function(target: &Expr, interval: &IntervalParts) -> Expr {
211 let interval_date_type = if interval.intraday() {
212 SQLiteIntervalType::Datetime
213 } else {
214 SQLiteIntervalType::Date
215 };
216
217 let function_args = vec![
218 Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(target.clone()))),
219 SQLiteIntervalVisitor::create_interval_arg("years", interval.years),
220 SQLiteIntervalVisitor::create_interval_arg("months", interval.months),
221 SQLiteIntervalVisitor::create_interval_arg("days", interval.days),
222 SQLiteIntervalVisitor::create_interval_arg("hours", interval.hours),
223 SQLiteIntervalVisitor::create_interval_arg("minutes", interval.minutes),
224 SQLiteIntervalVisitor::create_interval_arg_with_fraction(
225 "seconds",
226 interval.seconds,
227 interval.nanos,
228 ),
229 ]
230 .into_iter()
231 .flatten() .collect();
233
234 let datetime_function = Expr::Function(ast::Function {
235 name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
236 interval_date_type.to_string(),
237 ))]),
238 args: ast::FunctionArguments::List(FunctionArgumentList {
239 duplicate_treatment: None,
240 args: function_args,
241 clauses: Vec::new(),
242 }),
243 filter: None,
244 null_treatment: None,
245 over: None,
246 within_group: Vec::new(),
247 parameters: ast::FunctionArguments::None,
248 uses_odbc_syntax: false,
249 });
250
251 Expr::Cast {
252 expr: Box::new(datetime_function),
253 data_type: ast::DataType::Text,
254 format: None,
255 kind: ast::CastKind::Cast,
256 }
257 }
258
259 fn create_interval_arg(unit: &str, value: i64) -> Option<FunctionArg> {
260 if value == 0 {
261 None
262 } else {
263 Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
264 ast::Value::SingleQuotedString(format!("{value:+} {unit}")),
265 ))))
266 }
267 }
268
269 fn create_interval_arg_with_fraction(
270 unit: &str,
271 value: i64,
272 fraction: u32,
273 ) -> Option<FunctionArg> {
274 if value == 0 && fraction == 0 {
275 None
276 } else {
277 let fraction_str = if fraction > 0 {
278 format!(".{fraction:09}")
279 } else {
280 String::new()
281 };
282
283 Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
284 ast::Value::SingleQuotedString(format!("{value:+}{fraction_str} {unit}")),
285 ))))
286 }
287 }
288}
289
290#[cfg(test)]
291mod test {
292 use super::*;
293
294 #[test]
295 fn test_interval_parts_parse() {
296 let parts = SQLiteIntervalVisitor::parse_interval_string(
297 "0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
298 )
299 .expect("interval parts should be parsed");
300
301 assert_eq!(parts.years, 0);
302 assert_eq!(parts.months, 0);
303 assert_eq!(parts.days, 1);
304 assert_eq!(parts.hours, 0);
305 assert_eq!(parts.minutes, 0);
306 assert_eq!(parts.seconds, 0);
307 assert_eq!(parts.nanos, 0);
308 }
309
310 #[test]
311 fn test_interval_parts_parse_with_nanos() {
312 let parts = SQLiteIntervalVisitor::parse_interval_string(
313 "0 YEARS 0 MONS 0 DAYS 0 HOURS 0 MINS 0.000000001 SECS",
314 )
315 .expect("interval parts should be parsed");
316
317 assert_eq!(parts.years, 0);
318 assert_eq!(parts.months, 0);
319 assert_eq!(parts.days, 0);
320 assert_eq!(parts.hours, 0);
321 assert_eq!(parts.minutes, 0);
322 assert_eq!(parts.seconds, 0);
323 assert_eq!(parts.nanos, 1);
324 }
325
326 #[test]
327 fn test_interval_parts_parse_negative() {
328 let parts = SQLiteIntervalVisitor::parse_interval_string(
329 "0 YEARS 0 MONS -1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
330 )
331 .expect("interval parts should be parsed");
332
333 assert_eq!(parts.years, 0);
334 assert_eq!(parts.months, 0);
335 assert_eq!(parts.days, -1);
336 assert_eq!(parts.hours, 0);
337 assert_eq!(parts.minutes, 0);
338 assert_eq!(parts.seconds, 0);
339 assert_eq!(parts.nanos, 0);
340 }
341
342 #[test]
343 fn test_interval_parts_parse_intraday() {
344 let parts = SQLiteIntervalVisitor::parse_interval_string(
345 "0 YEARS 0 MONS 0 DAYS 1 HOURS 1 MINS 1.000000001 SECS",
346 )
347 .expect("interval parts should be parsed");
348
349 assert_eq!(parts.years, 0);
350 assert_eq!(parts.months, 0);
351 assert_eq!(parts.days, 0);
352 assert_eq!(parts.hours, 1);
353 assert_eq!(parts.minutes, 1);
354 assert_eq!(parts.seconds, 1);
355 assert_eq!(parts.nanos, 1);
356
357 assert!(parts.intraday());
358 }
359
360 #[test]
361 fn test_interval_parts_parse_interday() {
362 let parts = SQLiteIntervalVisitor::parse_interval_string(
363 "0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
364 )
365 .expect("interval parts should be parsed");
366
367 assert_eq!(parts.years, 0);
368 assert_eq!(parts.months, 0);
369 assert_eq!(parts.days, 1);
370 assert_eq!(parts.hours, 0);
371 assert_eq!(parts.minutes, 0);
372 assert_eq!(parts.seconds, 0);
373 assert_eq!(parts.nanos, 0);
374
375 assert!(!parts.intraday());
376 }
377
378 #[test]
379 fn test_create_date_function() {
380 let target = Expr::value(ast::Value::SingleQuotedString("1995-01-01".to_string()));
381 let interval = IntervalParts::new()
382 .with_years(1)
383 .with_months(2)
384 .with_days(3)
385 .with_hours(0)
386 .with_minutes(0)
387 .with_seconds(0)
388 .with_nanos(0);
389
390 let datetime_function = SQLiteIntervalVisitor::create_datetime_function(&target, &interval);
391
392 let expected = Expr::Cast {
393 expr: Box::new(Expr::Function(ast::Function {
394 name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new("date"))]),
395 args: ast::FunctionArguments::List(FunctionArgumentList {
396 duplicate_treatment: None,
397 args: vec![
398 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
399 ast::Value::SingleQuotedString("1995-01-01".to_string()),
400 ))),
401 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
402 ast::Value::SingleQuotedString("+1 years".to_string()),
403 ))),
404 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
405 ast::Value::SingleQuotedString("+2 months".to_string()),
406 ))),
407 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
408 ast::Value::SingleQuotedString("+3 days".to_string()),
409 ))),
410 ],
411 clauses: Vec::new(),
412 }),
413 filter: None,
414 null_treatment: None,
415 over: None,
416 within_group: Vec::new(),
417 parameters: ast::FunctionArguments::None,
418 uses_odbc_syntax: false,
419 })),
420 data_type: ast::DataType::Text,
421 format: None,
422 kind: ast::CastKind::Cast,
423 };
424
425 assert_eq!(datetime_function, expected);
426 }
427
428 #[test]
429 fn test_create_datetime_function() {
430 let target = Expr::value(ast::Value::SingleQuotedString("1995-01-01".to_string()));
431 let interval = IntervalParts::new()
432 .with_years(0)
433 .with_months(0)
434 .with_days(0)
435 .with_hours(1)
436 .with_minutes(2)
437 .with_seconds(3)
438 .with_nanos(0);
439
440 let datetime_function = SQLiteIntervalVisitor::create_datetime_function(&target, &interval);
441
442 let expected = Expr::Cast {
443 expr: Box::new(Expr::Function(ast::Function {
444 name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new("datetime"))]),
445 args: ast::FunctionArguments::List(FunctionArgumentList {
446 duplicate_treatment: None,
447 args: vec![
448 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
449 ast::Value::SingleQuotedString("1995-01-01".to_string()),
450 ))),
451 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
452 ast::Value::SingleQuotedString("+1 hours".to_string()),
453 ))),
454 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
455 ast::Value::SingleQuotedString("+2 minutes".to_string()),
456 ))),
457 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
458 ast::Value::SingleQuotedString("+3 seconds".to_string()),
459 ))),
460 ],
461 clauses: Vec::new(),
462 }),
463 filter: None,
464 null_treatment: None,
465 over: None,
466 within_group: Vec::new(),
467 parameters: ast::FunctionArguments::None,
468 uses_odbc_syntax: false,
469 })),
470 data_type: ast::DataType::Text,
471 format: None,
472 kind: ast::CastKind::Cast,
473 };
474
475 assert_eq!(datetime_function, expected);
476 }
477}