Skip to main content

laminar_sql/parser/
analytic_parser.rs

1//! Analytic window function detection and extraction
2//!
3//! Analyzes SQL queries for analytic functions like LAG, LEAD, FIRST_VALUE,
4//! LAST_VALUE, and NTH_VALUE with OVER clauses. These are per-row window
5//! functions (distinct from GROUP BY aggregate windows like TUMBLE/HOP/SESSION).
6
7use sqlparser::ast::{Expr, SelectItem, SetExpr, Statement};
8
9/// Types of analytic window functions.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum AnalyticFunctionType {
12    /// LAG(col, offset, default) — look back `offset` rows in partition.
13    Lag,
14    /// LEAD(col, offset, default) — look ahead `offset` rows in partition.
15    Lead,
16    /// FIRST_VALUE(col) OVER (...) — first value in window frame.
17    FirstValue,
18    /// LAST_VALUE(col) OVER (...) — last value in window frame.
19    LastValue,
20    /// NTH_VALUE(col, n) OVER (...) — n-th value in window frame.
21    NthValue,
22}
23
24impl AnalyticFunctionType {
25    /// Returns the function name as used in SQL.
26    #[must_use]
27    pub fn sql_name(&self) -> &'static str {
28        match self {
29            Self::Lag => "LAG",
30            Self::Lead => "LEAD",
31            Self::FirstValue => "FIRST_VALUE",
32            Self::LastValue => "LAST_VALUE",
33            Self::NthValue => "NTH_VALUE",
34        }
35    }
36
37    /// Returns true if this function requires buffering future events.
38    #[must_use]
39    pub fn requires_lookahead(&self) -> bool {
40        matches!(self, Self::Lead)
41    }
42}
43
44/// Information about a single analytic function call.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AnalyticFunctionInfo {
47    /// Type of analytic function
48    pub function_type: AnalyticFunctionType,
49    /// Column being referenced (first argument)
50    pub column: String,
51    /// Offset for LAG/LEAD (default 1), or N for NTH_VALUE
52    pub offset: usize,
53    /// Default value expression as string (for LAG/LEAD third argument)
54    pub default_value: Option<String>,
55    /// Output alias (AS name)
56    pub alias: Option<String>,
57}
58
59/// Result of analyzing analytic functions in a query.
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct AnalyticWindowAnalysis {
62    /// Analytic functions found in the query
63    pub functions: Vec<AnalyticFunctionInfo>,
64    /// PARTITION BY columns from the OVER clause
65    pub partition_columns: Vec<String>,
66    /// ORDER BY columns from the OVER clause
67    pub order_columns: Vec<String>,
68}
69
70impl AnalyticWindowAnalysis {
71    /// Returns true if any function requires lookahead (LEAD).
72    #[must_use]
73    pub fn has_lookahead(&self) -> bool {
74        self.functions
75            .iter()
76            .any(|f| f.function_type.requires_lookahead())
77    }
78
79    /// Returns the maximum offset across all functions.
80    #[must_use]
81    pub fn max_offset(&self) -> usize {
82        self.functions.iter().map(|f| f.offset).max().unwrap_or(0)
83    }
84}
85
86/// Analyzes a SQL statement for analytic window functions.
87///
88/// Walks SELECT items looking for functions with OVER clauses that match
89/// LAG, LEAD, FIRST_VALUE, LAST_VALUE, or NTH_VALUE. Returns `None` if
90/// no analytic functions are found.
91///
92/// # Arguments
93///
94/// * `stmt` - The SQL statement to analyze
95///
96/// # Returns
97///
98/// An `AnalyticWindowAnalysis` if analytic functions are found, or `None`.
99#[must_use]
100pub fn analyze_analytic_functions(stmt: &Statement) -> Option<AnalyticWindowAnalysis> {
101    let Statement::Query(query) = stmt else {
102        return None;
103    };
104
105    let SetExpr::Select(select) = query.body.as_ref() else {
106        return None;
107    };
108
109    let mut functions = Vec::new();
110    let mut partition_columns = Vec::new();
111    let mut order_columns = Vec::new();
112    let mut first_window = true;
113
114    for item in &select.projection {
115        let (expr, alias) = match item {
116            SelectItem::UnnamedExpr(expr) => (expr, None),
117            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
118            _ => continue,
119        };
120
121        if let Some(info) = extract_analytic_function(expr, alias, &mut |spec| {
122            if first_window {
123                partition_columns = spec
124                    .partition_by
125                    .iter()
126                    .filter_map(extract_column_name)
127                    .collect();
128                order_columns = spec
129                    .order_by
130                    .iter()
131                    .filter_map(|ob| extract_column_name(&ob.expr))
132                    .collect();
133                first_window = false;
134            }
135        }) {
136            functions.push(info);
137        }
138    }
139
140    if functions.is_empty() {
141        return None;
142    }
143
144    Some(AnalyticWindowAnalysis {
145        functions,
146        partition_columns,
147        order_columns,
148    })
149}
150
151/// Extracts an analytic function from an expression.
152///
153/// Returns function info if the expression is a recognized analytic function
154/// with an OVER clause. Calls `on_window_spec` with the window spec from the
155/// first function found so the caller can extract partition/order columns.
156fn extract_analytic_function(
157    expr: &Expr,
158    alias: Option<String>,
159    on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
160) -> Option<AnalyticFunctionInfo> {
161    let Expr::Function(func) = expr else {
162        return None;
163    };
164
165    let name = func.name.to_string().to_uppercase();
166    let function_type = match name.as_str() {
167        "LAG" => AnalyticFunctionType::Lag,
168        "LEAD" => AnalyticFunctionType::Lead,
169        "FIRST_VALUE" => AnalyticFunctionType::FirstValue,
170        "LAST_VALUE" => AnalyticFunctionType::LastValue,
171        "NTH_VALUE" => AnalyticFunctionType::NthValue,
172        _ => return None,
173    };
174
175    // Must have an OVER clause to be an analytic function
176    let window_spec = func.over.as_ref()?;
177    match window_spec {
178        sqlparser::ast::WindowType::WindowSpec(spec) => {
179            on_window_spec(spec);
180        }
181        sqlparser::ast::WindowType::NamedWindow(_) => {}
182    }
183
184    // Extract arguments
185    let args = extract_function_args(func);
186
187    // First arg is the column
188    let column = args.first().cloned().unwrap_or_default();
189
190    // Second arg is offset (for LAG/LEAD) or N (for NTH_VALUE), default 1
191    let offset = args
192        .get(1)
193        .and_then(|s| s.parse::<usize>().ok())
194        .unwrap_or(1);
195
196    // Third arg is default value (for LAG/LEAD only)
197    let default_value = if matches!(
198        function_type,
199        AnalyticFunctionType::Lag | AnalyticFunctionType::Lead
200    ) {
201        args.get(2).cloned()
202    } else {
203        None
204    };
205
206    Some(AnalyticFunctionInfo {
207        function_type,
208        column,
209        offset,
210        default_value,
211        alias,
212    })
213}
214
215/// Extracts function argument expressions as strings.
216fn extract_function_args(func: &sqlparser::ast::Function) -> Vec<String> {
217    match &func.args {
218        sqlparser::ast::FunctionArguments::List(list) => list
219            .args
220            .iter()
221            .filter_map(|arg| match arg {
222                sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
223                    expr,
224                )) => Some(expr_to_string(expr)),
225                _ => None,
226            })
227            .collect(),
228        _ => vec![],
229    }
230}
231
232/// Converts an expression to its string representation.
233fn expr_to_string(expr: &Expr) -> String {
234    match expr {
235        Expr::Identifier(ident) => ident.value.clone(),
236        Expr::CompoundIdentifier(parts) => parts.last().map_or(String::new(), |p| p.value.clone()),
237        Expr::Value(value_with_span) => match &value_with_span.value {
238            sqlparser::ast::Value::Number(n, _) => n.clone(),
239            sqlparser::ast::Value::SingleQuotedString(s) => s.clone(),
240            sqlparser::ast::Value::Null => "NULL".to_string(),
241            _ => format!("{}", value_with_span.value),
242        },
243        Expr::UnaryOp {
244            op: sqlparser::ast::UnaryOperator::Minus,
245            expr: inner,
246        } => format!("-{}", expr_to_string(inner)),
247        _ => expr.to_string(),
248    }
249}
250
251/// Extracts a simple column name from an expression.
252fn extract_column_name(expr: &Expr) -> Option<String> {
253    match expr {
254        Expr::Identifier(ident) => Some(ident.value.clone()),
255        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
256        _ => None,
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use sqlparser::dialect::GenericDialect;
264    use sqlparser::parser::Parser;
265
266    fn parse_stmt(sql: &str) -> Statement {
267        let dialect = GenericDialect {};
268        let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
269        stmts.remove(0)
270    }
271
272    #[test]
273    fn test_lag_basic() {
274        let sql = "SELECT price, LAG(price) OVER (ORDER BY ts) AS prev_price FROM trades";
275        let stmt = parse_stmt(sql);
276        let analysis = analyze_analytic_functions(&stmt).unwrap();
277        assert_eq!(analysis.functions.len(), 1);
278        assert_eq!(
279            analysis.functions[0].function_type,
280            AnalyticFunctionType::Lag
281        );
282        assert_eq!(analysis.functions[0].column, "price");
283        assert_eq!(analysis.functions[0].offset, 1);
284        assert_eq!(analysis.functions[0].alias.as_deref(), Some("prev_price"));
285    }
286
287    #[test]
288    fn test_lag_with_offset() {
289        let sql = "SELECT LAG(price, 3) OVER (ORDER BY ts) AS prev3 FROM trades";
290        let stmt = parse_stmt(sql);
291        let analysis = analyze_analytic_functions(&stmt).unwrap();
292        assert_eq!(analysis.functions[0].offset, 3);
293    }
294
295    #[test]
296    fn test_lag_with_default() {
297        let sql = "SELECT LAG(price, 1, 0) OVER (ORDER BY ts) AS prev FROM trades";
298        let stmt = parse_stmt(sql);
299        let analysis = analyze_analytic_functions(&stmt).unwrap();
300        assert_eq!(analysis.functions[0].offset, 1);
301        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("0"));
302    }
303
304    #[test]
305    fn test_lead_basic() {
306        let sql = "SELECT LEAD(price) OVER (ORDER BY ts) AS next_price FROM trades";
307        let stmt = parse_stmt(sql);
308        let analysis = analyze_analytic_functions(&stmt).unwrap();
309        assert_eq!(
310            analysis.functions[0].function_type,
311            AnalyticFunctionType::Lead
312        );
313        assert!(analysis.has_lookahead());
314    }
315
316    #[test]
317    fn test_lead_with_offset_and_default() {
318        let sql = "SELECT LEAD(price, 2, -1) OVER (ORDER BY ts) AS next2 FROM trades";
319        let stmt = parse_stmt(sql);
320        let analysis = analyze_analytic_functions(&stmt).unwrap();
321        assert_eq!(analysis.functions[0].offset, 2);
322        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("-1"));
323    }
324
325    #[test]
326    fn test_partition_by_extraction() {
327        let sql = "SELECT symbol, LAG(price) OVER (PARTITION BY symbol ORDER BY ts) FROM trades";
328        let stmt = parse_stmt(sql);
329        let analysis = analyze_analytic_functions(&stmt).unwrap();
330        assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
331        assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
332    }
333
334    #[test]
335    fn test_multiple_analytic_functions() {
336        let sql = "SELECT
337            LAG(price) OVER (ORDER BY ts) AS prev,
338            LEAD(price) OVER (ORDER BY ts) AS next
339            FROM trades";
340        let stmt = parse_stmt(sql);
341        let analysis = analyze_analytic_functions(&stmt).unwrap();
342        assert_eq!(analysis.functions.len(), 2);
343        assert_eq!(
344            analysis.functions[0].function_type,
345            AnalyticFunctionType::Lag
346        );
347        assert_eq!(
348            analysis.functions[1].function_type,
349            AnalyticFunctionType::Lead
350        );
351    }
352
353    #[test]
354    fn test_first_value() {
355        let sql =
356            "SELECT FIRST_VALUE(price) OVER (PARTITION BY symbol ORDER BY ts) AS first FROM trades";
357        let stmt = parse_stmt(sql);
358        let analysis = analyze_analytic_functions(&stmt).unwrap();
359        assert_eq!(
360            analysis.functions[0].function_type,
361            AnalyticFunctionType::FirstValue
362        );
363        assert_eq!(analysis.functions[0].column, "price");
364    }
365
366    #[test]
367    fn test_last_value() {
368        let sql = "SELECT LAST_VALUE(price) OVER (ORDER BY ts) FROM trades";
369        let stmt = parse_stmt(sql);
370        let analysis = analyze_analytic_functions(&stmt).unwrap();
371        assert_eq!(
372            analysis.functions[0].function_type,
373            AnalyticFunctionType::LastValue
374        );
375    }
376
377    #[test]
378    fn test_no_analytic_functions() {
379        let sql = "SELECT price, volume FROM trades WHERE price > 100";
380        let stmt = parse_stmt(sql);
381        assert!(analyze_analytic_functions(&stmt).is_none());
382    }
383
384    #[test]
385    fn test_max_offset() {
386        let sql = "SELECT
387            LAG(price, 1) OVER (ORDER BY ts) AS p1,
388            LAG(price, 5) OVER (ORDER BY ts) AS p5,
389            LEAD(price, 3) OVER (ORDER BY ts) AS n3
390            FROM trades";
391        let stmt = parse_stmt(sql);
392        let analysis = analyze_analytic_functions(&stmt).unwrap();
393        assert_eq!(analysis.max_offset(), 5);
394    }
395}