Skip to main content

polyglot_sql/
lineage.rs

1//! Column Lineage Tracking
2//!
3//! This module provides functionality to track column lineage through SQL queries,
4//! building a graph of how columns flow from source tables to the result set.
5//! Supports UNION/INTERSECT/EXCEPT, CTEs, derived tables, subqueries, and star expansion.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
11use crate::schema::Schema;
12use crate::scope::{build_scope, Scope};
13use crate::traversal::ExpressionWalk;
14use crate::{Error, Result};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18/// A node in the column lineage graph
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LineageNode {
21    /// Name of this lineage step (e.g., "table.column")
22    pub name: String,
23    /// The expression at this node
24    pub expression: Expression,
25    /// The source expression (the full query context)
26    pub source: Expression,
27    /// Downstream nodes that depend on this one
28    pub downstream: Vec<LineageNode>,
29    /// Optional source name (e.g., for derived tables)
30    pub source_name: String,
31    /// Optional reference node name (e.g., for CTEs)
32    pub reference_node_name: String,
33}
34
35impl LineageNode {
36    /// Create a new lineage node
37    pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
38        Self {
39            name: name.into(),
40            expression,
41            source,
42            downstream: Vec::new(),
43            source_name: String::new(),
44            reference_node_name: String::new(),
45        }
46    }
47
48    /// Iterate over all nodes in the lineage graph using DFS
49    pub fn walk(&self) -> LineageWalker<'_> {
50        LineageWalker { stack: vec![self] }
51    }
52
53    /// Get all downstream column names
54    pub fn downstream_names(&self) -> Vec<String> {
55        self.downstream.iter().map(|n| n.name.clone()).collect()
56    }
57}
58
59/// Iterator for walking the lineage graph
60pub struct LineageWalker<'a> {
61    stack: Vec<&'a LineageNode>,
62}
63
64impl<'a> Iterator for LineageWalker<'a> {
65    type Item = &'a LineageNode;
66
67    fn next(&mut self) -> Option<Self::Item> {
68        if let Some(node) = self.stack.pop() {
69            // Add children in reverse order so they're visited in order
70            for child in node.downstream.iter().rev() {
71                self.stack.push(child);
72            }
73            Some(node)
74        } else {
75            None
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ColumnRef: name or positional index for column lookup
82// ---------------------------------------------------------------------------
83
84/// Column reference for lineage tracing — by name or positional index.
85enum ColumnRef<'a> {
86    Name(&'a str),
87    Index(usize),
88}
89
90// ---------------------------------------------------------------------------
91// Public API
92// ---------------------------------------------------------------------------
93
94/// Build the lineage graph for a column in a SQL query
95///
96/// # Arguments
97/// * `column` - The column name to trace lineage for
98/// * `sql` - The SQL expression (SELECT, UNION, etc.)
99/// * `dialect` - Optional dialect for parsing
100/// * `trim_selects` - If true, trim the source SELECT to only include the target column
101///
102/// # Returns
103/// The root lineage node for the specified column
104///
105/// # Example
106/// ```ignore
107/// use polyglot_sql::lineage::lineage;
108/// use polyglot_sql::parse_one;
109/// use polyglot_sql::DialectType;
110///
111/// let sql = "SELECT a, b + 1 AS c FROM t";
112/// let expr = parse_one(sql, DialectType::Generic).unwrap();
113/// let node = lineage("c", &expr, None, false).unwrap();
114/// ```
115pub fn lineage(
116    column: &str,
117    sql: &Expression,
118    dialect: Option<DialectType>,
119    trim_selects: bool,
120) -> Result<LineageNode> {
121    lineage_from_expression(column, sql, dialect, trim_selects)
122}
123
124/// Build the lineage graph for a column in a SQL query using optional schema metadata.
125///
126/// When `schema` is provided, the query is first qualified with
127/// `optimizer::qualify_columns`, allowing more accurate lineage for unqualified or
128/// ambiguous column references.
129///
130/// # Arguments
131/// * `column` - The column name to trace lineage for
132/// * `sql` - The SQL expression (SELECT, UNION, etc.)
133/// * `schema` - Optional schema used for qualification
134/// * `dialect` - Optional dialect for qualification and lineage handling
135/// * `trim_selects` - If true, trim the source SELECT to only include the target column
136///
137/// # Returns
138/// The root lineage node for the specified column
139pub fn lineage_with_schema(
140    column: &str,
141    sql: &Expression,
142    schema: Option<&dyn Schema>,
143    dialect: Option<DialectType>,
144    trim_selects: bool,
145) -> Result<LineageNode> {
146    let qualified_expression = if let Some(schema) = schema {
147        let options = if let Some(dialect_type) = dialect.or_else(|| schema.dialect()) {
148            QualifyColumnsOptions::new().with_dialect(dialect_type)
149        } else {
150            QualifyColumnsOptions::new()
151        };
152
153        qualify_columns(sql.clone(), schema, &options).map_err(|e| {
154            Error::internal(format!("Lineage qualification failed with schema: {}", e))
155        })?
156    } else {
157        sql.clone()
158    };
159
160    lineage_from_expression(column, &qualified_expression, dialect, trim_selects)
161}
162
163fn lineage_from_expression(
164    column: &str,
165    sql: &Expression,
166    dialect: Option<DialectType>,
167    trim_selects: bool,
168) -> Result<LineageNode> {
169    let scope = build_scope(sql);
170    to_node(
171        ColumnRef::Name(column),
172        &scope,
173        dialect,
174        "",
175        "",
176        "",
177        trim_selects,
178    )
179}
180
181/// Get all source tables from a lineage graph
182pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
183    let mut tables = HashSet::new();
184    collect_source_tables(node, &mut tables);
185    tables
186}
187
188/// Recursively collect source table names from lineage graph
189pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
190    if let Expression::Table(table) = &node.source {
191        tables.insert(table.name.name.clone());
192    }
193    for child in &node.downstream {
194        collect_source_tables(child, tables);
195    }
196}
197
198// ---------------------------------------------------------------------------
199// Core recursive lineage builder
200// ---------------------------------------------------------------------------
201
202/// Recursively build a lineage node for a column in a scope.
203fn to_node(
204    column: ColumnRef<'_>,
205    scope: &Scope,
206    dialect: Option<DialectType>,
207    scope_name: &str,
208    source_name: &str,
209    reference_node_name: &str,
210    trim_selects: bool,
211) -> Result<LineageNode> {
212    to_node_inner(
213        column,
214        scope,
215        dialect,
216        scope_name,
217        source_name,
218        reference_node_name,
219        trim_selects,
220        &[],
221    )
222}
223
224fn to_node_inner(
225    column: ColumnRef<'_>,
226    scope: &Scope,
227    dialect: Option<DialectType>,
228    scope_name: &str,
229    source_name: &str,
230    reference_node_name: &str,
231    trim_selects: bool,
232    ancestor_cte_scopes: &[Scope],
233) -> Result<LineageNode> {
234    let scope_expr = &scope.expression;
235
236    // Build combined CTE scopes: current scope's cte_scopes + ancestors
237    let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
238    for s in ancestor_cte_scopes {
239        all_cte_scopes.push(s);
240    }
241
242    // 0. Unwrap CTE scope — CTE scope expressions are Expression::Cte(...)
243    //    but we need the inner query (SELECT/UNION) for column lookup.
244    let effective_expr = match scope_expr {
245        Expression::Cte(cte) => &cte.this,
246        other => other,
247    };
248
249    // 1. Set operations (UNION / INTERSECT / EXCEPT)
250    if matches!(
251        effective_expr,
252        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
253    ) {
254        // For CTE wrapping a set op, create a temporary scope with the inner expression
255        if matches!(scope_expr, Expression::Cte(_)) {
256            let mut inner_scope = Scope::new(effective_expr.clone());
257            inner_scope.union_scopes = scope.union_scopes.clone();
258            inner_scope.sources = scope.sources.clone();
259            inner_scope.cte_sources = scope.cte_sources.clone();
260            inner_scope.cte_scopes = scope.cte_scopes.clone();
261            inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
262            inner_scope.subquery_scopes = scope.subquery_scopes.clone();
263            return handle_set_operation(
264                &column,
265                &inner_scope,
266                dialect,
267                scope_name,
268                source_name,
269                reference_node_name,
270                trim_selects,
271                ancestor_cte_scopes,
272            );
273        }
274        return handle_set_operation(
275            &column,
276            scope,
277            dialect,
278            scope_name,
279            source_name,
280            reference_node_name,
281            trim_selects,
282            ancestor_cte_scopes,
283        );
284    }
285
286    // 2. Find the select expression for this column
287    let select_expr = find_select_expr(effective_expr, &column)?;
288    let column_name = resolve_column_name(&column, &select_expr);
289
290    // 3. Trim source if requested
291    let node_source = if trim_selects {
292        trim_source(effective_expr, &select_expr)
293    } else {
294        effective_expr.clone()
295    };
296
297    // 4. Create the lineage node
298    let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
299    node.source_name = source_name.to_string();
300    node.reference_node_name = reference_node_name.to_string();
301
302    // 5. Star handling — add downstream for each source
303    if matches!(&select_expr, Expression::Star(_)) {
304        for (name, source_info) in &scope.sources {
305            let child = LineageNode::new(
306                format!("{}.*", name),
307                Expression::Star(crate::expressions::Star {
308                    table: None,
309                    except: None,
310                    replace: None,
311                    rename: None,
312                    trailing_comments: vec![],
313                    span: None,
314                }),
315                source_info.expression.clone(),
316            );
317            node.downstream.push(child);
318        }
319        return Ok(node);
320    }
321
322    // 6. Subqueries in select — trace through scalar subqueries
323    let subqueries: Vec<&Expression> =
324        select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
325    for sq_expr in subqueries {
326        if let Expression::Subquery(sq) = sq_expr {
327            for sq_scope in &scope.subquery_scopes {
328                if sq_scope.expression == sq.this {
329                    if let Ok(child) = to_node_inner(
330                        ColumnRef::Index(0),
331                        sq_scope,
332                        dialect,
333                        &column_name,
334                        "",
335                        "",
336                        trim_selects,
337                        ancestor_cte_scopes,
338                    ) {
339                        node.downstream.push(child);
340                    }
341                    break;
342                }
343            }
344        }
345    }
346
347    // 7. Column references — trace each column to its source
348    let col_refs = find_column_refs_in_expr(&select_expr);
349    for col_ref in col_refs {
350        let col_name = &col_ref.column;
351        if let Some(ref table_id) = col_ref.table {
352            let tbl = &table_id.name;
353            resolve_qualified_column(
354                &mut node,
355                scope,
356                dialect,
357                tbl,
358                col_name,
359                &column_name,
360                trim_selects,
361                &all_cte_scopes,
362            );
363        } else {
364            resolve_unqualified_column(
365                &mut node,
366                scope,
367                dialect,
368                col_name,
369                &column_name,
370                trim_selects,
371                &all_cte_scopes,
372            );
373        }
374    }
375
376    Ok(node)
377}
378
379// ---------------------------------------------------------------------------
380// Set operation handling
381// ---------------------------------------------------------------------------
382
383fn handle_set_operation(
384    column: &ColumnRef<'_>,
385    scope: &Scope,
386    dialect: Option<DialectType>,
387    scope_name: &str,
388    source_name: &str,
389    reference_node_name: &str,
390    trim_selects: bool,
391    ancestor_cte_scopes: &[Scope],
392) -> Result<LineageNode> {
393    let scope_expr = &scope.expression;
394
395    // Determine column index
396    let col_index = match column {
397        ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
398        ColumnRef::Index(i) => *i,
399    };
400
401    let col_name = match column {
402        ColumnRef::Name(name) => name.to_string(),
403        ColumnRef::Index(_) => format!("_{col_index}"),
404    };
405
406    let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
407    node.source_name = source_name.to_string();
408    node.reference_node_name = reference_node_name.to_string();
409
410    // Recurse into each union branch
411    for branch_scope in &scope.union_scopes {
412        if let Ok(child) = to_node_inner(
413            ColumnRef::Index(col_index),
414            branch_scope,
415            dialect,
416            scope_name,
417            "",
418            "",
419            trim_selects,
420            ancestor_cte_scopes,
421        ) {
422            node.downstream.push(child);
423        }
424    }
425
426    Ok(node)
427}
428
429// ---------------------------------------------------------------------------
430// Column resolution helpers
431// ---------------------------------------------------------------------------
432
433fn resolve_qualified_column(
434    node: &mut LineageNode,
435    scope: &Scope,
436    dialect: Option<DialectType>,
437    table: &str,
438    col_name: &str,
439    parent_name: &str,
440    trim_selects: bool,
441    all_cte_scopes: &[&Scope],
442) {
443    // Check if table is a CTE reference (cte_sources tracks CTE names)
444    if scope.cte_sources.contains_key(table) {
445        if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
446            // Build ancestor CTE scopes from all_cte_scopes for the recursive call
447            let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
448            if let Ok(child) = to_node_inner(
449                ColumnRef::Name(col_name),
450                child_scope,
451                dialect,
452                parent_name,
453                table,
454                parent_name,
455                trim_selects,
456                &ancestors,
457            ) {
458                node.downstream.push(child);
459                return;
460            }
461        }
462    }
463
464    // Check if table is a derived table (is_scope = true in sources)
465    if let Some(source_info) = scope.sources.get(table) {
466        if source_info.is_scope {
467            if let Some(child_scope) = find_child_scope(scope, table) {
468                let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
469                if let Ok(child) = to_node_inner(
470                    ColumnRef::Name(col_name),
471                    child_scope,
472                    dialect,
473                    parent_name,
474                    table,
475                    parent_name,
476                    trim_selects,
477                    &ancestors,
478                ) {
479                    node.downstream.push(child);
480                    return;
481                }
482            }
483        }
484    }
485
486    // Base table source found in current scope: preserve alias in the display name
487    // but store the resolved table expression and name for downstream consumers.
488    if let Some(source_info) = scope.sources.get(table) {
489        if !source_info.is_scope {
490            node.downstream.push(make_table_column_node_from_source(
491                table,
492                col_name,
493                &source_info.expression,
494            ));
495            return;
496        }
497    }
498
499    // Base table or unresolved — terminal node
500    node.downstream
501        .push(make_table_column_node(table, col_name));
502}
503
504fn resolve_unqualified_column(
505    node: &mut LineageNode,
506    scope: &Scope,
507    dialect: Option<DialectType>,
508    col_name: &str,
509    parent_name: &str,
510    trim_selects: bool,
511    all_cte_scopes: &[&Scope],
512) {
513    // Try to find which source this column belongs to.
514    // Build the source list from the actual FROM/JOIN clauses to avoid
515    // mixing in CTE definitions that are in scope but not referenced.
516    let from_source_names = source_names_from_from_join(scope);
517
518    if from_source_names.len() == 1 {
519        let tbl = &from_source_names[0];
520        resolve_qualified_column(
521            node,
522            scope,
523            dialect,
524            tbl,
525            col_name,
526            parent_name,
527            trim_selects,
528            all_cte_scopes,
529        );
530        return;
531    }
532
533    // Multiple sources — can't resolve without schema info, add unqualified node
534    let child = LineageNode::new(
535        col_name.to_string(),
536        Expression::Column(crate::expressions::Column {
537            name: crate::expressions::Identifier::new(col_name.to_string()),
538            table: None,
539            join_mark: false,
540            trailing_comments: vec![],
541            span: None,
542        }),
543        node.source.clone(),
544    );
545    node.downstream.push(child);
546}
547
548fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
549    fn source_name(expr: &Expression) -> Option<String> {
550        match expr {
551            Expression::Table(table) => Some(
552                table
553                    .alias
554                    .as_ref()
555                    .map(|a| a.name.clone())
556                    .unwrap_or_else(|| table.name.name.clone()),
557            ),
558            Expression::Subquery(subquery) => {
559                subquery.alias.as_ref().map(|alias| alias.name.clone())
560            }
561            Expression::Paren(paren) => source_name(&paren.this),
562            _ => None,
563        }
564    }
565
566    let effective_expr = match &scope.expression {
567        Expression::Cte(cte) => &cte.this,
568        expr => expr,
569    };
570
571    let mut names = Vec::new();
572    let mut seen = std::collections::HashSet::new();
573
574    if let Expression::Select(select) = effective_expr {
575        if let Some(from) = &select.from {
576            for expr in &from.expressions {
577                if let Some(name) = source_name(expr) {
578                    if !name.is_empty() && seen.insert(name.clone()) {
579                        names.push(name);
580                    }
581                }
582            }
583        }
584        for join in &select.joins {
585            if let Some(name) = source_name(&join.this) {
586                if !name.is_empty() && seen.insert(name.clone()) {
587                    names.push(name);
588                }
589            }
590        }
591    }
592
593    names
594}
595
596// ---------------------------------------------------------------------------
597// Helper functions
598// ---------------------------------------------------------------------------
599
600/// Get the alias or name of an expression
601fn get_alias_or_name(expr: &Expression) -> Option<String> {
602    match expr {
603        Expression::Alias(alias) => Some(alias.alias.name.clone()),
604        Expression::Column(col) => Some(col.name.name.clone()),
605        Expression::Identifier(id) => Some(id.name.clone()),
606        Expression::Star(_) => Some("*".to_string()),
607        _ => None,
608    }
609}
610
611/// Resolve the display name for a column reference.
612fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
613    match column {
614        ColumnRef::Name(n) => n.to_string(),
615        ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
616    }
617}
618
619/// Find the select expression matching a column reference.
620fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
621    if let Expression::Select(ref select) = scope_expr {
622        match column {
623            ColumnRef::Name(name) => {
624                for expr in &select.expressions {
625                    if get_alias_or_name(expr).as_deref() == Some(name) {
626                        return Ok(expr.clone());
627                    }
628                }
629                Err(crate::error::Error::parse(
630                    format!("Cannot find column '{}' in query", name),
631                    0,
632                    0,
633                    0,
634                    0,
635                ))
636            }
637            ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
638                crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0, 0, 0)
639            }),
640        }
641    } else {
642        Err(crate::error::Error::parse(
643            "Expected SELECT expression for column lookup",
644            0,
645            0,
646            0,
647            0,
648        ))
649    }
650}
651
652/// Find the positional index of a column name in a set operation's first SELECT branch.
653fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
654    let mut expr = set_op_expr;
655    loop {
656        match expr {
657            Expression::Union(u) => expr = &u.left,
658            Expression::Intersect(i) => expr = &i.left,
659            Expression::Except(e) => expr = &e.left,
660            Expression::Select(select) => {
661                for (i, e) in select.expressions.iter().enumerate() {
662                    if get_alias_or_name(e).as_deref() == Some(name) {
663                        return Ok(i);
664                    }
665                }
666                return Err(crate::error::Error::parse(
667                    format!("Cannot find column '{}' in set operation", name),
668                    0,
669                    0,
670                    0,
671                    0,
672                ));
673            }
674            _ => {
675                return Err(crate::error::Error::parse(
676                    "Expected SELECT or set operation",
677                    0,
678                    0,
679                    0,
680                    0,
681                ))
682            }
683        }
684    }
685}
686
687/// If trim_selects is enabled, return a copy of the SELECT with only the target column.
688fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
689    if let Expression::Select(select) = select_expr {
690        let mut trimmed = select.as_ref().clone();
691        trimmed.expressions = vec![target_expr.clone()];
692        Expression::Select(Box::new(trimmed))
693    } else {
694        select_expr.clone()
695    }
696}
697
698/// Find the child scope (CTE or derived table) for a given source name.
699fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
700    // Check CTE scopes
701    if scope.cte_sources.contains_key(source_name) {
702        for cte_scope in &scope.cte_scopes {
703            if let Expression::Cte(cte) = &cte_scope.expression {
704                if cte.alias.name == source_name {
705                    return Some(cte_scope);
706                }
707            }
708        }
709    }
710
711    // Check derived table scopes
712    if let Some(source_info) = scope.sources.get(source_name) {
713        if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
714            if let Expression::Subquery(sq) = &source_info.expression {
715                for dt_scope in &scope.derived_table_scopes {
716                    if dt_scope.expression == sq.this {
717                        return Some(dt_scope);
718                    }
719                }
720            }
721        }
722    }
723
724    None
725}
726
727/// Find a CTE scope by name, searching through a combined list of CTE scopes.
728/// This handles nested CTEs where the current scope doesn't have the CTE scope
729/// as a direct child but knows about it via cte_sources.
730fn find_child_scope_in<'a>(
731    all_cte_scopes: &[&'a Scope],
732    scope: &'a Scope,
733    source_name: &str,
734) -> Option<&'a Scope> {
735    // First try the scope's own cte_scopes
736    for cte_scope in &scope.cte_scopes {
737        if let Expression::Cte(cte) = &cte_scope.expression {
738            if cte.alias.name == source_name {
739                return Some(cte_scope);
740            }
741        }
742    }
743
744    // Then search through all ancestor CTE scopes
745    for cte_scope in all_cte_scopes {
746        if let Expression::Cte(cte) = &cte_scope.expression {
747            if cte.alias.name == source_name {
748                return Some(cte_scope);
749            }
750        }
751    }
752
753    // Fall back to derived table scopes
754    if let Some(source_info) = scope.sources.get(source_name) {
755        if source_info.is_scope {
756            if let Expression::Subquery(sq) = &source_info.expression {
757                for dt_scope in &scope.derived_table_scopes {
758                    if dt_scope.expression == sq.this {
759                        return Some(dt_scope);
760                    }
761                }
762            }
763        }
764    }
765
766    None
767}
768
769/// Create a terminal lineage node for a table.column reference.
770fn make_table_column_node(table: &str, column: &str) -> LineageNode {
771    let mut node = LineageNode::new(
772        format!("{}.{}", table, column),
773        Expression::Column(crate::expressions::Column {
774            name: crate::expressions::Identifier::new(column.to_string()),
775            table: Some(crate::expressions::Identifier::new(table.to_string())),
776            join_mark: false,
777            trailing_comments: vec![],
778            span: None,
779        }),
780        Expression::Table(crate::expressions::TableRef::new(table)),
781    );
782    node.source_name = table.to_string();
783    node
784}
785
786fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
787    let mut parts: Vec<String> = Vec::new();
788    if let Some(catalog) = &table_ref.catalog {
789        parts.push(catalog.name.clone());
790    }
791    if let Some(schema) = &table_ref.schema {
792        parts.push(schema.name.clone());
793    }
794    parts.push(table_ref.name.name.clone());
795    parts.join(".")
796}
797
798fn make_table_column_node_from_source(
799    table_alias: &str,
800    column: &str,
801    source: &Expression,
802) -> LineageNode {
803    let mut node = LineageNode::new(
804        format!("{}.{}", table_alias, column),
805        Expression::Column(crate::expressions::Column {
806            name: crate::expressions::Identifier::new(column.to_string()),
807            table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
808            join_mark: false,
809            trailing_comments: vec![],
810            span: None,
811        }),
812        source.clone(),
813    );
814
815    if let Expression::Table(table_ref) = source {
816        node.source_name = table_name_from_table_ref(table_ref);
817    } else {
818        node.source_name = table_alias.to_string();
819    }
820
821    node
822}
823
824/// Simple column reference extracted from an expression
825#[derive(Debug, Clone)]
826struct SimpleColumnRef {
827    table: Option<crate::expressions::Identifier>,
828    column: String,
829}
830
831/// Find all column references in an expression (does not recurse into subqueries).
832fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
833    let mut refs = Vec::new();
834    collect_column_refs(expr, &mut refs);
835    refs
836}
837
838fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
839    let mut stack: Vec<&Expression> = vec![expr];
840
841    while let Some(current) = stack.pop() {
842        match current {
843            // === Leaf: collect Column references ===
844            Expression::Column(col) => {
845                refs.push(SimpleColumnRef {
846                    table: col.table.clone(),
847                    column: col.name.name.clone(),
848                });
849            }
850
851            // === Boundary: don't recurse into subqueries (handled separately) ===
852            Expression::Subquery(_) | Expression::Exists(_) => {}
853
854            // === BinaryOp variants: left, right ===
855            Expression::And(op)
856            | Expression::Or(op)
857            | Expression::Eq(op)
858            | Expression::Neq(op)
859            | Expression::Lt(op)
860            | Expression::Lte(op)
861            | Expression::Gt(op)
862            | Expression::Gte(op)
863            | Expression::Add(op)
864            | Expression::Sub(op)
865            | Expression::Mul(op)
866            | Expression::Div(op)
867            | Expression::Mod(op)
868            | Expression::BitwiseAnd(op)
869            | Expression::BitwiseOr(op)
870            | Expression::BitwiseXor(op)
871            | Expression::BitwiseLeftShift(op)
872            | Expression::BitwiseRightShift(op)
873            | Expression::Concat(op)
874            | Expression::Adjacent(op)
875            | Expression::TsMatch(op)
876            | Expression::PropertyEQ(op)
877            | Expression::ArrayContainsAll(op)
878            | Expression::ArrayContainedBy(op)
879            | Expression::ArrayOverlaps(op)
880            | Expression::JSONBContainsAllTopKeys(op)
881            | Expression::JSONBContainsAnyTopKeys(op)
882            | Expression::JSONBDeleteAtPath(op)
883            | Expression::ExtendsLeft(op)
884            | Expression::ExtendsRight(op)
885            | Expression::Is(op)
886            | Expression::MemberOf(op)
887            | Expression::NullSafeEq(op)
888            | Expression::NullSafeNeq(op)
889            | Expression::Glob(op)
890            | Expression::Match(op) => {
891                stack.push(&op.left);
892                stack.push(&op.right);
893            }
894
895            // === UnaryOp variants: this ===
896            Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
897                stack.push(&u.this);
898            }
899
900            // === UnaryFunc variants: this ===
901            Expression::Upper(f)
902            | Expression::Lower(f)
903            | Expression::Length(f)
904            | Expression::LTrim(f)
905            | Expression::RTrim(f)
906            | Expression::Reverse(f)
907            | Expression::Abs(f)
908            | Expression::Sqrt(f)
909            | Expression::Cbrt(f)
910            | Expression::Ln(f)
911            | Expression::Exp(f)
912            | Expression::Sign(f)
913            | Expression::Date(f)
914            | Expression::Time(f)
915            | Expression::DateFromUnixDate(f)
916            | Expression::UnixDate(f)
917            | Expression::UnixSeconds(f)
918            | Expression::UnixMillis(f)
919            | Expression::UnixMicros(f)
920            | Expression::TimeStrToDate(f)
921            | Expression::DateToDi(f)
922            | Expression::DiToDate(f)
923            | Expression::TsOrDiToDi(f)
924            | Expression::TsOrDsToDatetime(f)
925            | Expression::TsOrDsToTimestamp(f)
926            | Expression::YearOfWeek(f)
927            | Expression::YearOfWeekIso(f)
928            | Expression::Initcap(f)
929            | Expression::Ascii(f)
930            | Expression::Chr(f)
931            | Expression::Soundex(f)
932            | Expression::ByteLength(f)
933            | Expression::Hex(f)
934            | Expression::LowerHex(f)
935            | Expression::Unicode(f)
936            | Expression::Radians(f)
937            | Expression::Degrees(f)
938            | Expression::Sin(f)
939            | Expression::Cos(f)
940            | Expression::Tan(f)
941            | Expression::Asin(f)
942            | Expression::Acos(f)
943            | Expression::Atan(f)
944            | Expression::IsNan(f)
945            | Expression::IsInf(f)
946            | Expression::ArrayLength(f)
947            | Expression::ArraySize(f)
948            | Expression::Cardinality(f)
949            | Expression::ArrayReverse(f)
950            | Expression::ArrayDistinct(f)
951            | Expression::ArrayFlatten(f)
952            | Expression::ArrayCompact(f)
953            | Expression::Explode(f)
954            | Expression::ExplodeOuter(f)
955            | Expression::ToArray(f)
956            | Expression::MapFromEntries(f)
957            | Expression::MapKeys(f)
958            | Expression::MapValues(f)
959            | Expression::JsonArrayLength(f)
960            | Expression::JsonKeys(f)
961            | Expression::JsonType(f)
962            | Expression::ParseJson(f)
963            | Expression::ToJson(f)
964            | Expression::Typeof(f)
965            | Expression::BitwiseCount(f)
966            | Expression::Year(f)
967            | Expression::Month(f)
968            | Expression::Day(f)
969            | Expression::Hour(f)
970            | Expression::Minute(f)
971            | Expression::Second(f)
972            | Expression::DayOfWeek(f)
973            | Expression::DayOfWeekIso(f)
974            | Expression::DayOfMonth(f)
975            | Expression::DayOfYear(f)
976            | Expression::WeekOfYear(f)
977            | Expression::Quarter(f)
978            | Expression::Epoch(f)
979            | Expression::EpochMs(f)
980            | Expression::TimeStrToUnix(f)
981            | Expression::SHA(f)
982            | Expression::SHA1Digest(f)
983            | Expression::TimeToUnix(f)
984            | Expression::JSONBool(f)
985            | Expression::Int64(f)
986            | Expression::MD5NumberLower64(f)
987            | Expression::MD5NumberUpper64(f)
988            | Expression::DateStrToDate(f)
989            | Expression::DateToDateStr(f) => {
990                stack.push(&f.this);
991            }
992
993            // === BinaryFunc variants: this, expression ===
994            Expression::Power(f)
995            | Expression::NullIf(f)
996            | Expression::IfNull(f)
997            | Expression::Nvl(f)
998            | Expression::UnixToTimeStr(f)
999            | Expression::Contains(f)
1000            | Expression::StartsWith(f)
1001            | Expression::EndsWith(f)
1002            | Expression::Levenshtein(f)
1003            | Expression::ModFunc(f)
1004            | Expression::Atan2(f)
1005            | Expression::IntDiv(f)
1006            | Expression::AddMonths(f)
1007            | Expression::MonthsBetween(f)
1008            | Expression::NextDay(f)
1009            | Expression::ArrayContains(f)
1010            | Expression::ArrayPosition(f)
1011            | Expression::ArrayAppend(f)
1012            | Expression::ArrayPrepend(f)
1013            | Expression::ArrayUnion(f)
1014            | Expression::ArrayExcept(f)
1015            | Expression::ArrayRemove(f)
1016            | Expression::StarMap(f)
1017            | Expression::MapFromArrays(f)
1018            | Expression::MapContainsKey(f)
1019            | Expression::ElementAt(f)
1020            | Expression::JsonMergePatch(f)
1021            | Expression::JSONBContains(f)
1022            | Expression::JSONBExtract(f) => {
1023                stack.push(&f.this);
1024                stack.push(&f.expression);
1025            }
1026
1027            // === VarArgFunc variants: expressions ===
1028            Expression::Greatest(f)
1029            | Expression::Least(f)
1030            | Expression::Coalesce(f)
1031            | Expression::ArrayConcat(f)
1032            | Expression::ArrayIntersect(f)
1033            | Expression::ArrayZip(f)
1034            | Expression::MapConcat(f)
1035            | Expression::JsonArray(f) => {
1036                for e in &f.expressions {
1037                    stack.push(e);
1038                }
1039            }
1040
1041            // === AggFunc variants: this, filter, having_max, limit ===
1042            Expression::Sum(f)
1043            | Expression::Avg(f)
1044            | Expression::Min(f)
1045            | Expression::Max(f)
1046            | Expression::ArrayAgg(f)
1047            | Expression::CountIf(f)
1048            | Expression::Stddev(f)
1049            | Expression::StddevPop(f)
1050            | Expression::StddevSamp(f)
1051            | Expression::Variance(f)
1052            | Expression::VarPop(f)
1053            | Expression::VarSamp(f)
1054            | Expression::Median(f)
1055            | Expression::Mode(f)
1056            | Expression::First(f)
1057            | Expression::Last(f)
1058            | Expression::AnyValue(f)
1059            | Expression::ApproxDistinct(f)
1060            | Expression::ApproxCountDistinct(f)
1061            | Expression::LogicalAnd(f)
1062            | Expression::LogicalOr(f)
1063            | Expression::Skewness(f)
1064            | Expression::ArrayConcatAgg(f)
1065            | Expression::ArrayUniqueAgg(f)
1066            | Expression::BoolXorAgg(f)
1067            | Expression::BitwiseAndAgg(f)
1068            | Expression::BitwiseOrAgg(f)
1069            | Expression::BitwiseXorAgg(f) => {
1070                stack.push(&f.this);
1071                if let Some(ref filter) = f.filter {
1072                    stack.push(filter);
1073                }
1074                if let Some((ref expr, _)) = f.having_max {
1075                    stack.push(expr);
1076                }
1077                if let Some(ref limit) = f.limit {
1078                    stack.push(limit);
1079                }
1080            }
1081
1082            // === Generic Function / AggregateFunction: args ===
1083            Expression::Function(func) => {
1084                for arg in &func.args {
1085                    stack.push(arg);
1086                }
1087            }
1088            Expression::AggregateFunction(func) => {
1089                for arg in &func.args {
1090                    stack.push(arg);
1091                }
1092                if let Some(ref filter) = func.filter {
1093                    stack.push(filter);
1094                }
1095                if let Some(ref limit) = func.limit {
1096                    stack.push(limit);
1097                }
1098            }
1099
1100            // === WindowFunction: this (skip Over for lineage purposes) ===
1101            Expression::WindowFunction(wf) => {
1102                stack.push(&wf.this);
1103            }
1104
1105            // === Containers and special expressions ===
1106            Expression::Alias(a) => {
1107                stack.push(&a.this);
1108            }
1109            Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1110                stack.push(&c.this);
1111                if let Some(ref fmt) = c.format {
1112                    stack.push(fmt);
1113                }
1114                if let Some(ref def) = c.default {
1115                    stack.push(def);
1116                }
1117            }
1118            Expression::Paren(p) => {
1119                stack.push(&p.this);
1120            }
1121            Expression::Annotated(a) => {
1122                stack.push(&a.this);
1123            }
1124            Expression::Case(case) => {
1125                if let Some(ref operand) = case.operand {
1126                    stack.push(operand);
1127                }
1128                for (cond, result) in &case.whens {
1129                    stack.push(cond);
1130                    stack.push(result);
1131                }
1132                if let Some(ref else_expr) = case.else_ {
1133                    stack.push(else_expr);
1134                }
1135            }
1136            Expression::Collation(c) => {
1137                stack.push(&c.this);
1138            }
1139            Expression::In(i) => {
1140                stack.push(&i.this);
1141                for e in &i.expressions {
1142                    stack.push(e);
1143                }
1144                if let Some(ref q) = i.query {
1145                    stack.push(q);
1146                }
1147                if let Some(ref u) = i.unnest {
1148                    stack.push(u);
1149                }
1150            }
1151            Expression::Between(b) => {
1152                stack.push(&b.this);
1153                stack.push(&b.low);
1154                stack.push(&b.high);
1155            }
1156            Expression::IsNull(n) => {
1157                stack.push(&n.this);
1158            }
1159            Expression::IsTrue(t) | Expression::IsFalse(t) => {
1160                stack.push(&t.this);
1161            }
1162            Expression::IsJson(j) => {
1163                stack.push(&j.this);
1164            }
1165            Expression::Like(l) | Expression::ILike(l) => {
1166                stack.push(&l.left);
1167                stack.push(&l.right);
1168                if let Some(ref esc) = l.escape {
1169                    stack.push(esc);
1170                }
1171            }
1172            Expression::SimilarTo(s) => {
1173                stack.push(&s.this);
1174                stack.push(&s.pattern);
1175                if let Some(ref esc) = s.escape {
1176                    stack.push(esc);
1177                }
1178            }
1179            Expression::Ordered(o) => {
1180                stack.push(&o.this);
1181            }
1182            Expression::Array(a) => {
1183                for e in &a.expressions {
1184                    stack.push(e);
1185                }
1186            }
1187            Expression::Tuple(t) => {
1188                for e in &t.expressions {
1189                    stack.push(e);
1190                }
1191            }
1192            Expression::Struct(s) => {
1193                for (_, e) in &s.fields {
1194                    stack.push(e);
1195                }
1196            }
1197            Expression::Subscript(s) => {
1198                stack.push(&s.this);
1199                stack.push(&s.index);
1200            }
1201            Expression::Dot(d) => {
1202                stack.push(&d.this);
1203            }
1204            Expression::MethodCall(m) => {
1205                stack.push(&m.this);
1206                for arg in &m.args {
1207                    stack.push(arg);
1208                }
1209            }
1210            Expression::ArraySlice(s) => {
1211                stack.push(&s.this);
1212                if let Some(ref start) = s.start {
1213                    stack.push(start);
1214                }
1215                if let Some(ref end) = s.end {
1216                    stack.push(end);
1217                }
1218            }
1219            Expression::Lambda(l) => {
1220                stack.push(&l.body);
1221            }
1222            Expression::NamedArgument(n) => {
1223                stack.push(&n.value);
1224            }
1225            Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1226                stack.push(e);
1227            }
1228
1229            // === Custom function structs ===
1230            Expression::Substring(f) => {
1231                stack.push(&f.this);
1232                stack.push(&f.start);
1233                if let Some(ref len) = f.length {
1234                    stack.push(len);
1235                }
1236            }
1237            Expression::Trim(f) => {
1238                stack.push(&f.this);
1239                if let Some(ref chars) = f.characters {
1240                    stack.push(chars);
1241                }
1242            }
1243            Expression::Replace(f) => {
1244                stack.push(&f.this);
1245                stack.push(&f.old);
1246                stack.push(&f.new);
1247            }
1248            Expression::IfFunc(f) => {
1249                stack.push(&f.condition);
1250                stack.push(&f.true_value);
1251                if let Some(ref fv) = f.false_value {
1252                    stack.push(fv);
1253                }
1254            }
1255            Expression::Nvl2(f) => {
1256                stack.push(&f.this);
1257                stack.push(&f.true_value);
1258                stack.push(&f.false_value);
1259            }
1260            Expression::ConcatWs(f) => {
1261                stack.push(&f.separator);
1262                for e in &f.expressions {
1263                    stack.push(e);
1264                }
1265            }
1266            Expression::Count(f) => {
1267                if let Some(ref this) = f.this {
1268                    stack.push(this);
1269                }
1270                if let Some(ref filter) = f.filter {
1271                    stack.push(filter);
1272                }
1273            }
1274            Expression::GroupConcat(f) => {
1275                stack.push(&f.this);
1276                if let Some(ref sep) = f.separator {
1277                    stack.push(sep);
1278                }
1279                if let Some(ref filter) = f.filter {
1280                    stack.push(filter);
1281                }
1282            }
1283            Expression::StringAgg(f) => {
1284                stack.push(&f.this);
1285                if let Some(ref sep) = f.separator {
1286                    stack.push(sep);
1287                }
1288                if let Some(ref filter) = f.filter {
1289                    stack.push(filter);
1290                }
1291                if let Some(ref limit) = f.limit {
1292                    stack.push(limit);
1293                }
1294            }
1295            Expression::ListAgg(f) => {
1296                stack.push(&f.this);
1297                if let Some(ref sep) = f.separator {
1298                    stack.push(sep);
1299                }
1300                if let Some(ref filter) = f.filter {
1301                    stack.push(filter);
1302                }
1303            }
1304            Expression::SumIf(f) => {
1305                stack.push(&f.this);
1306                stack.push(&f.condition);
1307                if let Some(ref filter) = f.filter {
1308                    stack.push(filter);
1309                }
1310            }
1311            Expression::DateAdd(f) | Expression::DateSub(f) => {
1312                stack.push(&f.this);
1313                stack.push(&f.interval);
1314            }
1315            Expression::DateDiff(f) => {
1316                stack.push(&f.this);
1317                stack.push(&f.expression);
1318            }
1319            Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1320                stack.push(&f.this);
1321            }
1322            Expression::Extract(f) => {
1323                stack.push(&f.this);
1324            }
1325            Expression::Round(f) => {
1326                stack.push(&f.this);
1327                if let Some(ref d) = f.decimals {
1328                    stack.push(d);
1329                }
1330            }
1331            Expression::Floor(f) => {
1332                stack.push(&f.this);
1333                if let Some(ref s) = f.scale {
1334                    stack.push(s);
1335                }
1336                if let Some(ref t) = f.to {
1337                    stack.push(t);
1338                }
1339            }
1340            Expression::Ceil(f) => {
1341                stack.push(&f.this);
1342                if let Some(ref d) = f.decimals {
1343                    stack.push(d);
1344                }
1345                if let Some(ref t) = f.to {
1346                    stack.push(t);
1347                }
1348            }
1349            Expression::Log(f) => {
1350                stack.push(&f.this);
1351                if let Some(ref b) = f.base {
1352                    stack.push(b);
1353                }
1354            }
1355            Expression::AtTimeZone(f) => {
1356                stack.push(&f.this);
1357                stack.push(&f.zone);
1358            }
1359            Expression::Lead(f) | Expression::Lag(f) => {
1360                stack.push(&f.this);
1361                if let Some(ref off) = f.offset {
1362                    stack.push(off);
1363                }
1364                if let Some(ref def) = f.default {
1365                    stack.push(def);
1366                }
1367            }
1368            Expression::FirstValue(f) | Expression::LastValue(f) => {
1369                stack.push(&f.this);
1370            }
1371            Expression::NthValue(f) => {
1372                stack.push(&f.this);
1373                stack.push(&f.offset);
1374            }
1375            Expression::Position(f) => {
1376                stack.push(&f.substring);
1377                stack.push(&f.string);
1378                if let Some(ref start) = f.start {
1379                    stack.push(start);
1380                }
1381            }
1382            Expression::Decode(f) => {
1383                stack.push(&f.this);
1384                for (search, result) in &f.search_results {
1385                    stack.push(search);
1386                    stack.push(result);
1387                }
1388                if let Some(ref def) = f.default {
1389                    stack.push(def);
1390                }
1391            }
1392            Expression::CharFunc(f) => {
1393                for arg in &f.args {
1394                    stack.push(arg);
1395                }
1396            }
1397            Expression::ArraySort(f) => {
1398                stack.push(&f.this);
1399                if let Some(ref cmp) = f.comparator {
1400                    stack.push(cmp);
1401                }
1402            }
1403            Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1404                stack.push(&f.this);
1405                stack.push(&f.separator);
1406                if let Some(ref nr) = f.null_replacement {
1407                    stack.push(nr);
1408                }
1409            }
1410            Expression::ArrayFilter(f) => {
1411                stack.push(&f.this);
1412                stack.push(&f.filter);
1413            }
1414            Expression::ArrayTransform(f) => {
1415                stack.push(&f.this);
1416                stack.push(&f.transform);
1417            }
1418            Expression::Sequence(f)
1419            | Expression::Generate(f)
1420            | Expression::ExplodingGenerateSeries(f) => {
1421                stack.push(&f.start);
1422                stack.push(&f.stop);
1423                if let Some(ref step) = f.step {
1424                    stack.push(step);
1425                }
1426            }
1427            Expression::JsonExtract(f)
1428            | Expression::JsonExtractScalar(f)
1429            | Expression::JsonQuery(f)
1430            | Expression::JsonValue(f) => {
1431                stack.push(&f.this);
1432                stack.push(&f.path);
1433            }
1434            Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1435                stack.push(&f.this);
1436                for p in &f.paths {
1437                    stack.push(p);
1438                }
1439            }
1440            Expression::JsonObject(f) => {
1441                for (k, v) in &f.pairs {
1442                    stack.push(k);
1443                    stack.push(v);
1444                }
1445            }
1446            Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1447                stack.push(&f.this);
1448                for (path, val) in &f.path_values {
1449                    stack.push(path);
1450                    stack.push(val);
1451                }
1452            }
1453            Expression::Overlay(f) => {
1454                stack.push(&f.this);
1455                stack.push(&f.replacement);
1456                stack.push(&f.from);
1457                if let Some(ref len) = f.length {
1458                    stack.push(len);
1459                }
1460            }
1461            Expression::Convert(f) => {
1462                stack.push(&f.this);
1463                if let Some(ref style) = f.style {
1464                    stack.push(style);
1465                }
1466            }
1467            Expression::ApproxPercentile(f) => {
1468                stack.push(&f.this);
1469                stack.push(&f.percentile);
1470                if let Some(ref acc) = f.accuracy {
1471                    stack.push(acc);
1472                }
1473                if let Some(ref filter) = f.filter {
1474                    stack.push(filter);
1475                }
1476            }
1477            Expression::Percentile(f)
1478            | Expression::PercentileCont(f)
1479            | Expression::PercentileDisc(f) => {
1480                stack.push(&f.this);
1481                stack.push(&f.percentile);
1482                if let Some(ref filter) = f.filter {
1483                    stack.push(filter);
1484                }
1485            }
1486            Expression::WithinGroup(f) => {
1487                stack.push(&f.this);
1488            }
1489            Expression::Left(f) | Expression::Right(f) => {
1490                stack.push(&f.this);
1491                stack.push(&f.length);
1492            }
1493            Expression::Repeat(f) => {
1494                stack.push(&f.this);
1495                stack.push(&f.times);
1496            }
1497            Expression::Lpad(f) | Expression::Rpad(f) => {
1498                stack.push(&f.this);
1499                stack.push(&f.length);
1500                if let Some(ref fill) = f.fill {
1501                    stack.push(fill);
1502                }
1503            }
1504            Expression::Split(f) => {
1505                stack.push(&f.this);
1506                stack.push(&f.delimiter);
1507            }
1508            Expression::RegexpLike(f) => {
1509                stack.push(&f.this);
1510                stack.push(&f.pattern);
1511                if let Some(ref flags) = f.flags {
1512                    stack.push(flags);
1513                }
1514            }
1515            Expression::RegexpReplace(f) => {
1516                stack.push(&f.this);
1517                stack.push(&f.pattern);
1518                stack.push(&f.replacement);
1519                if let Some(ref flags) = f.flags {
1520                    stack.push(flags);
1521                }
1522            }
1523            Expression::RegexpExtract(f) => {
1524                stack.push(&f.this);
1525                stack.push(&f.pattern);
1526                if let Some(ref group) = f.group {
1527                    stack.push(group);
1528                }
1529            }
1530            Expression::ToDate(f) => {
1531                stack.push(&f.this);
1532                if let Some(ref fmt) = f.format {
1533                    stack.push(fmt);
1534                }
1535            }
1536            Expression::ToTimestamp(f) => {
1537                stack.push(&f.this);
1538                if let Some(ref fmt) = f.format {
1539                    stack.push(fmt);
1540                }
1541            }
1542            Expression::DateFormat(f) | Expression::FormatDate(f) => {
1543                stack.push(&f.this);
1544                stack.push(&f.format);
1545            }
1546            Expression::LastDay(f) => {
1547                stack.push(&f.this);
1548            }
1549            Expression::FromUnixtime(f) => {
1550                stack.push(&f.this);
1551                if let Some(ref fmt) = f.format {
1552                    stack.push(fmt);
1553                }
1554            }
1555            Expression::UnixTimestamp(f) => {
1556                if let Some(ref this) = f.this {
1557                    stack.push(this);
1558                }
1559                if let Some(ref fmt) = f.format {
1560                    stack.push(fmt);
1561                }
1562            }
1563            Expression::MakeDate(f) => {
1564                stack.push(&f.year);
1565                stack.push(&f.month);
1566                stack.push(&f.day);
1567            }
1568            Expression::MakeTimestamp(f) => {
1569                stack.push(&f.year);
1570                stack.push(&f.month);
1571                stack.push(&f.day);
1572                stack.push(&f.hour);
1573                stack.push(&f.minute);
1574                stack.push(&f.second);
1575                if let Some(ref tz) = f.timezone {
1576                    stack.push(tz);
1577                }
1578            }
1579            Expression::TruncFunc(f) => {
1580                stack.push(&f.this);
1581                if let Some(ref d) = f.decimals {
1582                    stack.push(d);
1583                }
1584            }
1585            Expression::ArrayFunc(f) => {
1586                for e in &f.expressions {
1587                    stack.push(e);
1588                }
1589            }
1590            Expression::Unnest(f) => {
1591                stack.push(&f.this);
1592                for e in &f.expressions {
1593                    stack.push(e);
1594                }
1595            }
1596            Expression::StructFunc(f) => {
1597                for (_, e) in &f.fields {
1598                    stack.push(e);
1599                }
1600            }
1601            Expression::StructExtract(f) => {
1602                stack.push(&f.this);
1603            }
1604            Expression::NamedStruct(f) => {
1605                for (k, v) in &f.pairs {
1606                    stack.push(k);
1607                    stack.push(v);
1608                }
1609            }
1610            Expression::MapFunc(f) => {
1611                for k in &f.keys {
1612                    stack.push(k);
1613                }
1614                for v in &f.values {
1615                    stack.push(v);
1616                }
1617            }
1618            Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1619                stack.push(&f.this);
1620                stack.push(&f.transform);
1621            }
1622            Expression::JsonArrayAgg(f) => {
1623                stack.push(&f.this);
1624                if let Some(ref filter) = f.filter {
1625                    stack.push(filter);
1626                }
1627            }
1628            Expression::JsonObjectAgg(f) => {
1629                stack.push(&f.key);
1630                stack.push(&f.value);
1631                if let Some(ref filter) = f.filter {
1632                    stack.push(filter);
1633                }
1634            }
1635            Expression::NTile(f) => {
1636                if let Some(ref n) = f.num_buckets {
1637                    stack.push(n);
1638                }
1639            }
1640            Expression::Rand(f) => {
1641                if let Some(ref s) = f.seed {
1642                    stack.push(s);
1643                }
1644                if let Some(ref lo) = f.lower {
1645                    stack.push(lo);
1646                }
1647                if let Some(ref hi) = f.upper {
1648                    stack.push(hi);
1649                }
1650            }
1651            Expression::Any(q) | Expression::All(q) => {
1652                stack.push(&q.this);
1653                stack.push(&q.subquery);
1654            }
1655            Expression::Overlaps(o) => {
1656                if let Some(ref this) = o.this {
1657                    stack.push(this);
1658                }
1659                if let Some(ref expr) = o.expression {
1660                    stack.push(expr);
1661                }
1662                if let Some(ref ls) = o.left_start {
1663                    stack.push(ls);
1664                }
1665                if let Some(ref le) = o.left_end {
1666                    stack.push(le);
1667                }
1668                if let Some(ref rs) = o.right_start {
1669                    stack.push(rs);
1670                }
1671                if let Some(ref re) = o.right_end {
1672                    stack.push(re);
1673                }
1674            }
1675            Expression::Interval(i) => {
1676                if let Some(ref this) = i.this {
1677                    stack.push(this);
1678                }
1679            }
1680            Expression::TimeStrToTime(f) => {
1681                stack.push(&f.this);
1682                if let Some(ref zone) = f.zone {
1683                    stack.push(zone);
1684                }
1685            }
1686            Expression::JSONBExtractScalar(f) => {
1687                stack.push(&f.this);
1688                stack.push(&f.expression);
1689                if let Some(ref jt) = f.json_type {
1690                    stack.push(jt);
1691                }
1692            }
1693
1694            // === True leaves and non-expression-bearing nodes ===
1695            // Literals, Identifier, Star, DataType, Placeholder, Boolean, Null,
1696            // CurrentDate/Time/Timestamp, RowNumber, Rank, DenseRank, PercentRank,
1697            // CumeDist, Random, Pi, SessionUser, DDL statements, clauses, etc.
1698            _ => {}
1699        }
1700    }
1701}
1702
1703// ---------------------------------------------------------------------------
1704// Tests
1705// ---------------------------------------------------------------------------
1706
1707#[cfg(test)]
1708mod tests {
1709    use super::*;
1710    use crate::dialects::{Dialect, DialectType};
1711    use crate::expressions::DataType;
1712    use crate::optimizer::annotate_types::annotate_types;
1713    use crate::schema::{MappingSchema, Schema};
1714
1715    fn parse(sql: &str) -> Expression {
1716        let dialect = Dialect::get(DialectType::Generic);
1717        let ast = dialect.parse(sql).unwrap();
1718        ast.into_iter().next().unwrap()
1719    }
1720
1721    #[test]
1722    fn test_simple_lineage() {
1723        let expr = parse("SELECT a FROM t");
1724        let node = lineage("a", &expr, None, false).unwrap();
1725
1726        assert_eq!(node.name, "a");
1727        assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1728        // Should trace to t.a
1729        let names = node.downstream_names();
1730        assert!(
1731            names.iter().any(|n| n == "t.a"),
1732            "Expected t.a in downstream, got: {:?}",
1733            names
1734        );
1735    }
1736
1737    #[test]
1738    fn test_lineage_walk() {
1739        let root = LineageNode {
1740            name: "col_a".to_string(),
1741            expression: Expression::Null(crate::expressions::Null),
1742            source: Expression::Null(crate::expressions::Null),
1743            downstream: vec![LineageNode::new(
1744                "t.a",
1745                Expression::Null(crate::expressions::Null),
1746                Expression::Null(crate::expressions::Null),
1747            )],
1748            source_name: String::new(),
1749            reference_node_name: String::new(),
1750        };
1751
1752        let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1753        assert_eq!(names.len(), 2);
1754        assert_eq!(names[0], "col_a");
1755        assert_eq!(names[1], "t.a");
1756    }
1757
1758    #[test]
1759    fn test_aliased_column() {
1760        let expr = parse("SELECT a + 1 AS b FROM t");
1761        let node = lineage("b", &expr, None, false).unwrap();
1762
1763        assert_eq!(node.name, "b");
1764        // Should trace through the expression to t.a
1765        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1766        assert!(
1767            all_names.iter().any(|n| n.contains("a")),
1768            "Expected to trace to column a, got: {:?}",
1769            all_names
1770        );
1771    }
1772
1773    #[test]
1774    fn test_qualified_column() {
1775        let expr = parse("SELECT t.a FROM t");
1776        let node = lineage("a", &expr, None, false).unwrap();
1777
1778        assert_eq!(node.name, "a");
1779        let names = node.downstream_names();
1780        assert!(
1781            names.iter().any(|n| n == "t.a"),
1782            "Expected t.a, got: {:?}",
1783            names
1784        );
1785    }
1786
1787    #[test]
1788    fn test_unqualified_column() {
1789        let expr = parse("SELECT a FROM t");
1790        let node = lineage("a", &expr, None, false).unwrap();
1791
1792        // Unqualified but single source → resolved to t.a
1793        let names = node.downstream_names();
1794        assert!(
1795            names.iter().any(|n| n == "t.a"),
1796            "Expected t.a, got: {:?}",
1797            names
1798        );
1799    }
1800
1801    #[test]
1802    fn test_lineage_with_schema_qualifies_root_expression_issue_40() {
1803        let query = "SELECT name FROM users";
1804        let dialect = Dialect::get(DialectType::BigQuery);
1805        let expr = dialect
1806            .parse(query)
1807            .unwrap()
1808            .into_iter()
1809            .next()
1810            .expect("expected one expression");
1811
1812        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1813        schema
1814            .add_table("users", &[("name".into(), DataType::Text)], None)
1815            .expect("schema setup");
1816
1817        let node_without_schema = lineage("name", &expr, Some(DialectType::BigQuery), false)
1818            .expect("lineage without schema");
1819        let root_without_schema = annotate_types(
1820            &node_without_schema.expression,
1821            Some(&schema),
1822            Some(DialectType::BigQuery),
1823        );
1824        assert_eq!(
1825            root_without_schema, None,
1826            "Expected unresolved root type without schema-aware lineage qualification"
1827        );
1828
1829        let node_with_schema = lineage_with_schema(
1830            "name",
1831            &expr,
1832            Some(&schema),
1833            Some(DialectType::BigQuery),
1834            false,
1835        )
1836        .expect("lineage with schema");
1837        let root_with_schema = annotate_types(
1838            &node_with_schema.expression,
1839            Some(&schema),
1840            Some(DialectType::BigQuery),
1841        );
1842
1843        assert_eq!(root_with_schema, Some(DataType::Text));
1844    }
1845
1846    #[test]
1847    fn test_lineage_with_schema_none_matches_lineage() {
1848        let expr = parse("SELECT a FROM t");
1849        let baseline = lineage("a", &expr, None, false).expect("lineage baseline");
1850        let with_none =
1851            lineage_with_schema("a", &expr, None, None, false).expect("lineage_with_schema");
1852
1853        assert_eq!(with_none.name, baseline.name);
1854        assert_eq!(with_none.downstream_names(), baseline.downstream_names());
1855    }
1856
1857    #[test]
1858    fn test_lineage_join() {
1859        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1860
1861        let node_a = lineage("a", &expr, None, false).unwrap();
1862        let names_a = node_a.downstream_names();
1863        assert!(
1864            names_a.iter().any(|n| n == "t.a"),
1865            "Expected t.a, got: {:?}",
1866            names_a
1867        );
1868
1869        let node_b = lineage("b", &expr, None, false).unwrap();
1870        let names_b = node_b.downstream_names();
1871        assert!(
1872            names_b.iter().any(|n| n == "s.b"),
1873            "Expected s.b, got: {:?}",
1874            names_b
1875        );
1876    }
1877
1878    #[test]
1879    fn test_lineage_alias_leaf_has_resolved_source_name() {
1880        let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
1881        let node = lineage("col1", &expr, None, false).unwrap();
1882
1883        // Keep alias in the display lineage edge.
1884        let names = node.downstream_names();
1885        assert!(
1886            names.iter().any(|n| n == "t1.col1"),
1887            "Expected aliased column edge t1.col1, got: {:?}",
1888            names
1889        );
1890
1891        // Leaf should expose the resolved base table for consumers.
1892        let leaf = node
1893            .downstream
1894            .iter()
1895            .find(|n| n.name == "t1.col1")
1896            .expect("Expected t1.col1 leaf");
1897        assert_eq!(leaf.source_name, "table1");
1898        match &leaf.source {
1899            Expression::Table(table) => assert_eq!(table.name.name, "table1"),
1900            _ => panic!("Expected leaf source to be a table expression"),
1901        }
1902    }
1903
1904    #[test]
1905    fn test_lineage_derived_table() {
1906        let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
1907        let node = lineage("a", &expr, None, false).unwrap();
1908
1909        assert_eq!(node.name, "a");
1910        // Should trace through the derived table to t.a
1911        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1912        assert!(
1913            all_names.iter().any(|n| n == "t.a"),
1914            "Expected to trace through derived table to t.a, got: {:?}",
1915            all_names
1916        );
1917    }
1918
1919    #[test]
1920    fn test_lineage_cte() {
1921        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
1922        let node = lineage("a", &expr, None, false).unwrap();
1923
1924        assert_eq!(node.name, "a");
1925        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1926        assert!(
1927            all_names.iter().any(|n| n == "t.a"),
1928            "Expected to trace through CTE to t.a, got: {:?}",
1929            all_names
1930        );
1931    }
1932
1933    #[test]
1934    fn test_lineage_union() {
1935        let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
1936        let node = lineage("a", &expr, None, false).unwrap();
1937
1938        assert_eq!(node.name, "a");
1939        // Should have 2 downstream branches
1940        assert_eq!(
1941            node.downstream.len(),
1942            2,
1943            "Expected 2 branches for UNION, got {}",
1944            node.downstream.len()
1945        );
1946    }
1947
1948    #[test]
1949    fn test_lineage_cte_union() {
1950        let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
1951        let node = lineage("a", &expr, None, false).unwrap();
1952
1953        // Should trace through CTE into both UNION branches
1954        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1955        assert!(
1956            all_names.len() >= 3,
1957            "Expected at least 3 nodes for CTE with UNION, got: {:?}",
1958            all_names
1959        );
1960    }
1961
1962    #[test]
1963    fn test_lineage_star() {
1964        let expr = parse("SELECT * FROM t");
1965        let node = lineage("*", &expr, None, false).unwrap();
1966
1967        assert_eq!(node.name, "*");
1968        // Should have downstream for table t
1969        assert!(
1970            !node.downstream.is_empty(),
1971            "Star should produce downstream nodes"
1972        );
1973    }
1974
1975    #[test]
1976    fn test_lineage_subquery_in_select() {
1977        let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
1978        let node = lineage("x", &expr, None, false).unwrap();
1979
1980        assert_eq!(node.name, "x");
1981        // Should have traced into the scalar subquery
1982        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1983        assert!(
1984            all_names.len() >= 2,
1985            "Expected tracing into scalar subquery, got: {:?}",
1986            all_names
1987        );
1988    }
1989
1990    #[test]
1991    fn test_lineage_multiple_columns() {
1992        let expr = parse("SELECT a, b FROM t");
1993
1994        let node_a = lineage("a", &expr, None, false).unwrap();
1995        let node_b = lineage("b", &expr, None, false).unwrap();
1996
1997        assert_eq!(node_a.name, "a");
1998        assert_eq!(node_b.name, "b");
1999
2000        // Each should trace independently
2001        let names_a = node_a.downstream_names();
2002        let names_b = node_b.downstream_names();
2003        assert!(names_a.iter().any(|n| n == "t.a"));
2004        assert!(names_b.iter().any(|n| n == "t.b"));
2005    }
2006
2007    #[test]
2008    fn test_get_source_tables() {
2009        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
2010        let node = lineage("a", &expr, None, false).unwrap();
2011
2012        let tables = get_source_tables(&node);
2013        assert!(
2014            tables.contains("t"),
2015            "Expected source table 't', got: {:?}",
2016            tables
2017        );
2018    }
2019
2020    #[test]
2021    fn test_lineage_column_not_found() {
2022        let expr = parse("SELECT a FROM t");
2023        let result = lineage("nonexistent", &expr, None, false);
2024        assert!(result.is_err());
2025    }
2026
2027    #[test]
2028    fn test_lineage_nested_cte() {
2029        let expr = parse(
2030            "WITH cte1 AS (SELECT a FROM t), \
2031             cte2 AS (SELECT a FROM cte1) \
2032             SELECT a FROM cte2",
2033        );
2034        let node = lineage("a", &expr, None, false).unwrap();
2035
2036        // Should trace through cte2 → cte1 → t
2037        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2038        assert!(
2039            all_names.len() >= 3,
2040            "Expected to trace through nested CTEs, got: {:?}",
2041            all_names
2042        );
2043    }
2044
2045    #[test]
2046    fn test_trim_selects_true() {
2047        let expr = parse("SELECT a, b, c FROM t");
2048        let node = lineage("a", &expr, None, true).unwrap();
2049
2050        // The source should be trimmed to only include 'a'
2051        if let Expression::Select(select) = &node.source {
2052            assert_eq!(
2053                select.expressions.len(),
2054                1,
2055                "Trimmed source should have 1 expression, got {}",
2056                select.expressions.len()
2057            );
2058        } else {
2059            panic!("Expected Select source");
2060        }
2061    }
2062
2063    #[test]
2064    fn test_trim_selects_false() {
2065        let expr = parse("SELECT a, b, c FROM t");
2066        let node = lineage("a", &expr, None, false).unwrap();
2067
2068        // The source should keep all columns
2069        if let Expression::Select(select) = &node.source {
2070            assert_eq!(
2071                select.expressions.len(),
2072                3,
2073                "Untrimmed source should have 3 expressions"
2074            );
2075        } else {
2076            panic!("Expected Select source");
2077        }
2078    }
2079
2080    #[test]
2081    fn test_lineage_expression_in_select() {
2082        let expr = parse("SELECT a + b AS c FROM t");
2083        let node = lineage("c", &expr, None, false).unwrap();
2084
2085        // Should trace to both a and b from t
2086        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2087        assert!(
2088            all_names.len() >= 3,
2089            "Expected to trace a + b to both columns, got: {:?}",
2090            all_names
2091        );
2092    }
2093
2094    #[test]
2095    fn test_set_operation_by_index() {
2096        let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
2097
2098        // Trace column "a" which is at index 0
2099        let node = lineage("a", &expr, None, false).unwrap();
2100
2101        // UNION branches should be traced by index
2102        assert_eq!(node.downstream.len(), 2);
2103    }
2104
2105    // --- Tests for column lineage inside function calls (issue #18) ---
2106
2107    fn print_node(node: &LineageNode, indent: usize) {
2108        let pad = "  ".repeat(indent);
2109        println!(
2110            "{pad}name={:?} source_name={:?}",
2111            node.name, node.source_name
2112        );
2113        for child in &node.downstream {
2114            print_node(child, indent + 1);
2115        }
2116    }
2117
2118    #[test]
2119    fn test_issue18_repro() {
2120        // Exact scenario from the issue
2121        let query = "SELECT UPPER(name) as upper_name FROM users";
2122        println!("Query: {query}\n");
2123
2124        let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2125        let exprs = dialect.parse(query).unwrap();
2126        let expr = &exprs[0];
2127
2128        let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2129        println!("lineage(\"upper_name\"):");
2130        print_node(&node, 1);
2131
2132        let names = node.downstream_names();
2133        assert!(
2134            names.iter().any(|n| n == "users.name"),
2135            "Expected users.name in downstream, got: {:?}",
2136            names
2137        );
2138    }
2139
2140    #[test]
2141    fn test_lineage_upper_function() {
2142        let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2143        let node = lineage("upper_name", &expr, None, false).unwrap();
2144
2145        let names = node.downstream_names();
2146        assert!(
2147            names.iter().any(|n| n == "users.name"),
2148            "Expected users.name in downstream, got: {:?}",
2149            names
2150        );
2151    }
2152
2153    #[test]
2154    fn test_lineage_round_function() {
2155        let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2156        let node = lineage("rounded", &expr, None, false).unwrap();
2157
2158        let names = node.downstream_names();
2159        assert!(
2160            names.iter().any(|n| n == "products.price"),
2161            "Expected products.price in downstream, got: {:?}",
2162            names
2163        );
2164    }
2165
2166    #[test]
2167    fn test_lineage_coalesce_function() {
2168        let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2169        let node = lineage("val", &expr, None, false).unwrap();
2170
2171        let names = node.downstream_names();
2172        assert!(
2173            names.iter().any(|n| n == "t.a"),
2174            "Expected t.a in downstream, got: {:?}",
2175            names
2176        );
2177        assert!(
2178            names.iter().any(|n| n == "t.b"),
2179            "Expected t.b in downstream, got: {:?}",
2180            names
2181        );
2182    }
2183
2184    #[test]
2185    fn test_lineage_count_function() {
2186        let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2187        let node = lineage("cnt", &expr, None, false).unwrap();
2188
2189        let names = node.downstream_names();
2190        assert!(
2191            names.iter().any(|n| n == "t.id"),
2192            "Expected t.id in downstream, got: {:?}",
2193            names
2194        );
2195    }
2196
2197    #[test]
2198    fn test_lineage_sum_function() {
2199        let expr = parse("SELECT SUM(amount) AS total FROM t");
2200        let node = lineage("total", &expr, None, false).unwrap();
2201
2202        let names = node.downstream_names();
2203        assert!(
2204            names.iter().any(|n| n == "t.amount"),
2205            "Expected t.amount in downstream, got: {:?}",
2206            names
2207        );
2208    }
2209
2210    #[test]
2211    fn test_lineage_case_with_nested_functions() {
2212        let expr =
2213            parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2214        let node = lineage("result", &expr, None, false).unwrap();
2215
2216        let names = node.downstream_names();
2217        assert!(
2218            names.iter().any(|n| n == "t.x"),
2219            "Expected t.x in downstream, got: {:?}",
2220            names
2221        );
2222        assert!(
2223            names.iter().any(|n| n == "t.name"),
2224            "Expected t.name in downstream, got: {:?}",
2225            names
2226        );
2227    }
2228
2229    #[test]
2230    fn test_lineage_substring_function() {
2231        let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2232        let node = lineage("short", &expr, None, false).unwrap();
2233
2234        let names = node.downstream_names();
2235        assert!(
2236            names.iter().any(|n| n == "t.name"),
2237            "Expected t.name in downstream, got: {:?}",
2238            names
2239        );
2240    }
2241}