Skip to main content

polyglot_sql/
lineage.rs

1//! Column Lineage Tracking
2//!
3//! This module provides functionality to track column lineage through SQL queries,
4//! building a graph of how columns flow from source tables to the result set.
5//! Supports UNION/INTERSECT/EXCEPT, CTEs, derived tables, subqueries, and star expansion.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16/// A node in the column lineage graph
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19    /// Name of this lineage step (e.g., "table.column")
20    pub name: String,
21    /// The expression at this node
22    pub expression: Expression,
23    /// The source expression (the full query context)
24    pub source: Expression,
25    /// Downstream nodes that depend on this one
26    pub downstream: Vec<LineageNode>,
27    /// Optional source name (e.g., for derived tables)
28    pub source_name: String,
29    /// Optional reference node name (e.g., for CTEs)
30    pub reference_node_name: String,
31}
32
33impl LineageNode {
34    /// Create a new lineage node
35    pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36        Self {
37            name: name.into(),
38            expression,
39            source,
40            downstream: Vec::new(),
41            source_name: String::new(),
42            reference_node_name: String::new(),
43        }
44    }
45
46    /// Iterate over all nodes in the lineage graph using DFS
47    pub fn walk(&self) -> LineageWalker<'_> {
48        LineageWalker {
49            stack: vec![self],
50        }
51    }
52
53    /// Get all downstream column names
54    pub fn downstream_names(&self) -> Vec<String> {
55        self.downstream.iter().map(|n| n.name.clone()).collect()
56    }
57}
58
59/// Iterator for walking the lineage graph
60pub struct LineageWalker<'a> {
61    stack: Vec<&'a LineageNode>,
62}
63
64impl<'a> Iterator for LineageWalker<'a> {
65    type Item = &'a LineageNode;
66
67    fn next(&mut self) -> Option<Self::Item> {
68        if let Some(node) = self.stack.pop() {
69            // Add children in reverse order so they're visited in order
70            for child in node.downstream.iter().rev() {
71                self.stack.push(child);
72            }
73            Some(node)
74        } else {
75            None
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ColumnRef: name or positional index for column lookup
82// ---------------------------------------------------------------------------
83
84/// Column reference for lineage tracing — by name or positional index.
85enum ColumnRef<'a> {
86    Name(&'a str),
87    Index(usize),
88}
89
90// ---------------------------------------------------------------------------
91// Public API
92// ---------------------------------------------------------------------------
93
94/// Build the lineage graph for a column in a SQL query
95///
96/// # Arguments
97/// * `column` - The column name to trace lineage for
98/// * `sql` - The SQL expression (SELECT, UNION, etc.)
99/// * `dialect` - Optional dialect for parsing
100/// * `trim_selects` - If true, trim the source SELECT to only include the target column
101///
102/// # Returns
103/// The root lineage node for the specified column
104///
105/// # Example
106/// ```ignore
107/// use polyglot_sql::lineage::lineage;
108/// use polyglot_sql::parse_one;
109/// use polyglot_sql::DialectType;
110///
111/// let sql = "SELECT a, b + 1 AS c FROM t";
112/// let expr = parse_one(sql, DialectType::Generic).unwrap();
113/// let node = lineage("c", &expr, None, false).unwrap();
114/// ```
115pub fn lineage(
116    column: &str,
117    sql: &Expression,
118    dialect: Option<DialectType>,
119    trim_selects: bool,
120) -> Result<LineageNode> {
121    let scope = build_scope(sql);
122    to_node(
123        ColumnRef::Name(column),
124        &scope,
125        dialect,
126        "",
127        "",
128        "",
129        trim_selects,
130    )
131}
132
133/// Get all source tables from a lineage graph
134pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
135    let mut tables = HashSet::new();
136    collect_source_tables(node, &mut tables);
137    tables
138}
139
140/// Recursively collect source table names from lineage graph
141pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
142    if let Expression::Table(table) = &node.source {
143        tables.insert(table.name.name.clone());
144    }
145    for child in &node.downstream {
146        collect_source_tables(child, tables);
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Core recursive lineage builder
152// ---------------------------------------------------------------------------
153
154/// Recursively build a lineage node for a column in a scope.
155fn to_node(
156    column: ColumnRef<'_>,
157    scope: &Scope,
158    dialect: Option<DialectType>,
159    scope_name: &str,
160    source_name: &str,
161    reference_node_name: &str,
162    trim_selects: bool,
163) -> Result<LineageNode> {
164    to_node_inner(column, scope, dialect, scope_name, source_name, reference_node_name, trim_selects, &[])
165}
166
167fn to_node_inner(
168    column: ColumnRef<'_>,
169    scope: &Scope,
170    dialect: Option<DialectType>,
171    scope_name: &str,
172    source_name: &str,
173    reference_node_name: &str,
174    trim_selects: bool,
175    ancestor_cte_scopes: &[Scope],
176) -> Result<LineageNode> {
177    let scope_expr = &scope.expression;
178
179    // Build combined CTE scopes: current scope's cte_scopes + ancestors
180    let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
181    for s in ancestor_cte_scopes {
182        all_cte_scopes.push(s);
183    }
184
185    // 0. Unwrap CTE scope — CTE scope expressions are Expression::Cte(...)
186    //    but we need the inner query (SELECT/UNION) for column lookup.
187    let effective_expr = match scope_expr {
188        Expression::Cte(cte) => &cte.this,
189        other => other,
190    };
191
192    // 1. Set operations (UNION / INTERSECT / EXCEPT)
193    if matches!(
194        effective_expr,
195        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
196    ) {
197        // For CTE wrapping a set op, create a temporary scope with the inner expression
198        if matches!(scope_expr, Expression::Cte(_)) {
199            let mut inner_scope = Scope::new(effective_expr.clone());
200            inner_scope.union_scopes = scope.union_scopes.clone();
201            inner_scope.sources = scope.sources.clone();
202            inner_scope.cte_sources = scope.cte_sources.clone();
203            inner_scope.cte_scopes = scope.cte_scopes.clone();
204            inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
205            inner_scope.subquery_scopes = scope.subquery_scopes.clone();
206            return handle_set_operation(
207                &column,
208                &inner_scope,
209                dialect,
210                scope_name,
211                source_name,
212                reference_node_name,
213                trim_selects,
214                ancestor_cte_scopes,
215            );
216        }
217        return handle_set_operation(
218            &column,
219            scope,
220            dialect,
221            scope_name,
222            source_name,
223            reference_node_name,
224            trim_selects,
225            ancestor_cte_scopes,
226        );
227    }
228
229    // 2. Find the select expression for this column
230    let select_expr = find_select_expr(effective_expr, &column)?;
231    let column_name = resolve_column_name(&column, &select_expr);
232
233    // 3. Trim source if requested
234    let node_source = if trim_selects {
235        trim_source(effective_expr, &select_expr)
236    } else {
237        effective_expr.clone()
238    };
239
240    // 4. Create the lineage node
241    let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
242    node.source_name = source_name.to_string();
243    node.reference_node_name = reference_node_name.to_string();
244
245    // 5. Star handling — add downstream for each source
246    if matches!(&select_expr, Expression::Star(_)) {
247        for (name, source_info) in &scope.sources {
248            let child = LineageNode::new(
249                format!("{}.*", name),
250                Expression::Star(crate::expressions::Star {
251                    table: None,
252                    except: None,
253                    replace: None,
254                    rename: None,
255                    trailing_comments: vec![],
256                }),
257                source_info.expression.clone(),
258            );
259            node.downstream.push(child);
260        }
261        return Ok(node);
262    }
263
264    // 6. Subqueries in select — trace through scalar subqueries
265    let subqueries: Vec<&Expression> = select_expr.find_all(|e| {
266        matches!(e, Expression::Subquery(sq) if sq.alias.is_none())
267    });
268    for sq_expr in subqueries {
269        if let Expression::Subquery(sq) = sq_expr {
270            for sq_scope in &scope.subquery_scopes {
271                if sq_scope.expression == sq.this {
272                    if let Ok(child) = to_node_inner(
273                        ColumnRef::Index(0),
274                        sq_scope,
275                        dialect,
276                        &column_name,
277                        "",
278                        "",
279                        trim_selects,
280                        ancestor_cte_scopes,
281                    ) {
282                        node.downstream.push(child);
283                    }
284                    break;
285                }
286            }
287        }
288    }
289
290    // 7. Column references — trace each column to its source
291    let col_refs = find_column_refs_in_expr(&select_expr);
292    for col_ref in col_refs {
293        let col_name = &col_ref.column;
294        if let Some(ref table_id) = col_ref.table {
295            let tbl = &table_id.name;
296            resolve_qualified_column(
297                &mut node,
298                scope,
299                dialect,
300                tbl,
301                col_name,
302                &column_name,
303                trim_selects,
304                &all_cte_scopes,
305            );
306        } else {
307            resolve_unqualified_column(
308                &mut node,
309                scope,
310                dialect,
311                col_name,
312                &column_name,
313                trim_selects,
314                &all_cte_scopes,
315            );
316        }
317    }
318
319    Ok(node)
320}
321
322// ---------------------------------------------------------------------------
323// Set operation handling
324// ---------------------------------------------------------------------------
325
326fn handle_set_operation(
327    column: &ColumnRef<'_>,
328    scope: &Scope,
329    dialect: Option<DialectType>,
330    scope_name: &str,
331    source_name: &str,
332    reference_node_name: &str,
333    trim_selects: bool,
334    ancestor_cte_scopes: &[Scope],
335) -> Result<LineageNode> {
336    let scope_expr = &scope.expression;
337
338    // Determine column index
339    let col_index = match column {
340        ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
341        ColumnRef::Index(i) => *i,
342    };
343
344    let col_name = match column {
345        ColumnRef::Name(name) => name.to_string(),
346        ColumnRef::Index(_) => format!("_{col_index}"),
347    };
348
349    let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
350    node.source_name = source_name.to_string();
351    node.reference_node_name = reference_node_name.to_string();
352
353    // Recurse into each union branch
354    for branch_scope in &scope.union_scopes {
355        if let Ok(child) = to_node_inner(
356            ColumnRef::Index(col_index),
357            branch_scope,
358            dialect,
359            scope_name,
360            "",
361            "",
362            trim_selects,
363            ancestor_cte_scopes,
364        ) {
365            node.downstream.push(child);
366        }
367    }
368
369    Ok(node)
370}
371
372// ---------------------------------------------------------------------------
373// Column resolution helpers
374// ---------------------------------------------------------------------------
375
376fn resolve_qualified_column(
377    node: &mut LineageNode,
378    scope: &Scope,
379    dialect: Option<DialectType>,
380    table: &str,
381    col_name: &str,
382    parent_name: &str,
383    trim_selects: bool,
384    all_cte_scopes: &[&Scope],
385) {
386    // Check if table is a CTE reference (cte_sources tracks CTE names)
387    if scope.cte_sources.contains_key(table) {
388        if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
389            // Build ancestor CTE scopes from all_cte_scopes for the recursive call
390            let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
391            if let Ok(child) = to_node_inner(
392                ColumnRef::Name(col_name),
393                child_scope,
394                dialect,
395                parent_name,
396                table,
397                parent_name,
398                trim_selects,
399                &ancestors,
400            ) {
401                node.downstream.push(child);
402                return;
403            }
404        }
405    }
406
407    // Check if table is a derived table (is_scope = true in sources)
408    if let Some(source_info) = scope.sources.get(table) {
409        if source_info.is_scope {
410            if let Some(child_scope) = find_child_scope(scope, table) {
411                let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
412                if let Ok(child) = to_node_inner(
413                    ColumnRef::Name(col_name),
414                    child_scope,
415                    dialect,
416                    parent_name,
417                    table,
418                    parent_name,
419                    trim_selects,
420                    &ancestors,
421                ) {
422                    node.downstream.push(child);
423                    return;
424                }
425            }
426        }
427    }
428
429    // Base table or unresolved — terminal node
430    node.downstream.push(make_table_column_node(table, col_name));
431}
432
433fn resolve_unqualified_column(
434    node: &mut LineageNode,
435    scope: &Scope,
436    dialect: Option<DialectType>,
437    col_name: &str,
438    parent_name: &str,
439    trim_selects: bool,
440    all_cte_scopes: &[&Scope],
441) {
442    // Try to find which source this column belongs to.
443    // Filter to only FROM-clause sources: add_cte_source adds all CTEs to sources
444    // with Expression::Cte, but FROM-clause Table references overwrite with Expression::Table.
445    // So CTE-only entries (not referenced in FROM) have Expression::Cte — exclude those.
446    let from_source_names: Vec<&String> = scope
447        .sources
448        .iter()
449        .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
450        .map(|(name, _)| name)
451        .collect();
452
453    if from_source_names.len() == 1 {
454        let tbl = from_source_names[0];
455        resolve_qualified_column(node, scope, dialect, tbl, col_name, parent_name, trim_selects, all_cte_scopes);
456        return;
457    }
458
459    // Multiple sources — can't resolve without schema info, add unqualified node
460    let child = LineageNode::new(
461        col_name.to_string(),
462        Expression::Column(crate::expressions::Column {
463            name: crate::expressions::Identifier::new(col_name.to_string()),
464            table: None,
465            join_mark: false,
466            trailing_comments: vec![],
467        }),
468        node.source.clone(),
469    );
470    node.downstream.push(child);
471}
472
473// ---------------------------------------------------------------------------
474// Helper functions
475// ---------------------------------------------------------------------------
476
477/// Get the alias or name of an expression
478fn get_alias_or_name(expr: &Expression) -> Option<String> {
479    match expr {
480        Expression::Alias(alias) => Some(alias.alias.name.clone()),
481        Expression::Column(col) => Some(col.name.name.clone()),
482        Expression::Identifier(id) => Some(id.name.clone()),
483        Expression::Star(_) => Some("*".to_string()),
484        _ => None,
485    }
486}
487
488/// Resolve the display name for a column reference.
489fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
490    match column {
491        ColumnRef::Name(n) => n.to_string(),
492        ColumnRef::Index(_) => {
493            get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string())
494        }
495    }
496}
497
498/// Find the select expression matching a column reference.
499fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
500    if let Expression::Select(ref select) = scope_expr {
501        match column {
502            ColumnRef::Name(name) => {
503                for expr in &select.expressions {
504                    if get_alias_or_name(expr).as_deref() == Some(name) {
505                        return Ok(expr.clone());
506                    }
507                }
508                Err(crate::error::Error::Parse(format!(
509                    "Cannot find column '{}' in query",
510                    name
511                )))
512            }
513            ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
514                crate::error::Error::Parse(format!("Column index {} out of range", idx))
515            }),
516        }
517    } else {
518        Err(crate::error::Error::Parse(
519            "Expected SELECT expression for column lookup".to_string(),
520        ))
521    }
522}
523
524/// Find the positional index of a column name in a set operation's first SELECT branch.
525fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
526    let mut expr = set_op_expr;
527    loop {
528        match expr {
529            Expression::Union(u) => expr = &u.left,
530            Expression::Intersect(i) => expr = &i.left,
531            Expression::Except(e) => expr = &e.left,
532            Expression::Select(select) => {
533                for (i, e) in select.expressions.iter().enumerate() {
534                    if get_alias_or_name(e).as_deref() == Some(name) {
535                        return Ok(i);
536                    }
537                }
538                return Err(crate::error::Error::Parse(format!(
539                    "Cannot find column '{}' in set operation",
540                    name
541                )));
542            }
543            _ => {
544                return Err(crate::error::Error::Parse(
545                    "Expected SELECT or set operation".to_string(),
546                ))
547            }
548        }
549    }
550}
551
552/// If trim_selects is enabled, return a copy of the SELECT with only the target column.
553fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
554    if let Expression::Select(select) = select_expr {
555        let mut trimmed = select.as_ref().clone();
556        trimmed.expressions = vec![target_expr.clone()];
557        Expression::Select(Box::new(trimmed))
558    } else {
559        select_expr.clone()
560    }
561}
562
563/// Find the child scope (CTE or derived table) for a given source name.
564fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
565    // Check CTE scopes
566    if scope.cte_sources.contains_key(source_name) {
567        for cte_scope in &scope.cte_scopes {
568            if let Expression::Cte(cte) = &cte_scope.expression {
569                if cte.alias.name == source_name {
570                    return Some(cte_scope);
571                }
572            }
573        }
574    }
575
576    // Check derived table scopes
577    if let Some(source_info) = scope.sources.get(source_name) {
578        if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
579            if let Expression::Subquery(sq) = &source_info.expression {
580                for dt_scope in &scope.derived_table_scopes {
581                    if dt_scope.expression == sq.this {
582                        return Some(dt_scope);
583                    }
584                }
585            }
586        }
587    }
588
589    None
590}
591
592/// Find a CTE scope by name, searching through a combined list of CTE scopes.
593/// This handles nested CTEs where the current scope doesn't have the CTE scope
594/// as a direct child but knows about it via cte_sources.
595fn find_child_scope_in<'a>(
596    all_cte_scopes: &[&'a Scope],
597    scope: &'a Scope,
598    source_name: &str,
599) -> Option<&'a Scope> {
600    // First try the scope's own cte_scopes
601    for cte_scope in &scope.cte_scopes {
602        if let Expression::Cte(cte) = &cte_scope.expression {
603            if cte.alias.name == source_name {
604                return Some(cte_scope);
605            }
606        }
607    }
608
609    // Then search through all ancestor CTE scopes
610    for cte_scope in all_cte_scopes {
611        if let Expression::Cte(cte) = &cte_scope.expression {
612            if cte.alias.name == source_name {
613                return Some(cte_scope);
614            }
615        }
616    }
617
618    // Fall back to derived table scopes
619    if let Some(source_info) = scope.sources.get(source_name) {
620        if source_info.is_scope {
621            if let Expression::Subquery(sq) = &source_info.expression {
622                for dt_scope in &scope.derived_table_scopes {
623                    if dt_scope.expression == sq.this {
624                        return Some(dt_scope);
625                    }
626                }
627            }
628        }
629    }
630
631    None
632}
633
634/// Create a terminal lineage node for a table.column reference.
635fn make_table_column_node(table: &str, column: &str) -> LineageNode {
636    LineageNode::new(
637        format!("{}.{}", table, column),
638        Expression::Column(crate::expressions::Column {
639            name: crate::expressions::Identifier::new(column.to_string()),
640            table: Some(crate::expressions::Identifier::new(table.to_string())),
641            join_mark: false,
642            trailing_comments: vec![],
643        }),
644        Expression::Table(crate::expressions::TableRef::new(table)),
645    )
646}
647
648/// Simple column reference extracted from an expression
649#[derive(Debug, Clone)]
650struct SimpleColumnRef {
651    table: Option<crate::expressions::Identifier>,
652    column: String,
653}
654
655/// Find all column references in an expression (does not recurse into subqueries).
656fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
657    let mut refs = Vec::new();
658    collect_column_refs(expr, &mut refs);
659    refs
660}
661
662fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
663    match expr {
664        Expression::Column(col) => {
665            refs.push(SimpleColumnRef {
666                table: col.table.clone(),
667                column: col.name.name.clone(),
668            });
669        }
670        Expression::Alias(alias) => {
671            collect_column_refs(&alias.this, refs);
672        }
673        Expression::And(op)
674        | Expression::Or(op)
675        | Expression::Eq(op)
676        | Expression::Neq(op)
677        | Expression::Lt(op)
678        | Expression::Lte(op)
679        | Expression::Gt(op)
680        | Expression::Gte(op)
681        | Expression::Add(op)
682        | Expression::Sub(op)
683        | Expression::Mul(op)
684        | Expression::Div(op)
685        | Expression::Mod(op)
686        | Expression::BitwiseAnd(op)
687        | Expression::BitwiseOr(op)
688        | Expression::BitwiseXor(op)
689        | Expression::Concat(op) => {
690            collect_column_refs(&op.left, refs);
691            collect_column_refs(&op.right, refs);
692        }
693        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
694            collect_column_refs(&u.this, refs);
695        }
696        Expression::Function(func) => {
697            for arg in &func.args {
698                collect_column_refs(arg, refs);
699            }
700        }
701        Expression::AggregateFunction(func) => {
702            for arg in &func.args {
703                collect_column_refs(arg, refs);
704            }
705        }
706        Expression::WindowFunction(wf) => {
707            collect_column_refs(&wf.this, refs);
708        }
709        Expression::Case(case) => {
710            if let Some(operand) = &case.operand {
711                collect_column_refs(operand, refs);
712            }
713            for (cond, result) in &case.whens {
714                collect_column_refs(cond, refs);
715                collect_column_refs(result, refs);
716            }
717            if let Some(ref else_expr) = case.else_ {
718                collect_column_refs(else_expr, refs);
719            }
720        }
721        Expression::Cast(cast) => {
722            collect_column_refs(&cast.this, refs);
723        }
724        Expression::Paren(p) => {
725            collect_column_refs(&p.this, refs);
726        }
727        Expression::Coalesce(c) => {
728            for e in &c.expressions {
729                collect_column_refs(e, refs);
730            }
731        }
732        // Don't recurse into subqueries — those are handled separately
733        Expression::Subquery(_) | Expression::Exists(_) => {}
734        _ => {}
735    }
736}
737
738// ---------------------------------------------------------------------------
739// Tests
740// ---------------------------------------------------------------------------
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use crate::dialects::{Dialect, DialectType};
746
747    fn parse(sql: &str) -> Expression {
748        let dialect = Dialect::get(DialectType::Generic);
749        let ast = dialect.parse(sql).unwrap();
750        ast.into_iter().next().unwrap()
751    }
752
753    #[test]
754    fn test_simple_lineage() {
755        let expr = parse("SELECT a FROM t");
756        let node = lineage("a", &expr, None, false).unwrap();
757
758        assert_eq!(node.name, "a");
759        assert!(!node.downstream.is_empty(), "Should have downstream nodes");
760        // Should trace to t.a
761        let names = node.downstream_names();
762        assert!(
763            names.iter().any(|n| n == "t.a"),
764            "Expected t.a in downstream, got: {:?}",
765            names
766        );
767    }
768
769    #[test]
770    fn test_lineage_walk() {
771        let root = LineageNode {
772            name: "col_a".to_string(),
773            expression: Expression::Null(crate::expressions::Null),
774            source: Expression::Null(crate::expressions::Null),
775            downstream: vec![LineageNode::new(
776                "t.a",
777                Expression::Null(crate::expressions::Null),
778                Expression::Null(crate::expressions::Null),
779            )],
780            source_name: String::new(),
781            reference_node_name: String::new(),
782        };
783
784        let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
785        assert_eq!(names.len(), 2);
786        assert_eq!(names[0], "col_a");
787        assert_eq!(names[1], "t.a");
788    }
789
790    #[test]
791    fn test_aliased_column() {
792        let expr = parse("SELECT a + 1 AS b FROM t");
793        let node = lineage("b", &expr, None, false).unwrap();
794
795        assert_eq!(node.name, "b");
796        // Should trace through the expression to t.a
797        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
798        assert!(
799            all_names.iter().any(|n| n.contains("a")),
800            "Expected to trace to column a, got: {:?}",
801            all_names
802        );
803    }
804
805    #[test]
806    fn test_qualified_column() {
807        let expr = parse("SELECT t.a FROM t");
808        let node = lineage("a", &expr, None, false).unwrap();
809
810        assert_eq!(node.name, "a");
811        let names = node.downstream_names();
812        assert!(
813            names.iter().any(|n| n == "t.a"),
814            "Expected t.a, got: {:?}",
815            names
816        );
817    }
818
819    #[test]
820    fn test_unqualified_column() {
821        let expr = parse("SELECT a FROM t");
822        let node = lineage("a", &expr, None, false).unwrap();
823
824        // Unqualified but single source → resolved to t.a
825        let names = node.downstream_names();
826        assert!(
827            names.iter().any(|n| n == "t.a"),
828            "Expected t.a, got: {:?}",
829            names
830        );
831    }
832
833    #[test]
834    fn test_lineage_join() {
835        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
836
837        let node_a = lineage("a", &expr, None, false).unwrap();
838        let names_a = node_a.downstream_names();
839        assert!(
840            names_a.iter().any(|n| n == "t.a"),
841            "Expected t.a, got: {:?}",
842            names_a
843        );
844
845        let node_b = lineage("b", &expr, None, false).unwrap();
846        let names_b = node_b.downstream_names();
847        assert!(
848            names_b.iter().any(|n| n == "s.b"),
849            "Expected s.b, got: {:?}",
850            names_b
851        );
852    }
853
854    #[test]
855    fn test_lineage_derived_table() {
856        let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
857        let node = lineage("a", &expr, None, false).unwrap();
858
859        assert_eq!(node.name, "a");
860        // Should trace through the derived table to t.a
861        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
862        assert!(
863            all_names.iter().any(|n| n == "t.a"),
864            "Expected to trace through derived table to t.a, got: {:?}",
865            all_names
866        );
867    }
868
869    #[test]
870    fn test_lineage_cte() {
871        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
872        let node = lineage("a", &expr, None, false).unwrap();
873
874        assert_eq!(node.name, "a");
875        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
876        assert!(
877            all_names.iter().any(|n| n == "t.a"),
878            "Expected to trace through CTE to t.a, got: {:?}",
879            all_names
880        );
881    }
882
883    #[test]
884    fn test_lineage_union() {
885        let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
886        let node = lineage("a", &expr, None, false).unwrap();
887
888        assert_eq!(node.name, "a");
889        // Should have 2 downstream branches
890        assert_eq!(
891            node.downstream.len(),
892            2,
893            "Expected 2 branches for UNION, got {}",
894            node.downstream.len()
895        );
896    }
897
898    #[test]
899    fn test_lineage_cte_union() {
900        let expr = parse(
901            "WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte",
902        );
903        let node = lineage("a", &expr, None, false).unwrap();
904
905        // Should trace through CTE into both UNION branches
906        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
907        assert!(
908            all_names.len() >= 3,
909            "Expected at least 3 nodes for CTE with UNION, got: {:?}",
910            all_names
911        );
912    }
913
914    #[test]
915    fn test_lineage_star() {
916        let expr = parse("SELECT * FROM t");
917        let node = lineage("*", &expr, None, false).unwrap();
918
919        assert_eq!(node.name, "*");
920        // Should have downstream for table t
921        assert!(
922            !node.downstream.is_empty(),
923            "Star should produce downstream nodes"
924        );
925    }
926
927    #[test]
928    fn test_lineage_subquery_in_select() {
929        let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
930        let node = lineage("x", &expr, None, false).unwrap();
931
932        assert_eq!(node.name, "x");
933        // Should have traced into the scalar subquery
934        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
935        assert!(
936            all_names.len() >= 2,
937            "Expected tracing into scalar subquery, got: {:?}",
938            all_names
939        );
940    }
941
942    #[test]
943    fn test_lineage_multiple_columns() {
944        let expr = parse("SELECT a, b FROM t");
945
946        let node_a = lineage("a", &expr, None, false).unwrap();
947        let node_b = lineage("b", &expr, None, false).unwrap();
948
949        assert_eq!(node_a.name, "a");
950        assert_eq!(node_b.name, "b");
951
952        // Each should trace independently
953        let names_a = node_a.downstream_names();
954        let names_b = node_b.downstream_names();
955        assert!(names_a.iter().any(|n| n == "t.a"));
956        assert!(names_b.iter().any(|n| n == "t.b"));
957    }
958
959    #[test]
960    fn test_get_source_tables() {
961        let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
962        let node = lineage("a", &expr, None, false).unwrap();
963
964        let tables = get_source_tables(&node);
965        assert!(
966            tables.contains("t"),
967            "Expected source table 't', got: {:?}",
968            tables
969        );
970    }
971
972    #[test]
973    fn test_lineage_column_not_found() {
974        let expr = parse("SELECT a FROM t");
975        let result = lineage("nonexistent", &expr, None, false);
976        assert!(result.is_err());
977    }
978
979    #[test]
980    fn test_lineage_nested_cte() {
981        let expr = parse(
982            "WITH cte1 AS (SELECT a FROM t), \
983             cte2 AS (SELECT a FROM cte1) \
984             SELECT a FROM cte2",
985        );
986        let node = lineage("a", &expr, None, false).unwrap();
987
988        // Should trace through cte2 → cte1 → t
989        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
990        assert!(
991            all_names.len() >= 3,
992            "Expected to trace through nested CTEs, got: {:?}",
993            all_names
994        );
995    }
996
997    #[test]
998    fn test_trim_selects_true() {
999        let expr = parse("SELECT a, b, c FROM t");
1000        let node = lineage("a", &expr, None, true).unwrap();
1001
1002        // The source should be trimmed to only include 'a'
1003        if let Expression::Select(select) = &node.source {
1004            assert_eq!(
1005                select.expressions.len(),
1006                1,
1007                "Trimmed source should have 1 expression, got {}",
1008                select.expressions.len()
1009            );
1010        } else {
1011            panic!("Expected Select source");
1012        }
1013    }
1014
1015    #[test]
1016    fn test_trim_selects_false() {
1017        let expr = parse("SELECT a, b, c FROM t");
1018        let node = lineage("a", &expr, None, false).unwrap();
1019
1020        // The source should keep all columns
1021        if let Expression::Select(select) = &node.source {
1022            assert_eq!(
1023                select.expressions.len(),
1024                3,
1025                "Untrimmed source should have 3 expressions"
1026            );
1027        } else {
1028            panic!("Expected Select source");
1029        }
1030    }
1031
1032    #[test]
1033    fn test_lineage_expression_in_select() {
1034        let expr = parse("SELECT a + b AS c FROM t");
1035        let node = lineage("c", &expr, None, false).unwrap();
1036
1037        // Should trace to both a and b from t
1038        let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1039        assert!(
1040            all_names.len() >= 3,
1041            "Expected to trace a + b to both columns, got: {:?}",
1042            all_names
1043        );
1044    }
1045
1046    #[test]
1047    fn test_set_operation_by_index() {
1048        let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1049
1050        // Trace column "a" which is at index 0
1051        let node = lineage("a", &expr, None, false).unwrap();
1052
1053        // UNION branches should be traced by index
1054        assert_eq!(node.downstream.len(), 2);
1055    }
1056}