1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use std::collections::HashSet;
14
15#[derive(Debug, Clone)]
17pub struct LineageNode {
18 pub name: String,
20 pub expression: Expression,
22 pub source: Expression,
24 pub downstream: Vec<LineageNode>,
26 pub source_name: String,
28 pub reference_node_name: String,
30}
31
32impl LineageNode {
33 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
35 Self {
36 name: name.into(),
37 expression,
38 source,
39 downstream: Vec::new(),
40 source_name: String::new(),
41 reference_node_name: String::new(),
42 }
43 }
44
45 pub fn walk(&self) -> LineageWalker<'_> {
47 LineageWalker {
48 stack: vec![self],
49 }
50 }
51
52 pub fn downstream_names(&self) -> Vec<String> {
54 self.downstream.iter().map(|n| n.name.clone()).collect()
55 }
56}
57
58pub struct LineageWalker<'a> {
60 stack: Vec<&'a LineageNode>,
61}
62
63impl<'a> Iterator for LineageWalker<'a> {
64 type Item = &'a LineageNode;
65
66 fn next(&mut self) -> Option<Self::Item> {
67 if let Some(node) = self.stack.pop() {
68 for child in node.downstream.iter().rev() {
70 self.stack.push(child);
71 }
72 Some(node)
73 } else {
74 None
75 }
76 }
77}
78
79enum ColumnRef<'a> {
85 Name(&'a str),
86 Index(usize),
87}
88
89pub fn lineage(
115 column: &str,
116 sql: &Expression,
117 dialect: Option<DialectType>,
118 trim_selects: bool,
119) -> Result<LineageNode> {
120 let scope = build_scope(sql);
121 to_node(
122 ColumnRef::Name(column),
123 &scope,
124 dialect,
125 "",
126 "",
127 "",
128 trim_selects,
129 )
130}
131
132pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
134 let mut tables = HashSet::new();
135 collect_source_tables(node, &mut tables);
136 tables
137}
138
139pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
141 if let Expression::Table(table) = &node.source {
142 tables.insert(table.name.name.clone());
143 }
144 for child in &node.downstream {
145 collect_source_tables(child, tables);
146 }
147}
148
149fn to_node(
155 column: ColumnRef<'_>,
156 scope: &Scope,
157 dialect: Option<DialectType>,
158 scope_name: &str,
159 source_name: &str,
160 reference_node_name: &str,
161 trim_selects: bool,
162) -> Result<LineageNode> {
163 to_node_inner(column, scope, dialect, scope_name, source_name, reference_node_name, trim_selects, &[])
164}
165
166fn to_node_inner(
167 column: ColumnRef<'_>,
168 scope: &Scope,
169 dialect: Option<DialectType>,
170 scope_name: &str,
171 source_name: &str,
172 reference_node_name: &str,
173 trim_selects: bool,
174 ancestor_cte_scopes: &[Scope],
175) -> Result<LineageNode> {
176 let scope_expr = &scope.expression;
177
178 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
180 for s in ancestor_cte_scopes {
181 all_cte_scopes.push(s);
182 }
183
184 let effective_expr = match scope_expr {
187 Expression::Cte(cte) => &cte.this,
188 other => other,
189 };
190
191 if matches!(
193 effective_expr,
194 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
195 ) {
196 if matches!(scope_expr, Expression::Cte(_)) {
198 let mut inner_scope = Scope::new(effective_expr.clone());
199 inner_scope.union_scopes = scope.union_scopes.clone();
200 inner_scope.sources = scope.sources.clone();
201 inner_scope.cte_sources = scope.cte_sources.clone();
202 inner_scope.cte_scopes = scope.cte_scopes.clone();
203 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
204 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
205 return handle_set_operation(
206 &column,
207 &inner_scope,
208 dialect,
209 scope_name,
210 source_name,
211 reference_node_name,
212 trim_selects,
213 ancestor_cte_scopes,
214 );
215 }
216 return handle_set_operation(
217 &column,
218 scope,
219 dialect,
220 scope_name,
221 source_name,
222 reference_node_name,
223 trim_selects,
224 ancestor_cte_scopes,
225 );
226 }
227
228 let select_expr = find_select_expr(effective_expr, &column)?;
230 let column_name = resolve_column_name(&column, &select_expr);
231
232 let node_source = if trim_selects {
234 trim_source(effective_expr, &select_expr)
235 } else {
236 effective_expr.clone()
237 };
238
239 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
241 node.source_name = source_name.to_string();
242 node.reference_node_name = reference_node_name.to_string();
243
244 if matches!(&select_expr, Expression::Star(_)) {
246 for (name, source_info) in &scope.sources {
247 let child = LineageNode::new(
248 format!("{}.*", name),
249 Expression::Star(crate::expressions::Star {
250 table: None,
251 except: None,
252 replace: None,
253 rename: None,
254 trailing_comments: vec![],
255 }),
256 source_info.expression.clone(),
257 );
258 node.downstream.push(child);
259 }
260 return Ok(node);
261 }
262
263 let subqueries: Vec<&Expression> = select_expr.find_all(|e| {
265 matches!(e, Expression::Subquery(sq) if sq.alias.is_none())
266 });
267 for sq_expr in subqueries {
268 if let Expression::Subquery(sq) = sq_expr {
269 for sq_scope in &scope.subquery_scopes {
270 if sq_scope.expression == sq.this {
271 if let Ok(child) = to_node_inner(
272 ColumnRef::Index(0),
273 sq_scope,
274 dialect,
275 &column_name,
276 "",
277 "",
278 trim_selects,
279 ancestor_cte_scopes,
280 ) {
281 node.downstream.push(child);
282 }
283 break;
284 }
285 }
286 }
287 }
288
289 let col_refs = find_column_refs_in_expr(&select_expr);
291 for col_ref in col_refs {
292 let col_name = &col_ref.column;
293 if let Some(ref table_id) = col_ref.table {
294 let tbl = &table_id.name;
295 resolve_qualified_column(
296 &mut node,
297 scope,
298 dialect,
299 tbl,
300 col_name,
301 &column_name,
302 trim_selects,
303 &all_cte_scopes,
304 );
305 } else {
306 resolve_unqualified_column(
307 &mut node,
308 scope,
309 dialect,
310 col_name,
311 &column_name,
312 trim_selects,
313 &all_cte_scopes,
314 );
315 }
316 }
317
318 Ok(node)
319}
320
321fn handle_set_operation(
326 column: &ColumnRef<'_>,
327 scope: &Scope,
328 dialect: Option<DialectType>,
329 scope_name: &str,
330 source_name: &str,
331 reference_node_name: &str,
332 trim_selects: bool,
333 ancestor_cte_scopes: &[Scope],
334) -> Result<LineageNode> {
335 let scope_expr = &scope.expression;
336
337 let col_index = match column {
339 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
340 ColumnRef::Index(i) => *i,
341 };
342
343 let col_name = match column {
344 ColumnRef::Name(name) => name.to_string(),
345 ColumnRef::Index(_) => format!("_{col_index}"),
346 };
347
348 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
349 node.source_name = source_name.to_string();
350 node.reference_node_name = reference_node_name.to_string();
351
352 for branch_scope in &scope.union_scopes {
354 if let Ok(child) = to_node_inner(
355 ColumnRef::Index(col_index),
356 branch_scope,
357 dialect,
358 scope_name,
359 "",
360 "",
361 trim_selects,
362 ancestor_cte_scopes,
363 ) {
364 node.downstream.push(child);
365 }
366 }
367
368 Ok(node)
369}
370
371fn resolve_qualified_column(
376 node: &mut LineageNode,
377 scope: &Scope,
378 dialect: Option<DialectType>,
379 table: &str,
380 col_name: &str,
381 parent_name: &str,
382 trim_selects: bool,
383 all_cte_scopes: &[&Scope],
384) {
385 if scope.cte_sources.contains_key(table) {
387 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
388 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
390 if let Ok(child) = to_node_inner(
391 ColumnRef::Name(col_name),
392 child_scope,
393 dialect,
394 parent_name,
395 table,
396 parent_name,
397 trim_selects,
398 &ancestors,
399 ) {
400 node.downstream.push(child);
401 return;
402 }
403 }
404 }
405
406 if let Some(source_info) = scope.sources.get(table) {
408 if source_info.is_scope {
409 if let Some(child_scope) = find_child_scope(scope, table) {
410 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
411 if let Ok(child) = to_node_inner(
412 ColumnRef::Name(col_name),
413 child_scope,
414 dialect,
415 parent_name,
416 table,
417 parent_name,
418 trim_selects,
419 &ancestors,
420 ) {
421 node.downstream.push(child);
422 return;
423 }
424 }
425 }
426 }
427
428 node.downstream.push(make_table_column_node(table, col_name));
430}
431
432fn resolve_unqualified_column(
433 node: &mut LineageNode,
434 scope: &Scope,
435 dialect: Option<DialectType>,
436 col_name: &str,
437 parent_name: &str,
438 trim_selects: bool,
439 all_cte_scopes: &[&Scope],
440) {
441 let from_source_names: Vec<&String> = scope
446 .sources
447 .iter()
448 .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
449 .map(|(name, _)| name)
450 .collect();
451
452 if from_source_names.len() == 1 {
453 let tbl = from_source_names[0];
454 resolve_qualified_column(node, scope, dialect, tbl, col_name, parent_name, trim_selects, all_cte_scopes);
455 return;
456 }
457
458 let child = LineageNode::new(
460 col_name.to_string(),
461 Expression::Column(crate::expressions::Column {
462 name: crate::expressions::Identifier::new(col_name.to_string()),
463 table: None,
464 join_mark: false,
465 trailing_comments: vec![],
466 }),
467 node.source.clone(),
468 );
469 node.downstream.push(child);
470}
471
472fn get_alias_or_name(expr: &Expression) -> Option<String> {
478 match expr {
479 Expression::Alias(alias) => Some(alias.alias.name.clone()),
480 Expression::Column(col) => Some(col.name.name.clone()),
481 Expression::Identifier(id) => Some(id.name.clone()),
482 Expression::Star(_) => Some("*".to_string()),
483 _ => None,
484 }
485}
486
487fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
489 match column {
490 ColumnRef::Name(n) => n.to_string(),
491 ColumnRef::Index(_) => {
492 get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string())
493 }
494 }
495}
496
497fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
499 if let Expression::Select(ref select) = scope_expr {
500 match column {
501 ColumnRef::Name(name) => {
502 for expr in &select.expressions {
503 if get_alias_or_name(expr).as_deref() == Some(name) {
504 return Ok(expr.clone());
505 }
506 }
507 Err(crate::error::Error::Parse(format!(
508 "Cannot find column '{}' in query",
509 name
510 )))
511 }
512 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
513 crate::error::Error::Parse(format!("Column index {} out of range", idx))
514 }),
515 }
516 } else {
517 Err(crate::error::Error::Parse(
518 "Expected SELECT expression for column lookup".to_string(),
519 ))
520 }
521}
522
523fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
525 let mut expr = set_op_expr;
526 loop {
527 match expr {
528 Expression::Union(u) => expr = &u.left,
529 Expression::Intersect(i) => expr = &i.left,
530 Expression::Except(e) => expr = &e.left,
531 Expression::Select(select) => {
532 for (i, e) in select.expressions.iter().enumerate() {
533 if get_alias_or_name(e).as_deref() == Some(name) {
534 return Ok(i);
535 }
536 }
537 return Err(crate::error::Error::Parse(format!(
538 "Cannot find column '{}' in set operation",
539 name
540 )));
541 }
542 _ => {
543 return Err(crate::error::Error::Parse(
544 "Expected SELECT or set operation".to_string(),
545 ))
546 }
547 }
548 }
549}
550
551fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
553 if let Expression::Select(select) = select_expr {
554 let mut trimmed = select.as_ref().clone();
555 trimmed.expressions = vec![target_expr.clone()];
556 Expression::Select(Box::new(trimmed))
557 } else {
558 select_expr.clone()
559 }
560}
561
562fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
564 if scope.cte_sources.contains_key(source_name) {
566 for cte_scope in &scope.cte_scopes {
567 if let Expression::Cte(cte) = &cte_scope.expression {
568 if cte.alias.name == source_name {
569 return Some(cte_scope);
570 }
571 }
572 }
573 }
574
575 if let Some(source_info) = scope.sources.get(source_name) {
577 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
578 if let Expression::Subquery(sq) = &source_info.expression {
579 for dt_scope in &scope.derived_table_scopes {
580 if dt_scope.expression == sq.this {
581 return Some(dt_scope);
582 }
583 }
584 }
585 }
586 }
587
588 None
589}
590
591fn find_child_scope_in<'a>(
595 all_cte_scopes: &[&'a Scope],
596 scope: &'a Scope,
597 source_name: &str,
598) -> Option<&'a Scope> {
599 for cte_scope in &scope.cte_scopes {
601 if let Expression::Cte(cte) = &cte_scope.expression {
602 if cte.alias.name == source_name {
603 return Some(cte_scope);
604 }
605 }
606 }
607
608 for cte_scope in all_cte_scopes {
610 if let Expression::Cte(cte) = &cte_scope.expression {
611 if cte.alias.name == source_name {
612 return Some(cte_scope);
613 }
614 }
615 }
616
617 if let Some(source_info) = scope.sources.get(source_name) {
619 if source_info.is_scope {
620 if let Expression::Subquery(sq) = &source_info.expression {
621 for dt_scope in &scope.derived_table_scopes {
622 if dt_scope.expression == sq.this {
623 return Some(dt_scope);
624 }
625 }
626 }
627 }
628 }
629
630 None
631}
632
633fn make_table_column_node(table: &str, column: &str) -> LineageNode {
635 LineageNode::new(
636 format!("{}.{}", table, column),
637 Expression::Column(crate::expressions::Column {
638 name: crate::expressions::Identifier::new(column.to_string()),
639 table: Some(crate::expressions::Identifier::new(table.to_string())),
640 join_mark: false,
641 trailing_comments: vec![],
642 }),
643 Expression::Table(crate::expressions::TableRef::new(table)),
644 )
645}
646
647#[derive(Debug, Clone)]
649struct SimpleColumnRef {
650 table: Option<crate::expressions::Identifier>,
651 column: String,
652}
653
654fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
656 let mut refs = Vec::new();
657 collect_column_refs(expr, &mut refs);
658 refs
659}
660
661fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
662 match expr {
663 Expression::Column(col) => {
664 refs.push(SimpleColumnRef {
665 table: col.table.clone(),
666 column: col.name.name.clone(),
667 });
668 }
669 Expression::Alias(alias) => {
670 collect_column_refs(&alias.this, refs);
671 }
672 Expression::And(op)
673 | Expression::Or(op)
674 | Expression::Eq(op)
675 | Expression::Neq(op)
676 | Expression::Lt(op)
677 | Expression::Lte(op)
678 | Expression::Gt(op)
679 | Expression::Gte(op)
680 | Expression::Add(op)
681 | Expression::Sub(op)
682 | Expression::Mul(op)
683 | Expression::Div(op)
684 | Expression::Mod(op)
685 | Expression::BitwiseAnd(op)
686 | Expression::BitwiseOr(op)
687 | Expression::BitwiseXor(op)
688 | Expression::Concat(op) => {
689 collect_column_refs(&op.left, refs);
690 collect_column_refs(&op.right, refs);
691 }
692 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
693 collect_column_refs(&u.this, refs);
694 }
695 Expression::Function(func) => {
696 for arg in &func.args {
697 collect_column_refs(arg, refs);
698 }
699 }
700 Expression::AggregateFunction(func) => {
701 for arg in &func.args {
702 collect_column_refs(arg, refs);
703 }
704 }
705 Expression::WindowFunction(wf) => {
706 collect_column_refs(&wf.this, refs);
707 }
708 Expression::Case(case) => {
709 if let Some(operand) = &case.operand {
710 collect_column_refs(operand, refs);
711 }
712 for (cond, result) in &case.whens {
713 collect_column_refs(cond, refs);
714 collect_column_refs(result, refs);
715 }
716 if let Some(ref else_expr) = case.else_ {
717 collect_column_refs(else_expr, refs);
718 }
719 }
720 Expression::Cast(cast) => {
721 collect_column_refs(&cast.this, refs);
722 }
723 Expression::Paren(p) => {
724 collect_column_refs(&p.this, refs);
725 }
726 Expression::Coalesce(c) => {
727 for e in &c.expressions {
728 collect_column_refs(e, refs);
729 }
730 }
731 Expression::Subquery(_) | Expression::Exists(_) => {}
733 _ => {}
734 }
735}
736
737#[cfg(test)]
742mod tests {
743 use super::*;
744 use crate::dialects::{Dialect, DialectType};
745
746 fn parse(sql: &str) -> Expression {
747 let dialect = Dialect::get(DialectType::Generic);
748 let ast = dialect.parse(sql).unwrap();
749 ast.into_iter().next().unwrap()
750 }
751
752 #[test]
753 fn test_simple_lineage() {
754 let expr = parse("SELECT a FROM t");
755 let node = lineage("a", &expr, None, false).unwrap();
756
757 assert_eq!(node.name, "a");
758 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
759 let names = node.downstream_names();
761 assert!(
762 names.iter().any(|n| n == "t.a"),
763 "Expected t.a in downstream, got: {:?}",
764 names
765 );
766 }
767
768 #[test]
769 fn test_lineage_walk() {
770 let root = LineageNode {
771 name: "col_a".to_string(),
772 expression: Expression::Null(crate::expressions::Null),
773 source: Expression::Null(crate::expressions::Null),
774 downstream: vec![LineageNode::new(
775 "t.a",
776 Expression::Null(crate::expressions::Null),
777 Expression::Null(crate::expressions::Null),
778 )],
779 source_name: String::new(),
780 reference_node_name: String::new(),
781 };
782
783 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
784 assert_eq!(names.len(), 2);
785 assert_eq!(names[0], "col_a");
786 assert_eq!(names[1], "t.a");
787 }
788
789 #[test]
790 fn test_aliased_column() {
791 let expr = parse("SELECT a + 1 AS b FROM t");
792 let node = lineage("b", &expr, None, false).unwrap();
793
794 assert_eq!(node.name, "b");
795 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
797 assert!(
798 all_names.iter().any(|n| n.contains("a")),
799 "Expected to trace to column a, got: {:?}",
800 all_names
801 );
802 }
803
804 #[test]
805 fn test_qualified_column() {
806 let expr = parse("SELECT t.a FROM t");
807 let node = lineage("a", &expr, None, false).unwrap();
808
809 assert_eq!(node.name, "a");
810 let names = node.downstream_names();
811 assert!(
812 names.iter().any(|n| n == "t.a"),
813 "Expected t.a, got: {:?}",
814 names
815 );
816 }
817
818 #[test]
819 fn test_unqualified_column() {
820 let expr = parse("SELECT a FROM t");
821 let node = lineage("a", &expr, None, false).unwrap();
822
823 let names = node.downstream_names();
825 assert!(
826 names.iter().any(|n| n == "t.a"),
827 "Expected t.a, got: {:?}",
828 names
829 );
830 }
831
832 #[test]
833 fn test_lineage_join() {
834 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
835
836 let node_a = lineage("a", &expr, None, false).unwrap();
837 let names_a = node_a.downstream_names();
838 assert!(
839 names_a.iter().any(|n| n == "t.a"),
840 "Expected t.a, got: {:?}",
841 names_a
842 );
843
844 let node_b = lineage("b", &expr, None, false).unwrap();
845 let names_b = node_b.downstream_names();
846 assert!(
847 names_b.iter().any(|n| n == "s.b"),
848 "Expected s.b, got: {:?}",
849 names_b
850 );
851 }
852
853 #[test]
854 fn test_lineage_derived_table() {
855 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
856 let node = lineage("a", &expr, None, false).unwrap();
857
858 assert_eq!(node.name, "a");
859 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
861 assert!(
862 all_names.iter().any(|n| n == "t.a"),
863 "Expected to trace through derived table to t.a, got: {:?}",
864 all_names
865 );
866 }
867
868 #[test]
869 fn test_lineage_cte() {
870 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
871 let node = lineage("a", &expr, None, false).unwrap();
872
873 assert_eq!(node.name, "a");
874 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
875 assert!(
876 all_names.iter().any(|n| n == "t.a"),
877 "Expected to trace through CTE to t.a, got: {:?}",
878 all_names
879 );
880 }
881
882 #[test]
883 fn test_lineage_union() {
884 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
885 let node = lineage("a", &expr, None, false).unwrap();
886
887 assert_eq!(node.name, "a");
888 assert_eq!(
890 node.downstream.len(),
891 2,
892 "Expected 2 branches for UNION, got {}",
893 node.downstream.len()
894 );
895 }
896
897 #[test]
898 fn test_lineage_cte_union() {
899 let expr = parse(
900 "WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte",
901 );
902 let node = lineage("a", &expr, None, false).unwrap();
903
904 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
906 assert!(
907 all_names.len() >= 3,
908 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
909 all_names
910 );
911 }
912
913 #[test]
914 fn test_lineage_star() {
915 let expr = parse("SELECT * FROM t");
916 let node = lineage("*", &expr, None, false).unwrap();
917
918 assert_eq!(node.name, "*");
919 assert!(
921 !node.downstream.is_empty(),
922 "Star should produce downstream nodes"
923 );
924 }
925
926 #[test]
927 fn test_lineage_subquery_in_select() {
928 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
929 let node = lineage("x", &expr, None, false).unwrap();
930
931 assert_eq!(node.name, "x");
932 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
934 assert!(
935 all_names.len() >= 2,
936 "Expected tracing into scalar subquery, got: {:?}",
937 all_names
938 );
939 }
940
941 #[test]
942 fn test_lineage_multiple_columns() {
943 let expr = parse("SELECT a, b FROM t");
944
945 let node_a = lineage("a", &expr, None, false).unwrap();
946 let node_b = lineage("b", &expr, None, false).unwrap();
947
948 assert_eq!(node_a.name, "a");
949 assert_eq!(node_b.name, "b");
950
951 let names_a = node_a.downstream_names();
953 let names_b = node_b.downstream_names();
954 assert!(names_a.iter().any(|n| n == "t.a"));
955 assert!(names_b.iter().any(|n| n == "t.b"));
956 }
957
958 #[test]
959 fn test_get_source_tables() {
960 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
961 let node = lineage("a", &expr, None, false).unwrap();
962
963 let tables = get_source_tables(&node);
964 assert!(
965 tables.contains("t"),
966 "Expected source table 't', got: {:?}",
967 tables
968 );
969 }
970
971 #[test]
972 fn test_lineage_column_not_found() {
973 let expr = parse("SELECT a FROM t");
974 let result = lineage("nonexistent", &expr, None, false);
975 assert!(result.is_err());
976 }
977
978 #[test]
979 fn test_lineage_nested_cte() {
980 let expr = parse(
981 "WITH cte1 AS (SELECT a FROM t), \
982 cte2 AS (SELECT a FROM cte1) \
983 SELECT a FROM cte2",
984 );
985 let node = lineage("a", &expr, None, false).unwrap();
986
987 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
989 assert!(
990 all_names.len() >= 3,
991 "Expected to trace through nested CTEs, got: {:?}",
992 all_names
993 );
994 }
995
996 #[test]
997 fn test_trim_selects_true() {
998 let expr = parse("SELECT a, b, c FROM t");
999 let node = lineage("a", &expr, None, true).unwrap();
1000
1001 if let Expression::Select(select) = &node.source {
1003 assert_eq!(
1004 select.expressions.len(),
1005 1,
1006 "Trimmed source should have 1 expression, got {}",
1007 select.expressions.len()
1008 );
1009 } else {
1010 panic!("Expected Select source");
1011 }
1012 }
1013
1014 #[test]
1015 fn test_trim_selects_false() {
1016 let expr = parse("SELECT a, b, c FROM t");
1017 let node = lineage("a", &expr, None, false).unwrap();
1018
1019 if let Expression::Select(select) = &node.source {
1021 assert_eq!(
1022 select.expressions.len(),
1023 3,
1024 "Untrimmed source should have 3 expressions"
1025 );
1026 } else {
1027 panic!("Expected Select source");
1028 }
1029 }
1030
1031 #[test]
1032 fn test_lineage_expression_in_select() {
1033 let expr = parse("SELECT a + b AS c FROM t");
1034 let node = lineage("c", &expr, None, false).unwrap();
1035
1036 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1038 assert!(
1039 all_names.len() >= 3,
1040 "Expected to trace a + b to both columns, got: {:?}",
1041 all_names
1042 );
1043 }
1044
1045 #[test]
1046 fn test_set_operation_by_index() {
1047 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1048
1049 let node = lineage("a", &expr, None, false).unwrap();
1051
1052 assert_eq!(node.downstream.len(), 2);
1054 }
1055}