datafusion_table_providers/sqlite/
sqlite_interval.rs

1use datafusion::error::DataFusionError;
2use datafusion::sql::sqlparser::ast::{
3    self, BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArgumentList, Ident,
4    ObjectNamePart, VisitorMut,
5};
6use std::fmt::Display;
7use std::ops::ControlFlow;
8use std::str::FromStr;
9
10#[derive(Default)]
11pub struct SQLiteIntervalVisitor {}
12
13#[derive(Default, Debug)]
14struct IntervalParts {
15    years: i64,
16    months: i64,
17    days: i64,
18    hours: i64,
19    minutes: i64,
20    seconds: i64,
21    nanos: u32,
22}
23
24enum SQLiteIntervalType {
25    Date,
26    Datetime,
27}
28
29impl Display for SQLiteIntervalType {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            SQLiteIntervalType::Date => write!(f, "date"),
33            SQLiteIntervalType::Datetime => write!(f, "datetime"),
34        }
35    }
36}
37
38type IntervalSetter = fn(IntervalParts, i64) -> IntervalParts;
39
40impl IntervalParts {
41    fn new() -> Self {
42        Self::default()
43    }
44
45    fn intraday(&self) -> bool {
46        self.hours > 0 || self.minutes > 0 || self.seconds > 0 || self.nanos > 0
47    }
48
49    fn negate(mut self) -> Self {
50        self.years = -self.years;
51        self.months = -self.months;
52        self.days = -self.days;
53        self.hours = -self.hours;
54        self.minutes = -self.minutes;
55        self.seconds = -self.seconds;
56        self
57    }
58
59    fn with_years(mut self, years: i64) -> Self {
60        self.years = years;
61        self
62    }
63
64    fn with_months(mut self, months: i64) -> Self {
65        self.months = months;
66        self
67    }
68
69    fn with_days(mut self, days: i64) -> Self {
70        self.days = days;
71        self
72    }
73
74    fn with_hours(mut self, hours: i64) -> Self {
75        self.hours = hours;
76        self
77    }
78
79    fn with_minutes(mut self, minutes: i64) -> Self {
80        self.minutes = minutes;
81        self
82    }
83
84    fn with_seconds(mut self, seconds: i64) -> Self {
85        self.seconds = seconds;
86        self
87    }
88
89    fn with_nanos(mut self, nanos: u32) -> Self {
90        self.nanos = nanos;
91        self
92    }
93}
94
95impl VisitorMut for SQLiteIntervalVisitor {
96    type Break = ();
97
98    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
99        // for each INTERVAL, find the previous (or next, if the INTERVAL is first) expression or column name that is associated with it
100        // e.g. `column_name + INTERVAL '1' DAY``, we should find the `column_name`
101        // then replace the `INTERVAL` with e.g. `datetime(column_name, '+1 day')`
102        // this should also apply to expressions though, like `CAST(column_name AS TEXT) + INTERVAL '1' DAY`
103        // in this example, it would be replaced with `datetime(CAST(column_name AS TEXT), '+1 day')`
104
105        // TODO: figure out nested BinaryOp, e.g. `column_name + INTERVAL '1' DAY + INTERVAL '1' DAY`
106        if let Expr::BinaryOp { op, left, right } = expr {
107            if *op != BinaryOperator::Plus && *op != BinaryOperator::Minus {
108                return ControlFlow::Continue(());
109            }
110
111            let (target, interval) = SQLiteIntervalVisitor::normalize_interval_expr(left, right);
112
113            if let Expr::Interval(_) = interval.as_ref() {
114                // parse the INTERVAL and get the bits out of it
115                // e.g. INTERVAL 0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINUTES 0.000000000 SECS -> IntervalParts { days: 1 }
116                if let Ok(interval_parts) = SQLiteIntervalVisitor::parse_interval(interval) {
117                    // negate the interval parts if the operator is minus
118                    let interval_parts = if *op == BinaryOperator::Minus {
119                        interval_parts.negate()
120                    } else {
121                        interval_parts
122                    };
123
124                    *expr =
125                        SQLiteIntervalVisitor::create_datetime_function(target, &interval_parts);
126                }
127            }
128        }
129        ControlFlow::Continue(())
130    }
131}
132
133impl SQLiteIntervalVisitor {
134    // normalize the sides of the operation to make sure the INTERVAL is always on the right
135    fn normalize_interval_expr<'a>(
136        left: &'a mut Box<Expr>,
137        right: &'a mut Box<Expr>,
138    ) -> (&'a mut Box<Expr>, &'a mut Box<Expr>) {
139        if let Expr::Interval { .. } = left.as_ref() {
140            (right, left)
141        } else {
142            (left, right)
143        }
144    }
145
146    fn parse_interval(interval: &Expr) -> Result<IntervalParts, DataFusionError> {
147        if let Expr::Interval(interval_expr) = interval {
148            if let Expr::Value(ast::ValueWithSpan {
149                value: ast::Value::SingleQuotedString(value),
150                span: _,
151            }) = interval_expr.value.as_ref()
152            {
153                return SQLiteIntervalVisitor::parse_interval_string(value);
154            }
155        }
156        Err(DataFusionError::Plan(
157            "Invalid interval expression".to_string(),
158        ))
159    }
160
161    fn parse_interval_string(value: &str) -> Result<IntervalParts, DataFusionError> {
162        let mut parts = IntervalParts::new();
163        let mut remaining = value;
164
165        let components: [(_, IntervalSetter); 5] = [
166            ("YEARS", IntervalParts::with_years),
167            ("MONS", IntervalParts::with_months),
168            ("DAYS", IntervalParts::with_days),
169            ("HOURS", IntervalParts::with_hours),
170            ("MINS", IntervalParts::with_minutes),
171        ];
172
173        for (unit, setter) in &components {
174            if let Some((value, rest)) = remaining.split_once(unit) {
175                let parsed_value: i64 = SQLiteIntervalVisitor::parse_value(value.trim())?;
176                parts = setter(parts, parsed_value);
177                remaining = rest;
178            }
179        }
180
181        // Parse seconds and nanoseconds separately
182        if let Some((secs, _)) = remaining.split_once("SECS") {
183            let (seconds, nanos) = SQLiteIntervalVisitor::parse_seconds_and_nanos(secs.trim())?;
184            parts = parts.with_seconds(seconds).with_nanos(nanos);
185        }
186
187        Ok(parts)
188    }
189
190    fn parse_seconds_and_nanos(value: &str) -> Result<(i64, u32), DataFusionError> {
191        let parts: Vec<&str> = value.split('.').collect();
192        let seconds = SQLiteIntervalVisitor::parse_value(parts[0])?;
193        let nanos = if parts.len() > 1 {
194            let nanos_str = format!("{:0<9}", parts[1]);
195            nanos_str[..9].parse().map_err(|_| {
196                DataFusionError::Plan(format!("Failed to parse nanoseconds: {}", parts[1]))
197            })?
198        } else {
199            0
200        };
201        Ok((seconds, nanos))
202    }
203
204    fn parse_value<T: FromStr>(value: &str) -> Result<T, DataFusionError> {
205        value
206            .parse()
207            .map_err(|_| DataFusionError::Plan(format!("Failed to parse interval value: {value}")))
208    }
209
210    fn create_datetime_function(target: &Expr, interval: &IntervalParts) -> Expr {
211        let interval_date_type = if interval.intraday() {
212            SQLiteIntervalType::Datetime
213        } else {
214            SQLiteIntervalType::Date
215        };
216
217        let function_args = vec![
218            Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(target.clone()))),
219            SQLiteIntervalVisitor::create_interval_arg("years", interval.years),
220            SQLiteIntervalVisitor::create_interval_arg("months", interval.months),
221            SQLiteIntervalVisitor::create_interval_arg("days", interval.days),
222            SQLiteIntervalVisitor::create_interval_arg("hours", interval.hours),
223            SQLiteIntervalVisitor::create_interval_arg("minutes", interval.minutes),
224            SQLiteIntervalVisitor::create_interval_arg_with_fraction(
225                "seconds",
226                interval.seconds,
227                interval.nanos,
228            ),
229        ]
230        .into_iter()
231        .flatten() // flatten the list of arguments to exclude 0 values
232        .collect();
233
234        let datetime_function = Expr::Function(ast::Function {
235            name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
236                interval_date_type.to_string(),
237            ))]),
238            args: ast::FunctionArguments::List(FunctionArgumentList {
239                duplicate_treatment: None,
240                args: function_args,
241                clauses: Vec::new(),
242            }),
243            filter: None,
244            null_treatment: None,
245            over: None,
246            within_group: Vec::new(),
247            parameters: ast::FunctionArguments::None,
248            uses_odbc_syntax: false,
249        });
250
251        Expr::Cast {
252            expr: Box::new(datetime_function),
253            data_type: ast::DataType::Text,
254            format: None,
255            kind: ast::CastKind::Cast,
256        }
257    }
258
259    fn create_interval_arg(unit: &str, value: i64) -> Option<FunctionArg> {
260        if value == 0 {
261            None
262        } else {
263            Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
264                ast::Value::SingleQuotedString(format!("{value:+} {unit}")),
265            ))))
266        }
267    }
268
269    fn create_interval_arg_with_fraction(
270        unit: &str,
271        value: i64,
272        fraction: u32,
273    ) -> Option<FunctionArg> {
274        if value == 0 && fraction == 0 {
275            None
276        } else {
277            let fraction_str = if fraction > 0 {
278                format!(".{fraction:09}")
279            } else {
280                String::new()
281            };
282
283            Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
284                ast::Value::SingleQuotedString(format!("{value:+}{fraction_str} {unit}")),
285            ))))
286        }
287    }
288}
289
290#[cfg(test)]
291mod test {
292    use super::*;
293
294    #[test]
295    fn test_interval_parts_parse() {
296        let parts = SQLiteIntervalVisitor::parse_interval_string(
297            "0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
298        )
299        .expect("interval parts should be parsed");
300
301        assert_eq!(parts.years, 0);
302        assert_eq!(parts.months, 0);
303        assert_eq!(parts.days, 1);
304        assert_eq!(parts.hours, 0);
305        assert_eq!(parts.minutes, 0);
306        assert_eq!(parts.seconds, 0);
307        assert_eq!(parts.nanos, 0);
308    }
309
310    #[test]
311    fn test_interval_parts_parse_with_nanos() {
312        let parts = SQLiteIntervalVisitor::parse_interval_string(
313            "0 YEARS 0 MONS 0 DAYS 0 HOURS 0 MINS 0.000000001 SECS",
314        )
315        .expect("interval parts should be parsed");
316
317        assert_eq!(parts.years, 0);
318        assert_eq!(parts.months, 0);
319        assert_eq!(parts.days, 0);
320        assert_eq!(parts.hours, 0);
321        assert_eq!(parts.minutes, 0);
322        assert_eq!(parts.seconds, 0);
323        assert_eq!(parts.nanos, 1);
324    }
325
326    #[test]
327    fn test_interval_parts_parse_negative() {
328        let parts = SQLiteIntervalVisitor::parse_interval_string(
329            "0 YEARS 0 MONS -1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
330        )
331        .expect("interval parts should be parsed");
332
333        assert_eq!(parts.years, 0);
334        assert_eq!(parts.months, 0);
335        assert_eq!(parts.days, -1);
336        assert_eq!(parts.hours, 0);
337        assert_eq!(parts.minutes, 0);
338        assert_eq!(parts.seconds, 0);
339        assert_eq!(parts.nanos, 0);
340    }
341
342    #[test]
343    fn test_interval_parts_parse_intraday() {
344        let parts = SQLiteIntervalVisitor::parse_interval_string(
345            "0 YEARS 0 MONS 0 DAYS 1 HOURS 1 MINS 1.000000001 SECS",
346        )
347        .expect("interval parts should be parsed");
348
349        assert_eq!(parts.years, 0);
350        assert_eq!(parts.months, 0);
351        assert_eq!(parts.days, 0);
352        assert_eq!(parts.hours, 1);
353        assert_eq!(parts.minutes, 1);
354        assert_eq!(parts.seconds, 1);
355        assert_eq!(parts.nanos, 1);
356
357        assert!(parts.intraday());
358    }
359
360    #[test]
361    fn test_interval_parts_parse_interday() {
362        let parts = SQLiteIntervalVisitor::parse_interval_string(
363            "0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS",
364        )
365        .expect("interval parts should be parsed");
366
367        assert_eq!(parts.years, 0);
368        assert_eq!(parts.months, 0);
369        assert_eq!(parts.days, 1);
370        assert_eq!(parts.hours, 0);
371        assert_eq!(parts.minutes, 0);
372        assert_eq!(parts.seconds, 0);
373        assert_eq!(parts.nanos, 0);
374
375        assert!(!parts.intraday());
376    }
377
378    #[test]
379    fn test_create_date_function() {
380        let target = Expr::value(ast::Value::SingleQuotedString("1995-01-01".to_string()));
381        let interval = IntervalParts::new()
382            .with_years(1)
383            .with_months(2)
384            .with_days(3)
385            .with_hours(0)
386            .with_minutes(0)
387            .with_seconds(0)
388            .with_nanos(0);
389
390        let datetime_function = SQLiteIntervalVisitor::create_datetime_function(&target, &interval);
391
392        let expected = Expr::Cast {
393            expr: Box::new(Expr::Function(ast::Function {
394                name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new("date"))]),
395                args: ast::FunctionArguments::List(FunctionArgumentList {
396                    duplicate_treatment: None,
397                    args: vec![
398                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
399                            ast::Value::SingleQuotedString("1995-01-01".to_string()),
400                        ))),
401                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
402                            ast::Value::SingleQuotedString("+1 years".to_string()),
403                        ))),
404                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
405                            ast::Value::SingleQuotedString("+2 months".to_string()),
406                        ))),
407                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
408                            ast::Value::SingleQuotedString("+3 days".to_string()),
409                        ))),
410                    ],
411                    clauses: Vec::new(),
412                }),
413                filter: None,
414                null_treatment: None,
415                over: None,
416                within_group: Vec::new(),
417                parameters: ast::FunctionArguments::None,
418                uses_odbc_syntax: false,
419            })),
420            data_type: ast::DataType::Text,
421            format: None,
422            kind: ast::CastKind::Cast,
423        };
424
425        assert_eq!(datetime_function, expected);
426    }
427
428    #[test]
429    fn test_create_datetime_function() {
430        let target = Expr::value(ast::Value::SingleQuotedString("1995-01-01".to_string()));
431        let interval = IntervalParts::new()
432            .with_years(0)
433            .with_months(0)
434            .with_days(0)
435            .with_hours(1)
436            .with_minutes(2)
437            .with_seconds(3)
438            .with_nanos(0);
439
440        let datetime_function = SQLiteIntervalVisitor::create_datetime_function(&target, &interval);
441
442        let expected = Expr::Cast {
443            expr: Box::new(Expr::Function(ast::Function {
444                name: ast::ObjectName(vec![ObjectNamePart::Identifier(Ident::new("datetime"))]),
445                args: ast::FunctionArguments::List(FunctionArgumentList {
446                    duplicate_treatment: None,
447                    args: vec![
448                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
449                            ast::Value::SingleQuotedString("1995-01-01".to_string()),
450                        ))),
451                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
452                            ast::Value::SingleQuotedString("+1 hours".to_string()),
453                        ))),
454                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
455                            ast::Value::SingleQuotedString("+2 minutes".to_string()),
456                        ))),
457                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::value(
458                            ast::Value::SingleQuotedString("+3 seconds".to_string()),
459                        ))),
460                    ],
461                    clauses: Vec::new(),
462                }),
463                filter: None,
464                null_treatment: None,
465                over: None,
466                within_group: Vec::new(),
467                parameters: ast::FunctionArguments::None,
468                uses_odbc_syntax: false,
469            })),
470            data_type: ast::DataType::Text,
471            format: None,
472            kind: ast::CastKind::Cast,
473        };
474
475        assert_eq!(datetime_function, expected);
476    }
477}