Skip to main content

polyglot_sql/
lineage.rs

1//! Column Lineage Tracking
2//!
3//! This module provides functionality to track column lineage through SQL queries,
4//! building a graph of how columns flow from source tables to the result set.
5//! Supports UNION/INTERSECT/EXCEPT, CTEs, derived tables, subqueries, and star expansion.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16/// A node in the column lineage graph
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19    /// Name of this lineage step (e.g., "table.column")
20    pub name: String,
21    /// The expression at this node
22    pub expression: Expression,
23    /// The source expression (the full query context)
24    pub source: Expression,
25    /// Downstream nodes that depend on this one
26    pub downstream: Vec<LineageNode>,
27    /// Optional source name (e.g., for derived tables)
28    pub source_name: String,
29    /// Optional reference node name (e.g., for CTEs)
30    pub reference_node_name: String,
31}
32
33impl LineageNode {
34    /// Create a new lineage node
35    pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36        Self {
37            name: name.into(),
38            expression,
39            source,
40            downstream: Vec::new(),
41            source_name: String::new(),
42            reference_node_name: String::new(),
43        }
44    }
45
46    /// Iterate over all nodes in the lineage graph using DFS
47    pub fn walk(&self) -> LineageWalker<'_> {
48        LineageWalker { stack: vec![self] }
49    }
50
51    /// Get all downstream column names
52    pub fn downstream_names(&self) -> Vec<String> {
53        self.downstream.iter().map(|n| n.name.clone()).collect()
54    }
55}
56
57/// Iterator for walking the lineage graph
58pub struct LineageWalker<'a> {
59    stack: Vec<&'a LineageNode>,
60}
61
62impl<'a> Iterator for LineageWalker<'a> {
63    type Item = &'a LineageNode;
64
65    fn next(&mut self) -> Option<Self::Item> {
66        if let Some(node) = self.stack.pop() {
67            // Add children in reverse order so they're visited in order
68            for child in node.downstream.iter().rev() {
69                self.stack.push(child);
70            }
71            Some(node)
72        } else {
73            None
74        }
75    }
76}
77
78// ---------------------------------------------------------------------------
79// ColumnRef: name or positional index for column lookup
80// ---------------------------------------------------------------------------
81
82/// Column reference for lineage tracing — by name or positional index.
83enum ColumnRef<'a> {
84    Name(&'a str),
85    Index(usize),
86}
87
88// ---------------------------------------------------------------------------
89// Public API
90// ---------------------------------------------------------------------------
91
92/// Build the lineage graph for a column in a SQL query
93///
94/// # Arguments
95/// * `column` - The column name to trace lineage for
96/// * `sql` - The SQL expression (SELECT, UNION, etc.)
97/// * `dialect` - Optional dialect for parsing
98/// * `trim_selects` - If true, trim the source SELECT to only include the target column
99///
100/// # Returns
101/// The root lineage node for the specified column
102///
103/// # Example
104/// ```ignore
105/// use polyglot_sql::lineage::lineage;
106/// use polyglot_sql::parse_one;
107/// use polyglot_sql::DialectType;
108///
109/// let sql = "SELECT a, b + 1 AS c FROM t";
110/// let expr = parse_one(sql, DialectType::Generic).unwrap();
111/// let node = lineage("c", &expr, None, false).unwrap();
112/// ```
113pub fn lineage(
114    column: &str,
115    sql: &Expression,
116    dialect: Option<DialectType>,
117    trim_selects: bool,
118) -> Result<LineageNode> {
119    let scope = build_scope(sql);
120    to_node(
121        ColumnRef::Name(column),
122        &scope,
123        dialect,
124        "",
125        "",
126        "",
127        trim_selects,
128    )
129}
130
131/// Get all source tables from a lineage graph
132pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
133    let mut tables = HashSet::new();
134    collect_source_tables(node, &mut tables);
135    tables
136}
137
138/// Recursively collect source table names from lineage graph
139pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
140    if let Expression::Table(table) = &node.source {
141        tables.insert(table.name.name.clone());
142    }
143    for child in &node.downstream {
144        collect_source_tables(child, tables);
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Core recursive lineage builder
150// ---------------------------------------------------------------------------
151
152/// Recursively build a lineage node for a column in a scope.
153fn to_node(
154    column: ColumnRef<'_>,
155    scope: &Scope,
156    dialect: Option<DialectType>,
157    scope_name: &str,
158    source_name: &str,
159    reference_node_name: &str,
160    trim_selects: bool,
161) -> Result<LineageNode> {
162    to_node_inner(
163        column,
164        scope,
165        dialect,
166        scope_name,
167        source_name,
168        reference_node_name,
169        trim_selects,
170        &[],
171    )
172}
173
174fn to_node_inner(
175    column: ColumnRef<'_>,
176    scope: &Scope,
177    dialect: Option<DialectType>,
178    scope_name: &str,
179    source_name: &str,
180    reference_node_name: &str,
181    trim_selects: bool,
182    ancestor_cte_scopes: &[Scope],
183) -> Result<LineageNode> {
184    let scope_expr = &scope.expression;
185
186    // Build combined CTE scopes: current scope's cte_scopes + ancestors
187    let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
188    for s in ancestor_cte_scopes {
189        all_cte_scopes.push(s);
190    }
191
192    // 0. Unwrap CTE scope — CTE scope expressions are Expression::Cte(...)
193    //    but we need the inner query (SELECT/UNION) for column lookup.
194    let effective_expr = match scope_expr {
195        Expression::Cte(cte) => &cte.this,
196        other => other,
197    };
198
199    // 1. Set operations (UNION / INTERSECT / EXCEPT)
200    if matches!(
201        effective_expr,
202        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
203    ) {
204        // For CTE wrapping a set op, create a temporary scope with the inner expression
205        if matches!(scope_expr, Expression::Cte(_)) {
206            let mut inner_scope = Scope::new(effective_expr.clone());
207            inner_scope.union_scopes = scope.union_scopes.clone();
208            inner_scope.sources = scope.sources.clone();
209            inner_scope.cte_sources = scope.cte_sources.clone();
210            inner_scope.cte_scopes = scope.cte_scopes.clone();
211            inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
212            inner_scope.subquery_scopes = scope.subquery_scopes.clone();
213            return handle_set_operation(
214                &column,
215                &inner_scope,
216                dialect,
217                scope_name,
218                source_name,
219                reference_node_name,
220                trim_selects,
221                ancestor_cte_scopes,
222            );
223        }
224        return handle_set_operation(
225            &column,
226            scope,
227            dialect,
228            scope_name,
229            source_name,
230            reference_node_name,
231            trim_selects,
232            ancestor_cte_scopes,
233        );
234    }
235
236    // 2. Find the select expression for this column
237    let select_expr = find_select_expr(effective_expr, &column)?;
238    let column_name = resolve_column_name(&column, &select_expr);
239
240    // 3. Trim source if requested
241    let node_source = if trim_selects {
242        trim_source(effective_expr, &select_expr)
243    } else {
244        effective_expr.clone()
245    };
246
247    // 4. Create the lineage node
248    let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
249    node.source_name = source_name.to_string();
250    node.reference_node_name = reference_node_name.to_string();
251
252    // 5. Star handling — add downstream for each source
253    if matches!(&select_expr, Expression::Star(_)) {
254        for (name, source_info) in &scope.sources {
255            let child = LineageNode::new(
256                format!("{}.*", name),
257                Expression::Star(crate::expressions::Star {
258                    table: None,
259                    except: None,
260                    replace: None,
261                    rename: None,
262                    trailing_comments: vec![],
263                }),
264                source_info.expression.clone(),
265            );
266            node.downstream.push(child);
267        }
268        return Ok(node);
269    }
270
271    // 6. Subqueries in select — trace through scalar subqueries
272    let subqueries: Vec<&Expression> =
273        select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
274    for sq_expr in subqueries {
275        if let Expression::Subquery(sq) = sq_expr {
276            for sq_scope in &scope.subquery_scopes {
277                if sq_scope.expression == sq.this {
278                    if let Ok(child) = to_node_inner(
279                        ColumnRef::Index(0),
280                        sq_scope,
281                        dialect,
282                        &column_name,
283                        "",
284                        "",
285                        trim_selects,
286                        ancestor_cte_scopes,
287                    ) {
288                        node.downstream.push(child);
289                    }
290                    break;
291                }
292            }
293        }
294    }
295
296    // 7. Column references — trace each column to its source
297    let col_refs = find_column_refs_in_expr(&select_expr);
298    for col_ref in col_refs {
299        let col_name = &col_ref.column;
300        if let Some(ref table_id) = col_ref.table {
301            let tbl = &table_id.name;
302            resolve_qualified_column(
303                &mut node,
304                scope,
305                dialect,
306                tbl,
307                col_name,
308                &column_name,
309                trim_selects,
310                &all_cte_scopes,
311            );
312        } else {
313            resolve_unqualified_column(
314                &mut node,
315                scope,
316                dialect,
317                col_name,
318                &column_name,
319                trim_selects,
320                &all_cte_scopes,
321            );
322        }
323    }
324
325    Ok(node)
326}
327
328// ---------------------------------------------------------------------------
329// Set operation handling
330// ---------------------------------------------------------------------------
331
332fn handle_set_operation(
333    column: &ColumnRef<'_>,
334    scope: &Scope,
335    dialect: Option<DialectType>,
336    scope_name: &str,
337    source_name: &str,
338    reference_node_name: &str,
339    trim_selects: bool,
340    ancestor_cte_scopes: &[Scope],
341) -> Result<LineageNode> {
342    let scope_expr = &scope.expression;
343
344    // Determine column index
345    let col_index = match column {
346        ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
347        ColumnRef::Index(i) => *i,
348    };
349
350    let col_name = match column {
351        ColumnRef::Name(name) => name.to_string(),
352        ColumnRef::Index(_) => format!("_{col_index}"),
353    };
354
355    let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
356    node.source_name = source_name.to_string();
357    node.reference_node_name = reference_node_name.to_string();
358
359    // Recurse into each union branch
360    for branch_scope in &scope.union_scopes {
361        if let Ok(child) = to_node_inner(
362            ColumnRef::Index(col_index),
363            branch_scope,
364            dialect,
365            scope_name,
366            "",
367            "",
368            trim_selects,
369            ancestor_cte_scopes,
370        ) {
371            node.downstream.push(child);
372        }
373    }
374
375    Ok(node)
376}
377
378// ---------------------------------------------------------------------------
379// Column resolution helpers
380// ---------------------------------------------------------------------------
381
382fn resolve_qualified_column(
383    node: &mut LineageNode,
384    scope: &Scope,
385    dialect: Option<DialectType>,
386    table: &str,
387    col_name: &str,
388    parent_name: &str,
389    trim_selects: bool,
390    all_cte_scopes: &[&Scope],
391) {
392    // Check if table is a CTE reference (cte_sources tracks CTE names)
393    if scope.cte_sources.contains_key(table) {
394        if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
395            // Build ancestor CTE scopes from all_cte_scopes for the recursive call
396            let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
397            if let Ok(child) = to_node_inner(
398                ColumnRef::Name(col_name),
399                child_scope,
400                dialect,
401                parent_name,
402                table,
403                parent_name,
404                trim_selects,
405                &ancestors,
406            ) {
407                node.downstream.push(child);
408                return;
409            }
410        }
411    }
412
413    // Check if table is a derived table (is_scope = true in sources)
414    if let Some(source_info) = scope.sources.get(table) {
415        if source_info.is_scope {
416            if let Some(child_scope) = find_child_scope(scope, table) {
417                let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
418                if let Ok(child) = to_node_inner(
419                    ColumnRef::Name(col_name),
420                    child_scope,
421                    dialect,
422                    parent_name,
423                    table,
424                    parent_name,
425                    trim_selects,
426                    &ancestors,
427                ) {
428                    node.downstream.push(child);
429                    return;
430                }
431            }
432        }
433    }
434
435    // Base table or unresolved — terminal node
436    node.downstream
437        .push(make_table_column_node(table, col_name));
438}
439
440fn resolve_unqualified_column(
441    node: &mut LineageNode,
442    scope: &Scope,
443    dialect: Option<DialectType>,
444    col_name: &str,
445    parent_name: &str,
446    trim_selects: bool,
447    all_cte_scopes: &[&Scope],
448) {
449    // Try to find which source this column belongs to.
450    // Filter to only FROM-clause sources: add_cte_source adds all CTEs to sources
451    // with Expression::Cte, but FROM-clause Table references overwrite with Expression::Table.
452    // So CTE-only entries (not referenced in FROM) have Expression::Cte — exclude those.
453    let from_source_names: Vec<&String> = scope
454        .sources
455        .iter()
456        .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
457        .map(|(name, _)| name)
458        .collect();
459
460    if from_source_names.len() == 1 {
461        let tbl = from_source_names[0];
462        resolve_qualified_column(
463            node,
464            scope,
465            dialect,
466            tbl,
467            col_name,
468            parent_name,
469            trim_selects,
470            all_cte_scopes,
471        );
472        return;
473    }
474
475    // Multiple sources — can't resolve without schema info, add unqualified node
476    let child = LineageNode::new(
477        col_name.to_string(),
478        Expression::Column(crate::expressions::Column {
479            name: crate::expressions::Identifier::new(col_name.to_string()),
480            table: None,
481            join_mark: false,
482            trailing_comments: vec![],
483        }),
484        node.source.clone(),
485    );
486    node.downstream.push(child);
487}
488
489// ---------------------------------------------------------------------------
490// Helper functions
491// ---------------------------------------------------------------------------
492
493/// Get the alias or name of an expression
494fn get_alias_or_name(expr: &Expression) -> Option<String> {
495    match expr {
496        Expression::Alias(alias) => Some(alias.alias.name.clone()),
497        Expression::Column(col) => Some(col.name.name.clone()),
498        Expression::Identifier(id) => Some(id.name.clone()),
499        Expression::Star(_) => Some("*".to_string()),
500        _ => None,
501    }
502}
503
504/// Resolve the display name for a column reference.
505fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
506    match column {
507        ColumnRef::Name(n) => n.to_string(),
508        ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
509    }
510}
511
512/// Find the select expression matching a column reference.
513fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
514    if let Expression::Select(ref select) = scope_expr {
515        match column {
516            ColumnRef::Name(name) => {
517                for expr in &select.expressions {
518                    if get_alias_or_name(expr).as_deref() == Some(name) {
519                        return Ok(expr.clone());
520                    }
521                }
522                Err(crate::error::Error::Parse(format!(
523                    "Cannot find column '{}' in query",
524                    name
525                )))
526            }
527            ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
528                crate::error::Error::Parse(format!("Column index {} out of range", idx))
529            }),
530        }
531    } else {
532        Err(crate::error::Error::Parse(
533            "Expected SELECT expression for column lookup".to_string(),
534        ))
535    }
536}
537
538/// Find the positional index of a column name in a set operation's first SELECT branch.
539fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
540    let mut expr = set_op_expr;
541    loop {
542        match expr {
543            Expression::Union(u) => expr = &u.left,
544            Expression::Intersect(i) => expr = &i.left,
545            Expression::Except(e) => expr = &e.left,
546            Expression::Select(select) => {
547                for (i, e) in select.expressions.iter().enumerate() {
548                    if get_alias_or_name(e).as_deref() == Some(name) {
549                        return Ok(i);
550                    }
551                }
552                return Err(crate::error::Error::Parse(format!(
553                    "Cannot find column '{}' in set operation",
554                    name
555                )));
556            }
557            _ => {
558                return Err(crate::error::Error::Parse(
559                    "Expected SELECT or set operation".to_string(),
560                ))
561            }
562        }
563    }
564}
565
566/// If trim_selects is enabled, return a copy of the SELECT with only the target column.
567fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
568    if let Expression::Select(select) = select_expr {
569        let mut trimmed = select.as_ref().clone();
570        trimmed.expressions = vec![target_expr.clone()];
571        Expression::Select(Box::new(trimmed))
572    } else {
573        select_expr.clone()
574    }
575}
576
577/// Find the child scope (CTE or derived table) for a given source name.
578fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
579    // Check CTE scopes
580    if scope.cte_sources.contains_key(source_name) {
581        for cte_scope in &scope.cte_scopes {
582            if let Expression::Cte(cte) = &cte_scope.expression {
583                if cte.alias.name == source_name {
584                    return Some(cte_scope);
585                }
586            }
587        }
588    }
589
590    // Check derived table scopes
591    if let Some(source_info) = scope.sources.get(source_name) {
592        if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
593            if let Expression::Subquery(sq) = &source_info.expression {
594                for dt_scope in &scope.derived_table_scopes {
595                    if dt_scope.expression == sq.this {
596                        return Some(dt_scope);
597                    }
598                }
599            }
600        }
601    }
602
603    None
604}
605
606/// Find a CTE scope by name, searching through a combined list of CTE scopes.
607/// This handles nested CTEs where the current scope doesn't have the CTE scope
608/// as a direct child but knows about it via cte_sources.
609fn find_child_scope_in<'a>(
610    all_cte_scopes: &[&'a Scope],
611    scope: &'a Scope,
612    source_name: &str,
613) -> Option<&'a Scope> {
614    // First try the scope's own cte_scopes
615    for cte_scope in &scope.cte_scopes {
616        if let Expression::Cte(cte) = &cte_scope.expression {
617            if cte.alias.name == source_name {
618                return Some(cte_scope);
619            }
620        }
621    }
622
623    // Then search through all ancestor CTE scopes
624    for cte_scope in all_cte_scopes {
625        if let Expression::Cte(cte) = &cte_scope.expression {
626            if cte.alias.name == source_name {
627                return Some(cte_scope);
628            }
629        }
630    }
631
632    // Fall back to derived table scopes
633    if let Some(source_info) = scope.sources.get(source_name) {
634        if source_info.is_scope {
635            if let Expression::Subquery(sq) = &source_info.expression {
636                for dt_scope in &scope.derived_table_scopes {
637                    if dt_scope.expression == sq.this {
638                        return Some(dt_scope);
639                    }
640                }
641            }
642        }
643    }
644
645    None
646}
647
648/// Create a terminal lineage node for a table.column reference.
649fn make_table_column_node(table: &str, column: &str) -> LineageNode {
650    LineageNode::new(
651        format!("{}.{}", table, column),
652        Expression::Column(crate::expressions::Column {
653            name: crate::expressions::Identifier::new(column.to_string()),
654            table: Some(crate::expressions::Identifier::new(table.to_string())),
655            join_mark: false,
656            trailing_comments: vec![],
657        }),
658        Expression::Table(crate::expressions::TableRef::new(table)),
659    )
660}
661
662/// Simple column reference extracted from an expression
663#[derive(Debug, Clone)]
664struct SimpleColumnRef {
665    table: Option<crate::expressions::Identifier>,
666    column: String,
667}
668
669/// Find all column references in an expression (does not recurse into subqueries).
670fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
671    let mut refs = Vec::new();
672    collect_column_refs(expr, &mut refs);
673    refs
674}
675
676fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
677    match expr {
678        Expression::Column(col) => {
679            refs.push(SimpleColumnRef {
680                table: col.table.clone(),
681                column: col.name.name.clone(),
682            });
683        }
684        Expression::Alias(alias) => {
685            collect_column_refs(&alias.this, refs);
686        }
687        Expression::And(op)
688        | Expression::Or(op)
689        | Expression::Eq(op)
690        | Expression::Neq(op)
691        | Expression::Lt(op)
692        | Expression::Lte(op)
693        | Expression::Gt(op)
694        | Expression::Gte(op)
695        | Expression::Add(op)
696        | Expression::Sub(op)
697        | Expression::Mul(op)
698        | Expression::Div(op)
699        | Expression::Mod(op)
700        | Expression::BitwiseAnd(op)
701        | Expression::BitwiseOr(op)
702        | Expression::BitwiseXor(op)
703        | Expression::Concat(op) => {
704            collect_column_refs(&op.left, refs);
705            collect_column_refs(&op.right, refs);
706        }
707        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
708            collect_column_refs(&u.this, refs);
709        }
710        Expression::Function(func) => {
711            for arg in &func.args {
712                collect_column_refs(arg, refs);
713            }
714        }
715        Expression::AggregateFunction(func) => {
716            for arg in &func.args {
717                collect_column_refs(arg, refs);
718            }
719        }
720        Expression::WindowFunction(wf) => {
721            collect_column_refs(&wf.this, refs);
722        }
723        Expression::Case(case) => {
724            if let Some(operand) = &case.operand {
725                collect_column_refs(operand, refs);
726            }
727            for (cond, result) in &case.whens {
728                collect_column_refs(cond, refs);
729                collect_column_refs(result, refs);
730            }
731            if let Some(ref else_expr) = case.else_ {
732                collect_column_refs(else_expr, refs);
733            }
734        }
735        Expression::Cast(cast) => {
736            collect_column_refs(&cast.this, refs);
737        }
738        Expression::Paren(p) => {
739            collect_column_refs(&p.this, refs);
740        }
741        Expression::Coalesce(c) => {
742            for e in &c.expressions {
743                collect_column_refs(e, refs);
744            }
745        }
746        // Don't recurse into subqueries — those are handled separately
747        Expression::Subquery(_) | Expression::Exists(_) => {}
748        _ => {}
749    }
750}
751
752// ---------------------------------------------------------------------------
753// Tests
754// ---------------------------------------------------------------------------
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759    use crate::dialects::{Dialect, DialectType};
760
761    fn parse(sql: &str) -> Expression {
762        let dialect = Dialect::get(DialectType::Generic);
763        let ast = dialect.parse(sql).unwrap();
764        ast.into_iter().next().unwrap()
765    }
766
767    #[test]
768    fn test_simple_lineage() {
769        let expr = parse("SELECT a FROM t");
770        let node = lineage("a", &expr, None, false).unwrap();
771
772        assert_eq!(node.name, "a");
773        assert!(!node.downstream.is_empty(), "Should have downstream nodes");
774        // Should trace to t.a
775        let names = node.downstream_names();
776        assert!(
777            names.iter().any(|n| n == "t.a"),
778            "Expected t.a in downstream, got: {:?}",
779            names
780        );
781    }
782
783    #[test]
784    fn test_lineage_walk() {
785        let root = LineageNode {
786            name: "col_a".to_string(),
787            expression: Expression::Null(crate::expressions::Null),
788            source: Expression::Null(crate::expressions::Null),
789            downstream: vec![LineageNode::new(
790                "t.a",
791                Expression::Null(crate::expressions::Null),
792                Expression::Null(crate::expressions::Null),
793            )],
794            source_name: String::new(),
795            reference_node_name: String::new(),
796        };
797
798        let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
799        assert_eq!(names.len(), 2);
800        assert_eq!(names[0], "col_a");
801        assert_eq!(names[1], "t.a");
802    }
803
804    #[test]
805    fn test_aliased_column() {
806        let expr = parse("SELECT a + 1 AS b FROM t");
807        let node = lineage("b", &expr, None, false).unwrap();
808
809        assert_eq!(node.name, "b");
810        // Should trace through the expression to t.a
811        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
812        assert!(
813            all_names.iter().any(|n| n.contains("a")),
814            "Expected to trace to column a, got: {:?}",
815            all_names
816        );
817    }
818
819    #[test]
820    fn test_qualified_column() {
821        let expr = parse("SELECT t.a FROM t");
822        let node = lineage("a", &expr, None, false).unwrap();
823
824        assert_eq!(node.name, "a");
825        let names = node.downstream_names();
826        assert!(
827            names.iter().any(|n| n == "t.a"),
828            "Expected t.a, got: {:?}",
829            names
830        );
831    }
832
833    #[test]
834    fn test_unqualified_column() {
835        let expr = parse("SELECT a FROM t");
836        let node = lineage("a", &expr, None, false).unwrap();
837
838        // Unqualified but single source → resolved to t.a
839        let names = node.downstream_names();
840        assert!(
841            names.iter().any(|n| n == "t.a"),
842            "Expected t.a, got: {:?}",
843            names
844        );
845    }
846
847    #[test]
848    fn test_lineage_join() {
849        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
850
851        let node_a = lineage("a", &expr, None, false).unwrap();
852        let names_a = node_a.downstream_names();
853        assert!(
854            names_a.iter().any(|n| n == "t.a"),
855            "Expected t.a, got: {:?}",
856            names_a
857        );
858
859        let node_b = lineage("b", &expr, None, false).unwrap();
860        let names_b = node_b.downstream_names();
861        assert!(
862            names_b.iter().any(|n| n == "s.b"),
863            "Expected s.b, got: {:?}",
864            names_b
865        );
866    }
867
868    #[test]
869    fn test_lineage_derived_table() {
870        let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
871        let node = lineage("a", &expr, None, false).unwrap();
872
873        assert_eq!(node.name, "a");
874        // Should trace through the derived table to t.a
875        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
876        assert!(
877            all_names.iter().any(|n| n == "t.a"),
878            "Expected to trace through derived table to t.a, got: {:?}",
879            all_names
880        );
881    }
882
883    #[test]
884    fn test_lineage_cte() {
885        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
886        let node = lineage("a", &expr, None, false).unwrap();
887
888        assert_eq!(node.name, "a");
889        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
890        assert!(
891            all_names.iter().any(|n| n == "t.a"),
892            "Expected to trace through CTE to t.a, got: {:?}",
893            all_names
894        );
895    }
896
897    #[test]
898    fn test_lineage_union() {
899        let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
900        let node = lineage("a", &expr, None, false).unwrap();
901
902        assert_eq!(node.name, "a");
903        // Should have 2 downstream branches
904        assert_eq!(
905            node.downstream.len(),
906            2,
907            "Expected 2 branches for UNION, got {}",
908            node.downstream.len()
909        );
910    }
911
912    #[test]
913    fn test_lineage_cte_union() {
914        let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
915        let node = lineage("a", &expr, None, false).unwrap();
916
917        // Should trace through CTE into both UNION branches
918        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
919        assert!(
920            all_names.len() >= 3,
921            "Expected at least 3 nodes for CTE with UNION, got: {:?}",
922            all_names
923        );
924    }
925
926    #[test]
927    fn test_lineage_star() {
928        let expr = parse("SELECT * FROM t");
929        let node = lineage("*", &expr, None, false).unwrap();
930
931        assert_eq!(node.name, "*");
932        // Should have downstream for table t
933        assert!(
934            !node.downstream.is_empty(),
935            "Star should produce downstream nodes"
936        );
937    }
938
939    #[test]
940    fn test_lineage_subquery_in_select() {
941        let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
942        let node = lineage("x", &expr, None, false).unwrap();
943
944        assert_eq!(node.name, "x");
945        // Should have traced into the scalar subquery
946        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
947        assert!(
948            all_names.len() >= 2,
949            "Expected tracing into scalar subquery, got: {:?}",
950            all_names
951        );
952    }
953
954    #[test]
955    fn test_lineage_multiple_columns() {
956        let expr = parse("SELECT a, b FROM t");
957
958        let node_a = lineage("a", &expr, None, false).unwrap();
959        let node_b = lineage("b", &expr, None, false).unwrap();
960
961        assert_eq!(node_a.name, "a");
962        assert_eq!(node_b.name, "b");
963
964        // Each should trace independently
965        let names_a = node_a.downstream_names();
966        let names_b = node_b.downstream_names();
967        assert!(names_a.iter().any(|n| n == "t.a"));
968        assert!(names_b.iter().any(|n| n == "t.b"));
969    }
970
971    #[test]
972    fn test_get_source_tables() {
973        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
974        let node = lineage("a", &expr, None, false).unwrap();
975
976        let tables = get_source_tables(&node);
977        assert!(
978            tables.contains("t"),
979            "Expected source table 't', got: {:?}",
980            tables
981        );
982    }
983
984    #[test]
985    fn test_lineage_column_not_found() {
986        let expr = parse("SELECT a FROM t");
987        let result = lineage("nonexistent", &expr, None, false);
988        assert!(result.is_err());
989    }
990
991    #[test]
992    fn test_lineage_nested_cte() {
993        let expr = parse(
994            "WITH cte1 AS (SELECT a FROM t), \
995             cte2 AS (SELECT a FROM cte1) \
996             SELECT a FROM cte2",
997        );
998        let node = lineage("a", &expr, None, false).unwrap();
999
1000        // Should trace through cte2 → cte1 → t
1001        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1002        assert!(
1003            all_names.len() >= 3,
1004            "Expected to trace through nested CTEs, got: {:?}",
1005            all_names
1006        );
1007    }
1008
1009    #[test]
1010    fn test_trim_selects_true() {
1011        let expr = parse("SELECT a, b, c FROM t");
1012        let node = lineage("a", &expr, None, true).unwrap();
1013
1014        // The source should be trimmed to only include 'a'
1015        if let Expression::Select(select) = &node.source {
1016            assert_eq!(
1017                select.expressions.len(),
1018                1,
1019                "Trimmed source should have 1 expression, got {}",
1020                select.expressions.len()
1021            );
1022        } else {
1023            panic!("Expected Select source");
1024        }
1025    }
1026
1027    #[test]
1028    fn test_trim_selects_false() {
1029        let expr = parse("SELECT a, b, c FROM t");
1030        let node = lineage("a", &expr, None, false).unwrap();
1031
1032        // The source should keep all columns
1033        if let Expression::Select(select) = &node.source {
1034            assert_eq!(
1035                select.expressions.len(),
1036                3,
1037                "Untrimmed source should have 3 expressions"
1038            );
1039        } else {
1040            panic!("Expected Select source");
1041        }
1042    }
1043
1044    #[test]
1045    fn test_lineage_expression_in_select() {
1046        let expr = parse("SELECT a + b AS c FROM t");
1047        let node = lineage("c", &expr, None, false).unwrap();
1048
1049        // Should trace to both a and b from t
1050        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1051        assert!(
1052            all_names.len() >= 3,
1053            "Expected to trace a + b to both columns, got: {:?}",
1054            all_names
1055        );
1056    }
1057
1058    #[test]
1059    fn test_set_operation_by_index() {
1060        let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1061
1062        // Trace column "a" which is at index 0
1063        let node = lineage("a", &expr, None, false).unwrap();
1064
1065        // UNION branches should be traced by index
1066        assert_eq!(node.downstream.len(), 2);
1067    }
1068}