1use sqlparser::ast::{Expr, OrderByKind, Query, SelectItem, SetExpr, Statement};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct OrderAnalysis {
11 pub order_columns: Vec<OrderColumn>,
13 pub limit: Option<usize>,
15 pub is_windowed: bool,
17 pub pattern: OrderPattern,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct OrderColumn {
24 pub column: String,
26 pub descending: bool,
28 pub nulls_first: bool,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum OrderPattern {
35 None,
37 SourceSatisfied,
39 TopK {
41 k: usize,
43 },
44 WindowLocal,
46 PerGroupTopK {
48 k: usize,
50 partition_columns: Vec<String>,
52 rank_type: RankType,
54 },
55 Unbounded,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum RankType {
62 RowNumber,
64 Rank,
66 DenseRank,
68}
69
70impl OrderAnalysis {
71 #[must_use]
73 pub fn is_streaming_safe(&self) -> bool {
74 !matches!(self.pattern, OrderPattern::Unbounded)
75 }
76}
77
78#[must_use]
91pub fn analyze_order_by(stmt: &Statement) -> OrderAnalysis {
92 let Statement::Query(query) = stmt else {
93 return OrderAnalysis {
94 order_columns: vec![],
95 limit: None,
96 is_windowed: false,
97 pattern: OrderPattern::None,
98 };
99 };
100
101 let limit = extract_limit(query);
102 let is_windowed = check_is_windowed(query);
103
104 if let Some((k, partition_columns, rank_type)) = detect_row_number_pattern(query) {
108 let order_columns = extract_order_columns(query);
109 return OrderAnalysis {
110 order_columns,
111 limit,
112 is_windowed,
113 pattern: OrderPattern::PerGroupTopK {
114 k,
115 partition_columns,
116 rank_type,
117 },
118 };
119 }
120
121 let order_columns = extract_order_columns(query);
122 if order_columns.is_empty() {
123 return OrderAnalysis {
124 order_columns: vec![],
125 limit: None,
126 is_windowed: false,
127 pattern: OrderPattern::None,
128 };
129 }
130
131 let pattern = if is_windowed {
132 OrderPattern::WindowLocal
133 } else if let Some(k) = limit {
134 OrderPattern::TopK { k }
135 } else {
136 OrderPattern::Unbounded
137 };
138
139 OrderAnalysis {
140 order_columns,
141 limit,
142 is_windowed,
143 pattern,
144 }
145}
146
147#[must_use]
152pub fn is_order_satisfied(
153 required: &[OrderColumn],
154 source: &[crate::datafusion::SortColumn],
155) -> bool {
156 if required.is_empty() {
157 return true;
158 }
159 if source.len() < required.len() {
160 return false;
161 }
162 required.iter().zip(source.iter()).all(|(req, src)| {
163 req.column == src.name
164 && req.descending == src.descending
165 && req.nulls_first == src.nulls_first
166 })
167}
168
169fn extract_order_columns(query: &Query) -> Vec<OrderColumn> {
171 let Some(order_by) = &query.order_by else {
172 return vec![];
173 };
174
175 let OrderByKind::Expressions(exprs) = &order_by.kind else {
176 return vec![]; };
178
179 exprs
180 .iter()
181 .filter_map(|ob_expr| {
182 let column = extract_column_name(&ob_expr.expr)?;
183 let descending = !ob_expr.options.asc.unwrap_or(true);
184 let nulls_first = ob_expr.options.nulls_first.unwrap_or(false);
185 Some(OrderColumn {
186 column,
187 descending,
188 nulls_first,
189 })
190 })
191 .collect()
192}
193
194fn extract_limit(query: &Query) -> Option<usize> {
196 use sqlparser::ast::LimitClause;
197
198 let limit_clause = query.limit_clause.as_ref()?;
199 match limit_clause {
200 LimitClause::LimitOffset { limit, .. } => {
201 let expr = limit.as_ref()?;
202 expr_to_usize(expr)
203 }
204 LimitClause::OffsetCommaLimit { limit, .. } => expr_to_usize(limit),
205 }
206}
207
208fn check_is_windowed(query: &Query) -> bool {
210 if let SetExpr::Select(select) = query.body.as_ref() {
211 use sqlparser::ast::GroupByExpr;
212 match &select.group_by {
213 GroupByExpr::Expressions(exprs, _modifiers) => {
214 exprs.iter().any(is_window_function_call)
215 }
216 GroupByExpr::All(_) => false,
217 }
218 } else {
219 false
220 }
221}
222
223fn detect_row_number_pattern(query: &Query) -> Option<(usize, Vec<String>, RankType)> {
229 if let SetExpr::Select(select) = query.body.as_ref() {
231 for item in &select.projection {
232 if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
233 if let Some((partition_cols, _order_cols, rank_type)) =
234 extract_row_number_info(expr)
235 {
236 if let Some(k) = extract_limit(query) {
238 return Some((k, partition_cols, rank_type));
239 }
240 }
241 }
242 }
243
244 for from in &select.from {
246 if let sqlparser::ast::TableFactor::Derived { subquery, .. } = &from.relation {
247 if let SetExpr::Select(inner_select) = subquery.body.as_ref() {
248 for item in &inner_select.projection {
249 if let SelectItem::ExprWithAlias { expr, alias } = item {
250 if let Some((partition_cols, _order_cols, rank_type)) =
251 extract_row_number_info(expr)
252 {
253 if let Some(k) =
256 extract_rn_filter_limit(select.selection.as_ref(), &alias.value)
257 {
258 return Some((k, partition_cols, rank_type));
259 }
260 }
261 }
262 }
263 }
264 }
265 }
266 }
267 None
268}
269
270fn extract_row_number_info(expr: &Expr) -> Option<(Vec<String>, Vec<String>, RankType)> {
274 if let Expr::Function(func) = expr {
275 let name = func.name.to_string().to_uppercase();
276 let rank_type = match name.as_str() {
277 "ROW_NUMBER" => RankType::RowNumber,
278 "RANK" => RankType::Rank,
279 "DENSE_RANK" => RankType::DenseRank,
280 _ => return None,
281 };
282 if let Some(ref window_spec) = func.over {
283 match window_spec {
284 sqlparser::ast::WindowType::WindowSpec(spec) => {
285 let partition_cols: Vec<String> = spec
286 .partition_by
287 .iter()
288 .filter_map(extract_column_name)
289 .collect();
290 let order_cols: Vec<String> = spec
291 .order_by
292 .iter()
293 .filter_map(|ob| extract_column_name(&ob.expr))
294 .collect();
295 return Some((partition_cols, order_cols, rank_type));
296 }
297 sqlparser::ast::WindowType::NamedWindow(_) => {}
298 }
299 }
300 }
301 None
302}
303
304fn extract_rn_filter_limit(selection: Option<&Expr>, alias: &str) -> Option<usize> {
306 let where_expr = selection?;
307 if let Expr::BinaryOp { left, op, right } = where_expr {
308 use sqlparser::ast::BinaryOperator;
309 match op {
310 BinaryOperator::LtEq => {
311 if extract_column_name(left)? == alias {
313 return expr_to_usize(right);
314 }
315 }
316 BinaryOperator::Lt => {
317 if extract_column_name(left)? == alias {
319 return expr_to_usize(right).map(|n| n.saturating_sub(1));
320 }
321 }
322 _ => {}
323 }
324 }
325 None
326}
327
328fn is_window_function_call(expr: &Expr) -> bool {
330 if let Expr::Function(func) = expr {
331 let name = func.name.to_string().to_uppercase();
332 matches!(name.as_str(), "TUMBLE" | "HOP" | "SESSION")
333 } else {
334 false
335 }
336}
337
338fn extract_column_name(expr: &Expr) -> Option<String> {
340 match expr {
341 Expr::Identifier(ident) => Some(ident.value.clone()),
342 Expr::CompoundIdentifier(parts) => {
343 parts.last().map(|p| p.value.clone())
345 }
346 _ => None,
347 }
348}
349
350fn expr_to_usize(expr: &Expr) -> Option<usize> {
352 match expr {
353 Expr::Value(value_with_span) => match &value_with_span.value {
354 sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().ok(),
355 _ => None,
356 },
357 _ => None,
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use sqlparser::dialect::GenericDialect;
365 use sqlparser::parser::Parser;
366
367 fn parse_stmt(sql: &str) -> Statement {
368 let dialect = GenericDialect {};
369 let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
370 stmts.remove(0)
371 }
372
373 #[test]
374 fn test_analyze_simple_order_by() {
375 let stmt = parse_stmt("SELECT id, value FROM events ORDER BY id");
376 let analysis = analyze_order_by(&stmt);
377 assert_eq!(analysis.order_columns.len(), 1);
378 assert_eq!(analysis.order_columns[0].column, "id");
379 assert!(!analysis.order_columns[0].descending);
380 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
381 }
382
383 #[test]
384 fn test_analyze_order_by_desc() {
385 let stmt = parse_stmt("SELECT * FROM events ORDER BY price DESC");
386 let analysis = analyze_order_by(&stmt);
387 assert_eq!(analysis.order_columns.len(), 1);
388 assert!(analysis.order_columns[0].descending);
389 }
390
391 #[test]
392 fn test_analyze_order_by_nulls_first() {
393 let stmt = parse_stmt("SELECT * FROM events ORDER BY value ASC NULLS FIRST");
394 let analysis = analyze_order_by(&stmt);
395 assert_eq!(analysis.order_columns.len(), 1);
396 assert!(!analysis.order_columns[0].descending);
397 assert!(analysis.order_columns[0].nulls_first);
398 }
399
400 #[test]
401 fn test_analyze_order_by_multiple_columns() {
402 let stmt = parse_stmt("SELECT * FROM events ORDER BY category ASC, price DESC NULLS LAST");
403 let analysis = analyze_order_by(&stmt);
404 assert_eq!(analysis.order_columns.len(), 2);
405 assert_eq!(analysis.order_columns[0].column, "category");
406 assert!(!analysis.order_columns[0].descending);
407 assert_eq!(analysis.order_columns[1].column, "price");
408 assert!(analysis.order_columns[1].descending);
409 }
410
411 #[test]
412 fn test_analyze_order_by_with_limit() {
413 let stmt = parse_stmt("SELECT * FROM events ORDER BY price DESC LIMIT 10");
414 let analysis = analyze_order_by(&stmt);
415 assert_eq!(analysis.limit, Some(10));
416 assert_eq!(analysis.pattern, OrderPattern::TopK { k: 10 });
417 }
418
419 #[test]
420 fn test_analyze_order_by_without_limit() {
421 let stmt = parse_stmt("SELECT * FROM events ORDER BY id");
422 let analysis = analyze_order_by(&stmt);
423 assert!(analysis.limit.is_none());
424 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
425 assert!(!analysis.is_streaming_safe());
426 }
427
428 #[test]
429 fn test_analyze_no_order_by() {
430 let stmt = parse_stmt("SELECT * FROM events");
431 let analysis = analyze_order_by(&stmt);
432 assert_eq!(analysis.pattern, OrderPattern::None);
433 assert!(analysis.order_columns.is_empty());
434 assert!(analysis.is_streaming_safe());
435 }
436
437 #[test]
438 fn test_analyze_select_star() {
439 let stmt = parse_stmt("SELECT * FROM events WHERE id > 5");
440 let analysis = analyze_order_by(&stmt);
441 assert_eq!(analysis.pattern, OrderPattern::None);
442 }
443
444 #[test]
445 fn test_detect_row_number_pattern() {
446 let sql = "SELECT * FROM (
447 SELECT *, ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn
448 FROM trades
449 ) sub WHERE rn <= 5";
450 let stmt = parse_stmt(sql);
451 let analysis = analyze_order_by(&stmt);
452
453 assert_eq!(
455 analysis.pattern,
456 OrderPattern::PerGroupTopK {
457 k: 5,
458 partition_columns: vec!["category".to_string()],
459 rank_type: RankType::RowNumber,
460 }
461 );
462 assert!(analysis.is_streaming_safe());
463 }
464
465 #[test]
466 fn test_detect_row_number_with_partition() {
467 let sql = "SELECT * FROM (
468 SELECT *, ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn
469 FROM trades
470 ) sub WHERE rn <= 3 ORDER BY category LIMIT 100";
471 let stmt = parse_stmt(sql);
472 let analysis = analyze_order_by(&stmt);
473
474 assert_eq!(
476 analysis.pattern,
477 OrderPattern::PerGroupTopK {
478 k: 3,
479 partition_columns: vec!["category".to_string()],
480 rank_type: RankType::RowNumber,
481 }
482 );
483 assert!(analysis.is_streaming_safe());
484 }
485
486 #[test]
487 fn test_detect_row_number_without_filter() {
488 let sql = "SELECT *, ROW_NUMBER() OVER (ORDER BY price DESC) AS rn FROM trades";
489 let stmt = parse_stmt(sql);
490 let analysis = analyze_order_by(&stmt);
491 assert_eq!(analysis.pattern, OrderPattern::None);
493 }
494
495 #[test]
498 fn test_row_number_subquery_no_outer_order() {
499 let sql = "SELECT * FROM (
500 SELECT *, ROW_NUMBER() OVER (PARTITION BY symbol ORDER BY ts DESC) AS rn
501 FROM trades
502 ) sub WHERE rn <= 10";
503 let stmt = parse_stmt(sql);
504 let analysis = analyze_order_by(&stmt);
505 assert_eq!(
506 analysis.pattern,
507 OrderPattern::PerGroupTopK {
508 k: 10,
509 partition_columns: vec!["symbol".to_string()],
510 rank_type: RankType::RowNumber,
511 }
512 );
513 assert!(analysis.is_streaming_safe());
514 }
515
516 #[test]
517 fn test_row_number_direct_with_limit() {
518 let sql = "SELECT *, ROW_NUMBER() OVER (PARTITION BY cat ORDER BY val DESC) AS rn
519 FROM events LIMIT 5";
520 let stmt = parse_stmt(sql);
521 let analysis = analyze_order_by(&stmt);
522 assert_eq!(
523 analysis.pattern,
524 OrderPattern::PerGroupTopK {
525 k: 5,
526 partition_columns: vec!["cat".to_string()],
527 rank_type: RankType::RowNumber,
528 }
529 );
530 }
531
532 #[test]
533 fn test_detect_rank_pattern() {
534 let sql = "SELECT * FROM (
535 SELECT *, RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rn
536 FROM trades
537 ) sub WHERE rn <= 3";
538 let stmt = parse_stmt(sql);
539 let analysis = analyze_order_by(&stmt);
540 assert_eq!(
541 analysis.pattern,
542 OrderPattern::PerGroupTopK {
543 k: 3,
544 partition_columns: vec!["category".to_string()],
545 rank_type: RankType::Rank,
546 }
547 );
548 assert!(analysis.is_streaming_safe());
549 }
550
551 #[test]
552 fn test_detect_dense_rank_pattern() {
553 let sql = "SELECT * FROM (
554 SELECT *, DENSE_RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS rn
555 FROM sales
556 ) sub WHERE rn <= 5";
557 let stmt = parse_stmt(sql);
558 let analysis = analyze_order_by(&stmt);
559 assert_eq!(
560 analysis.pattern,
561 OrderPattern::PerGroupTopK {
562 k: 5,
563 partition_columns: vec!["region".to_string()],
564 rank_type: RankType::DenseRank,
565 }
566 );
567 }
568
569 #[test]
570 fn test_rank_multiple_partition_columns() {
571 let sql = "SELECT * FROM (
572 SELECT *, RANK() OVER (PARTITION BY region, category ORDER BY sales DESC) AS rn
573 FROM revenue
574 ) sub WHERE rn <= 3";
575 let stmt = parse_stmt(sql);
576 let analysis = analyze_order_by(&stmt);
577 match &analysis.pattern {
578 OrderPattern::PerGroupTopK {
579 k,
580 partition_columns,
581 rank_type,
582 } => {
583 assert_eq!(*k, 3);
584 assert_eq!(
585 partition_columns,
586 &["region".to_string(), "category".to_string()]
587 );
588 assert_eq!(*rank_type, RankType::Rank);
589 }
590 _ => panic!("Expected PerGroupTopK, got {:?}", analysis.pattern),
591 }
592 }
593
594 #[test]
595 fn test_rank_extracts_order_columns() {
596 let sql = "SELECT *, RANK() OVER (PARTITION BY cat ORDER BY price DESC, ts ASC) AS rn
597 FROM trades LIMIT 10";
598 let stmt = parse_stmt(sql);
599 let analysis = analyze_order_by(&stmt);
600 assert!(matches!(
601 analysis.pattern,
602 OrderPattern::PerGroupTopK {
603 rank_type: RankType::Rank,
604 ..
605 }
606 ));
607 }
608
609 #[test]
610 fn test_rank_pattern_is_streaming_safe() {
611 let sql = "SELECT * FROM (
612 SELECT *, DENSE_RANK() OVER (PARTITION BY cat ORDER BY val) AS rn
613 FROM events
614 ) sub WHERE rn <= 5";
615 let stmt = parse_stmt(sql);
616 let analysis = analyze_order_by(&stmt);
617 assert!(analysis.is_streaming_safe());
618 }
619
620 #[test]
621 fn test_no_ranking_function_none() {
622 let sql = "SELECT id, name FROM events WHERE id > 5";
623 let stmt = parse_stmt(sql);
624 let analysis = analyze_order_by(&stmt);
625 assert_eq!(analysis.pattern, OrderPattern::None);
626 }
627
628 #[test]
629 fn test_order_satisfied_exact_match() {
630 use crate::datafusion::SortColumn;
631 let required = vec![OrderColumn {
632 column: "event_time".to_string(),
633 descending: false,
634 nulls_first: false,
635 }];
636 let source = vec![SortColumn::ascending("event_time")];
637 assert!(is_order_satisfied(&required, &source));
638 }
639
640 #[test]
641 fn test_order_satisfied_prefix_match() {
642 use crate::datafusion::SortColumn;
643 let required = vec![OrderColumn {
644 column: "event_time".to_string(),
645 descending: false,
646 nulls_first: false,
647 }];
648 let source = vec![
649 SortColumn::ascending("event_time"),
650 SortColumn::ascending("id"),
651 ];
652 assert!(is_order_satisfied(&required, &source));
653 }
654
655 #[test]
656 fn test_order_not_satisfied_different_direction() {
657 use crate::datafusion::SortColumn;
658 let required = vec![OrderColumn {
659 column: "event_time".to_string(),
660 descending: true,
661 nulls_first: false,
662 }];
663 let source = vec![SortColumn::ascending("event_time")];
664 assert!(!is_order_satisfied(&required, &source));
665 }
666
667 #[test]
668 fn test_order_not_satisfied_different_columns() {
669 use crate::datafusion::SortColumn;
670 let required = vec![OrderColumn {
671 column: "id".to_string(),
672 descending: false,
673 nulls_first: false,
674 }];
675 let source = vec![SortColumn::ascending("event_time")];
676 assert!(!is_order_satisfied(&required, &source));
677 }
678
679 #[test]
680 fn test_topk_pattern_streaming_safe() {
681 let stmt = parse_stmt("SELECT * FROM trades ORDER BY price DESC LIMIT 5");
682 let analysis = analyze_order_by(&stmt);
683 assert!(analysis.is_streaming_safe());
684 assert_eq!(analysis.pattern, OrderPattern::TopK { k: 5 });
685 }
686
687 #[test]
688 fn test_unbounded_pattern_not_streaming_safe() {
689 let stmt = parse_stmt("SELECT * FROM trades ORDER BY price DESC");
690 let analysis = analyze_order_by(&stmt);
691 assert!(!analysis.is_streaming_safe());
692 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
693 }
694
695 #[test]
696 fn test_no_order_by_streaming_safe() {
697 let stmt = parse_stmt("SELECT * FROM trades");
698 let analysis = analyze_order_by(&stmt);
699 assert!(analysis.is_streaming_safe());
700 }
701
702 #[test]
703 fn test_windowed_order_by() {
704 let stmt = parse_stmt(
705 "SELECT COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '5' MINUTE) ORDER BY event_time",
706 );
707 let analysis = analyze_order_by(&stmt);
708 assert_eq!(analysis.pattern, OrderPattern::WindowLocal);
709 assert!(analysis.is_windowed);
710 assert!(analysis.is_streaming_safe());
711 }
712}