1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19 pub name: String,
21 pub expression: Expression,
23 pub source: Expression,
25 pub downstream: Vec<LineageNode>,
27 pub source_name: String,
29 pub reference_node_name: String,
31}
32
33impl LineageNode {
34 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36 Self {
37 name: name.into(),
38 expression,
39 source,
40 downstream: Vec::new(),
41 source_name: String::new(),
42 reference_node_name: String::new(),
43 }
44 }
45
46 pub fn walk(&self) -> LineageWalker<'_> {
48 LineageWalker {
49 stack: vec![self],
50 }
51 }
52
53 pub fn downstream_names(&self) -> Vec<String> {
55 self.downstream.iter().map(|n| n.name.clone()).collect()
56 }
57}
58
59pub struct LineageWalker<'a> {
61 stack: Vec<&'a LineageNode>,
62}
63
64impl<'a> Iterator for LineageWalker<'a> {
65 type Item = &'a LineageNode;
66
67 fn next(&mut self) -> Option<Self::Item> {
68 if let Some(node) = self.stack.pop() {
69 for child in node.downstream.iter().rev() {
71 self.stack.push(child);
72 }
73 Some(node)
74 } else {
75 None
76 }
77 }
78}
79
80enum ColumnRef<'a> {
86 Name(&'a str),
87 Index(usize),
88}
89
90pub fn lineage(
116 column: &str,
117 sql: &Expression,
118 dialect: Option<DialectType>,
119 trim_selects: bool,
120) -> Result<LineageNode> {
121 let scope = build_scope(sql);
122 to_node(
123 ColumnRef::Name(column),
124 &scope,
125 dialect,
126 "",
127 "",
128 "",
129 trim_selects,
130 )
131}
132
133pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
135 let mut tables = HashSet::new();
136 collect_source_tables(node, &mut tables);
137 tables
138}
139
140pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
142 if let Expression::Table(table) = &node.source {
143 tables.insert(table.name.name.clone());
144 }
145 for child in &node.downstream {
146 collect_source_tables(child, tables);
147 }
148}
149
150fn to_node(
156 column: ColumnRef<'_>,
157 scope: &Scope,
158 dialect: Option<DialectType>,
159 scope_name: &str,
160 source_name: &str,
161 reference_node_name: &str,
162 trim_selects: bool,
163) -> Result<LineageNode> {
164 to_node_inner(column, scope, dialect, scope_name, source_name, reference_node_name, trim_selects, &[])
165}
166
167fn to_node_inner(
168 column: ColumnRef<'_>,
169 scope: &Scope,
170 dialect: Option<DialectType>,
171 scope_name: &str,
172 source_name: &str,
173 reference_node_name: &str,
174 trim_selects: bool,
175 ancestor_cte_scopes: &[Scope],
176) -> Result<LineageNode> {
177 let scope_expr = &scope.expression;
178
179 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
181 for s in ancestor_cte_scopes {
182 all_cte_scopes.push(s);
183 }
184
185 let effective_expr = match scope_expr {
188 Expression::Cte(cte) => &cte.this,
189 other => other,
190 };
191
192 if matches!(
194 effective_expr,
195 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
196 ) {
197 if matches!(scope_expr, Expression::Cte(_)) {
199 let mut inner_scope = Scope::new(effective_expr.clone());
200 inner_scope.union_scopes = scope.union_scopes.clone();
201 inner_scope.sources = scope.sources.clone();
202 inner_scope.cte_sources = scope.cte_sources.clone();
203 inner_scope.cte_scopes = scope.cte_scopes.clone();
204 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
205 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
206 return handle_set_operation(
207 &column,
208 &inner_scope,
209 dialect,
210 scope_name,
211 source_name,
212 reference_node_name,
213 trim_selects,
214 ancestor_cte_scopes,
215 );
216 }
217 return handle_set_operation(
218 &column,
219 scope,
220 dialect,
221 scope_name,
222 source_name,
223 reference_node_name,
224 trim_selects,
225 ancestor_cte_scopes,
226 );
227 }
228
229 let select_expr = find_select_expr(effective_expr, &column)?;
231 let column_name = resolve_column_name(&column, &select_expr);
232
233 let node_source = if trim_selects {
235 trim_source(effective_expr, &select_expr)
236 } else {
237 effective_expr.clone()
238 };
239
240 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
242 node.source_name = source_name.to_string();
243 node.reference_node_name = reference_node_name.to_string();
244
245 if matches!(&select_expr, Expression::Star(_)) {
247 for (name, source_info) in &scope.sources {
248 let child = LineageNode::new(
249 format!("{}.*", name),
250 Expression::Star(crate::expressions::Star {
251 table: None,
252 except: None,
253 replace: None,
254 rename: None,
255 trailing_comments: vec![],
256 }),
257 source_info.expression.clone(),
258 );
259 node.downstream.push(child);
260 }
261 return Ok(node);
262 }
263
264 let subqueries: Vec<&Expression> = select_expr.find_all(|e| {
266 matches!(e, Expression::Subquery(sq) if sq.alias.is_none())
267 });
268 for sq_expr in subqueries {
269 if let Expression::Subquery(sq) = sq_expr {
270 for sq_scope in &scope.subquery_scopes {
271 if sq_scope.expression == sq.this {
272 if let Ok(child) = to_node_inner(
273 ColumnRef::Index(0),
274 sq_scope,
275 dialect,
276 &column_name,
277 "",
278 "",
279 trim_selects,
280 ancestor_cte_scopes,
281 ) {
282 node.downstream.push(child);
283 }
284 break;
285 }
286 }
287 }
288 }
289
290 let col_refs = find_column_refs_in_expr(&select_expr);
292 for col_ref in col_refs {
293 let col_name = &col_ref.column;
294 if let Some(ref table_id) = col_ref.table {
295 let tbl = &table_id.name;
296 resolve_qualified_column(
297 &mut node,
298 scope,
299 dialect,
300 tbl,
301 col_name,
302 &column_name,
303 trim_selects,
304 &all_cte_scopes,
305 );
306 } else {
307 resolve_unqualified_column(
308 &mut node,
309 scope,
310 dialect,
311 col_name,
312 &column_name,
313 trim_selects,
314 &all_cte_scopes,
315 );
316 }
317 }
318
319 Ok(node)
320}
321
322fn handle_set_operation(
327 column: &ColumnRef<'_>,
328 scope: &Scope,
329 dialect: Option<DialectType>,
330 scope_name: &str,
331 source_name: &str,
332 reference_node_name: &str,
333 trim_selects: bool,
334 ancestor_cte_scopes: &[Scope],
335) -> Result<LineageNode> {
336 let scope_expr = &scope.expression;
337
338 let col_index = match column {
340 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
341 ColumnRef::Index(i) => *i,
342 };
343
344 let col_name = match column {
345 ColumnRef::Name(name) => name.to_string(),
346 ColumnRef::Index(_) => format!("_{col_index}"),
347 };
348
349 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
350 node.source_name = source_name.to_string();
351 node.reference_node_name = reference_node_name.to_string();
352
353 for branch_scope in &scope.union_scopes {
355 if let Ok(child) = to_node_inner(
356 ColumnRef::Index(col_index),
357 branch_scope,
358 dialect,
359 scope_name,
360 "",
361 "",
362 trim_selects,
363 ancestor_cte_scopes,
364 ) {
365 node.downstream.push(child);
366 }
367 }
368
369 Ok(node)
370}
371
372fn resolve_qualified_column(
377 node: &mut LineageNode,
378 scope: &Scope,
379 dialect: Option<DialectType>,
380 table: &str,
381 col_name: &str,
382 parent_name: &str,
383 trim_selects: bool,
384 all_cte_scopes: &[&Scope],
385) {
386 if scope.cte_sources.contains_key(table) {
388 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
389 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
391 if let Ok(child) = to_node_inner(
392 ColumnRef::Name(col_name),
393 child_scope,
394 dialect,
395 parent_name,
396 table,
397 parent_name,
398 trim_selects,
399 &ancestors,
400 ) {
401 node.downstream.push(child);
402 return;
403 }
404 }
405 }
406
407 if let Some(source_info) = scope.sources.get(table) {
409 if source_info.is_scope {
410 if let Some(child_scope) = find_child_scope(scope, table) {
411 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
412 if let Ok(child) = to_node_inner(
413 ColumnRef::Name(col_name),
414 child_scope,
415 dialect,
416 parent_name,
417 table,
418 parent_name,
419 trim_selects,
420 &ancestors,
421 ) {
422 node.downstream.push(child);
423 return;
424 }
425 }
426 }
427 }
428
429 node.downstream.push(make_table_column_node(table, col_name));
431}
432
433fn resolve_unqualified_column(
434 node: &mut LineageNode,
435 scope: &Scope,
436 dialect: Option<DialectType>,
437 col_name: &str,
438 parent_name: &str,
439 trim_selects: bool,
440 all_cte_scopes: &[&Scope],
441) {
442 let from_source_names: Vec<&String> = scope
447 .sources
448 .iter()
449 .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
450 .map(|(name, _)| name)
451 .collect();
452
453 if from_source_names.len() == 1 {
454 let tbl = from_source_names[0];
455 resolve_qualified_column(node, scope, dialect, tbl, col_name, parent_name, trim_selects, all_cte_scopes);
456 return;
457 }
458
459 let child = LineageNode::new(
461 col_name.to_string(),
462 Expression::Column(crate::expressions::Column {
463 name: crate::expressions::Identifier::new(col_name.to_string()),
464 table: None,
465 join_mark: false,
466 trailing_comments: vec![],
467 }),
468 node.source.clone(),
469 );
470 node.downstream.push(child);
471}
472
473fn get_alias_or_name(expr: &Expression) -> Option<String> {
479 match expr {
480 Expression::Alias(alias) => Some(alias.alias.name.clone()),
481 Expression::Column(col) => Some(col.name.name.clone()),
482 Expression::Identifier(id) => Some(id.name.clone()),
483 Expression::Star(_) => Some("*".to_string()),
484 _ => None,
485 }
486}
487
488fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
490 match column {
491 ColumnRef::Name(n) => n.to_string(),
492 ColumnRef::Index(_) => {
493 get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string())
494 }
495 }
496}
497
498fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
500 if let Expression::Select(ref select) = scope_expr {
501 match column {
502 ColumnRef::Name(name) => {
503 for expr in &select.expressions {
504 if get_alias_or_name(expr).as_deref() == Some(name) {
505 return Ok(expr.clone());
506 }
507 }
508 Err(crate::error::Error::Parse(format!(
509 "Cannot find column '{}' in query",
510 name
511 )))
512 }
513 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
514 crate::error::Error::Parse(format!("Column index {} out of range", idx))
515 }),
516 }
517 } else {
518 Err(crate::error::Error::Parse(
519 "Expected SELECT expression for column lookup".to_string(),
520 ))
521 }
522}
523
524fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
526 let mut expr = set_op_expr;
527 loop {
528 match expr {
529 Expression::Union(u) => expr = &u.left,
530 Expression::Intersect(i) => expr = &i.left,
531 Expression::Except(e) => expr = &e.left,
532 Expression::Select(select) => {
533 for (i, e) in select.expressions.iter().enumerate() {
534 if get_alias_or_name(e).as_deref() == Some(name) {
535 return Ok(i);
536 }
537 }
538 return Err(crate::error::Error::Parse(format!(
539 "Cannot find column '{}' in set operation",
540 name
541 )));
542 }
543 _ => {
544 return Err(crate::error::Error::Parse(
545 "Expected SELECT or set operation".to_string(),
546 ))
547 }
548 }
549 }
550}
551
552fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
554 if let Expression::Select(select) = select_expr {
555 let mut trimmed = select.as_ref().clone();
556 trimmed.expressions = vec![target_expr.clone()];
557 Expression::Select(Box::new(trimmed))
558 } else {
559 select_expr.clone()
560 }
561}
562
563fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
565 if scope.cte_sources.contains_key(source_name) {
567 for cte_scope in &scope.cte_scopes {
568 if let Expression::Cte(cte) = &cte_scope.expression {
569 if cte.alias.name == source_name {
570 return Some(cte_scope);
571 }
572 }
573 }
574 }
575
576 if let Some(source_info) = scope.sources.get(source_name) {
578 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
579 if let Expression::Subquery(sq) = &source_info.expression {
580 for dt_scope in &scope.derived_table_scopes {
581 if dt_scope.expression == sq.this {
582 return Some(dt_scope);
583 }
584 }
585 }
586 }
587 }
588
589 None
590}
591
592fn find_child_scope_in<'a>(
596 all_cte_scopes: &[&'a Scope],
597 scope: &'a Scope,
598 source_name: &str,
599) -> Option<&'a Scope> {
600 for cte_scope in &scope.cte_scopes {
602 if let Expression::Cte(cte) = &cte_scope.expression {
603 if cte.alias.name == source_name {
604 return Some(cte_scope);
605 }
606 }
607 }
608
609 for cte_scope in all_cte_scopes {
611 if let Expression::Cte(cte) = &cte_scope.expression {
612 if cte.alias.name == source_name {
613 return Some(cte_scope);
614 }
615 }
616 }
617
618 if let Some(source_info) = scope.sources.get(source_name) {
620 if source_info.is_scope {
621 if let Expression::Subquery(sq) = &source_info.expression {
622 for dt_scope in &scope.derived_table_scopes {
623 if dt_scope.expression == sq.this {
624 return Some(dt_scope);
625 }
626 }
627 }
628 }
629 }
630
631 None
632}
633
634fn make_table_column_node(table: &str, column: &str) -> LineageNode {
636 LineageNode::new(
637 format!("{}.{}", table, column),
638 Expression::Column(crate::expressions::Column {
639 name: crate::expressions::Identifier::new(column.to_string()),
640 table: Some(crate::expressions::Identifier::new(table.to_string())),
641 join_mark: false,
642 trailing_comments: vec![],
643 }),
644 Expression::Table(crate::expressions::TableRef::new(table)),
645 )
646}
647
648#[derive(Debug, Clone)]
650struct SimpleColumnRef {
651 table: Option<crate::expressions::Identifier>,
652 column: String,
653}
654
655fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
657 let mut refs = Vec::new();
658 collect_column_refs(expr, &mut refs);
659 refs
660}
661
662fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
663 match expr {
664 Expression::Column(col) => {
665 refs.push(SimpleColumnRef {
666 table: col.table.clone(),
667 column: col.name.name.clone(),
668 });
669 }
670 Expression::Alias(alias) => {
671 collect_column_refs(&alias.this, refs);
672 }
673 Expression::And(op)
674 | Expression::Or(op)
675 | Expression::Eq(op)
676 | Expression::Neq(op)
677 | Expression::Lt(op)
678 | Expression::Lte(op)
679 | Expression::Gt(op)
680 | Expression::Gte(op)
681 | Expression::Add(op)
682 | Expression::Sub(op)
683 | Expression::Mul(op)
684 | Expression::Div(op)
685 | Expression::Mod(op)
686 | Expression::BitwiseAnd(op)
687 | Expression::BitwiseOr(op)
688 | Expression::BitwiseXor(op)
689 | Expression::Concat(op) => {
690 collect_column_refs(&op.left, refs);
691 collect_column_refs(&op.right, refs);
692 }
693 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
694 collect_column_refs(&u.this, refs);
695 }
696 Expression::Function(func) => {
697 for arg in &func.args {
698 collect_column_refs(arg, refs);
699 }
700 }
701 Expression::AggregateFunction(func) => {
702 for arg in &func.args {
703 collect_column_refs(arg, refs);
704 }
705 }
706 Expression::WindowFunction(wf) => {
707 collect_column_refs(&wf.this, refs);
708 }
709 Expression::Case(case) => {
710 if let Some(operand) = &case.operand {
711 collect_column_refs(operand, refs);
712 }
713 for (cond, result) in &case.whens {
714 collect_column_refs(cond, refs);
715 collect_column_refs(result, refs);
716 }
717 if let Some(ref else_expr) = case.else_ {
718 collect_column_refs(else_expr, refs);
719 }
720 }
721 Expression::Cast(cast) => {
722 collect_column_refs(&cast.this, refs);
723 }
724 Expression::Paren(p) => {
725 collect_column_refs(&p.this, refs);
726 }
727 Expression::Coalesce(c) => {
728 for e in &c.expressions {
729 collect_column_refs(e, refs);
730 }
731 }
732 Expression::Subquery(_) | Expression::Exists(_) => {}
734 _ => {}
735 }
736}
737
738#[cfg(test)]
743mod tests {
744 use super::*;
745 use crate::dialects::{Dialect, DialectType};
746
747 fn parse(sql: &str) -> Expression {
748 let dialect = Dialect::get(DialectType::Generic);
749 let ast = dialect.parse(sql).unwrap();
750 ast.into_iter().next().unwrap()
751 }
752
753 #[test]
754 fn test_simple_lineage() {
755 let expr = parse("SELECT a FROM t");
756 let node = lineage("a", &expr, None, false).unwrap();
757
758 assert_eq!(node.name, "a");
759 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
760 let names = node.downstream_names();
762 assert!(
763 names.iter().any(|n| n == "t.a"),
764 "Expected t.a in downstream, got: {:?}",
765 names
766 );
767 }
768
769 #[test]
770 fn test_lineage_walk() {
771 let root = LineageNode {
772 name: "col_a".to_string(),
773 expression: Expression::Null(crate::expressions::Null),
774 source: Expression::Null(crate::expressions::Null),
775 downstream: vec![LineageNode::new(
776 "t.a",
777 Expression::Null(crate::expressions::Null),
778 Expression::Null(crate::expressions::Null),
779 )],
780 source_name: String::new(),
781 reference_node_name: String::new(),
782 };
783
784 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
785 assert_eq!(names.len(), 2);
786 assert_eq!(names[0], "col_a");
787 assert_eq!(names[1], "t.a");
788 }
789
790 #[test]
791 fn test_aliased_column() {
792 let expr = parse("SELECT a + 1 AS b FROM t");
793 let node = lineage("b", &expr, None, false).unwrap();
794
795 assert_eq!(node.name, "b");
796 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
798 assert!(
799 all_names.iter().any(|n| n.contains("a")),
800 "Expected to trace to column a, got: {:?}",
801 all_names
802 );
803 }
804
805 #[test]
806 fn test_qualified_column() {
807 let expr = parse("SELECT t.a FROM t");
808 let node = lineage("a", &expr, None, false).unwrap();
809
810 assert_eq!(node.name, "a");
811 let names = node.downstream_names();
812 assert!(
813 names.iter().any(|n| n == "t.a"),
814 "Expected t.a, got: {:?}",
815 names
816 );
817 }
818
819 #[test]
820 fn test_unqualified_column() {
821 let expr = parse("SELECT a FROM t");
822 let node = lineage("a", &expr, None, false).unwrap();
823
824 let names = node.downstream_names();
826 assert!(
827 names.iter().any(|n| n == "t.a"),
828 "Expected t.a, got: {:?}",
829 names
830 );
831 }
832
833 #[test]
834 fn test_lineage_join() {
835 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
836
837 let node_a = lineage("a", &expr, None, false).unwrap();
838 let names_a = node_a.downstream_names();
839 assert!(
840 names_a.iter().any(|n| n == "t.a"),
841 "Expected t.a, got: {:?}",
842 names_a
843 );
844
845 let node_b = lineage("b", &expr, None, false).unwrap();
846 let names_b = node_b.downstream_names();
847 assert!(
848 names_b.iter().any(|n| n == "s.b"),
849 "Expected s.b, got: {:?}",
850 names_b
851 );
852 }
853
854 #[test]
855 fn test_lineage_derived_table() {
856 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
857 let node = lineage("a", &expr, None, false).unwrap();
858
859 assert_eq!(node.name, "a");
860 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
862 assert!(
863 all_names.iter().any(|n| n == "t.a"),
864 "Expected to trace through derived table to t.a, got: {:?}",
865 all_names
866 );
867 }
868
869 #[test]
870 fn test_lineage_cte() {
871 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
872 let node = lineage("a", &expr, None, false).unwrap();
873
874 assert_eq!(node.name, "a");
875 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
876 assert!(
877 all_names.iter().any(|n| n == "t.a"),
878 "Expected to trace through CTE to t.a, got: {:?}",
879 all_names
880 );
881 }
882
883 #[test]
884 fn test_lineage_union() {
885 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
886 let node = lineage("a", &expr, None, false).unwrap();
887
888 assert_eq!(node.name, "a");
889 assert_eq!(
891 node.downstream.len(),
892 2,
893 "Expected 2 branches for UNION, got {}",
894 node.downstream.len()
895 );
896 }
897
898 #[test]
899 fn test_lineage_cte_union() {
900 let expr = parse(
901 "WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte",
902 );
903 let node = lineage("a", &expr, None, false).unwrap();
904
905 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
907 assert!(
908 all_names.len() >= 3,
909 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
910 all_names
911 );
912 }
913
914 #[test]
915 fn test_lineage_star() {
916 let expr = parse("SELECT * FROM t");
917 let node = lineage("*", &expr, None, false).unwrap();
918
919 assert_eq!(node.name, "*");
920 assert!(
922 !node.downstream.is_empty(),
923 "Star should produce downstream nodes"
924 );
925 }
926
927 #[test]
928 fn test_lineage_subquery_in_select() {
929 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
930 let node = lineage("x", &expr, None, false).unwrap();
931
932 assert_eq!(node.name, "x");
933 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
935 assert!(
936 all_names.len() >= 2,
937 "Expected tracing into scalar subquery, got: {:?}",
938 all_names
939 );
940 }
941
942 #[test]
943 fn test_lineage_multiple_columns() {
944 let expr = parse("SELECT a, b FROM t");
945
946 let node_a = lineage("a", &expr, None, false).unwrap();
947 let node_b = lineage("b", &expr, None, false).unwrap();
948
949 assert_eq!(node_a.name, "a");
950 assert_eq!(node_b.name, "b");
951
952 let names_a = node_a.downstream_names();
954 let names_b = node_b.downstream_names();
955 assert!(names_a.iter().any(|n| n == "t.a"));
956 assert!(names_b.iter().any(|n| n == "t.b"));
957 }
958
959 #[test]
960 fn test_get_source_tables() {
961 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
962 let node = lineage("a", &expr, None, false).unwrap();
963
964 let tables = get_source_tables(&node);
965 assert!(
966 tables.contains("t"),
967 "Expected source table 't', got: {:?}",
968 tables
969 );
970 }
971
972 #[test]
973 fn test_lineage_column_not_found() {
974 let expr = parse("SELECT a FROM t");
975 let result = lineage("nonexistent", &expr, None, false);
976 assert!(result.is_err());
977 }
978
979 #[test]
980 fn test_lineage_nested_cte() {
981 let expr = parse(
982 "WITH cte1 AS (SELECT a FROM t), \
983 cte2 AS (SELECT a FROM cte1) \
984 SELECT a FROM cte2",
985 );
986 let node = lineage("a", &expr, None, false).unwrap();
987
988 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
990 assert!(
991 all_names.len() >= 3,
992 "Expected to trace through nested CTEs, got: {:?}",
993 all_names
994 );
995 }
996
997 #[test]
998 fn test_trim_selects_true() {
999 let expr = parse("SELECT a, b, c FROM t");
1000 let node = lineage("a", &expr, None, true).unwrap();
1001
1002 if let Expression::Select(select) = &node.source {
1004 assert_eq!(
1005 select.expressions.len(),
1006 1,
1007 "Trimmed source should have 1 expression, got {}",
1008 select.expressions.len()
1009 );
1010 } else {
1011 panic!("Expected Select source");
1012 }
1013 }
1014
1015 #[test]
1016 fn test_trim_selects_false() {
1017 let expr = parse("SELECT a, b, c FROM t");
1018 let node = lineage("a", &expr, None, false).unwrap();
1019
1020 if let Expression::Select(select) = &node.source {
1022 assert_eq!(
1023 select.expressions.len(),
1024 3,
1025 "Untrimmed source should have 3 expressions"
1026 );
1027 } else {
1028 panic!("Expected Select source");
1029 }
1030 }
1031
1032 #[test]
1033 fn test_lineage_expression_in_select() {
1034 let expr = parse("SELECT a + b AS c FROM t");
1035 let node = lineage("c", &expr, None, false).unwrap();
1036
1037 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1039 assert!(
1040 all_names.len() >= 3,
1041 "Expected to trace a + b to both columns, got: {:?}",
1042 all_names
1043 );
1044 }
1045
1046 #[test]
1047 fn test_set_operation_by_index() {
1048 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1049
1050 let node = lineage("a", &expr, None, false).unwrap();
1052
1053 assert_eq!(node.downstream.len(), 2);
1055 }
1056}