Skip to main content

polyglot_sql/
lineage.rs

1//! Column Lineage Tracking
2//!
3//! This module provides functionality to track column lineage through SQL queries,
4//! building a graph of how columns flow from source tables to the result set.
5//! Supports UNION/INTERSECT/EXCEPT, CTEs, derived tables, subqueries, and star expansion.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::optimizer::annotate_types::annotate_types;
11use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
12use crate::schema::{normalize_name, Schema};
13use crate::scope::{build_scope, Scope};
14use crate::traversal::ExpressionWalk;
15use crate::{Error, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18
19/// A node in the column lineage graph
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct LineageNode {
22    /// Name of this lineage step (e.g., "table.column")
23    pub name: String,
24    /// The expression at this node
25    pub expression: Expression,
26    /// The source expression (the full query context)
27    pub source: Expression,
28    /// Downstream nodes that depend on this one
29    pub downstream: Vec<LineageNode>,
30    /// Optional source name (e.g., for derived tables)
31    pub source_name: String,
32    /// Optional reference node name (e.g., for CTEs)
33    pub reference_node_name: String,
34}
35
36impl LineageNode {
37    /// Create a new lineage node
38    pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
39        Self {
40            name: name.into(),
41            expression,
42            source,
43            downstream: Vec::new(),
44            source_name: String::new(),
45            reference_node_name: String::new(),
46        }
47    }
48
49    /// Iterate over all nodes in the lineage graph using DFS
50    pub fn walk(&self) -> LineageWalker<'_> {
51        LineageWalker { stack: vec![self] }
52    }
53
54    /// Get all downstream column names
55    pub fn downstream_names(&self) -> Vec<String> {
56        self.downstream.iter().map(|n| n.name.clone()).collect()
57    }
58}
59
60/// Iterator for walking the lineage graph
61pub struct LineageWalker<'a> {
62    stack: Vec<&'a LineageNode>,
63}
64
65impl<'a> Iterator for LineageWalker<'a> {
66    type Item = &'a LineageNode;
67
68    fn next(&mut self) -> Option<Self::Item> {
69        if let Some(node) = self.stack.pop() {
70            // Add children in reverse order so they're visited in order
71            for child in node.downstream.iter().rev() {
72                self.stack.push(child);
73            }
74            Some(node)
75        } else {
76            None
77        }
78    }
79}
80
81// ---------------------------------------------------------------------------
82// ColumnRef: name or positional index for column lookup
83// ---------------------------------------------------------------------------
84
85/// Column reference for lineage tracing — by name or positional index.
86enum ColumnRef<'a> {
87    Name(&'a str),
88    Index(usize),
89}
90
91// ---------------------------------------------------------------------------
92// Public API
93// ---------------------------------------------------------------------------
94
95/// Build the lineage graph for a column in a SQL query
96///
97/// # Arguments
98/// * `column` - The column name to trace lineage for
99/// * `sql` - The SQL expression (SELECT, UNION, etc.)
100/// * `dialect` - Optional dialect for parsing
101/// * `trim_selects` - If true, trim the source SELECT to only include the target column
102///
103/// # Returns
104/// The root lineage node for the specified column
105///
106/// # Example
107/// ```ignore
108/// use polyglot_sql::lineage::lineage;
109/// use polyglot_sql::parse_one;
110/// use polyglot_sql::DialectType;
111///
112/// let sql = "SELECT a, b + 1 AS c FROM t";
113/// let expr = parse_one(sql, DialectType::Generic).unwrap();
114/// let node = lineage("c", &expr, None, false).unwrap();
115/// ```
116pub fn lineage(
117    column: &str,
118    sql: &Expression,
119    dialect: Option<DialectType>,
120    trim_selects: bool,
121) -> Result<LineageNode> {
122    lineage_from_expression(column, sql, dialect, trim_selects)
123}
124
125/// Build the lineage graph for a column in a SQL query using optional schema metadata.
126///
127/// When `schema` is provided, the query is first qualified with
128/// `optimizer::qualify_columns`, allowing more accurate lineage for unqualified or
129/// ambiguous column references.
130///
131/// # Arguments
132/// * `column` - The column name to trace lineage for
133/// * `sql` - The SQL expression (SELECT, UNION, etc.)
134/// * `schema` - Optional schema used for qualification
135/// * `dialect` - Optional dialect for qualification and lineage handling
136/// * `trim_selects` - If true, trim the source SELECT to only include the target column
137///
138/// # Returns
139/// The root lineage node for the specified column
140pub fn lineage_with_schema(
141    column: &str,
142    sql: &Expression,
143    schema: Option<&dyn Schema>,
144    dialect: Option<DialectType>,
145    trim_selects: bool,
146) -> Result<LineageNode> {
147    let mut qualified_expression = if let Some(schema) = schema {
148        let options = if let Some(dialect_type) = dialect.or_else(|| schema.dialect()) {
149            QualifyColumnsOptions::new().with_dialect(dialect_type)
150        } else {
151            QualifyColumnsOptions::new()
152        };
153
154        qualify_columns(sql.clone(), schema, &options).map_err(|e| {
155            Error::internal(format!("Lineage qualification failed with schema: {}", e))
156        })?
157    } else {
158        sql.clone()
159    };
160
161    // Annotate types in-place so lineage nodes carry type information
162    annotate_types(&mut qualified_expression, schema, dialect);
163
164    lineage_from_expression(column, &qualified_expression, dialect, trim_selects)
165}
166
167fn lineage_from_expression(
168    column: &str,
169    sql: &Expression,
170    dialect: Option<DialectType>,
171    trim_selects: bool,
172) -> Result<LineageNode> {
173    let scope = build_scope(sql);
174    to_node(
175        ColumnRef::Name(column),
176        &scope,
177        dialect,
178        "",
179        "",
180        "",
181        trim_selects,
182    )
183}
184
185/// Get all source tables from a lineage graph
186pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
187    let mut tables = HashSet::new();
188    collect_source_tables(node, &mut tables);
189    tables
190}
191
192/// Recursively collect source table names from lineage graph
193pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
194    if let Expression::Table(table) = &node.source {
195        tables.insert(table.name.name.clone());
196    }
197    for child in &node.downstream {
198        collect_source_tables(child, tables);
199    }
200}
201
202// ---------------------------------------------------------------------------
203// Core recursive lineage builder
204// ---------------------------------------------------------------------------
205
206/// Recursively build a lineage node for a column in a scope.
207fn to_node(
208    column: ColumnRef<'_>,
209    scope: &Scope,
210    dialect: Option<DialectType>,
211    scope_name: &str,
212    source_name: &str,
213    reference_node_name: &str,
214    trim_selects: bool,
215) -> Result<LineageNode> {
216    to_node_inner(
217        column,
218        scope,
219        dialect,
220        scope_name,
221        source_name,
222        reference_node_name,
223        trim_selects,
224        &[],
225    )
226}
227
228fn to_node_inner(
229    column: ColumnRef<'_>,
230    scope: &Scope,
231    dialect: Option<DialectType>,
232    scope_name: &str,
233    source_name: &str,
234    reference_node_name: &str,
235    trim_selects: bool,
236    ancestor_cte_scopes: &[Scope],
237) -> Result<LineageNode> {
238    let scope_expr = &scope.expression;
239
240    // Build combined CTE scopes: current scope's cte_scopes + ancestors
241    let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
242    for s in ancestor_cte_scopes {
243        all_cte_scopes.push(s);
244    }
245
246    // 0. Unwrap CTE scope — CTE scope expressions are Expression::Cte(...)
247    //    but we need the inner query (SELECT/UNION) for column lookup.
248    let effective_expr = match scope_expr {
249        Expression::Cte(cte) => &cte.this,
250        other => other,
251    };
252
253    // 1. Set operations (UNION / INTERSECT / EXCEPT)
254    if matches!(
255        effective_expr,
256        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
257    ) {
258        // For CTE wrapping a set op, create a temporary scope with the inner expression
259        if matches!(scope_expr, Expression::Cte(_)) {
260            let mut inner_scope = Scope::new(effective_expr.clone());
261            inner_scope.union_scopes = scope.union_scopes.clone();
262            inner_scope.sources = scope.sources.clone();
263            inner_scope.cte_sources = scope.cte_sources.clone();
264            inner_scope.cte_scopes = scope.cte_scopes.clone();
265            inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
266            inner_scope.subquery_scopes = scope.subquery_scopes.clone();
267            return handle_set_operation(
268                &column,
269                &inner_scope,
270                dialect,
271                scope_name,
272                source_name,
273                reference_node_name,
274                trim_selects,
275                ancestor_cte_scopes,
276            );
277        }
278        return handle_set_operation(
279            &column,
280            scope,
281            dialect,
282            scope_name,
283            source_name,
284            reference_node_name,
285            trim_selects,
286            ancestor_cte_scopes,
287        );
288    }
289
290    // 2. Find the select expression for this column
291    let select_expr = find_select_expr(effective_expr, &column, dialect)?;
292    let column_name = resolve_column_name(&column, &select_expr);
293
294    // 3. Trim source if requested
295    let node_source = if trim_selects {
296        trim_source(effective_expr, &select_expr)
297    } else {
298        effective_expr.clone()
299    };
300
301    // 4. Create the lineage node
302    let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
303    node.source_name = source_name.to_string();
304    node.reference_node_name = reference_node_name.to_string();
305
306    // 5. Star handling — add downstream for each source
307    if matches!(&select_expr, Expression::Star(_)) {
308        for (name, source_info) in &scope.sources {
309            let child = LineageNode::new(
310                format!("{}.*", name),
311                Expression::Star(crate::expressions::Star {
312                    table: None,
313                    except: None,
314                    replace: None,
315                    rename: None,
316                    trailing_comments: vec![],
317                    span: None,
318                }),
319                source_info.expression.clone(),
320            );
321            node.downstream.push(child);
322        }
323        return Ok(node);
324    }
325
326    // 6. Subqueries in select — trace through scalar subqueries
327    let subqueries: Vec<&Expression> =
328        select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
329    for sq_expr in subqueries {
330        if let Expression::Subquery(sq) = sq_expr {
331            for sq_scope in &scope.subquery_scopes {
332                if sq_scope.expression == sq.this {
333                    if let Ok(child) = to_node_inner(
334                        ColumnRef::Index(0),
335                        sq_scope,
336                        dialect,
337                        &column_name,
338                        "",
339                        "",
340                        trim_selects,
341                        ancestor_cte_scopes,
342                    ) {
343                        node.downstream.push(child);
344                    }
345                    break;
346                }
347            }
348        }
349    }
350
351    // 7. Column references — trace each column to its source
352    let col_refs = find_column_refs_in_expr(&select_expr);
353    for col_ref in col_refs {
354        let col_name = &col_ref.column;
355        if let Some(ref table_id) = col_ref.table {
356            let tbl = &table_id.name;
357            resolve_qualified_column(
358                &mut node,
359                scope,
360                dialect,
361                tbl,
362                col_name,
363                &column_name,
364                trim_selects,
365                &all_cte_scopes,
366            );
367        } else {
368            resolve_unqualified_column(
369                &mut node,
370                scope,
371                dialect,
372                col_name,
373                &column_name,
374                trim_selects,
375                &all_cte_scopes,
376            );
377        }
378    }
379
380    Ok(node)
381}
382
383// ---------------------------------------------------------------------------
384// Set operation handling
385// ---------------------------------------------------------------------------
386
387fn handle_set_operation(
388    column: &ColumnRef<'_>,
389    scope: &Scope,
390    dialect: Option<DialectType>,
391    scope_name: &str,
392    source_name: &str,
393    reference_node_name: &str,
394    trim_selects: bool,
395    ancestor_cte_scopes: &[Scope],
396) -> Result<LineageNode> {
397    let scope_expr = &scope.expression;
398
399    // Determine column index
400    let col_index = match column {
401        ColumnRef::Name(name) => column_to_index(scope_expr, name, dialect)?,
402        ColumnRef::Index(i) => *i,
403    };
404
405    let col_name = match column {
406        ColumnRef::Name(name) => name.to_string(),
407        ColumnRef::Index(_) => format!("_{col_index}"),
408    };
409
410    let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
411    node.source_name = source_name.to_string();
412    node.reference_node_name = reference_node_name.to_string();
413
414    // Recurse into each union branch
415    for branch_scope in &scope.union_scopes {
416        if let Ok(child) = to_node_inner(
417            ColumnRef::Index(col_index),
418            branch_scope,
419            dialect,
420            scope_name,
421            "",
422            "",
423            trim_selects,
424            ancestor_cte_scopes,
425        ) {
426            node.downstream.push(child);
427        }
428    }
429
430    Ok(node)
431}
432
433// ---------------------------------------------------------------------------
434// Column resolution helpers
435// ---------------------------------------------------------------------------
436
437fn resolve_qualified_column(
438    node: &mut LineageNode,
439    scope: &Scope,
440    dialect: Option<DialectType>,
441    table: &str,
442    col_name: &str,
443    parent_name: &str,
444    trim_selects: bool,
445    all_cte_scopes: &[&Scope],
446) {
447    // Check if table is a CTE reference (cte_sources tracks CTE names)
448    if scope.cte_sources.contains_key(table) {
449        if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
450            // Build ancestor CTE scopes from all_cte_scopes for the recursive call
451            let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
452            if let Ok(child) = to_node_inner(
453                ColumnRef::Name(col_name),
454                child_scope,
455                dialect,
456                parent_name,
457                table,
458                parent_name,
459                trim_selects,
460                &ancestors,
461            ) {
462                node.downstream.push(child);
463                return;
464            }
465        }
466    }
467
468    // Check if table is a derived table (is_scope = true in sources)
469    if let Some(source_info) = scope.sources.get(table) {
470        if source_info.is_scope {
471            if let Some(child_scope) = find_child_scope(scope, table) {
472                let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
473                if let Ok(child) = to_node_inner(
474                    ColumnRef::Name(col_name),
475                    child_scope,
476                    dialect,
477                    parent_name,
478                    table,
479                    parent_name,
480                    trim_selects,
481                    &ancestors,
482                ) {
483                    node.downstream.push(child);
484                    return;
485                }
486            }
487        }
488    }
489
490    // Base table source found in current scope: preserve alias in the display name
491    // but store the resolved table expression and name for downstream consumers.
492    if let Some(source_info) = scope.sources.get(table) {
493        if !source_info.is_scope {
494            node.downstream.push(make_table_column_node_from_source(
495                table,
496                col_name,
497                &source_info.expression,
498            ));
499            return;
500        }
501    }
502
503    // Base table or unresolved — terminal node
504    node.downstream
505        .push(make_table_column_node(table, col_name));
506}
507
508fn resolve_unqualified_column(
509    node: &mut LineageNode,
510    scope: &Scope,
511    dialect: Option<DialectType>,
512    col_name: &str,
513    parent_name: &str,
514    trim_selects: bool,
515    all_cte_scopes: &[&Scope],
516) {
517    // Try to find which source this column belongs to.
518    // Build the source list from the actual FROM/JOIN clauses to avoid
519    // mixing in CTE definitions that are in scope but not referenced.
520    let from_source_names = source_names_from_from_join(scope);
521
522    if from_source_names.len() == 1 {
523        let tbl = &from_source_names[0];
524        resolve_qualified_column(
525            node,
526            scope,
527            dialect,
528            tbl,
529            col_name,
530            parent_name,
531            trim_selects,
532            all_cte_scopes,
533        );
534        return;
535    }
536
537    // Multiple sources — can't resolve without schema info, add unqualified node
538    let child = LineageNode::new(
539        col_name.to_string(),
540        Expression::Column(crate::expressions::Column {
541            name: crate::expressions::Identifier::new(col_name.to_string()),
542            table: None,
543            join_mark: false,
544            trailing_comments: vec![],
545            span: None,
546            inferred_type: None,
547        }),
548        node.source.clone(),
549    );
550    node.downstream.push(child);
551}
552
553fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
554    fn source_name(expr: &Expression) -> Option<String> {
555        match expr {
556            Expression::Table(table) => Some(
557                table
558                    .alias
559                    .as_ref()
560                    .map(|a| a.name.clone())
561                    .unwrap_or_else(|| table.name.name.clone()),
562            ),
563            Expression::Subquery(subquery) => {
564                subquery.alias.as_ref().map(|alias| alias.name.clone())
565            }
566            Expression::Paren(paren) => source_name(&paren.this),
567            _ => None,
568        }
569    }
570
571    let effective_expr = match &scope.expression {
572        Expression::Cte(cte) => &cte.this,
573        expr => expr,
574    };
575
576    let mut names = Vec::new();
577    let mut seen = std::collections::HashSet::new();
578
579    if let Expression::Select(select) = effective_expr {
580        if let Some(from) = &select.from {
581            for expr in &from.expressions {
582                if let Some(name) = source_name(expr) {
583                    if !name.is_empty() && seen.insert(name.clone()) {
584                        names.push(name);
585                    }
586                }
587            }
588        }
589        for join in &select.joins {
590            if let Some(name) = source_name(&join.this) {
591                if !name.is_empty() && seen.insert(name.clone()) {
592                    names.push(name);
593                }
594            }
595        }
596    }
597
598    names
599}
600
601// ---------------------------------------------------------------------------
602// Helper functions
603// ---------------------------------------------------------------------------
604
605/// Get the alias or name of an expression
606fn get_alias_or_name(expr: &Expression) -> Option<String> {
607    match expr {
608        Expression::Alias(alias) => Some(alias.alias.name.clone()),
609        Expression::Column(col) => Some(col.name.name.clone()),
610        Expression::Identifier(id) => Some(id.name.clone()),
611        Expression::Star(_) => Some("*".to_string()),
612        _ => None,
613    }
614}
615
616/// Resolve the display name for a column reference.
617fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
618    match column {
619        ColumnRef::Name(n) => n.to_string(),
620        ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
621    }
622}
623
624/// Find the select expression matching a column reference.
625fn find_select_expr(
626    scope_expr: &Expression,
627    column: &ColumnRef<'_>,
628    dialect: Option<DialectType>,
629) -> Result<Expression> {
630    if let Expression::Select(ref select) = scope_expr {
631        match column {
632            ColumnRef::Name(name) => {
633                let normalized_name = normalize_column_name(name, dialect);
634                for expr in &select.expressions {
635                    if let Some(alias_or_name) = get_alias_or_name(expr) {
636                        if normalize_column_name(&alias_or_name, dialect) == normalized_name {
637                            return Ok(expr.clone());
638                        }
639                    }
640                }
641                Err(crate::error::Error::parse(
642                    format!("Cannot find column '{}' in query", name),
643                    0,
644                    0,
645                    0,
646                    0,
647                ))
648            }
649            ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
650                crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0, 0, 0)
651            }),
652        }
653    } else {
654        Err(crate::error::Error::parse(
655            "Expected SELECT expression for column lookup",
656            0,
657            0,
658            0,
659            0,
660        ))
661    }
662}
663
664/// Find the positional index of a column name in a set operation's first SELECT branch.
665fn column_to_index(
666    set_op_expr: &Expression,
667    name: &str,
668    dialect: Option<DialectType>,
669) -> Result<usize> {
670    let normalized_name = normalize_column_name(name, dialect);
671    let mut expr = set_op_expr;
672    loop {
673        match expr {
674            Expression::Union(u) => expr = &u.left,
675            Expression::Intersect(i) => expr = &i.left,
676            Expression::Except(e) => expr = &e.left,
677            Expression::Select(select) => {
678                for (i, e) in select.expressions.iter().enumerate() {
679                    if let Some(alias_or_name) = get_alias_or_name(e) {
680                        if normalize_column_name(&alias_or_name, dialect) == normalized_name {
681                            return Ok(i);
682                        }
683                    }
684                }
685                return Err(crate::error::Error::parse(
686                    format!("Cannot find column '{}' in set operation", name),
687                    0,
688                    0,
689                    0,
690                    0,
691                ));
692            }
693            _ => {
694                return Err(crate::error::Error::parse(
695                    "Expected SELECT or set operation",
696                    0,
697                    0,
698                    0,
699                    0,
700                ))
701            }
702        }
703    }
704}
705
706fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
707    normalize_name(name, dialect, false, true)
708}
709
710/// If trim_selects is enabled, return a copy of the SELECT with only the target column.
711fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
712    if let Expression::Select(select) = select_expr {
713        let mut trimmed = select.as_ref().clone();
714        trimmed.expressions = vec![target_expr.clone()];
715        Expression::Select(Box::new(trimmed))
716    } else {
717        select_expr.clone()
718    }
719}
720
721/// Find the child scope (CTE or derived table) for a given source name.
722fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
723    // Check CTE scopes
724    if scope.cte_sources.contains_key(source_name) {
725        for cte_scope in &scope.cte_scopes {
726            if let Expression::Cte(cte) = &cte_scope.expression {
727                if cte.alias.name == source_name {
728                    return Some(cte_scope);
729                }
730            }
731        }
732    }
733
734    // Check derived table scopes
735    if let Some(source_info) = scope.sources.get(source_name) {
736        if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
737            if let Expression::Subquery(sq) = &source_info.expression {
738                for dt_scope in &scope.derived_table_scopes {
739                    if dt_scope.expression == sq.this {
740                        return Some(dt_scope);
741                    }
742                }
743            }
744        }
745    }
746
747    None
748}
749
750/// Find a CTE scope by name, searching through a combined list of CTE scopes.
751/// This handles nested CTEs where the current scope doesn't have the CTE scope
752/// as a direct child but knows about it via cte_sources.
753fn find_child_scope_in<'a>(
754    all_cte_scopes: &[&'a Scope],
755    scope: &'a Scope,
756    source_name: &str,
757) -> Option<&'a Scope> {
758    // First try the scope's own cte_scopes
759    for cte_scope in &scope.cte_scopes {
760        if let Expression::Cte(cte) = &cte_scope.expression {
761            if cte.alias.name == source_name {
762                return Some(cte_scope);
763            }
764        }
765    }
766
767    // Then search through all ancestor CTE scopes
768    for cte_scope in all_cte_scopes {
769        if let Expression::Cte(cte) = &cte_scope.expression {
770            if cte.alias.name == source_name {
771                return Some(cte_scope);
772            }
773        }
774    }
775
776    // Fall back to derived table scopes
777    if let Some(source_info) = scope.sources.get(source_name) {
778        if source_info.is_scope {
779            if let Expression::Subquery(sq) = &source_info.expression {
780                for dt_scope in &scope.derived_table_scopes {
781                    if dt_scope.expression == sq.this {
782                        return Some(dt_scope);
783                    }
784                }
785            }
786        }
787    }
788
789    None
790}
791
792/// Create a terminal lineage node for a table.column reference.
793fn make_table_column_node(table: &str, column: &str) -> LineageNode {
794    let mut node = LineageNode::new(
795        format!("{}.{}", table, column),
796        Expression::Column(crate::expressions::Column {
797            name: crate::expressions::Identifier::new(column.to_string()),
798            table: Some(crate::expressions::Identifier::new(table.to_string())),
799            join_mark: false,
800            trailing_comments: vec![],
801            span: None,
802            inferred_type: None,
803        }),
804        Expression::Table(crate::expressions::TableRef::new(table)),
805    );
806    node.source_name = table.to_string();
807    node
808}
809
810fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
811    let mut parts: Vec<String> = Vec::new();
812    if let Some(catalog) = &table_ref.catalog {
813        parts.push(catalog.name.clone());
814    }
815    if let Some(schema) = &table_ref.schema {
816        parts.push(schema.name.clone());
817    }
818    parts.push(table_ref.name.name.clone());
819    parts.join(".")
820}
821
822fn make_table_column_node_from_source(
823    table_alias: &str,
824    column: &str,
825    source: &Expression,
826) -> LineageNode {
827    let mut node = LineageNode::new(
828        format!("{}.{}", table_alias, column),
829        Expression::Column(crate::expressions::Column {
830            name: crate::expressions::Identifier::new(column.to_string()),
831            table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
832            join_mark: false,
833            trailing_comments: vec![],
834            span: None,
835            inferred_type: None,
836        }),
837        source.clone(),
838    );
839
840    if let Expression::Table(table_ref) = source {
841        node.source_name = table_name_from_table_ref(table_ref);
842    } else {
843        node.source_name = table_alias.to_string();
844    }
845
846    node
847}
848
849/// Simple column reference extracted from an expression
850#[derive(Debug, Clone)]
851struct SimpleColumnRef {
852    table: Option<crate::expressions::Identifier>,
853    column: String,
854}
855
856/// Find all column references in an expression (does not recurse into subqueries).
857fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
858    let mut refs = Vec::new();
859    collect_column_refs(expr, &mut refs);
860    refs
861}
862
863fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
864    let mut stack: Vec<&Expression> = vec![expr];
865
866    while let Some(current) = stack.pop() {
867        match current {
868            // === Leaf: collect Column references ===
869            Expression::Column(col) => {
870                refs.push(SimpleColumnRef {
871                    table: col.table.clone(),
872                    column: col.name.name.clone(),
873                });
874            }
875
876            // === Boundary: don't recurse into subqueries (handled separately) ===
877            Expression::Subquery(_) | Expression::Exists(_) => {}
878
879            // === BinaryOp variants: left, right ===
880            Expression::And(op)
881            | Expression::Or(op)
882            | Expression::Eq(op)
883            | Expression::Neq(op)
884            | Expression::Lt(op)
885            | Expression::Lte(op)
886            | Expression::Gt(op)
887            | Expression::Gte(op)
888            | Expression::Add(op)
889            | Expression::Sub(op)
890            | Expression::Mul(op)
891            | Expression::Div(op)
892            | Expression::Mod(op)
893            | Expression::BitwiseAnd(op)
894            | Expression::BitwiseOr(op)
895            | Expression::BitwiseXor(op)
896            | Expression::BitwiseLeftShift(op)
897            | Expression::BitwiseRightShift(op)
898            | Expression::Concat(op)
899            | Expression::Adjacent(op)
900            | Expression::TsMatch(op)
901            | Expression::PropertyEQ(op)
902            | Expression::ArrayContainsAll(op)
903            | Expression::ArrayContainedBy(op)
904            | Expression::ArrayOverlaps(op)
905            | Expression::JSONBContainsAllTopKeys(op)
906            | Expression::JSONBContainsAnyTopKeys(op)
907            | Expression::JSONBDeleteAtPath(op)
908            | Expression::ExtendsLeft(op)
909            | Expression::ExtendsRight(op)
910            | Expression::Is(op)
911            | Expression::MemberOf(op)
912            | Expression::NullSafeEq(op)
913            | Expression::NullSafeNeq(op)
914            | Expression::Glob(op)
915            | Expression::Match(op) => {
916                stack.push(&op.left);
917                stack.push(&op.right);
918            }
919
920            // === UnaryOp variants: this ===
921            Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
922                stack.push(&u.this);
923            }
924
925            // === UnaryFunc variants: this ===
926            Expression::Upper(f)
927            | Expression::Lower(f)
928            | Expression::Length(f)
929            | Expression::LTrim(f)
930            | Expression::RTrim(f)
931            | Expression::Reverse(f)
932            | Expression::Abs(f)
933            | Expression::Sqrt(f)
934            | Expression::Cbrt(f)
935            | Expression::Ln(f)
936            | Expression::Exp(f)
937            | Expression::Sign(f)
938            | Expression::Date(f)
939            | Expression::Time(f)
940            | Expression::DateFromUnixDate(f)
941            | Expression::UnixDate(f)
942            | Expression::UnixSeconds(f)
943            | Expression::UnixMillis(f)
944            | Expression::UnixMicros(f)
945            | Expression::TimeStrToDate(f)
946            | Expression::DateToDi(f)
947            | Expression::DiToDate(f)
948            | Expression::TsOrDiToDi(f)
949            | Expression::TsOrDsToDatetime(f)
950            | Expression::TsOrDsToTimestamp(f)
951            | Expression::YearOfWeek(f)
952            | Expression::YearOfWeekIso(f)
953            | Expression::Initcap(f)
954            | Expression::Ascii(f)
955            | Expression::Chr(f)
956            | Expression::Soundex(f)
957            | Expression::ByteLength(f)
958            | Expression::Hex(f)
959            | Expression::LowerHex(f)
960            | Expression::Unicode(f)
961            | Expression::Radians(f)
962            | Expression::Degrees(f)
963            | Expression::Sin(f)
964            | Expression::Cos(f)
965            | Expression::Tan(f)
966            | Expression::Asin(f)
967            | Expression::Acos(f)
968            | Expression::Atan(f)
969            | Expression::IsNan(f)
970            | Expression::IsInf(f)
971            | Expression::ArrayLength(f)
972            | Expression::ArraySize(f)
973            | Expression::Cardinality(f)
974            | Expression::ArrayReverse(f)
975            | Expression::ArrayDistinct(f)
976            | Expression::ArrayFlatten(f)
977            | Expression::ArrayCompact(f)
978            | Expression::Explode(f)
979            | Expression::ExplodeOuter(f)
980            | Expression::ToArray(f)
981            | Expression::MapFromEntries(f)
982            | Expression::MapKeys(f)
983            | Expression::MapValues(f)
984            | Expression::JsonArrayLength(f)
985            | Expression::JsonKeys(f)
986            | Expression::JsonType(f)
987            | Expression::ParseJson(f)
988            | Expression::ToJson(f)
989            | Expression::Typeof(f)
990            | Expression::BitwiseCount(f)
991            | Expression::Year(f)
992            | Expression::Month(f)
993            | Expression::Day(f)
994            | Expression::Hour(f)
995            | Expression::Minute(f)
996            | Expression::Second(f)
997            | Expression::DayOfWeek(f)
998            | Expression::DayOfWeekIso(f)
999            | Expression::DayOfMonth(f)
1000            | Expression::DayOfYear(f)
1001            | Expression::WeekOfYear(f)
1002            | Expression::Quarter(f)
1003            | Expression::Epoch(f)
1004            | Expression::EpochMs(f)
1005            | Expression::TimeStrToUnix(f)
1006            | Expression::SHA(f)
1007            | Expression::SHA1Digest(f)
1008            | Expression::TimeToUnix(f)
1009            | Expression::JSONBool(f)
1010            | Expression::Int64(f)
1011            | Expression::MD5NumberLower64(f)
1012            | Expression::MD5NumberUpper64(f)
1013            | Expression::DateStrToDate(f)
1014            | Expression::DateToDateStr(f) => {
1015                stack.push(&f.this);
1016            }
1017
1018            // === BinaryFunc variants: this, expression ===
1019            Expression::Power(f)
1020            | Expression::NullIf(f)
1021            | Expression::IfNull(f)
1022            | Expression::Nvl(f)
1023            | Expression::UnixToTimeStr(f)
1024            | Expression::Contains(f)
1025            | Expression::StartsWith(f)
1026            | Expression::EndsWith(f)
1027            | Expression::Levenshtein(f)
1028            | Expression::ModFunc(f)
1029            | Expression::Atan2(f)
1030            | Expression::IntDiv(f)
1031            | Expression::AddMonths(f)
1032            | Expression::MonthsBetween(f)
1033            | Expression::NextDay(f)
1034            | Expression::ArrayContains(f)
1035            | Expression::ArrayPosition(f)
1036            | Expression::ArrayAppend(f)
1037            | Expression::ArrayPrepend(f)
1038            | Expression::ArrayUnion(f)
1039            | Expression::ArrayExcept(f)
1040            | Expression::ArrayRemove(f)
1041            | Expression::StarMap(f)
1042            | Expression::MapFromArrays(f)
1043            | Expression::MapContainsKey(f)
1044            | Expression::ElementAt(f)
1045            | Expression::JsonMergePatch(f)
1046            | Expression::JSONBContains(f)
1047            | Expression::JSONBExtract(f) => {
1048                stack.push(&f.this);
1049                stack.push(&f.expression);
1050            }
1051
1052            // === VarArgFunc variants: expressions ===
1053            Expression::Greatest(f)
1054            | Expression::Least(f)
1055            | Expression::Coalesce(f)
1056            | Expression::ArrayConcat(f)
1057            | Expression::ArrayIntersect(f)
1058            | Expression::ArrayZip(f)
1059            | Expression::MapConcat(f)
1060            | Expression::JsonArray(f) => {
1061                for e in &f.expressions {
1062                    stack.push(e);
1063                }
1064            }
1065
1066            // === AggFunc variants: this, filter, having_max, limit ===
1067            Expression::Sum(f)
1068            | Expression::Avg(f)
1069            | Expression::Min(f)
1070            | Expression::Max(f)
1071            | Expression::ArrayAgg(f)
1072            | Expression::CountIf(f)
1073            | Expression::Stddev(f)
1074            | Expression::StddevPop(f)
1075            | Expression::StddevSamp(f)
1076            | Expression::Variance(f)
1077            | Expression::VarPop(f)
1078            | Expression::VarSamp(f)
1079            | Expression::Median(f)
1080            | Expression::Mode(f)
1081            | Expression::First(f)
1082            | Expression::Last(f)
1083            | Expression::AnyValue(f)
1084            | Expression::ApproxDistinct(f)
1085            | Expression::ApproxCountDistinct(f)
1086            | Expression::LogicalAnd(f)
1087            | Expression::LogicalOr(f)
1088            | Expression::Skewness(f)
1089            | Expression::ArrayConcatAgg(f)
1090            | Expression::ArrayUniqueAgg(f)
1091            | Expression::BoolXorAgg(f)
1092            | Expression::BitwiseAndAgg(f)
1093            | Expression::BitwiseOrAgg(f)
1094            | Expression::BitwiseXorAgg(f) => {
1095                stack.push(&f.this);
1096                if let Some(ref filter) = f.filter {
1097                    stack.push(filter);
1098                }
1099                if let Some((ref expr, _)) = f.having_max {
1100                    stack.push(expr);
1101                }
1102                if let Some(ref limit) = f.limit {
1103                    stack.push(limit);
1104                }
1105            }
1106
1107            // === Generic Function / AggregateFunction: args ===
1108            Expression::Function(func) => {
1109                for arg in &func.args {
1110                    stack.push(arg);
1111                }
1112            }
1113            Expression::AggregateFunction(func) => {
1114                for arg in &func.args {
1115                    stack.push(arg);
1116                }
1117                if let Some(ref filter) = func.filter {
1118                    stack.push(filter);
1119                }
1120                if let Some(ref limit) = func.limit {
1121                    stack.push(limit);
1122                }
1123            }
1124
1125            // === WindowFunction: this (skip Over for lineage purposes) ===
1126            Expression::WindowFunction(wf) => {
1127                stack.push(&wf.this);
1128            }
1129
1130            // === Containers and special expressions ===
1131            Expression::Alias(a) => {
1132                stack.push(&a.this);
1133            }
1134            Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1135                stack.push(&c.this);
1136                if let Some(ref fmt) = c.format {
1137                    stack.push(fmt);
1138                }
1139                if let Some(ref def) = c.default {
1140                    stack.push(def);
1141                }
1142            }
1143            Expression::Paren(p) => {
1144                stack.push(&p.this);
1145            }
1146            Expression::Annotated(a) => {
1147                stack.push(&a.this);
1148            }
1149            Expression::Case(case) => {
1150                if let Some(ref operand) = case.operand {
1151                    stack.push(operand);
1152                }
1153                for (cond, result) in &case.whens {
1154                    stack.push(cond);
1155                    stack.push(result);
1156                }
1157                if let Some(ref else_expr) = case.else_ {
1158                    stack.push(else_expr);
1159                }
1160            }
1161            Expression::Collation(c) => {
1162                stack.push(&c.this);
1163            }
1164            Expression::In(i) => {
1165                stack.push(&i.this);
1166                for e in &i.expressions {
1167                    stack.push(e);
1168                }
1169                if let Some(ref q) = i.query {
1170                    stack.push(q);
1171                }
1172                if let Some(ref u) = i.unnest {
1173                    stack.push(u);
1174                }
1175            }
1176            Expression::Between(b) => {
1177                stack.push(&b.this);
1178                stack.push(&b.low);
1179                stack.push(&b.high);
1180            }
1181            Expression::IsNull(n) => {
1182                stack.push(&n.this);
1183            }
1184            Expression::IsTrue(t) | Expression::IsFalse(t) => {
1185                stack.push(&t.this);
1186            }
1187            Expression::IsJson(j) => {
1188                stack.push(&j.this);
1189            }
1190            Expression::Like(l) | Expression::ILike(l) => {
1191                stack.push(&l.left);
1192                stack.push(&l.right);
1193                if let Some(ref esc) = l.escape {
1194                    stack.push(esc);
1195                }
1196            }
1197            Expression::SimilarTo(s) => {
1198                stack.push(&s.this);
1199                stack.push(&s.pattern);
1200                if let Some(ref esc) = s.escape {
1201                    stack.push(esc);
1202                }
1203            }
1204            Expression::Ordered(o) => {
1205                stack.push(&o.this);
1206            }
1207            Expression::Array(a) => {
1208                for e in &a.expressions {
1209                    stack.push(e);
1210                }
1211            }
1212            Expression::Tuple(t) => {
1213                for e in &t.expressions {
1214                    stack.push(e);
1215                }
1216            }
1217            Expression::Struct(s) => {
1218                for (_, e) in &s.fields {
1219                    stack.push(e);
1220                }
1221            }
1222            Expression::Subscript(s) => {
1223                stack.push(&s.this);
1224                stack.push(&s.index);
1225            }
1226            Expression::Dot(d) => {
1227                stack.push(&d.this);
1228            }
1229            Expression::MethodCall(m) => {
1230                stack.push(&m.this);
1231                for arg in &m.args {
1232                    stack.push(arg);
1233                }
1234            }
1235            Expression::ArraySlice(s) => {
1236                stack.push(&s.this);
1237                if let Some(ref start) = s.start {
1238                    stack.push(start);
1239                }
1240                if let Some(ref end) = s.end {
1241                    stack.push(end);
1242                }
1243            }
1244            Expression::Lambda(l) => {
1245                stack.push(&l.body);
1246            }
1247            Expression::NamedArgument(n) => {
1248                stack.push(&n.value);
1249            }
1250            Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1251                stack.push(e);
1252            }
1253
1254            // === Custom function structs ===
1255            Expression::Substring(f) => {
1256                stack.push(&f.this);
1257                stack.push(&f.start);
1258                if let Some(ref len) = f.length {
1259                    stack.push(len);
1260                }
1261            }
1262            Expression::Trim(f) => {
1263                stack.push(&f.this);
1264                if let Some(ref chars) = f.characters {
1265                    stack.push(chars);
1266                }
1267            }
1268            Expression::Replace(f) => {
1269                stack.push(&f.this);
1270                stack.push(&f.old);
1271                stack.push(&f.new);
1272            }
1273            Expression::IfFunc(f) => {
1274                stack.push(&f.condition);
1275                stack.push(&f.true_value);
1276                if let Some(ref fv) = f.false_value {
1277                    stack.push(fv);
1278                }
1279            }
1280            Expression::Nvl2(f) => {
1281                stack.push(&f.this);
1282                stack.push(&f.true_value);
1283                stack.push(&f.false_value);
1284            }
1285            Expression::ConcatWs(f) => {
1286                stack.push(&f.separator);
1287                for e in &f.expressions {
1288                    stack.push(e);
1289                }
1290            }
1291            Expression::Count(f) => {
1292                if let Some(ref this) = f.this {
1293                    stack.push(this);
1294                }
1295                if let Some(ref filter) = f.filter {
1296                    stack.push(filter);
1297                }
1298            }
1299            Expression::GroupConcat(f) => {
1300                stack.push(&f.this);
1301                if let Some(ref sep) = f.separator {
1302                    stack.push(sep);
1303                }
1304                if let Some(ref filter) = f.filter {
1305                    stack.push(filter);
1306                }
1307            }
1308            Expression::StringAgg(f) => {
1309                stack.push(&f.this);
1310                if let Some(ref sep) = f.separator {
1311                    stack.push(sep);
1312                }
1313                if let Some(ref filter) = f.filter {
1314                    stack.push(filter);
1315                }
1316                if let Some(ref limit) = f.limit {
1317                    stack.push(limit);
1318                }
1319            }
1320            Expression::ListAgg(f) => {
1321                stack.push(&f.this);
1322                if let Some(ref sep) = f.separator {
1323                    stack.push(sep);
1324                }
1325                if let Some(ref filter) = f.filter {
1326                    stack.push(filter);
1327                }
1328            }
1329            Expression::SumIf(f) => {
1330                stack.push(&f.this);
1331                stack.push(&f.condition);
1332                if let Some(ref filter) = f.filter {
1333                    stack.push(filter);
1334                }
1335            }
1336            Expression::DateAdd(f) | Expression::DateSub(f) => {
1337                stack.push(&f.this);
1338                stack.push(&f.interval);
1339            }
1340            Expression::DateDiff(f) => {
1341                stack.push(&f.this);
1342                stack.push(&f.expression);
1343            }
1344            Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1345                stack.push(&f.this);
1346            }
1347            Expression::Extract(f) => {
1348                stack.push(&f.this);
1349            }
1350            Expression::Round(f) => {
1351                stack.push(&f.this);
1352                if let Some(ref d) = f.decimals {
1353                    stack.push(d);
1354                }
1355            }
1356            Expression::Floor(f) => {
1357                stack.push(&f.this);
1358                if let Some(ref s) = f.scale {
1359                    stack.push(s);
1360                }
1361                if let Some(ref t) = f.to {
1362                    stack.push(t);
1363                }
1364            }
1365            Expression::Ceil(f) => {
1366                stack.push(&f.this);
1367                if let Some(ref d) = f.decimals {
1368                    stack.push(d);
1369                }
1370                if let Some(ref t) = f.to {
1371                    stack.push(t);
1372                }
1373            }
1374            Expression::Log(f) => {
1375                stack.push(&f.this);
1376                if let Some(ref b) = f.base {
1377                    stack.push(b);
1378                }
1379            }
1380            Expression::AtTimeZone(f) => {
1381                stack.push(&f.this);
1382                stack.push(&f.zone);
1383            }
1384            Expression::Lead(f) | Expression::Lag(f) => {
1385                stack.push(&f.this);
1386                if let Some(ref off) = f.offset {
1387                    stack.push(off);
1388                }
1389                if let Some(ref def) = f.default {
1390                    stack.push(def);
1391                }
1392            }
1393            Expression::FirstValue(f) | Expression::LastValue(f) => {
1394                stack.push(&f.this);
1395            }
1396            Expression::NthValue(f) => {
1397                stack.push(&f.this);
1398                stack.push(&f.offset);
1399            }
1400            Expression::Position(f) => {
1401                stack.push(&f.substring);
1402                stack.push(&f.string);
1403                if let Some(ref start) = f.start {
1404                    stack.push(start);
1405                }
1406            }
1407            Expression::Decode(f) => {
1408                stack.push(&f.this);
1409                for (search, result) in &f.search_results {
1410                    stack.push(search);
1411                    stack.push(result);
1412                }
1413                if let Some(ref def) = f.default {
1414                    stack.push(def);
1415                }
1416            }
1417            Expression::CharFunc(f) => {
1418                for arg in &f.args {
1419                    stack.push(arg);
1420                }
1421            }
1422            Expression::ArraySort(f) => {
1423                stack.push(&f.this);
1424                if let Some(ref cmp) = f.comparator {
1425                    stack.push(cmp);
1426                }
1427            }
1428            Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1429                stack.push(&f.this);
1430                stack.push(&f.separator);
1431                if let Some(ref nr) = f.null_replacement {
1432                    stack.push(nr);
1433                }
1434            }
1435            Expression::ArrayFilter(f) => {
1436                stack.push(&f.this);
1437                stack.push(&f.filter);
1438            }
1439            Expression::ArrayTransform(f) => {
1440                stack.push(&f.this);
1441                stack.push(&f.transform);
1442            }
1443            Expression::Sequence(f)
1444            | Expression::Generate(f)
1445            | Expression::ExplodingGenerateSeries(f) => {
1446                stack.push(&f.start);
1447                stack.push(&f.stop);
1448                if let Some(ref step) = f.step {
1449                    stack.push(step);
1450                }
1451            }
1452            Expression::JsonExtract(f)
1453            | Expression::JsonExtractScalar(f)
1454            | Expression::JsonQuery(f)
1455            | Expression::JsonValue(f) => {
1456                stack.push(&f.this);
1457                stack.push(&f.path);
1458            }
1459            Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1460                stack.push(&f.this);
1461                for p in &f.paths {
1462                    stack.push(p);
1463                }
1464            }
1465            Expression::JsonObject(f) => {
1466                for (k, v) in &f.pairs {
1467                    stack.push(k);
1468                    stack.push(v);
1469                }
1470            }
1471            Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1472                stack.push(&f.this);
1473                for (path, val) in &f.path_values {
1474                    stack.push(path);
1475                    stack.push(val);
1476                }
1477            }
1478            Expression::Overlay(f) => {
1479                stack.push(&f.this);
1480                stack.push(&f.replacement);
1481                stack.push(&f.from);
1482                if let Some(ref len) = f.length {
1483                    stack.push(len);
1484                }
1485            }
1486            Expression::Convert(f) => {
1487                stack.push(&f.this);
1488                if let Some(ref style) = f.style {
1489                    stack.push(style);
1490                }
1491            }
1492            Expression::ApproxPercentile(f) => {
1493                stack.push(&f.this);
1494                stack.push(&f.percentile);
1495                if let Some(ref acc) = f.accuracy {
1496                    stack.push(acc);
1497                }
1498                if let Some(ref filter) = f.filter {
1499                    stack.push(filter);
1500                }
1501            }
1502            Expression::Percentile(f)
1503            | Expression::PercentileCont(f)
1504            | Expression::PercentileDisc(f) => {
1505                stack.push(&f.this);
1506                stack.push(&f.percentile);
1507                if let Some(ref filter) = f.filter {
1508                    stack.push(filter);
1509                }
1510            }
1511            Expression::WithinGroup(f) => {
1512                stack.push(&f.this);
1513            }
1514            Expression::Left(f) | Expression::Right(f) => {
1515                stack.push(&f.this);
1516                stack.push(&f.length);
1517            }
1518            Expression::Repeat(f) => {
1519                stack.push(&f.this);
1520                stack.push(&f.times);
1521            }
1522            Expression::Lpad(f) | Expression::Rpad(f) => {
1523                stack.push(&f.this);
1524                stack.push(&f.length);
1525                if let Some(ref fill) = f.fill {
1526                    stack.push(fill);
1527                }
1528            }
1529            Expression::Split(f) => {
1530                stack.push(&f.this);
1531                stack.push(&f.delimiter);
1532            }
1533            Expression::RegexpLike(f) => {
1534                stack.push(&f.this);
1535                stack.push(&f.pattern);
1536                if let Some(ref flags) = f.flags {
1537                    stack.push(flags);
1538                }
1539            }
1540            Expression::RegexpReplace(f) => {
1541                stack.push(&f.this);
1542                stack.push(&f.pattern);
1543                stack.push(&f.replacement);
1544                if let Some(ref flags) = f.flags {
1545                    stack.push(flags);
1546                }
1547            }
1548            Expression::RegexpExtract(f) => {
1549                stack.push(&f.this);
1550                stack.push(&f.pattern);
1551                if let Some(ref group) = f.group {
1552                    stack.push(group);
1553                }
1554            }
1555            Expression::ToDate(f) => {
1556                stack.push(&f.this);
1557                if let Some(ref fmt) = f.format {
1558                    stack.push(fmt);
1559                }
1560            }
1561            Expression::ToTimestamp(f) => {
1562                stack.push(&f.this);
1563                if let Some(ref fmt) = f.format {
1564                    stack.push(fmt);
1565                }
1566            }
1567            Expression::DateFormat(f) | Expression::FormatDate(f) => {
1568                stack.push(&f.this);
1569                stack.push(&f.format);
1570            }
1571            Expression::LastDay(f) => {
1572                stack.push(&f.this);
1573            }
1574            Expression::FromUnixtime(f) => {
1575                stack.push(&f.this);
1576                if let Some(ref fmt) = f.format {
1577                    stack.push(fmt);
1578                }
1579            }
1580            Expression::UnixTimestamp(f) => {
1581                if let Some(ref this) = f.this {
1582                    stack.push(this);
1583                }
1584                if let Some(ref fmt) = f.format {
1585                    stack.push(fmt);
1586                }
1587            }
1588            Expression::MakeDate(f) => {
1589                stack.push(&f.year);
1590                stack.push(&f.month);
1591                stack.push(&f.day);
1592            }
1593            Expression::MakeTimestamp(f) => {
1594                stack.push(&f.year);
1595                stack.push(&f.month);
1596                stack.push(&f.day);
1597                stack.push(&f.hour);
1598                stack.push(&f.minute);
1599                stack.push(&f.second);
1600                if let Some(ref tz) = f.timezone {
1601                    stack.push(tz);
1602                }
1603            }
1604            Expression::TruncFunc(f) => {
1605                stack.push(&f.this);
1606                if let Some(ref d) = f.decimals {
1607                    stack.push(d);
1608                }
1609            }
1610            Expression::ArrayFunc(f) => {
1611                for e in &f.expressions {
1612                    stack.push(e);
1613                }
1614            }
1615            Expression::Unnest(f) => {
1616                stack.push(&f.this);
1617                for e in &f.expressions {
1618                    stack.push(e);
1619                }
1620            }
1621            Expression::StructFunc(f) => {
1622                for (_, e) in &f.fields {
1623                    stack.push(e);
1624                }
1625            }
1626            Expression::StructExtract(f) => {
1627                stack.push(&f.this);
1628            }
1629            Expression::NamedStruct(f) => {
1630                for (k, v) in &f.pairs {
1631                    stack.push(k);
1632                    stack.push(v);
1633                }
1634            }
1635            Expression::MapFunc(f) => {
1636                for k in &f.keys {
1637                    stack.push(k);
1638                }
1639                for v in &f.values {
1640                    stack.push(v);
1641                }
1642            }
1643            Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1644                stack.push(&f.this);
1645                stack.push(&f.transform);
1646            }
1647            Expression::JsonArrayAgg(f) => {
1648                stack.push(&f.this);
1649                if let Some(ref filter) = f.filter {
1650                    stack.push(filter);
1651                }
1652            }
1653            Expression::JsonObjectAgg(f) => {
1654                stack.push(&f.key);
1655                stack.push(&f.value);
1656                if let Some(ref filter) = f.filter {
1657                    stack.push(filter);
1658                }
1659            }
1660            Expression::NTile(f) => {
1661                if let Some(ref n) = f.num_buckets {
1662                    stack.push(n);
1663                }
1664            }
1665            Expression::Rand(f) => {
1666                if let Some(ref s) = f.seed {
1667                    stack.push(s);
1668                }
1669                if let Some(ref lo) = f.lower {
1670                    stack.push(lo);
1671                }
1672                if let Some(ref hi) = f.upper {
1673                    stack.push(hi);
1674                }
1675            }
1676            Expression::Any(q) | Expression::All(q) => {
1677                stack.push(&q.this);
1678                stack.push(&q.subquery);
1679            }
1680            Expression::Overlaps(o) => {
1681                if let Some(ref this) = o.this {
1682                    stack.push(this);
1683                }
1684                if let Some(ref expr) = o.expression {
1685                    stack.push(expr);
1686                }
1687                if let Some(ref ls) = o.left_start {
1688                    stack.push(ls);
1689                }
1690                if let Some(ref le) = o.left_end {
1691                    stack.push(le);
1692                }
1693                if let Some(ref rs) = o.right_start {
1694                    stack.push(rs);
1695                }
1696                if let Some(ref re) = o.right_end {
1697                    stack.push(re);
1698                }
1699            }
1700            Expression::Interval(i) => {
1701                if let Some(ref this) = i.this {
1702                    stack.push(this);
1703                }
1704            }
1705            Expression::TimeStrToTime(f) => {
1706                stack.push(&f.this);
1707                if let Some(ref zone) = f.zone {
1708                    stack.push(zone);
1709                }
1710            }
1711            Expression::JSONBExtractScalar(f) => {
1712                stack.push(&f.this);
1713                stack.push(&f.expression);
1714                if let Some(ref jt) = f.json_type {
1715                    stack.push(jt);
1716                }
1717            }
1718
1719            // === True leaves and non-expression-bearing nodes ===
1720            // Literals, Identifier, Star, DataType, Placeholder, Boolean, Null,
1721            // CurrentDate/Time/Timestamp, RowNumber, Rank, DenseRank, PercentRank,
1722            // CumeDist, Random, Pi, SessionUser, DDL statements, clauses, etc.
1723            _ => {}
1724        }
1725    }
1726}
1727
1728// ---------------------------------------------------------------------------
1729// Tests
1730// ---------------------------------------------------------------------------
1731
1732#[cfg(test)]
1733mod tests {
1734    use super::*;
1735    use crate::dialects::{Dialect, DialectType};
1736    use crate::expressions::DataType;
1737    use crate::optimizer::annotate_types::annotate_types;
1738    use crate::parse_one;
1739    use crate::schema::{MappingSchema, Schema};
1740
1741    fn parse(sql: &str) -> Expression {
1742        let dialect = Dialect::get(DialectType::Generic);
1743        let ast = dialect.parse(sql).unwrap();
1744        ast.into_iter().next().unwrap()
1745    }
1746
1747    #[test]
1748    fn test_simple_lineage() {
1749        let expr = parse("SELECT a FROM t");
1750        let node = lineage("a", &expr, None, false).unwrap();
1751
1752        assert_eq!(node.name, "a");
1753        assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1754        // Should trace to t.a
1755        let names = node.downstream_names();
1756        assert!(
1757            names.iter().any(|n| n == "t.a"),
1758            "Expected t.a in downstream, got: {:?}",
1759            names
1760        );
1761    }
1762
1763    #[test]
1764    fn test_lineage_walk() {
1765        let root = LineageNode {
1766            name: "col_a".to_string(),
1767            expression: Expression::Null(crate::expressions::Null),
1768            source: Expression::Null(crate::expressions::Null),
1769            downstream: vec![LineageNode::new(
1770                "t.a",
1771                Expression::Null(crate::expressions::Null),
1772                Expression::Null(crate::expressions::Null),
1773            )],
1774            source_name: String::new(),
1775            reference_node_name: String::new(),
1776        };
1777
1778        let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1779        assert_eq!(names.len(), 2);
1780        assert_eq!(names[0], "col_a");
1781        assert_eq!(names[1], "t.a");
1782    }
1783
1784    #[test]
1785    fn test_aliased_column() {
1786        let expr = parse("SELECT a + 1 AS b FROM t");
1787        let node = lineage("b", &expr, None, false).unwrap();
1788
1789        assert_eq!(node.name, "b");
1790        // Should trace through the expression to t.a
1791        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1792        assert!(
1793            all_names.iter().any(|n| n.contains("a")),
1794            "Expected to trace to column a, got: {:?}",
1795            all_names
1796        );
1797    }
1798
1799    #[test]
1800    fn test_qualified_column() {
1801        let expr = parse("SELECT t.a FROM t");
1802        let node = lineage("a", &expr, None, false).unwrap();
1803
1804        assert_eq!(node.name, "a");
1805        let names = node.downstream_names();
1806        assert!(
1807            names.iter().any(|n| n == "t.a"),
1808            "Expected t.a, got: {:?}",
1809            names
1810        );
1811    }
1812
1813    #[test]
1814    fn test_unqualified_column() {
1815        let expr = parse("SELECT a FROM t");
1816        let node = lineage("a", &expr, None, false).unwrap();
1817
1818        // Unqualified but single source → resolved to t.a
1819        let names = node.downstream_names();
1820        assert!(
1821            names.iter().any(|n| n == "t.a"),
1822            "Expected t.a, got: {:?}",
1823            names
1824        );
1825    }
1826
1827    #[test]
1828    fn test_lineage_with_schema_qualifies_root_expression_issue_40() {
1829        let query = "SELECT name FROM users";
1830        let dialect = Dialect::get(DialectType::BigQuery);
1831        let expr = dialect
1832            .parse(query)
1833            .unwrap()
1834            .into_iter()
1835            .next()
1836            .expect("expected one expression");
1837
1838        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1839        schema
1840            .add_table("users", &[("name".into(), DataType::Text)], None)
1841            .expect("schema setup");
1842
1843        let node_without_schema = lineage("name", &expr, Some(DialectType::BigQuery), false)
1844            .expect("lineage without schema");
1845        let mut expr_without = node_without_schema.expression.clone();
1846        annotate_types(
1847            &mut expr_without,
1848            Some(&schema),
1849            Some(DialectType::BigQuery),
1850        );
1851        assert_eq!(
1852            expr_without.inferred_type(),
1853            None,
1854            "Expected unresolved root type without schema-aware lineage qualification"
1855        );
1856
1857        let node_with_schema = lineage_with_schema(
1858            "name",
1859            &expr,
1860            Some(&schema),
1861            Some(DialectType::BigQuery),
1862            false,
1863        )
1864        .expect("lineage with schema");
1865        let mut expr_with = node_with_schema.expression.clone();
1866        annotate_types(&mut expr_with, Some(&schema), Some(DialectType::BigQuery));
1867
1868        assert_eq!(expr_with.inferred_type(), Some(&DataType::Text));
1869    }
1870
1871    #[test]
1872    fn test_lineage_with_schema_correlated_scalar_subquery() {
1873        let query = "SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1";
1874        let dialect = Dialect::get(DialectType::BigQuery);
1875        let expr = dialect
1876            .parse(query)
1877            .unwrap()
1878            .into_iter()
1879            .next()
1880            .expect("expected one expression");
1881
1882        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1883        schema
1884            .add_table(
1885                "t1",
1886                &[("id".into(), DataType::BigInt { length: None })],
1887                None,
1888            )
1889            .expect("schema setup");
1890        schema
1891            .add_table(
1892                "t2",
1893                &[
1894                    ("id".into(), DataType::BigInt { length: None }),
1895                    ("val".into(), DataType::BigInt { length: None }),
1896                ],
1897                None,
1898            )
1899            .expect("schema setup");
1900
1901        let node = lineage_with_schema(
1902            "id",
1903            &expr,
1904            Some(&schema),
1905            Some(DialectType::BigQuery),
1906            false,
1907        )
1908        .expect("lineage_with_schema should handle correlated scalar subqueries");
1909
1910        assert_eq!(node.name, "id");
1911    }
1912
1913    #[test]
1914    fn test_lineage_with_schema_join_using() {
1915        let query = "SELECT a FROM t1 JOIN t2 USING(a)";
1916        let dialect = Dialect::get(DialectType::BigQuery);
1917        let expr = dialect
1918            .parse(query)
1919            .unwrap()
1920            .into_iter()
1921            .next()
1922            .expect("expected one expression");
1923
1924        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1925        schema
1926            .add_table(
1927                "t1",
1928                &[("a".into(), DataType::BigInt { length: None })],
1929                None,
1930            )
1931            .expect("schema setup");
1932        schema
1933            .add_table(
1934                "t2",
1935                &[("a".into(), DataType::BigInt { length: None })],
1936                None,
1937            )
1938            .expect("schema setup");
1939
1940        let node = lineage_with_schema(
1941            "a",
1942            &expr,
1943            Some(&schema),
1944            Some(DialectType::BigQuery),
1945            false,
1946        )
1947        .expect("lineage_with_schema should handle JOIN USING");
1948
1949        assert_eq!(node.name, "a");
1950    }
1951
1952    #[test]
1953    fn test_lineage_with_schema_qualified_table_name() {
1954        let query = "SELECT a FROM raw.t1";
1955        let dialect = Dialect::get(DialectType::BigQuery);
1956        let expr = dialect
1957            .parse(query)
1958            .unwrap()
1959            .into_iter()
1960            .next()
1961            .expect("expected one expression");
1962
1963        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1964        schema
1965            .add_table(
1966                "raw.t1",
1967                &[("a".into(), DataType::BigInt { length: None })],
1968                None,
1969            )
1970            .expect("schema setup");
1971
1972        let node = lineage_with_schema(
1973            "a",
1974            &expr,
1975            Some(&schema),
1976            Some(DialectType::BigQuery),
1977            false,
1978        )
1979        .expect("lineage_with_schema should handle dotted schema.table names");
1980
1981        assert_eq!(node.name, "a");
1982    }
1983
1984    #[test]
1985    fn test_lineage_with_schema_none_matches_lineage() {
1986        let expr = parse("SELECT a FROM t");
1987        let baseline = lineage("a", &expr, None, false).expect("lineage baseline");
1988        let with_none =
1989            lineage_with_schema("a", &expr, None, None, false).expect("lineage_with_schema");
1990
1991        assert_eq!(with_none.name, baseline.name);
1992        assert_eq!(with_none.downstream_names(), baseline.downstream_names());
1993    }
1994
1995    #[test]
1996    fn test_lineage_with_schema_bigquery_mixed_case_column_names_issue_60() {
1997        let dialect = Dialect::get(DialectType::BigQuery);
1998        let expr = dialect
1999            .parse("SELECT Name AS name FROM teams")
2000            .unwrap()
2001            .into_iter()
2002            .next()
2003            .expect("expected one expression");
2004
2005        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
2006        schema
2007            .add_table(
2008                "teams",
2009                &[("Name".into(), DataType::String { length: None })],
2010                None,
2011            )
2012            .expect("schema setup");
2013
2014        let node = lineage_with_schema(
2015            "name",
2016            &expr,
2017            Some(&schema),
2018            Some(DialectType::BigQuery),
2019            false,
2020        )
2021        .expect("lineage_with_schema should resolve mixed-case BigQuery columns");
2022
2023        let names = node.downstream_names();
2024        assert!(
2025            names.iter().any(|n| n == "teams.Name"),
2026            "Expected teams.Name in downstream, got: {:?}",
2027            names
2028        );
2029    }
2030
2031    #[test]
2032    fn test_lineage_bigquery_mixed_case_alias_lookup() {
2033        let dialect = Dialect::get(DialectType::BigQuery);
2034        let expr = dialect
2035            .parse("SELECT Name AS Name FROM teams")
2036            .unwrap()
2037            .into_iter()
2038            .next()
2039            .expect("expected one expression");
2040
2041        let node = lineage("name", &expr, Some(DialectType::BigQuery), false)
2042            .expect("lineage should resolve mixed-case aliases in BigQuery");
2043
2044        assert_eq!(node.name, "name");
2045    }
2046
2047    #[test]
2048    fn test_lineage_with_schema_snowflake_datediff_date_part_issue_61() {
2049        let expr = parse_one(
2050            "SELECT DATEDIFF(day, date_utc, CURRENT_DATE()) AS recency FROM fact.some_daily_metrics",
2051            DialectType::Snowflake,
2052        )
2053        .expect("parse");
2054
2055        let mut schema = MappingSchema::with_dialect(DialectType::Snowflake);
2056        schema
2057            .add_table(
2058                "fact.some_daily_metrics",
2059                &[("date_utc".to_string(), DataType::Date)],
2060                None,
2061            )
2062            .expect("schema setup");
2063
2064        let node = lineage_with_schema(
2065            "recency",
2066            &expr,
2067            Some(&schema),
2068            Some(DialectType::Snowflake),
2069            false,
2070        )
2071        .expect("lineage_with_schema should not treat date part as a column");
2072
2073        let names = node.downstream_names();
2074        assert!(
2075            names.iter().any(|n| n == "some_daily_metrics.date_utc"),
2076            "Expected some_daily_metrics.date_utc in downstream, got: {:?}",
2077            names
2078        );
2079        assert!(
2080            !names.iter().any(|n| n.ends_with(".day") || n == "day"),
2081            "Did not expect date part to appear as lineage column, got: {:?}",
2082            names
2083        );
2084    }
2085
2086    #[test]
2087    fn test_snowflake_datediff_parses_to_typed_ast() {
2088        let expr = parse_one(
2089            "SELECT DATEDIFF(day, date_utc, CURRENT_DATE()) AS recency FROM fact.some_daily_metrics",
2090            DialectType::Snowflake,
2091        )
2092        .expect("parse");
2093
2094        match expr {
2095            Expression::Select(select) => match &select.expressions[0] {
2096                Expression::Alias(alias) => match &alias.this {
2097                    Expression::DateDiff(f) => {
2098                        assert_eq!(f.unit, Some(crate::expressions::IntervalUnit::Day));
2099                    }
2100                    other => panic!("expected DateDiff, got {other:?}"),
2101                },
2102                other => panic!("expected Alias, got {other:?}"),
2103            },
2104            other => panic!("expected Select, got {other:?}"),
2105        }
2106    }
2107
2108    #[test]
2109    fn test_lineage_with_schema_snowflake_dateadd_date_part_issue_followup() {
2110        let expr = parse_one(
2111            "SELECT DATEADD(day, 1, date_utc) AS next_day FROM fact.some_daily_metrics",
2112            DialectType::Snowflake,
2113        )
2114        .expect("parse");
2115
2116        let mut schema = MappingSchema::with_dialect(DialectType::Snowflake);
2117        schema
2118            .add_table(
2119                "fact.some_daily_metrics",
2120                &[("date_utc".to_string(), DataType::Date)],
2121                None,
2122            )
2123            .expect("schema setup");
2124
2125        let node = lineage_with_schema(
2126            "next_day",
2127            &expr,
2128            Some(&schema),
2129            Some(DialectType::Snowflake),
2130            false,
2131        )
2132        .expect("lineage_with_schema should not treat DATEADD date part as a column");
2133
2134        let names = node.downstream_names();
2135        assert!(
2136            names.iter().any(|n| n == "some_daily_metrics.date_utc"),
2137            "Expected some_daily_metrics.date_utc in downstream, got: {:?}",
2138            names
2139        );
2140        assert!(
2141            !names.iter().any(|n| n.ends_with(".day") || n == "day"),
2142            "Did not expect date part to appear as lineage column, got: {:?}",
2143            names
2144        );
2145    }
2146
2147    #[test]
2148    fn test_lineage_with_schema_snowflake_date_part_identifier_issue_followup() {
2149        let expr = parse_one(
2150            "SELECT DATE_PART(day, date_utc) AS day_part FROM fact.some_daily_metrics",
2151            DialectType::Snowflake,
2152        )
2153        .expect("parse");
2154
2155        let mut schema = MappingSchema::with_dialect(DialectType::Snowflake);
2156        schema
2157            .add_table(
2158                "fact.some_daily_metrics",
2159                &[("date_utc".to_string(), DataType::Date)],
2160                None,
2161            )
2162            .expect("schema setup");
2163
2164        let node = lineage_with_schema(
2165            "day_part",
2166            &expr,
2167            Some(&schema),
2168            Some(DialectType::Snowflake),
2169            false,
2170        )
2171        .expect("lineage_with_schema should not treat DATE_PART identifier as a column");
2172
2173        let names = node.downstream_names();
2174        assert!(
2175            names.iter().any(|n| n == "some_daily_metrics.date_utc"),
2176            "Expected some_daily_metrics.date_utc in downstream, got: {:?}",
2177            names
2178        );
2179        assert!(
2180            !names.iter().any(|n| n.ends_with(".day") || n == "day"),
2181            "Did not expect date part to appear as lineage column, got: {:?}",
2182            names
2183        );
2184    }
2185
2186    #[test]
2187    fn test_lineage_with_schema_snowflake_date_part_string_literal_control() {
2188        let expr = parse_one(
2189            "SELECT DATE_PART('day', date_utc) AS day_part FROM fact.some_daily_metrics",
2190            DialectType::Snowflake,
2191        )
2192        .expect("parse");
2193
2194        let mut schema = MappingSchema::with_dialect(DialectType::Snowflake);
2195        schema
2196            .add_table(
2197                "fact.some_daily_metrics",
2198                &[("date_utc".to_string(), DataType::Date)],
2199                None,
2200            )
2201            .expect("schema setup");
2202
2203        let node = lineage_with_schema(
2204            "day_part",
2205            &expr,
2206            Some(&schema),
2207            Some(DialectType::Snowflake),
2208            false,
2209        )
2210        .expect("quoted DATE_PART should continue to work");
2211
2212        let names = node.downstream_names();
2213        assert!(
2214            names.iter().any(|n| n == "some_daily_metrics.date_utc"),
2215            "Expected some_daily_metrics.date_utc in downstream, got: {:?}",
2216            names
2217        );
2218    }
2219
2220    #[test]
2221    fn test_snowflake_dateadd_date_part_identifier_stays_generic_function() {
2222        let expr = parse_one(
2223            "SELECT DATEADD(day, 1, date_utc) AS next_day FROM fact.some_daily_metrics",
2224            DialectType::Snowflake,
2225        )
2226        .expect("parse");
2227
2228        match expr {
2229            Expression::Select(select) => match &select.expressions[0] {
2230                Expression::Alias(alias) => match &alias.this {
2231                    Expression::Function(f) => {
2232                        assert_eq!(f.name.to_uppercase(), "DATEADD");
2233                        assert!(matches!(&f.args[0], Expression::Var(v) if v.this == "day"));
2234                    }
2235                    other => panic!("expected generic DATEADD function, got {other:?}"),
2236                },
2237                other => panic!("expected Alias, got {other:?}"),
2238            },
2239            other => panic!("expected Select, got {other:?}"),
2240        }
2241    }
2242
2243    #[test]
2244    fn test_snowflake_date_part_identifier_stays_generic_function_with_var_arg() {
2245        let expr = parse_one(
2246            "SELECT DATE_PART(day, date_utc) AS day_part FROM fact.some_daily_metrics",
2247            DialectType::Snowflake,
2248        )
2249        .expect("parse");
2250
2251        match expr {
2252            Expression::Select(select) => match &select.expressions[0] {
2253                Expression::Alias(alias) => match &alias.this {
2254                    Expression::Function(f) => {
2255                        assert_eq!(f.name.to_uppercase(), "DATE_PART");
2256                        assert!(matches!(&f.args[0], Expression::Var(v) if v.this == "day"));
2257                    }
2258                    other => panic!("expected generic DATE_PART function, got {other:?}"),
2259                },
2260                other => panic!("expected Alias, got {other:?}"),
2261            },
2262            other => panic!("expected Select, got {other:?}"),
2263        }
2264    }
2265
2266    #[test]
2267    fn test_snowflake_date_part_string_literal_stays_generic_function() {
2268        let expr = parse_one(
2269            "SELECT DATE_PART('day', date_utc) AS day_part FROM fact.some_daily_metrics",
2270            DialectType::Snowflake,
2271        )
2272        .expect("parse");
2273
2274        match expr {
2275            Expression::Select(select) => match &select.expressions[0] {
2276                Expression::Alias(alias) => match &alias.this {
2277                    Expression::Function(f) => {
2278                        assert_eq!(f.name.to_uppercase(), "DATE_PART");
2279                    }
2280                    other => panic!("expected generic DATE_PART function, got {other:?}"),
2281                },
2282                other => panic!("expected Alias, got {other:?}"),
2283            },
2284            other => panic!("expected Select, got {other:?}"),
2285        }
2286    }
2287
2288    #[test]
2289    fn test_lineage_join() {
2290        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
2291
2292        let node_a = lineage("a", &expr, None, false).unwrap();
2293        let names_a = node_a.downstream_names();
2294        assert!(
2295            names_a.iter().any(|n| n == "t.a"),
2296            "Expected t.a, got: {:?}",
2297            names_a
2298        );
2299
2300        let node_b = lineage("b", &expr, None, false).unwrap();
2301        let names_b = node_b.downstream_names();
2302        assert!(
2303            names_b.iter().any(|n| n == "s.b"),
2304            "Expected s.b, got: {:?}",
2305            names_b
2306        );
2307    }
2308
2309    #[test]
2310    fn test_lineage_alias_leaf_has_resolved_source_name() {
2311        let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
2312        let node = lineage("col1", &expr, None, false).unwrap();
2313
2314        // Keep alias in the display lineage edge.
2315        let names = node.downstream_names();
2316        assert!(
2317            names.iter().any(|n| n == "t1.col1"),
2318            "Expected aliased column edge t1.col1, got: {:?}",
2319            names
2320        );
2321
2322        // Leaf should expose the resolved base table for consumers.
2323        let leaf = node
2324            .downstream
2325            .iter()
2326            .find(|n| n.name == "t1.col1")
2327            .expect("Expected t1.col1 leaf");
2328        assert_eq!(leaf.source_name, "table1");
2329        match &leaf.source {
2330            Expression::Table(table) => assert_eq!(table.name.name, "table1"),
2331            _ => panic!("Expected leaf source to be a table expression"),
2332        }
2333    }
2334
2335    #[test]
2336    fn test_lineage_derived_table() {
2337        let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
2338        let node = lineage("a", &expr, None, false).unwrap();
2339
2340        assert_eq!(node.name, "a");
2341        // Should trace through the derived table to t.a
2342        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2343        assert!(
2344            all_names.iter().any(|n| n == "t.a"),
2345            "Expected to trace through derived table to t.a, got: {:?}",
2346            all_names
2347        );
2348    }
2349
2350    #[test]
2351    fn test_lineage_cte() {
2352        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
2353        let node = lineage("a", &expr, None, false).unwrap();
2354
2355        assert_eq!(node.name, "a");
2356        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2357        assert!(
2358            all_names.iter().any(|n| n == "t.a"),
2359            "Expected to trace through CTE to t.a, got: {:?}",
2360            all_names
2361        );
2362    }
2363
2364    #[test]
2365    fn test_lineage_union() {
2366        let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
2367        let node = lineage("a", &expr, None, false).unwrap();
2368
2369        assert_eq!(node.name, "a");
2370        // Should have 2 downstream branches
2371        assert_eq!(
2372            node.downstream.len(),
2373            2,
2374            "Expected 2 branches for UNION, got {}",
2375            node.downstream.len()
2376        );
2377    }
2378
2379    #[test]
2380    fn test_lineage_cte_union() {
2381        let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
2382        let node = lineage("a", &expr, None, false).unwrap();
2383
2384        // Should trace through CTE into both UNION branches
2385        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2386        assert!(
2387            all_names.len() >= 3,
2388            "Expected at least 3 nodes for CTE with UNION, got: {:?}",
2389            all_names
2390        );
2391    }
2392
2393    #[test]
2394    fn test_lineage_star() {
2395        let expr = parse("SELECT * FROM t");
2396        let node = lineage("*", &expr, None, false).unwrap();
2397
2398        assert_eq!(node.name, "*");
2399        // Should have downstream for table t
2400        assert!(
2401            !node.downstream.is_empty(),
2402            "Star should produce downstream nodes"
2403        );
2404    }
2405
2406    #[test]
2407    fn test_lineage_subquery_in_select() {
2408        let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
2409        let node = lineage("x", &expr, None, false).unwrap();
2410
2411        assert_eq!(node.name, "x");
2412        // Should have traced into the scalar subquery
2413        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2414        assert!(
2415            all_names.len() >= 2,
2416            "Expected tracing into scalar subquery, got: {:?}",
2417            all_names
2418        );
2419    }
2420
2421    #[test]
2422    fn test_lineage_multiple_columns() {
2423        let expr = parse("SELECT a, b FROM t");
2424
2425        let node_a = lineage("a", &expr, None, false).unwrap();
2426        let node_b = lineage("b", &expr, None, false).unwrap();
2427
2428        assert_eq!(node_a.name, "a");
2429        assert_eq!(node_b.name, "b");
2430
2431        // Each should trace independently
2432        let names_a = node_a.downstream_names();
2433        let names_b = node_b.downstream_names();
2434        assert!(names_a.iter().any(|n| n == "t.a"));
2435        assert!(names_b.iter().any(|n| n == "t.b"));
2436    }
2437
2438    #[test]
2439    fn test_get_source_tables() {
2440        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
2441        let node = lineage("a", &expr, None, false).unwrap();
2442
2443        let tables = get_source_tables(&node);
2444        assert!(
2445            tables.contains("t"),
2446            "Expected source table 't', got: {:?}",
2447            tables
2448        );
2449    }
2450
2451    #[test]
2452    fn test_lineage_column_not_found() {
2453        let expr = parse("SELECT a FROM t");
2454        let result = lineage("nonexistent", &expr, None, false);
2455        assert!(result.is_err());
2456    }
2457
2458    #[test]
2459    fn test_lineage_nested_cte() {
2460        let expr = parse(
2461            "WITH cte1 AS (SELECT a FROM t), \
2462             cte2 AS (SELECT a FROM cte1) \
2463             SELECT a FROM cte2",
2464        );
2465        let node = lineage("a", &expr, None, false).unwrap();
2466
2467        // Should trace through cte2 → cte1 → t
2468        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2469        assert!(
2470            all_names.len() >= 3,
2471            "Expected to trace through nested CTEs, got: {:?}",
2472            all_names
2473        );
2474    }
2475
2476    #[test]
2477    fn test_trim_selects_true() {
2478        let expr = parse("SELECT a, b, c FROM t");
2479        let node = lineage("a", &expr, None, true).unwrap();
2480
2481        // The source should be trimmed to only include 'a'
2482        if let Expression::Select(select) = &node.source {
2483            assert_eq!(
2484                select.expressions.len(),
2485                1,
2486                "Trimmed source should have 1 expression, got {}",
2487                select.expressions.len()
2488            );
2489        } else {
2490            panic!("Expected Select source");
2491        }
2492    }
2493
2494    #[test]
2495    fn test_trim_selects_false() {
2496        let expr = parse("SELECT a, b, c FROM t");
2497        let node = lineage("a", &expr, None, false).unwrap();
2498
2499        // The source should keep all columns
2500        if let Expression::Select(select) = &node.source {
2501            assert_eq!(
2502                select.expressions.len(),
2503                3,
2504                "Untrimmed source should have 3 expressions"
2505            );
2506        } else {
2507            panic!("Expected Select source");
2508        }
2509    }
2510
2511    #[test]
2512    fn test_lineage_expression_in_select() {
2513        let expr = parse("SELECT a + b AS c FROM t");
2514        let node = lineage("c", &expr, None, false).unwrap();
2515
2516        // Should trace to both a and b from t
2517        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2518        assert!(
2519            all_names.len() >= 3,
2520            "Expected to trace a + b to both columns, got: {:?}",
2521            all_names
2522        );
2523    }
2524
2525    #[test]
2526    fn test_set_operation_by_index() {
2527        let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
2528
2529        // Trace column "a" which is at index 0
2530        let node = lineage("a", &expr, None, false).unwrap();
2531
2532        // UNION branches should be traced by index
2533        assert_eq!(node.downstream.len(), 2);
2534    }
2535
2536    // --- Tests for column lineage inside function calls (issue #18) ---
2537
2538    fn print_node(node: &LineageNode, indent: usize) {
2539        let pad = "  ".repeat(indent);
2540        println!(
2541            "{pad}name={:?} source_name={:?}",
2542            node.name, node.source_name
2543        );
2544        for child in &node.downstream {
2545            print_node(child, indent + 1);
2546        }
2547    }
2548
2549    #[test]
2550    fn test_issue18_repro() {
2551        // Exact scenario from the issue
2552        let query = "SELECT UPPER(name) as upper_name FROM users";
2553        println!("Query: {query}\n");
2554
2555        let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2556        let exprs = dialect.parse(query).unwrap();
2557        let expr = &exprs[0];
2558
2559        let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2560        println!("lineage(\"upper_name\"):");
2561        print_node(&node, 1);
2562
2563        let names = node.downstream_names();
2564        assert!(
2565            names.iter().any(|n| n == "users.name"),
2566            "Expected users.name in downstream, got: {:?}",
2567            names
2568        );
2569    }
2570
2571    #[test]
2572    fn test_lineage_upper_function() {
2573        let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2574        let node = lineage("upper_name", &expr, None, false).unwrap();
2575
2576        let names = node.downstream_names();
2577        assert!(
2578            names.iter().any(|n| n == "users.name"),
2579            "Expected users.name in downstream, got: {:?}",
2580            names
2581        );
2582    }
2583
2584    #[test]
2585    fn test_lineage_round_function() {
2586        let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2587        let node = lineage("rounded", &expr, None, false).unwrap();
2588
2589        let names = node.downstream_names();
2590        assert!(
2591            names.iter().any(|n| n == "products.price"),
2592            "Expected products.price in downstream, got: {:?}",
2593            names
2594        );
2595    }
2596
2597    #[test]
2598    fn test_lineage_coalesce_function() {
2599        let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2600        let node = lineage("val", &expr, None, false).unwrap();
2601
2602        let names = node.downstream_names();
2603        assert!(
2604            names.iter().any(|n| n == "t.a"),
2605            "Expected t.a in downstream, got: {:?}",
2606            names
2607        );
2608        assert!(
2609            names.iter().any(|n| n == "t.b"),
2610            "Expected t.b in downstream, got: {:?}",
2611            names
2612        );
2613    }
2614
2615    #[test]
2616    fn test_lineage_count_function() {
2617        let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2618        let node = lineage("cnt", &expr, None, false).unwrap();
2619
2620        let names = node.downstream_names();
2621        assert!(
2622            names.iter().any(|n| n == "t.id"),
2623            "Expected t.id in downstream, got: {:?}",
2624            names
2625        );
2626    }
2627
2628    #[test]
2629    fn test_lineage_sum_function() {
2630        let expr = parse("SELECT SUM(amount) AS total FROM t");
2631        let node = lineage("total", &expr, None, false).unwrap();
2632
2633        let names = node.downstream_names();
2634        assert!(
2635            names.iter().any(|n| n == "t.amount"),
2636            "Expected t.amount in downstream, got: {:?}",
2637            names
2638        );
2639    }
2640
2641    #[test]
2642    fn test_lineage_case_with_nested_functions() {
2643        let expr =
2644            parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2645        let node = lineage("result", &expr, None, false).unwrap();
2646
2647        let names = node.downstream_names();
2648        assert!(
2649            names.iter().any(|n| n == "t.x"),
2650            "Expected t.x in downstream, got: {:?}",
2651            names
2652        );
2653        assert!(
2654            names.iter().any(|n| n == "t.name"),
2655            "Expected t.name in downstream, got: {:?}",
2656            names
2657        );
2658    }
2659
2660    #[test]
2661    fn test_lineage_substring_function() {
2662        let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2663        let node = lineage("short", &expr, None, false).unwrap();
2664
2665        let names = node.downstream_names();
2666        assert!(
2667            names.iter().any(|n| n == "t.name"),
2668            "Expected t.name in downstream, got: {:?}",
2669            names
2670        );
2671    }
2672}