1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19 pub name: String,
21 pub expression: Expression,
23 pub source: Expression,
25 pub downstream: Vec<LineageNode>,
27 pub source_name: String,
29 pub reference_node_name: String,
31}
32
33impl LineageNode {
34 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36 Self {
37 name: name.into(),
38 expression,
39 source,
40 downstream: Vec::new(),
41 source_name: String::new(),
42 reference_node_name: String::new(),
43 }
44 }
45
46 pub fn walk(&self) -> LineageWalker<'_> {
48 LineageWalker { stack: vec![self] }
49 }
50
51 pub fn downstream_names(&self) -> Vec<String> {
53 self.downstream.iter().map(|n| n.name.clone()).collect()
54 }
55}
56
57pub struct LineageWalker<'a> {
59 stack: Vec<&'a LineageNode>,
60}
61
62impl<'a> Iterator for LineageWalker<'a> {
63 type Item = &'a LineageNode;
64
65 fn next(&mut self) -> Option<Self::Item> {
66 if let Some(node) = self.stack.pop() {
67 for child in node.downstream.iter().rev() {
69 self.stack.push(child);
70 }
71 Some(node)
72 } else {
73 None
74 }
75 }
76}
77
78enum ColumnRef<'a> {
84 Name(&'a str),
85 Index(usize),
86}
87
88pub fn lineage(
114 column: &str,
115 sql: &Expression,
116 dialect: Option<DialectType>,
117 trim_selects: bool,
118) -> Result<LineageNode> {
119 let scope = build_scope(sql);
120 to_node(
121 ColumnRef::Name(column),
122 &scope,
123 dialect,
124 "",
125 "",
126 "",
127 trim_selects,
128 )
129}
130
131pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
133 let mut tables = HashSet::new();
134 collect_source_tables(node, &mut tables);
135 tables
136}
137
138pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
140 if let Expression::Table(table) = &node.source {
141 tables.insert(table.name.name.clone());
142 }
143 for child in &node.downstream {
144 collect_source_tables(child, tables);
145 }
146}
147
148fn to_node(
154 column: ColumnRef<'_>,
155 scope: &Scope,
156 dialect: Option<DialectType>,
157 scope_name: &str,
158 source_name: &str,
159 reference_node_name: &str,
160 trim_selects: bool,
161) -> Result<LineageNode> {
162 to_node_inner(
163 column,
164 scope,
165 dialect,
166 scope_name,
167 source_name,
168 reference_node_name,
169 trim_selects,
170 &[],
171 )
172}
173
174fn to_node_inner(
175 column: ColumnRef<'_>,
176 scope: &Scope,
177 dialect: Option<DialectType>,
178 scope_name: &str,
179 source_name: &str,
180 reference_node_name: &str,
181 trim_selects: bool,
182 ancestor_cte_scopes: &[Scope],
183) -> Result<LineageNode> {
184 let scope_expr = &scope.expression;
185
186 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
188 for s in ancestor_cte_scopes {
189 all_cte_scopes.push(s);
190 }
191
192 let effective_expr = match scope_expr {
195 Expression::Cte(cte) => &cte.this,
196 other => other,
197 };
198
199 if matches!(
201 effective_expr,
202 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
203 ) {
204 if matches!(scope_expr, Expression::Cte(_)) {
206 let mut inner_scope = Scope::new(effective_expr.clone());
207 inner_scope.union_scopes = scope.union_scopes.clone();
208 inner_scope.sources = scope.sources.clone();
209 inner_scope.cte_sources = scope.cte_sources.clone();
210 inner_scope.cte_scopes = scope.cte_scopes.clone();
211 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
212 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
213 return handle_set_operation(
214 &column,
215 &inner_scope,
216 dialect,
217 scope_name,
218 source_name,
219 reference_node_name,
220 trim_selects,
221 ancestor_cte_scopes,
222 );
223 }
224 return handle_set_operation(
225 &column,
226 scope,
227 dialect,
228 scope_name,
229 source_name,
230 reference_node_name,
231 trim_selects,
232 ancestor_cte_scopes,
233 );
234 }
235
236 let select_expr = find_select_expr(effective_expr, &column)?;
238 let column_name = resolve_column_name(&column, &select_expr);
239
240 let node_source = if trim_selects {
242 trim_source(effective_expr, &select_expr)
243 } else {
244 effective_expr.clone()
245 };
246
247 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
249 node.source_name = source_name.to_string();
250 node.reference_node_name = reference_node_name.to_string();
251
252 if matches!(&select_expr, Expression::Star(_)) {
254 for (name, source_info) in &scope.sources {
255 let child = LineageNode::new(
256 format!("{}.*", name),
257 Expression::Star(crate::expressions::Star {
258 table: None,
259 except: None,
260 replace: None,
261 rename: None,
262 trailing_comments: vec![],
263 }),
264 source_info.expression.clone(),
265 );
266 node.downstream.push(child);
267 }
268 return Ok(node);
269 }
270
271 let subqueries: Vec<&Expression> =
273 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
274 for sq_expr in subqueries {
275 if let Expression::Subquery(sq) = sq_expr {
276 for sq_scope in &scope.subquery_scopes {
277 if sq_scope.expression == sq.this {
278 if let Ok(child) = to_node_inner(
279 ColumnRef::Index(0),
280 sq_scope,
281 dialect,
282 &column_name,
283 "",
284 "",
285 trim_selects,
286 ancestor_cte_scopes,
287 ) {
288 node.downstream.push(child);
289 }
290 break;
291 }
292 }
293 }
294 }
295
296 let col_refs = find_column_refs_in_expr(&select_expr);
298 for col_ref in col_refs {
299 let col_name = &col_ref.column;
300 if let Some(ref table_id) = col_ref.table {
301 let tbl = &table_id.name;
302 resolve_qualified_column(
303 &mut node,
304 scope,
305 dialect,
306 tbl,
307 col_name,
308 &column_name,
309 trim_selects,
310 &all_cte_scopes,
311 );
312 } else {
313 resolve_unqualified_column(
314 &mut node,
315 scope,
316 dialect,
317 col_name,
318 &column_name,
319 trim_selects,
320 &all_cte_scopes,
321 );
322 }
323 }
324
325 Ok(node)
326}
327
328fn handle_set_operation(
333 column: &ColumnRef<'_>,
334 scope: &Scope,
335 dialect: Option<DialectType>,
336 scope_name: &str,
337 source_name: &str,
338 reference_node_name: &str,
339 trim_selects: bool,
340 ancestor_cte_scopes: &[Scope],
341) -> Result<LineageNode> {
342 let scope_expr = &scope.expression;
343
344 let col_index = match column {
346 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
347 ColumnRef::Index(i) => *i,
348 };
349
350 let col_name = match column {
351 ColumnRef::Name(name) => name.to_string(),
352 ColumnRef::Index(_) => format!("_{col_index}"),
353 };
354
355 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
356 node.source_name = source_name.to_string();
357 node.reference_node_name = reference_node_name.to_string();
358
359 for branch_scope in &scope.union_scopes {
361 if let Ok(child) = to_node_inner(
362 ColumnRef::Index(col_index),
363 branch_scope,
364 dialect,
365 scope_name,
366 "",
367 "",
368 trim_selects,
369 ancestor_cte_scopes,
370 ) {
371 node.downstream.push(child);
372 }
373 }
374
375 Ok(node)
376}
377
378fn resolve_qualified_column(
383 node: &mut LineageNode,
384 scope: &Scope,
385 dialect: Option<DialectType>,
386 table: &str,
387 col_name: &str,
388 parent_name: &str,
389 trim_selects: bool,
390 all_cte_scopes: &[&Scope],
391) {
392 if scope.cte_sources.contains_key(table) {
394 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
395 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
397 if let Ok(child) = to_node_inner(
398 ColumnRef::Name(col_name),
399 child_scope,
400 dialect,
401 parent_name,
402 table,
403 parent_name,
404 trim_selects,
405 &ancestors,
406 ) {
407 node.downstream.push(child);
408 return;
409 }
410 }
411 }
412
413 if let Some(source_info) = scope.sources.get(table) {
415 if source_info.is_scope {
416 if let Some(child_scope) = find_child_scope(scope, table) {
417 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
418 if let Ok(child) = to_node_inner(
419 ColumnRef::Name(col_name),
420 child_scope,
421 dialect,
422 parent_name,
423 table,
424 parent_name,
425 trim_selects,
426 &ancestors,
427 ) {
428 node.downstream.push(child);
429 return;
430 }
431 }
432 }
433 }
434
435 node.downstream
437 .push(make_table_column_node(table, col_name));
438}
439
440fn resolve_unqualified_column(
441 node: &mut LineageNode,
442 scope: &Scope,
443 dialect: Option<DialectType>,
444 col_name: &str,
445 parent_name: &str,
446 trim_selects: bool,
447 all_cte_scopes: &[&Scope],
448) {
449 let from_source_names: Vec<&String> = scope
454 .sources
455 .iter()
456 .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
457 .map(|(name, _)| name)
458 .collect();
459
460 if from_source_names.len() == 1 {
461 let tbl = from_source_names[0];
462 resolve_qualified_column(
463 node,
464 scope,
465 dialect,
466 tbl,
467 col_name,
468 parent_name,
469 trim_selects,
470 all_cte_scopes,
471 );
472 return;
473 }
474
475 let child = LineageNode::new(
477 col_name.to_string(),
478 Expression::Column(crate::expressions::Column {
479 name: crate::expressions::Identifier::new(col_name.to_string()),
480 table: None,
481 join_mark: false,
482 trailing_comments: vec![],
483 }),
484 node.source.clone(),
485 );
486 node.downstream.push(child);
487}
488
489fn get_alias_or_name(expr: &Expression) -> Option<String> {
495 match expr {
496 Expression::Alias(alias) => Some(alias.alias.name.clone()),
497 Expression::Column(col) => Some(col.name.name.clone()),
498 Expression::Identifier(id) => Some(id.name.clone()),
499 Expression::Star(_) => Some("*".to_string()),
500 _ => None,
501 }
502}
503
504fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
506 match column {
507 ColumnRef::Name(n) => n.to_string(),
508 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
509 }
510}
511
512fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
514 if let Expression::Select(ref select) = scope_expr {
515 match column {
516 ColumnRef::Name(name) => {
517 for expr in &select.expressions {
518 if get_alias_or_name(expr).as_deref() == Some(name) {
519 return Ok(expr.clone());
520 }
521 }
522 Err(crate::error::Error::Parse(format!(
523 "Cannot find column '{}' in query",
524 name
525 )))
526 }
527 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
528 crate::error::Error::Parse(format!("Column index {} out of range", idx))
529 }),
530 }
531 } else {
532 Err(crate::error::Error::Parse(
533 "Expected SELECT expression for column lookup".to_string(),
534 ))
535 }
536}
537
538fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
540 let mut expr = set_op_expr;
541 loop {
542 match expr {
543 Expression::Union(u) => expr = &u.left,
544 Expression::Intersect(i) => expr = &i.left,
545 Expression::Except(e) => expr = &e.left,
546 Expression::Select(select) => {
547 for (i, e) in select.expressions.iter().enumerate() {
548 if get_alias_or_name(e).as_deref() == Some(name) {
549 return Ok(i);
550 }
551 }
552 return Err(crate::error::Error::Parse(format!(
553 "Cannot find column '{}' in set operation",
554 name
555 )));
556 }
557 _ => {
558 return Err(crate::error::Error::Parse(
559 "Expected SELECT or set operation".to_string(),
560 ))
561 }
562 }
563 }
564}
565
566fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
568 if let Expression::Select(select) = select_expr {
569 let mut trimmed = select.as_ref().clone();
570 trimmed.expressions = vec![target_expr.clone()];
571 Expression::Select(Box::new(trimmed))
572 } else {
573 select_expr.clone()
574 }
575}
576
577fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
579 if scope.cte_sources.contains_key(source_name) {
581 for cte_scope in &scope.cte_scopes {
582 if let Expression::Cte(cte) = &cte_scope.expression {
583 if cte.alias.name == source_name {
584 return Some(cte_scope);
585 }
586 }
587 }
588 }
589
590 if let Some(source_info) = scope.sources.get(source_name) {
592 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
593 if let Expression::Subquery(sq) = &source_info.expression {
594 for dt_scope in &scope.derived_table_scopes {
595 if dt_scope.expression == sq.this {
596 return Some(dt_scope);
597 }
598 }
599 }
600 }
601 }
602
603 None
604}
605
606fn find_child_scope_in<'a>(
610 all_cte_scopes: &[&'a Scope],
611 scope: &'a Scope,
612 source_name: &str,
613) -> Option<&'a Scope> {
614 for cte_scope in &scope.cte_scopes {
616 if let Expression::Cte(cte) = &cte_scope.expression {
617 if cte.alias.name == source_name {
618 return Some(cte_scope);
619 }
620 }
621 }
622
623 for cte_scope in all_cte_scopes {
625 if let Expression::Cte(cte) = &cte_scope.expression {
626 if cte.alias.name == source_name {
627 return Some(cte_scope);
628 }
629 }
630 }
631
632 if let Some(source_info) = scope.sources.get(source_name) {
634 if source_info.is_scope {
635 if let Expression::Subquery(sq) = &source_info.expression {
636 for dt_scope in &scope.derived_table_scopes {
637 if dt_scope.expression == sq.this {
638 return Some(dt_scope);
639 }
640 }
641 }
642 }
643 }
644
645 None
646}
647
648fn make_table_column_node(table: &str, column: &str) -> LineageNode {
650 LineageNode::new(
651 format!("{}.{}", table, column),
652 Expression::Column(crate::expressions::Column {
653 name: crate::expressions::Identifier::new(column.to_string()),
654 table: Some(crate::expressions::Identifier::new(table.to_string())),
655 join_mark: false,
656 trailing_comments: vec![],
657 }),
658 Expression::Table(crate::expressions::TableRef::new(table)),
659 )
660}
661
662#[derive(Debug, Clone)]
664struct SimpleColumnRef {
665 table: Option<crate::expressions::Identifier>,
666 column: String,
667}
668
669fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
671 let mut refs = Vec::new();
672 collect_column_refs(expr, &mut refs);
673 refs
674}
675
676fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
677 match expr {
678 Expression::Column(col) => {
679 refs.push(SimpleColumnRef {
680 table: col.table.clone(),
681 column: col.name.name.clone(),
682 });
683 }
684 Expression::Alias(alias) => {
685 collect_column_refs(&alias.this, refs);
686 }
687 Expression::And(op)
688 | Expression::Or(op)
689 | Expression::Eq(op)
690 | Expression::Neq(op)
691 | Expression::Lt(op)
692 | Expression::Lte(op)
693 | Expression::Gt(op)
694 | Expression::Gte(op)
695 | Expression::Add(op)
696 | Expression::Sub(op)
697 | Expression::Mul(op)
698 | Expression::Div(op)
699 | Expression::Mod(op)
700 | Expression::BitwiseAnd(op)
701 | Expression::BitwiseOr(op)
702 | Expression::BitwiseXor(op)
703 | Expression::Concat(op) => {
704 collect_column_refs(&op.left, refs);
705 collect_column_refs(&op.right, refs);
706 }
707 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
708 collect_column_refs(&u.this, refs);
709 }
710 Expression::Function(func) => {
711 for arg in &func.args {
712 collect_column_refs(arg, refs);
713 }
714 }
715 Expression::AggregateFunction(func) => {
716 for arg in &func.args {
717 collect_column_refs(arg, refs);
718 }
719 }
720 Expression::WindowFunction(wf) => {
721 collect_column_refs(&wf.this, refs);
722 }
723 Expression::Case(case) => {
724 if let Some(operand) = &case.operand {
725 collect_column_refs(operand, refs);
726 }
727 for (cond, result) in &case.whens {
728 collect_column_refs(cond, refs);
729 collect_column_refs(result, refs);
730 }
731 if let Some(ref else_expr) = case.else_ {
732 collect_column_refs(else_expr, refs);
733 }
734 }
735 Expression::Cast(cast) => {
736 collect_column_refs(&cast.this, refs);
737 }
738 Expression::Paren(p) => {
739 collect_column_refs(&p.this, refs);
740 }
741 Expression::Coalesce(c) => {
742 for e in &c.expressions {
743 collect_column_refs(e, refs);
744 }
745 }
746 Expression::Subquery(_) | Expression::Exists(_) => {}
748 _ => {}
749 }
750}
751
752#[cfg(test)]
757mod tests {
758 use super::*;
759 use crate::dialects::{Dialect, DialectType};
760
761 fn parse(sql: &str) -> Expression {
762 let dialect = Dialect::get(DialectType::Generic);
763 let ast = dialect.parse(sql).unwrap();
764 ast.into_iter().next().unwrap()
765 }
766
767 #[test]
768 fn test_simple_lineage() {
769 let expr = parse("SELECT a FROM t");
770 let node = lineage("a", &expr, None, false).unwrap();
771
772 assert_eq!(node.name, "a");
773 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
774 let names = node.downstream_names();
776 assert!(
777 names.iter().any(|n| n == "t.a"),
778 "Expected t.a in downstream, got: {:?}",
779 names
780 );
781 }
782
783 #[test]
784 fn test_lineage_walk() {
785 let root = LineageNode {
786 name: "col_a".to_string(),
787 expression: Expression::Null(crate::expressions::Null),
788 source: Expression::Null(crate::expressions::Null),
789 downstream: vec![LineageNode::new(
790 "t.a",
791 Expression::Null(crate::expressions::Null),
792 Expression::Null(crate::expressions::Null),
793 )],
794 source_name: String::new(),
795 reference_node_name: String::new(),
796 };
797
798 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
799 assert_eq!(names.len(), 2);
800 assert_eq!(names[0], "col_a");
801 assert_eq!(names[1], "t.a");
802 }
803
804 #[test]
805 fn test_aliased_column() {
806 let expr = parse("SELECT a + 1 AS b FROM t");
807 let node = lineage("b", &expr, None, false).unwrap();
808
809 assert_eq!(node.name, "b");
810 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
812 assert!(
813 all_names.iter().any(|n| n.contains("a")),
814 "Expected to trace to column a, got: {:?}",
815 all_names
816 );
817 }
818
819 #[test]
820 fn test_qualified_column() {
821 let expr = parse("SELECT t.a FROM t");
822 let node = lineage("a", &expr, None, false).unwrap();
823
824 assert_eq!(node.name, "a");
825 let names = node.downstream_names();
826 assert!(
827 names.iter().any(|n| n == "t.a"),
828 "Expected t.a, got: {:?}",
829 names
830 );
831 }
832
833 #[test]
834 fn test_unqualified_column() {
835 let expr = parse("SELECT a FROM t");
836 let node = lineage("a", &expr, None, false).unwrap();
837
838 let names = node.downstream_names();
840 assert!(
841 names.iter().any(|n| n == "t.a"),
842 "Expected t.a, got: {:?}",
843 names
844 );
845 }
846
847 #[test]
848 fn test_lineage_join() {
849 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
850
851 let node_a = lineage("a", &expr, None, false).unwrap();
852 let names_a = node_a.downstream_names();
853 assert!(
854 names_a.iter().any(|n| n == "t.a"),
855 "Expected t.a, got: {:?}",
856 names_a
857 );
858
859 let node_b = lineage("b", &expr, None, false).unwrap();
860 let names_b = node_b.downstream_names();
861 assert!(
862 names_b.iter().any(|n| n == "s.b"),
863 "Expected s.b, got: {:?}",
864 names_b
865 );
866 }
867
868 #[test]
869 fn test_lineage_derived_table() {
870 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
871 let node = lineage("a", &expr, None, false).unwrap();
872
873 assert_eq!(node.name, "a");
874 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
876 assert!(
877 all_names.iter().any(|n| n == "t.a"),
878 "Expected to trace through derived table to t.a, got: {:?}",
879 all_names
880 );
881 }
882
883 #[test]
884 fn test_lineage_cte() {
885 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
886 let node = lineage("a", &expr, None, false).unwrap();
887
888 assert_eq!(node.name, "a");
889 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
890 assert!(
891 all_names.iter().any(|n| n == "t.a"),
892 "Expected to trace through CTE to t.a, got: {:?}",
893 all_names
894 );
895 }
896
897 #[test]
898 fn test_lineage_union() {
899 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
900 let node = lineage("a", &expr, None, false).unwrap();
901
902 assert_eq!(node.name, "a");
903 assert_eq!(
905 node.downstream.len(),
906 2,
907 "Expected 2 branches for UNION, got {}",
908 node.downstream.len()
909 );
910 }
911
912 #[test]
913 fn test_lineage_cte_union() {
914 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
915 let node = lineage("a", &expr, None, false).unwrap();
916
917 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
919 assert!(
920 all_names.len() >= 3,
921 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
922 all_names
923 );
924 }
925
926 #[test]
927 fn test_lineage_star() {
928 let expr = parse("SELECT * FROM t");
929 let node = lineage("*", &expr, None, false).unwrap();
930
931 assert_eq!(node.name, "*");
932 assert!(
934 !node.downstream.is_empty(),
935 "Star should produce downstream nodes"
936 );
937 }
938
939 #[test]
940 fn test_lineage_subquery_in_select() {
941 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
942 let node = lineage("x", &expr, None, false).unwrap();
943
944 assert_eq!(node.name, "x");
945 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
947 assert!(
948 all_names.len() >= 2,
949 "Expected tracing into scalar subquery, got: {:?}",
950 all_names
951 );
952 }
953
954 #[test]
955 fn test_lineage_multiple_columns() {
956 let expr = parse("SELECT a, b FROM t");
957
958 let node_a = lineage("a", &expr, None, false).unwrap();
959 let node_b = lineage("b", &expr, None, false).unwrap();
960
961 assert_eq!(node_a.name, "a");
962 assert_eq!(node_b.name, "b");
963
964 let names_a = node_a.downstream_names();
966 let names_b = node_b.downstream_names();
967 assert!(names_a.iter().any(|n| n == "t.a"));
968 assert!(names_b.iter().any(|n| n == "t.b"));
969 }
970
971 #[test]
972 fn test_get_source_tables() {
973 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
974 let node = lineage("a", &expr, None, false).unwrap();
975
976 let tables = get_source_tables(&node);
977 assert!(
978 tables.contains("t"),
979 "Expected source table 't', got: {:?}",
980 tables
981 );
982 }
983
984 #[test]
985 fn test_lineage_column_not_found() {
986 let expr = parse("SELECT a FROM t");
987 let result = lineage("nonexistent", &expr, None, false);
988 assert!(result.is_err());
989 }
990
991 #[test]
992 fn test_lineage_nested_cte() {
993 let expr = parse(
994 "WITH cte1 AS (SELECT a FROM t), \
995 cte2 AS (SELECT a FROM cte1) \
996 SELECT a FROM cte2",
997 );
998 let node = lineage("a", &expr, None, false).unwrap();
999
1000 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1002 assert!(
1003 all_names.len() >= 3,
1004 "Expected to trace through nested CTEs, got: {:?}",
1005 all_names
1006 );
1007 }
1008
1009 #[test]
1010 fn test_trim_selects_true() {
1011 let expr = parse("SELECT a, b, c FROM t");
1012 let node = lineage("a", &expr, None, true).unwrap();
1013
1014 if let Expression::Select(select) = &node.source {
1016 assert_eq!(
1017 select.expressions.len(),
1018 1,
1019 "Trimmed source should have 1 expression, got {}",
1020 select.expressions.len()
1021 );
1022 } else {
1023 panic!("Expected Select source");
1024 }
1025 }
1026
1027 #[test]
1028 fn test_trim_selects_false() {
1029 let expr = parse("SELECT a, b, c FROM t");
1030 let node = lineage("a", &expr, None, false).unwrap();
1031
1032 if let Expression::Select(select) = &node.source {
1034 assert_eq!(
1035 select.expressions.len(),
1036 3,
1037 "Untrimmed source should have 3 expressions"
1038 );
1039 } else {
1040 panic!("Expected Select source");
1041 }
1042 }
1043
1044 #[test]
1045 fn test_lineage_expression_in_select() {
1046 let expr = parse("SELECT a + b AS c FROM t");
1047 let node = lineage("c", &expr, None, false).unwrap();
1048
1049 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1051 assert!(
1052 all_names.len() >= 3,
1053 "Expected to trace a + b to both columns, got: {:?}",
1054 all_names
1055 );
1056 }
1057
1058 #[test]
1059 fn test_set_operation_by_index() {
1060 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1061
1062 let node = lineage("a", &expr, None, false).unwrap();
1064
1065 assert_eq!(node.downstream.len(), 2);
1067 }
1068}