1use sqlparser::ast::{Expr, SelectItem, SetExpr, Statement};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum AnalyticFunctionType {
12 Lag,
14 Lead,
16 FirstValue,
18 LastValue,
20 NthValue,
22}
23
24impl AnalyticFunctionType {
25 #[must_use]
27 pub fn sql_name(&self) -> &'static str {
28 match self {
29 Self::Lag => "LAG",
30 Self::Lead => "LEAD",
31 Self::FirstValue => "FIRST_VALUE",
32 Self::LastValue => "LAST_VALUE",
33 Self::NthValue => "NTH_VALUE",
34 }
35 }
36
37 #[must_use]
39 pub fn requires_lookahead(&self) -> bool {
40 matches!(self, Self::Lead)
41 }
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AnalyticFunctionInfo {
47 pub function_type: AnalyticFunctionType,
49 pub column: String,
51 pub offset: usize,
53 pub default_value: Option<String>,
55 pub alias: Option<String>,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct AnalyticWindowAnalysis {
62 pub functions: Vec<AnalyticFunctionInfo>,
64 pub partition_columns: Vec<String>,
66 pub order_columns: Vec<String>,
68}
69
70impl AnalyticWindowAnalysis {
71 #[must_use]
73 pub fn has_lookahead(&self) -> bool {
74 self.functions
75 .iter()
76 .any(|f| f.function_type.requires_lookahead())
77 }
78
79 #[must_use]
81 pub fn max_offset(&self) -> usize {
82 self.functions.iter().map(|f| f.offset).max().unwrap_or(0)
83 }
84}
85
86#[must_use]
100pub fn analyze_analytic_functions(stmt: &Statement) -> Option<AnalyticWindowAnalysis> {
101 let Statement::Query(query) = stmt else {
102 return None;
103 };
104
105 let SetExpr::Select(select) = query.body.as_ref() else {
106 return None;
107 };
108
109 let mut functions = Vec::new();
110 let mut partition_columns = Vec::new();
111 let mut order_columns = Vec::new();
112 let mut first_window = true;
113
114 for item in &select.projection {
115 let (expr, alias) = match item {
116 SelectItem::UnnamedExpr(expr) => (expr, None),
117 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
118 _ => continue,
119 };
120
121 if let Some(info) = extract_analytic_function(expr, alias, &mut |spec| {
122 if first_window {
123 partition_columns = spec
124 .partition_by
125 .iter()
126 .filter_map(extract_column_name)
127 .collect();
128 order_columns = spec
129 .order_by
130 .iter()
131 .filter_map(|ob| extract_column_name(&ob.expr))
132 .collect();
133 first_window = false;
134 }
135 }) {
136 functions.push(info);
137 }
138 }
139
140 if functions.is_empty() {
141 return None;
142 }
143
144 Some(AnalyticWindowAnalysis {
145 functions,
146 partition_columns,
147 order_columns,
148 })
149}
150
151fn extract_analytic_function(
157 expr: &Expr,
158 alias: Option<String>,
159 on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
160) -> Option<AnalyticFunctionInfo> {
161 let Expr::Function(func) = expr else {
162 return None;
163 };
164
165 let name = func.name.to_string().to_uppercase();
166 let function_type = match name.as_str() {
167 "LAG" => AnalyticFunctionType::Lag,
168 "LEAD" => AnalyticFunctionType::Lead,
169 "FIRST_VALUE" => AnalyticFunctionType::FirstValue,
170 "LAST_VALUE" => AnalyticFunctionType::LastValue,
171 "NTH_VALUE" => AnalyticFunctionType::NthValue,
172 _ => return None,
173 };
174
175 let window_spec = func.over.as_ref()?;
177 match window_spec {
178 sqlparser::ast::WindowType::WindowSpec(spec) => {
179 on_window_spec(spec);
180 }
181 sqlparser::ast::WindowType::NamedWindow(_) => {}
182 }
183
184 let args = extract_function_args(func);
186
187 let column = args.first().cloned().unwrap_or_default();
189
190 let offset = args
192 .get(1)
193 .and_then(|s| s.parse::<usize>().ok())
194 .unwrap_or(1);
195
196 let default_value = if matches!(
198 function_type,
199 AnalyticFunctionType::Lag | AnalyticFunctionType::Lead
200 ) {
201 args.get(2).cloned()
202 } else {
203 None
204 };
205
206 Some(AnalyticFunctionInfo {
207 function_type,
208 column,
209 offset,
210 default_value,
211 alias,
212 })
213}
214
215fn extract_function_args(func: &sqlparser::ast::Function) -> Vec<String> {
217 match &func.args {
218 sqlparser::ast::FunctionArguments::List(list) => list
219 .args
220 .iter()
221 .filter_map(|arg| match arg {
222 sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
223 expr,
224 )) => Some(expr_to_string(expr)),
225 _ => None,
226 })
227 .collect(),
228 _ => vec![],
229 }
230}
231
232fn expr_to_string(expr: &Expr) -> String {
234 match expr {
235 Expr::Identifier(ident) => ident.value.clone(),
236 Expr::CompoundIdentifier(parts) => parts.last().map_or(String::new(), |p| p.value.clone()),
237 Expr::Value(value_with_span) => match &value_with_span.value {
238 sqlparser::ast::Value::Number(n, _) => n.clone(),
239 sqlparser::ast::Value::SingleQuotedString(s) => s.clone(),
240 sqlparser::ast::Value::Null => "NULL".to_string(),
241 _ => format!("{}", value_with_span.value),
242 },
243 Expr::UnaryOp {
244 op: sqlparser::ast::UnaryOperator::Minus,
245 expr: inner,
246 } => format!("-{}", expr_to_string(inner)),
247 _ => expr.to_string(),
248 }
249}
250
251fn extract_column_name(expr: &Expr) -> Option<String> {
253 match expr {
254 Expr::Identifier(ident) => Some(ident.value.clone()),
255 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
256 _ => None,
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use sqlparser::dialect::GenericDialect;
264 use sqlparser::parser::Parser;
265
266 fn parse_stmt(sql: &str) -> Statement {
267 let dialect = GenericDialect {};
268 let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
269 stmts.remove(0)
270 }
271
272 #[test]
273 fn test_lag_basic() {
274 let sql = "SELECT price, LAG(price) OVER (ORDER BY ts) AS prev_price FROM trades";
275 let stmt = parse_stmt(sql);
276 let analysis = analyze_analytic_functions(&stmt).unwrap();
277 assert_eq!(analysis.functions.len(), 1);
278 assert_eq!(
279 analysis.functions[0].function_type,
280 AnalyticFunctionType::Lag
281 );
282 assert_eq!(analysis.functions[0].column, "price");
283 assert_eq!(analysis.functions[0].offset, 1);
284 assert_eq!(analysis.functions[0].alias.as_deref(), Some("prev_price"));
285 }
286
287 #[test]
288 fn test_lag_with_offset() {
289 let sql = "SELECT LAG(price, 3) OVER (ORDER BY ts) AS prev3 FROM trades";
290 let stmt = parse_stmt(sql);
291 let analysis = analyze_analytic_functions(&stmt).unwrap();
292 assert_eq!(analysis.functions[0].offset, 3);
293 }
294
295 #[test]
296 fn test_lag_with_default() {
297 let sql = "SELECT LAG(price, 1, 0) OVER (ORDER BY ts) AS prev FROM trades";
298 let stmt = parse_stmt(sql);
299 let analysis = analyze_analytic_functions(&stmt).unwrap();
300 assert_eq!(analysis.functions[0].offset, 1);
301 assert_eq!(analysis.functions[0].default_value.as_deref(), Some("0"));
302 }
303
304 #[test]
305 fn test_lead_basic() {
306 let sql = "SELECT LEAD(price) OVER (ORDER BY ts) AS next_price FROM trades";
307 let stmt = parse_stmt(sql);
308 let analysis = analyze_analytic_functions(&stmt).unwrap();
309 assert_eq!(
310 analysis.functions[0].function_type,
311 AnalyticFunctionType::Lead
312 );
313 assert!(analysis.has_lookahead());
314 }
315
316 #[test]
317 fn test_lead_with_offset_and_default() {
318 let sql = "SELECT LEAD(price, 2, -1) OVER (ORDER BY ts) AS next2 FROM trades";
319 let stmt = parse_stmt(sql);
320 let analysis = analyze_analytic_functions(&stmt).unwrap();
321 assert_eq!(analysis.functions[0].offset, 2);
322 assert_eq!(analysis.functions[0].default_value.as_deref(), Some("-1"));
323 }
324
325 #[test]
326 fn test_partition_by_extraction() {
327 let sql = "SELECT symbol, LAG(price) OVER (PARTITION BY symbol ORDER BY ts) FROM trades";
328 let stmt = parse_stmt(sql);
329 let analysis = analyze_analytic_functions(&stmt).unwrap();
330 assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
331 assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
332 }
333
334 #[test]
335 fn test_multiple_analytic_functions() {
336 let sql = "SELECT
337 LAG(price) OVER (ORDER BY ts) AS prev,
338 LEAD(price) OVER (ORDER BY ts) AS next
339 FROM trades";
340 let stmt = parse_stmt(sql);
341 let analysis = analyze_analytic_functions(&stmt).unwrap();
342 assert_eq!(analysis.functions.len(), 2);
343 assert_eq!(
344 analysis.functions[0].function_type,
345 AnalyticFunctionType::Lag
346 );
347 assert_eq!(
348 analysis.functions[1].function_type,
349 AnalyticFunctionType::Lead
350 );
351 }
352
353 #[test]
354 fn test_first_value() {
355 let sql =
356 "SELECT FIRST_VALUE(price) OVER (PARTITION BY symbol ORDER BY ts) AS first FROM trades";
357 let stmt = parse_stmt(sql);
358 let analysis = analyze_analytic_functions(&stmt).unwrap();
359 assert_eq!(
360 analysis.functions[0].function_type,
361 AnalyticFunctionType::FirstValue
362 );
363 assert_eq!(analysis.functions[0].column, "price");
364 }
365
366 #[test]
367 fn test_last_value() {
368 let sql = "SELECT LAST_VALUE(price) OVER (ORDER BY ts) FROM trades";
369 let stmt = parse_stmt(sql);
370 let analysis = analyze_analytic_functions(&stmt).unwrap();
371 assert_eq!(
372 analysis.functions[0].function_type,
373 AnalyticFunctionType::LastValue
374 );
375 }
376
377 #[test]
378 fn test_no_analytic_functions() {
379 let sql = "SELECT price, volume FROM trades WHERE price > 100";
380 let stmt = parse_stmt(sql);
381 assert!(analyze_analytic_functions(&stmt).is_none());
382 }
383
384 #[test]
385 fn test_max_offset() {
386 let sql = "SELECT
387 LAG(price, 1) OVER (ORDER BY ts) AS p1,
388 LAG(price, 5) OVER (ORDER BY ts) AS p5,
389 LEAD(price, 3) OVER (ORDER BY ts) AS n3
390 FROM trades";
391 let stmt = parse_stmt(sql);
392 let analysis = analyze_analytic_functions(&stmt).unwrap();
393 assert_eq!(analysis.max_offset(), 5);
394 }
395}