flowscope_core/completion/
ast_extractor.rs

1//! AST context extraction for hybrid SQL completion.
2//!
3//! This module extracts completion-relevant context from parsed SQL ASTs,
4//! including CTEs, table aliases, and subquery aliases with their columns.
5
6use sqlparser::ast::{
7    Cte, Expr, Query, Select, SelectItem, SetExpr, Spanned, Statement, TableFactor, TableWithJoins,
8};
9
10use crate::analyzer::helpers::{infer_expr_type, line_col_to_offset};
11use crate::types::{AstColumnInfo, AstContext, AstTableInfo, CteInfo, SubqueryInfo};
12
13/// Information about a SELECT alias for lateral reference support.
14///
15/// Lateral column aliases allow referencing aliases defined earlier in the same
16/// SELECT list. This is supported by dialects like DuckDB, BigQuery, and Snowflake.
17///
18/// # Scope Tracking
19///
20/// Each alias tracks the byte offset range of its containing SELECT's projection.
21/// This is used to ensure aliases from CTEs or subqueries don't leak into outer
22/// SELECT scopes.
23#[derive(Debug, Clone)]
24pub struct LateralAliasInfo {
25    /// The alias name
26    pub name: String,
27    /// Byte offset where the alias definition ends
28    pub definition_end: usize,
29    /// Byte offset where the containing SELECT's projection starts (after SELECT keyword)
30    pub projection_start: usize,
31    /// Byte offset where the containing SELECT's projection ends (before FROM/WHERE/etc.)
32    pub projection_end: usize,
33}
34
35/// Maximum recursion depth for AST traversal to prevent stack overflow.
36/// This is a defensive limit - realistic SQL rarely exceeds 10-20 levels of nesting.
37const MAX_EXTRACTION_DEPTH: usize = 50;
38
39/// Maximum number of lateral aliases to extract.
40/// This prevents memory exhaustion from malicious input with thousands of aliases.
41/// 1000 aliases is far beyond any realistic SQL query.
42const MAX_LATERAL_ALIASES: usize = 1000;
43
44/// Extract AST context from parsed statements for completion enrichment.
45///
46/// Extracts:
47/// - CTE definitions and their columns
48/// - Table aliases and their resolved names
49/// - Subquery aliases and their projected columns
50pub(crate) fn extract_ast_context(statements: &[Statement]) -> AstContext {
51    let mut ctx = AstContext::default();
52
53    for stmt in statements {
54        extract_from_statement(stmt, &mut ctx, 0);
55    }
56
57    ctx
58}
59
60/// Extract context from a single statement
61fn extract_from_statement(stmt: &Statement, ctx: &mut AstContext, depth: usize) {
62    if depth > MAX_EXTRACTION_DEPTH {
63        return; // Silent truncation is acceptable for completion
64    }
65
66    match stmt {
67        Statement::Query(query) => {
68            extract_from_query(query, ctx, depth);
69        }
70        Statement::Insert(insert) => {
71            // Extract from INSERT ... SELECT
72            if let Some(source) = &insert.source {
73                extract_from_query(source, ctx, depth);
74            }
75        }
76        Statement::CreateTable(ct) => {
77            // Extract from CREATE TABLE ... AS SELECT
78            if let Some(query) = &ct.query {
79                extract_from_query(query, ctx, depth);
80            }
81        }
82        Statement::CreateView { query, .. } => {
83            extract_from_query(query, ctx, depth);
84        }
85        _ => {}
86    }
87}
88
89/// Extract context from a Query (SELECT, UNION, etc.)
90fn extract_from_query(query: &Query, ctx: &mut AstContext, depth: usize) {
91    if depth > MAX_EXTRACTION_DEPTH {
92        return;
93    }
94
95    // Extract CTEs first (they're in scope for the body)
96    if let Some(with) = &query.with {
97        let is_recursive = with.recursive;
98        for cte in &with.cte_tables {
99            if let Some(info) = extract_cte_info(cte, is_recursive) {
100                ctx.cte_definitions.insert(info.name.clone(), info);
101            }
102        }
103    }
104
105    // Extract from the query body
106    extract_from_set_expr(&query.body, ctx, depth + 1);
107}
108
109/// Extract context from a SetExpr (SELECT, UNION, etc.)
110fn extract_from_set_expr(set_expr: &SetExpr, ctx: &mut AstContext, depth: usize) {
111    if depth > MAX_EXTRACTION_DEPTH {
112        return;
113    }
114
115    match set_expr {
116        SetExpr::Select(select) => {
117            extract_from_select(select, ctx, depth);
118        }
119        SetExpr::Query(query) => {
120            extract_from_query(query, ctx, depth);
121        }
122        SetExpr::SetOperation { left, right, .. } => {
123            extract_from_set_expr(left, ctx, depth + 1);
124            extract_from_set_expr(right, ctx, depth + 1);
125        }
126        SetExpr::Values(_) => {}
127        SetExpr::Insert(_) => {}
128        SetExpr::Update(_) => {}
129        SetExpr::Table(_) => {}
130        SetExpr::Delete(_) => {}
131        SetExpr::Merge(_) => {}
132    }
133}
134
135/// Extract context from a SELECT statement
136fn extract_from_select(select: &Select, ctx: &mut AstContext, depth: usize) {
137    if depth > MAX_EXTRACTION_DEPTH {
138        return;
139    }
140
141    // Extract from FROM clause
142    for table_with_joins in &select.from {
143        extract_from_table_with_joins(table_with_joins, ctx, depth);
144    }
145}
146
147/// Extract context from a table with joins
148fn extract_from_table_with_joins(twj: &TableWithJoins, ctx: &mut AstContext, depth: usize) {
149    if depth > MAX_EXTRACTION_DEPTH {
150        return;
151    }
152
153    extract_from_table_factor(&twj.relation, ctx, depth);
154
155    for join in &twj.joins {
156        extract_from_table_factor(&join.relation, ctx, depth);
157    }
158}
159
160/// Extract context from a table factor (table reference)
161fn extract_from_table_factor(tf: &TableFactor, ctx: &mut AstContext, depth: usize) {
162    if depth > MAX_EXTRACTION_DEPTH {
163        return;
164    }
165
166    match tf {
167        TableFactor::Table { name, alias, .. } => {
168            let table_name = name.to_string();
169            let alias_name = alias.as_ref().map(|a| a.name.value.clone());
170
171            // Use alias if present, otherwise use table name
172            let key = alias_name.clone().unwrap_or_else(|| {
173                // Use just the table name part (last component)
174                name.0
175                    .last()
176                    .map(|i| i.to_string())
177                    .unwrap_or(table_name.clone())
178            });
179
180            ctx.table_aliases.insert(key, AstTableInfo);
181        }
182        TableFactor::Derived {
183            subquery, alias, ..
184        } => {
185            // Extract subquery info
186            if let Some(alias) = alias {
187                let columns = extract_projected_columns_from_query(subquery);
188                ctx.subquery_aliases.insert(
189                    alias.name.value.clone(),
190                    SubqueryInfo {
191                        projected_columns: columns,
192                    },
193                );
194            }
195
196            // Recurse into subquery
197            extract_from_query(subquery, ctx, depth + 1);
198        }
199        TableFactor::NestedJoin {
200            table_with_joins, ..
201        } => {
202            extract_from_table_with_joins(table_with_joins, ctx, depth + 1);
203        }
204        TableFactor::TableFunction { .. } => {}
205        TableFactor::UNNEST {
206            alias: Some(alias), ..
207        } => {
208            ctx.table_aliases
209                .insert(alias.name.value.clone(), AstTableInfo);
210        }
211        _ => {}
212    }
213}
214
215/// Extract CTE info from a CTE definition
216fn extract_cte_info(cte: &Cte, is_recursive: bool) -> Option<CteInfo> {
217    let name = cte.alias.name.value.clone();
218
219    // Get declared columns from alias
220    let declared_columns: Vec<String> = cte
221        .alias
222        .columns
223        .iter()
224        .map(|c| c.name.value.clone())
225        .collect();
226
227    // Get projected columns from CTE body
228    let projected_columns = if is_recursive {
229        // For recursive CTEs, only use the base case (first SELECT in UNION)
230        extract_base_case_columns(&cte.query)
231    } else {
232        extract_projected_columns_from_query(&cte.query)
233    };
234
235    Some(CteInfo {
236        name,
237        declared_columns,
238        projected_columns,
239    })
240}
241
242/// Extract columns from the base case of a recursive CTE
243fn extract_base_case_columns(query: &Query) -> Vec<AstColumnInfo> {
244    match &*query.body {
245        SetExpr::SetOperation { left, .. } => {
246            // In UNION, left is typically the base case
247            if let SetExpr::Select(select) = &**left {
248                extract_select_columns(select)
249            } else {
250                vec![]
251            }
252        }
253        SetExpr::Select(select) => extract_select_columns(select),
254        _ => vec![],
255    }
256}
257
258/// Extract projected columns from a query
259fn extract_projected_columns_from_query(query: &Query) -> Vec<AstColumnInfo> {
260    match &*query.body {
261        SetExpr::Select(select) => extract_select_columns(select),
262        SetExpr::SetOperation { left, .. } => {
263            // Use left side's columns for UNION
264            if let SetExpr::Select(select) = &**left {
265                extract_select_columns(select)
266            } else {
267                vec![]
268            }
269        }
270        _ => vec![],
271    }
272}
273
274/// Extract columns from a SELECT's projection
275fn extract_select_columns(select: &Select) -> Vec<AstColumnInfo> {
276    let mut columns = Vec::new();
277
278    for (idx, item) in select.projection.iter().enumerate() {
279        match item {
280            SelectItem::ExprWithAlias { alias, expr } => {
281                columns.push(AstColumnInfo {
282                    name: alias.value.clone(),
283                    data_type: infer_data_type(expr),
284                });
285            }
286            SelectItem::UnnamedExpr(expr) => {
287                columns.push(AstColumnInfo {
288                    name: derive_column_name(expr, idx),
289                    data_type: infer_data_type(expr),
290                });
291            }
292            SelectItem::Wildcard(_) => {
293                columns.push(AstColumnInfo {
294                    name: "*".to_string(),
295                    data_type: None,
296                });
297            }
298            SelectItem::QualifiedWildcard(name, _) => {
299                columns.push(AstColumnInfo {
300                    name: format!("{}.*", name),
301                    data_type: None,
302                });
303            }
304        }
305    }
306
307    columns
308}
309
310/// Derive column name from expression
311fn derive_column_name(expr: &Expr, index: usize) -> String {
312    match expr {
313        Expr::Identifier(ident) => ident.value.clone(),
314        Expr::CompoundIdentifier(parts) => parts
315            .last()
316            .map(|i| i.value.clone())
317            .unwrap_or_else(|| format!("col_{}", index)),
318        Expr::Function(func) => func.name.to_string().to_lowercase(),
319        Expr::Cast { .. } => format!("col_{}", index),
320        Expr::Case { .. } => format!("case_{}", index),
321        Expr::Subquery(_) => format!("subquery_{}", index),
322        _ => format!("col_{}", index),
323    }
324}
325
326/// Infer data type from expression using the analyzer's centralized type inference.
327///
328/// Returns the canonical type name in uppercase (e.g., "TEXT", "INTEGER", "FLOAT").
329/// See [`crate::analyzer::helpers::infer_expr_type`] for supported expressions and behavior.
330fn infer_data_type(expr: &Expr) -> Option<String> {
331    infer_expr_type(expr).map(|canonical| canonical.as_uppercase_str().to_string())
332}
333
334/// Extracts SELECT aliases with their positions from parsed statements.
335///
336/// Returns aliases that appear in SELECT projections, along with the byte offset
337/// where each alias definition ends. This is used for lateral alias completion
338/// in dialects that support referencing earlier aliases in the same SELECT list.
339///
340/// # Arguments
341///
342/// * `statements` - Parsed SQL statements
343/// * `sql` - Original SQL source text (needed for line:column to byte offset conversion)
344///
345/// # Returns
346///
347/// A vector of lateral alias info, ordered by position in the source.
348/// Limited to `MAX_LATERAL_ALIASES` to prevent memory exhaustion.
349pub(crate) fn extract_lateral_aliases(
350    statements: &[Statement],
351    sql: &str,
352) -> Vec<LateralAliasInfo> {
353    let mut aliases = Vec::with_capacity(64); // Reasonable starting capacity
354
355    for stmt in statements {
356        // Stop extraction if we've hit the limit
357        if aliases.len() >= MAX_LATERAL_ALIASES {
358            break;
359        }
360
361        if let Statement::Query(query) = stmt {
362            // Extract from top-level CTEs first
363            if let Some(with) = &query.with {
364                for cte in &with.cte_tables {
365                    if aliases.len() >= MAX_LATERAL_ALIASES {
366                        break;
367                    }
368                    extract_lateral_aliases_from_set_expr(&cte.query.body, sql, &mut aliases, 0);
369                }
370            }
371            // Then extract from the main query body
372            if aliases.len() < MAX_LATERAL_ALIASES {
373                extract_lateral_aliases_from_set_expr(&query.body, sql, &mut aliases, 0);
374            }
375        }
376    }
377
378    aliases
379}
380
381/// Extract lateral aliases from a SetExpr (handles SELECT, nested Query, and set operations).
382///
383/// Note: This intentionally extracts from all SELECT projections, including CTEs.
384/// The filtering by cursor position happens in the consumer (context.rs) using
385/// the projection_start/projection_end fields.
386fn extract_lateral_aliases_from_set_expr(
387    set_expr: &SetExpr,
388    sql: &str,
389    aliases: &mut Vec<LateralAliasInfo>,
390    depth: usize,
391) {
392    if depth > MAX_EXTRACTION_DEPTH || aliases.len() >= MAX_LATERAL_ALIASES {
393        return;
394    }
395
396    match set_expr {
397        SetExpr::Select(select) => {
398            extract_lateral_aliases_from_select(select, sql, aliases);
399        }
400        SetExpr::Query(query) => {
401            // Handle CTEs within nested queries
402            if let Some(with) = &query.with {
403                for cte in &with.cte_tables {
404                    if aliases.len() >= MAX_LATERAL_ALIASES {
405                        break;
406                    }
407                    extract_lateral_aliases_from_set_expr(&cte.query.body, sql, aliases, depth + 1);
408                }
409            }
410            if aliases.len() < MAX_LATERAL_ALIASES {
411                extract_lateral_aliases_from_set_expr(&query.body, sql, aliases, depth + 1);
412            }
413        }
414        SetExpr::SetOperation { left, right, .. } => {
415            extract_lateral_aliases_from_set_expr(left, sql, aliases, depth + 1);
416            if aliases.len() < MAX_LATERAL_ALIASES {
417                extract_lateral_aliases_from_set_expr(right, sql, aliases, depth + 1);
418            }
419        }
420        _ => {}
421    }
422}
423
424/// Extract lateral aliases from a SELECT's projection.
425///
426/// Iterates over the projection items and records aliases with their positions,
427/// including the span of the containing SELECT projection for scope filtering.
428///
429/// # Safety Notes
430///
431/// The function validates that computed byte offsets are:
432/// 1. Within bounds of the SQL string
433/// 2. At valid UTF-8 character boundaries
434///
435/// Aliases with invalid offsets (e.g., from parser bugs or multi-byte char issues)
436/// are silently skipped to prevent panics.
437fn extract_lateral_aliases_from_select(
438    select: &Select,
439    sql: &str,
440    aliases: &mut Vec<LateralAliasInfo>,
441) {
442    // Early return if we've hit the extraction limit
443    if aliases.len() >= MAX_LATERAL_ALIASES {
444        return;
445    }
446
447    // Compute projection span from the first and last projection items
448    // This is used to filter aliases to only those in the same SELECT scope as cursor
449    let projection_span = compute_projection_span(select, sql);
450    let (projection_start, projection_end) = match projection_span {
451        Some((start, end)) => (start, end),
452        None => return, // Can't determine span, skip this SELECT
453    };
454
455    for item in &select.projection {
456        // Check limit on each iteration to stop early
457        if aliases.len() >= MAX_LATERAL_ALIASES {
458            break;
459        }
460
461        if let SelectItem::ExprWithAlias { alias, .. } = item {
462            // Convert line:column to byte offset
463            // sqlparser uses 1-indexed line/column numbers
464            if let Some(end_offset) = line_col_to_offset(
465                sql,
466                alias.span.end.line as usize,
467                alias.span.end.column as usize,
468            ) {
469                // Validate offset is within bounds and at a valid UTF-8 boundary
470                // This prevents panics on multi-byte characters (emoji, unicode identifiers)
471                if end_offset <= sql.len() && sql.is_char_boundary(end_offset) {
472                    aliases.push(LateralAliasInfo {
473                        name: alias.value.clone(),
474                        definition_end: end_offset,
475                        projection_start,
476                        projection_end,
477                    });
478                }
479            }
480        }
481    }
482}
483
484/// Compute the byte offset span of a SELECT's projection area.
485///
486/// The projection area starts at the first projection item and extends to:
487/// - The FROM clause (if present)
488/// - Or the end of the SELECT (if no FROM)
489///
490/// This allows cursor positions after the last alias (like `SELECT a AS x, |`)
491/// to still be considered "within" the projection area for lateral alias purposes.
492///
493/// Returns (start, end) offsets, or None if the span cannot be determined.
494fn compute_projection_span(select: &Select, sql: &str) -> Option<(usize, usize)> {
495    if select.projection.is_empty() {
496        return None;
497    }
498
499    // Find the first projection item that has a usable span. This skips leading
500    // wildcards (plain `*`), which deliberately return None in select_item_span.
501    // When every item lacks a span (e.g., `SELECT *`), fall back to the SELECT
502    // keyword span so we still have a safe projection boundary.
503    let first_span = select
504        .projection
505        .iter()
506        .filter_map(select_item_span)
507        .next()
508        .or_else(|| {
509            let span = select.span();
510            if span.start.line > 0 && span.start.column > 0 {
511                Some((span.start.line, span.start.column))
512            } else {
513                None
514            }
515        })?;
516    let start = line_col_to_offset(sql, first_span.0 as usize, first_span.1 as usize)?;
517
518    // Determine the end of the projection area
519    // Prefer FROM clause position if available, otherwise use last item
520    let end = if let Some(from_item) = select.from.first() {
521        // Use the start of the FROM clause as the end of projection area
522        compute_from_clause_start(from_item, sql).unwrap_or_else(|| {
523            // Fallback to last projection item
524            select
525                .projection
526                .last()
527                .and_then(|item| {
528                    let span = select_item_end_span(item)?;
529                    line_col_to_offset(sql, span.0 as usize, span.1 as usize)
530                })
531                .unwrap_or(sql.len())
532        })
533    } else {
534        // No FROM clause - use a large value to include trailing positions
535        // This handles cases like `SELECT a AS x, |` (no FROM yet)
536        sql.len()
537    };
538
539    // Validate spans
540    if start <= sql.len() && end <= sql.len() && start <= end {
541        Some((start, end))
542    } else {
543        None
544    }
545}
546
547/// Get the byte offset where the FROM clause starts.
548fn compute_from_clause_start(from_item: &TableWithJoins, sql: &str) -> Option<usize> {
549    // Get the span of the first table in FROM
550    let span = table_factor_span(&from_item.relation)?;
551    let table_start = line_col_to_offset(sql, span.0 as usize, span.1 as usize)?;
552
553    // Search backwards from table_start to find "FROM" keyword
554    // This gives us the actual start of the FROM clause
555    //
556    // Safety: We need to find a valid UTF-8 char boundary before slicing.
557    // saturating_sub(50) might land in the middle of a multi-byte character.
558    let search_start = find_char_boundary_before(sql, table_start.saturating_sub(50));
559    let search_area = &sql[search_start..table_start];
560
561    // Find "FROM" (case insensitive) using ASCII-only comparison.
562    // We avoid to_uppercase() because it can change string length for certain
563    // Unicode characters (e.g., German ß -> SS), causing position misalignment.
564    if let Some(pos) = rfind_ascii_case_insensitive(search_area, b"FROM") {
565        Some(search_start + pos)
566    } else {
567        // Fallback to table start if FROM not found
568        Some(table_start)
569    }
570}
571
572/// Find a valid UTF-8 char boundary at or before the given position.
573/// Returns 0 if no valid boundary is found before the position.
574fn find_char_boundary_before(s: &str, pos: usize) -> usize {
575    if pos >= s.len() {
576        return s.len();
577    }
578    // Walk backwards from pos to find a valid char boundary
579    (0..=pos)
580        .rev()
581        .find(|&i| s.is_char_boundary(i))
582        .unwrap_or(0)
583}
584
585/// Case-insensitive reverse search for an ASCII pattern in a string slice.
586/// Returns the byte offset of the last occurrence, or None if not found.
587///
588/// This is safe for UTF-8 strings because ASCII bytes (0x00-0x7F) never appear
589/// as continuation bytes in multi-byte UTF-8 sequences.
590fn rfind_ascii_case_insensitive(haystack: &str, needle: &[u8]) -> Option<usize> {
591    if needle.is_empty() || haystack.len() < needle.len() {
592        return None;
593    }
594
595    let haystack_bytes = haystack.as_bytes();
596
597    // Search from the end backwards
598    for start in (0..=(haystack_bytes.len() - needle.len())).rev() {
599        let mut matches = true;
600        for (i, &needle_byte) in needle.iter().enumerate() {
601            let hay_byte = haystack_bytes[start + i];
602            // ASCII case-insensitive comparison
603            if !hay_byte.eq_ignore_ascii_case(&needle_byte) {
604                matches = false;
605                break;
606            }
607        }
608        if matches {
609            return Some(start);
610        }
611    }
612    None
613}
614
615/// Get the start position (line, column) of a TableFactor.
616fn table_factor_span(tf: &TableFactor) -> Option<(u64, u64)> {
617    match tf {
618        TableFactor::Table { name, .. } => name.0.first().map(|i| {
619            let span = i.span();
620            (span.start.line, span.start.column)
621        }),
622        TableFactor::Derived { subquery, .. } => {
623            // For derived tables, try to get the span of the subquery
624            let span = subquery.body.span();
625            if span.start.line > 0 {
626                Some((span.start.line, span.start.column))
627            } else {
628                None
629            }
630        }
631        _ => None,
632    }
633}
634
635/// Get the start position (line, column) of a SelectItem.
636fn select_item_span(item: &SelectItem) -> Option<(u64, u64)> {
637    match item {
638        SelectItem::ExprWithAlias { expr, .. } | SelectItem::UnnamedExpr(expr) => {
639            expr_start_span(expr)
640        }
641        SelectItem::Wildcard(opts) => {
642            // Wildcard span comes from the options if available
643            if let Some(exclude) = &opts.opt_exclude {
644                // Use first exclusion's span if available
645                match exclude {
646                    sqlparser::ast::ExcludeSelectItem::Single(ident) => {
647                        Some((ident.span.start.line, ident.span.start.column))
648                    }
649                    sqlparser::ast::ExcludeSelectItem::Multiple(idents) => idents
650                        .first()
651                        .map(|i| (i.span.start.line, i.span.start.column)),
652                }
653            } else {
654                None
655            }
656        }
657        SelectItem::QualifiedWildcard(name, _) => {
658            let span = name.span();
659            Some((span.start.line, span.start.column))
660        }
661    }
662}
663
664/// Get the end position (line, column) of a SelectItem.
665fn select_item_end_span(item: &SelectItem) -> Option<(u64, u64)> {
666    match item {
667        SelectItem::ExprWithAlias { alias, .. } => {
668            Some((alias.span.end.line, alias.span.end.column))
669        }
670        SelectItem::UnnamedExpr(expr) => expr_end_span(expr),
671        SelectItem::Wildcard(_) => None, // Wildcard doesn't have reliable end span
672        SelectItem::QualifiedWildcard(name, _) => {
673            let span = name.span();
674            Some((span.end.line, span.end.column))
675        }
676    }
677}
678
679/// Get the start position (line, column) of an expression.
680/// Uses the Spanned trait for comprehensive coverage of all expression types.
681fn expr_start_span(expr: &Expr) -> Option<(u64, u64)> {
682    let span = expr.span();
683    // Check for valid span (non-zero positions indicate valid span)
684    if span.start.line > 0 && span.start.column > 0 {
685        Some((span.start.line, span.start.column))
686    } else {
687        None
688    }
689}
690
691/// Get the end position (line, column) of an expression.
692/// Uses the Spanned trait for comprehensive coverage of all expression types.
693fn expr_end_span(expr: &Expr) -> Option<(u64, u64)> {
694    let span = expr.span();
695    // Check for valid span (non-zero positions indicate valid span)
696    if span.end.line > 0 && span.end.column > 0 {
697        Some((span.end.line, span.end.column))
698    } else {
699        None
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use sqlparser::parser::Parser;
707
708    fn parse_sql(sql: &str) -> Vec<Statement> {
709        Parser::parse_sql(&sqlparser::dialect::GenericDialect {}, sql).unwrap()
710    }
711
712    #[test]
713    fn test_extract_cte() {
714        let sql = "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte";
715        let stmts = parse_sql(sql);
716        let ctx = extract_ast_context(&stmts);
717
718        assert!(ctx.cte_definitions.contains_key("cte"));
719        let cte = &ctx.cte_definitions["cte"];
720        assert_eq!(cte.name, "cte");
721        assert_eq!(cte.projected_columns.len(), 2);
722        assert_eq!(cte.projected_columns[0].name, "id");
723        assert_eq!(cte.projected_columns[1].name, "name");
724    }
725
726    #[test]
727    fn test_extract_cte_with_declared_columns() {
728        let sql = "WITH cte(a, b) AS (SELECT id, name FROM users) SELECT * FROM cte";
729        let stmts = parse_sql(sql);
730        let ctx = extract_ast_context(&stmts);
731
732        let cte = &ctx.cte_definitions["cte"];
733        assert_eq!(cte.declared_columns, vec!["a", "b"]);
734    }
735
736    #[test]
737    fn test_extract_table_alias() {
738        let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
739        let stmts = parse_sql(sql);
740        let ctx = extract_ast_context(&stmts);
741
742        // Aliases are stored as keys in the table_aliases map
743        assert!(ctx.table_aliases.contains_key("u"));
744        assert!(ctx.table_aliases.contains_key("o"));
745    }
746
747    #[test]
748    fn test_extract_subquery_alias() {
749        let sql = "SELECT * FROM (SELECT a, b FROM t) AS sub WHERE sub.a = 1";
750        let stmts = parse_sql(sql);
751        let ctx = extract_ast_context(&stmts);
752
753        assert!(ctx.subquery_aliases.contains_key("sub"));
754        let sub = &ctx.subquery_aliases["sub"];
755        assert_eq!(sub.projected_columns.len(), 2);
756        assert_eq!(sub.projected_columns[0].name, "a");
757        assert_eq!(sub.projected_columns[1].name, "b");
758    }
759
760    #[test]
761    fn test_extract_lateral_subquery() {
762        let sql = "SELECT * FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) AS o";
763        let stmts = parse_sql(sql);
764        let ctx = extract_ast_context(&stmts);
765
766        // Lateral subqueries are extracted just like regular derived tables
767        assert!(ctx.subquery_aliases.contains_key("o"));
768    }
769
770    #[test]
771    fn test_extract_column_with_alias() {
772        let sql =
773            "WITH cte AS (SELECT id AS user_id, name AS user_name FROM users) SELECT * FROM cte";
774        let stmts = parse_sql(sql);
775        let ctx = extract_ast_context(&stmts);
776
777        let cte = &ctx.cte_definitions["cte"];
778        assert_eq!(cte.projected_columns[0].name, "user_id");
779        assert_eq!(cte.projected_columns[1].name, "user_name");
780    }
781
782    #[test]
783    fn test_extract_function_column_name() {
784        let sql = "WITH cte AS (SELECT COUNT(*), SUM(amount) FROM orders) SELECT * FROM cte";
785        let stmts = parse_sql(sql);
786        let ctx = extract_ast_context(&stmts);
787
788        let cte = &ctx.cte_definitions["cte"];
789        assert!(cte.projected_columns[0]
790            .name
791            .to_lowercase()
792            .contains("count"));
793    }
794
795    #[test]
796    fn test_extract_wildcard() {
797        let sql = "WITH cte AS (SELECT * FROM users) SELECT * FROM cte";
798        let stmts = parse_sql(sql);
799        let ctx = extract_ast_context(&stmts);
800
801        let cte = &ctx.cte_definitions["cte"];
802        assert_eq!(cte.projected_columns[0].name, "*");
803    }
804
805    #[test]
806    fn test_extract_recursive_cte() {
807        let sql = r#"
808            WITH RECURSIVE cte AS (
809                SELECT 1 AS n
810                UNION ALL
811                SELECT n + 1 FROM cte WHERE n < 10
812            )
813            SELECT * FROM cte
814        "#;
815        let stmts = parse_sql(sql);
816        let ctx = extract_ast_context(&stmts);
817
818        let cte = &ctx.cte_definitions["cte"];
819        // Should have column from base case
820        assert_eq!(cte.projected_columns.len(), 1);
821        assert_eq!(cte.projected_columns[0].name, "n");
822    }
823
824    #[test]
825    fn test_has_enrichment() {
826        let sql = "SELECT * FROM users";
827        let stmts = parse_sql(sql);
828        let ctx = extract_ast_context(&stmts);
829
830        assert!(ctx.has_enrichment()); // Has table alias
831    }
832
833    #[test]
834    fn test_empty_context() {
835        let ctx = AstContext::default();
836        assert!(!ctx.has_enrichment());
837    }
838
839    // Lateral alias extraction tests
840
841    #[test]
842    fn test_extract_lateral_aliases_single() {
843        let sql = "SELECT price * qty AS total FROM orders";
844        let stmts = parse_sql(sql);
845        let aliases = extract_lateral_aliases(&stmts, sql);
846
847        assert_eq!(aliases.len(), 1);
848        assert_eq!(aliases[0].name, "total");
849        // The alias ends after "total" (position should be after the alias)
850        assert!(aliases[0].definition_end > 0);
851        assert!(aliases[0].definition_end <= sql.len());
852    }
853
854    #[test]
855    fn test_extract_lateral_aliases_with_leading_wildcard() {
856        let sql = "SELECT *, price * qty AS total, discount AS disc FROM orders";
857        let stmts = parse_sql(sql);
858        let aliases = extract_lateral_aliases(&stmts, sql);
859
860        let names: Vec<_> = aliases.iter().map(|a| a.name.as_str()).collect();
861        assert_eq!(names, vec!["total", "disc"]);
862    }
863
864    #[test]
865    fn test_extract_lateral_aliases_multiple() {
866        let sql = "SELECT a AS x, b AS y, c AS z FROM t";
867        let stmts = parse_sql(sql);
868        let aliases = extract_lateral_aliases(&stmts, sql);
869
870        assert_eq!(aliases.len(), 3);
871        assert_eq!(aliases[0].name, "x");
872        assert_eq!(aliases[1].name, "y");
873        assert_eq!(aliases[2].name, "z");
874        // Aliases should be ordered by position
875        assert!(aliases[0].definition_end < aliases[1].definition_end);
876        assert!(aliases[1].definition_end < aliases[2].definition_end);
877    }
878
879    #[test]
880    fn test_extract_lateral_aliases_with_expression() {
881        let sql = "SELECT price * qty AS total, total * 0.1 AS tax FROM orders";
882        let stmts = parse_sql(sql);
883        let aliases = extract_lateral_aliases(&stmts, sql);
884
885        assert_eq!(aliases.len(), 2);
886        assert_eq!(aliases[0].name, "total");
887        assert_eq!(aliases[1].name, "tax");
888    }
889
890    #[test]
891    fn test_extract_lateral_aliases_no_aliases() {
892        let sql = "SELECT price, qty FROM orders";
893        let stmts = parse_sql(sql);
894        let aliases = extract_lateral_aliases(&stmts, sql);
895
896        assert!(aliases.is_empty());
897    }
898
899    #[test]
900    fn test_extract_lateral_aliases_mixed() {
901        // Mix of aliased and non-aliased columns
902        let sql = "SELECT a, b AS alias_b, c FROM t";
903        let stmts = parse_sql(sql);
904        let aliases = extract_lateral_aliases(&stmts, sql);
905
906        assert_eq!(aliases.len(), 1);
907        assert_eq!(aliases[0].name, "alias_b");
908    }
909
910    #[test]
911    fn test_extract_lateral_aliases_quoted() {
912        let sql = r#"SELECT a AS "My Total", b AS "Tax Amount" FROM t"#;
913        let stmts = parse_sql(sql);
914        let aliases = extract_lateral_aliases(&stmts, sql);
915
916        assert_eq!(aliases.len(), 2);
917        assert_eq!(aliases[0].name, "My Total");
918        assert_eq!(aliases[1].name, "Tax Amount");
919    }
920
921    #[test]
922    fn test_extract_lateral_aliases_subquery_in_from() {
923        // Aliases in subqueries in FROM clause should NOT be extracted,
924        // because lateral aliases are only visible within the same SELECT list
925        let sql = "SELECT * FROM (SELECT a AS x, b AS y FROM t) sub";
926        let stmts = parse_sql(sql);
927        let aliases = extract_lateral_aliases(&stmts, sql);
928
929        // The outer SELECT has SELECT * which has no aliases
930        assert_eq!(aliases.len(), 0);
931    }
932
933    #[test]
934    fn test_extract_lateral_aliases_outer_select_with_alias() {
935        // If the outer SELECT has aliases, those should be extracted
936        let sql = "SELECT sub.x AS outer_x FROM (SELECT a AS x FROM t) sub";
937        let stmts = parse_sql(sql);
938        let aliases = extract_lateral_aliases(&stmts, sql);
939
940        assert_eq!(aliases.len(), 1);
941        assert_eq!(aliases[0].name, "outer_x");
942    }
943
944    #[test]
945    fn test_extract_lateral_aliases_with_unicode() {
946        // Unicode characters in SQL should not cause panics
947        // The parser may report byte offsets that don't align with char boundaries
948        let sql = "SELECT '日本語' AS label, value AS val FROM t";
949        let stmts = parse_sql(sql);
950        let aliases = extract_lateral_aliases(&stmts, sql);
951
952        // Should successfully extract aliases even with unicode in the SQL
953        assert_eq!(aliases.len(), 2);
954        assert_eq!(aliases[0].name, "label");
955        assert_eq!(aliases[1].name, "val");
956    }
957
958    #[test]
959    fn test_extract_lateral_aliases_cte_scope_isolation() {
960        // Aliases from CTE's SELECT should have different projection spans
961        // than aliases from the outer SELECT
962        let sql =
963            "WITH cte AS (SELECT a AS inner_alias FROM t) SELECT cte.a AS outer_alias FROM cte";
964        let stmts = parse_sql(sql);
965        let aliases = extract_lateral_aliases(&stmts, sql);
966
967        // Should extract both aliases, but they should have different projection spans
968        assert_eq!(aliases.len(), 2);
969
970        let inner = aliases.iter().find(|a| a.name == "inner_alias").unwrap();
971        let outer = aliases.iter().find(|a| a.name == "outer_alias").unwrap();
972
973        // Inner alias projection should be before outer alias projection
974        assert!(
975            inner.projection_start < outer.projection_start,
976            "CTE projection should start before outer SELECT projection"
977        );
978
979        // Projection spans should not overlap significantly
980        assert!(
981            inner.projection_end < outer.projection_start
982                || outer.projection_end < inner.projection_start
983                || inner.projection_start != outer.projection_start,
984            "CTE and outer SELECT projections should have different spans"
985        );
986    }
987
988    #[test]
989    fn test_extract_lateral_aliases_projection_span_validity() {
990        // Verify that projection spans are valid and contain the alias
991        let sql = "SELECT a AS x, b AS y FROM t";
992        let stmts = parse_sql(sql);
993        let aliases = extract_lateral_aliases(&stmts, sql);
994
995        assert_eq!(aliases.len(), 2);
996
997        for alias in &aliases {
998            // Projection span should contain the alias definition
999            assert!(
1000                alias.definition_end <= alias.projection_end,
1001                "Alias definition should be within projection span"
1002            );
1003            assert!(
1004                alias.projection_start < alias.definition_end,
1005                "Projection should start before alias definition ends"
1006            );
1007        }
1008    }
1009}