1use sqlparser::ast::{
7 CreateView, Cte, Expr, Query, Select, SelectItem, SetExpr, Spanned, Statement, TableFactor,
8 TableWithJoins,
9};
10
11use crate::analyzer::helpers::{infer_expr_type, line_col_to_offset};
12use crate::types::{AstColumnInfo, AstContext, AstTableInfo, CteInfo, SubqueryInfo};
13
14#[derive(Debug, Clone)]
25pub struct LateralAliasInfo {
26 pub name: String,
28 pub definition_end: usize,
30 pub projection_start: usize,
32 pub projection_end: usize,
34}
35
36const MAX_EXTRACTION_DEPTH: usize = 50;
39
40const MAX_LATERAL_ALIASES: usize = 1000;
44
45pub(crate) fn extract_ast_context(statements: &[Statement]) -> AstContext {
52 let mut ctx = AstContext::default();
53
54 for stmt in statements {
55 extract_from_statement(stmt, &mut ctx, 0);
56 }
57
58 ctx
59}
60
61fn extract_from_statement(stmt: &Statement, ctx: &mut AstContext, depth: usize) {
63 if depth > MAX_EXTRACTION_DEPTH {
64 return; }
66
67 match stmt {
68 Statement::Query(query) => {
69 extract_from_query(query, ctx, depth);
70 }
71 Statement::Insert(insert) => {
72 if let Some(source) = &insert.source {
74 extract_from_query(source, ctx, depth);
75 }
76 }
77 Statement::CreateTable(ct) => {
78 if let Some(query) = &ct.query {
80 extract_from_query(query, ctx, depth);
81 }
82 }
83 Statement::CreateView(CreateView { query, .. }) => {
84 extract_from_query(query, ctx, depth);
85 }
86 _ => {}
87 }
88}
89
90fn extract_from_query(query: &Query, ctx: &mut AstContext, depth: usize) {
92 if depth > MAX_EXTRACTION_DEPTH {
93 return;
94 }
95
96 if let Some(with) = &query.with {
98 let is_recursive = with.recursive;
99 for cte in &with.cte_tables {
100 if let Some(info) = extract_cte_info(cte, is_recursive) {
101 ctx.cte_definitions.insert(info.name.clone(), info);
102 }
103 }
104 }
105
106 extract_from_set_expr(&query.body, ctx, depth + 1);
108}
109
110fn extract_from_set_expr(set_expr: &SetExpr, ctx: &mut AstContext, depth: usize) {
112 if depth > MAX_EXTRACTION_DEPTH {
113 return;
114 }
115
116 match set_expr {
117 SetExpr::Select(select) => {
118 extract_from_select(select, ctx, depth);
119 }
120 SetExpr::Query(query) => {
121 extract_from_query(query, ctx, depth);
122 }
123 SetExpr::SetOperation { left, right, .. } => {
124 extract_from_set_expr(left, ctx, depth + 1);
125 extract_from_set_expr(right, ctx, depth + 1);
126 }
127 SetExpr::Values(_) => {}
128 SetExpr::Insert(_) => {}
129 SetExpr::Update(_) => {}
130 SetExpr::Table(_) => {}
131 SetExpr::Delete(_) => {}
132 SetExpr::Merge(_) => {}
133 }
134}
135
136fn extract_from_select(select: &Select, ctx: &mut AstContext, depth: usize) {
138 if depth > MAX_EXTRACTION_DEPTH {
139 return;
140 }
141
142 for table_with_joins in &select.from {
144 extract_from_table_with_joins(table_with_joins, ctx, depth);
145 }
146}
147
148fn extract_from_table_with_joins(twj: &TableWithJoins, ctx: &mut AstContext, depth: usize) {
150 if depth > MAX_EXTRACTION_DEPTH {
151 return;
152 }
153
154 extract_from_table_factor(&twj.relation, ctx, depth);
155
156 for join in &twj.joins {
157 extract_from_table_factor(&join.relation, ctx, depth);
158 }
159}
160
161fn extract_from_table_factor(tf: &TableFactor, ctx: &mut AstContext, depth: usize) {
163 if depth > MAX_EXTRACTION_DEPTH {
164 return;
165 }
166
167 match tf {
168 TableFactor::Table { name, alias, .. } => {
169 let table_name = name.to_string();
170 let alias_name = alias.as_ref().map(|a| a.name.value.clone());
171
172 let key = alias_name.clone().unwrap_or_else(|| {
174 name.0
176 .last()
177 .map(|i| i.to_string())
178 .unwrap_or(table_name.clone())
179 });
180
181 ctx.table_aliases.insert(key, AstTableInfo);
182 }
183 TableFactor::Derived {
184 subquery, alias, ..
185 } => {
186 if let Some(alias) = alias {
188 let columns = extract_projected_columns_from_query(subquery);
189 ctx.subquery_aliases.insert(
190 alias.name.value.clone(),
191 SubqueryInfo {
192 projected_columns: columns,
193 },
194 );
195 }
196
197 extract_from_query(subquery, ctx, depth + 1);
199 }
200 TableFactor::NestedJoin {
201 table_with_joins, ..
202 } => {
203 extract_from_table_with_joins(table_with_joins, ctx, depth + 1);
204 }
205 TableFactor::TableFunction { .. } => {}
206 TableFactor::UNNEST {
207 alias: Some(alias), ..
208 } => {
209 ctx.table_aliases
210 .insert(alias.name.value.clone(), AstTableInfo);
211 }
212 _ => {}
213 }
214}
215
216fn extract_cte_info(cte: &Cte, is_recursive: bool) -> Option<CteInfo> {
218 let name = cte.alias.name.value.clone();
219
220 let declared_columns: Vec<String> = cte
222 .alias
223 .columns
224 .iter()
225 .map(|c| c.name.value.clone())
226 .collect();
227
228 let projected_columns = if is_recursive {
230 extract_base_case_columns(&cte.query)
232 } else {
233 extract_projected_columns_from_query(&cte.query)
234 };
235
236 Some(CteInfo {
237 name,
238 declared_columns,
239 projected_columns,
240 })
241}
242
243fn extract_base_case_columns(query: &Query) -> Vec<AstColumnInfo> {
245 match &*query.body {
246 SetExpr::SetOperation { left, .. } => {
247 if let SetExpr::Select(select) = &**left {
249 extract_select_columns(select)
250 } else {
251 vec![]
252 }
253 }
254 SetExpr::Select(select) => extract_select_columns(select),
255 _ => vec![],
256 }
257}
258
259fn extract_projected_columns_from_query(query: &Query) -> Vec<AstColumnInfo> {
261 match &*query.body {
262 SetExpr::Select(select) => extract_select_columns(select),
263 SetExpr::SetOperation { left, .. } => {
264 if let SetExpr::Select(select) = &**left {
266 extract_select_columns(select)
267 } else {
268 vec![]
269 }
270 }
271 _ => vec![],
272 }
273}
274
275fn extract_select_columns(select: &Select) -> Vec<AstColumnInfo> {
277 let mut columns = Vec::new();
278
279 for (idx, item) in select.projection.iter().enumerate() {
280 match item {
281 SelectItem::ExprWithAlias { alias, expr } => {
282 columns.push(AstColumnInfo {
283 name: alias.value.clone(),
284 data_type: infer_data_type(expr),
285 });
286 }
287 SelectItem::UnnamedExpr(expr) => {
288 columns.push(AstColumnInfo {
289 name: derive_column_name(expr, idx),
290 data_type: infer_data_type(expr),
291 });
292 }
293 SelectItem::Wildcard(_) => {
294 columns.push(AstColumnInfo {
295 name: "*".to_string(),
296 data_type: None,
297 });
298 }
299 SelectItem::QualifiedWildcard(name, _) => {
300 columns.push(AstColumnInfo {
301 name: format!("{}.*", name),
302 data_type: None,
303 });
304 }
305 }
306 }
307
308 columns
309}
310
311fn derive_column_name(expr: &Expr, index: usize) -> String {
313 match expr {
314 Expr::Identifier(ident) => ident.value.clone(),
315 Expr::CompoundIdentifier(parts) => parts
316 .last()
317 .map(|i| i.value.clone())
318 .unwrap_or_else(|| format!("col_{}", index)),
319 Expr::Function(func) => func.name.to_string().to_lowercase(),
320 Expr::Cast { .. } => format!("col_{}", index),
321 Expr::Case { .. } => format!("case_{}", index),
322 Expr::Subquery(_) => format!("subquery_{}", index),
323 _ => format!("col_{}", index),
324 }
325}
326
327fn infer_data_type(expr: &Expr) -> Option<String> {
332 infer_expr_type(expr).map(|canonical| canonical.as_uppercase_str().to_string())
333}
334
335pub(crate) fn extract_lateral_aliases(
351 statements: &[Statement],
352 sql: &str,
353) -> Vec<LateralAliasInfo> {
354 let mut aliases = Vec::with_capacity(64); for stmt in statements {
357 if aliases.len() >= MAX_LATERAL_ALIASES {
359 break;
360 }
361
362 if let Statement::Query(query) = stmt {
363 if let Some(with) = &query.with {
365 for cte in &with.cte_tables {
366 if aliases.len() >= MAX_LATERAL_ALIASES {
367 break;
368 }
369 extract_lateral_aliases_from_set_expr(&cte.query.body, sql, &mut aliases, 0);
370 }
371 }
372 if aliases.len() < MAX_LATERAL_ALIASES {
374 extract_lateral_aliases_from_set_expr(&query.body, sql, &mut aliases, 0);
375 }
376 }
377 }
378
379 aliases
380}
381
382fn extract_lateral_aliases_from_set_expr(
388 set_expr: &SetExpr,
389 sql: &str,
390 aliases: &mut Vec<LateralAliasInfo>,
391 depth: usize,
392) {
393 if depth > MAX_EXTRACTION_DEPTH || aliases.len() >= MAX_LATERAL_ALIASES {
394 return;
395 }
396
397 match set_expr {
398 SetExpr::Select(select) => {
399 extract_lateral_aliases_from_select(select, sql, aliases);
400 }
401 SetExpr::Query(query) => {
402 if let Some(with) = &query.with {
404 for cte in &with.cte_tables {
405 if aliases.len() >= MAX_LATERAL_ALIASES {
406 break;
407 }
408 extract_lateral_aliases_from_set_expr(&cte.query.body, sql, aliases, depth + 1);
409 }
410 }
411 if aliases.len() < MAX_LATERAL_ALIASES {
412 extract_lateral_aliases_from_set_expr(&query.body, sql, aliases, depth + 1);
413 }
414 }
415 SetExpr::SetOperation { left, right, .. } => {
416 extract_lateral_aliases_from_set_expr(left, sql, aliases, depth + 1);
417 if aliases.len() < MAX_LATERAL_ALIASES {
418 extract_lateral_aliases_from_set_expr(right, sql, aliases, depth + 1);
419 }
420 }
421 _ => {}
422 }
423}
424
425fn extract_lateral_aliases_from_select(
439 select: &Select,
440 sql: &str,
441 aliases: &mut Vec<LateralAliasInfo>,
442) {
443 if aliases.len() >= MAX_LATERAL_ALIASES {
445 return;
446 }
447
448 let projection_span = compute_projection_span(select, sql);
451 let (projection_start, projection_end) = match projection_span {
452 Some((start, end)) => (start, end),
453 None => return, };
455
456 for item in &select.projection {
457 if aliases.len() >= MAX_LATERAL_ALIASES {
459 break;
460 }
461
462 if let SelectItem::ExprWithAlias { alias, .. } = item {
463 if let Some(end_offset) = line_col_to_offset(
466 sql,
467 alias.span.end.line as usize,
468 alias.span.end.column as usize,
469 ) {
470 if end_offset <= sql.len() && sql.is_char_boundary(end_offset) {
473 aliases.push(LateralAliasInfo {
474 name: alias.value.clone(),
475 definition_end: end_offset,
476 projection_start,
477 projection_end,
478 });
479 }
480 }
481 }
482 }
483}
484
485fn compute_projection_span(select: &Select, sql: &str) -> Option<(usize, usize)> {
496 if select.projection.is_empty() {
497 return None;
498 }
499
500 let first_span = select
505 .projection
506 .iter()
507 .filter_map(select_item_span)
508 .next()
509 .or_else(|| {
510 let span = select.span();
511 if span.start.line > 0 && span.start.column > 0 {
512 Some((span.start.line, span.start.column))
513 } else {
514 None
515 }
516 })?;
517 let start = line_col_to_offset(sql, first_span.0 as usize, first_span.1 as usize)?;
518
519 let end = if let Some(from_item) = select.from.first() {
522 compute_from_clause_start(from_item, sql).unwrap_or_else(|| {
524 select
526 .projection
527 .last()
528 .and_then(|item| {
529 let span = select_item_end_span(item)?;
530 line_col_to_offset(sql, span.0 as usize, span.1 as usize)
531 })
532 .unwrap_or(sql.len())
533 })
534 } else {
535 sql.len()
538 };
539
540 if start <= sql.len() && end <= sql.len() && start <= end {
542 Some((start, end))
543 } else {
544 None
545 }
546}
547
548fn compute_from_clause_start(from_item: &TableWithJoins, sql: &str) -> Option<usize> {
550 let span = table_factor_span(&from_item.relation)?;
552 let table_start = line_col_to_offset(sql, span.0 as usize, span.1 as usize)?;
553
554 let search_start = find_char_boundary_before(sql, table_start.saturating_sub(50));
560 let search_area = &sql[search_start..table_start];
561
562 if let Some(pos) = rfind_ascii_case_insensitive(search_area, b"FROM") {
566 Some(search_start + pos)
567 } else {
568 Some(table_start)
570 }
571}
572
573fn find_char_boundary_before(s: &str, pos: usize) -> usize {
576 if pos >= s.len() {
577 return s.len();
578 }
579 (0..=pos)
581 .rev()
582 .find(|&i| s.is_char_boundary(i))
583 .unwrap_or(0)
584}
585
586fn rfind_ascii_case_insensitive(haystack: &str, needle: &[u8]) -> Option<usize> {
592 if needle.is_empty() || haystack.len() < needle.len() {
593 return None;
594 }
595
596 let haystack_bytes = haystack.as_bytes();
597
598 for start in (0..=(haystack_bytes.len() - needle.len())).rev() {
600 let mut matches = true;
601 for (i, &needle_byte) in needle.iter().enumerate() {
602 let hay_byte = haystack_bytes[start + i];
603 if !hay_byte.eq_ignore_ascii_case(&needle_byte) {
605 matches = false;
606 break;
607 }
608 }
609 if matches {
610 return Some(start);
611 }
612 }
613 None
614}
615
616fn table_factor_span(tf: &TableFactor) -> Option<(u64, u64)> {
618 match tf {
619 TableFactor::Table { name, .. } => name.0.first().map(|i| {
620 let span = i.span();
621 (span.start.line, span.start.column)
622 }),
623 TableFactor::Derived { subquery, .. } => {
624 let span = subquery.body.span();
626 if span.start.line > 0 {
627 Some((span.start.line, span.start.column))
628 } else {
629 None
630 }
631 }
632 _ => None,
633 }
634}
635
636fn select_item_span(item: &SelectItem) -> Option<(u64, u64)> {
638 match item {
639 SelectItem::ExprWithAlias { expr, .. } | SelectItem::UnnamedExpr(expr) => {
640 expr_start_span(expr)
641 }
642 SelectItem::Wildcard(opts) => {
643 if let Some(exclude) = &opts.opt_exclude {
645 match exclude {
647 sqlparser::ast::ExcludeSelectItem::Single(ident) => {
648 Some((ident.span.start.line, ident.span.start.column))
649 }
650 sqlparser::ast::ExcludeSelectItem::Multiple(idents) => idents
651 .first()
652 .map(|i| (i.span.start.line, i.span.start.column)),
653 }
654 } else {
655 None
656 }
657 }
658 SelectItem::QualifiedWildcard(name, _) => {
659 let span = name.span();
660 Some((span.start.line, span.start.column))
661 }
662 }
663}
664
665fn select_item_end_span(item: &SelectItem) -> Option<(u64, u64)> {
667 match item {
668 SelectItem::ExprWithAlias { alias, .. } => {
669 Some((alias.span.end.line, alias.span.end.column))
670 }
671 SelectItem::UnnamedExpr(expr) => expr_end_span(expr),
672 SelectItem::Wildcard(_) => None, SelectItem::QualifiedWildcard(name, _) => {
674 let span = name.span();
675 Some((span.end.line, span.end.column))
676 }
677 }
678}
679
680fn expr_start_span(expr: &Expr) -> Option<(u64, u64)> {
683 let span = expr.span();
684 if span.start.line > 0 && span.start.column > 0 {
686 Some((span.start.line, span.start.column))
687 } else {
688 None
689 }
690}
691
692fn expr_end_span(expr: &Expr) -> Option<(u64, u64)> {
695 let span = expr.span();
696 if span.end.line > 0 && span.end.column > 0 {
698 Some((span.end.line, span.end.column))
699 } else {
700 None
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use sqlparser::parser::Parser;
708
709 fn parse_sql(sql: &str) -> Vec<Statement> {
710 Parser::parse_sql(&sqlparser::dialect::GenericDialect {}, sql).unwrap()
711 }
712
713 #[test]
714 fn test_extract_cte() {
715 let sql = "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte";
716 let stmts = parse_sql(sql);
717 let ctx = extract_ast_context(&stmts);
718
719 assert!(ctx.cte_definitions.contains_key("cte"));
720 let cte = &ctx.cte_definitions["cte"];
721 assert_eq!(cte.name, "cte");
722 assert_eq!(cte.projected_columns.len(), 2);
723 assert_eq!(cte.projected_columns[0].name, "id");
724 assert_eq!(cte.projected_columns[1].name, "name");
725 }
726
727 #[test]
728 fn test_extract_cte_with_declared_columns() {
729 let sql = "WITH cte(a, b) AS (SELECT id, name FROM users) SELECT * FROM cte";
730 let stmts = parse_sql(sql);
731 let ctx = extract_ast_context(&stmts);
732
733 let cte = &ctx.cte_definitions["cte"];
734 assert_eq!(cte.declared_columns, vec!["a", "b"]);
735 }
736
737 #[test]
738 fn test_extract_table_alias() {
739 let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
740 let stmts = parse_sql(sql);
741 let ctx = extract_ast_context(&stmts);
742
743 assert!(ctx.table_aliases.contains_key("u"));
745 assert!(ctx.table_aliases.contains_key("o"));
746 }
747
748 #[test]
749 fn test_extract_subquery_alias() {
750 let sql = "SELECT * FROM (SELECT a, b FROM t) AS sub WHERE sub.a = 1";
751 let stmts = parse_sql(sql);
752 let ctx = extract_ast_context(&stmts);
753
754 assert!(ctx.subquery_aliases.contains_key("sub"));
755 let sub = &ctx.subquery_aliases["sub"];
756 assert_eq!(sub.projected_columns.len(), 2);
757 assert_eq!(sub.projected_columns[0].name, "a");
758 assert_eq!(sub.projected_columns[1].name, "b");
759 }
760
761 #[test]
762 fn test_extract_lateral_subquery() {
763 let sql = "SELECT * FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) AS o";
764 let stmts = parse_sql(sql);
765 let ctx = extract_ast_context(&stmts);
766
767 assert!(ctx.subquery_aliases.contains_key("o"));
769 }
770
771 #[test]
772 fn test_extract_column_with_alias() {
773 let sql =
774 "WITH cte AS (SELECT id AS user_id, name AS user_name FROM users) SELECT * FROM cte";
775 let stmts = parse_sql(sql);
776 let ctx = extract_ast_context(&stmts);
777
778 let cte = &ctx.cte_definitions["cte"];
779 assert_eq!(cte.projected_columns[0].name, "user_id");
780 assert_eq!(cte.projected_columns[1].name, "user_name");
781 }
782
783 #[test]
784 fn test_extract_function_column_name() {
785 let sql = "WITH cte AS (SELECT COUNT(*), SUM(amount) FROM orders) SELECT * FROM cte";
786 let stmts = parse_sql(sql);
787 let ctx = extract_ast_context(&stmts);
788
789 let cte = &ctx.cte_definitions["cte"];
790 assert!(cte.projected_columns[0]
791 .name
792 .to_lowercase()
793 .contains("count"));
794 }
795
796 #[test]
797 fn test_extract_wildcard() {
798 let sql = "WITH cte AS (SELECT * FROM users) SELECT * FROM cte";
799 let stmts = parse_sql(sql);
800 let ctx = extract_ast_context(&stmts);
801
802 let cte = &ctx.cte_definitions["cte"];
803 assert_eq!(cte.projected_columns[0].name, "*");
804 }
805
806 #[test]
807 fn test_extract_recursive_cte() {
808 let sql = r#"
809 WITH RECURSIVE cte AS (
810 SELECT 1 AS n
811 UNION ALL
812 SELECT n + 1 FROM cte WHERE n < 10
813 )
814 SELECT * FROM cte
815 "#;
816 let stmts = parse_sql(sql);
817 let ctx = extract_ast_context(&stmts);
818
819 let cte = &ctx.cte_definitions["cte"];
820 assert_eq!(cte.projected_columns.len(), 1);
822 assert_eq!(cte.projected_columns[0].name, "n");
823 }
824
825 #[test]
826 fn test_has_enrichment() {
827 let sql = "SELECT * FROM users";
828 let stmts = parse_sql(sql);
829 let ctx = extract_ast_context(&stmts);
830
831 assert!(ctx.has_enrichment()); }
833
834 #[test]
835 fn test_empty_context() {
836 let ctx = AstContext::default();
837 assert!(!ctx.has_enrichment());
838 }
839
840 #[test]
843 fn test_extract_lateral_aliases_single() {
844 let sql = "SELECT price * qty AS total FROM orders";
845 let stmts = parse_sql(sql);
846 let aliases = extract_lateral_aliases(&stmts, sql);
847
848 assert_eq!(aliases.len(), 1);
849 assert_eq!(aliases[0].name, "total");
850 assert!(aliases[0].definition_end > 0);
852 assert!(aliases[0].definition_end <= sql.len());
853 }
854
855 #[test]
856 fn test_extract_lateral_aliases_with_leading_wildcard() {
857 let sql = "SELECT *, price * qty AS total, discount AS disc FROM orders";
858 let stmts = parse_sql(sql);
859 let aliases = extract_lateral_aliases(&stmts, sql);
860
861 let names: Vec<_> = aliases.iter().map(|a| a.name.as_str()).collect();
862 assert_eq!(names, vec!["total", "disc"]);
863 }
864
865 #[test]
866 fn test_extract_lateral_aliases_multiple() {
867 let sql = "SELECT a AS x, b AS y, c AS z FROM t";
868 let stmts = parse_sql(sql);
869 let aliases = extract_lateral_aliases(&stmts, sql);
870
871 assert_eq!(aliases.len(), 3);
872 assert_eq!(aliases[0].name, "x");
873 assert_eq!(aliases[1].name, "y");
874 assert_eq!(aliases[2].name, "z");
875 assert!(aliases[0].definition_end < aliases[1].definition_end);
877 assert!(aliases[1].definition_end < aliases[2].definition_end);
878 }
879
880 #[test]
881 fn test_extract_lateral_aliases_with_expression() {
882 let sql = "SELECT price * qty AS total, total * 0.1 AS tax FROM orders";
883 let stmts = parse_sql(sql);
884 let aliases = extract_lateral_aliases(&stmts, sql);
885
886 assert_eq!(aliases.len(), 2);
887 assert_eq!(aliases[0].name, "total");
888 assert_eq!(aliases[1].name, "tax");
889 }
890
891 #[test]
892 fn test_extract_lateral_aliases_no_aliases() {
893 let sql = "SELECT price, qty FROM orders";
894 let stmts = parse_sql(sql);
895 let aliases = extract_lateral_aliases(&stmts, sql);
896
897 assert!(aliases.is_empty());
898 }
899
900 #[test]
901 fn test_extract_lateral_aliases_mixed() {
902 let sql = "SELECT a, b AS alias_b, c FROM t";
904 let stmts = parse_sql(sql);
905 let aliases = extract_lateral_aliases(&stmts, sql);
906
907 assert_eq!(aliases.len(), 1);
908 assert_eq!(aliases[0].name, "alias_b");
909 }
910
911 #[test]
912 fn test_extract_lateral_aliases_quoted() {
913 let sql = r#"SELECT a AS "My Total", b AS "Tax Amount" FROM t"#;
914 let stmts = parse_sql(sql);
915 let aliases = extract_lateral_aliases(&stmts, sql);
916
917 assert_eq!(aliases.len(), 2);
918 assert_eq!(aliases[0].name, "My Total");
919 assert_eq!(aliases[1].name, "Tax Amount");
920 }
921
922 #[test]
923 fn test_extract_lateral_aliases_subquery_in_from() {
924 let sql = "SELECT * FROM (SELECT a AS x, b AS y FROM t) sub";
927 let stmts = parse_sql(sql);
928 let aliases = extract_lateral_aliases(&stmts, sql);
929
930 assert_eq!(aliases.len(), 0);
932 }
933
934 #[test]
935 fn test_extract_lateral_aliases_outer_select_with_alias() {
936 let sql = "SELECT sub.x AS outer_x FROM (SELECT a AS x FROM t) sub";
938 let stmts = parse_sql(sql);
939 let aliases = extract_lateral_aliases(&stmts, sql);
940
941 assert_eq!(aliases.len(), 1);
942 assert_eq!(aliases[0].name, "outer_x");
943 }
944
945 #[test]
946 fn test_extract_lateral_aliases_with_unicode() {
947 let sql = "SELECT '日本語' AS label, value AS val FROM t";
950 let stmts = parse_sql(sql);
951 let aliases = extract_lateral_aliases(&stmts, sql);
952
953 assert_eq!(aliases.len(), 2);
955 assert_eq!(aliases[0].name, "label");
956 assert_eq!(aliases[1].name, "val");
957 }
958
959 #[test]
960 fn test_extract_lateral_aliases_cte_scope_isolation() {
961 let sql =
964 "WITH cte AS (SELECT a AS inner_alias FROM t) SELECT cte.a AS outer_alias FROM cte";
965 let stmts = parse_sql(sql);
966 let aliases = extract_lateral_aliases(&stmts, sql);
967
968 assert_eq!(aliases.len(), 2);
970
971 let inner = aliases.iter().find(|a| a.name == "inner_alias").unwrap();
972 let outer = aliases.iter().find(|a| a.name == "outer_alias").unwrap();
973
974 assert!(
976 inner.projection_start < outer.projection_start,
977 "CTE projection should start before outer SELECT projection"
978 );
979
980 assert!(
982 inner.projection_end < outer.projection_start
983 || outer.projection_end < inner.projection_start
984 || inner.projection_start != outer.projection_start,
985 "CTE and outer SELECT projections should have different spans"
986 );
987 }
988
989 #[test]
990 fn test_extract_lateral_aliases_projection_span_validity() {
991 let sql = "SELECT a AS x, b AS y FROM t";
993 let stmts = parse_sql(sql);
994 let aliases = extract_lateral_aliases(&stmts, sql);
995
996 assert_eq!(aliases.len(), 2);
997
998 for alias in &aliases {
999 assert!(
1001 alias.definition_end <= alias.projection_end,
1002 "Alias definition should be within projection span"
1003 );
1004 assert!(
1005 alias.projection_start < alias.definition_end,
1006 "Projection should start before alias definition ends"
1007 );
1008 }
1009 }
1010}