1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19 pub name: String,
21 pub expression: Expression,
23 pub source: Expression,
25 pub downstream: Vec<LineageNode>,
27 pub source_name: String,
29 pub reference_node_name: String,
31}
32
33impl LineageNode {
34 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36 Self {
37 name: name.into(),
38 expression,
39 source,
40 downstream: Vec::new(),
41 source_name: String::new(),
42 reference_node_name: String::new(),
43 }
44 }
45
46 pub fn walk(&self) -> LineageWalker<'_> {
48 LineageWalker { stack: vec![self] }
49 }
50
51 pub fn downstream_names(&self) -> Vec<String> {
53 self.downstream.iter().map(|n| n.name.clone()).collect()
54 }
55}
56
57pub struct LineageWalker<'a> {
59 stack: Vec<&'a LineageNode>,
60}
61
62impl<'a> Iterator for LineageWalker<'a> {
63 type Item = &'a LineageNode;
64
65 fn next(&mut self) -> Option<Self::Item> {
66 if let Some(node) = self.stack.pop() {
67 for child in node.downstream.iter().rev() {
69 self.stack.push(child);
70 }
71 Some(node)
72 } else {
73 None
74 }
75 }
76}
77
78enum ColumnRef<'a> {
84 Name(&'a str),
85 Index(usize),
86}
87
88pub fn lineage(
114 column: &str,
115 sql: &Expression,
116 dialect: Option<DialectType>,
117 trim_selects: bool,
118) -> Result<LineageNode> {
119 let scope = build_scope(sql);
120 to_node(
121 ColumnRef::Name(column),
122 &scope,
123 dialect,
124 "",
125 "",
126 "",
127 trim_selects,
128 )
129}
130
131pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
133 let mut tables = HashSet::new();
134 collect_source_tables(node, &mut tables);
135 tables
136}
137
138pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
140 if let Expression::Table(table) = &node.source {
141 tables.insert(table.name.name.clone());
142 }
143 for child in &node.downstream {
144 collect_source_tables(child, tables);
145 }
146}
147
148fn to_node(
154 column: ColumnRef<'_>,
155 scope: &Scope,
156 dialect: Option<DialectType>,
157 scope_name: &str,
158 source_name: &str,
159 reference_node_name: &str,
160 trim_selects: bool,
161) -> Result<LineageNode> {
162 to_node_inner(
163 column,
164 scope,
165 dialect,
166 scope_name,
167 source_name,
168 reference_node_name,
169 trim_selects,
170 &[],
171 )
172}
173
174fn to_node_inner(
175 column: ColumnRef<'_>,
176 scope: &Scope,
177 dialect: Option<DialectType>,
178 scope_name: &str,
179 source_name: &str,
180 reference_node_name: &str,
181 trim_selects: bool,
182 ancestor_cte_scopes: &[Scope],
183) -> Result<LineageNode> {
184 let scope_expr = &scope.expression;
185
186 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
188 for s in ancestor_cte_scopes {
189 all_cte_scopes.push(s);
190 }
191
192 let effective_expr = match scope_expr {
195 Expression::Cte(cte) => &cte.this,
196 other => other,
197 };
198
199 if matches!(
201 effective_expr,
202 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
203 ) {
204 if matches!(scope_expr, Expression::Cte(_)) {
206 let mut inner_scope = Scope::new(effective_expr.clone());
207 inner_scope.union_scopes = scope.union_scopes.clone();
208 inner_scope.sources = scope.sources.clone();
209 inner_scope.cte_sources = scope.cte_sources.clone();
210 inner_scope.cte_scopes = scope.cte_scopes.clone();
211 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
212 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
213 return handle_set_operation(
214 &column,
215 &inner_scope,
216 dialect,
217 scope_name,
218 source_name,
219 reference_node_name,
220 trim_selects,
221 ancestor_cte_scopes,
222 );
223 }
224 return handle_set_operation(
225 &column,
226 scope,
227 dialect,
228 scope_name,
229 source_name,
230 reference_node_name,
231 trim_selects,
232 ancestor_cte_scopes,
233 );
234 }
235
236 let select_expr = find_select_expr(effective_expr, &column)?;
238 let column_name = resolve_column_name(&column, &select_expr);
239
240 let node_source = if trim_selects {
242 trim_source(effective_expr, &select_expr)
243 } else {
244 effective_expr.clone()
245 };
246
247 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
249 node.source_name = source_name.to_string();
250 node.reference_node_name = reference_node_name.to_string();
251
252 if matches!(&select_expr, Expression::Star(_)) {
254 for (name, source_info) in &scope.sources {
255 let child = LineageNode::new(
256 format!("{}.*", name),
257 Expression::Star(crate::expressions::Star {
258 table: None,
259 except: None,
260 replace: None,
261 rename: None,
262 trailing_comments: vec![],
263 }),
264 source_info.expression.clone(),
265 );
266 node.downstream.push(child);
267 }
268 return Ok(node);
269 }
270
271 let subqueries: Vec<&Expression> =
273 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
274 for sq_expr in subqueries {
275 if let Expression::Subquery(sq) = sq_expr {
276 for sq_scope in &scope.subquery_scopes {
277 if sq_scope.expression == sq.this {
278 if let Ok(child) = to_node_inner(
279 ColumnRef::Index(0),
280 sq_scope,
281 dialect,
282 &column_name,
283 "",
284 "",
285 trim_selects,
286 ancestor_cte_scopes,
287 ) {
288 node.downstream.push(child);
289 }
290 break;
291 }
292 }
293 }
294 }
295
296 let col_refs = find_column_refs_in_expr(&select_expr);
298 for col_ref in col_refs {
299 let col_name = &col_ref.column;
300 if let Some(ref table_id) = col_ref.table {
301 let tbl = &table_id.name;
302 resolve_qualified_column(
303 &mut node,
304 scope,
305 dialect,
306 tbl,
307 col_name,
308 &column_name,
309 trim_selects,
310 &all_cte_scopes,
311 );
312 } else {
313 resolve_unqualified_column(
314 &mut node,
315 scope,
316 dialect,
317 col_name,
318 &column_name,
319 trim_selects,
320 &all_cte_scopes,
321 );
322 }
323 }
324
325 Ok(node)
326}
327
328fn handle_set_operation(
333 column: &ColumnRef<'_>,
334 scope: &Scope,
335 dialect: Option<DialectType>,
336 scope_name: &str,
337 source_name: &str,
338 reference_node_name: &str,
339 trim_selects: bool,
340 ancestor_cte_scopes: &[Scope],
341) -> Result<LineageNode> {
342 let scope_expr = &scope.expression;
343
344 let col_index = match column {
346 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
347 ColumnRef::Index(i) => *i,
348 };
349
350 let col_name = match column {
351 ColumnRef::Name(name) => name.to_string(),
352 ColumnRef::Index(_) => format!("_{col_index}"),
353 };
354
355 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
356 node.source_name = source_name.to_string();
357 node.reference_node_name = reference_node_name.to_string();
358
359 for branch_scope in &scope.union_scopes {
361 if let Ok(child) = to_node_inner(
362 ColumnRef::Index(col_index),
363 branch_scope,
364 dialect,
365 scope_name,
366 "",
367 "",
368 trim_selects,
369 ancestor_cte_scopes,
370 ) {
371 node.downstream.push(child);
372 }
373 }
374
375 Ok(node)
376}
377
378fn resolve_qualified_column(
383 node: &mut LineageNode,
384 scope: &Scope,
385 dialect: Option<DialectType>,
386 table: &str,
387 col_name: &str,
388 parent_name: &str,
389 trim_selects: bool,
390 all_cte_scopes: &[&Scope],
391) {
392 if scope.cte_sources.contains_key(table) {
394 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
395 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
397 if let Ok(child) = to_node_inner(
398 ColumnRef::Name(col_name),
399 child_scope,
400 dialect,
401 parent_name,
402 table,
403 parent_name,
404 trim_selects,
405 &ancestors,
406 ) {
407 node.downstream.push(child);
408 return;
409 }
410 }
411 }
412
413 if let Some(source_info) = scope.sources.get(table) {
415 if source_info.is_scope {
416 if let Some(child_scope) = find_child_scope(scope, table) {
417 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
418 if let Ok(child) = to_node_inner(
419 ColumnRef::Name(col_name),
420 child_scope,
421 dialect,
422 parent_name,
423 table,
424 parent_name,
425 trim_selects,
426 &ancestors,
427 ) {
428 node.downstream.push(child);
429 return;
430 }
431 }
432 }
433 }
434
435 if let Some(source_info) = scope.sources.get(table) {
438 if !source_info.is_scope {
439 node.downstream.push(make_table_column_node_from_source(
440 table,
441 col_name,
442 &source_info.expression,
443 ));
444 return;
445 }
446 }
447
448 node.downstream
450 .push(make_table_column_node(table, col_name));
451}
452
453fn resolve_unqualified_column(
454 node: &mut LineageNode,
455 scope: &Scope,
456 dialect: Option<DialectType>,
457 col_name: &str,
458 parent_name: &str,
459 trim_selects: bool,
460 all_cte_scopes: &[&Scope],
461) {
462 let from_source_names: Vec<&String> = scope
467 .sources
468 .iter()
469 .filter(|(_, info)| !matches!(info.expression, Expression::Cte(_)))
470 .map(|(name, _)| name)
471 .collect();
472
473 if from_source_names.len() == 1 {
474 let tbl = from_source_names[0];
475 resolve_qualified_column(
476 node,
477 scope,
478 dialect,
479 tbl,
480 col_name,
481 parent_name,
482 trim_selects,
483 all_cte_scopes,
484 );
485 return;
486 }
487
488 let child = LineageNode::new(
490 col_name.to_string(),
491 Expression::Column(crate::expressions::Column {
492 name: crate::expressions::Identifier::new(col_name.to_string()),
493 table: None,
494 join_mark: false,
495 trailing_comments: vec![],
496 }),
497 node.source.clone(),
498 );
499 node.downstream.push(child);
500}
501
502fn get_alias_or_name(expr: &Expression) -> Option<String> {
508 match expr {
509 Expression::Alias(alias) => Some(alias.alias.name.clone()),
510 Expression::Column(col) => Some(col.name.name.clone()),
511 Expression::Identifier(id) => Some(id.name.clone()),
512 Expression::Star(_) => Some("*".to_string()),
513 _ => None,
514 }
515}
516
517fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
519 match column {
520 ColumnRef::Name(n) => n.to_string(),
521 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
522 }
523}
524
525fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
527 if let Expression::Select(ref select) = scope_expr {
528 match column {
529 ColumnRef::Name(name) => {
530 for expr in &select.expressions {
531 if get_alias_or_name(expr).as_deref() == Some(name) {
532 return Ok(expr.clone());
533 }
534 }
535 Err(crate::error::Error::parse(
536 format!("Cannot find column '{}' in query", name),
537 0,
538 0,
539 ))
540 }
541 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
542 crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0)
543 }),
544 }
545 } else {
546 Err(crate::error::Error::parse(
547 "Expected SELECT expression for column lookup",
548 0,
549 0,
550 ))
551 }
552}
553
554fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
556 let mut expr = set_op_expr;
557 loop {
558 match expr {
559 Expression::Union(u) => expr = &u.left,
560 Expression::Intersect(i) => expr = &i.left,
561 Expression::Except(e) => expr = &e.left,
562 Expression::Select(select) => {
563 for (i, e) in select.expressions.iter().enumerate() {
564 if get_alias_or_name(e).as_deref() == Some(name) {
565 return Ok(i);
566 }
567 }
568 return Err(crate::error::Error::parse(
569 format!("Cannot find column '{}' in set operation", name),
570 0,
571 0,
572 ));
573 }
574 _ => {
575 return Err(crate::error::Error::parse(
576 "Expected SELECT or set operation",
577 0,
578 0,
579 ))
580 }
581 }
582 }
583}
584
585fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
587 if let Expression::Select(select) = select_expr {
588 let mut trimmed = select.as_ref().clone();
589 trimmed.expressions = vec![target_expr.clone()];
590 Expression::Select(Box::new(trimmed))
591 } else {
592 select_expr.clone()
593 }
594}
595
596fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
598 if scope.cte_sources.contains_key(source_name) {
600 for cte_scope in &scope.cte_scopes {
601 if let Expression::Cte(cte) = &cte_scope.expression {
602 if cte.alias.name == source_name {
603 return Some(cte_scope);
604 }
605 }
606 }
607 }
608
609 if let Some(source_info) = scope.sources.get(source_name) {
611 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
612 if let Expression::Subquery(sq) = &source_info.expression {
613 for dt_scope in &scope.derived_table_scopes {
614 if dt_scope.expression == sq.this {
615 return Some(dt_scope);
616 }
617 }
618 }
619 }
620 }
621
622 None
623}
624
625fn find_child_scope_in<'a>(
629 all_cte_scopes: &[&'a Scope],
630 scope: &'a Scope,
631 source_name: &str,
632) -> Option<&'a Scope> {
633 for cte_scope in &scope.cte_scopes {
635 if let Expression::Cte(cte) = &cte_scope.expression {
636 if cte.alias.name == source_name {
637 return Some(cte_scope);
638 }
639 }
640 }
641
642 for cte_scope in all_cte_scopes {
644 if let Expression::Cte(cte) = &cte_scope.expression {
645 if cte.alias.name == source_name {
646 return Some(cte_scope);
647 }
648 }
649 }
650
651 if let Some(source_info) = scope.sources.get(source_name) {
653 if source_info.is_scope {
654 if let Expression::Subquery(sq) = &source_info.expression {
655 for dt_scope in &scope.derived_table_scopes {
656 if dt_scope.expression == sq.this {
657 return Some(dt_scope);
658 }
659 }
660 }
661 }
662 }
663
664 None
665}
666
667fn make_table_column_node(table: &str, column: &str) -> LineageNode {
669 let mut node = LineageNode::new(
670 format!("{}.{}", table, column),
671 Expression::Column(crate::expressions::Column {
672 name: crate::expressions::Identifier::new(column.to_string()),
673 table: Some(crate::expressions::Identifier::new(table.to_string())),
674 join_mark: false,
675 trailing_comments: vec![],
676 }),
677 Expression::Table(crate::expressions::TableRef::new(table)),
678 );
679 node.source_name = table.to_string();
680 node
681}
682
683fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
684 let mut parts: Vec<String> = Vec::new();
685 if let Some(catalog) = &table_ref.catalog {
686 parts.push(catalog.name.clone());
687 }
688 if let Some(schema) = &table_ref.schema {
689 parts.push(schema.name.clone());
690 }
691 parts.push(table_ref.name.name.clone());
692 parts.join(".")
693}
694
695fn make_table_column_node_from_source(
696 table_alias: &str,
697 column: &str,
698 source: &Expression,
699) -> LineageNode {
700 let mut node = LineageNode::new(
701 format!("{}.{}", table_alias, column),
702 Expression::Column(crate::expressions::Column {
703 name: crate::expressions::Identifier::new(column.to_string()),
704 table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
705 join_mark: false,
706 trailing_comments: vec![],
707 }),
708 source.clone(),
709 );
710
711 if let Expression::Table(table_ref) = source {
712 node.source_name = table_name_from_table_ref(table_ref);
713 } else {
714 node.source_name = table_alias.to_string();
715 }
716
717 node
718}
719
720#[derive(Debug, Clone)]
722struct SimpleColumnRef {
723 table: Option<crate::expressions::Identifier>,
724 column: String,
725}
726
727fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
729 let mut refs = Vec::new();
730 collect_column_refs(expr, &mut refs);
731 refs
732}
733
734fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
735 let mut stack: Vec<&Expression> = vec![expr];
736
737 while let Some(current) = stack.pop() {
738 match current {
739 Expression::Column(col) => {
741 refs.push(SimpleColumnRef {
742 table: col.table.clone(),
743 column: col.name.name.clone(),
744 });
745 }
746
747 Expression::Subquery(_) | Expression::Exists(_) => {}
749
750 Expression::And(op)
752 | Expression::Or(op)
753 | Expression::Eq(op)
754 | Expression::Neq(op)
755 | Expression::Lt(op)
756 | Expression::Lte(op)
757 | Expression::Gt(op)
758 | Expression::Gte(op)
759 | Expression::Add(op)
760 | Expression::Sub(op)
761 | Expression::Mul(op)
762 | Expression::Div(op)
763 | Expression::Mod(op)
764 | Expression::BitwiseAnd(op)
765 | Expression::BitwiseOr(op)
766 | Expression::BitwiseXor(op)
767 | Expression::BitwiseLeftShift(op)
768 | Expression::BitwiseRightShift(op)
769 | Expression::Concat(op)
770 | Expression::Adjacent(op)
771 | Expression::TsMatch(op)
772 | Expression::PropertyEQ(op)
773 | Expression::ArrayContainsAll(op)
774 | Expression::ArrayContainedBy(op)
775 | Expression::ArrayOverlaps(op)
776 | Expression::JSONBContainsAllTopKeys(op)
777 | Expression::JSONBContainsAnyTopKeys(op)
778 | Expression::JSONBDeleteAtPath(op)
779 | Expression::ExtendsLeft(op)
780 | Expression::ExtendsRight(op)
781 | Expression::Is(op)
782 | Expression::MemberOf(op)
783 | Expression::NullSafeEq(op)
784 | Expression::NullSafeNeq(op)
785 | Expression::Glob(op)
786 | Expression::Match(op) => {
787 stack.push(&op.left);
788 stack.push(&op.right);
789 }
790
791 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
793 stack.push(&u.this);
794 }
795
796 Expression::Upper(f)
798 | Expression::Lower(f)
799 | Expression::Length(f)
800 | Expression::LTrim(f)
801 | Expression::RTrim(f)
802 | Expression::Reverse(f)
803 | Expression::Abs(f)
804 | Expression::Sqrt(f)
805 | Expression::Cbrt(f)
806 | Expression::Ln(f)
807 | Expression::Exp(f)
808 | Expression::Sign(f)
809 | Expression::Date(f)
810 | Expression::Time(f)
811 | Expression::DateFromUnixDate(f)
812 | Expression::UnixDate(f)
813 | Expression::UnixSeconds(f)
814 | Expression::UnixMillis(f)
815 | Expression::UnixMicros(f)
816 | Expression::TimeStrToDate(f)
817 | Expression::DateToDi(f)
818 | Expression::DiToDate(f)
819 | Expression::TsOrDiToDi(f)
820 | Expression::TsOrDsToDatetime(f)
821 | Expression::TsOrDsToTimestamp(f)
822 | Expression::YearOfWeek(f)
823 | Expression::YearOfWeekIso(f)
824 | Expression::Initcap(f)
825 | Expression::Ascii(f)
826 | Expression::Chr(f)
827 | Expression::Soundex(f)
828 | Expression::ByteLength(f)
829 | Expression::Hex(f)
830 | Expression::LowerHex(f)
831 | Expression::Unicode(f)
832 | Expression::Radians(f)
833 | Expression::Degrees(f)
834 | Expression::Sin(f)
835 | Expression::Cos(f)
836 | Expression::Tan(f)
837 | Expression::Asin(f)
838 | Expression::Acos(f)
839 | Expression::Atan(f)
840 | Expression::IsNan(f)
841 | Expression::IsInf(f)
842 | Expression::ArrayLength(f)
843 | Expression::ArraySize(f)
844 | Expression::Cardinality(f)
845 | Expression::ArrayReverse(f)
846 | Expression::ArrayDistinct(f)
847 | Expression::ArrayFlatten(f)
848 | Expression::ArrayCompact(f)
849 | Expression::Explode(f)
850 | Expression::ExplodeOuter(f)
851 | Expression::ToArray(f)
852 | Expression::MapFromEntries(f)
853 | Expression::MapKeys(f)
854 | Expression::MapValues(f)
855 | Expression::JsonArrayLength(f)
856 | Expression::JsonKeys(f)
857 | Expression::JsonType(f)
858 | Expression::ParseJson(f)
859 | Expression::ToJson(f)
860 | Expression::Typeof(f)
861 | Expression::BitwiseCount(f)
862 | Expression::Year(f)
863 | Expression::Month(f)
864 | Expression::Day(f)
865 | Expression::Hour(f)
866 | Expression::Minute(f)
867 | Expression::Second(f)
868 | Expression::DayOfWeek(f)
869 | Expression::DayOfWeekIso(f)
870 | Expression::DayOfMonth(f)
871 | Expression::DayOfYear(f)
872 | Expression::WeekOfYear(f)
873 | Expression::Quarter(f)
874 | Expression::Epoch(f)
875 | Expression::EpochMs(f)
876 | Expression::TimeStrToUnix(f)
877 | Expression::SHA(f)
878 | Expression::SHA1Digest(f)
879 | Expression::TimeToUnix(f)
880 | Expression::JSONBool(f)
881 | Expression::Int64(f)
882 | Expression::MD5NumberLower64(f)
883 | Expression::MD5NumberUpper64(f)
884 | Expression::DateStrToDate(f)
885 | Expression::DateToDateStr(f) => {
886 stack.push(&f.this);
887 }
888
889 Expression::Power(f)
891 | Expression::NullIf(f)
892 | Expression::IfNull(f)
893 | Expression::Nvl(f)
894 | Expression::UnixToTimeStr(f)
895 | Expression::Contains(f)
896 | Expression::StartsWith(f)
897 | Expression::EndsWith(f)
898 | Expression::Levenshtein(f)
899 | Expression::ModFunc(f)
900 | Expression::Atan2(f)
901 | Expression::IntDiv(f)
902 | Expression::AddMonths(f)
903 | Expression::MonthsBetween(f)
904 | Expression::NextDay(f)
905 | Expression::ArrayContains(f)
906 | Expression::ArrayPosition(f)
907 | Expression::ArrayAppend(f)
908 | Expression::ArrayPrepend(f)
909 | Expression::ArrayUnion(f)
910 | Expression::ArrayExcept(f)
911 | Expression::ArrayRemove(f)
912 | Expression::StarMap(f)
913 | Expression::MapFromArrays(f)
914 | Expression::MapContainsKey(f)
915 | Expression::ElementAt(f)
916 | Expression::JsonMergePatch(f)
917 | Expression::JSONBContains(f)
918 | Expression::JSONBExtract(f) => {
919 stack.push(&f.this);
920 stack.push(&f.expression);
921 }
922
923 Expression::Greatest(f)
925 | Expression::Least(f)
926 | Expression::Coalesce(f)
927 | Expression::ArrayConcat(f)
928 | Expression::ArrayIntersect(f)
929 | Expression::ArrayZip(f)
930 | Expression::MapConcat(f)
931 | Expression::JsonArray(f) => {
932 for e in &f.expressions {
933 stack.push(e);
934 }
935 }
936
937 Expression::Sum(f)
939 | Expression::Avg(f)
940 | Expression::Min(f)
941 | Expression::Max(f)
942 | Expression::ArrayAgg(f)
943 | Expression::CountIf(f)
944 | Expression::Stddev(f)
945 | Expression::StddevPop(f)
946 | Expression::StddevSamp(f)
947 | Expression::Variance(f)
948 | Expression::VarPop(f)
949 | Expression::VarSamp(f)
950 | Expression::Median(f)
951 | Expression::Mode(f)
952 | Expression::First(f)
953 | Expression::Last(f)
954 | Expression::AnyValue(f)
955 | Expression::ApproxDistinct(f)
956 | Expression::ApproxCountDistinct(f)
957 | Expression::LogicalAnd(f)
958 | Expression::LogicalOr(f)
959 | Expression::Skewness(f)
960 | Expression::ArrayConcatAgg(f)
961 | Expression::ArrayUniqueAgg(f)
962 | Expression::BoolXorAgg(f)
963 | Expression::BitwiseAndAgg(f)
964 | Expression::BitwiseOrAgg(f)
965 | Expression::BitwiseXorAgg(f) => {
966 stack.push(&f.this);
967 if let Some(ref filter) = f.filter {
968 stack.push(filter);
969 }
970 if let Some((ref expr, _)) = f.having_max {
971 stack.push(expr);
972 }
973 if let Some(ref limit) = f.limit {
974 stack.push(limit);
975 }
976 }
977
978 Expression::Function(func) => {
980 for arg in &func.args {
981 stack.push(arg);
982 }
983 }
984 Expression::AggregateFunction(func) => {
985 for arg in &func.args {
986 stack.push(arg);
987 }
988 if let Some(ref filter) = func.filter {
989 stack.push(filter);
990 }
991 if let Some(ref limit) = func.limit {
992 stack.push(limit);
993 }
994 }
995
996 Expression::WindowFunction(wf) => {
998 stack.push(&wf.this);
999 }
1000
1001 Expression::Alias(a) => {
1003 stack.push(&a.this);
1004 }
1005 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1006 stack.push(&c.this);
1007 if let Some(ref fmt) = c.format {
1008 stack.push(fmt);
1009 }
1010 if let Some(ref def) = c.default {
1011 stack.push(def);
1012 }
1013 }
1014 Expression::Paren(p) => {
1015 stack.push(&p.this);
1016 }
1017 Expression::Annotated(a) => {
1018 stack.push(&a.this);
1019 }
1020 Expression::Case(case) => {
1021 if let Some(ref operand) = case.operand {
1022 stack.push(operand);
1023 }
1024 for (cond, result) in &case.whens {
1025 stack.push(cond);
1026 stack.push(result);
1027 }
1028 if let Some(ref else_expr) = case.else_ {
1029 stack.push(else_expr);
1030 }
1031 }
1032 Expression::Collation(c) => {
1033 stack.push(&c.this);
1034 }
1035 Expression::In(i) => {
1036 stack.push(&i.this);
1037 for e in &i.expressions {
1038 stack.push(e);
1039 }
1040 if let Some(ref q) = i.query {
1041 stack.push(q);
1042 }
1043 if let Some(ref u) = i.unnest {
1044 stack.push(u);
1045 }
1046 }
1047 Expression::Between(b) => {
1048 stack.push(&b.this);
1049 stack.push(&b.low);
1050 stack.push(&b.high);
1051 }
1052 Expression::IsNull(n) => {
1053 stack.push(&n.this);
1054 }
1055 Expression::IsTrue(t) | Expression::IsFalse(t) => {
1056 stack.push(&t.this);
1057 }
1058 Expression::IsJson(j) => {
1059 stack.push(&j.this);
1060 }
1061 Expression::Like(l) | Expression::ILike(l) => {
1062 stack.push(&l.left);
1063 stack.push(&l.right);
1064 if let Some(ref esc) = l.escape {
1065 stack.push(esc);
1066 }
1067 }
1068 Expression::SimilarTo(s) => {
1069 stack.push(&s.this);
1070 stack.push(&s.pattern);
1071 if let Some(ref esc) = s.escape {
1072 stack.push(esc);
1073 }
1074 }
1075 Expression::Ordered(o) => {
1076 stack.push(&o.this);
1077 }
1078 Expression::Array(a) => {
1079 for e in &a.expressions {
1080 stack.push(e);
1081 }
1082 }
1083 Expression::Tuple(t) => {
1084 for e in &t.expressions {
1085 stack.push(e);
1086 }
1087 }
1088 Expression::Struct(s) => {
1089 for (_, e) in &s.fields {
1090 stack.push(e);
1091 }
1092 }
1093 Expression::Subscript(s) => {
1094 stack.push(&s.this);
1095 stack.push(&s.index);
1096 }
1097 Expression::Dot(d) => {
1098 stack.push(&d.this);
1099 }
1100 Expression::MethodCall(m) => {
1101 stack.push(&m.this);
1102 for arg in &m.args {
1103 stack.push(arg);
1104 }
1105 }
1106 Expression::ArraySlice(s) => {
1107 stack.push(&s.this);
1108 if let Some(ref start) = s.start {
1109 stack.push(start);
1110 }
1111 if let Some(ref end) = s.end {
1112 stack.push(end);
1113 }
1114 }
1115 Expression::Lambda(l) => {
1116 stack.push(&l.body);
1117 }
1118 Expression::NamedArgument(n) => {
1119 stack.push(&n.value);
1120 }
1121 Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1122 stack.push(e);
1123 }
1124
1125 Expression::Substring(f) => {
1127 stack.push(&f.this);
1128 stack.push(&f.start);
1129 if let Some(ref len) = f.length {
1130 stack.push(len);
1131 }
1132 }
1133 Expression::Trim(f) => {
1134 stack.push(&f.this);
1135 if let Some(ref chars) = f.characters {
1136 stack.push(chars);
1137 }
1138 }
1139 Expression::Replace(f) => {
1140 stack.push(&f.this);
1141 stack.push(&f.old);
1142 stack.push(&f.new);
1143 }
1144 Expression::IfFunc(f) => {
1145 stack.push(&f.condition);
1146 stack.push(&f.true_value);
1147 if let Some(ref fv) = f.false_value {
1148 stack.push(fv);
1149 }
1150 }
1151 Expression::Nvl2(f) => {
1152 stack.push(&f.this);
1153 stack.push(&f.true_value);
1154 stack.push(&f.false_value);
1155 }
1156 Expression::ConcatWs(f) => {
1157 stack.push(&f.separator);
1158 for e in &f.expressions {
1159 stack.push(e);
1160 }
1161 }
1162 Expression::Count(f) => {
1163 if let Some(ref this) = f.this {
1164 stack.push(this);
1165 }
1166 if let Some(ref filter) = f.filter {
1167 stack.push(filter);
1168 }
1169 }
1170 Expression::GroupConcat(f) => {
1171 stack.push(&f.this);
1172 if let Some(ref sep) = f.separator {
1173 stack.push(sep);
1174 }
1175 if let Some(ref filter) = f.filter {
1176 stack.push(filter);
1177 }
1178 }
1179 Expression::StringAgg(f) => {
1180 stack.push(&f.this);
1181 if let Some(ref sep) = f.separator {
1182 stack.push(sep);
1183 }
1184 if let Some(ref filter) = f.filter {
1185 stack.push(filter);
1186 }
1187 if let Some(ref limit) = f.limit {
1188 stack.push(limit);
1189 }
1190 }
1191 Expression::ListAgg(f) => {
1192 stack.push(&f.this);
1193 if let Some(ref sep) = f.separator {
1194 stack.push(sep);
1195 }
1196 if let Some(ref filter) = f.filter {
1197 stack.push(filter);
1198 }
1199 }
1200 Expression::SumIf(f) => {
1201 stack.push(&f.this);
1202 stack.push(&f.condition);
1203 if let Some(ref filter) = f.filter {
1204 stack.push(filter);
1205 }
1206 }
1207 Expression::DateAdd(f) | Expression::DateSub(f) => {
1208 stack.push(&f.this);
1209 stack.push(&f.interval);
1210 }
1211 Expression::DateDiff(f) => {
1212 stack.push(&f.this);
1213 stack.push(&f.expression);
1214 }
1215 Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1216 stack.push(&f.this);
1217 }
1218 Expression::Extract(f) => {
1219 stack.push(&f.this);
1220 }
1221 Expression::Round(f) => {
1222 stack.push(&f.this);
1223 if let Some(ref d) = f.decimals {
1224 stack.push(d);
1225 }
1226 }
1227 Expression::Floor(f) => {
1228 stack.push(&f.this);
1229 if let Some(ref s) = f.scale {
1230 stack.push(s);
1231 }
1232 if let Some(ref t) = f.to {
1233 stack.push(t);
1234 }
1235 }
1236 Expression::Ceil(f) => {
1237 stack.push(&f.this);
1238 if let Some(ref d) = f.decimals {
1239 stack.push(d);
1240 }
1241 if let Some(ref t) = f.to {
1242 stack.push(t);
1243 }
1244 }
1245 Expression::Log(f) => {
1246 stack.push(&f.this);
1247 if let Some(ref b) = f.base {
1248 stack.push(b);
1249 }
1250 }
1251 Expression::AtTimeZone(f) => {
1252 stack.push(&f.this);
1253 stack.push(&f.zone);
1254 }
1255 Expression::Lead(f) | Expression::Lag(f) => {
1256 stack.push(&f.this);
1257 if let Some(ref off) = f.offset {
1258 stack.push(off);
1259 }
1260 if let Some(ref def) = f.default {
1261 stack.push(def);
1262 }
1263 }
1264 Expression::FirstValue(f) | Expression::LastValue(f) => {
1265 stack.push(&f.this);
1266 }
1267 Expression::NthValue(f) => {
1268 stack.push(&f.this);
1269 stack.push(&f.offset);
1270 }
1271 Expression::Position(f) => {
1272 stack.push(&f.substring);
1273 stack.push(&f.string);
1274 if let Some(ref start) = f.start {
1275 stack.push(start);
1276 }
1277 }
1278 Expression::Decode(f) => {
1279 stack.push(&f.this);
1280 for (search, result) in &f.search_results {
1281 stack.push(search);
1282 stack.push(result);
1283 }
1284 if let Some(ref def) = f.default {
1285 stack.push(def);
1286 }
1287 }
1288 Expression::CharFunc(f) => {
1289 for arg in &f.args {
1290 stack.push(arg);
1291 }
1292 }
1293 Expression::ArraySort(f) => {
1294 stack.push(&f.this);
1295 if let Some(ref cmp) = f.comparator {
1296 stack.push(cmp);
1297 }
1298 }
1299 Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1300 stack.push(&f.this);
1301 stack.push(&f.separator);
1302 if let Some(ref nr) = f.null_replacement {
1303 stack.push(nr);
1304 }
1305 }
1306 Expression::ArrayFilter(f) => {
1307 stack.push(&f.this);
1308 stack.push(&f.filter);
1309 }
1310 Expression::ArrayTransform(f) => {
1311 stack.push(&f.this);
1312 stack.push(&f.transform);
1313 }
1314 Expression::Sequence(f)
1315 | Expression::Generate(f)
1316 | Expression::ExplodingGenerateSeries(f) => {
1317 stack.push(&f.start);
1318 stack.push(&f.stop);
1319 if let Some(ref step) = f.step {
1320 stack.push(step);
1321 }
1322 }
1323 Expression::JsonExtract(f)
1324 | Expression::JsonExtractScalar(f)
1325 | Expression::JsonQuery(f)
1326 | Expression::JsonValue(f) => {
1327 stack.push(&f.this);
1328 stack.push(&f.path);
1329 }
1330 Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1331 stack.push(&f.this);
1332 for p in &f.paths {
1333 stack.push(p);
1334 }
1335 }
1336 Expression::JsonObject(f) => {
1337 for (k, v) in &f.pairs {
1338 stack.push(k);
1339 stack.push(v);
1340 }
1341 }
1342 Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1343 stack.push(&f.this);
1344 for (path, val) in &f.path_values {
1345 stack.push(path);
1346 stack.push(val);
1347 }
1348 }
1349 Expression::Overlay(f) => {
1350 stack.push(&f.this);
1351 stack.push(&f.replacement);
1352 stack.push(&f.from);
1353 if let Some(ref len) = f.length {
1354 stack.push(len);
1355 }
1356 }
1357 Expression::Convert(f) => {
1358 stack.push(&f.this);
1359 if let Some(ref style) = f.style {
1360 stack.push(style);
1361 }
1362 }
1363 Expression::ApproxPercentile(f) => {
1364 stack.push(&f.this);
1365 stack.push(&f.percentile);
1366 if let Some(ref acc) = f.accuracy {
1367 stack.push(acc);
1368 }
1369 if let Some(ref filter) = f.filter {
1370 stack.push(filter);
1371 }
1372 }
1373 Expression::Percentile(f)
1374 | Expression::PercentileCont(f)
1375 | Expression::PercentileDisc(f) => {
1376 stack.push(&f.this);
1377 stack.push(&f.percentile);
1378 if let Some(ref filter) = f.filter {
1379 stack.push(filter);
1380 }
1381 }
1382 Expression::WithinGroup(f) => {
1383 stack.push(&f.this);
1384 }
1385 Expression::Left(f) | Expression::Right(f) => {
1386 stack.push(&f.this);
1387 stack.push(&f.length);
1388 }
1389 Expression::Repeat(f) => {
1390 stack.push(&f.this);
1391 stack.push(&f.times);
1392 }
1393 Expression::Lpad(f) | Expression::Rpad(f) => {
1394 stack.push(&f.this);
1395 stack.push(&f.length);
1396 if let Some(ref fill) = f.fill {
1397 stack.push(fill);
1398 }
1399 }
1400 Expression::Split(f) => {
1401 stack.push(&f.this);
1402 stack.push(&f.delimiter);
1403 }
1404 Expression::RegexpLike(f) => {
1405 stack.push(&f.this);
1406 stack.push(&f.pattern);
1407 if let Some(ref flags) = f.flags {
1408 stack.push(flags);
1409 }
1410 }
1411 Expression::RegexpReplace(f) => {
1412 stack.push(&f.this);
1413 stack.push(&f.pattern);
1414 stack.push(&f.replacement);
1415 if let Some(ref flags) = f.flags {
1416 stack.push(flags);
1417 }
1418 }
1419 Expression::RegexpExtract(f) => {
1420 stack.push(&f.this);
1421 stack.push(&f.pattern);
1422 if let Some(ref group) = f.group {
1423 stack.push(group);
1424 }
1425 }
1426 Expression::ToDate(f) => {
1427 stack.push(&f.this);
1428 if let Some(ref fmt) = f.format {
1429 stack.push(fmt);
1430 }
1431 }
1432 Expression::ToTimestamp(f) => {
1433 stack.push(&f.this);
1434 if let Some(ref fmt) = f.format {
1435 stack.push(fmt);
1436 }
1437 }
1438 Expression::DateFormat(f) | Expression::FormatDate(f) => {
1439 stack.push(&f.this);
1440 stack.push(&f.format);
1441 }
1442 Expression::LastDay(f) => {
1443 stack.push(&f.this);
1444 }
1445 Expression::FromUnixtime(f) => {
1446 stack.push(&f.this);
1447 if let Some(ref fmt) = f.format {
1448 stack.push(fmt);
1449 }
1450 }
1451 Expression::UnixTimestamp(f) => {
1452 if let Some(ref this) = f.this {
1453 stack.push(this);
1454 }
1455 if let Some(ref fmt) = f.format {
1456 stack.push(fmt);
1457 }
1458 }
1459 Expression::MakeDate(f) => {
1460 stack.push(&f.year);
1461 stack.push(&f.month);
1462 stack.push(&f.day);
1463 }
1464 Expression::MakeTimestamp(f) => {
1465 stack.push(&f.year);
1466 stack.push(&f.month);
1467 stack.push(&f.day);
1468 stack.push(&f.hour);
1469 stack.push(&f.minute);
1470 stack.push(&f.second);
1471 if let Some(ref tz) = f.timezone {
1472 stack.push(tz);
1473 }
1474 }
1475 Expression::TruncFunc(f) => {
1476 stack.push(&f.this);
1477 if let Some(ref d) = f.decimals {
1478 stack.push(d);
1479 }
1480 }
1481 Expression::ArrayFunc(f) => {
1482 for e in &f.expressions {
1483 stack.push(e);
1484 }
1485 }
1486 Expression::Unnest(f) => {
1487 stack.push(&f.this);
1488 for e in &f.expressions {
1489 stack.push(e);
1490 }
1491 }
1492 Expression::StructFunc(f) => {
1493 for (_, e) in &f.fields {
1494 stack.push(e);
1495 }
1496 }
1497 Expression::StructExtract(f) => {
1498 stack.push(&f.this);
1499 }
1500 Expression::NamedStruct(f) => {
1501 for (k, v) in &f.pairs {
1502 stack.push(k);
1503 stack.push(v);
1504 }
1505 }
1506 Expression::MapFunc(f) => {
1507 for k in &f.keys {
1508 stack.push(k);
1509 }
1510 for v in &f.values {
1511 stack.push(v);
1512 }
1513 }
1514 Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1515 stack.push(&f.this);
1516 stack.push(&f.transform);
1517 }
1518 Expression::JsonArrayAgg(f) => {
1519 stack.push(&f.this);
1520 if let Some(ref filter) = f.filter {
1521 stack.push(filter);
1522 }
1523 }
1524 Expression::JsonObjectAgg(f) => {
1525 stack.push(&f.key);
1526 stack.push(&f.value);
1527 if let Some(ref filter) = f.filter {
1528 stack.push(filter);
1529 }
1530 }
1531 Expression::NTile(f) => {
1532 if let Some(ref n) = f.num_buckets {
1533 stack.push(n);
1534 }
1535 }
1536 Expression::Rand(f) => {
1537 if let Some(ref s) = f.seed {
1538 stack.push(s);
1539 }
1540 if let Some(ref lo) = f.lower {
1541 stack.push(lo);
1542 }
1543 if let Some(ref hi) = f.upper {
1544 stack.push(hi);
1545 }
1546 }
1547 Expression::Any(q) | Expression::All(q) => {
1548 stack.push(&q.this);
1549 stack.push(&q.subquery);
1550 }
1551 Expression::Overlaps(o) => {
1552 if let Some(ref this) = o.this {
1553 stack.push(this);
1554 }
1555 if let Some(ref expr) = o.expression {
1556 stack.push(expr);
1557 }
1558 if let Some(ref ls) = o.left_start {
1559 stack.push(ls);
1560 }
1561 if let Some(ref le) = o.left_end {
1562 stack.push(le);
1563 }
1564 if let Some(ref rs) = o.right_start {
1565 stack.push(rs);
1566 }
1567 if let Some(ref re) = o.right_end {
1568 stack.push(re);
1569 }
1570 }
1571 Expression::Interval(i) => {
1572 if let Some(ref this) = i.this {
1573 stack.push(this);
1574 }
1575 }
1576 Expression::TimeStrToTime(f) => {
1577 stack.push(&f.this);
1578 if let Some(ref zone) = f.zone {
1579 stack.push(zone);
1580 }
1581 }
1582 Expression::JSONBExtractScalar(f) => {
1583 stack.push(&f.this);
1584 stack.push(&f.expression);
1585 if let Some(ref jt) = f.json_type {
1586 stack.push(jt);
1587 }
1588 }
1589
1590 _ => {}
1595 }
1596 }
1597}
1598
1599#[cfg(test)]
1604mod tests {
1605 use super::*;
1606 use crate::dialects::{Dialect, DialectType};
1607
1608 fn parse(sql: &str) -> Expression {
1609 let dialect = Dialect::get(DialectType::Generic);
1610 let ast = dialect.parse(sql).unwrap();
1611 ast.into_iter().next().unwrap()
1612 }
1613
1614 #[test]
1615 fn test_simple_lineage() {
1616 let expr = parse("SELECT a FROM t");
1617 let node = lineage("a", &expr, None, false).unwrap();
1618
1619 assert_eq!(node.name, "a");
1620 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1621 let names = node.downstream_names();
1623 assert!(
1624 names.iter().any(|n| n == "t.a"),
1625 "Expected t.a in downstream, got: {:?}",
1626 names
1627 );
1628 }
1629
1630 #[test]
1631 fn test_lineage_walk() {
1632 let root = LineageNode {
1633 name: "col_a".to_string(),
1634 expression: Expression::Null(crate::expressions::Null),
1635 source: Expression::Null(crate::expressions::Null),
1636 downstream: vec![LineageNode::new(
1637 "t.a",
1638 Expression::Null(crate::expressions::Null),
1639 Expression::Null(crate::expressions::Null),
1640 )],
1641 source_name: String::new(),
1642 reference_node_name: String::new(),
1643 };
1644
1645 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1646 assert_eq!(names.len(), 2);
1647 assert_eq!(names[0], "col_a");
1648 assert_eq!(names[1], "t.a");
1649 }
1650
1651 #[test]
1652 fn test_aliased_column() {
1653 let expr = parse("SELECT a + 1 AS b FROM t");
1654 let node = lineage("b", &expr, None, false).unwrap();
1655
1656 assert_eq!(node.name, "b");
1657 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1659 assert!(
1660 all_names.iter().any(|n| n.contains("a")),
1661 "Expected to trace to column a, got: {:?}",
1662 all_names
1663 );
1664 }
1665
1666 #[test]
1667 fn test_qualified_column() {
1668 let expr = parse("SELECT t.a FROM t");
1669 let node = lineage("a", &expr, None, false).unwrap();
1670
1671 assert_eq!(node.name, "a");
1672 let names = node.downstream_names();
1673 assert!(
1674 names.iter().any(|n| n == "t.a"),
1675 "Expected t.a, got: {:?}",
1676 names
1677 );
1678 }
1679
1680 #[test]
1681 fn test_unqualified_column() {
1682 let expr = parse("SELECT a FROM t");
1683 let node = lineage("a", &expr, None, false).unwrap();
1684
1685 let names = node.downstream_names();
1687 assert!(
1688 names.iter().any(|n| n == "t.a"),
1689 "Expected t.a, got: {:?}",
1690 names
1691 );
1692 }
1693
1694 #[test]
1695 fn test_lineage_join() {
1696 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1697
1698 let node_a = lineage("a", &expr, None, false).unwrap();
1699 let names_a = node_a.downstream_names();
1700 assert!(
1701 names_a.iter().any(|n| n == "t.a"),
1702 "Expected t.a, got: {:?}",
1703 names_a
1704 );
1705
1706 let node_b = lineage("b", &expr, None, false).unwrap();
1707 let names_b = node_b.downstream_names();
1708 assert!(
1709 names_b.iter().any(|n| n == "s.b"),
1710 "Expected s.b, got: {:?}",
1711 names_b
1712 );
1713 }
1714
1715 #[test]
1716 fn test_lineage_alias_leaf_has_resolved_source_name() {
1717 let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
1718 let node = lineage("col1", &expr, None, false).unwrap();
1719
1720 let names = node.downstream_names();
1722 assert!(
1723 names.iter().any(|n| n == "t1.col1"),
1724 "Expected aliased column edge t1.col1, got: {:?}",
1725 names
1726 );
1727
1728 let leaf = node
1730 .downstream
1731 .iter()
1732 .find(|n| n.name == "t1.col1")
1733 .expect("Expected t1.col1 leaf");
1734 assert_eq!(leaf.source_name, "table1");
1735 match &leaf.source {
1736 Expression::Table(table) => assert_eq!(table.name.name, "table1"),
1737 _ => panic!("Expected leaf source to be a table expression"),
1738 }
1739 }
1740
1741 #[test]
1742 fn test_lineage_derived_table() {
1743 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
1744 let node = lineage("a", &expr, None, false).unwrap();
1745
1746 assert_eq!(node.name, "a");
1747 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1749 assert!(
1750 all_names.iter().any(|n| n == "t.a"),
1751 "Expected to trace through derived table to t.a, got: {:?}",
1752 all_names
1753 );
1754 }
1755
1756 #[test]
1757 fn test_lineage_cte() {
1758 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
1759 let node = lineage("a", &expr, None, false).unwrap();
1760
1761 assert_eq!(node.name, "a");
1762 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1763 assert!(
1764 all_names.iter().any(|n| n == "t.a"),
1765 "Expected to trace through CTE to t.a, got: {:?}",
1766 all_names
1767 );
1768 }
1769
1770 #[test]
1771 fn test_lineage_union() {
1772 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
1773 let node = lineage("a", &expr, None, false).unwrap();
1774
1775 assert_eq!(node.name, "a");
1776 assert_eq!(
1778 node.downstream.len(),
1779 2,
1780 "Expected 2 branches for UNION, got {}",
1781 node.downstream.len()
1782 );
1783 }
1784
1785 #[test]
1786 fn test_lineage_cte_union() {
1787 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
1788 let node = lineage("a", &expr, None, false).unwrap();
1789
1790 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1792 assert!(
1793 all_names.len() >= 3,
1794 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
1795 all_names
1796 );
1797 }
1798
1799 #[test]
1800 fn test_lineage_star() {
1801 let expr = parse("SELECT * FROM t");
1802 let node = lineage("*", &expr, None, false).unwrap();
1803
1804 assert_eq!(node.name, "*");
1805 assert!(
1807 !node.downstream.is_empty(),
1808 "Star should produce downstream nodes"
1809 );
1810 }
1811
1812 #[test]
1813 fn test_lineage_subquery_in_select() {
1814 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
1815 let node = lineage("x", &expr, None, false).unwrap();
1816
1817 assert_eq!(node.name, "x");
1818 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1820 assert!(
1821 all_names.len() >= 2,
1822 "Expected tracing into scalar subquery, got: {:?}",
1823 all_names
1824 );
1825 }
1826
1827 #[test]
1828 fn test_lineage_multiple_columns() {
1829 let expr = parse("SELECT a, b FROM t");
1830
1831 let node_a = lineage("a", &expr, None, false).unwrap();
1832 let node_b = lineage("b", &expr, None, false).unwrap();
1833
1834 assert_eq!(node_a.name, "a");
1835 assert_eq!(node_b.name, "b");
1836
1837 let names_a = node_a.downstream_names();
1839 let names_b = node_b.downstream_names();
1840 assert!(names_a.iter().any(|n| n == "t.a"));
1841 assert!(names_b.iter().any(|n| n == "t.b"));
1842 }
1843
1844 #[test]
1845 fn test_get_source_tables() {
1846 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1847 let node = lineage("a", &expr, None, false).unwrap();
1848
1849 let tables = get_source_tables(&node);
1850 assert!(
1851 tables.contains("t"),
1852 "Expected source table 't', got: {:?}",
1853 tables
1854 );
1855 }
1856
1857 #[test]
1858 fn test_lineage_column_not_found() {
1859 let expr = parse("SELECT a FROM t");
1860 let result = lineage("nonexistent", &expr, None, false);
1861 assert!(result.is_err());
1862 }
1863
1864 #[test]
1865 fn test_lineage_nested_cte() {
1866 let expr = parse(
1867 "WITH cte1 AS (SELECT a FROM t), \
1868 cte2 AS (SELECT a FROM cte1) \
1869 SELECT a FROM cte2",
1870 );
1871 let node = lineage("a", &expr, None, false).unwrap();
1872
1873 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1875 assert!(
1876 all_names.len() >= 3,
1877 "Expected to trace through nested CTEs, got: {:?}",
1878 all_names
1879 );
1880 }
1881
1882 #[test]
1883 fn test_trim_selects_true() {
1884 let expr = parse("SELECT a, b, c FROM t");
1885 let node = lineage("a", &expr, None, true).unwrap();
1886
1887 if let Expression::Select(select) = &node.source {
1889 assert_eq!(
1890 select.expressions.len(),
1891 1,
1892 "Trimmed source should have 1 expression, got {}",
1893 select.expressions.len()
1894 );
1895 } else {
1896 panic!("Expected Select source");
1897 }
1898 }
1899
1900 #[test]
1901 fn test_trim_selects_false() {
1902 let expr = parse("SELECT a, b, c FROM t");
1903 let node = lineage("a", &expr, None, false).unwrap();
1904
1905 if let Expression::Select(select) = &node.source {
1907 assert_eq!(
1908 select.expressions.len(),
1909 3,
1910 "Untrimmed source should have 3 expressions"
1911 );
1912 } else {
1913 panic!("Expected Select source");
1914 }
1915 }
1916
1917 #[test]
1918 fn test_lineage_expression_in_select() {
1919 let expr = parse("SELECT a + b AS c FROM t");
1920 let node = lineage("c", &expr, None, false).unwrap();
1921
1922 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1924 assert!(
1925 all_names.len() >= 3,
1926 "Expected to trace a + b to both columns, got: {:?}",
1927 all_names
1928 );
1929 }
1930
1931 #[test]
1932 fn test_set_operation_by_index() {
1933 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1934
1935 let node = lineage("a", &expr, None, false).unwrap();
1937
1938 assert_eq!(node.downstream.len(), 2);
1940 }
1941
1942 fn print_node(node: &LineageNode, indent: usize) {
1945 let pad = " ".repeat(indent);
1946 println!(
1947 "{pad}name={:?} source_name={:?}",
1948 node.name, node.source_name
1949 );
1950 for child in &node.downstream {
1951 print_node(child, indent + 1);
1952 }
1953 }
1954
1955 #[test]
1956 fn test_issue18_repro() {
1957 let query = "SELECT UPPER(name) as upper_name FROM users";
1959 println!("Query: {query}\n");
1960
1961 let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
1962 let exprs = dialect.parse(query).unwrap();
1963 let expr = &exprs[0];
1964
1965 let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
1966 println!("lineage(\"upper_name\"):");
1967 print_node(&node, 1);
1968
1969 let names = node.downstream_names();
1970 assert!(
1971 names.iter().any(|n| n == "users.name"),
1972 "Expected users.name in downstream, got: {:?}",
1973 names
1974 );
1975 }
1976
1977 #[test]
1978 fn test_lineage_upper_function() {
1979 let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
1980 let node = lineage("upper_name", &expr, None, false).unwrap();
1981
1982 let names = node.downstream_names();
1983 assert!(
1984 names.iter().any(|n| n == "users.name"),
1985 "Expected users.name in downstream, got: {:?}",
1986 names
1987 );
1988 }
1989
1990 #[test]
1991 fn test_lineage_round_function() {
1992 let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
1993 let node = lineage("rounded", &expr, None, false).unwrap();
1994
1995 let names = node.downstream_names();
1996 assert!(
1997 names.iter().any(|n| n == "products.price"),
1998 "Expected products.price in downstream, got: {:?}",
1999 names
2000 );
2001 }
2002
2003 #[test]
2004 fn test_lineage_coalesce_function() {
2005 let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2006 let node = lineage("val", &expr, None, false).unwrap();
2007
2008 let names = node.downstream_names();
2009 assert!(
2010 names.iter().any(|n| n == "t.a"),
2011 "Expected t.a in downstream, got: {:?}",
2012 names
2013 );
2014 assert!(
2015 names.iter().any(|n| n == "t.b"),
2016 "Expected t.b in downstream, got: {:?}",
2017 names
2018 );
2019 }
2020
2021 #[test]
2022 fn test_lineage_count_function() {
2023 let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2024 let node = lineage("cnt", &expr, None, false).unwrap();
2025
2026 let names = node.downstream_names();
2027 assert!(
2028 names.iter().any(|n| n == "t.id"),
2029 "Expected t.id in downstream, got: {:?}",
2030 names
2031 );
2032 }
2033
2034 #[test]
2035 fn test_lineage_sum_function() {
2036 let expr = parse("SELECT SUM(amount) AS total FROM t");
2037 let node = lineage("total", &expr, None, false).unwrap();
2038
2039 let names = node.downstream_names();
2040 assert!(
2041 names.iter().any(|n| n == "t.amount"),
2042 "Expected t.amount in downstream, got: {:?}",
2043 names
2044 );
2045 }
2046
2047 #[test]
2048 fn test_lineage_case_with_nested_functions() {
2049 let expr =
2050 parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2051 let node = lineage("result", &expr, None, false).unwrap();
2052
2053 let names = node.downstream_names();
2054 assert!(
2055 names.iter().any(|n| n == "t.x"),
2056 "Expected t.x in downstream, got: {:?}",
2057 names
2058 );
2059 assert!(
2060 names.iter().any(|n| n == "t.name"),
2061 "Expected t.name in downstream, got: {:?}",
2062 names
2063 );
2064 }
2065
2066 #[test]
2067 fn test_lineage_substring_function() {
2068 let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2069 let node = lineage("short", &expr, None, false).unwrap();
2070
2071 let names = node.downstream_names();
2072 assert!(
2073 names.iter().any(|n| n == "t.name"),
2074 "Expected t.name in downstream, got: {:?}",
2075 names
2076 );
2077 }
2078}