Skip to main content

laminar_sql/parser/
analytic_parser.rs

1//! Analytic window function detection and extraction
2//!
3//! Analyzes SQL queries for analytic functions like LAG, LEAD, FIRST_VALUE,
4//! LAST_VALUE, and NTH_VALUE with OVER clauses. These are per-row window
5//! functions (distinct from GROUP BY aggregate windows like TUMBLE/HOP/SESSION).
6
7use sqlparser::ast::{Expr, SelectItem, SetExpr, Statement};
8
9/// Types of analytic window functions.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum AnalyticFunctionType {
12    /// LAG(col, offset, default) — look back `offset` rows in partition.
13    Lag,
14    /// LEAD(col, offset, default) — look ahead `offset` rows in partition.
15    Lead,
16    /// FIRST_VALUE(col) OVER (...) — first value in window frame.
17    FirstValue,
18    /// LAST_VALUE(col) OVER (...) — last value in window frame.
19    LastValue,
20    /// NTH_VALUE(col, n) OVER (...) — n-th value in window frame.
21    NthValue,
22}
23
24impl AnalyticFunctionType {
25    /// Returns the function name as used in SQL.
26    #[must_use]
27    pub fn sql_name(&self) -> &'static str {
28        match self {
29            Self::Lag => "LAG",
30            Self::Lead => "LEAD",
31            Self::FirstValue => "FIRST_VALUE",
32            Self::LastValue => "LAST_VALUE",
33            Self::NthValue => "NTH_VALUE",
34        }
35    }
36
37    /// Returns true if this function requires buffering future events.
38    #[must_use]
39    pub fn requires_lookahead(&self) -> bool {
40        matches!(self, Self::Lead)
41    }
42}
43
44/// Information about a single analytic function call.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AnalyticFunctionInfo {
47    /// Type of analytic function
48    pub function_type: AnalyticFunctionType,
49    /// Column being referenced (first argument)
50    pub column: String,
51    /// Offset for LAG/LEAD (default 1), or N for NTH_VALUE
52    pub offset: usize,
53    /// Default value expression as string (for LAG/LEAD third argument)
54    pub default_value: Option<String>,
55    /// Output alias (AS name)
56    pub alias: Option<String>,
57}
58
59/// Result of analyzing analytic functions in a query.
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct AnalyticWindowAnalysis {
62    /// Analytic functions found in the query
63    pub functions: Vec<AnalyticFunctionInfo>,
64    /// PARTITION BY columns from the OVER clause
65    pub partition_columns: Vec<String>,
66    /// ORDER BY columns from the OVER clause
67    pub order_columns: Vec<String>,
68}
69
70impl AnalyticWindowAnalysis {
71    /// Returns true if any function requires lookahead (LEAD).
72    #[must_use]
73    pub fn has_lookahead(&self) -> bool {
74        self.functions
75            .iter()
76            .any(|f| f.function_type.requires_lookahead())
77    }
78
79    /// Returns the maximum offset across all functions.
80    #[must_use]
81    pub fn max_offset(&self) -> usize {
82        self.functions.iter().map(|f| f.offset).max().unwrap_or(0)
83    }
84}
85
86/// Analyzes a SQL statement for analytic window functions.
87///
88/// Walks SELECT items looking for functions with OVER clauses that match
89/// LAG, LEAD, FIRST_VALUE, LAST_VALUE, or NTH_VALUE. Returns `None` if
90/// no analytic functions are found.
91///
92/// # Arguments
93///
94/// * `stmt` - The SQL statement to analyze
95///
96/// # Returns
97///
98/// An `AnalyticWindowAnalysis` if analytic functions are found, or `None`.
99#[must_use]
100pub fn analyze_analytic_functions(stmt: &Statement) -> Option<AnalyticWindowAnalysis> {
101    let Statement::Query(query) = stmt else {
102        return None;
103    };
104
105    let SetExpr::Select(select) = query.body.as_ref() else {
106        return None;
107    };
108
109    let mut functions = Vec::new();
110    let mut partition_columns = Vec::new();
111    let mut order_columns = Vec::new();
112    let mut first_window = true;
113
114    for item in &select.projection {
115        let (expr, alias) = match item {
116            SelectItem::UnnamedExpr(expr) => (expr, None),
117            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
118            _ => continue,
119        };
120
121        if let Some(info) = extract_analytic_function(expr, alias, &mut |spec| {
122            if first_window {
123                partition_columns = spec
124                    .partition_by
125                    .iter()
126                    .filter_map(extract_column_name)
127                    .collect();
128                order_columns = spec
129                    .order_by
130                    .iter()
131                    .filter_map(|ob| extract_column_name(&ob.expr))
132                    .collect();
133                first_window = false;
134            }
135        }) {
136            functions.push(info);
137        }
138    }
139
140    if functions.is_empty() {
141        return None;
142    }
143
144    Some(AnalyticWindowAnalysis {
145        functions,
146        partition_columns,
147        order_columns,
148    })
149}
150
151/// Extracts an analytic function from an expression.
152///
153/// Returns function info if the expression is a recognized analytic function
154/// with an OVER clause. Calls `on_window_spec` with the window spec from the
155/// first function found so the caller can extract partition/order columns.
156fn extract_analytic_function(
157    expr: &Expr,
158    alias: Option<String>,
159    on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
160) -> Option<AnalyticFunctionInfo> {
161    let Expr::Function(func) = expr else {
162        return None;
163    };
164
165    let name = func.name.to_string().to_uppercase();
166    let function_type = match name.as_str() {
167        "LAG" => AnalyticFunctionType::Lag,
168        "LEAD" => AnalyticFunctionType::Lead,
169        "FIRST_VALUE" => AnalyticFunctionType::FirstValue,
170        "LAST_VALUE" => AnalyticFunctionType::LastValue,
171        "NTH_VALUE" => AnalyticFunctionType::NthValue,
172        _ => return None,
173    };
174
175    // Must have an OVER clause to be an analytic function
176    let window_spec = func.over.as_ref()?;
177    match window_spec {
178        sqlparser::ast::WindowType::WindowSpec(spec) => {
179            on_window_spec(spec);
180        }
181        sqlparser::ast::WindowType::NamedWindow(_) => {}
182    }
183
184    // Extract arguments
185    let args = extract_function_args(func);
186
187    // First arg is the column
188    let column = args.first().cloned().unwrap_or_default();
189
190    // Second arg is offset (for LAG/LEAD) or N (for NTH_VALUE), default 1
191    let offset = args
192        .get(1)
193        .and_then(|s| s.parse::<usize>().ok())
194        .unwrap_or(1);
195
196    // Third arg is default value (for LAG/LEAD only)
197    let default_value = if matches!(
198        function_type,
199        AnalyticFunctionType::Lag | AnalyticFunctionType::Lead
200    ) {
201        args.get(2).cloned()
202    } else {
203        None
204    };
205
206    Some(AnalyticFunctionInfo {
207        function_type,
208        column,
209        offset,
210        default_value,
211        alias,
212    })
213}
214
215/// Extracts function argument expressions as strings.
216fn extract_function_args(func: &sqlparser::ast::Function) -> Vec<String> {
217    match &func.args {
218        sqlparser::ast::FunctionArguments::List(list) => list
219            .args
220            .iter()
221            .filter_map(|arg| match arg {
222                sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
223                    expr,
224                )) => Some(expr_to_string(expr)),
225                _ => None,
226            })
227            .collect(),
228        _ => vec![],
229    }
230}
231
232/// Converts an expression to its string representation.
233fn expr_to_string(expr: &Expr) -> String {
234    match expr {
235        Expr::Identifier(ident) => ident.value.clone(),
236        Expr::CompoundIdentifier(parts) => parts.last().map_or(String::new(), |p| p.value.clone()),
237        Expr::Value(value_with_span) => match &value_with_span.value {
238            sqlparser::ast::Value::Number(n, _) => n.clone(),
239            sqlparser::ast::Value::SingleQuotedString(s) => s.clone(),
240            sqlparser::ast::Value::Null => "NULL".to_string(),
241            _ => format!("{}", value_with_span.value),
242        },
243        Expr::UnaryOp {
244            op: sqlparser::ast::UnaryOperator::Minus,
245            expr: inner,
246        } => format!("-{}", expr_to_string(inner)),
247        _ => expr.to_string(),
248    }
249}
250
251/// Extracts a simple column name from an expression.
252fn extract_column_name(expr: &Expr) -> Option<String> {
253    match expr {
254        Expr::Identifier(ident) => Some(ident.value.clone()),
255        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
256        _ => None,
257    }
258}
259
260// --- Window Frame types (F-SQL-006) ---
261
262/// Types of aggregate functions used with window frames.
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
264pub enum WindowFrameFunction {
265    /// AVG(col) OVER (... ROWS BETWEEN ...)
266    Avg,
267    /// SUM(col) OVER (... ROWS BETWEEN ...)
268    Sum,
269    /// MIN(col) OVER (... ROWS BETWEEN ...)
270    Min,
271    /// MAX(col) OVER (... ROWS BETWEEN ...)
272    Max,
273    /// COUNT(*) OVER (... ROWS BETWEEN ...)
274    Count,
275    /// FIRST_VALUE(col) OVER (... ROWS BETWEEN ...)
276    FirstValue,
277    /// LAST_VALUE(col) OVER (... ROWS BETWEEN ...)
278    LastValue,
279}
280
281impl WindowFrameFunction {
282    /// Returns the function name as used in SQL.
283    #[must_use]
284    pub fn sql_name(&self) -> &'static str {
285        match self {
286            Self::Avg => "AVG",
287            Self::Sum => "SUM",
288            Self::Min => "MIN",
289            Self::Max => "MAX",
290            Self::Count => "COUNT",
291            Self::FirstValue => "FIRST_VALUE",
292            Self::LastValue => "LAST_VALUE",
293        }
294    }
295}
296
297/// Window frame unit type.
298#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
299pub enum FrameUnits {
300    /// ROWS BETWEEN — physical row offsets
301    Rows,
302    /// RANGE BETWEEN — logical value range
303    Range,
304}
305
306/// A single bound in a window frame specification.
307#[derive(Debug, Clone, PartialEq, Eq, Hash)]
308pub enum FrameBound {
309    /// UNBOUNDED PRECEDING
310    UnboundedPreceding,
311    /// N PRECEDING
312    Preceding(u64),
313    /// CURRENT ROW
314    CurrentRow,
315    /// N FOLLOWING
316    Following(u64),
317    /// UNBOUNDED FOLLOWING
318    UnboundedFollowing,
319}
320
321/// Information about a single window frame function call.
322#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct WindowFrameInfo {
324    /// Type of aggregate function
325    pub function_type: WindowFrameFunction,
326    /// Column being aggregated (or "*" for COUNT(*))
327    pub column: String,
328    /// Frame unit type (ROWS or RANGE)
329    pub units: FrameUnits,
330    /// Start bound of the frame
331    pub start_bound: FrameBound,
332    /// End bound of the frame
333    pub end_bound: FrameBound,
334    /// Output alias (AS name)
335    pub alias: Option<String>,
336}
337
338/// Result of analyzing window frame functions in a query.
339#[derive(Debug, Clone, PartialEq, Eq)]
340pub struct WindowFrameAnalysis {
341    /// Window frame functions found in the query
342    pub functions: Vec<WindowFrameInfo>,
343    /// PARTITION BY columns from the OVER clause
344    pub partition_columns: Vec<String>,
345    /// ORDER BY columns from the OVER clause
346    pub order_columns: Vec<String>,
347}
348
349impl WindowFrameAnalysis {
350    /// Returns true if any frame uses FOLLOWING bounds.
351    #[must_use]
352    pub fn has_following(&self) -> bool {
353        self.functions.iter().any(|f| {
354            matches!(
355                f.end_bound,
356                FrameBound::Following(_) | FrameBound::UnboundedFollowing
357            ) || matches!(
358                f.start_bound,
359                FrameBound::Following(_) | FrameBound::UnboundedFollowing
360            )
361        })
362    }
363
364    /// Returns the maximum preceding offset across all functions.
365    #[must_use]
366    pub fn max_preceding(&self) -> u64 {
367        self.functions
368            .iter()
369            .filter_map(|f| match &f.start_bound {
370                FrameBound::Preceding(n) => Some(*n),
371                _ => None,
372            })
373            .max()
374            .unwrap_or(0)
375    }
376}
377
378/// Analyzes a SQL statement for window frame aggregate functions.
379///
380/// Walks SELECT items looking for aggregate functions (AVG, SUM, MIN, MAX,
381/// COUNT, FIRST_VALUE, LAST_VALUE) with OVER clauses that contain explicit
382/// ROWS/RANGE frame specifications. Returns `None` if no such functions
383/// are found.
384///
385/// This is distinct from `analyze_analytic_functions()` which handles
386/// per-row offset functions (LAG/LEAD). Window frame functions compute
387/// aggregates over a sliding frame of rows.
388#[must_use]
389pub fn analyze_window_frames(stmt: &Statement) -> Option<WindowFrameAnalysis> {
390    let Statement::Query(query) = stmt else {
391        return None;
392    };
393
394    let SetExpr::Select(select) = query.body.as_ref() else {
395        return None;
396    };
397
398    let mut functions = Vec::new();
399    let mut partition_columns = Vec::new();
400    let mut order_columns = Vec::new();
401    let mut first_window = true;
402
403    for item in &select.projection {
404        let (expr, alias) = match item {
405            SelectItem::UnnamedExpr(expr) => (expr, None),
406            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
407            _ => continue,
408        };
409
410        if let Some(info) = extract_window_frame_function(expr, alias, &mut |spec| {
411            if first_window {
412                partition_columns = spec
413                    .partition_by
414                    .iter()
415                    .filter_map(extract_column_name)
416                    .collect();
417                order_columns = spec
418                    .order_by
419                    .iter()
420                    .filter_map(|ob| extract_column_name(&ob.expr))
421                    .collect();
422                first_window = false;
423            }
424        }) {
425            functions.push(info);
426        }
427    }
428
429    if functions.is_empty() {
430        return None;
431    }
432
433    Some(WindowFrameAnalysis {
434        functions,
435        partition_columns,
436        order_columns,
437    })
438}
439
440/// Extracts a window frame aggregate function from an expression.
441fn extract_window_frame_function(
442    expr: &Expr,
443    alias: Option<String>,
444    on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
445) -> Option<WindowFrameInfo> {
446    let Expr::Function(func) = expr else {
447        return None;
448    };
449
450    let name = func.name.to_string().to_uppercase();
451    let function_type = match name.as_str() {
452        "AVG" => WindowFrameFunction::Avg,
453        "SUM" => WindowFrameFunction::Sum,
454        "MIN" => WindowFrameFunction::Min,
455        "MAX" => WindowFrameFunction::Max,
456        "COUNT" => WindowFrameFunction::Count,
457        "FIRST_VALUE" => WindowFrameFunction::FirstValue,
458        "LAST_VALUE" => WindowFrameFunction::LastValue,
459        _ => return None,
460    };
461
462    // Must have an OVER clause with an explicit window frame
463    let window_type = func.over.as_ref()?;
464    let spec = match window_type {
465        sqlparser::ast::WindowType::WindowSpec(spec) => spec,
466        sqlparser::ast::WindowType::NamedWindow(_) => return None,
467    };
468
469    // Only match functions with explicit ROWS/RANGE frame specs
470    let frame = spec.window_frame.as_ref()?;
471
472    on_window_spec(spec);
473
474    let units = match frame.units {
475        sqlparser::ast::WindowFrameUnits::Rows => FrameUnits::Rows,
476        sqlparser::ast::WindowFrameUnits::Range => FrameUnits::Range,
477        sqlparser::ast::WindowFrameUnits::Groups => return None,
478    };
479
480    let start_bound = convert_frame_bound(&frame.start_bound);
481    let end_bound = frame
482        .end_bound
483        .as_ref()
484        .map_or(FrameBound::CurrentRow, convert_frame_bound);
485
486    // Extract the column argument
487    let args = extract_function_args(func);
488    let column = args.first().cloned().unwrap_or_else(|| "*".to_string());
489
490    Some(WindowFrameInfo {
491        function_type,
492        column,
493        units,
494        start_bound,
495        end_bound,
496        alias,
497    })
498}
499
500/// Converts a sqlparser `WindowFrameBound` to our `FrameBound`.
501fn convert_frame_bound(bound: &sqlparser::ast::WindowFrameBound) -> FrameBound {
502    match bound {
503        sqlparser::ast::WindowFrameBound::CurrentRow => FrameBound::CurrentRow,
504        sqlparser::ast::WindowFrameBound::Preceding(None) => FrameBound::UnboundedPreceding,
505        sqlparser::ast::WindowFrameBound::Preceding(Some(expr)) => {
506            let n = expr_to_u64(expr).unwrap_or(0);
507            FrameBound::Preceding(n)
508        }
509        sqlparser::ast::WindowFrameBound::Following(None) => FrameBound::UnboundedFollowing,
510        sqlparser::ast::WindowFrameBound::Following(Some(expr)) => {
511            let n = expr_to_u64(expr).unwrap_or(0);
512            FrameBound::Following(n)
513        }
514    }
515}
516
517/// Extracts a u64 value from an expression (numeric literal).
518fn expr_to_u64(expr: &Expr) -> Option<u64> {
519    match expr {
520        Expr::Value(value_with_span) => match &value_with_span.value {
521            sqlparser::ast::Value::Number(n, _) => n.parse().ok(),
522            _ => None,
523        },
524        _ => None,
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use sqlparser::dialect::GenericDialect;
532    use sqlparser::parser::Parser;
533
534    fn parse_stmt(sql: &str) -> Statement {
535        let dialect = GenericDialect {};
536        let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
537        stmts.remove(0)
538    }
539
540    #[test]
541    fn test_lag_basic() {
542        let sql = "SELECT price, LAG(price) OVER (ORDER BY ts) AS prev_price FROM trades";
543        let stmt = parse_stmt(sql);
544        let analysis = analyze_analytic_functions(&stmt).unwrap();
545        assert_eq!(analysis.functions.len(), 1);
546        assert_eq!(
547            analysis.functions[0].function_type,
548            AnalyticFunctionType::Lag
549        );
550        assert_eq!(analysis.functions[0].column, "price");
551        assert_eq!(analysis.functions[0].offset, 1);
552        assert_eq!(analysis.functions[0].alias.as_deref(), Some("prev_price"));
553    }
554
555    #[test]
556    fn test_lag_with_offset() {
557        let sql = "SELECT LAG(price, 3) OVER (ORDER BY ts) AS prev3 FROM trades";
558        let stmt = parse_stmt(sql);
559        let analysis = analyze_analytic_functions(&stmt).unwrap();
560        assert_eq!(analysis.functions[0].offset, 3);
561    }
562
563    #[test]
564    fn test_lag_with_default() {
565        let sql = "SELECT LAG(price, 1, 0) OVER (ORDER BY ts) AS prev FROM trades";
566        let stmt = parse_stmt(sql);
567        let analysis = analyze_analytic_functions(&stmt).unwrap();
568        assert_eq!(analysis.functions[0].offset, 1);
569        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("0"));
570    }
571
572    #[test]
573    fn test_lead_basic() {
574        let sql = "SELECT LEAD(price) OVER (ORDER BY ts) AS next_price FROM trades";
575        let stmt = parse_stmt(sql);
576        let analysis = analyze_analytic_functions(&stmt).unwrap();
577        assert_eq!(
578            analysis.functions[0].function_type,
579            AnalyticFunctionType::Lead
580        );
581        assert!(analysis.has_lookahead());
582    }
583
584    #[test]
585    fn test_lead_with_offset_and_default() {
586        let sql = "SELECT LEAD(price, 2, -1) OVER (ORDER BY ts) AS next2 FROM trades";
587        let stmt = parse_stmt(sql);
588        let analysis = analyze_analytic_functions(&stmt).unwrap();
589        assert_eq!(analysis.functions[0].offset, 2);
590        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("-1"));
591    }
592
593    #[test]
594    fn test_partition_by_extraction() {
595        let sql = "SELECT symbol, LAG(price) OVER (PARTITION BY symbol ORDER BY ts) FROM trades";
596        let stmt = parse_stmt(sql);
597        let analysis = analyze_analytic_functions(&stmt).unwrap();
598        assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
599        assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
600    }
601
602    #[test]
603    fn test_multiple_analytic_functions() {
604        let sql = "SELECT
605            LAG(price) OVER (ORDER BY ts) AS prev,
606            LEAD(price) OVER (ORDER BY ts) AS next
607            FROM trades";
608        let stmt = parse_stmt(sql);
609        let analysis = analyze_analytic_functions(&stmt).unwrap();
610        assert_eq!(analysis.functions.len(), 2);
611        assert_eq!(
612            analysis.functions[0].function_type,
613            AnalyticFunctionType::Lag
614        );
615        assert_eq!(
616            analysis.functions[1].function_type,
617            AnalyticFunctionType::Lead
618        );
619    }
620
621    #[test]
622    fn test_first_value() {
623        let sql =
624            "SELECT FIRST_VALUE(price) OVER (PARTITION BY symbol ORDER BY ts) AS first FROM trades";
625        let stmt = parse_stmt(sql);
626        let analysis = analyze_analytic_functions(&stmt).unwrap();
627        assert_eq!(
628            analysis.functions[0].function_type,
629            AnalyticFunctionType::FirstValue
630        );
631        assert_eq!(analysis.functions[0].column, "price");
632    }
633
634    #[test]
635    fn test_last_value() {
636        let sql = "SELECT LAST_VALUE(price) OVER (ORDER BY ts) FROM trades";
637        let stmt = parse_stmt(sql);
638        let analysis = analyze_analytic_functions(&stmt).unwrap();
639        assert_eq!(
640            analysis.functions[0].function_type,
641            AnalyticFunctionType::LastValue
642        );
643    }
644
645    #[test]
646    fn test_no_analytic_functions() {
647        let sql = "SELECT price, volume FROM trades WHERE price > 100";
648        let stmt = parse_stmt(sql);
649        assert!(analyze_analytic_functions(&stmt).is_none());
650    }
651
652    #[test]
653    fn test_max_offset() {
654        let sql = "SELECT
655            LAG(price, 1) OVER (ORDER BY ts) AS p1,
656            LAG(price, 5) OVER (ORDER BY ts) AS p5,
657            LEAD(price, 3) OVER (ORDER BY ts) AS n3
658            FROM trades";
659        let stmt = parse_stmt(sql);
660        let analysis = analyze_analytic_functions(&stmt).unwrap();
661        assert_eq!(analysis.max_offset(), 5);
662    }
663
664    // --- Window Frame tests (F-SQL-006) ---
665
666    #[test]
667    fn test_frame_rows_preceding_current() {
668        let sql = "SELECT AVG(price) OVER (ORDER BY ts \
669                    ROWS BETWEEN 9 PRECEDING AND CURRENT ROW) AS ma \
670                    FROM trades";
671        let stmt = parse_stmt(sql);
672        let analysis = analyze_window_frames(&stmt).unwrap();
673        assert_eq!(analysis.functions.len(), 1);
674        assert_eq!(
675            analysis.functions[0].function_type,
676            WindowFrameFunction::Avg
677        );
678        assert_eq!(analysis.functions[0].column, "price");
679        assert_eq!(analysis.functions[0].units, FrameUnits::Rows);
680        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(9));
681        assert_eq!(analysis.functions[0].end_bound, FrameBound::CurrentRow);
682        assert_eq!(analysis.functions[0].alias.as_deref(), Some("ma"));
683    }
684
685    #[test]
686    fn test_frame_rows_preceding_following() {
687        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
688                    ROWS BETWEEN 5 PRECEDING AND 3 FOLLOWING) AS total \
689                    FROM orders";
690        let stmt = parse_stmt(sql);
691        let analysis = analyze_window_frames(&stmt).unwrap();
692        assert_eq!(
693            analysis.functions[0].function_type,
694            WindowFrameFunction::Sum
695        );
696        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(5));
697        assert_eq!(analysis.functions[0].end_bound, FrameBound::Following(3));
698    }
699
700    #[test]
701    fn test_frame_unbounded_preceding_running_sum() {
702        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
703                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running \
704                    FROM orders";
705        let stmt = parse_stmt(sql);
706        let analysis = analyze_window_frames(&stmt).unwrap();
707        assert_eq!(
708            analysis.functions[0].start_bound,
709            FrameBound::UnboundedPreceding
710        );
711        assert_eq!(analysis.functions[0].end_bound, FrameBound::CurrentRow);
712    }
713
714    #[test]
715    fn test_frame_range_units() {
716        let sql = "SELECT AVG(price) OVER (ORDER BY ts \
717                    RANGE BETWEEN 10 PRECEDING AND CURRENT ROW) AS ra \
718                    FROM trades";
719        let stmt = parse_stmt(sql);
720        let analysis = analyze_window_frames(&stmt).unwrap();
721        assert_eq!(analysis.functions[0].units, FrameUnits::Range);
722        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(10));
723    }
724
725    #[test]
726    fn test_frame_partition_order_columns() {
727        let sql = "SELECT AVG(price) OVER (PARTITION BY symbol ORDER BY ts \
728                    ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) AS ma \
729                    FROM trades";
730        let stmt = parse_stmt(sql);
731        let analysis = analyze_window_frames(&stmt).unwrap();
732        assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
733        assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
734    }
735
736    #[test]
737    fn test_frame_multiple_functions() {
738        let sql = "SELECT \
739                    AVG(price) OVER (ORDER BY ts ROWS BETWEEN 9 PRECEDING AND CURRENT ROW) AS ma, \
740                    SUM(volume) OVER (ORDER BY ts ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) AS sv \
741                    FROM trades";
742        let stmt = parse_stmt(sql);
743        let analysis = analyze_window_frames(&stmt).unwrap();
744        assert_eq!(analysis.functions.len(), 2);
745        assert_eq!(
746            analysis.functions[0].function_type,
747            WindowFrameFunction::Avg
748        );
749        assert_eq!(analysis.functions[0].column, "price");
750        assert_eq!(
751            analysis.functions[1].function_type,
752            WindowFrameFunction::Sum
753        );
754        assert_eq!(analysis.functions[1].column, "volume");
755    }
756
757    #[test]
758    fn test_frame_no_frame_returns_none() {
759        // AVG with OVER but no explicit frame → None
760        let sql = "SELECT AVG(price) OVER (ORDER BY ts) FROM trades";
761        let stmt = parse_stmt(sql);
762        assert!(analyze_window_frames(&stmt).is_none());
763    }
764
765    #[test]
766    fn test_frame_unbounded_following() {
767        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
768                    ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS rest \
769                    FROM orders";
770        let stmt = parse_stmt(sql);
771        let analysis = analyze_window_frames(&stmt).unwrap();
772        assert_eq!(analysis.functions[0].start_bound, FrameBound::CurrentRow);
773        assert_eq!(
774            analysis.functions[0].end_bound,
775            FrameBound::UnboundedFollowing
776        );
777        assert!(analysis.has_following());
778    }
779
780    #[test]
781    fn test_frame_all_function_types() {
782        let sql = "SELECT \
783                    AVG(a) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f1, \
784                    SUM(b) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f2, \
785                    MIN(c) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f3, \
786                    MAX(d) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f4, \
787                    COUNT(e) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f5 \
788                    FROM t";
789        let stmt = parse_stmt(sql);
790        let analysis = analyze_window_frames(&stmt).unwrap();
791        assert_eq!(analysis.functions.len(), 5);
792        assert_eq!(
793            analysis.functions[0].function_type,
794            WindowFrameFunction::Avg
795        );
796        assert_eq!(
797            analysis.functions[1].function_type,
798            WindowFrameFunction::Sum
799        );
800        assert_eq!(
801            analysis.functions[2].function_type,
802            WindowFrameFunction::Min
803        );
804        assert_eq!(
805            analysis.functions[3].function_type,
806            WindowFrameFunction::Max
807        );
808        assert_eq!(
809            analysis.functions[4].function_type,
810            WindowFrameFunction::Count
811        );
812    }
813
814    #[test]
815    fn test_frame_max_preceding_helper() {
816        let sql = "SELECT \
817                    AVG(a) OVER (ORDER BY id ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) AS f1, \
818                    SUM(b) OVER (ORDER BY id ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS f2 \
819                    FROM t";
820        let stmt = parse_stmt(sql);
821        let analysis = analyze_window_frames(&stmt).unwrap();
822        assert_eq!(analysis.max_preceding(), 10);
823        assert!(!analysis.has_following());
824    }
825}