Skip to main content

laminar_sql/parser/
window_rewriter.rs

1//! Production implementation for window function extraction and rewriting
2//!
3//! This module handles:
4//! - TUMBLE(time_col, interval) - tumbling windows
5//! - HOP(time_col, slide, size) / SLIDE(...) - sliding/hopping windows
6//! - SESSION(time_col, gap) - session windows
7
8use sqlparser::ast::{
9    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Query, Select, SelectItem,
10    SetExpr, Statement,
11};
12
13use super::{ParseError, WindowFunction};
14
15/// Rewrites window functions in SQL queries
16pub struct WindowRewriter;
17
18impl WindowRewriter {
19    /// Rewrite a SQL statement to expand window functions.
20    ///
21    /// Transforms window functions like TUMBLE into appropriate table functions
22    /// and adds `window_start`/`window_end` columns.
23    ///
24    /// # Errors
25    ///
26    /// Returns `ParseError::WindowError` if a window function cannot be rewritten.
27    ///
28    /// # Example
29    ///
30    /// ```sql
31    /// -- Input:
32    /// SELECT COUNT(*) FROM events
33    /// GROUP BY TUMBLE(event_time, INTERVAL '5' MINUTE)
34    ///
35    /// -- Output:
36    /// SELECT window_start, window_end, COUNT(*)
37    /// FROM events
38    /// GROUP BY window_start, window_end
39    /// ```
40    pub fn rewrite_statement(stmt: &mut Statement) -> Result<(), ParseError> {
41        if let Statement::Query(query) = stmt {
42            Self::rewrite_query(query)?;
43        }
44        Ok(())
45    }
46
47    /// Rewrite a query
48    fn rewrite_query(query: &mut Query) -> Result<(), ParseError> {
49        if let SetExpr::Select(select) = &mut *query.body {
50            Self::rewrite_select(select)?;
51        }
52        Ok(())
53    }
54
55    /// Rewrite a SELECT statement to expand window functions.
56    ///
57    /// This processes GROUP BY to find window functions and adds
58    /// window_start/window_end to the projection.
59    fn rewrite_select(select: &mut Select) -> Result<(), ParseError> {
60        // Find window function in GROUP BY
61        let window_func = Self::find_window_in_group_by(select)?;
62
63        if let Some(_window) = window_func {
64            // Add window_start and window_end to projection if not already present
65            Self::ensure_window_columns_in_projection(select);
66        }
67
68        Ok(())
69    }
70
71    /// Find window function in GROUP BY clause.
72    fn find_window_in_group_by(select: &Select) -> Result<Option<WindowFunction>, ParseError> {
73        // Check the GROUP BY expressions
74        match &select.group_by {
75            sqlparser::ast::GroupByExpr::Expressions(exprs, _modifiers) => {
76                for expr in exprs {
77                    if let Some(window) = Self::extract_window_function(expr)? {
78                        return Ok(Some(window));
79                    }
80                }
81            }
82            sqlparser::ast::GroupByExpr::All(_) => {}
83        }
84        Ok(None)
85    }
86
87    /// Ensure window_start and window_end columns are in projection.
88    fn ensure_window_columns_in_projection(select: &mut Select) {
89        let has_window_start = Self::has_projection_column(select, "window_start");
90        let has_window_end = Self::has_projection_column(select, "window_end");
91
92        // Add window_start at the beginning if not present
93        if !has_window_start {
94            select.projection.insert(
95                0,
96                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_start"))),
97            );
98        }
99
100        // Add window_end after window_start if not present
101        if !has_window_end {
102            select.projection.insert(
103                1,
104                SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_end"))),
105            );
106        }
107    }
108
109    /// Check if a named column exists in the SELECT projection.
110    fn has_projection_column(select: &Select, name: &str) -> bool {
111        select.projection.iter().any(|item| {
112            if let SelectItem::UnnamedExpr(Expr::Identifier(ident)) = item {
113                ident.value.eq_ignore_ascii_case(name)
114            } else if let SelectItem::ExprWithAlias { alias, .. } = item {
115                alias.value.eq_ignore_ascii_case(name)
116            } else {
117                false
118            }
119        })
120    }
121
122    /// Check if expression contains a window function.
123    #[must_use]
124    pub fn contains_window_function(expr: &Expr) -> bool {
125        match expr {
126            Expr::Function(func) => {
127                if let Some(name) = func.name.0.last() {
128                    let func_name = name.to_string().to_uppercase();
129                    matches!(
130                        func_name.as_str(),
131                        "TUMBLE" | "HOP" | "SLIDE" | "SESSION" | "CUMULATE"
132                    )
133                } else {
134                    false
135                }
136            }
137            _ => false,
138        }
139    }
140
141    /// Extract window function details from expression.
142    ///
143    /// Parses the actual arguments from TUMBLE/HOP/SESSION function calls.
144    ///
145    /// # Supported syntax
146    ///
147    /// - `TUMBLE(time_column, interval)` - 2 arguments
148    /// - `HOP(time_column, slide_interval, window_size)` - 3 arguments
149    /// - `SLIDE(time_column, slide_interval, window_size)` - alias for HOP
150    /// - `SESSION(time_column, gap_interval)` - 2 arguments
151    ///
152    /// # Errors
153    ///
154    /// Returns `ParseError::WindowError` if:
155    /// - Function has empty name
156    /// - Wrong number of arguments for window type
157    /// - Arguments cannot be extracted
158    pub fn extract_window_function(expr: &Expr) -> Result<Option<WindowFunction>, ParseError> {
159        match expr {
160            Expr::Function(func) => {
161                let name =
162                    func.name.0.last().ok_or_else(|| {
163                        ParseError::WindowError("Empty function name".to_string())
164                    })?;
165
166                let func_name = name.to_string().to_uppercase();
167
168                // Extract arguments from the function
169                let args = Self::extract_function_args(&func.args)?;
170
171                match func_name.as_str() {
172                    "TUMBLE" => Self::parse_tumble_args(&args),
173                    "HOP" | "SLIDE" => Self::parse_hop_args(&args),
174                    "SESSION" => Self::parse_session_args(&args),
175                    "CUMULATE" => Self::parse_cumulate_args(&args),
176                    _ => Ok(None),
177                }
178            }
179            _ => Ok(None),
180        }
181    }
182
183    /// Extract function arguments as a vector of expressions.
184    fn extract_function_args(args: &FunctionArguments) -> Result<Vec<Expr>, ParseError> {
185        match args {
186            FunctionArguments::List(arg_list) => {
187                let mut result = Vec::new();
188                for arg in &arg_list.args {
189                    if let Some(expr) = Self::extract_arg_expr(arg) {
190                        result.push(expr);
191                    }
192                }
193                Ok(result)
194            }
195            FunctionArguments::None => Ok(vec![]),
196            FunctionArguments::Subquery(_) => Err(ParseError::WindowError(
197                "Subquery arguments not supported for window functions".to_string(),
198            )),
199        }
200    }
201
202    /// Extract expression from a function argument.
203    fn extract_arg_expr(arg: &FunctionArg) -> Option<Expr> {
204        match arg {
205            FunctionArg::Unnamed(arg_expr) => match arg_expr {
206                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
207                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
208            },
209            FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => match arg {
210                FunctionArgExpr::Expr(expr) => Some(expr.clone()),
211                FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
212            },
213        }
214    }
215
216    /// Parse TUMBLE(time_column, interval) arguments.
217    fn parse_tumble_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
218        if args.len() != 2 {
219            return Err(ParseError::WindowError(format!(
220                "TUMBLE requires 2 arguments (time_column, interval), got {}",
221                args.len()
222            )));
223        }
224
225        Ok(Some(WindowFunction::Tumble {
226            time_column: Box::new(args[0].clone()),
227            interval: Box::new(args[1].clone()),
228        }))
229    }
230
231    /// Parse HOP/SLIDE(time_column, slide_interval, window_size) arguments.
232    fn parse_hop_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
233        if args.len() != 3 {
234            return Err(ParseError::WindowError(format!(
235                "HOP/SLIDE requires 3 arguments (time_column, slide_interval, window_size), got {}",
236                args.len()
237            )));
238        }
239
240        Ok(Some(WindowFunction::Hop {
241            time_column: Box::new(args[0].clone()),
242            slide_interval: Box::new(args[1].clone()),
243            window_interval: Box::new(args[2].clone()),
244        }))
245    }
246
247    /// Parse SESSION(time_column, gap_interval) arguments.
248    fn parse_session_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
249        if args.len() != 2 {
250            return Err(ParseError::WindowError(format!(
251                "SESSION requires 2 arguments (time_column, gap_interval), got {}",
252                args.len()
253            )));
254        }
255
256        Ok(Some(WindowFunction::Session {
257            time_column: Box::new(args[0].clone()),
258            gap_interval: Box::new(args[1].clone()),
259        }))
260    }
261
262    /// Parse CUMULATE(time_column, step_interval, max_size_interval) arguments.
263    fn parse_cumulate_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
264        if args.len() != 3 {
265            return Err(ParseError::WindowError(format!(
266                "CUMULATE requires 3 arguments (time_column, step_interval, max_size_interval), got {}",
267                args.len()
268            )));
269        }
270
271        Ok(Some(WindowFunction::Cumulate {
272            time_column: Box::new(args[0].clone()),
273            step_interval: Box::new(args[1].clone()),
274            max_size_interval: Box::new(args[2].clone()),
275        }))
276    }
277
278    /// Extract the time column name from a window function.
279    ///
280    /// Returns the column name as a string if extractable.
281    #[must_use]
282    pub fn get_time_column_name(window: &WindowFunction) -> Option<String> {
283        let expr = match window {
284            WindowFunction::Tumble { time_column, .. }
285            | WindowFunction::Hop { time_column, .. }
286            | WindowFunction::Session { time_column, .. }
287            | WindowFunction::Cumulate { time_column, .. } => time_column.as_ref(),
288        };
289
290        match expr {
291            Expr::Identifier(ident) => Some(ident.value.clone()),
292            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
293            _ => None,
294        }
295    }
296
297    /// Parse an INTERVAL expression to Duration.
298    ///
299    /// Supports: SECOND, MINUTE, HOUR, DAY
300    ///
301    /// # Errors
302    ///
303    /// Returns `ParseError::WindowError` if the expression is not a valid interval.
304    pub fn parse_interval_to_duration(expr: &Expr) -> Result<std::time::Duration, ParseError> {
305        match expr {
306            Expr::Interval(interval) => {
307                // Extract the value
308                let value = Self::extract_interval_value(&interval.value)?;
309
310                // Get the unit (defaults to SECOND)
311                let unit = interval
312                    .leading_field
313                    .clone()
314                    .unwrap_or(sqlparser::ast::DateTimeField::Second);
315
316                match unit {
317                    sqlparser::ast::DateTimeField::Millisecond
318                    | sqlparser::ast::DateTimeField::Milliseconds => {
319                        return Ok(std::time::Duration::from_millis(value));
320                    }
321                    _ => {}
322                }
323
324                let seconds =
325                    match unit {
326                        sqlparser::ast::DateTimeField::Second
327                        | sqlparser::ast::DateTimeField::Seconds => value,
328                        sqlparser::ast::DateTimeField::Minute
329                        | sqlparser::ast::DateTimeField::Minutes => value * 60,
330                        sqlparser::ast::DateTimeField::Hour
331                        | sqlparser::ast::DateTimeField::Hours => value * 3600,
332                        sqlparser::ast::DateTimeField::Day
333                        | sqlparser::ast::DateTimeField::Days => value * 86400,
334                        _ => {
335                            return Err(ParseError::WindowError(format!(
336                                "Unsupported interval unit: {unit:?}"
337                            )))
338                        }
339                    };
340
341                Ok(std::time::Duration::from_secs(seconds))
342            }
343            // Handle string literal intervals like '5 MINUTES'
344            Expr::Value(value_with_span) => {
345                use sqlparser::ast::Value;
346                if let Value::SingleQuotedString(s) = &value_with_span.value {
347                    Self::parse_interval_string(s)
348                } else {
349                    Err(ParseError::WindowError(format!(
350                        "Expected string value, got: {value_with_span:?}"
351                    )))
352                }
353            }
354            // Handle identifier that might be an interval string
355            Expr::Identifier(ident) => Self::parse_interval_string(&ident.value),
356            _ => Err(ParseError::WindowError(format!(
357                "Expected INTERVAL expression, got: {expr:?}"
358            ))),
359        }
360    }
361
362    /// Extract numeric value from interval expression.
363    fn extract_interval_value(expr: &Expr) -> Result<u64, ParseError> {
364        match expr {
365            Expr::Value(value_with_span) => {
366                use sqlparser::ast::Value;
367                match &value_with_span.value {
368                    Value::Number(n, _) => n.parse::<u64>().map_err(|_| {
369                        ParseError::WindowError(format!("Invalid interval value: {n}"))
370                    }),
371                    Value::SingleQuotedString(s) => {
372                        // Handle '5' or '5 MINUTE'
373                        let num_str = s.split_whitespace().next().unwrap_or(s);
374                        num_str.parse::<u64>().map_err(|_| {
375                            ParseError::WindowError(format!("Invalid interval value: {s}"))
376                        })
377                    }
378                    _ => Err(ParseError::WindowError(format!(
379                        "Unsupported value type in interval: {value_with_span:?}"
380                    ))),
381                }
382            }
383            _ => Err(ParseError::WindowError(format!(
384                "Cannot extract interval value from: {expr:?}"
385            ))),
386        }
387    }
388
389    /// Parse an interval string like "5 MINUTES" or "1 HOUR".
390    fn parse_interval_string(s: &str) -> Result<std::time::Duration, ParseError> {
391        let parts: Vec<&str> = s.split_whitespace().collect();
392        if parts.is_empty() {
393            return Err(ParseError::WindowError("Empty interval string".to_string()));
394        }
395
396        let value: u64 = parts[0].parse().map_err(|_| {
397            ParseError::WindowError(format!("Invalid interval value: {}", parts[0]))
398        })?;
399
400        let unit = if parts.len() > 1 {
401            parts[1].to_uppercase()
402        } else {
403            "SECOND".to_string()
404        };
405
406        if matches!(unit.as_str(), "MILLISECOND" | "MILLISECONDS" | "MS") {
407            return Ok(std::time::Duration::from_millis(value));
408        }
409
410        let seconds = match unit.as_str() {
411            "SECOND" | "SECONDS" | "S" => value,
412            "MINUTE" | "MINUTES" | "M" => value * 60,
413            "HOUR" | "HOURS" | "H" => value * 3600,
414            "DAY" | "DAYS" | "D" => value * 86400,
415            _ => {
416                return Err(ParseError::WindowError(format!(
417                    "Unsupported interval unit: {unit}"
418                )))
419            }
420        };
421
422        Ok(std::time::Duration::from_secs(seconds))
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use sqlparser::dialect::GenericDialect;
430    use sqlparser::parser::Parser;
431
432    #[test]
433    fn test_contains_window_function() {
434        let sql = "SELECT TUMBLE(event_time, INTERVAL '5' MINUTE) FROM events";
435        let dialect = GenericDialect {};
436        let statements = Parser::parse_sql(&dialect, sql).unwrap();
437
438        if let Statement::Query(query) = &statements[0] {
439            if let SetExpr::Select(select) = &*query.body {
440                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
441                    assert!(WindowRewriter::contains_window_function(expr));
442                }
443            }
444        }
445    }
446
447    #[test]
448    fn test_rewrite_statement() {
449        let sql = "SELECT COUNT(*) FROM events GROUP BY event_time";
450        let dialect = GenericDialect {};
451        let mut statements = Parser::parse_sql(&dialect, sql).unwrap();
452
453        // Should not fail on standard SQL
454        assert!(WindowRewriter::rewrite_statement(&mut statements[0]).is_ok());
455    }
456
457    #[test]
458    fn test_extract_tumble_with_actual_args() {
459        let sql = "SELECT TUMBLE(order_time, INTERVAL '10' MINUTE) FROM orders";
460        let dialect = GenericDialect {};
461        let statements = Parser::parse_sql(&dialect, sql).unwrap();
462
463        if let Statement::Query(query) = &statements[0] {
464            if let SetExpr::Select(select) = &*query.body {
465                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
466                    let window = WindowRewriter::extract_window_function(expr)
467                        .unwrap()
468                        .unwrap();
469
470                    match window {
471                        WindowFunction::Tumble {
472                            time_column,
473                            interval,
474                        } => {
475                            // Verify time column is extracted correctly
476                            assert_eq!(time_column.to_string(), "order_time");
477
478                            // Verify interval is extracted
479                            assert!(interval.to_string().contains("10"));
480                        }
481                        _ => panic!("Expected Tumble window"),
482                    }
483                }
484            }
485        }
486    }
487
488    #[test]
489    fn test_extract_hop_with_actual_args() {
490        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM readings";
491        let dialect = GenericDialect {};
492        let statements = Parser::parse_sql(&dialect, sql).unwrap();
493
494        if let Statement::Query(query) = &statements[0] {
495            if let SetExpr::Select(select) = &*query.body {
496                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
497                    let window = WindowRewriter::extract_window_function(expr)
498                        .unwrap()
499                        .unwrap();
500
501                    match window {
502                        WindowFunction::Hop {
503                            time_column,
504                            slide_interval,
505                            window_interval,
506                        } => {
507                            assert_eq!(time_column.to_string(), "ts");
508                            assert!(slide_interval.to_string().contains('1'));
509                            assert!(window_interval.to_string().contains('5'));
510                        }
511                        _ => panic!("Expected Hop window"),
512                    }
513                }
514            }
515        }
516    }
517
518    #[test]
519    fn test_extract_session_with_actual_args() {
520        let sql = "SELECT SESSION(click_time, INTERVAL '30' MINUTE) FROM clicks";
521        let dialect = GenericDialect {};
522        let statements = Parser::parse_sql(&dialect, sql).unwrap();
523
524        if let Statement::Query(query) = &statements[0] {
525            if let SetExpr::Select(select) = &*query.body {
526                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
527                    let window = WindowRewriter::extract_window_function(expr)
528                        .unwrap()
529                        .unwrap();
530
531                    match window {
532                        WindowFunction::Session {
533                            time_column,
534                            gap_interval,
535                        } => {
536                            assert_eq!(time_column.to_string(), "click_time");
537                            assert!(gap_interval.to_string().contains("30"));
538                        }
539                        _ => panic!("Expected Session window"),
540                    }
541                }
542            }
543        }
544    }
545
546    #[test]
547    fn test_tumble_wrong_args_count() {
548        let sql = "SELECT TUMBLE(ts) FROM events";
549        let dialect = GenericDialect {};
550        let statements = Parser::parse_sql(&dialect, sql).unwrap();
551
552        if let Statement::Query(query) = &statements[0] {
553            if let SetExpr::Select(select) = &*query.body {
554                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
555                    let result = WindowRewriter::extract_window_function(expr);
556                    assert!(result.is_err());
557                    let err = result.unwrap_err();
558                    assert!(err.to_string().contains("2 arguments"));
559                }
560            }
561        }
562    }
563
564    #[test]
565    fn test_hop_wrong_args_count() {
566        let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE) FROM events";
567        let dialect = GenericDialect {};
568        let statements = Parser::parse_sql(&dialect, sql).unwrap();
569
570        if let Statement::Query(query) = &statements[0] {
571            if let SetExpr::Select(select) = &*query.body {
572                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
573                    let result = WindowRewriter::extract_window_function(expr);
574                    assert!(result.is_err());
575                    let err = result.unwrap_err();
576                    assert!(err.to_string().contains("3 arguments"));
577                }
578            }
579        }
580    }
581
582    #[test]
583    fn test_slide_alias_for_hop() {
584        let sql = "SELECT SLIDE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
585        let dialect = GenericDialect {};
586        let statements = Parser::parse_sql(&dialect, sql).unwrap();
587
588        if let Statement::Query(query) = &statements[0] {
589            if let SetExpr::Select(select) = &*query.body {
590                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
591                    let window = WindowRewriter::extract_window_function(expr)
592                        .unwrap()
593                        .unwrap();
594
595                    // SLIDE should be parsed as Hop
596                    assert!(matches!(window, WindowFunction::Hop { .. }));
597                }
598            }
599        }
600    }
601
602    #[test]
603    fn test_get_time_column_name() {
604        let sql = "SELECT TUMBLE(my_timestamp, INTERVAL '5' MINUTE) FROM events";
605        let dialect = GenericDialect {};
606        let statements = Parser::parse_sql(&dialect, sql).unwrap();
607
608        if let Statement::Query(query) = &statements[0] {
609            if let SetExpr::Select(select) = &*query.body {
610                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
611                    let window = WindowRewriter::extract_window_function(expr)
612                        .unwrap()
613                        .unwrap();
614
615                    let col_name = WindowRewriter::get_time_column_name(&window);
616                    assert_eq!(col_name, Some("my_timestamp".to_string()));
617                }
618            }
619        }
620    }
621
622    #[test]
623    fn test_parse_interval_to_duration() {
624        // Test parsing from GROUP BY
625        let sql = "SELECT COUNT(*) FROM events GROUP BY TUMBLE(ts, INTERVAL '5' MINUTE)";
626        let dialect = GenericDialect {};
627        let statements = Parser::parse_sql(&dialect, sql).unwrap();
628
629        if let Statement::Query(query) = &statements[0] {
630            if let SetExpr::Select(select) = &*query.body {
631                if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
632                    if let Some(expr) = exprs.first() {
633                        let window = WindowRewriter::extract_window_function(expr)
634                            .unwrap()
635                            .unwrap();
636
637                        if let WindowFunction::Tumble { interval, .. } = window {
638                            let duration =
639                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
640                            assert_eq!(duration, std::time::Duration::from_secs(300));
641                        }
642                    }
643                }
644            }
645        }
646    }
647
648    #[test]
649    fn test_parse_interval_string_formats() {
650        // Test various interval string formats
651        let cases = [
652            ("5 MINUTE", 300),
653            ("5 MINUTES", 300),
654            ("1 HOUR", 3600),
655            ("2 HOURS", 7200),
656            ("10 SECOND", 10),
657            ("1 DAY", 86400),
658        ];
659
660        for (input, expected_secs) in cases {
661            let result = WindowRewriter::parse_interval_string(input).unwrap();
662            assert_eq!(
663                result,
664                std::time::Duration::from_secs(expected_secs),
665                "Failed for input: {input}"
666            );
667        }
668    }
669
670    #[test]
671    fn test_window_in_group_by() {
672        let sql = "SELECT user_id, COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '1' HOUR), user_id";
673        let dialect = GenericDialect {};
674        let statements = Parser::parse_sql(&dialect, sql).unwrap();
675
676        if let Statement::Query(query) = &statements[0] {
677            if let SetExpr::Select(select) = &*query.body {
678                let window = WindowRewriter::find_window_in_group_by(select)
679                    .unwrap()
680                    .unwrap();
681
682                assert!(matches!(window, WindowFunction::Tumble { .. }));
683
684                if let WindowFunction::Tumble { time_column, .. } = window {
685                    assert_eq!(time_column.to_string(), "event_time");
686                }
687            }
688        }
689    }
690
691    #[test]
692    fn test_contains_cumulate_window_function() {
693        let sql = "SELECT CUMULATE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
694        let dialect = GenericDialect {};
695        let statements = Parser::parse_sql(&dialect, sql).unwrap();
696
697        if let Statement::Query(query) = &statements[0] {
698            if let SetExpr::Select(select) = &*query.body {
699                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
700                    assert!(WindowRewriter::contains_window_function(expr));
701                }
702            }
703        }
704    }
705
706    #[test]
707    fn test_extract_cumulate_with_actual_args() {
708        let sql =
709            "SELECT CUMULATE(order_time, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM orders";
710        let dialect = GenericDialect {};
711        let statements = Parser::parse_sql(&dialect, sql).unwrap();
712
713        if let Statement::Query(query) = &statements[0] {
714            if let SetExpr::Select(select) = &*query.body {
715                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
716                    let window = WindowRewriter::extract_window_function(expr)
717                        .unwrap()
718                        .unwrap();
719
720                    match window {
721                        WindowFunction::Cumulate {
722                            time_column,
723                            step_interval,
724                            max_size_interval,
725                        } => {
726                            assert_eq!(time_column.to_string(), "order_time");
727                            assert!(step_interval.to_string().contains('1'));
728                            assert!(max_size_interval.to_string().contains('5'));
729                        }
730                        _ => panic!("Expected Cumulate window"),
731                    }
732                }
733            }
734        }
735    }
736
737    #[test]
738    fn test_cumulate_wrong_args_count() {
739        let sql = "SELECT CUMULATE(ts, INTERVAL '1' MINUTE) FROM events";
740        let dialect = GenericDialect {};
741        let statements = Parser::parse_sql(&dialect, sql).unwrap();
742
743        if let Statement::Query(query) = &statements[0] {
744            if let SetExpr::Select(select) = &*query.body {
745                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
746                    let result = WindowRewriter::extract_window_function(expr);
747                    assert!(result.is_err());
748                    let err = result.unwrap_err();
749                    assert!(err.to_string().contains("3 arguments"));
750                }
751            }
752        }
753    }
754
755    #[test]
756    fn test_cumulate_time_column_name() {
757        let sql = "SELECT CUMULATE(my_ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
758        let dialect = GenericDialect {};
759        let statements = Parser::parse_sql(&dialect, sql).unwrap();
760
761        if let Statement::Query(query) = &statements[0] {
762            if let SetExpr::Select(select) = &*query.body {
763                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
764                    let window = WindowRewriter::extract_window_function(expr)
765                        .unwrap()
766                        .unwrap();
767
768                    let col_name = WindowRewriter::get_time_column_name(&window);
769                    assert_eq!(col_name, Some("my_ts".to_string()));
770                }
771            }
772        }
773    }
774
775    #[test]
776    fn test_millisecond_interval() {
777        // parse_interval_to_duration should handle MILLISECOND unit
778        let sql = "SELECT TUMBLE(ts, INTERVAL '500' MILLISECOND) FROM events";
779        let dialect = GenericDialect {};
780        let statements = Parser::parse_sql(&dialect, sql).unwrap();
781
782        if let Statement::Query(query) = &statements[0] {
783            if let SetExpr::Select(select) = &*query.body {
784                if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
785                    let window = WindowRewriter::extract_window_function(expr)
786                        .unwrap()
787                        .unwrap();
788
789                    match window {
790                        WindowFunction::Tumble {
791                            time_column: _,
792                            interval,
793                        } => {
794                            let duration =
795                                WindowRewriter::parse_interval_to_duration(&interval).unwrap();
796                            assert_eq!(
797                                duration,
798                                std::time::Duration::from_millis(500),
799                                "INTERVAL '500' MILLISECOND should parse to 500ms"
800                            );
801                        }
802                        _ => panic!("Expected Tumble window"),
803                    }
804                }
805            }
806        }
807    }
808
809    #[test]
810    fn test_millisecond_interval_string() {
811        // parse_interval_string should handle MS unit
812        let duration = WindowRewriter::parse_interval_string("250 MS").unwrap();
813        assert_eq!(duration, std::time::Duration::from_millis(250));
814
815        let duration2 = WindowRewriter::parse_interval_string("100 MILLISECONDS").unwrap();
816        assert_eq!(duration2, std::time::Duration::from_millis(100));
817
818        let duration3 = WindowRewriter::parse_interval_string("750 MILLISECOND").unwrap();
819        assert_eq!(duration3, std::time::Duration::from_millis(750));
820    }
821}