Skip to main content

laminar_sql/parser/
window_rewriter.rs

1//! Production implementation for window function extraction and rewriting
2//!
3//! This module handles:
4//! - TUMBLE(time_col, interval) - tumbling windows
5//! - HOP(time_col, slide, size) / SLIDE(...) - sliding/hopping windows
6//! - SESSION(time_col, gap) - session windows
7
8use sqlparser::ast::{
9    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Query, Select, SelectItem,
10    SetExpr, Statement,
11};
12
13use super::{ParseError, WindowFunction};
14
15/// Rewrites window functions in SQL queries
16pub struct WindowRewriter;
17
18impl WindowRewriter {
19    /// Rewrite a SQL statement to expand window functions.
20    ///
21    /// Transforms window functions like TUMBLE into appropriate table functions
22    /// and adds `window_start`/`window_end` columns.
23    ///
24    /// # Errors
25    ///
26    /// Returns `ParseError::WindowError` if a window function cannot be rewritten.
27    ///
28    /// # Example
29    ///
30    /// ```sql
31    /// -- Input:
32    /// SELECT COUNT(*) FROM events
33    /// GROUP BY TUMBLE(event_time, INTERVAL '5' MINUTE)
34    ///
35    /// -- Output:
36    /// SELECT window_start, window_end, COUNT(*)
37    /// FROM events
38    /// GROUP BY window_start, window_end
39    /// ```
40    pub fn rewrite_statement(stmt: &mut Statement) -> Result<(), ParseError> {
41        if let Statement::Query(query) = stmt {
42            Self::rewrite_query(query)?;
43        }
44        Ok(())
45    }
46
47    /// Rewrite a query
48    fn rewrite_query(query: &mut Query) -> Result<(), ParseError> {
49        if let SetExpr::Select(select) = &mut *query.body {
50            Self::rewrite_select(select)?;
51        }
52        Ok(())
53    }
54
55    /// Rewrite a SELECT statement to expand window functions.
56    ///
57    /// This processes GROUP BY to find window functions and adds
58    /// window_start/window_end to the projection.
59    fn rewrite_select(select: &mut Select) -> Result<(), ParseError> {
60        // Find window function in GROUP BY
61        let window_func = Self::find_window_in_group_by(select)?;
62
63        if let Some(_window) = window_func {
64            // Add window_start and window_end to projection if not already present
65            Self::ensure_window_columns_in_projection(select);
66        }
67
68        Ok(())
69    }
70
71    /// Find window function in GROUP BY clause.
72    fn find_window_in_group_by(select: &Select) -> Result<Option<WindowFunction>, ParseError> {
73        // Check the GROUP BY expressions
74        match &select.group_by {
75            sqlparser::ast::GroupByExpr::Expressions(exprs, _modifiers) => {
76                for expr in exprs {
77                    if let Some(window) = Self::extract_window_function(expr)? {
78                        return Ok(Some(window));
79                    }
80                }
81            }
82            sqlparser::ast::GroupByExpr::All(_) => {}
83        }
84        Ok(None)
85    }
86
87    /// Ensure window_start and window_end columns are in projection.
88    fn ensure_window_columns_in_projection(select: &mut Select) {
89        let has_window_start = Self::has_projection_column(select, "window_start");
90        let has_window_end = Self::has_projection_column(select, "window_end");
91
92        // Add window_start at the beginning if not present
93        if !has_window_start {
94            select.projection.insert(
95                0,
96                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_start"))),
97            );
98        }
99
100        // Add window_end after window_start if not present
101        if !has_window_end {
102            select.projection.insert(
103                1,
104                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_end"))),
105            );
106        }
107    }
108
109    /// Check if a named column exists in the SELECT projection.
110    fn has_projection_column(select: &Select, name: &str) -> bool {
111        select.projection.iter().any(|item| {
112            if let SelectItem::UnnamedExpr(Expr::Identifier(ident)) = item {
113                ident.value.eq_ignore_ascii_case(name)
114            } else if let SelectItem::ExprWithAlias { alias, .. } = item {
115                alias.value.eq_ignore_ascii_case(name)
116            } else {
117                false
118            }
119        })
120    }
121
122    /// Check if expression contains a window function.
123    #[must_use]
124    pub fn contains_window_function(expr: &Expr) -> bool {
125        match expr {
126            Expr::Function(func) => {
127                if let Some(name) = func.name.0.last() {
128                    let func_name = name.to_string().to_uppercase();
129                    matches!(func_name.as_str(), "TUMBLE" | "HOP" | "SLIDE" | "SESSION")
130                } else {
131                    false
132                }
133            }
134            _ => false,
135        }
136    }
137
138    /// Extract window function details from expression.
139    ///
140    /// Parses the actual arguments from TUMBLE/HOP/SESSION function calls.
141    ///
142    /// # Supported syntax
143    ///
144    /// - `TUMBLE(time_column, interval)` - 2 arguments
145    /// - `HOP(time_column, slide_interval, window_size)` - 3 arguments
146    /// - `SLIDE(time_column, slide_interval, window_size)` - alias for HOP
147    /// - `SESSION(time_column, gap_interval)` - 2 arguments
148    ///
149    /// # Errors
150    ///
151    /// Returns `ParseError::WindowError` if:
152    /// - Function has empty name
153    /// - Wrong number of arguments for window type
154    /// - Arguments cannot be extracted
155    pub fn extract_window_function(expr: &Expr) -> Result<Option<WindowFunction>, ParseError> {
156        match expr {
157            Expr::Function(func) => {
158                let name =
159                    func.name.0.last().ok_or_else(|| {
160                        ParseError::WindowError("Empty function name".to_string())
161                    })?;
162
163                let func_name = name.to_string().to_uppercase();
164
165                // Extract arguments from the function
166                let args = Self::extract_function_args(&func.args)?;
167
168                match func_name.as_str() {
169                    "TUMBLE" => Self::parse_tumble_args(&args),
170                    "HOP" | "SLIDE" => Self::parse_hop_args(&args),
171                    "SESSION" => Self::parse_session_args(&args),
172                    _ => Ok(None),
173                }
174            }
175            _ => Ok(None),
176        }
177    }
178
179    /// Extract function arguments as a vector of expressions.
180    fn extract_function_args(args: &FunctionArguments) -> Result<Vec<Expr>, ParseError> {
181        match args {
182            FunctionArguments::List(arg_list) => {
183                let mut result = Vec::new();
184                for arg in &arg_list.args {
185                    if let Some(expr) = Self::extract_arg_expr(arg) {
186                        result.push(expr);
187                    }
188                }
189                Ok(result)
190            }
191            FunctionArguments::None => Ok(vec![]),
192            FunctionArguments::Subquery(_) => Err(ParseError::WindowError(
193                "Subquery arguments not supported for window functions".to_string(),
194            )),
195        }
196    }
197
198    /// Extract expression from a function argument.
199    fn extract_arg_expr(arg: &FunctionArg) -> Option<Expr> {
200        match arg {
201            FunctionArg::Unnamed(arg_expr) => match arg_expr {
202                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
203                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
204            },
205            FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => match arg {
206                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
207                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
208            },
209        }
210    }
211
212    /// Parse TUMBLE(time_column, interval) arguments.
213    fn parse_tumble_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
214        if args.len() != 2 {
215            return Err(ParseError::WindowError(format!(
216                "TUMBLE requires 2 arguments (time_column, interval), got {}",
217                args.len()
218            )));
219        }
220
221        Ok(Some(WindowFunction::Tumble {
222            time_column: Box::new(args[0].clone()),
223            interval: Box::new(args[1].clone()),
224        }))
225    }
226
227    /// Parse HOP/SLIDE(time_column, slide_interval, window_size) arguments.
228    fn parse_hop_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
229        if args.len() != 3 {
230            return Err(ParseError::WindowError(format!(
231                "HOP/SLIDE requires 3 arguments (time_column, slide_interval, window_size), got {}",
232                args.len()
233            )));
234        }
235
236        Ok(Some(WindowFunction::Hop {
237            time_column: Box::new(args[0].clone()),
238            slide_interval: Box::new(args[1].clone()),
239            window_interval: Box::new(args[2].clone()),
240        }))
241    }
242
243    /// Parse SESSION(time_column, gap_interval) arguments.
244    fn parse_session_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
245        if args.len() != 2 {
246            return Err(ParseError::WindowError(format!(
247                "SESSION requires 2 arguments (time_column, gap_interval), got {}",
248                args.len()
249            )));
250        }
251
252        Ok(Some(WindowFunction::Session {
253            time_column: Box::new(args[0].clone()),
254            gap_interval: Box::new(args[1].clone()),
255        }))
256    }
257
258    /// Extract the time column name from a window function.
259    ///
260    /// Returns the column name as a string if extractable.
261    #[must_use]
262    pub fn get_time_column_name(window: &WindowFunction) -> Option<String> {
263        let expr = match window {
264            WindowFunction::Tumble { time_column, .. }
265            | WindowFunction::Hop { time_column, .. }
266            | WindowFunction::Session { time_column, .. } => time_column.as_ref(),
267        };
268
269        match expr {
270            Expr::Identifier(ident) => Some(ident.value.clone()),
271            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
272            _ => None,
273        }
274    }
275
276    /// Parse an INTERVAL expression to Duration.
277    ///
278    /// Supports: SECOND, MINUTE, HOUR, DAY
279    ///
280    /// # Errors
281    ///
282    /// Returns `ParseError::WindowError` if the expression is not a valid interval.
283    pub fn parse_interval_to_duration(expr: &Expr) -> Result<std::time::Duration, ParseError> {
284        match expr {
285            Expr::Interval(interval) => {
286                // Extract the value
287                let value = Self::extract_interval_value(&interval.value)?;
288
289                // Get the unit (defaults to SECOND)
290                let unit = interval
291                    .leading_field
292                    .clone()
293                    .unwrap_or(sqlparser::ast::DateTimeField::Second);
294
295                let seconds =
296                    match unit {
297                        sqlparser::ast::DateTimeField::Second
298                        | sqlparser::ast::DateTimeField::Seconds => value,
299                        sqlparser::ast::DateTimeField::Minute
300                        | sqlparser::ast::DateTimeField::Minutes => value * 60,
301                        sqlparser::ast::DateTimeField::Hour
302                        | sqlparser::ast::DateTimeField::Hours => value * 3600,
303                        sqlparser::ast::DateTimeField::Day
304                        | sqlparser::ast::DateTimeField::Days => value * 86400,
305                        _ => {
306                            return Err(ParseError::WindowError(format!(
307                                "Unsupported interval unit: {unit:?}"
308                            )))
309                        }
310                    };
311
312                Ok(std::time::Duration::from_secs(seconds))
313            }
314            // Handle string literal intervals like '5 MINUTES'
315            Expr::Value(value_with_span) => {
316                use sqlparser::ast::Value;
317                if let Value::SingleQuotedString(s) = &value_with_span.value {
318                    Self::parse_interval_string(s)
319                } else {
320                    Err(ParseError::WindowError(format!(
321                        "Expected string value, got: {value_with_span:?}"
322                    )))
323                }
324            }
325            // Handle identifier that might be an interval string
326            Expr::Identifier(ident) => Self::parse_interval_string(&ident.value),
327            _ => Err(ParseError::WindowError(format!(
328                "Expected INTERVAL expression, got: {expr:?}"
329            ))),
330        }
331    }
332
333    /// Extract numeric value from interval expression.
334    fn extract_interval_value(expr: &Expr) -> Result<u64, ParseError> {
335        match expr {
336            Expr::Value(value_with_span) => {
337                use sqlparser::ast::Value;
338                match &value_with_span.value {
339                    Value::Number(n, _) => n.parse::<u64>().map_err(|_| {
340                        ParseError::WindowError(format!("Invalid interval value: {n}"))
341                    }),
342                    Value::SingleQuotedString(s) => {
343                        // Handle '5' or '5 MINUTE'
344                        let num_str = s.split_whitespace().next().unwrap_or(s);
345                        num_str.parse::<u64>().map_err(|_| {
346                            ParseError::WindowError(format!("Invalid interval value: {s}"))
347                        })
348                    }
349                    _ => Err(ParseError::WindowError(format!(
350                        "Unsupported value type in interval: {value_with_span:?}"
351                    ))),
352                }
353            }
354            _ => Err(ParseError::WindowError(format!(
355                "Cannot extract interval value from: {expr:?}"
356            ))),
357        }
358    }
359
360    /// Parse an interval string like "5 MINUTES" or "1 HOUR".
361    fn parse_interval_string(s: &str) -> Result<std::time::Duration, ParseError> {
362        let parts: Vec<&str> = s.split_whitespace().collect();
363        if parts.is_empty() {
364            return Err(ParseError::WindowError("Empty interval string".to_string()));
365        }
366
367        let value: u64 = parts[0].parse().map_err(|_| {
368            ParseError::WindowError(format!("Invalid interval value: {}", parts[0]))
369        })?;
370
371        let unit = if parts.len() > 1 {
372            parts[1].to_uppercase()
373        } else {
374            "SECOND".to_string()
375        };
376
377        let seconds = match unit.as_str() {
378            "SECOND" | "SECONDS" | "S" => value,
379            "MINUTE" | "MINUTES" | "M" => value * 60,
380            "HOUR" | "HOURS" | "H" => value * 3600,
381            "DAY" | "DAYS" | "D" => value * 86400,
382            _ => {
383                return Err(ParseError::WindowError(format!(
384                    "Unsupported interval unit: {unit}"
385                )))
386            }
387        };
388
389        Ok(std::time::Duration::from_secs(seconds))
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use sqlparser::dialect::GenericDialect;
397    use sqlparser::parser::Parser;
398
399    #[test]
400    fn test_contains_window_function() {
401        let sql = "SELECT TUMBLE(event_time, INTERVAL '5' MINUTE) FROM events";
402        let dialect = GenericDialect {};
403        let statements = Parser::parse_sql(&dialect, sql).unwrap();
404
405        if let Statement::Query(query) = &statements[0] {
406            if let SetExpr::Select(select) = &*query.body {
407                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
408                    assert!(WindowRewriter::contains_window_function(expr));
409                }
410            }
411        }
412    }
413
414    #[test]
415    fn test_rewrite_statement() {
416        let sql = "SELECT COUNT(*) FROM events GROUP BY event_time";
417        let dialect = GenericDialect {};
418        let mut statements = Parser::parse_sql(&dialect, sql).unwrap();
419
420        // Should not fail on standard SQL
421        assert!(WindowRewriter::rewrite_statement(&mut statements[0]).is_ok());
422    }
423
424    #[test]
425    fn test_extract_tumble_with_actual_args() {
426        let sql = "SELECT TUMBLE(order_time, INTERVAL '10' MINUTE) FROM orders";
427        let dialect = GenericDialect {};
428        let statements = Parser::parse_sql(&dialect, sql).unwrap();
429
430        if let Statement::Query(query) = &statements[0] {
431            if let SetExpr::Select(select) = &*query.body {
432                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
433                    let window = WindowRewriter::extract_window_function(expr)
434                        .unwrap()
435                        .unwrap();
436
437                    match window {
438                        WindowFunction::Tumble {
439                            time_column,
440                            interval,
441                        } => {
442                            // Verify time column is extracted correctly
443                            assert_eq!(time_column.to_string(), "order_time");
444
445                            // Verify interval is extracted
446                            assert!(interval.to_string().contains("10"));
447                        }
448                        _ => panic!("Expected Tumble window"),
449                    }
450                }
451            }
452        }
453    }
454
455    #[test]
456    fn test_extract_hop_with_actual_args() {
457        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM readings";
458        let dialect = GenericDialect {};
459        let statements = Parser::parse_sql(&dialect, sql).unwrap();
460
461        if let Statement::Query(query) = &statements[0] {
462            if let SetExpr::Select(select) = &*query.body {
463                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
464                    let window = WindowRewriter::extract_window_function(expr)
465                        .unwrap()
466                        .unwrap();
467
468                    match window {
469                        WindowFunction::Hop {
470                            time_column,
471                            slide_interval,
472                            window_interval,
473                        } => {
474                            assert_eq!(time_column.to_string(), "ts");
475                            assert!(slide_interval.to_string().contains('1'));
476                            assert!(window_interval.to_string().contains('5'));
477                        }
478                        _ => panic!("Expected Hop window"),
479                    }
480                }
481            }
482        }
483    }
484
485    #[test]
486    fn test_extract_session_with_actual_args() {
487        let sql = "SELECT SESSION(click_time, INTERVAL '30' MINUTE) FROM clicks";
488        let dialect = GenericDialect {};
489        let statements = Parser::parse_sql(&dialect, sql).unwrap();
490
491        if let Statement::Query(query) = &statements[0] {
492            if let SetExpr::Select(select) = &*query.body {
493                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
494                    let window = WindowRewriter::extract_window_function(expr)
495                        .unwrap()
496                        .unwrap();
497
498                    match window {
499                        WindowFunction::Session {
500                            time_column,
501                            gap_interval,
502                        } => {
503                            assert_eq!(time_column.to_string(), "click_time");
504                            assert!(gap_interval.to_string().contains("30"));
505                        }
506                        _ => panic!("Expected Session window"),
507                    }
508                }
509            }
510        }
511    }
512
513    #[test]
514    fn test_tumble_wrong_args_count() {
515        let sql = "SELECT TUMBLE(ts) FROM events";
516        let dialect = GenericDialect {};
517        let statements = Parser::parse_sql(&dialect, sql).unwrap();
518
519        if let Statement::Query(query) = &statements[0] {
520            if let SetExpr::Select(select) = &*query.body {
521                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
522                    let result = WindowRewriter::extract_window_function(expr);
523                    assert!(result.is_err());
524                    let err = result.unwrap_err();
525                    assert!(err.to_string().contains("2 arguments"));
526                }
527            }
528        }
529    }
530
531    #[test]
532    fn test_hop_wrong_args_count() {
533        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE) FROM events";
534        let dialect = GenericDialect {};
535        let statements = Parser::parse_sql(&dialect, sql).unwrap();
536
537        if let Statement::Query(query) = &statements[0] {
538            if let SetExpr::Select(select) = &*query.body {
539                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
540                    let result = WindowRewriter::extract_window_function(expr);
541                    assert!(result.is_err());
542                    let err = result.unwrap_err();
543                    assert!(err.to_string().contains("3 arguments"));
544                }
545            }
546        }
547    }
548
549    #[test]
550    fn test_slide_alias_for_hop() {
551        let sql = "SELECT SLIDE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
552        let dialect = GenericDialect {};
553        let statements = Parser::parse_sql(&dialect, sql).unwrap();
554
555        if let Statement::Query(query) = &statements[0] {
556            if let SetExpr::Select(select) = &*query.body {
557                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
558                    let window = WindowRewriter::extract_window_function(expr)
559                        .unwrap()
560                        .unwrap();
561
562                    // SLIDE should be parsed as Hop
563                    assert!(matches!(window, WindowFunction::Hop { .. }));
564                }
565            }
566        }
567    }
568
569    #[test]
570    fn test_get_time_column_name() {
571        let sql = "SELECT TUMBLE(my_timestamp, INTERVAL '5' MINUTE) FROM events";
572        let dialect = GenericDialect {};
573        let statements = Parser::parse_sql(&dialect, sql).unwrap();
574
575        if let Statement::Query(query) = &statements[0] {
576            if let SetExpr::Select(select) = &*query.body {
577                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
578                    let window = WindowRewriter::extract_window_function(expr)
579                        .unwrap()
580                        .unwrap();
581
582                    let col_name = WindowRewriter::get_time_column_name(&window);
583                    assert_eq!(col_name, Some("my_timestamp".to_string()));
584                }
585            }
586        }
587    }
588
589    #[test]
590    fn test_parse_interval_to_duration() {
591        // Test parsing from GROUP BY
592        let sql = "SELECT COUNT(*) FROM events GROUP BY TUMBLE(ts, INTERVAL '5' MINUTE)";
593        let dialect = GenericDialect {};
594        let statements = Parser::parse_sql(&dialect, sql).unwrap();
595
596        if let Statement::Query(query) = &statements[0] {
597            if let SetExpr::Select(select) = &*query.body {
598                if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
599                    if let Some(expr) = exprs.first() {
600                        let window = WindowRewriter::extract_window_function(expr)
601                            .unwrap()
602                            .unwrap();
603
604                        if let WindowFunction::Tumble { interval, .. } = window {
605                            let duration =
606                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
607                            assert_eq!(duration, std::time::Duration::from_secs(300));
608                        }
609                    }
610                }
611            }
612        }
613    }
614
615    #[test]
616    fn test_parse_interval_string_formats() {
617        // Test various interval string formats
618        let cases = [
619            ("5 MINUTE", 300),
620            ("5 MINUTES", 300),
621            ("1 HOUR", 3600),
622            ("2 HOURS", 7200),
623            ("10 SECOND", 10),
624            ("1 DAY", 86400),
625        ];
626
627        for (input, expected_secs) in cases {
628            let result = WindowRewriter::parse_interval_string(input).unwrap();
629            assert_eq!(
630                result,
631                std::time::Duration::from_secs(expected_secs),
632                "Failed for input: {input}"
633            );
634        }
635    }
636
637    #[test]
638    fn test_window_in_group_by() {
639        let sql = "SELECT user_id, COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '1' HOUR), user_id";
640        let dialect = GenericDialect {};
641        let statements = Parser::parse_sql(&dialect, sql).unwrap();
642
643        if let Statement::Query(query) = &statements[0] {
644            if let SetExpr::Select(select) = &*query.body {
645                let window = WindowRewriter::find_window_in_group_by(select)
646                    .unwrap()
647                    .unwrap();
648
649                assert!(matches!(window, WindowFunction::Tumble { .. }));
650
651                if let WindowFunction::Tumble { time_column, .. } = window {
652                    assert_eq!(time_column.to_string(), "event_time");
653                }
654            }
655        }
656    }
657}