Skip to main content

laminar_sql/parser/
window_rewriter.rs

1//! Production implementation for window function extraction and rewriting
2//!
3//! This module handles:
4//! - TUMBLE(time_col, interval) - tumbling windows
5//! - HOP(time_col, slide, size) / SLIDE(...) - sliding/hopping windows
6//! - SESSION(time_col, gap) - session windows
7
8use sqlparser::ast::{
9    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Query, Select, SelectItem,
10    SetExpr, Statement,
11};
12
13use super::{ParseError, WindowFunction};
14
15/// Rewrites window functions in SQL queries
16pub struct WindowRewriter;
17
18impl WindowRewriter {
19    /// Rewrite a SQL statement to expand window functions.
20    ///
21    /// Transforms window functions like TUMBLE into appropriate table functions
22    /// and adds `window_start`/`window_end` columns.
23    ///
24    /// # Errors
25    ///
26    /// Returns `ParseError::WindowError` if a window function cannot be rewritten.
27    ///
28    /// # Example
29    ///
30    /// ```sql
31    /// -- Input:
32    /// SELECT COUNT(*) FROM events
33    /// GROUP BY TUMBLE(event_time, INTERVAL '5' MINUTE)
34    ///
35    /// -- Output:
36    /// SELECT window_start, window_end, COUNT(*)
37    /// FROM events
38    /// GROUP BY window_start, window_end
39    /// ```
40    pub fn rewrite_statement(stmt: &mut Statement) -> Result<(), ParseError> {
41        if let Statement::Query(query) = stmt {
42            Self::rewrite_query(query)?;
43        }
44        Ok(())
45    }
46
47    /// Rewrite a query
48    fn rewrite_query(query: &mut Query) -> Result<(), ParseError> {
49        if let SetExpr::Select(select) = &mut *query.body {
50            Self::rewrite_select(select)?;
51        }
52        Ok(())
53    }
54
55    /// Rewrite a SELECT statement to expand window functions.
56    ///
57    /// This processes GROUP BY to find window functions and adds
58    /// window_start/window_end to the projection.
59    fn rewrite_select(select: &mut Select) -> Result<(), ParseError> {
60        // Find window function in GROUP BY
61        let window_func = Self::find_window_in_group_by(select)?;
62
63        if let Some(_window) = window_func {
64            // Add window_start and window_end to projection if not already present
65            Self::ensure_window_columns_in_projection(select);
66        }
67
68        Ok(())
69    }
70
71    /// Find window function in GROUP BY clause.
72    fn find_window_in_group_by(select: &Select) -> Result<Option<WindowFunction>, ParseError> {
73        // Check the GROUP BY expressions
74        match &select.group_by {
75            sqlparser::ast::GroupByExpr::Expressions(exprs, _modifiers) => {
76                for expr in exprs {
77                    if let Some(window) = Self::extract_window_function(expr)? {
78                        return Ok(Some(window));
79                    }
80                }
81            }
82            sqlparser::ast::GroupByExpr::All(_) => {}
83        }
84        Ok(None)
85    }
86
87    /// Ensure window_start and window_end columns are in projection.
88    fn ensure_window_columns_in_projection(select: &mut Select) {
89        let has_window_start = Self::has_projection_column(select, "window_start");
90        let has_window_end = Self::has_projection_column(select, "window_end");
91
92        // Add window_start at the beginning if not present
93        if !has_window_start {
94            select.projection.insert(
95                0,
96                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_start"))),
97            );
98        }
99
100        // Add window_end after window_start if not present
101        if !has_window_end {
102            select.projection.insert(
103                1,
104                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_end"))),
105            );
106        }
107    }
108
109    /// Check if a named column exists in the SELECT projection.
110    fn has_projection_column(select: &Select, name: &str) -> bool {
111        select.projection.iter().any(|item| {
112            if let SelectItem::UnnamedExpr(Expr::Identifier(ident)) = item {
113                ident.value.eq_ignore_ascii_case(name)
114            } else if let SelectItem::ExprWithAlias { alias, .. } = item {
115                alias.value.eq_ignore_ascii_case(name)
116            } else {
117                false
118            }
119        })
120    }
121
122    /// Check if expression contains a window function.
123    #[must_use]
124    pub fn contains_window_function(expr: &Expr) -> bool {
125        match expr {
126            Expr::Function(func) => {
127                if let Some(name) = func.name.0.last() {
128                    let func_name = name.to_string().to_uppercase();
129                    matches!(
130                        func_name.as_str(),
131                        "TUMBLE" | "HOP" | "SLIDE" | "SESSION" | "CUMULATE"
132                    )
133                } else {
134                    false
135                }
136            }
137            _ => false,
138        }
139    }
140
141    /// Extract window function details from expression.
142    ///
143    /// Parses the actual arguments from TUMBLE/HOP/SESSION function calls.
144    ///
145    /// # Supported syntax
146    ///
147    /// - `TUMBLE(time_column, interval)` - 2 arguments
148    /// - `HOP(time_column, slide_interval, window_size)` - 3 arguments
149    /// - `SLIDE(time_column, slide_interval, window_size)` - alias for HOP
150    /// - `SESSION(time_column, gap_interval)` - 2 arguments
151    ///
152    /// # Errors
153    ///
154    /// Returns `ParseError::WindowError` if:
155    /// - Function has empty name
156    /// - Wrong number of arguments for window type
157    /// - Arguments cannot be extracted
158    pub fn extract_window_function(expr: &Expr) -> Result<Option<WindowFunction>, ParseError> {
159        match expr {
160            Expr::Function(func) => {
161                let name =
162                    func.name.0.last().ok_or_else(|| {
163                        ParseError::WindowError("Empty function name".to_string())
164                    })?;
165
166                let func_name = name.to_string().to_uppercase();
167
168                // Extract arguments from the function
169                let args = Self::extract_function_args(&func.args)?;
170
171                match func_name.as_str() {
172                    "TUMBLE" => Self::parse_tumble_args(&args),
173                    "HOP" | "SLIDE" => Self::parse_hop_args(&args),
174                    "SESSION" => Self::parse_session_args(&args),
175                    "CUMULATE" => Self::parse_cumulate_args(&args),
176                    _ => Ok(None),
177                }
178            }
179            _ => Ok(None),
180        }
181    }
182
183    /// Extract function arguments as a vector of expressions.
184    fn extract_function_args(args: &FunctionArguments) -> Result<Vec<Expr>, ParseError> {
185        match args {
186            FunctionArguments::List(arg_list) => {
187                let mut result = Vec::new();
188                for arg in &arg_list.args {
189                    if let Some(expr) = Self::extract_arg_expr(arg) {
190                        result.push(expr);
191                    }
192                }
193                Ok(result)
194            }
195            FunctionArguments::None => Ok(vec![]),
196            FunctionArguments::Subquery(_) => Err(ParseError::WindowError(
197                "Subquery arguments not supported for window functions".to_string(),
198            )),
199        }
200    }
201
202    /// Extract expression from a function argument.
203    fn extract_arg_expr(arg: &FunctionArg) -> Option<Expr> {
204        match arg {
205            FunctionArg::Unnamed(arg_expr) => match arg_expr {
206                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
207                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
208            },
209            FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => match arg {
210                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
211                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
212            },
213        }
214    }
215
216    /// Parse TUMBLE(time_column, interval [, offset]) arguments.
217    fn parse_tumble_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
218        if args.len() < 2 || args.len() > 3 {
219            return Err(ParseError::WindowError(format!(
220                "TUMBLE requires 2-3 arguments (time_column, interval [, offset]), got {}",
221                args.len()
222            )));
223        }
224
225        Ok(Some(WindowFunction::Tumble {
226            time_column: Box::new(args[0].clone()),
227            interval: Box::new(args[1].clone()),
228            offset: args.get(2).map(|e| Box::new(e.clone())),
229        }))
230    }
231
232    /// Parse HOP/SLIDE(time_column, slide_interval, window_size [, offset]) arguments.
233    fn parse_hop_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
234        if args.len() < 3 || args.len() > 4 {
235            return Err(ParseError::WindowError(format!(
236                "HOP/SLIDE requires 3-4 arguments (time_column, slide_interval, window_size [, offset]), got {}",
237                args.len()
238            )));
239        }
240
241        Ok(Some(WindowFunction::Hop {
242            time_column: Box::new(args[0].clone()),
243            slide_interval: Box::new(args[1].clone()),
244            window_interval: Box::new(args[2].clone()),
245            offset: args.get(3).map(|e| Box::new(e.clone())),
246        }))
247    }
248
249    /// Parse SESSION(time_column, gap_interval) arguments.
250    fn parse_session_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
251        if args.len() != 2 {
252            return Err(ParseError::WindowError(format!(
253                "SESSION requires 2 arguments (time_column, gap_interval), got {}",
254                args.len()
255            )));
256        }
257
258        Ok(Some(WindowFunction::Session {
259            time_column: Box::new(args[0].clone()),
260            gap_interval: Box::new(args[1].clone()),
261        }))
262    }
263
264    /// Parse CUMULATE(time_column, step_interval, max_size_interval) arguments.
265    fn parse_cumulate_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
266        if args.len() != 3 {
267            return Err(ParseError::WindowError(format!(
268                "CUMULATE requires 3 arguments (time_column, step_interval, max_size_interval), got {}",
269                args.len()
270            )));
271        }
272
273        Ok(Some(WindowFunction::Cumulate {
274            time_column: Box::new(args[0].clone()),
275            step_interval: Box::new(args[1].clone()),
276            max_size_interval: Box::new(args[2].clone()),
277        }))
278    }
279
280    /// Extract the time column name from a window function.
281    ///
282    /// Returns the column name as a string if extractable.
283    #[must_use]
284    pub fn get_time_column_name(window: &WindowFunction) -> Option<String> {
285        let expr = match window {
286            WindowFunction::Tumble { time_column, .. }
287            | WindowFunction::Hop { time_column, .. }
288            | WindowFunction::Session { time_column, .. }
289            | WindowFunction::Cumulate { time_column, .. } => time_column.as_ref(),
290        };
291
292        match expr {
293            Expr::Identifier(ident) => Some(ident.value.clone()),
294            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
295            _ => None,
296        }
297    }
298
299    /// Parse an INTERVAL expression to Duration.
300    ///
301    /// Supports: SECOND, MINUTE, HOUR, DAY
302    ///
303    /// # Errors
304    ///
305    /// Returns `ParseError::WindowError` if the expression is not a valid interval.
306    pub fn parse_interval_to_duration(expr: &Expr) -> Result<std::time::Duration, ParseError> {
307        match expr {
308            Expr::Interval(interval) => {
309                // Extract the value
310                let value = Self::extract_interval_value(&interval.value)?;
311
312                // Get the unit (defaults to SECOND)
313                let unit = interval
314                    .leading_field
315                    .clone()
316                    .unwrap_or(sqlparser::ast::DateTimeField::Second);
317
318                match unit {
319                    sqlparser::ast::DateTimeField::Millisecond
320                    | sqlparser::ast::DateTimeField::Milliseconds => {
321                        return Ok(std::time::Duration::from_millis(value));
322                    }
323                    _ => {}
324                }
325
326                let seconds =
327                    match unit {
328                        sqlparser::ast::DateTimeField::Second
329                        | sqlparser::ast::DateTimeField::Seconds => value,
330                        sqlparser::ast::DateTimeField::Minute
331                        | sqlparser::ast::DateTimeField::Minutes => value * 60,
332                        sqlparser::ast::DateTimeField::Hour
333                        | sqlparser::ast::DateTimeField::Hours => value * 3600,
334                        sqlparser::ast::DateTimeField::Day
335                        | sqlparser::ast::DateTimeField::Days => value * 86400,
336                        _ => {
337                            return Err(ParseError::WindowError(format!(
338                                "Unsupported interval unit: {unit:?}"
339                            )))
340                        }
341                    };
342
343                Ok(std::time::Duration::from_secs(seconds))
344            }
345            // Handle string literal intervals like '5 MINUTES'
346            Expr::Value(value_with_span) => {
347                use sqlparser::ast::Value;
348                if let Value::SingleQuotedString(s) = &value_with_span.value {
349                    Self::parse_interval_string(s)
350                } else {
351                    Err(ParseError::WindowError(format!(
352                        "Expected string value, got: {value_with_span:?}"
353                    )))
354                }
355            }
356            // Handle identifier that might be an interval string
357            Expr::Identifier(ident) => Self::parse_interval_string(&ident.value),
358            _ => Err(ParseError::WindowError(format!(
359                "Expected INTERVAL expression, got: {expr:?}"
360            ))),
361        }
362    }
363
364    /// Extract numeric value from interval expression.
365    fn extract_interval_value(expr: &Expr) -> Result<u64, ParseError> {
366        match expr {
367            Expr::Value(value_with_span) => {
368                use sqlparser::ast::Value;
369                match &value_with_span.value {
370                    Value::Number(n, _) => n.parse::<u64>().map_err(|_| {
371                        ParseError::WindowError(format!("Invalid interval value: {n}"))
372                    }),
373                    Value::SingleQuotedString(s) => {
374                        // Handle '5' or '5 MINUTE'
375                        let num_str = s.split_whitespace().next().unwrap_or(s);
376                        num_str.parse::<u64>().map_err(|_| {
377                            ParseError::WindowError(format!("Invalid interval value: {s}"))
378                        })
379                    }
380                    _ => Err(ParseError::WindowError(format!(
381                        "Unsupported value type in interval: {value_with_span:?}"
382                    ))),
383                }
384            }
385            _ => Err(ParseError::WindowError(format!(
386                "Cannot extract interval value from: {expr:?}"
387            ))),
388        }
389    }
390
391    /// Parse an interval string like "5 MINUTES" or "1 HOUR".
392    fn parse_interval_string(s: &str) -> Result<std::time::Duration, ParseError> {
393        let parts: Vec<&str> = s.split_whitespace().collect();
394        if parts.is_empty() {
395            return Err(ParseError::WindowError("Empty interval string".to_string()));
396        }
397
398        let value: u64 = parts[0].parse().map_err(|_| {
399            ParseError::WindowError(format!("Invalid interval value: {}", parts[0]))
400        })?;
401
402        let unit = if parts.len() > 1 {
403            parts[1].to_uppercase()
404        } else {
405            "SECOND".to_string()
406        };
407
408        if matches!(unit.as_str(), "MILLISECOND" | "MILLISECONDS" | "MS") {
409            return Ok(std::time::Duration::from_millis(value));
410        }
411
412        let seconds = match unit.as_str() {
413            "SECOND" | "SECONDS" | "S" => value,
414            "MINUTE" | "MINUTES" | "M" => value * 60,
415            "HOUR" | "HOURS" | "H" => value * 3600,
416            "DAY" | "DAYS" | "D" => value * 86400,
417            _ => {
418                return Err(ParseError::WindowError(format!(
419                    "Unsupported interval unit: {unit}"
420                )))
421            }
422        };
423
424        Ok(std::time::Duration::from_secs(seconds))
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use sqlparser::dialect::GenericDialect;
432    use sqlparser::parser::Parser;
433
434    #[test]
435    fn test_contains_window_function() {
436        let sql = "SELECT TUMBLE(event_time, INTERVAL '5' MINUTE) FROM events";
437        let dialect = GenericDialect {};
438        let statements = Parser::parse_sql(&dialect, sql).unwrap();
439
440        if let Statement::Query(query) = &statements[0] {
441            if let SetExpr::Select(select) = &*query.body {
442                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
443                    assert!(WindowRewriter::contains_window_function(expr));
444                }
445            }
446        }
447    }
448
449    #[test]
450    fn test_rewrite_statement() {
451        let sql = "SELECT COUNT(*) FROM events GROUP BY event_time";
452        let dialect = GenericDialect {};
453        let mut statements = Parser::parse_sql(&dialect, sql).unwrap();
454
455        // Should not fail on standard SQL
456        assert!(WindowRewriter::rewrite_statement(&mut statements[0]).is_ok());
457    }
458
459    #[test]
460    fn test_extract_tumble_with_actual_args() {
461        let sql = "SELECT TUMBLE(order_time, INTERVAL '10' MINUTE) FROM orders";
462        let dialect = GenericDialect {};
463        let statements = Parser::parse_sql(&dialect, sql).unwrap();
464
465        if let Statement::Query(query) = &statements[0] {
466            if let SetExpr::Select(select) = &*query.body {
467                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
468                    let window = WindowRewriter::extract_window_function(expr)
469                        .unwrap()
470                        .unwrap();
471
472                    match window {
473                        WindowFunction::Tumble {
474                            time_column,
475                            interval,
476                            ..
477                        } => {
478                            // Verify time column is extracted correctly
479                            assert_eq!(time_column.to_string(), "order_time");
480
481                            // Verify interval is extracted
482                            assert!(interval.to_string().contains("10"));
483                        }
484                        _ => panic!("Expected Tumble window"),
485                    }
486                }
487            }
488        }
489    }
490
491    #[test]
492    fn test_extract_hop_with_actual_args() {
493        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM readings";
494        let dialect = GenericDialect {};
495        let statements = Parser::parse_sql(&dialect, sql).unwrap();
496
497        if let Statement::Query(query) = &statements[0] {
498            if let SetExpr::Select(select) = &*query.body {
499                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
500                    let window = WindowRewriter::extract_window_function(expr)
501                        .unwrap()
502                        .unwrap();
503
504                    match window {
505                        WindowFunction::Hop {
506                            time_column,
507                            slide_interval,
508                            window_interval,
509                            ..
510                        } => {
511                            assert_eq!(time_column.to_string(), "ts");
512                            assert!(slide_interval.to_string().contains('1'));
513                            assert!(window_interval.to_string().contains('5'));
514                        }
515                        _ => panic!("Expected Hop window"),
516                    }
517                }
518            }
519        }
520    }
521
522    #[test]
523    fn test_extract_session_with_actual_args() {
524        let sql = "SELECT SESSION(click_time, INTERVAL '30' MINUTE) FROM clicks";
525        let dialect = GenericDialect {};
526        let statements = Parser::parse_sql(&dialect, sql).unwrap();
527
528        if let Statement::Query(query) = &statements[0] {
529            if let SetExpr::Select(select) = &*query.body {
530                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
531                    let window = WindowRewriter::extract_window_function(expr)
532                        .unwrap()
533                        .unwrap();
534
535                    match window {
536                        WindowFunction::Session {
537                            time_column,
538                            gap_interval,
539                        } => {
540                            assert_eq!(time_column.to_string(), "click_time");
541                            assert!(gap_interval.to_string().contains("30"));
542                        }
543                        _ => panic!("Expected Session window"),
544                    }
545                }
546            }
547        }
548    }
549
550    #[test]
551    fn test_tumble_wrong_args_count() {
552        let sql = "SELECT TUMBLE(ts) FROM events";
553        let dialect = GenericDialect {};
554        let statements = Parser::parse_sql(&dialect, sql).unwrap();
555
556        if let Statement::Query(query) = &statements[0] {
557            if let SetExpr::Select(select) = &*query.body {
558                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
559                    let result = WindowRewriter::extract_window_function(expr);
560                    assert!(result.is_err());
561                    let err = result.unwrap_err();
562                    assert!(err.to_string().contains("2-3 arguments"));
563                }
564            }
565        }
566    }
567
568    #[test]
569    fn test_hop_wrong_args_count() {
570        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE) FROM events";
571        let dialect = GenericDialect {};
572        let statements = Parser::parse_sql(&dialect, sql).unwrap();
573
574        if let Statement::Query(query) = &statements[0] {
575            if let SetExpr::Select(select) = &*query.body {
576                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
577                    let result = WindowRewriter::extract_window_function(expr);
578                    assert!(result.is_err());
579                    let err = result.unwrap_err();
580                    assert!(err.to_string().contains("3-4 arguments"));
581                }
582            }
583        }
584    }
585
586    #[test]
587    fn test_slide_alias_for_hop() {
588        let sql = "SELECT SLIDE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
589        let dialect = GenericDialect {};
590        let statements = Parser::parse_sql(&dialect, sql).unwrap();
591
592        if let Statement::Query(query) = &statements[0] {
593            if let SetExpr::Select(select) = &*query.body {
594                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
595                    let window = WindowRewriter::extract_window_function(expr)
596                        .unwrap()
597                        .unwrap();
598
599                    // SLIDE should be parsed as Hop
600                    assert!(matches!(window, WindowFunction::Hop { .. }));
601                }
602            }
603        }
604    }
605
606    #[test]
607    fn test_get_time_column_name() {
608        let sql = "SELECT TUMBLE(my_timestamp, INTERVAL '5' MINUTE) FROM events";
609        let dialect = GenericDialect {};
610        let statements = Parser::parse_sql(&dialect, sql).unwrap();
611
612        if let Statement::Query(query) = &statements[0] {
613            if let SetExpr::Select(select) = &*query.body {
614                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
615                    let window = WindowRewriter::extract_window_function(expr)
616                        .unwrap()
617                        .unwrap();
618
619                    let col_name = WindowRewriter::get_time_column_name(&window);
620                    assert_eq!(col_name, Some("my_timestamp".to_string()));
621                }
622            }
623        }
624    }
625
626    #[test]
627    fn test_parse_interval_to_duration() {
628        // Test parsing from GROUP BY
629        let sql = "SELECT COUNT(*) FROM events GROUP BY TUMBLE(ts, INTERVAL '5' MINUTE)";
630        let dialect = GenericDialect {};
631        let statements = Parser::parse_sql(&dialect, sql).unwrap();
632
633        if let Statement::Query(query) = &statements[0] {
634            if let SetExpr::Select(select) = &*query.body {
635                if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
636                    if let Some(expr) = exprs.first() {
637                        let window = WindowRewriter::extract_window_function(expr)
638                            .unwrap()
639                            .unwrap();
640
641                        if let WindowFunction::Tumble { interval, .. } = window {
642                            let duration =
643                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
644                            assert_eq!(duration, std::time::Duration::from_secs(300));
645                        }
646                    }
647                }
648            }
649        }
650    }
651
652    #[test]
653    fn test_parse_interval_string_formats() {
654        // Test various interval string formats
655        let cases = [
656            ("5 MINUTE", 300),
657            ("5 MINUTES", 300),
658            ("1 HOUR", 3600),
659            ("2 HOURS", 7200),
660            ("10 SECOND", 10),
661            ("1 DAY", 86400),
662        ];
663
664        for (input, expected_secs) in cases {
665            let result = WindowRewriter::parse_interval_string(input).unwrap();
666            assert_eq!(
667                result,
668                std::time::Duration::from_secs(expected_secs),
669                "Failed for input: {input}"
670            );
671        }
672    }
673
674    #[test]
675    fn test_window_in_group_by() {
676        let sql = "SELECT user_id, COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '1' HOUR), user_id";
677        let dialect = GenericDialect {};
678        let statements = Parser::parse_sql(&dialect, sql).unwrap();
679
680        if let Statement::Query(query) = &statements[0] {
681            if let SetExpr::Select(select) = &*query.body {
682                let window = WindowRewriter::find_window_in_group_by(select)
683                    .unwrap()
684                    .unwrap();
685
686                assert!(matches!(window, WindowFunction::Tumble { .. }));
687
688                if let WindowFunction::Tumble { time_column, .. } = window {
689                    assert_eq!(time_column.to_string(), "event_time");
690                }
691            }
692        }
693    }
694
695    #[test]
696    fn test_contains_cumulate_window_function() {
697        let sql = "SELECT CUMULATE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
698        let dialect = GenericDialect {};
699        let statements = Parser::parse_sql(&dialect, sql).unwrap();
700
701        if let Statement::Query(query) = &statements[0] {
702            if let SetExpr::Select(select) = &*query.body {
703                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
704                    assert!(WindowRewriter::contains_window_function(expr));
705                }
706            }
707        }
708    }
709
710    #[test]
711    fn test_extract_cumulate_with_actual_args() {
712        let sql =
713            "SELECT CUMULATE(order_time, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM orders";
714        let dialect = GenericDialect {};
715        let statements = Parser::parse_sql(&dialect, sql).unwrap();
716
717        if let Statement::Query(query) = &statements[0] {
718            if let SetExpr::Select(select) = &*query.body {
719                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
720                    let window = WindowRewriter::extract_window_function(expr)
721                        .unwrap()
722                        .unwrap();
723
724                    match window {
725                        WindowFunction::Cumulate {
726                            time_column,
727                            step_interval,
728                            max_size_interval,
729                        } => {
730                            assert_eq!(time_column.to_string(), "order_time");
731                            assert!(step_interval.to_string().contains('1'));
732                            assert!(max_size_interval.to_string().contains('5'));
733                        }
734                        _ => panic!("Expected Cumulate window"),
735                    }
736                }
737            }
738        }
739    }
740
741    #[test]
742    fn test_cumulate_wrong_args_count() {
743        let sql = "SELECT CUMULATE(ts, INTERVAL '1' MINUTE) FROM events";
744        let dialect = GenericDialect {};
745        let statements = Parser::parse_sql(&dialect, sql).unwrap();
746
747        if let Statement::Query(query) = &statements[0] {
748            if let SetExpr::Select(select) = &*query.body {
749                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
750                    let result = WindowRewriter::extract_window_function(expr);
751                    assert!(result.is_err());
752                    let err = result.unwrap_err();
753                    assert!(err.to_string().contains("3 arguments"));
754                }
755            }
756        }
757    }
758
759    #[test]
760    fn test_cumulate_time_column_name() {
761        let sql = "SELECT CUMULATE(my_ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
762        let dialect = GenericDialect {};
763        let statements = Parser::parse_sql(&dialect, sql).unwrap();
764
765        if let Statement::Query(query) = &statements[0] {
766            if let SetExpr::Select(select) = &*query.body {
767                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
768                    let window = WindowRewriter::extract_window_function(expr)
769                        .unwrap()
770                        .unwrap();
771
772                    let col_name = WindowRewriter::get_time_column_name(&window);
773                    assert_eq!(col_name, Some("my_ts".to_string()));
774                }
775            }
776        }
777    }
778
779    #[test]
780    fn test_millisecond_interval() {
781        // parse_interval_to_duration should handle MILLISECOND unit
782        let sql = "SELECT TUMBLE(ts, INTERVAL '500' MILLISECOND) FROM events";
783        let dialect = GenericDialect {};
784        let statements = Parser::parse_sql(&dialect, sql).unwrap();
785
786        if let Statement::Query(query) = &statements[0] {
787            if let SetExpr::Select(select) = &*query.body {
788                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
789                    let window = WindowRewriter::extract_window_function(expr)
790                        .unwrap()
791                        .unwrap();
792
793                    match window {
794                        WindowFunction::Tumble {
795                            time_column: _,
796                            interval,
797                            ..
798                        } => {
799                            let duration =
800                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
801                            assert_eq!(
802                                duration,
803                                std::time::Duration::from_millis(500),
804                                "INTERVAL '500' MILLISECOND should parse to 500ms"
805                            );
806                        }
807                        _ => panic!("Expected Tumble window"),
808                    }
809                }
810            }
811        }
812    }
813
814    #[test]
815    fn test_millisecond_interval_string() {
816        // parse_interval_string should handle MS unit
817        let duration = WindowRewriter::parse_interval_string("250 MS").unwrap();
818        assert_eq!(duration, std::time::Duration::from_millis(250));
819
820        let duration2 = WindowRewriter::parse_interval_string("100 MILLISECONDS").unwrap();
821        assert_eq!(duration2, std::time::Duration::from_millis(100));
822
823        let duration3 = WindowRewriter::parse_interval_string("750 MILLISECOND").unwrap();
824        assert_eq!(duration3, std::time::Duration::from_millis(750));
825    }
826
827    #[test]
828    fn test_parse_tumble_with_offset() {
829        let sql = "SELECT TUMBLE(ts, INTERVAL '1' HOUR, INTERVAL '30' MINUTE) FROM events";
830        let dialect = GenericDialect {};
831        let statements = Parser::parse_sql(&dialect, sql).unwrap();
832
833        if let Statement::Query(query) = &statements[0] {
834            if let SetExpr::Select(select) = &*query.body {
835                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
836                    let window = WindowRewriter::extract_window_function(expr)
837                        .unwrap()
838                        .unwrap();
839
840                    match window {
841                        WindowFunction::Tumble {
842                            interval, offset, ..
843                        } => {
844                            let dur =
845                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
846                            assert_eq!(dur, std::time::Duration::from_secs(3600));
847                            assert!(offset.is_some(), "Expected offset to be set");
848                            let off_dur = WindowRewriter::parse_interval_to_duration(
849                                offset.as_ref().unwrap(),
850                            )
851                            .unwrap();
852                            assert_eq!(off_dur, std::time::Duration::from_secs(1800));
853                        }
854                        _ => panic!("Expected Tumble window"),
855                    }
856                }
857            }
858        }
859    }
860
861    #[test]
862    fn test_parse_hop_with_offset() {
863        let sql = "SELECT HOP(ts, INTERVAL '5' MINUTE, INTERVAL '15' MINUTE, INTERVAL '2' MINUTE) FROM events";
864        let dialect = GenericDialect {};
865        let statements = Parser::parse_sql(&dialect, sql).unwrap();
866
867        if let Statement::Query(query) = &statements[0] {
868            if let SetExpr::Select(select) = &*query.body {
869                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
870                    let window = WindowRewriter::extract_window_function(expr)
871                        .unwrap()
872                        .unwrap();
873
874                    match window {
875                        WindowFunction::Hop {
876                            slide_interval,
877                            window_interval,
878                            offset,
879                            ..
880                        } => {
881                            let slide = WindowRewriter::parse_interval_to_duration(&slide_interval)
882                                .unwrap();
883                            let size = WindowRewriter::parse_interval_to_duration(&window_interval)
884                                .unwrap();
885                            assert_eq!(slide, std::time::Duration::from_secs(300));
886                            assert_eq!(size, std::time::Duration::from_secs(900));
887                            assert!(offset.is_some(), "Expected offset to be set");
888                            let off_dur = WindowRewriter::parse_interval_to_duration(
889                                offset.as_ref().unwrap(),
890                            )
891                            .unwrap();
892                            assert_eq!(off_dur, std::time::Duration::from_secs(120));
893                        }
894                        _ => panic!("Expected Hop window"),
895                    }
896                }
897            }
898        }
899    }
900}