Skip to main content

polyglot_sql/
lineage.rs

1//! Column Lineage Tracking
2//!
3//! This module provides functionality to track column lineage through SQL queries,
4//! building a graph of how columns flow from source tables to the result set.
5//! Supports UNION/INTERSECT/EXCEPT, CTEs, derived tables, subqueries, and star expansion.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use std::collections::HashSet;
14
15/// A node in the column lineage graph
16#[derive(Debug, Clone)]
17pub struct LineageNode {
18    /// Name of this lineage step (e.g., "table.column")
19    pub name: String,
20    /// The expression at this node
21    pub expression: Expression,
22    /// The source expression (the full query context)
23    pub source: Expression,
24    /// Downstream nodes that depend on this one
25    pub downstream: Vec<LineageNode>,
26    /// Optional source name (e.g., for derived tables)
27    pub source_name: String,
28    /// Optional reference node name (e.g., for CTEs)
29    pub reference_node_name: String,
30}
31
32impl LineageNode {
33    /// Create a new lineage node
34    pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
35        Self {
36            name: name.into(),
37            expression,
38            source,
39            downstream: Vec::new(),
40            source_name: String::new(),
41            reference_node_name: String::new(),
42        }
43    }
44
45    /// Iterate over all nodes in the lineage graph using DFS
46    pub fn walk(&self) -> LineageWalker<'_> {
47        LineageWalker {
48            stack: vec![self],
49        }
50    }
51
52    /// Get all downstream column names
53    pub fn downstream_names(&self) -> Vec<String> {
54        self.downstream.iter().map(|n| n.name.clone()).collect()
55    }
56}
57
58/// Iterator for walking the lineage graph
59pub struct LineageWalker<'a> {
60    stack: Vec<&'a LineageNode>,
61}
62
63impl<'a> Iterator for LineageWalker<'a> {
64    type Item = &'a LineageNode;
65
66    fn next(&mut self) -> Option<Self::Item> {
67        if let Some(node) = self.stack.pop() {
68            // Add children in reverse order so they're visited in order
69            for child in node.downstream.iter().rev() {
70                self.stack.push(child);
71            }
72            Some(node)
73        } else {
74            None
75        }
76    }
77}
78
79// ---------------------------------------------------------------------------
80// ColumnRef: name or positional index for column lookup
81// ---------------------------------------------------------------------------
82
83/// Column reference for lineage tracing — by name or positional index.
84enum ColumnRef<'a> {
85    Name(&'a str),
86    Index(usize),
87}
88
89// ---------------------------------------------------------------------------
90// Public API
91// ---------------------------------------------------------------------------
92
93/// Build the lineage graph for a column in a SQL query
94///
95/// # Arguments
96/// * `column` - The column name to trace lineage for
97/// * `sql` - The SQL expression (SELECT, UNION, etc.)
98/// * `dialect` - Optional dialect for parsing
99/// * `trim_selects` - If true, trim the source SELECT to only include the target column
100///
101/// # Returns
102/// The root lineage node for the specified column
103///
104/// # Example
105/// ```ignore
106/// use polyglot_sql::lineage::lineage;
107/// use polyglot_sql::parse_one;
108/// use polyglot_sql::DialectType;
109///
110/// let sql = "SELECT a, b + 1 AS c FROM t";
111/// let expr = parse_one(sql, DialectType::Generic).unwrap();
112/// let node = lineage("c", &expr, None, false).unwrap();
113/// ```
114pub fn lineage(
115    column: &str,
116    sql: &Expression,
117    dialect: Option<DialectType>,
118    trim_selects: bool,
119) -> Result<LineageNode> {
120    let scope = build_scope(sql);
121    to_node(
122        ColumnRef::Name(column),
123        &scope,
124        dialect,
125        "",
126        "",
127        "",
128        trim_selects,
129    )
130}
131
132/// Get all source tables from a lineage graph
133pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
134    let mut tables = HashSet::new();
135    collect_source_tables(node, &mut tables);
136    tables
137}
138
139/// Recursively collect source table names from lineage graph
140pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
141    if let Expression::Table(table) = &node.source {
142        tables.insert(table.name.name.clone());
143    }
144    for child in &node.downstream {
145        collect_source_tables(child, tables);
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Core recursive lineage builder
151// ---------------------------------------------------------------------------
152
153/// Recursively build a lineage node for a column in a scope.
154fn to_node(
155    column: ColumnRef<'_>,
156    scope: &Scope,
157    dialect: Option<DialectType>,
158    scope_name: &str,
159    source_name: &str,
160    reference_node_name: &str,
161    trim_selects: bool,
162) -> Result<LineageNode> {
163    to_node_inner(column, scope, dialect, scope_name, source_name, reference_node_name, trim_selects, &[])
164}
165
166fn to_node_inner(
167    column: ColumnRef<'_>,
168    scope: &Scope,
169    dialect: Option<DialectType>,
170    scope_name: &str,
171    source_name: &str,
172    reference_node_name: &str,
173    trim_selects: bool,
174    ancestor_cte_scopes: &[Scope],
175) -> Result<LineageNode> {
176    let scope_expr = &scope.expression;
177
178    // Build combined CTE scopes: current scope's cte_scopes + ancestors
179    let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
180    for s in ancestor_cte_scopes {
181        all_cte_scopes.push(s);
182    }
183
184    // 0. Unwrap CTE scope — CTE scope expressions are Expression::Cte(...)
185    //    but we need the inner query (SELECT/UNION) for column lookup.
186    let effective_expr = match scope_expr {
187        Expression::Cte(cte) => &cte.this,
188        other => other,
189    };
190
191    // 1. Set operations (UNION / INTERSECT / EXCEPT)
192    if matches!(
193        effective_expr,
194        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
195    ) {
196        // For CTE wrapping a set op, create a temporary scope with the inner expression
197        if matches!(scope_expr, Expression::Cte(_)) {
198            let mut inner_scope = Scope::new(effective_expr.clone());
199            inner_scope.union_scopes = scope.union_scopes.clone();
200            inner_scope.sources = scope.sources.clone();
201            inner_scope.cte_sources = scope.cte_sources.clone();
202            inner_scope.cte_scopes = scope.cte_scopes.clone();
203            inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
204            inner_scope.subquery_scopes = scope.subquery_scopes.clone();
205            return handle_set_operation(
206                &column,
207                &inner_scope,
208                dialect,
209                scope_name,
210                source_name,
211                reference_node_name,
212                trim_selects,
213                ancestor_cte_scopes,
214            );
215        }
216        return handle_set_operation(
217            &column,
218            scope,
219            dialect,
220            scope_name,
221            source_name,
222            reference_node_name,
223            trim_selects,
224            ancestor_cte_scopes,
225        );
226    }
227
228    // 2. Find the select expression for this column
229    let select_expr = find_select_expr(effective_expr, &column)?;
230    let column_name = resolve_column_name(&column, &select_expr);
231
232    // 3. Trim source if requested
233    let node_source = if trim_selects {
234        trim_source(effective_expr, &select_expr)
235    } else {
236        effective_expr.clone()
237    };
238
239    // 4. Create the lineage node
240    let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
241    node.source_name = source_name.to_string();
242    node.reference_node_name = reference_node_name.to_string();
243
244    // 5. Star handling — add downstream for each source
245    if matches!(&select_expr, Expression::Star(_)) {
246        for (name, source_info) in &scope.sources {
247            let child = LineageNode::new(
248                format!("{}.*", name),
249                Expression::Star(crate::expressions::Star {
250                    table: None,
251                    except: None,
252                    replace: None,
253                    rename: None,
254                    trailing_comments: vec![],
255                }),
256                source_info.expression.clone(),
257            );
258            node.downstream.push(child);
259        }
260        return Ok(node);
261    }
262
263    // 6. Subqueries in select — trace through scalar subqueries
264    let subqueries: Vec<&Expression> = select_expr.find_all(|e| {
265        matches!(e, Expression::Subquery(sq) if sq.alias.is_none())
266    });
267    for sq_expr in subqueries {
268        if let Expression::Subquery(sq) = sq_expr {
269            for sq_scope in &scope.subquery_scopes {
270                if sq_scope.expression == sq.this {
271                    if let Ok(child) = to_node_inner(
272                        ColumnRef::Index(0),
273                        sq_scope,
274                        dialect,
275                        &column_name,
276                        "",
277                        "",
278                        trim_selects,
279                        ancestor_cte_scopes,
280                    ) {
281                        node.downstream.push(child);
282                    }
283                    break;
284                }
285            }
286        }
287    }
288
289    // 7. Column references — trace each column to its source
290    let col_refs = find_column_refs_in_expr(&select_expr);
291    for col_ref in col_refs {
292        let col_name = &col_ref.column;
293        if let Some(ref table_id) = col_ref.table {
294            let tbl = &table_id.name;
295            resolve_qualified_column(
296                &mut node,
297                scope,
298                dialect,
299                tbl,
300                col_name,
301                &column_name,
302                trim_selects,
303                &all_cte_scopes,
304            );
305        } else {
306            resolve_unqualified_column(
307                &mut node,
308                scope,
309                dialect,
310                col_name,
311                &column_name,
312                trim_selects,
313                &all_cte_scopes,
314            );
315        }
316    }
317
318    Ok(node)
319}
320
321// ---------------------------------------------------------------------------
322// Set operation handling
323// ---------------------------------------------------------------------------
324
325fn handle_set_operation(
326    column: &ColumnRef<'_>,
327    scope: &Scope,
328    dialect: Option<DialectType>,
329    scope_name: &str,
330    source_name: &str,
331    reference_node_name: &str,
332    trim_selects: bool,
333    ancestor_cte_scopes: &[Scope],
334) -> Result<LineageNode> {
335    let scope_expr = &scope.expression;
336
337    // Determine column index
338    let col_index = match column {
339        ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
340        ColumnRef::Index(i) => *i,
341    };
342
343    let col_name = match column {
344        ColumnRef::Name(name) => name.to_string(),
345        ColumnRef::Index(_) => format!("_{col_index}"),
346    };
347
348    let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
349    node.source_name = source_name.to_string();
350    node.reference_node_name = reference_node_name.to_string();
351
352    // Recurse into each union branch
353    for branch_scope in &scope.union_scopes {
354        if let Ok(child) = to_node_inner(
355            ColumnRef::Index(col_index),
356            branch_scope,
357            dialect,
358            scope_name,
359            "",
360            "",
361            trim_selects,
362            ancestor_cte_scopes,
363        ) {
364            node.downstream.push(child);
365        }
366    }
367
368    Ok(node)
369}
370
371// ---------------------------------------------------------------------------
372// Column resolution helpers
373// ---------------------------------------------------------------------------
374
375fn resolve_qualified_column(
376    node: &mut LineageNode,
377    scope: &Scope,
378    dialect: Option<DialectType>,
379    table: &str,
380    col_name: &str,
381    parent_name: &str,
382    trim_selects: bool,
383    all_cte_scopes: &[&Scope],
384) {
385    // Check if table is a CTE reference (cte_sources tracks CTE names)
386    if scope.cte_sources.contains_key(table) {
387        if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
388            // Build ancestor CTE scopes from all_cte_scopes for the recursive call
389            let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
390            if let Ok(child) = to_node_inner(
391                ColumnRef::Name(col_name),
392                child_scope,
393                dialect,
394                parent_name,
395                table,
396                parent_name,
397                trim_selects,
398                &ancestors,
399            ) {
400                node.downstream.push(child);
401                return;
402            }
403        }
404    }
405
406    // Check if table is a derived table (is_scope = true in sources)
407    if let Some(source_info) = scope.sources.get(table) {
408        if source_info.is_scope {
409            if let Some(child_scope) = find_child_scope(scope, table) {
410                let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
411                if let Ok(child) = to_node_inner(
412                    ColumnRef::Name(col_name),
413                    child_scope,
414                    dialect,
415                    parent_name,
416                    table,
417                    parent_name,
418                    trim_selects,
419                    &ancestors,
420                ) {
421                    node.downstream.push(child);
422                    return;
423                }
424            }
425        }
426    }
427
428    // Base table or unresolved — terminal node
429    node.downstream.push(make_table_column_node(table, col_name));
430}
431
432fn resolve_unqualified_column(
433    node: &mut LineageNode,
434    scope: &Scope,
435    dialect: Option<DialectType>,
436    col_name: &str,
437    parent_name: &str,
438    trim_selects: bool,
439    all_cte_scopes: &[&Scope],
440) {
441    // Try to find which source this column belongs to.
442    // Filter to only FROM-clause sources: add_cte_source adds all CTEs to sources
443    // with Expression::Cte, but FROM-clause Table references overwrite with Expression::Table.
444    // So CTE-only entries (not referenced in FROM) have Expression::Cte — exclude those.
445    let from_source_names: Vec<&String> = scope
446        .sources
447        .iter()
448        .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
449        .map(|(name, _)| name)
450        .collect();
451
452    if from_source_names.len() == 1 {
453        let tbl = from_source_names[0];
454        resolve_qualified_column(node, scope, dialect, tbl, col_name, parent_name, trim_selects, all_cte_scopes);
455        return;
456    }
457
458    // Multiple sources — can't resolve without schema info, add unqualified node
459    let child = LineageNode::new(
460        col_name.to_string(),
461        Expression::Column(crate::expressions::Column {
462            name: crate::expressions::Identifier::new(col_name.to_string()),
463            table: None,
464            join_mark: false,
465            trailing_comments: vec![],
466        }),
467        node.source.clone(),
468    );
469    node.downstream.push(child);
470}
471
472// ---------------------------------------------------------------------------
473// Helper functions
474// ---------------------------------------------------------------------------
475
476/// Get the alias or name of an expression
477fn get_alias_or_name(expr: &Expression) -> Option<String> {
478    match expr {
479        Expression::Alias(alias) => Some(alias.alias.name.clone()),
480        Expression::Column(col) => Some(col.name.name.clone()),
481        Expression::Identifier(id) => Some(id.name.clone()),
482        Expression::Star(_) => Some("*".to_string()),
483        _ => None,
484    }
485}
486
487/// Resolve the display name for a column reference.
488fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
489    match column {
490        ColumnRef::Name(n) => n.to_string(),
491        ColumnRef::Index(_) => {
492            get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string())
493        }
494    }
495}
496
497/// Find the select expression matching a column reference.
498fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
499    if let Expression::Select(ref select) = scope_expr {
500        match column {
501            ColumnRef::Name(name) => {
502                for expr in &select.expressions {
503                    if get_alias_or_name(expr).as_deref() == Some(name) {
504                        return Ok(expr.clone());
505                    }
506                }
507                Err(crate::error::Error::Parse(format!(
508                    "Cannot find column '{}' in query",
509                    name
510                )))
511            }
512            ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
513                crate::error::Error::Parse(format!("Column index {} out of range", idx))
514            }),
515        }
516    } else {
517        Err(crate::error::Error::Parse(
518            "Expected SELECT expression for column lookup".to_string(),
519        ))
520    }
521}
522
523/// Find the positional index of a column name in a set operation's first SELECT branch.
524fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
525    let mut expr = set_op_expr;
526    loop {
527        match expr {
528            Expression::Union(u) => expr = &u.left,
529            Expression::Intersect(i) => expr = &i.left,
530            Expression::Except(e) => expr = &e.left,
531            Expression::Select(select) => {
532                for (i, e) in select.expressions.iter().enumerate() {
533                    if get_alias_or_name(e).as_deref() == Some(name) {
534                        return Ok(i);
535                    }
536                }
537                return Err(crate::error::Error::Parse(format!(
538                    "Cannot find column '{}' in set operation",
539                    name
540                )));
541            }
542            _ => {
543                return Err(crate::error::Error::Parse(
544                    "Expected SELECT or set operation".to_string(),
545                ))
546            }
547        }
548    }
549}
550
551/// If trim_selects is enabled, return a copy of the SELECT with only the target column.
552fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
553    if let Expression::Select(select) = select_expr {
554        let mut trimmed = select.as_ref().clone();
555        trimmed.expressions = vec![target_expr.clone()];
556        Expression::Select(Box::new(trimmed))
557    } else {
558        select_expr.clone()
559    }
560}
561
562/// Find the child scope (CTE or derived table) for a given source name.
563fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
564    // Check CTE scopes
565    if scope.cte_sources.contains_key(source_name) {
566        for cte_scope in &scope.cte_scopes {
567            if let Expression::Cte(cte) = &cte_scope.expression {
568                if cte.alias.name == source_name {
569                    return Some(cte_scope);
570                }
571            }
572        }
573    }
574
575    // Check derived table scopes
576    if let Some(source_info) = scope.sources.get(source_name) {
577        if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
578            if let Expression::Subquery(sq) = &source_info.expression {
579                for dt_scope in &scope.derived_table_scopes {
580                    if dt_scope.expression == sq.this {
581                        return Some(dt_scope);
582                    }
583                }
584            }
585        }
586    }
587
588    None
589}
590
591/// Find a CTE scope by name, searching through a combined list of CTE scopes.
592/// This handles nested CTEs where the current scope doesn't have the CTE scope
593/// as a direct child but knows about it via cte_sources.
594fn find_child_scope_in<'a>(
595    all_cte_scopes: &[&'a Scope],
596    scope: &'a Scope,
597    source_name: &str,
598) -> Option<&'a Scope> {
599    // First try the scope's own cte_scopes
600    for cte_scope in &scope.cte_scopes {
601        if let Expression::Cte(cte) = &cte_scope.expression {
602            if cte.alias.name == source_name {
603                return Some(cte_scope);
604            }
605        }
606    }
607
608    // Then search through all ancestor CTE scopes
609    for cte_scope in all_cte_scopes {
610        if let Expression::Cte(cte) = &cte_scope.expression {
611            if cte.alias.name == source_name {
612                return Some(cte_scope);
613            }
614        }
615    }
616
617    // Fall back to derived table scopes
618    if let Some(source_info) = scope.sources.get(source_name) {
619        if source_info.is_scope {
620            if let Expression::Subquery(sq) = &source_info.expression {
621                for dt_scope in &scope.derived_table_scopes {
622                    if dt_scope.expression == sq.this {
623                        return Some(dt_scope);
624                    }
625                }
626            }
627        }
628    }
629
630    None
631}
632
633/// Create a terminal lineage node for a table.column reference.
634fn make_table_column_node(table: &str, column: &str) -> LineageNode {
635    LineageNode::new(
636        format!("{}.{}", table, column),
637        Expression::Column(crate::expressions::Column {
638            name: crate::expressions::Identifier::new(column.to_string()),
639            table: Some(crate::expressions::Identifier::new(table.to_string())),
640            join_mark: false,
641            trailing_comments: vec![],
642        }),
643        Expression::Table(crate::expressions::TableRef::new(table)),
644    )
645}
646
647/// Simple column reference extracted from an expression
648#[derive(Debug, Clone)]
649struct SimpleColumnRef {
650    table: Option<crate::expressions::Identifier>,
651    column: String,
652}
653
654/// Find all column references in an expression (does not recurse into subqueries).
655fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
656    let mut refs = Vec::new();
657    collect_column_refs(expr, &mut refs);
658    refs
659}
660
661fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
662    match expr {
663        Expression::Column(col) => {
664            refs.push(SimpleColumnRef {
665                table: col.table.clone(),
666                column: col.name.name.clone(),
667            });
668        }
669        Expression::Alias(alias) => {
670            collect_column_refs(&alias.this, refs);
671        }
672        Expression::And(op)
673        | Expression::Or(op)
674        | Expression::Eq(op)
675        | Expression::Neq(op)
676        | Expression::Lt(op)
677        | Expression::Lte(op)
678        | Expression::Gt(op)
679        | Expression::Gte(op)
680        | Expression::Add(op)
681        | Expression::Sub(op)
682        | Expression::Mul(op)
683        | Expression::Div(op)
684        | Expression::Mod(op)
685        | Expression::BitwiseAnd(op)
686        | Expression::BitwiseOr(op)
687        | Expression::BitwiseXor(op)
688        | Expression::Concat(op) => {
689            collect_column_refs(&op.left, refs);
690            collect_column_refs(&op.right, refs);
691        }
692        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
693            collect_column_refs(&u.this, refs);
694        }
695        Expression::Function(func) => {
696            for arg in &func.args {
697                collect_column_refs(arg, refs);
698            }
699        }
700        Expression::AggregateFunction(func) => {
701            for arg in &func.args {
702                collect_column_refs(arg, refs);
703            }
704        }
705        Expression::WindowFunction(wf) => {
706            collect_column_refs(&wf.this, refs);
707        }
708        Expression::Case(case) => {
709            if let Some(operand) = &case.operand {
710                collect_column_refs(operand, refs);
711            }
712            for (cond, result) in &case.whens {
713                collect_column_refs(cond, refs);
714                collect_column_refs(result, refs);
715            }
716            if let Some(ref else_expr) = case.else_ {
717                collect_column_refs(else_expr, refs);
718            }
719        }
720        Expression::Cast(cast) => {
721            collect_column_refs(&cast.this, refs);
722        }
723        Expression::Paren(p) => {
724            collect_column_refs(&p.this, refs);
725        }
726        Expression::Coalesce(c) => {
727            for e in &c.expressions {
728                collect_column_refs(e, refs);
729            }
730        }
731        // Don't recurse into subqueries — those are handled separately
732        Expression::Subquery(_) | Expression::Exists(_) => {}
733        _ => {}
734    }
735}
736
737// ---------------------------------------------------------------------------
738// Tests
739// ---------------------------------------------------------------------------
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744    use crate::dialects::{Dialect, DialectType};
745
746    fn parse(sql: &str) -> Expression {
747        let dialect = Dialect::get(DialectType::Generic);
748        let ast = dialect.parse(sql).unwrap();
749        ast.into_iter().next().unwrap()
750    }
751
752    #[test]
753    fn test_simple_lineage() {
754        let expr = parse("SELECT a FROM t");
755        let node = lineage("a", &expr, None, false).unwrap();
756
757        assert_eq!(node.name, "a");
758        assert!(!node.downstream.is_empty(), "Should have downstream nodes");
759        // Should trace to t.a
760        let names = node.downstream_names();
761        assert!(
762            names.iter().any(|n| n == "t.a"),
763            "Expected t.a in downstream, got: {:?}",
764            names
765        );
766    }
767
768    #[test]
769    fn test_lineage_walk() {
770        let root = LineageNode {
771            name: "col_a".to_string(),
772            expression: Expression::Null(crate::expressions::Null),
773            source: Expression::Null(crate::expressions::Null),
774            downstream: vec![LineageNode::new(
775                "t.a",
776                Expression::Null(crate::expressions::Null),
777                Expression::Null(crate::expressions::Null),
778            )],
779            source_name: String::new(),
780            reference_node_name: String::new(),
781        };
782
783        let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
784        assert_eq!(names.len(), 2);
785        assert_eq!(names[0], "col_a");
786        assert_eq!(names[1], "t.a");
787    }
788
789    #[test]
790    fn test_aliased_column() {
791        let expr = parse("SELECT a + 1 AS b FROM t");
792        let node = lineage("b", &expr, None, false).unwrap();
793
794        assert_eq!(node.name, "b");
795        // Should trace through the expression to t.a
796        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
797        assert!(
798            all_names.iter().any(|n| n.contains("a")),
799            "Expected to trace to column a, got: {:?}",
800            all_names
801        );
802    }
803
804    #[test]
805    fn test_qualified_column() {
806        let expr = parse("SELECT t.a FROM t");
807        let node = lineage("a", &expr, None, false).unwrap();
808
809        assert_eq!(node.name, "a");
810        let names = node.downstream_names();
811        assert!(
812            names.iter().any(|n| n == "t.a"),
813            "Expected t.a, got: {:?}",
814            names
815        );
816    }
817
818    #[test]
819    fn test_unqualified_column() {
820        let expr = parse("SELECT a FROM t");
821        let node = lineage("a", &expr, None, false).unwrap();
822
823        // Unqualified but single source → resolved to t.a
824        let names = node.downstream_names();
825        assert!(
826            names.iter().any(|n| n == "t.a"),
827            "Expected t.a, got: {:?}",
828            names
829        );
830    }
831
832    #[test]
833    fn test_lineage_join() {
834        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
835
836        let node_a = lineage("a", &expr, None, false).unwrap();
837        let names_a = node_a.downstream_names();
838        assert!(
839            names_a.iter().any(|n| n == "t.a"),
840            "Expected t.a, got: {:?}",
841            names_a
842        );
843
844        let node_b = lineage("b", &expr, None, false).unwrap();
845        let names_b = node_b.downstream_names();
846        assert!(
847            names_b.iter().any(|n| n == "s.b"),
848            "Expected s.b, got: {:?}",
849            names_b
850        );
851    }
852
853    #[test]
854    fn test_lineage_derived_table() {
855        let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
856        let node = lineage("a", &expr, None, false).unwrap();
857
858        assert_eq!(node.name, "a");
859        // Should trace through the derived table to t.a
860        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
861        assert!(
862            all_names.iter().any(|n| n == "t.a"),
863            "Expected to trace through derived table to t.a, got: {:?}",
864            all_names
865        );
866    }
867
868    #[test]
869    fn test_lineage_cte() {
870        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
871        let node = lineage("a", &expr, None, false).unwrap();
872
873        assert_eq!(node.name, "a");
874        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
875        assert!(
876            all_names.iter().any(|n| n == "t.a"),
877            "Expected to trace through CTE to t.a, got: {:?}",
878            all_names
879        );
880    }
881
882    #[test]
883    fn test_lineage_union() {
884        let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
885        let node = lineage("a", &expr, None, false).unwrap();
886
887        assert_eq!(node.name, "a");
888        // Should have 2 downstream branches
889        assert_eq!(
890            node.downstream.len(),
891            2,
892            "Expected 2 branches for UNION, got {}",
893            node.downstream.len()
894        );
895    }
896
897    #[test]
898    fn test_lineage_cte_union() {
899        let expr = parse(
900            "WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte",
901        );
902        let node = lineage("a", &expr, None, false).unwrap();
903
904        // Should trace through CTE into both UNION branches
905        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
906        assert!(
907            all_names.len() >= 3,
908            "Expected at least 3 nodes for CTE with UNION, got: {:?}",
909            all_names
910        );
911    }
912
913    #[test]
914    fn test_lineage_star() {
915        let expr = parse("SELECT * FROM t");
916        let node = lineage("*", &expr, None, false).unwrap();
917
918        assert_eq!(node.name, "*");
919        // Should have downstream for table t
920        assert!(
921            !node.downstream.is_empty(),
922            "Star should produce downstream nodes"
923        );
924    }
925
926    #[test]
927    fn test_lineage_subquery_in_select() {
928        let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
929        let node = lineage("x", &expr, None, false).unwrap();
930
931        assert_eq!(node.name, "x");
932        // Should have traced into the scalar subquery
933        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
934        assert!(
935            all_names.len() >= 2,
936            "Expected tracing into scalar subquery, got: {:?}",
937            all_names
938        );
939    }
940
941    #[test]
942    fn test_lineage_multiple_columns() {
943        let expr = parse("SELECT a, b FROM t");
944
945        let node_a = lineage("a", &expr, None, false).unwrap();
946        let node_b = lineage("b", &expr, None, false).unwrap();
947
948        assert_eq!(node_a.name, "a");
949        assert_eq!(node_b.name, "b");
950
951        // Each should trace independently
952        let names_a = node_a.downstream_names();
953        let names_b = node_b.downstream_names();
954        assert!(names_a.iter().any(|n| n == "t.a"));
955        assert!(names_b.iter().any(|n| n == "t.b"));
956    }
957
958    #[test]
959    fn test_get_source_tables() {
960        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
961        let node = lineage("a", &expr, None, false).unwrap();
962
963        let tables = get_source_tables(&node);
964        assert!(
965            tables.contains("t"),
966            "Expected source table 't', got: {:?}",
967            tables
968        );
969    }
970
971    #[test]
972    fn test_lineage_column_not_found() {
973        let expr = parse("SELECT a FROM t");
974        let result = lineage("nonexistent", &expr, None, false);
975        assert!(result.is_err());
976    }
977
978    #[test]
979    fn test_lineage_nested_cte() {
980        let expr = parse(
981            "WITH cte1 AS (SELECT a FROM t), \
982             cte2 AS (SELECT a FROM cte1) \
983             SELECT a FROM cte2",
984        );
985        let node = lineage("a", &expr, None, false).unwrap();
986
987        // Should trace through cte2 → cte1 → t
988        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
989        assert!(
990            all_names.len() >= 3,
991            "Expected to trace through nested CTEs, got: {:?}",
992            all_names
993        );
994    }
995
996    #[test]
997    fn test_trim_selects_true() {
998        let expr = parse("SELECT a, b, c FROM t");
999        let node = lineage("a", &expr, None, true).unwrap();
1000
1001        // The source should be trimmed to only include 'a'
1002        if let Expression::Select(select) = &node.source {
1003            assert_eq!(
1004                select.expressions.len(),
1005                1,
1006                "Trimmed source should have 1 expression, got {}",
1007                select.expressions.len()
1008            );
1009        } else {
1010            panic!("Expected Select source");
1011        }
1012    }
1013
1014    #[test]
1015    fn test_trim_selects_false() {
1016        let expr = parse("SELECT a, b, c FROM t");
1017        let node = lineage("a", &expr, None, false).unwrap();
1018
1019        // The source should keep all columns
1020        if let Expression::Select(select) = &node.source {
1021            assert_eq!(
1022                select.expressions.len(),
1023                3,
1024                "Untrimmed source should have 3 expressions"
1025            );
1026        } else {
1027            panic!("Expected Select source");
1028        }
1029    }
1030
1031    #[test]
1032    fn test_lineage_expression_in_select() {
1033        let expr = parse("SELECT a + b AS c FROM t");
1034        let node = lineage("c", &expr, None, false).unwrap();
1035
1036        // Should trace to both a and b from t
1037        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1038        assert!(
1039            all_names.len() >= 3,
1040            "Expected to trace a + b to both columns, got: {:?}",
1041            all_names
1042        );
1043    }
1044
1045    #[test]
1046    fn test_set_operation_by_index() {
1047        let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1048
1049        // Trace column "a" which is at index 0
1050        let node = lineage("a", &expr, None, false).unwrap();
1051
1052        // UNION branches should be traced by index
1053        assert_eq!(node.downstream.len(), 2);
1054    }
1055}