1use sqlparser::ast::{
7 Cte, Expr, Query, Select, SelectItem, SetExpr, Spanned, Statement, TableFactor, TableWithJoins,
8};
9
10use crate::analyzer::helpers::{infer_expr_type, line_col_to_offset};
11use crate::types::{AstColumnInfo, AstContext, AstTableInfo, CteInfo, SubqueryInfo};
12
13#[derive(Debug, Clone)]
24pub struct LateralAliasInfo {
25 pub name: String,
27 pub definition_end: usize,
29 pub projection_start: usize,
31 pub projection_end: usize,
33}
34
35const MAX_EXTRACTION_DEPTH: usize = 50;
38
39const MAX_LATERAL_ALIASES: usize = 1000;
43
44pub(crate) fn extract_ast_context(statements: &[Statement]) -> AstContext {
51 let mut ctx = AstContext::default();
52
53 for stmt in statements {
54 extract_from_statement(stmt, &mut ctx, 0);
55 }
56
57 ctx
58}
59
60fn extract_from_statement(stmt: &Statement, ctx: &mut AstContext, depth: usize) {
62 if depth > MAX_EXTRACTION_DEPTH {
63 return; }
65
66 match stmt {
67 Statement::Query(query) => {
68 extract_from_query(query, ctx, depth);
69 }
70 Statement::Insert(insert) => {
71 if let Some(source) = &insert.source {
73 extract_from_query(source, ctx, depth);
74 }
75 }
76 Statement::CreateTable(ct) => {
77 if let Some(query) = &ct.query {
79 extract_from_query(query, ctx, depth);
80 }
81 }
82 Statement::CreateView { query, .. } => {
83 extract_from_query(query, ctx, depth);
84 }
85 _ => {}
86 }
87}
88
89fn extract_from_query(query: &Query, ctx: &mut AstContext, depth: usize) {
91 if depth > MAX_EXTRACTION_DEPTH {
92 return;
93 }
94
95 if let Some(with) = &query.with {
97 let is_recursive = with.recursive;
98 for cte in &with.cte_tables {
99 if let Some(info) = extract_cte_info(cte, is_recursive) {
100 ctx.cte_definitions.insert(info.name.clone(), info);
101 }
102 }
103 }
104
105 extract_from_set_expr(&query.body, ctx, depth + 1);
107}
108
109fn extract_from_set_expr(set_expr: &SetExpr, ctx: &mut AstContext, depth: usize) {
111 if depth > MAX_EXTRACTION_DEPTH {
112 return;
113 }
114
115 match set_expr {
116 SetExpr::Select(select) => {
117 extract_from_select(select, ctx, depth);
118 }
119 SetExpr::Query(query) => {
120 extract_from_query(query, ctx, depth);
121 }
122 SetExpr::SetOperation { left, right, .. } => {
123 extract_from_set_expr(left, ctx, depth + 1);
124 extract_from_set_expr(right, ctx, depth + 1);
125 }
126 SetExpr::Values(_) => {}
127 SetExpr::Insert(_) => {}
128 SetExpr::Update(_) => {}
129 SetExpr::Table(_) => {}
130 SetExpr::Delete(_) => {}
131 SetExpr::Merge(_) => {}
132 }
133}
134
135fn extract_from_select(select: &Select, ctx: &mut AstContext, depth: usize) {
137 if depth > MAX_EXTRACTION_DEPTH {
138 return;
139 }
140
141 for table_with_joins in &select.from {
143 extract_from_table_with_joins(table_with_joins, ctx, depth);
144 }
145}
146
147fn extract_from_table_with_joins(twj: &TableWithJoins, ctx: &mut AstContext, depth: usize) {
149 if depth > MAX_EXTRACTION_DEPTH {
150 return;
151 }
152
153 extract_from_table_factor(&twj.relation, ctx, depth);
154
155 for join in &twj.joins {
156 extract_from_table_factor(&join.relation, ctx, depth);
157 }
158}
159
160fn extract_from_table_factor(tf: &TableFactor, ctx: &mut AstContext, depth: usize) {
162 if depth > MAX_EXTRACTION_DEPTH {
163 return;
164 }
165
166 match tf {
167 TableFactor::Table { name, alias, .. } => {
168 let table_name = name.to_string();
169 let alias_name = alias.as_ref().map(|a| a.name.value.clone());
170
171 let key = alias_name.clone().unwrap_or_else(|| {
173 name.0
175 .last()
176 .map(|i| i.to_string())
177 .unwrap_or(table_name.clone())
178 });
179
180 ctx.table_aliases.insert(key, AstTableInfo);
181 }
182 TableFactor::Derived {
183 subquery, alias, ..
184 } => {
185 if let Some(alias) = alias {
187 let columns = extract_projected_columns_from_query(subquery);
188 ctx.subquery_aliases.insert(
189 alias.name.value.clone(),
190 SubqueryInfo {
191 projected_columns: columns,
192 },
193 );
194 }
195
196 extract_from_query(subquery, ctx, depth + 1);
198 }
199 TableFactor::NestedJoin {
200 table_with_joins, ..
201 } => {
202 extract_from_table_with_joins(table_with_joins, ctx, depth + 1);
203 }
204 TableFactor::TableFunction { .. } => {}
205 TableFactor::UNNEST {
206 alias: Some(alias), ..
207 } => {
208 ctx.table_aliases
209 .insert(alias.name.value.clone(), AstTableInfo);
210 }
211 _ => {}
212 }
213}
214
215fn extract_cte_info(cte: &Cte, is_recursive: bool) -> Option<CteInfo> {
217 let name = cte.alias.name.value.clone();
218
219 let declared_columns: Vec<String> = cte
221 .alias
222 .columns
223 .iter()
224 .map(|c| c.name.value.clone())
225 .collect();
226
227 let projected_columns = if is_recursive {
229 extract_base_case_columns(&cte.query)
231 } else {
232 extract_projected_columns_from_query(&cte.query)
233 };
234
235 Some(CteInfo {
236 name,
237 declared_columns,
238 projected_columns,
239 })
240}
241
242fn extract_base_case_columns(query: &Query) -> Vec<AstColumnInfo> {
244 match &*query.body {
245 SetExpr::SetOperation { left, .. } => {
246 if let SetExpr::Select(select) = &**left {
248 extract_select_columns(select)
249 } else {
250 vec![]
251 }
252 }
253 SetExpr::Select(select) => extract_select_columns(select),
254 _ => vec![],
255 }
256}
257
258fn extract_projected_columns_from_query(query: &Query) -> Vec<AstColumnInfo> {
260 match &*query.body {
261 SetExpr::Select(select) => extract_select_columns(select),
262 SetExpr::SetOperation { left, .. } => {
263 if let SetExpr::Select(select) = &**left {
265 extract_select_columns(select)
266 } else {
267 vec![]
268 }
269 }
270 _ => vec![],
271 }
272}
273
274fn extract_select_columns(select: &Select) -> Vec<AstColumnInfo> {
276 let mut columns = Vec::new();
277
278 for (idx, item) in select.projection.iter().enumerate() {
279 match item {
280 SelectItem::ExprWithAlias { alias, expr } => {
281 columns.push(AstColumnInfo {
282 name: alias.value.clone(),
283 data_type: infer_data_type(expr),
284 });
285 }
286 SelectItem::UnnamedExpr(expr) => {
287 columns.push(AstColumnInfo {
288 name: derive_column_name(expr, idx),
289 data_type: infer_data_type(expr),
290 });
291 }
292 SelectItem::Wildcard(_) => {
293 columns.push(AstColumnInfo {
294 name: "*".to_string(),
295 data_type: None,
296 });
297 }
298 SelectItem::QualifiedWildcard(name, _) => {
299 columns.push(AstColumnInfo {
300 name: format!("{}.*", name),
301 data_type: None,
302 });
303 }
304 }
305 }
306
307 columns
308}
309
310fn derive_column_name(expr: &Expr, index: usize) -> String {
312 match expr {
313 Expr::Identifier(ident) => ident.value.clone(),
314 Expr::CompoundIdentifier(parts) => parts
315 .last()
316 .map(|i| i.value.clone())
317 .unwrap_or_else(|| format!("col_{}", index)),
318 Expr::Function(func) => func.name.to_string().to_lowercase(),
319 Expr::Cast { .. } => format!("col_{}", index),
320 Expr::Case { .. } => format!("case_{}", index),
321 Expr::Subquery(_) => format!("subquery_{}", index),
322 _ => format!("col_{}", index),
323 }
324}
325
326fn infer_data_type(expr: &Expr) -> Option<String> {
331 infer_expr_type(expr).map(|canonical| canonical.as_uppercase_str().to_string())
332}
333
334pub(crate) fn extract_lateral_aliases(
350 statements: &[Statement],
351 sql: &str,
352) -> Vec<LateralAliasInfo> {
353 let mut aliases = Vec::with_capacity(64); for stmt in statements {
356 if aliases.len() >= MAX_LATERAL_ALIASES {
358 break;
359 }
360
361 if let Statement::Query(query) = stmt {
362 if let Some(with) = &query.with {
364 for cte in &with.cte_tables {
365 if aliases.len() >= MAX_LATERAL_ALIASES {
366 break;
367 }
368 extract_lateral_aliases_from_set_expr(&cte.query.body, sql, &mut aliases, 0);
369 }
370 }
371 if aliases.len() < MAX_LATERAL_ALIASES {
373 extract_lateral_aliases_from_set_expr(&query.body, sql, &mut aliases, 0);
374 }
375 }
376 }
377
378 aliases
379}
380
381fn extract_lateral_aliases_from_set_expr(
387 set_expr: &SetExpr,
388 sql: &str,
389 aliases: &mut Vec<LateralAliasInfo>,
390 depth: usize,
391) {
392 if depth > MAX_EXTRACTION_DEPTH || aliases.len() >= MAX_LATERAL_ALIASES {
393 return;
394 }
395
396 match set_expr {
397 SetExpr::Select(select) => {
398 extract_lateral_aliases_from_select(select, sql, aliases);
399 }
400 SetExpr::Query(query) => {
401 if let Some(with) = &query.with {
403 for cte in &with.cte_tables {
404 if aliases.len() >= MAX_LATERAL_ALIASES {
405 break;
406 }
407 extract_lateral_aliases_from_set_expr(&cte.query.body, sql, aliases, depth + 1);
408 }
409 }
410 if aliases.len() < MAX_LATERAL_ALIASES {
411 extract_lateral_aliases_from_set_expr(&query.body, sql, aliases, depth + 1);
412 }
413 }
414 SetExpr::SetOperation { left, right, .. } => {
415 extract_lateral_aliases_from_set_expr(left, sql, aliases, depth + 1);
416 if aliases.len() < MAX_LATERAL_ALIASES {
417 extract_lateral_aliases_from_set_expr(right, sql, aliases, depth + 1);
418 }
419 }
420 _ => {}
421 }
422}
423
424fn extract_lateral_aliases_from_select(
438 select: &Select,
439 sql: &str,
440 aliases: &mut Vec<LateralAliasInfo>,
441) {
442 if aliases.len() >= MAX_LATERAL_ALIASES {
444 return;
445 }
446
447 let projection_span = compute_projection_span(select, sql);
450 let (projection_start, projection_end) = match projection_span {
451 Some((start, end)) => (start, end),
452 None => return, };
454
455 for item in &select.projection {
456 if aliases.len() >= MAX_LATERAL_ALIASES {
458 break;
459 }
460
461 if let SelectItem::ExprWithAlias { alias, .. } = item {
462 if let Some(end_offset) = line_col_to_offset(
465 sql,
466 alias.span.end.line as usize,
467 alias.span.end.column as usize,
468 ) {
469 if end_offset <= sql.len() && sql.is_char_boundary(end_offset) {
472 aliases.push(LateralAliasInfo {
473 name: alias.value.clone(),
474 definition_end: end_offset,
475 projection_start,
476 projection_end,
477 });
478 }
479 }
480 }
481 }
482}
483
484fn compute_projection_span(select: &Select, sql: &str) -> Option<(usize, usize)> {
495 if select.projection.is_empty() {
496 return None;
497 }
498
499 let first_span = select
504 .projection
505 .iter()
506 .filter_map(select_item_span)
507 .next()
508 .or_else(|| {
509 let span = select.span();
510 if span.start.line > 0 && span.start.column > 0 {
511 Some((span.start.line, span.start.column))
512 } else {
513 None
514 }
515 })?;
516 let start = line_col_to_offset(sql, first_span.0 as usize, first_span.1 as usize)?;
517
518 let end = if let Some(from_item) = select.from.first() {
521 compute_from_clause_start(from_item, sql).unwrap_or_else(|| {
523 select
525 .projection
526 .last()
527 .and_then(|item| {
528 let span = select_item_end_span(item)?;
529 line_col_to_offset(sql, span.0 as usize, span.1 as usize)
530 })
531 .unwrap_or(sql.len())
532 })
533 } else {
534 sql.len()
537 };
538
539 if start <= sql.len() && end <= sql.len() && start <= end {
541 Some((start, end))
542 } else {
543 None
544 }
545}
546
547fn compute_from_clause_start(from_item: &TableWithJoins, sql: &str) -> Option<usize> {
549 let span = table_factor_span(&from_item.relation)?;
551 let table_start = line_col_to_offset(sql, span.0 as usize, span.1 as usize)?;
552
553 let search_start = find_char_boundary_before(sql, table_start.saturating_sub(50));
559 let search_area = &sql[search_start..table_start];
560
561 if let Some(pos) = rfind_ascii_case_insensitive(search_area, b"FROM") {
565 Some(search_start + pos)
566 } else {
567 Some(table_start)
569 }
570}
571
572fn find_char_boundary_before(s: &str, pos: usize) -> usize {
575 if pos >= s.len() {
576 return s.len();
577 }
578 (0..=pos)
580 .rev()
581 .find(|&i| s.is_char_boundary(i))
582 .unwrap_or(0)
583}
584
585fn rfind_ascii_case_insensitive(haystack: &str, needle: &[u8]) -> Option<usize> {
591 if needle.is_empty() || haystack.len() < needle.len() {
592 return None;
593 }
594
595 let haystack_bytes = haystack.as_bytes();
596
597 for start in (0..=(haystack_bytes.len() - needle.len())).rev() {
599 let mut matches = true;
600 for (i, &needle_byte) in needle.iter().enumerate() {
601 let hay_byte = haystack_bytes[start + i];
602 if !hay_byte.eq_ignore_ascii_case(&needle_byte) {
604 matches = false;
605 break;
606 }
607 }
608 if matches {
609 return Some(start);
610 }
611 }
612 None
613}
614
615fn table_factor_span(tf: &TableFactor) -> Option<(u64, u64)> {
617 match tf {
618 TableFactor::Table { name, .. } => name.0.first().map(|i| {
619 let span = i.span();
620 (span.start.line, span.start.column)
621 }),
622 TableFactor::Derived { subquery, .. } => {
623 let span = subquery.body.span();
625 if span.start.line > 0 {
626 Some((span.start.line, span.start.column))
627 } else {
628 None
629 }
630 }
631 _ => None,
632 }
633}
634
635fn select_item_span(item: &SelectItem) -> Option<(u64, u64)> {
637 match item {
638 SelectItem::ExprWithAlias { expr, .. } | SelectItem::UnnamedExpr(expr) => {
639 expr_start_span(expr)
640 }
641 SelectItem::Wildcard(opts) => {
642 if let Some(exclude) = &opts.opt_exclude {
644 match exclude {
646 sqlparser::ast::ExcludeSelectItem::Single(ident) => {
647 Some((ident.span.start.line, ident.span.start.column))
648 }
649 sqlparser::ast::ExcludeSelectItem::Multiple(idents) => idents
650 .first()
651 .map(|i| (i.span.start.line, i.span.start.column)),
652 }
653 } else {
654 None
655 }
656 }
657 SelectItem::QualifiedWildcard(name, _) => {
658 let span = name.span();
659 Some((span.start.line, span.start.column))
660 }
661 }
662}
663
664fn select_item_end_span(item: &SelectItem) -> Option<(u64, u64)> {
666 match item {
667 SelectItem::ExprWithAlias { alias, .. } => {
668 Some((alias.span.end.line, alias.span.end.column))
669 }
670 SelectItem::UnnamedExpr(expr) => expr_end_span(expr),
671 SelectItem::Wildcard(_) => None, SelectItem::QualifiedWildcard(name, _) => {
673 let span = name.span();
674 Some((span.end.line, span.end.column))
675 }
676 }
677}
678
679fn expr_start_span(expr: &Expr) -> Option<(u64, u64)> {
682 let span = expr.span();
683 if span.start.line > 0 && span.start.column > 0 {
685 Some((span.start.line, span.start.column))
686 } else {
687 None
688 }
689}
690
691fn expr_end_span(expr: &Expr) -> Option<(u64, u64)> {
694 let span = expr.span();
695 if span.end.line > 0 && span.end.column > 0 {
697 Some((span.end.line, span.end.column))
698 } else {
699 None
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use sqlparser::parser::Parser;
707
708 fn parse_sql(sql: &str) -> Vec<Statement> {
709 Parser::parse_sql(&sqlparser::dialect::GenericDialect {}, sql).unwrap()
710 }
711
712 #[test]
713 fn test_extract_cte() {
714 let sql = "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte";
715 let stmts = parse_sql(sql);
716 let ctx = extract_ast_context(&stmts);
717
718 assert!(ctx.cte_definitions.contains_key("cte"));
719 let cte = &ctx.cte_definitions["cte"];
720 assert_eq!(cte.name, "cte");
721 assert_eq!(cte.projected_columns.len(), 2);
722 assert_eq!(cte.projected_columns[0].name, "id");
723 assert_eq!(cte.projected_columns[1].name, "name");
724 }
725
726 #[test]
727 fn test_extract_cte_with_declared_columns() {
728 let sql = "WITH cte(a, b) AS (SELECT id, name FROM users) SELECT * FROM cte";
729 let stmts = parse_sql(sql);
730 let ctx = extract_ast_context(&stmts);
731
732 let cte = &ctx.cte_definitions["cte"];
733 assert_eq!(cte.declared_columns, vec!["a", "b"]);
734 }
735
736 #[test]
737 fn test_extract_table_alias() {
738 let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
739 let stmts = parse_sql(sql);
740 let ctx = extract_ast_context(&stmts);
741
742 assert!(ctx.table_aliases.contains_key("u"));
744 assert!(ctx.table_aliases.contains_key("o"));
745 }
746
747 #[test]
748 fn test_extract_subquery_alias() {
749 let sql = "SELECT * FROM (SELECT a, b FROM t) AS sub WHERE sub.a = 1";
750 let stmts = parse_sql(sql);
751 let ctx = extract_ast_context(&stmts);
752
753 assert!(ctx.subquery_aliases.contains_key("sub"));
754 let sub = &ctx.subquery_aliases["sub"];
755 assert_eq!(sub.projected_columns.len(), 2);
756 assert_eq!(sub.projected_columns[0].name, "a");
757 assert_eq!(sub.projected_columns[1].name, "b");
758 }
759
760 #[test]
761 fn test_extract_lateral_subquery() {
762 let sql = "SELECT * FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) AS o";
763 let stmts = parse_sql(sql);
764 let ctx = extract_ast_context(&stmts);
765
766 assert!(ctx.subquery_aliases.contains_key("o"));
768 }
769
770 #[test]
771 fn test_extract_column_with_alias() {
772 let sql =
773 "WITH cte AS (SELECT id AS user_id, name AS user_name FROM users) SELECT * FROM cte";
774 let stmts = parse_sql(sql);
775 let ctx = extract_ast_context(&stmts);
776
777 let cte = &ctx.cte_definitions["cte"];
778 assert_eq!(cte.projected_columns[0].name, "user_id");
779 assert_eq!(cte.projected_columns[1].name, "user_name");
780 }
781
782 #[test]
783 fn test_extract_function_column_name() {
784 let sql = "WITH cte AS (SELECT COUNT(*), SUM(amount) FROM orders) SELECT * FROM cte";
785 let stmts = parse_sql(sql);
786 let ctx = extract_ast_context(&stmts);
787
788 let cte = &ctx.cte_definitions["cte"];
789 assert!(cte.projected_columns[0]
790 .name
791 .to_lowercase()
792 .contains("count"));
793 }
794
795 #[test]
796 fn test_extract_wildcard() {
797 let sql = "WITH cte AS (SELECT * FROM users) SELECT * FROM cte";
798 let stmts = parse_sql(sql);
799 let ctx = extract_ast_context(&stmts);
800
801 let cte = &ctx.cte_definitions["cte"];
802 assert_eq!(cte.projected_columns[0].name, "*");
803 }
804
805 #[test]
806 fn test_extract_recursive_cte() {
807 let sql = r#"
808 WITH RECURSIVE cte AS (
809 SELECT 1 AS n
810 UNION ALL
811 SELECT n + 1 FROM cte WHERE n < 10
812 )
813 SELECT * FROM cte
814 "#;
815 let stmts = parse_sql(sql);
816 let ctx = extract_ast_context(&stmts);
817
818 let cte = &ctx.cte_definitions["cte"];
819 assert_eq!(cte.projected_columns.len(), 1);
821 assert_eq!(cte.projected_columns[0].name, "n");
822 }
823
824 #[test]
825 fn test_has_enrichment() {
826 let sql = "SELECT * FROM users";
827 let stmts = parse_sql(sql);
828 let ctx = extract_ast_context(&stmts);
829
830 assert!(ctx.has_enrichment()); }
832
833 #[test]
834 fn test_empty_context() {
835 let ctx = AstContext::default();
836 assert!(!ctx.has_enrichment());
837 }
838
839 #[test]
842 fn test_extract_lateral_aliases_single() {
843 let sql = "SELECT price * qty AS total FROM orders";
844 let stmts = parse_sql(sql);
845 let aliases = extract_lateral_aliases(&stmts, sql);
846
847 assert_eq!(aliases.len(), 1);
848 assert_eq!(aliases[0].name, "total");
849 assert!(aliases[0].definition_end > 0);
851 assert!(aliases[0].definition_end <= sql.len());
852 }
853
854 #[test]
855 fn test_extract_lateral_aliases_with_leading_wildcard() {
856 let sql = "SELECT *, price * qty AS total, discount AS disc FROM orders";
857 let stmts = parse_sql(sql);
858 let aliases = extract_lateral_aliases(&stmts, sql);
859
860 let names: Vec<_> = aliases.iter().map(|a| a.name.as_str()).collect();
861 assert_eq!(names, vec!["total", "disc"]);
862 }
863
864 #[test]
865 fn test_extract_lateral_aliases_multiple() {
866 let sql = "SELECT a AS x, b AS y, c AS z FROM t";
867 let stmts = parse_sql(sql);
868 let aliases = extract_lateral_aliases(&stmts, sql);
869
870 assert_eq!(aliases.len(), 3);
871 assert_eq!(aliases[0].name, "x");
872 assert_eq!(aliases[1].name, "y");
873 assert_eq!(aliases[2].name, "z");
874 assert!(aliases[0].definition_end < aliases[1].definition_end);
876 assert!(aliases[1].definition_end < aliases[2].definition_end);
877 }
878
879 #[test]
880 fn test_extract_lateral_aliases_with_expression() {
881 let sql = "SELECT price * qty AS total, total * 0.1 AS tax FROM orders";
882 let stmts = parse_sql(sql);
883 let aliases = extract_lateral_aliases(&stmts, sql);
884
885 assert_eq!(aliases.len(), 2);
886 assert_eq!(aliases[0].name, "total");
887 assert_eq!(aliases[1].name, "tax");
888 }
889
890 #[test]
891 fn test_extract_lateral_aliases_no_aliases() {
892 let sql = "SELECT price, qty FROM orders";
893 let stmts = parse_sql(sql);
894 let aliases = extract_lateral_aliases(&stmts, sql);
895
896 assert!(aliases.is_empty());
897 }
898
899 #[test]
900 fn test_extract_lateral_aliases_mixed() {
901 let sql = "SELECT a, b AS alias_b, c FROM t";
903 let stmts = parse_sql(sql);
904 let aliases = extract_lateral_aliases(&stmts, sql);
905
906 assert_eq!(aliases.len(), 1);
907 assert_eq!(aliases[0].name, "alias_b");
908 }
909
910 #[test]
911 fn test_extract_lateral_aliases_quoted() {
912 let sql = r#"SELECT a AS "My Total", b AS "Tax Amount" FROM t"#;
913 let stmts = parse_sql(sql);
914 let aliases = extract_lateral_aliases(&stmts, sql);
915
916 assert_eq!(aliases.len(), 2);
917 assert_eq!(aliases[0].name, "My Total");
918 assert_eq!(aliases[1].name, "Tax Amount");
919 }
920
921 #[test]
922 fn test_extract_lateral_aliases_subquery_in_from() {
923 let sql = "SELECT * FROM (SELECT a AS x, b AS y FROM t) sub";
926 let stmts = parse_sql(sql);
927 let aliases = extract_lateral_aliases(&stmts, sql);
928
929 assert_eq!(aliases.len(), 0);
931 }
932
933 #[test]
934 fn test_extract_lateral_aliases_outer_select_with_alias() {
935 let sql = "SELECT sub.x AS outer_x FROM (SELECT a AS x FROM t) sub";
937 let stmts = parse_sql(sql);
938 let aliases = extract_lateral_aliases(&stmts, sql);
939
940 assert_eq!(aliases.len(), 1);
941 assert_eq!(aliases[0].name, "outer_x");
942 }
943
944 #[test]
945 fn test_extract_lateral_aliases_with_unicode() {
946 let sql = "SELECT '日本語' AS label, value AS val FROM t";
949 let stmts = parse_sql(sql);
950 let aliases = extract_lateral_aliases(&stmts, sql);
951
952 assert_eq!(aliases.len(), 2);
954 assert_eq!(aliases[0].name, "label");
955 assert_eq!(aliases[1].name, "val");
956 }
957
958 #[test]
959 fn test_extract_lateral_aliases_cte_scope_isolation() {
960 let sql =
963 "WITH cte AS (SELECT a AS inner_alias FROM t) SELECT cte.a AS outer_alias FROM cte";
964 let stmts = parse_sql(sql);
965 let aliases = extract_lateral_aliases(&stmts, sql);
966
967 assert_eq!(aliases.len(), 2);
969
970 let inner = aliases.iter().find(|a| a.name == "inner_alias").unwrap();
971 let outer = aliases.iter().find(|a| a.name == "outer_alias").unwrap();
972
973 assert!(
975 inner.projection_start < outer.projection_start,
976 "CTE projection should start before outer SELECT projection"
977 );
978
979 assert!(
981 inner.projection_end < outer.projection_start
982 || outer.projection_end < inner.projection_start
983 || inner.projection_start != outer.projection_start,
984 "CTE and outer SELECT projections should have different spans"
985 );
986 }
987
988 #[test]
989 fn test_extract_lateral_aliases_projection_span_validity() {
990 let sql = "SELECT a AS x, b AS y FROM t";
992 let stmts = parse_sql(sql);
993 let aliases = extract_lateral_aliases(&stmts, sql);
994
995 assert_eq!(aliases.len(), 2);
996
997 for alias in &aliases {
998 assert!(
1000 alias.definition_end <= alias.projection_end,
1001 "Alias definition should be within projection span"
1002 );
1003 assert!(
1004 alias.projection_start < alias.definition_end,
1005 "Projection should start before alias definition ends"
1006 );
1007 }
1008 }
1009}