Skip to main content

flowscope_core/completion/
ast_extractor.rs

1//! AST context extraction for hybrid SQL completion.
2//!
3//! This module extracts completion-relevant context from parsed SQL ASTs,
4//! including CTEs, table aliases, and subquery aliases with their columns.
5
6use sqlparser::ast::{
7    CreateView, Cte, Expr, Query, Select, SelectItem, SetExpr, Spanned, Statement, TableFactor,
8    TableWithJoins,
9};
10
11use crate::analyzer::helpers::{infer_expr_type, line_col_to_offset};
12use crate::types::{AstColumnInfo, AstContext, AstTableInfo, CteInfo, SubqueryInfo};
13
14/// Information about a SELECT alias for lateral reference support.
15///
16/// Lateral column aliases allow referencing aliases defined earlier in the same
17/// SELECT list. This is supported by dialects like DuckDB, BigQuery, and Snowflake.
18///
19/// # Scope Tracking
20///
21/// Each alias tracks the byte offset range of its containing SELECT's projection.
22/// This is used to ensure aliases from CTEs or subqueries don't leak into outer
23/// SELECT scopes.
24#[derive(Debug, Clone)]
25pub struct LateralAliasInfo {
26    /// The alias name
27    pub name: String,
28    /// Byte offset where the alias definition ends
29    pub definition_end: usize,
30    /// Byte offset where the containing SELECT's projection starts (after SELECT keyword)
31    pub projection_start: usize,
32    /// Byte offset where the containing SELECT's projection ends (before FROM/WHERE/etc.)
33    pub projection_end: usize,
34}
35
36/// Maximum recursion depth for AST traversal to prevent stack overflow.
37/// This is a defensive limit - realistic SQL rarely exceeds 10-20 levels of nesting.
38const MAX_EXTRACTION_DEPTH: usize = 50;
39
40/// Maximum number of lateral aliases to extract.
41/// This prevents memory exhaustion from malicious input with thousands of aliases.
42/// 1000 aliases is far beyond any realistic SQL query.
43const MAX_LATERAL_ALIASES: usize = 1000;
44
45/// Extract AST context from parsed statements for completion enrichment.
46///
47/// Extracts:
48/// - CTE definitions and their columns
49/// - Table aliases and their resolved names
50/// - Subquery aliases and their projected columns
51pub(crate) fn extract_ast_context(statements: &[Statement]) -> AstContext {
52    let mut ctx = AstContext::default();
53
54    for stmt in statements {
55        extract_from_statement(stmt, &mut ctx, 0);
56    }
57
58    ctx
59}
60
61/// Extract context from a single statement
62fn extract_from_statement(stmt: &Statement, ctx: &mut AstContext, depth: usize) {
63    if depth > MAX_EXTRACTION_DEPTH {
64        return; // Silent truncation is acceptable for completion
65    }
66
67    match stmt {
68        Statement::Query(query) => {
69            extract_from_query(query, ctx, depth);
70        }
71        Statement::Insert(insert) => {
72            // Extract from INSERT ... SELECT
73            if let Some(source) = &insert.source {
74                extract_from_query(source, ctx, depth);
75            }
76        }
77        Statement::CreateTable(ct) => {
78            // Extract from CREATE TABLE ... AS SELECT
79            if let Some(query) = &ct.query {
80                extract_from_query(query, ctx, depth);
81            }
82        }
83        Statement::CreateView(CreateView { query, .. }) => {
84            extract_from_query(query, ctx, depth);
85        }
86        _ => {}
87    }
88}
89
90/// Extract context from a Query (SELECT, UNION, etc.)
91fn extract_from_query(query: &Query, ctx: &mut AstContext, depth: usize) {
92    if depth > MAX_EXTRACTION_DEPTH {
93        return;
94    }
95
96    // Extract CTEs first (they're in scope for the body)
97    if let Some(with) = &query.with {
98        let is_recursive = with.recursive;
99        for cte in &with.cte_tables {
100            if let Some(info) = extract_cte_info(cte, is_recursive) {
101                ctx.cte_definitions.insert(info.name.clone(), info);
102            }
103        }
104    }
105
106    // Extract from the query body
107    extract_from_set_expr(&query.body, ctx, depth + 1);
108}
109
110/// Extract context from a SetExpr (SELECT, UNION, etc.)
111fn extract_from_set_expr(set_expr: &SetExpr, ctx: &mut AstContext, depth: usize) {
112    if depth > MAX_EXTRACTION_DEPTH {
113        return;
114    }
115
116    match set_expr {
117        SetExpr::Select(select) => {
118            extract_from_select(select, ctx, depth);
119        }
120        SetExpr::Query(query) => {
121            extract_from_query(query, ctx, depth);
122        }
123        SetExpr::SetOperation { left, right, .. } => {
124            extract_from_set_expr(left, ctx, depth + 1);
125            extract_from_set_expr(right, ctx, depth + 1);
126        }
127        SetExpr::Values(_) => {}
128        SetExpr::Insert(_) => {}
129        SetExpr::Update(_) => {}
130        SetExpr::Table(_) => {}
131        SetExpr::Delete(_) => {}
132        SetExpr::Merge(_) => {}
133    }
134}
135
136/// Extract context from a SELECT statement
137fn extract_from_select(select: &Select, ctx: &mut AstContext, depth: usize) {
138    if depth > MAX_EXTRACTION_DEPTH {
139        return;
140    }
141
142    // Extract from FROM clause
143    for table_with_joins in &select.from {
144        extract_from_table_with_joins(table_with_joins, ctx, depth);
145    }
146}
147
148/// Extract context from a table with joins
149fn extract_from_table_with_joins(twj: &TableWithJoins, ctx: &mut AstContext, depth: usize) {
150    if depth > MAX_EXTRACTION_DEPTH {
151        return;
152    }
153
154    extract_from_table_factor(&twj.relation, ctx, depth);
155
156    for join in &twj.joins {
157        extract_from_table_factor(&join.relation, ctx, depth);
158    }
159}
160
161/// Extract context from a table factor (table reference)
162fn extract_from_table_factor(tf: &TableFactor, ctx: &mut AstContext, depth: usize) {
163    if depth > MAX_EXTRACTION_DEPTH {
164        return;
165    }
166
167    match tf {
168        TableFactor::Table { name, alias, .. } => {
169            let table_name = name.to_string();
170            let alias_name = alias.as_ref().map(|a| a.name.value.clone());
171
172            // Use alias if present, otherwise use table name
173            let key = alias_name.clone().unwrap_or_else(|| {
174                // Use just the table name part (last component)
175                name.0
176                    .last()
177                    .map(|i| i.to_string())
178                    .unwrap_or(table_name.clone())
179            });
180
181            ctx.table_aliases.insert(key, AstTableInfo);
182        }
183        TableFactor::Derived {
184            subquery, alias, ..
185        } => {
186            // Extract subquery info
187            if let Some(alias) = alias {
188                let columns = extract_projected_columns_from_query(subquery);
189                ctx.subquery_aliases.insert(
190                    alias.name.value.clone(),
191                    SubqueryInfo {
192                        projected_columns: columns,
193                    },
194                );
195            }
196
197            // Recurse into subquery
198            extract_from_query(subquery, ctx, depth + 1);
199        }
200        TableFactor::NestedJoin {
201            table_with_joins, ..
202        } => {
203            extract_from_table_with_joins(table_with_joins, ctx, depth + 1);
204        }
205        TableFactor::TableFunction { .. } => {}
206        TableFactor::UNNEST {
207            alias: Some(alias), ..
208        } => {
209            ctx.table_aliases
210                .insert(alias.name.value.clone(), AstTableInfo);
211        }
212        _ => {}
213    }
214}
215
216/// Extract CTE info from a CTE definition
217fn extract_cte_info(cte: &Cte, is_recursive: bool) -> Option<CteInfo> {
218    let name = cte.alias.name.value.clone();
219
220    // Get declared columns from alias
221    let declared_columns: Vec<String> = cte
222        .alias
223        .columns
224        .iter()
225        .map(|c| c.name.value.clone())
226        .collect();
227
228    // Get projected columns from CTE body
229    let projected_columns = if is_recursive {
230        // For recursive CTEs, only use the base case (first SELECT in UNION)
231        extract_base_case_columns(&cte.query)
232    } else {
233        extract_projected_columns_from_query(&cte.query)
234    };
235
236    Some(CteInfo {
237        name,
238        declared_columns,
239        projected_columns,
240    })
241}
242
243/// Extract columns from the base case of a recursive CTE
244fn extract_base_case_columns(query: &Query) -> Vec<AstColumnInfo> {
245    match &*query.body {
246        SetExpr::SetOperation { left, .. } => {
247            // In UNION, left is typically the base case
248            if let SetExpr::Select(select) = &**left {
249                extract_select_columns(select)
250            } else {
251                vec![]
252            }
253        }
254        SetExpr::Select(select) => extract_select_columns(select),
255        _ => vec![],
256    }
257}
258
259/// Extract projected columns from a query
260fn extract_projected_columns_from_query(query: &Query) -> Vec<AstColumnInfo> {
261    match &*query.body {
262        SetExpr::Select(select) => extract_select_columns(select),
263        SetExpr::SetOperation { left, .. } => {
264            // Use left side's columns for UNION
265            if let SetExpr::Select(select) = &**left {
266                extract_select_columns(select)
267            } else {
268                vec![]
269            }
270        }
271        _ => vec![],
272    }
273}
274
275/// Extract columns from a SELECT's projection
276fn extract_select_columns(select: &Select) -> Vec<AstColumnInfo> {
277    let mut columns = Vec::new();
278
279    for (idx, item) in select.projection.iter().enumerate() {
280        match item {
281            SelectItem::ExprWithAlias { alias, expr } => {
282                columns.push(AstColumnInfo {
283                    name: alias.value.clone(),
284                    data_type: infer_data_type(expr),
285                });
286            }
287            SelectItem::UnnamedExpr(expr) => {
288                columns.push(AstColumnInfo {
289                    name: derive_column_name(expr, idx),
290                    data_type: infer_data_type(expr),
291                });
292            }
293            SelectItem::Wildcard(_) => {
294                columns.push(AstColumnInfo {
295                    name: "*".to_string(),
296                    data_type: None,
297                });
298            }
299            SelectItem::QualifiedWildcard(name, _) => {
300                columns.push(AstColumnInfo {
301                    name: format!("{}.*", name),
302                    data_type: None,
303                });
304            }
305        }
306    }
307
308    columns
309}
310
311/// Derive column name from expression
312fn derive_column_name(expr: &Expr, index: usize) -> String {
313    match expr {
314        Expr::Identifier(ident) => ident.value.clone(),
315        Expr::CompoundIdentifier(parts) => parts
316            .last()
317            .map(|i| i.value.clone())
318            .unwrap_or_else(|| format!("col_{}", index)),
319        Expr::Function(func) => func.name.to_string().to_lowercase(),
320        Expr::Cast { .. } => format!("col_{}", index),
321        Expr::Case { .. } => format!("case_{}", index),
322        Expr::Subquery(_) => format!("subquery_{}", index),
323        _ => format!("col_{}", index),
324    }
325}
326
327/// Infer data type from expression using the analyzer's centralized type inference.
328///
329/// Returns the canonical type name in uppercase (e.g., "TEXT", "INTEGER", "FLOAT").
330/// See [`crate::analyzer::helpers::infer_expr_type`] for supported expressions and behavior.
331fn infer_data_type(expr: &Expr) -> Option<String> {
332    infer_expr_type(expr).map(|canonical| canonical.as_uppercase_str().to_string())
333}
334
335/// Extracts SELECT aliases with their positions from parsed statements.
336///
337/// Returns aliases that appear in SELECT projections, along with the byte offset
338/// where each alias definition ends. This is used for lateral alias completion
339/// in dialects that support referencing earlier aliases in the same SELECT list.
340///
341/// # Arguments
342///
343/// * `statements` - Parsed SQL statements
344/// * `sql` - Original SQL source text (needed for line:column to byte offset conversion)
345///
346/// # Returns
347///
348/// A vector of lateral alias info, ordered by position in the source.
349/// Limited to `MAX_LATERAL_ALIASES` to prevent memory exhaustion.
350pub(crate) fn extract_lateral_aliases(
351    statements: &[Statement],
352    sql: &str,
353) -> Vec<LateralAliasInfo> {
354    let mut aliases = Vec::with_capacity(64); // Reasonable starting capacity
355
356    for stmt in statements {
357        // Stop extraction if we've hit the limit
358        if aliases.len() >= MAX_LATERAL_ALIASES {
359            break;
360        }
361
362        if let Statement::Query(query) = stmt {
363            // Extract from top-level CTEs first
364            if let Some(with) = &query.with {
365                for cte in &with.cte_tables {
366                    if aliases.len() >= MAX_LATERAL_ALIASES {
367                        break;
368                    }
369                    extract_lateral_aliases_from_set_expr(&cte.query.body, sql, &mut aliases, 0);
370                }
371            }
372            // Then extract from the main query body
373            if aliases.len() < MAX_LATERAL_ALIASES {
374                extract_lateral_aliases_from_set_expr(&query.body, sql, &mut aliases, 0);
375            }
376        }
377    }
378
379    aliases
380}
381
382/// Extract lateral aliases from a SetExpr (handles SELECT, nested Query, and set operations).
383///
384/// Note: This intentionally extracts from all SELECT projections, including CTEs.
385/// The filtering by cursor position happens in the consumer (context.rs) using
386/// the projection_start/projection_end fields.
387fn extract_lateral_aliases_from_set_expr(
388    set_expr: &SetExpr,
389    sql: &str,
390    aliases: &mut Vec<LateralAliasInfo>,
391    depth: usize,
392) {
393    if depth > MAX_EXTRACTION_DEPTH || aliases.len() >= MAX_LATERAL_ALIASES {
394        return;
395    }
396
397    match set_expr {
398        SetExpr::Select(select) => {
399            extract_lateral_aliases_from_select(select, sql, aliases);
400        }
401        SetExpr::Query(query) => {
402            // Handle CTEs within nested queries
403            if let Some(with) = &query.with {
404                for cte in &with.cte_tables {
405                    if aliases.len() >= MAX_LATERAL_ALIASES {
406                        break;
407                    }
408                    extract_lateral_aliases_from_set_expr(&cte.query.body, sql, aliases, depth + 1);
409                }
410            }
411            if aliases.len() < MAX_LATERAL_ALIASES {
412                extract_lateral_aliases_from_set_expr(&query.body, sql, aliases, depth + 1);
413            }
414        }
415        SetExpr::SetOperation { left, right, .. } => {
416            extract_lateral_aliases_from_set_expr(left, sql, aliases, depth + 1);
417            if aliases.len() < MAX_LATERAL_ALIASES {
418                extract_lateral_aliases_from_set_expr(right, sql, aliases, depth + 1);
419            }
420        }
421        _ => {}
422    }
423}
424
425/// Extract lateral aliases from a SELECT's projection.
426///
427/// Iterates over the projection items and records aliases with their positions,
428/// including the span of the containing SELECT projection for scope filtering.
429///
430/// # Safety Notes
431///
432/// The function validates that computed byte offsets are:
433/// 1. Within bounds of the SQL string
434/// 2. At valid UTF-8 character boundaries
435///
436/// Aliases with invalid offsets (e.g., from parser bugs or multi-byte char issues)
437/// are silently skipped to prevent panics.
438fn extract_lateral_aliases_from_select(
439    select: &Select,
440    sql: &str,
441    aliases: &mut Vec<LateralAliasInfo>,
442) {
443    // Early return if we've hit the extraction limit
444    if aliases.len() >= MAX_LATERAL_ALIASES {
445        return;
446    }
447
448    // Compute projection span from the first and last projection items
449    // This is used to filter aliases to only those in the same SELECT scope as cursor
450    let projection_span = compute_projection_span(select, sql);
451    let (projection_start, projection_end) = match projection_span {
452        Some((start, end)) => (start, end),
453        None => return, // Can't determine span, skip this SELECT
454    };
455
456    for item in &select.projection {
457        // Check limit on each iteration to stop early
458        if aliases.len() >= MAX_LATERAL_ALIASES {
459            break;
460        }
461
462        if let SelectItem::ExprWithAlias { alias, .. } = item {
463            // Convert line:column to byte offset
464            // sqlparser uses 1-indexed line/column numbers
465            if let Some(end_offset) = line_col_to_offset(
466                sql,
467                alias.span.end.line as usize,
468                alias.span.end.column as usize,
469            ) {
470                // Validate offset is within bounds and at a valid UTF-8 boundary
471                // This prevents panics on multi-byte characters (emoji, unicode identifiers)
472                if end_offset <= sql.len() && sql.is_char_boundary(end_offset) {
473                    aliases.push(LateralAliasInfo {
474                        name: alias.value.clone(),
475                        definition_end: end_offset,
476                        projection_start,
477                        projection_end,
478                    });
479                }
480            }
481        }
482    }
483}
484
485/// Compute the byte offset span of a SELECT's projection area.
486///
487/// The projection area starts at the first projection item and extends to:
488/// - The FROM clause (if present)
489/// - Or the end of the SELECT (if no FROM)
490///
491/// This allows cursor positions after the last alias (like `SELECT a AS x, |`)
492/// to still be considered "within" the projection area for lateral alias purposes.
493///
494/// Returns (start, end) offsets, or None if the span cannot be determined.
495fn compute_projection_span(select: &Select, sql: &str) -> Option<(usize, usize)> {
496    if select.projection.is_empty() {
497        return None;
498    }
499
500    // Find the first projection item that has a usable span. This skips leading
501    // wildcards (plain `*`), which deliberately return None in select_item_span.
502    // When every item lacks a span (e.g., `SELECT *`), fall back to the SELECT
503    // keyword span so we still have a safe projection boundary.
504    let first_span = select
505        .projection
506        .iter()
507        .filter_map(select_item_span)
508        .next()
509        .or_else(|| {
510            let span = select.span();
511            if span.start.line > 0 && span.start.column > 0 {
512                Some((span.start.line, span.start.column))
513            } else {
514                None
515            }
516        })?;
517    let start = line_col_to_offset(sql, first_span.0 as usize, first_span.1 as usize)?;
518
519    // Determine the end of the projection area
520    // Prefer FROM clause position if available, otherwise use last item
521    let end = if let Some(from_item) = select.from.first() {
522        // Use the start of the FROM clause as the end of projection area
523        compute_from_clause_start(from_item, sql).unwrap_or_else(|| {
524            // Fallback to last projection item
525            select
526                .projection
527                .last()
528                .and_then(|item| {
529                    let span = select_item_end_span(item)?;
530                    line_col_to_offset(sql, span.0 as usize, span.1 as usize)
531                })
532                .unwrap_or(sql.len())
533        })
534    } else {
535        // No FROM clause - use a large value to include trailing positions
536        // This handles cases like `SELECT a AS x, |` (no FROM yet)
537        sql.len()
538    };
539
540    // Validate spans
541    if start <= sql.len() && end <= sql.len() && start <= end {
542        Some((start, end))
543    } else {
544        None
545    }
546}
547
548/// Get the byte offset where the FROM clause starts.
549fn compute_from_clause_start(from_item: &TableWithJoins, sql: &str) -> Option<usize> {
550    // Get the span of the first table in FROM
551    let span = table_factor_span(&from_item.relation)?;
552    let table_start = line_col_to_offset(sql, span.0 as usize, span.1 as usize)?;
553
554    // Search backwards from table_start to find "FROM" keyword
555    // This gives us the actual start of the FROM clause
556    //
557    // Safety: We need to find a valid UTF-8 char boundary before slicing.
558    // saturating_sub(50) might land in the middle of a multi-byte character.
559    let search_start = find_char_boundary_before(sql, table_start.saturating_sub(50));
560    let search_area = &sql[search_start..table_start];
561
562    // Find "FROM" (case insensitive) using ASCII-only comparison.
563    // We avoid to_uppercase() because it can change string length for certain
564    // Unicode characters (e.g., German ß -> SS), causing position misalignment.
565    if let Some(pos) = rfind_ascii_case_insensitive(search_area, b"FROM") {
566        Some(search_start + pos)
567    } else {
568        // Fallback to table start if FROM not found
569        Some(table_start)
570    }
571}
572
573/// Find a valid UTF-8 char boundary at or before the given position.
574/// Returns 0 if no valid boundary is found before the position.
575fn find_char_boundary_before(s: &str, pos: usize) -> usize {
576    if pos >= s.len() {
577        return s.len();
578    }
579    // Walk backwards from pos to find a valid char boundary
580    (0..=pos)
581        .rev()
582        .find(|&i| s.is_char_boundary(i))
583        .unwrap_or(0)
584}
585
586/// Case-insensitive reverse search for an ASCII pattern in a string slice.
587/// Returns the byte offset of the last occurrence, or None if not found.
588///
589/// This is safe for UTF-8 strings because ASCII bytes (0x00-0x7F) never appear
590/// as continuation bytes in multi-byte UTF-8 sequences.
591fn rfind_ascii_case_insensitive(haystack: &str, needle: &[u8]) -> Option<usize> {
592    if needle.is_empty() || haystack.len() < needle.len() {
593        return None;
594    }
595
596    let haystack_bytes = haystack.as_bytes();
597
598    // Search from the end backwards
599    for start in (0..=(haystack_bytes.len() - needle.len())).rev() {
600        let mut matches = true;
601        for (i, &needle_byte) in needle.iter().enumerate() {
602            let hay_byte = haystack_bytes[start + i];
603            // ASCII case-insensitive comparison
604            if !hay_byte.eq_ignore_ascii_case(&needle_byte) {
605                matches = false;
606                break;
607            }
608        }
609        if matches {
610            return Some(start);
611        }
612    }
613    None
614}
615
616/// Get the start position (line, column) of a TableFactor.
617fn table_factor_span(tf: &TableFactor) -> Option<(u64, u64)> {
618    match tf {
619        TableFactor::Table { name, .. } => name.0.first().map(|i| {
620            let span = i.span();
621            (span.start.line, span.start.column)
622        }),
623        TableFactor::Derived { subquery, .. } => {
624            // For derived tables, try to get the span of the subquery
625            let span = subquery.body.span();
626            if span.start.line > 0 {
627                Some((span.start.line, span.start.column))
628            } else {
629                None
630            }
631        }
632        _ => None,
633    }
634}
635
636/// Get the start position (line, column) of a SelectItem.
637fn select_item_span(item: &SelectItem) -> Option<(u64, u64)> {
638    match item {
639        SelectItem::ExprWithAlias { expr, .. } | SelectItem::UnnamedExpr(expr) => {
640            expr_start_span(expr)
641        }
642        SelectItem::Wildcard(opts) => {
643            // Wildcard span comes from the options if available
644            if let Some(exclude) = &opts.opt_exclude {
645                // Use first exclusion's span if available
646                match exclude {
647                    sqlparser::ast::ExcludeSelectItem::Single(ident) => {
648                        Some((ident.span.start.line, ident.span.start.column))
649                    }
650                    sqlparser::ast::ExcludeSelectItem::Multiple(idents) => idents
651                        .first()
652                        .map(|i| (i.span.start.line, i.span.start.column)),
653                }
654            } else {
655                None
656            }
657        }
658        SelectItem::QualifiedWildcard(name, _) => {
659            let span = name.span();
660            Some((span.start.line, span.start.column))
661        }
662    }
663}
664
665/// Get the end position (line, column) of a SelectItem.
666fn select_item_end_span(item: &SelectItem) -> Option<(u64, u64)> {
667    match item {
668        SelectItem::ExprWithAlias { alias, .. } => {
669            Some((alias.span.end.line, alias.span.end.column))
670        }
671        SelectItem::UnnamedExpr(expr) => expr_end_span(expr),
672        SelectItem::Wildcard(_) => None, // Wildcard doesn't have reliable end span
673        SelectItem::QualifiedWildcard(name, _) => {
674            let span = name.span();
675            Some((span.end.line, span.end.column))
676        }
677    }
678}
679
680/// Get the start position (line, column) of an expression.
681/// Uses the Spanned trait for comprehensive coverage of all expression types.
682fn expr_start_span(expr: &Expr) -> Option<(u64, u64)> {
683    let span = expr.span();
684    // Check for valid span (non-zero positions indicate valid span)
685    if span.start.line > 0 && span.start.column > 0 {
686        Some((span.start.line, span.start.column))
687    } else {
688        None
689    }
690}
691
692/// Get the end position (line, column) of an expression.
693/// Uses the Spanned trait for comprehensive coverage of all expression types.
694fn expr_end_span(expr: &Expr) -> Option<(u64, u64)> {
695    let span = expr.span();
696    // Check for valid span (non-zero positions indicate valid span)
697    if span.end.line > 0 && span.end.column > 0 {
698        Some((span.end.line, span.end.column))
699    } else {
700        None
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use sqlparser::parser::Parser;
708
709    fn parse_sql(sql: &str) -> Vec<Statement> {
710        Parser::parse_sql(&sqlparser::dialect::GenericDialect {}, sql).unwrap()
711    }
712
713    #[test]
714    fn test_extract_cte() {
715        let sql = "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte";
716        let stmts = parse_sql(sql);
717        let ctx = extract_ast_context(&stmts);
718
719        assert!(ctx.cte_definitions.contains_key("cte"));
720        let cte = &ctx.cte_definitions["cte"];
721        assert_eq!(cte.name, "cte");
722        assert_eq!(cte.projected_columns.len(), 2);
723        assert_eq!(cte.projected_columns[0].name, "id");
724        assert_eq!(cte.projected_columns[1].name, "name");
725    }
726
727    #[test]
728    fn test_extract_cte_with_declared_columns() {
729        let sql = "WITH cte(a, b) AS (SELECT id, name FROM users) SELECT * FROM cte";
730        let stmts = parse_sql(sql);
731        let ctx = extract_ast_context(&stmts);
732
733        let cte = &ctx.cte_definitions["cte"];
734        assert_eq!(cte.declared_columns, vec!["a", "b"]);
735    }
736
737    #[test]
738    fn test_extract_table_alias() {
739        let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
740        let stmts = parse_sql(sql);
741        let ctx = extract_ast_context(&stmts);
742
743        // Aliases are stored as keys in the table_aliases map
744        assert!(ctx.table_aliases.contains_key("u"));
745        assert!(ctx.table_aliases.contains_key("o"));
746    }
747
748    #[test]
749    fn test_extract_subquery_alias() {
750        let sql = "SELECT * FROM (SELECT a, b FROM t) AS sub WHERE sub.a = 1";
751        let stmts = parse_sql(sql);
752        let ctx = extract_ast_context(&stmts);
753
754        assert!(ctx.subquery_aliases.contains_key("sub"));
755        let sub = &ctx.subquery_aliases["sub"];
756        assert_eq!(sub.projected_columns.len(), 2);
757        assert_eq!(sub.projected_columns[0].name, "a");
758        assert_eq!(sub.projected_columns[1].name, "b");
759    }
760
761    #[test]
762    fn test_extract_lateral_subquery() {
763        let sql = "SELECT * FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) AS o";
764        let stmts = parse_sql(sql);
765        let ctx = extract_ast_context(&stmts);
766
767        // Lateral subqueries are extracted just like regular derived tables
768        assert!(ctx.subquery_aliases.contains_key("o"));
769    }
770
771    #[test]
772    fn test_extract_column_with_alias() {
773        let sql =
774            "WITH cte AS (SELECT id AS user_id, name AS user_name FROM users) SELECT * FROM cte";
775        let stmts = parse_sql(sql);
776        let ctx = extract_ast_context(&stmts);
777
778        let cte = &ctx.cte_definitions["cte"];
779        assert_eq!(cte.projected_columns[0].name, "user_id");
780        assert_eq!(cte.projected_columns[1].name, "user_name");
781    }
782
783    #[test]
784    fn test_extract_function_column_name() {
785        let sql = "WITH cte AS (SELECT COUNT(*), SUM(amount) FROM orders) SELECT * FROM cte";
786        let stmts = parse_sql(sql);
787        let ctx = extract_ast_context(&stmts);
788
789        let cte = &ctx.cte_definitions["cte"];
790        assert!(cte.projected_columns[0]
791            .name
792            .to_lowercase()
793            .contains("count"));
794    }
795
796    #[test]
797    fn test_extract_wildcard() {
798        let sql = "WITH cte AS (SELECT * FROM users) SELECT * FROM cte";
799        let stmts = parse_sql(sql);
800        let ctx = extract_ast_context(&stmts);
801
802        let cte = &ctx.cte_definitions["cte"];
803        assert_eq!(cte.projected_columns[0].name, "*");
804    }
805
806    #[test]
807    fn test_extract_recursive_cte() {
808        let sql = r#"
809            WITH RECURSIVE cte AS (
810                SELECT 1 AS n
811                UNION ALL
812                SELECT n + 1 FROM cte WHERE n < 10
813            )
814            SELECT * FROM cte
815        "#;
816        let stmts = parse_sql(sql);
817        let ctx = extract_ast_context(&stmts);
818
819        let cte = &ctx.cte_definitions["cte"];
820        // Should have column from base case
821        assert_eq!(cte.projected_columns.len(), 1);
822        assert_eq!(cte.projected_columns[0].name, "n");
823    }
824
825    #[test]
826    fn test_has_enrichment() {
827        let sql = "SELECT * FROM users";
828        let stmts = parse_sql(sql);
829        let ctx = extract_ast_context(&stmts);
830
831        assert!(ctx.has_enrichment()); // Has table alias
832    }
833
834    #[test]
835    fn test_empty_context() {
836        let ctx = AstContext::default();
837        assert!(!ctx.has_enrichment());
838    }
839
840    // Lateral alias extraction tests
841
842    #[test]
843    fn test_extract_lateral_aliases_single() {
844        let sql = "SELECT price * qty AS total FROM orders";
845        let stmts = parse_sql(sql);
846        let aliases = extract_lateral_aliases(&stmts, sql);
847
848        assert_eq!(aliases.len(), 1);
849        assert_eq!(aliases[0].name, "total");
850        // The alias ends after "total" (position should be after the alias)
851        assert!(aliases[0].definition_end > 0);
852        assert!(aliases[0].definition_end <= sql.len());
853    }
854
855    #[test]
856    fn test_extract_lateral_aliases_with_leading_wildcard() {
857        let sql = "SELECT *, price * qty AS total, discount AS disc FROM orders";
858        let stmts = parse_sql(sql);
859        let aliases = extract_lateral_aliases(&stmts, sql);
860
861        let names: Vec<_> = aliases.iter().map(|a| a.name.as_str()).collect();
862        assert_eq!(names, vec!["total", "disc"]);
863    }
864
865    #[test]
866    fn test_extract_lateral_aliases_multiple() {
867        let sql = "SELECT a AS x, b AS y, c AS z FROM t";
868        let stmts = parse_sql(sql);
869        let aliases = extract_lateral_aliases(&stmts, sql);
870
871        assert_eq!(aliases.len(), 3);
872        assert_eq!(aliases[0].name, "x");
873        assert_eq!(aliases[1].name, "y");
874        assert_eq!(aliases[2].name, "z");
875        // Aliases should be ordered by position
876        assert!(aliases[0].definition_end < aliases[1].definition_end);
877        assert!(aliases[1].definition_end < aliases[2].definition_end);
878    }
879
880    #[test]
881    fn test_extract_lateral_aliases_with_expression() {
882        let sql = "SELECT price * qty AS total, total * 0.1 AS tax FROM orders";
883        let stmts = parse_sql(sql);
884        let aliases = extract_lateral_aliases(&stmts, sql);
885
886        assert_eq!(aliases.len(), 2);
887        assert_eq!(aliases[0].name, "total");
888        assert_eq!(aliases[1].name, "tax");
889    }
890
891    #[test]
892    fn test_extract_lateral_aliases_no_aliases() {
893        let sql = "SELECT price, qty FROM orders";
894        let stmts = parse_sql(sql);
895        let aliases = extract_lateral_aliases(&stmts, sql);
896
897        assert!(aliases.is_empty());
898    }
899
900    #[test]
901    fn test_extract_lateral_aliases_mixed() {
902        // Mix of aliased and non-aliased columns
903        let sql = "SELECT a, b AS alias_b, c FROM t";
904        let stmts = parse_sql(sql);
905        let aliases = extract_lateral_aliases(&stmts, sql);
906
907        assert_eq!(aliases.len(), 1);
908        assert_eq!(aliases[0].name, "alias_b");
909    }
910
911    #[test]
912    fn test_extract_lateral_aliases_quoted() {
913        let sql = r#"SELECT a AS "My Total", b AS "Tax Amount" FROM t"#;
914        let stmts = parse_sql(sql);
915        let aliases = extract_lateral_aliases(&stmts, sql);
916
917        assert_eq!(aliases.len(), 2);
918        assert_eq!(aliases[0].name, "My Total");
919        assert_eq!(aliases[1].name, "Tax Amount");
920    }
921
922    #[test]
923    fn test_extract_lateral_aliases_subquery_in_from() {
924        // Aliases in subqueries in FROM clause should NOT be extracted,
925        // because lateral aliases are only visible within the same SELECT list
926        let sql = "SELECT * FROM (SELECT a AS x, b AS y FROM t) sub";
927        let stmts = parse_sql(sql);
928        let aliases = extract_lateral_aliases(&stmts, sql);
929
930        // The outer SELECT has SELECT * which has no aliases
931        assert_eq!(aliases.len(), 0);
932    }
933
934    #[test]
935    fn test_extract_lateral_aliases_outer_select_with_alias() {
936        // If the outer SELECT has aliases, those should be extracted
937        let sql = "SELECT sub.x AS outer_x FROM (SELECT a AS x FROM t) sub";
938        let stmts = parse_sql(sql);
939        let aliases = extract_lateral_aliases(&stmts, sql);
940
941        assert_eq!(aliases.len(), 1);
942        assert_eq!(aliases[0].name, "outer_x");
943    }
944
945    #[test]
946    fn test_extract_lateral_aliases_with_unicode() {
947        // Unicode characters in SQL should not cause panics
948        // The parser may report byte offsets that don't align with char boundaries
949        let sql = "SELECT '日本語' AS label, value AS val FROM t";
950        let stmts = parse_sql(sql);
951        let aliases = extract_lateral_aliases(&stmts, sql);
952
953        // Should successfully extract aliases even with unicode in the SQL
954        assert_eq!(aliases.len(), 2);
955        assert_eq!(aliases[0].name, "label");
956        assert_eq!(aliases[1].name, "val");
957    }
958
959    #[test]
960    fn test_extract_lateral_aliases_cte_scope_isolation() {
961        // Aliases from CTE's SELECT should have different projection spans
962        // than aliases from the outer SELECT
963        let sql =
964            "WITH cte AS (SELECT a AS inner_alias FROM t) SELECT cte.a AS outer_alias FROM cte";
965        let stmts = parse_sql(sql);
966        let aliases = extract_lateral_aliases(&stmts, sql);
967
968        // Should extract both aliases, but they should have different projection spans
969        assert_eq!(aliases.len(), 2);
970
971        let inner = aliases.iter().find(|a| a.name == "inner_alias").unwrap();
972        let outer = aliases.iter().find(|a| a.name == "outer_alias").unwrap();
973
974        // Inner alias projection should be before outer alias projection
975        assert!(
976            inner.projection_start < outer.projection_start,
977            "CTE projection should start before outer SELECT projection"
978        );
979
980        // Projection spans should not overlap significantly
981        assert!(
982            inner.projection_end < outer.projection_start
983                || outer.projection_end < inner.projection_start
984                || inner.projection_start != outer.projection_start,
985            "CTE and outer SELECT projections should have different spans"
986        );
987    }
988
989    #[test]
990    fn test_extract_lateral_aliases_projection_span_validity() {
991        // Verify that projection spans are valid and contain the alias
992        let sql = "SELECT a AS x, b AS y FROM t";
993        let stmts = parse_sql(sql);
994        let aliases = extract_lateral_aliases(&stmts, sql);
995
996        assert_eq!(aliases.len(), 2);
997
998        for alias in &aliases {
999            // Projection span should contain the alias definition
1000            assert!(
1001                alias.definition_end <= alias.projection_end,
1002                "Alias definition should be within projection span"
1003            );
1004            assert!(
1005                alias.projection_start < alias.definition_end,
1006                "Projection should start before alias definition ends"
1007            );
1008        }
1009    }
1010}