1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19 pub name: String,
21 pub expression: Expression,
23 pub source: Expression,
25 pub downstream: Vec<LineageNode>,
27 pub source_name: String,
29 pub reference_node_name: String,
31}
32
33impl LineageNode {
34 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36 Self {
37 name: name.into(),
38 expression,
39 source,
40 downstream: Vec::new(),
41 source_name: String::new(),
42 reference_node_name: String::new(),
43 }
44 }
45
46 pub fn walk(&self) -> LineageWalker<'_> {
48 LineageWalker { stack: vec![self] }
49 }
50
51 pub fn downstream_names(&self) -> Vec<String> {
53 self.downstream.iter().map(|n| n.name.clone()).collect()
54 }
55}
56
57pub struct LineageWalker<'a> {
59 stack: Vec<&'a LineageNode>,
60}
61
62impl<'a> Iterator for LineageWalker<'a> {
63 type Item = &'a LineageNode;
64
65 fn next(&mut self) -> Option<Self::Item> {
66 if let Some(node) = self.stack.pop() {
67 for child in node.downstream.iter().rev() {
69 self.stack.push(child);
70 }
71 Some(node)
72 } else {
73 None
74 }
75 }
76}
77
78enum ColumnRef<'a> {
84 Name(&'a str),
85 Index(usize),
86}
87
88pub fn lineage(
114 column: &str,
115 sql: &Expression,
116 dialect: Option<DialectType>,
117 trim_selects: bool,
118) -> Result<LineageNode> {
119 let scope = build_scope(sql);
120 to_node(
121 ColumnRef::Name(column),
122 &scope,
123 dialect,
124 "",
125 "",
126 "",
127 trim_selects,
128 )
129}
130
131pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
133 let mut tables = HashSet::new();
134 collect_source_tables(node, &mut tables);
135 tables
136}
137
138pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
140 if let Expression::Table(table) = &node.source {
141 tables.insert(table.name.name.clone());
142 }
143 for child in &node.downstream {
144 collect_source_tables(child, tables);
145 }
146}
147
148fn to_node(
154 column: ColumnRef<'_>,
155 scope: &Scope,
156 dialect: Option<DialectType>,
157 scope_name: &str,
158 source_name: &str,
159 reference_node_name: &str,
160 trim_selects: bool,
161) -> Result<LineageNode> {
162 to_node_inner(
163 column,
164 scope,
165 dialect,
166 scope_name,
167 source_name,
168 reference_node_name,
169 trim_selects,
170 &[],
171 )
172}
173
174fn to_node_inner(
175 column: ColumnRef<'_>,
176 scope: &Scope,
177 dialect: Option<DialectType>,
178 scope_name: &str,
179 source_name: &str,
180 reference_node_name: &str,
181 trim_selects: bool,
182 ancestor_cte_scopes: &[Scope],
183) -> Result<LineageNode> {
184 let scope_expr = &scope.expression;
185
186 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
188 for s in ancestor_cte_scopes {
189 all_cte_scopes.push(s);
190 }
191
192 let effective_expr = match scope_expr {
195 Expression::Cte(cte) => &cte.this,
196 other => other,
197 };
198
199 if matches!(
201 effective_expr,
202 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
203 ) {
204 if matches!(scope_expr, Expression::Cte(_)) {
206 let mut inner_scope = Scope::new(effective_expr.clone());
207 inner_scope.union_scopes = scope.union_scopes.clone();
208 inner_scope.sources = scope.sources.clone();
209 inner_scope.cte_sources = scope.cte_sources.clone();
210 inner_scope.cte_scopes = scope.cte_scopes.clone();
211 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
212 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
213 return handle_set_operation(
214 &column,
215 &inner_scope,
216 dialect,
217 scope_name,
218 source_name,
219 reference_node_name,
220 trim_selects,
221 ancestor_cte_scopes,
222 );
223 }
224 return handle_set_operation(
225 &column,
226 scope,
227 dialect,
228 scope_name,
229 source_name,
230 reference_node_name,
231 trim_selects,
232 ancestor_cte_scopes,
233 );
234 }
235
236 let select_expr = find_select_expr(effective_expr, &column)?;
238 let column_name = resolve_column_name(&column, &select_expr);
239
240 let node_source = if trim_selects {
242 trim_source(effective_expr, &select_expr)
243 } else {
244 effective_expr.clone()
245 };
246
247 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
249 node.source_name = source_name.to_string();
250 node.reference_node_name = reference_node_name.to_string();
251
252 if matches!(&select_expr, Expression::Star(_)) {
254 for (name, source_info) in &scope.sources {
255 let child = LineageNode::new(
256 format!("{}.*", name),
257 Expression::Star(crate::expressions::Star {
258 table: None,
259 except: None,
260 replace: None,
261 rename: None,
262 trailing_comments: vec![],
263 span: None,
264 }),
265 source_info.expression.clone(),
266 );
267 node.downstream.push(child);
268 }
269 return Ok(node);
270 }
271
272 let subqueries: Vec<&Expression> =
274 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
275 for sq_expr in subqueries {
276 if let Expression::Subquery(sq) = sq_expr {
277 for sq_scope in &scope.subquery_scopes {
278 if sq_scope.expression == sq.this {
279 if let Ok(child) = to_node_inner(
280 ColumnRef::Index(0),
281 sq_scope,
282 dialect,
283 &column_name,
284 "",
285 "",
286 trim_selects,
287 ancestor_cte_scopes,
288 ) {
289 node.downstream.push(child);
290 }
291 break;
292 }
293 }
294 }
295 }
296
297 let col_refs = find_column_refs_in_expr(&select_expr);
299 for col_ref in col_refs {
300 let col_name = &col_ref.column;
301 if let Some(ref table_id) = col_ref.table {
302 let tbl = &table_id.name;
303 resolve_qualified_column(
304 &mut node,
305 scope,
306 dialect,
307 tbl,
308 col_name,
309 &column_name,
310 trim_selects,
311 &all_cte_scopes,
312 );
313 } else {
314 resolve_unqualified_column(
315 &mut node,
316 scope,
317 dialect,
318 col_name,
319 &column_name,
320 trim_selects,
321 &all_cte_scopes,
322 );
323 }
324 }
325
326 Ok(node)
327}
328
329fn handle_set_operation(
334 column: &ColumnRef<'_>,
335 scope: &Scope,
336 dialect: Option<DialectType>,
337 scope_name: &str,
338 source_name: &str,
339 reference_node_name: &str,
340 trim_selects: bool,
341 ancestor_cte_scopes: &[Scope],
342) -> Result<LineageNode> {
343 let scope_expr = &scope.expression;
344
345 let col_index = match column {
347 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
348 ColumnRef::Index(i) => *i,
349 };
350
351 let col_name = match column {
352 ColumnRef::Name(name) => name.to_string(),
353 ColumnRef::Index(_) => format!("_{col_index}"),
354 };
355
356 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
357 node.source_name = source_name.to_string();
358 node.reference_node_name = reference_node_name.to_string();
359
360 for branch_scope in &scope.union_scopes {
362 if let Ok(child) = to_node_inner(
363 ColumnRef::Index(col_index),
364 branch_scope,
365 dialect,
366 scope_name,
367 "",
368 "",
369 trim_selects,
370 ancestor_cte_scopes,
371 ) {
372 node.downstream.push(child);
373 }
374 }
375
376 Ok(node)
377}
378
379fn resolve_qualified_column(
384 node: &mut LineageNode,
385 scope: &Scope,
386 dialect: Option<DialectType>,
387 table: &str,
388 col_name: &str,
389 parent_name: &str,
390 trim_selects: bool,
391 all_cte_scopes: &[&Scope],
392) {
393 if scope.cte_sources.contains_key(table) {
395 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
396 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
398 if let Ok(child) = to_node_inner(
399 ColumnRef::Name(col_name),
400 child_scope,
401 dialect,
402 parent_name,
403 table,
404 parent_name,
405 trim_selects,
406 &ancestors,
407 ) {
408 node.downstream.push(child);
409 return;
410 }
411 }
412 }
413
414 if let Some(source_info) = scope.sources.get(table) {
416 if source_info.is_scope {
417 if let Some(child_scope) = find_child_scope(scope, table) {
418 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
419 if let Ok(child) = to_node_inner(
420 ColumnRef::Name(col_name),
421 child_scope,
422 dialect,
423 parent_name,
424 table,
425 parent_name,
426 trim_selects,
427 &ancestors,
428 ) {
429 node.downstream.push(child);
430 return;
431 }
432 }
433 }
434 }
435
436 if let Some(source_info) = scope.sources.get(table) {
439 if !source_info.is_scope {
440 node.downstream.push(make_table_column_node_from_source(
441 table,
442 col_name,
443 &source_info.expression,
444 ));
445 return;
446 }
447 }
448
449 node.downstream
451 .push(make_table_column_node(table, col_name));
452}
453
454fn resolve_unqualified_column(
455 node: &mut LineageNode,
456 scope: &Scope,
457 dialect: Option<DialectType>,
458 col_name: &str,
459 parent_name: &str,
460 trim_selects: bool,
461 all_cte_scopes: &[&Scope],
462) {
463 let from_source_names = source_names_from_from_join(scope);
467
468 if from_source_names.len() == 1 {
469 let tbl = &from_source_names[0];
470 resolve_qualified_column(
471 node,
472 scope,
473 dialect,
474 tbl,
475 col_name,
476 parent_name,
477 trim_selects,
478 all_cte_scopes,
479 );
480 return;
481 }
482
483 let child = LineageNode::new(
485 col_name.to_string(),
486 Expression::Column(crate::expressions::Column {
487 name: crate::expressions::Identifier::new(col_name.to_string()),
488 table: None,
489 join_mark: false,
490 trailing_comments: vec![],
491 span: None,
492 }),
493 node.source.clone(),
494 );
495 node.downstream.push(child);
496}
497
498fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
499 fn source_name(expr: &Expression) -> Option<String> {
500 match expr {
501 Expression::Table(table) => Some(
502 table
503 .alias
504 .as_ref()
505 .map(|a| a.name.clone())
506 .unwrap_or_else(|| table.name.name.clone()),
507 ),
508 Expression::Subquery(subquery) => {
509 subquery.alias.as_ref().map(|alias| alias.name.clone())
510 }
511 Expression::Paren(paren) => source_name(&paren.this),
512 _ => None,
513 }
514 }
515
516 let effective_expr = match &scope.expression {
517 Expression::Cte(cte) => &cte.this,
518 expr => expr,
519 };
520
521 let mut names = Vec::new();
522 let mut seen = std::collections::HashSet::new();
523
524 if let Expression::Select(select) = effective_expr {
525 if let Some(from) = &select.from {
526 for expr in &from.expressions {
527 if let Some(name) = source_name(expr) {
528 if !name.is_empty() && seen.insert(name.clone()) {
529 names.push(name);
530 }
531 }
532 }
533 }
534 for join in &select.joins {
535 if let Some(name) = source_name(&join.this) {
536 if !name.is_empty() && seen.insert(name.clone()) {
537 names.push(name);
538 }
539 }
540 }
541 }
542
543 names
544}
545
546fn get_alias_or_name(expr: &Expression) -> Option<String> {
552 match expr {
553 Expression::Alias(alias) => Some(alias.alias.name.clone()),
554 Expression::Column(col) => Some(col.name.name.clone()),
555 Expression::Identifier(id) => Some(id.name.clone()),
556 Expression::Star(_) => Some("*".to_string()),
557 _ => None,
558 }
559}
560
561fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
563 match column {
564 ColumnRef::Name(n) => n.to_string(),
565 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
566 }
567}
568
569fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
571 if let Expression::Select(ref select) = scope_expr {
572 match column {
573 ColumnRef::Name(name) => {
574 for expr in &select.expressions {
575 if get_alias_or_name(expr).as_deref() == Some(name) {
576 return Ok(expr.clone());
577 }
578 }
579 Err(crate::error::Error::parse(
580 format!("Cannot find column '{}' in query", name),
581 0,
582 0,
583 0,
584 0,
585 ))
586 }
587 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
588 crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0, 0, 0)
589 }),
590 }
591 } else {
592 Err(crate::error::Error::parse(
593 "Expected SELECT expression for column lookup",
594 0,
595 0,
596 0,
597 0,
598 ))
599 }
600}
601
602fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
604 let mut expr = set_op_expr;
605 loop {
606 match expr {
607 Expression::Union(u) => expr = &u.left,
608 Expression::Intersect(i) => expr = &i.left,
609 Expression::Except(e) => expr = &e.left,
610 Expression::Select(select) => {
611 for (i, e) in select.expressions.iter().enumerate() {
612 if get_alias_or_name(e).as_deref() == Some(name) {
613 return Ok(i);
614 }
615 }
616 return Err(crate::error::Error::parse(
617 format!("Cannot find column '{}' in set operation", name),
618 0,
619 0,
620 0,
621 0,
622 ));
623 }
624 _ => {
625 return Err(crate::error::Error::parse(
626 "Expected SELECT or set operation",
627 0,
628 0,
629 0,
630 0,
631 ))
632 }
633 }
634 }
635}
636
637fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
639 if let Expression::Select(select) = select_expr {
640 let mut trimmed = select.as_ref().clone();
641 trimmed.expressions = vec![target_expr.clone()];
642 Expression::Select(Box::new(trimmed))
643 } else {
644 select_expr.clone()
645 }
646}
647
648fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
650 if scope.cte_sources.contains_key(source_name) {
652 for cte_scope in &scope.cte_scopes {
653 if let Expression::Cte(cte) = &cte_scope.expression {
654 if cte.alias.name == source_name {
655 return Some(cte_scope);
656 }
657 }
658 }
659 }
660
661 if let Some(source_info) = scope.sources.get(source_name) {
663 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
664 if let Expression::Subquery(sq) = &source_info.expression {
665 for dt_scope in &scope.derived_table_scopes {
666 if dt_scope.expression == sq.this {
667 return Some(dt_scope);
668 }
669 }
670 }
671 }
672 }
673
674 None
675}
676
677fn find_child_scope_in<'a>(
681 all_cte_scopes: &[&'a Scope],
682 scope: &'a Scope,
683 source_name: &str,
684) -> Option<&'a Scope> {
685 for cte_scope in &scope.cte_scopes {
687 if let Expression::Cte(cte) = &cte_scope.expression {
688 if cte.alias.name == source_name {
689 return Some(cte_scope);
690 }
691 }
692 }
693
694 for cte_scope in all_cte_scopes {
696 if let Expression::Cte(cte) = &cte_scope.expression {
697 if cte.alias.name == source_name {
698 return Some(cte_scope);
699 }
700 }
701 }
702
703 if let Some(source_info) = scope.sources.get(source_name) {
705 if source_info.is_scope {
706 if let Expression::Subquery(sq) = &source_info.expression {
707 for dt_scope in &scope.derived_table_scopes {
708 if dt_scope.expression == sq.this {
709 return Some(dt_scope);
710 }
711 }
712 }
713 }
714 }
715
716 None
717}
718
719fn make_table_column_node(table: &str, column: &str) -> LineageNode {
721 let mut node = LineageNode::new(
722 format!("{}.{}", table, column),
723 Expression::Column(crate::expressions::Column {
724 name: crate::expressions::Identifier::new(column.to_string()),
725 table: Some(crate::expressions::Identifier::new(table.to_string())),
726 join_mark: false,
727 trailing_comments: vec![],
728 span: None,
729 }),
730 Expression::Table(crate::expressions::TableRef::new(table)),
731 );
732 node.source_name = table.to_string();
733 node
734}
735
736fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
737 let mut parts: Vec<String> = Vec::new();
738 if let Some(catalog) = &table_ref.catalog {
739 parts.push(catalog.name.clone());
740 }
741 if let Some(schema) = &table_ref.schema {
742 parts.push(schema.name.clone());
743 }
744 parts.push(table_ref.name.name.clone());
745 parts.join(".")
746}
747
748fn make_table_column_node_from_source(
749 table_alias: &str,
750 column: &str,
751 source: &Expression,
752) -> LineageNode {
753 let mut node = LineageNode::new(
754 format!("{}.{}", table_alias, column),
755 Expression::Column(crate::expressions::Column {
756 name: crate::expressions::Identifier::new(column.to_string()),
757 table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
758 join_mark: false,
759 trailing_comments: vec![],
760 span: None,
761 }),
762 source.clone(),
763 );
764
765 if let Expression::Table(table_ref) = source {
766 node.source_name = table_name_from_table_ref(table_ref);
767 } else {
768 node.source_name = table_alias.to_string();
769 }
770
771 node
772}
773
774#[derive(Debug, Clone)]
776struct SimpleColumnRef {
777 table: Option<crate::expressions::Identifier>,
778 column: String,
779}
780
781fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
783 let mut refs = Vec::new();
784 collect_column_refs(expr, &mut refs);
785 refs
786}
787
788fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
789 let mut stack: Vec<&Expression> = vec![expr];
790
791 while let Some(current) = stack.pop() {
792 match current {
793 Expression::Column(col) => {
795 refs.push(SimpleColumnRef {
796 table: col.table.clone(),
797 column: col.name.name.clone(),
798 });
799 }
800
801 Expression::Subquery(_) | Expression::Exists(_) => {}
803
804 Expression::And(op)
806 | Expression::Or(op)
807 | Expression::Eq(op)
808 | Expression::Neq(op)
809 | Expression::Lt(op)
810 | Expression::Lte(op)
811 | Expression::Gt(op)
812 | Expression::Gte(op)
813 | Expression::Add(op)
814 | Expression::Sub(op)
815 | Expression::Mul(op)
816 | Expression::Div(op)
817 | Expression::Mod(op)
818 | Expression::BitwiseAnd(op)
819 | Expression::BitwiseOr(op)
820 | Expression::BitwiseXor(op)
821 | Expression::BitwiseLeftShift(op)
822 | Expression::BitwiseRightShift(op)
823 | Expression::Concat(op)
824 | Expression::Adjacent(op)
825 | Expression::TsMatch(op)
826 | Expression::PropertyEQ(op)
827 | Expression::ArrayContainsAll(op)
828 | Expression::ArrayContainedBy(op)
829 | Expression::ArrayOverlaps(op)
830 | Expression::JSONBContainsAllTopKeys(op)
831 | Expression::JSONBContainsAnyTopKeys(op)
832 | Expression::JSONBDeleteAtPath(op)
833 | Expression::ExtendsLeft(op)
834 | Expression::ExtendsRight(op)
835 | Expression::Is(op)
836 | Expression::MemberOf(op)
837 | Expression::NullSafeEq(op)
838 | Expression::NullSafeNeq(op)
839 | Expression::Glob(op)
840 | Expression::Match(op) => {
841 stack.push(&op.left);
842 stack.push(&op.right);
843 }
844
845 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
847 stack.push(&u.this);
848 }
849
850 Expression::Upper(f)
852 | Expression::Lower(f)
853 | Expression::Length(f)
854 | Expression::LTrim(f)
855 | Expression::RTrim(f)
856 | Expression::Reverse(f)
857 | Expression::Abs(f)
858 | Expression::Sqrt(f)
859 | Expression::Cbrt(f)
860 | Expression::Ln(f)
861 | Expression::Exp(f)
862 | Expression::Sign(f)
863 | Expression::Date(f)
864 | Expression::Time(f)
865 | Expression::DateFromUnixDate(f)
866 | Expression::UnixDate(f)
867 | Expression::UnixSeconds(f)
868 | Expression::UnixMillis(f)
869 | Expression::UnixMicros(f)
870 | Expression::TimeStrToDate(f)
871 | Expression::DateToDi(f)
872 | Expression::DiToDate(f)
873 | Expression::TsOrDiToDi(f)
874 | Expression::TsOrDsToDatetime(f)
875 | Expression::TsOrDsToTimestamp(f)
876 | Expression::YearOfWeek(f)
877 | Expression::YearOfWeekIso(f)
878 | Expression::Initcap(f)
879 | Expression::Ascii(f)
880 | Expression::Chr(f)
881 | Expression::Soundex(f)
882 | Expression::ByteLength(f)
883 | Expression::Hex(f)
884 | Expression::LowerHex(f)
885 | Expression::Unicode(f)
886 | Expression::Radians(f)
887 | Expression::Degrees(f)
888 | Expression::Sin(f)
889 | Expression::Cos(f)
890 | Expression::Tan(f)
891 | Expression::Asin(f)
892 | Expression::Acos(f)
893 | Expression::Atan(f)
894 | Expression::IsNan(f)
895 | Expression::IsInf(f)
896 | Expression::ArrayLength(f)
897 | Expression::ArraySize(f)
898 | Expression::Cardinality(f)
899 | Expression::ArrayReverse(f)
900 | Expression::ArrayDistinct(f)
901 | Expression::ArrayFlatten(f)
902 | Expression::ArrayCompact(f)
903 | Expression::Explode(f)
904 | Expression::ExplodeOuter(f)
905 | Expression::ToArray(f)
906 | Expression::MapFromEntries(f)
907 | Expression::MapKeys(f)
908 | Expression::MapValues(f)
909 | Expression::JsonArrayLength(f)
910 | Expression::JsonKeys(f)
911 | Expression::JsonType(f)
912 | Expression::ParseJson(f)
913 | Expression::ToJson(f)
914 | Expression::Typeof(f)
915 | Expression::BitwiseCount(f)
916 | Expression::Year(f)
917 | Expression::Month(f)
918 | Expression::Day(f)
919 | Expression::Hour(f)
920 | Expression::Minute(f)
921 | Expression::Second(f)
922 | Expression::DayOfWeek(f)
923 | Expression::DayOfWeekIso(f)
924 | Expression::DayOfMonth(f)
925 | Expression::DayOfYear(f)
926 | Expression::WeekOfYear(f)
927 | Expression::Quarter(f)
928 | Expression::Epoch(f)
929 | Expression::EpochMs(f)
930 | Expression::TimeStrToUnix(f)
931 | Expression::SHA(f)
932 | Expression::SHA1Digest(f)
933 | Expression::TimeToUnix(f)
934 | Expression::JSONBool(f)
935 | Expression::Int64(f)
936 | Expression::MD5NumberLower64(f)
937 | Expression::MD5NumberUpper64(f)
938 | Expression::DateStrToDate(f)
939 | Expression::DateToDateStr(f) => {
940 stack.push(&f.this);
941 }
942
943 Expression::Power(f)
945 | Expression::NullIf(f)
946 | Expression::IfNull(f)
947 | Expression::Nvl(f)
948 | Expression::UnixToTimeStr(f)
949 | Expression::Contains(f)
950 | Expression::StartsWith(f)
951 | Expression::EndsWith(f)
952 | Expression::Levenshtein(f)
953 | Expression::ModFunc(f)
954 | Expression::Atan2(f)
955 | Expression::IntDiv(f)
956 | Expression::AddMonths(f)
957 | Expression::MonthsBetween(f)
958 | Expression::NextDay(f)
959 | Expression::ArrayContains(f)
960 | Expression::ArrayPosition(f)
961 | Expression::ArrayAppend(f)
962 | Expression::ArrayPrepend(f)
963 | Expression::ArrayUnion(f)
964 | Expression::ArrayExcept(f)
965 | Expression::ArrayRemove(f)
966 | Expression::StarMap(f)
967 | Expression::MapFromArrays(f)
968 | Expression::MapContainsKey(f)
969 | Expression::ElementAt(f)
970 | Expression::JsonMergePatch(f)
971 | Expression::JSONBContains(f)
972 | Expression::JSONBExtract(f) => {
973 stack.push(&f.this);
974 stack.push(&f.expression);
975 }
976
977 Expression::Greatest(f)
979 | Expression::Least(f)
980 | Expression::Coalesce(f)
981 | Expression::ArrayConcat(f)
982 | Expression::ArrayIntersect(f)
983 | Expression::ArrayZip(f)
984 | Expression::MapConcat(f)
985 | Expression::JsonArray(f) => {
986 for e in &f.expressions {
987 stack.push(e);
988 }
989 }
990
991 Expression::Sum(f)
993 | Expression::Avg(f)
994 | Expression::Min(f)
995 | Expression::Max(f)
996 | Expression::ArrayAgg(f)
997 | Expression::CountIf(f)
998 | Expression::Stddev(f)
999 | Expression::StddevPop(f)
1000 | Expression::StddevSamp(f)
1001 | Expression::Variance(f)
1002 | Expression::VarPop(f)
1003 | Expression::VarSamp(f)
1004 | Expression::Median(f)
1005 | Expression::Mode(f)
1006 | Expression::First(f)
1007 | Expression::Last(f)
1008 | Expression::AnyValue(f)
1009 | Expression::ApproxDistinct(f)
1010 | Expression::ApproxCountDistinct(f)
1011 | Expression::LogicalAnd(f)
1012 | Expression::LogicalOr(f)
1013 | Expression::Skewness(f)
1014 | Expression::ArrayConcatAgg(f)
1015 | Expression::ArrayUniqueAgg(f)
1016 | Expression::BoolXorAgg(f)
1017 | Expression::BitwiseAndAgg(f)
1018 | Expression::BitwiseOrAgg(f)
1019 | Expression::BitwiseXorAgg(f) => {
1020 stack.push(&f.this);
1021 if let Some(ref filter) = f.filter {
1022 stack.push(filter);
1023 }
1024 if let Some((ref expr, _)) = f.having_max {
1025 stack.push(expr);
1026 }
1027 if let Some(ref limit) = f.limit {
1028 stack.push(limit);
1029 }
1030 }
1031
1032 Expression::Function(func) => {
1034 for arg in &func.args {
1035 stack.push(arg);
1036 }
1037 }
1038 Expression::AggregateFunction(func) => {
1039 for arg in &func.args {
1040 stack.push(arg);
1041 }
1042 if let Some(ref filter) = func.filter {
1043 stack.push(filter);
1044 }
1045 if let Some(ref limit) = func.limit {
1046 stack.push(limit);
1047 }
1048 }
1049
1050 Expression::WindowFunction(wf) => {
1052 stack.push(&wf.this);
1053 }
1054
1055 Expression::Alias(a) => {
1057 stack.push(&a.this);
1058 }
1059 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1060 stack.push(&c.this);
1061 if let Some(ref fmt) = c.format {
1062 stack.push(fmt);
1063 }
1064 if let Some(ref def) = c.default {
1065 stack.push(def);
1066 }
1067 }
1068 Expression::Paren(p) => {
1069 stack.push(&p.this);
1070 }
1071 Expression::Annotated(a) => {
1072 stack.push(&a.this);
1073 }
1074 Expression::Case(case) => {
1075 if let Some(ref operand) = case.operand {
1076 stack.push(operand);
1077 }
1078 for (cond, result) in &case.whens {
1079 stack.push(cond);
1080 stack.push(result);
1081 }
1082 if let Some(ref else_expr) = case.else_ {
1083 stack.push(else_expr);
1084 }
1085 }
1086 Expression::Collation(c) => {
1087 stack.push(&c.this);
1088 }
1089 Expression::In(i) => {
1090 stack.push(&i.this);
1091 for e in &i.expressions {
1092 stack.push(e);
1093 }
1094 if let Some(ref q) = i.query {
1095 stack.push(q);
1096 }
1097 if let Some(ref u) = i.unnest {
1098 stack.push(u);
1099 }
1100 }
1101 Expression::Between(b) => {
1102 stack.push(&b.this);
1103 stack.push(&b.low);
1104 stack.push(&b.high);
1105 }
1106 Expression::IsNull(n) => {
1107 stack.push(&n.this);
1108 }
1109 Expression::IsTrue(t) | Expression::IsFalse(t) => {
1110 stack.push(&t.this);
1111 }
1112 Expression::IsJson(j) => {
1113 stack.push(&j.this);
1114 }
1115 Expression::Like(l) | Expression::ILike(l) => {
1116 stack.push(&l.left);
1117 stack.push(&l.right);
1118 if let Some(ref esc) = l.escape {
1119 stack.push(esc);
1120 }
1121 }
1122 Expression::SimilarTo(s) => {
1123 stack.push(&s.this);
1124 stack.push(&s.pattern);
1125 if let Some(ref esc) = s.escape {
1126 stack.push(esc);
1127 }
1128 }
1129 Expression::Ordered(o) => {
1130 stack.push(&o.this);
1131 }
1132 Expression::Array(a) => {
1133 for e in &a.expressions {
1134 stack.push(e);
1135 }
1136 }
1137 Expression::Tuple(t) => {
1138 for e in &t.expressions {
1139 stack.push(e);
1140 }
1141 }
1142 Expression::Struct(s) => {
1143 for (_, e) in &s.fields {
1144 stack.push(e);
1145 }
1146 }
1147 Expression::Subscript(s) => {
1148 stack.push(&s.this);
1149 stack.push(&s.index);
1150 }
1151 Expression::Dot(d) => {
1152 stack.push(&d.this);
1153 }
1154 Expression::MethodCall(m) => {
1155 stack.push(&m.this);
1156 for arg in &m.args {
1157 stack.push(arg);
1158 }
1159 }
1160 Expression::ArraySlice(s) => {
1161 stack.push(&s.this);
1162 if let Some(ref start) = s.start {
1163 stack.push(start);
1164 }
1165 if let Some(ref end) = s.end {
1166 stack.push(end);
1167 }
1168 }
1169 Expression::Lambda(l) => {
1170 stack.push(&l.body);
1171 }
1172 Expression::NamedArgument(n) => {
1173 stack.push(&n.value);
1174 }
1175 Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1176 stack.push(e);
1177 }
1178
1179 Expression::Substring(f) => {
1181 stack.push(&f.this);
1182 stack.push(&f.start);
1183 if let Some(ref len) = f.length {
1184 stack.push(len);
1185 }
1186 }
1187 Expression::Trim(f) => {
1188 stack.push(&f.this);
1189 if let Some(ref chars) = f.characters {
1190 stack.push(chars);
1191 }
1192 }
1193 Expression::Replace(f) => {
1194 stack.push(&f.this);
1195 stack.push(&f.old);
1196 stack.push(&f.new);
1197 }
1198 Expression::IfFunc(f) => {
1199 stack.push(&f.condition);
1200 stack.push(&f.true_value);
1201 if let Some(ref fv) = f.false_value {
1202 stack.push(fv);
1203 }
1204 }
1205 Expression::Nvl2(f) => {
1206 stack.push(&f.this);
1207 stack.push(&f.true_value);
1208 stack.push(&f.false_value);
1209 }
1210 Expression::ConcatWs(f) => {
1211 stack.push(&f.separator);
1212 for e in &f.expressions {
1213 stack.push(e);
1214 }
1215 }
1216 Expression::Count(f) => {
1217 if let Some(ref this) = f.this {
1218 stack.push(this);
1219 }
1220 if let Some(ref filter) = f.filter {
1221 stack.push(filter);
1222 }
1223 }
1224 Expression::GroupConcat(f) => {
1225 stack.push(&f.this);
1226 if let Some(ref sep) = f.separator {
1227 stack.push(sep);
1228 }
1229 if let Some(ref filter) = f.filter {
1230 stack.push(filter);
1231 }
1232 }
1233 Expression::StringAgg(f) => {
1234 stack.push(&f.this);
1235 if let Some(ref sep) = f.separator {
1236 stack.push(sep);
1237 }
1238 if let Some(ref filter) = f.filter {
1239 stack.push(filter);
1240 }
1241 if let Some(ref limit) = f.limit {
1242 stack.push(limit);
1243 }
1244 }
1245 Expression::ListAgg(f) => {
1246 stack.push(&f.this);
1247 if let Some(ref sep) = f.separator {
1248 stack.push(sep);
1249 }
1250 if let Some(ref filter) = f.filter {
1251 stack.push(filter);
1252 }
1253 }
1254 Expression::SumIf(f) => {
1255 stack.push(&f.this);
1256 stack.push(&f.condition);
1257 if let Some(ref filter) = f.filter {
1258 stack.push(filter);
1259 }
1260 }
1261 Expression::DateAdd(f) | Expression::DateSub(f) => {
1262 stack.push(&f.this);
1263 stack.push(&f.interval);
1264 }
1265 Expression::DateDiff(f) => {
1266 stack.push(&f.this);
1267 stack.push(&f.expression);
1268 }
1269 Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1270 stack.push(&f.this);
1271 }
1272 Expression::Extract(f) => {
1273 stack.push(&f.this);
1274 }
1275 Expression::Round(f) => {
1276 stack.push(&f.this);
1277 if let Some(ref d) = f.decimals {
1278 stack.push(d);
1279 }
1280 }
1281 Expression::Floor(f) => {
1282 stack.push(&f.this);
1283 if let Some(ref s) = f.scale {
1284 stack.push(s);
1285 }
1286 if let Some(ref t) = f.to {
1287 stack.push(t);
1288 }
1289 }
1290 Expression::Ceil(f) => {
1291 stack.push(&f.this);
1292 if let Some(ref d) = f.decimals {
1293 stack.push(d);
1294 }
1295 if let Some(ref t) = f.to {
1296 stack.push(t);
1297 }
1298 }
1299 Expression::Log(f) => {
1300 stack.push(&f.this);
1301 if let Some(ref b) = f.base {
1302 stack.push(b);
1303 }
1304 }
1305 Expression::AtTimeZone(f) => {
1306 stack.push(&f.this);
1307 stack.push(&f.zone);
1308 }
1309 Expression::Lead(f) | Expression::Lag(f) => {
1310 stack.push(&f.this);
1311 if let Some(ref off) = f.offset {
1312 stack.push(off);
1313 }
1314 if let Some(ref def) = f.default {
1315 stack.push(def);
1316 }
1317 }
1318 Expression::FirstValue(f) | Expression::LastValue(f) => {
1319 stack.push(&f.this);
1320 }
1321 Expression::NthValue(f) => {
1322 stack.push(&f.this);
1323 stack.push(&f.offset);
1324 }
1325 Expression::Position(f) => {
1326 stack.push(&f.substring);
1327 stack.push(&f.string);
1328 if let Some(ref start) = f.start {
1329 stack.push(start);
1330 }
1331 }
1332 Expression::Decode(f) => {
1333 stack.push(&f.this);
1334 for (search, result) in &f.search_results {
1335 stack.push(search);
1336 stack.push(result);
1337 }
1338 if let Some(ref def) = f.default {
1339 stack.push(def);
1340 }
1341 }
1342 Expression::CharFunc(f) => {
1343 for arg in &f.args {
1344 stack.push(arg);
1345 }
1346 }
1347 Expression::ArraySort(f) => {
1348 stack.push(&f.this);
1349 if let Some(ref cmp) = f.comparator {
1350 stack.push(cmp);
1351 }
1352 }
1353 Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1354 stack.push(&f.this);
1355 stack.push(&f.separator);
1356 if let Some(ref nr) = f.null_replacement {
1357 stack.push(nr);
1358 }
1359 }
1360 Expression::ArrayFilter(f) => {
1361 stack.push(&f.this);
1362 stack.push(&f.filter);
1363 }
1364 Expression::ArrayTransform(f) => {
1365 stack.push(&f.this);
1366 stack.push(&f.transform);
1367 }
1368 Expression::Sequence(f)
1369 | Expression::Generate(f)
1370 | Expression::ExplodingGenerateSeries(f) => {
1371 stack.push(&f.start);
1372 stack.push(&f.stop);
1373 if let Some(ref step) = f.step {
1374 stack.push(step);
1375 }
1376 }
1377 Expression::JsonExtract(f)
1378 | Expression::JsonExtractScalar(f)
1379 | Expression::JsonQuery(f)
1380 | Expression::JsonValue(f) => {
1381 stack.push(&f.this);
1382 stack.push(&f.path);
1383 }
1384 Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1385 stack.push(&f.this);
1386 for p in &f.paths {
1387 stack.push(p);
1388 }
1389 }
1390 Expression::JsonObject(f) => {
1391 for (k, v) in &f.pairs {
1392 stack.push(k);
1393 stack.push(v);
1394 }
1395 }
1396 Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1397 stack.push(&f.this);
1398 for (path, val) in &f.path_values {
1399 stack.push(path);
1400 stack.push(val);
1401 }
1402 }
1403 Expression::Overlay(f) => {
1404 stack.push(&f.this);
1405 stack.push(&f.replacement);
1406 stack.push(&f.from);
1407 if let Some(ref len) = f.length {
1408 stack.push(len);
1409 }
1410 }
1411 Expression::Convert(f) => {
1412 stack.push(&f.this);
1413 if let Some(ref style) = f.style {
1414 stack.push(style);
1415 }
1416 }
1417 Expression::ApproxPercentile(f) => {
1418 stack.push(&f.this);
1419 stack.push(&f.percentile);
1420 if let Some(ref acc) = f.accuracy {
1421 stack.push(acc);
1422 }
1423 if let Some(ref filter) = f.filter {
1424 stack.push(filter);
1425 }
1426 }
1427 Expression::Percentile(f)
1428 | Expression::PercentileCont(f)
1429 | Expression::PercentileDisc(f) => {
1430 stack.push(&f.this);
1431 stack.push(&f.percentile);
1432 if let Some(ref filter) = f.filter {
1433 stack.push(filter);
1434 }
1435 }
1436 Expression::WithinGroup(f) => {
1437 stack.push(&f.this);
1438 }
1439 Expression::Left(f) | Expression::Right(f) => {
1440 stack.push(&f.this);
1441 stack.push(&f.length);
1442 }
1443 Expression::Repeat(f) => {
1444 stack.push(&f.this);
1445 stack.push(&f.times);
1446 }
1447 Expression::Lpad(f) | Expression::Rpad(f) => {
1448 stack.push(&f.this);
1449 stack.push(&f.length);
1450 if let Some(ref fill) = f.fill {
1451 stack.push(fill);
1452 }
1453 }
1454 Expression::Split(f) => {
1455 stack.push(&f.this);
1456 stack.push(&f.delimiter);
1457 }
1458 Expression::RegexpLike(f) => {
1459 stack.push(&f.this);
1460 stack.push(&f.pattern);
1461 if let Some(ref flags) = f.flags {
1462 stack.push(flags);
1463 }
1464 }
1465 Expression::RegexpReplace(f) => {
1466 stack.push(&f.this);
1467 stack.push(&f.pattern);
1468 stack.push(&f.replacement);
1469 if let Some(ref flags) = f.flags {
1470 stack.push(flags);
1471 }
1472 }
1473 Expression::RegexpExtract(f) => {
1474 stack.push(&f.this);
1475 stack.push(&f.pattern);
1476 if let Some(ref group) = f.group {
1477 stack.push(group);
1478 }
1479 }
1480 Expression::ToDate(f) => {
1481 stack.push(&f.this);
1482 if let Some(ref fmt) = f.format {
1483 stack.push(fmt);
1484 }
1485 }
1486 Expression::ToTimestamp(f) => {
1487 stack.push(&f.this);
1488 if let Some(ref fmt) = f.format {
1489 stack.push(fmt);
1490 }
1491 }
1492 Expression::DateFormat(f) | Expression::FormatDate(f) => {
1493 stack.push(&f.this);
1494 stack.push(&f.format);
1495 }
1496 Expression::LastDay(f) => {
1497 stack.push(&f.this);
1498 }
1499 Expression::FromUnixtime(f) => {
1500 stack.push(&f.this);
1501 if let Some(ref fmt) = f.format {
1502 stack.push(fmt);
1503 }
1504 }
1505 Expression::UnixTimestamp(f) => {
1506 if let Some(ref this) = f.this {
1507 stack.push(this);
1508 }
1509 if let Some(ref fmt) = f.format {
1510 stack.push(fmt);
1511 }
1512 }
1513 Expression::MakeDate(f) => {
1514 stack.push(&f.year);
1515 stack.push(&f.month);
1516 stack.push(&f.day);
1517 }
1518 Expression::MakeTimestamp(f) => {
1519 stack.push(&f.year);
1520 stack.push(&f.month);
1521 stack.push(&f.day);
1522 stack.push(&f.hour);
1523 stack.push(&f.minute);
1524 stack.push(&f.second);
1525 if let Some(ref tz) = f.timezone {
1526 stack.push(tz);
1527 }
1528 }
1529 Expression::TruncFunc(f) => {
1530 stack.push(&f.this);
1531 if let Some(ref d) = f.decimals {
1532 stack.push(d);
1533 }
1534 }
1535 Expression::ArrayFunc(f) => {
1536 for e in &f.expressions {
1537 stack.push(e);
1538 }
1539 }
1540 Expression::Unnest(f) => {
1541 stack.push(&f.this);
1542 for e in &f.expressions {
1543 stack.push(e);
1544 }
1545 }
1546 Expression::StructFunc(f) => {
1547 for (_, e) in &f.fields {
1548 stack.push(e);
1549 }
1550 }
1551 Expression::StructExtract(f) => {
1552 stack.push(&f.this);
1553 }
1554 Expression::NamedStruct(f) => {
1555 for (k, v) in &f.pairs {
1556 stack.push(k);
1557 stack.push(v);
1558 }
1559 }
1560 Expression::MapFunc(f) => {
1561 for k in &f.keys {
1562 stack.push(k);
1563 }
1564 for v in &f.values {
1565 stack.push(v);
1566 }
1567 }
1568 Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1569 stack.push(&f.this);
1570 stack.push(&f.transform);
1571 }
1572 Expression::JsonArrayAgg(f) => {
1573 stack.push(&f.this);
1574 if let Some(ref filter) = f.filter {
1575 stack.push(filter);
1576 }
1577 }
1578 Expression::JsonObjectAgg(f) => {
1579 stack.push(&f.key);
1580 stack.push(&f.value);
1581 if let Some(ref filter) = f.filter {
1582 stack.push(filter);
1583 }
1584 }
1585 Expression::NTile(f) => {
1586 if let Some(ref n) = f.num_buckets {
1587 stack.push(n);
1588 }
1589 }
1590 Expression::Rand(f) => {
1591 if let Some(ref s) = f.seed {
1592 stack.push(s);
1593 }
1594 if let Some(ref lo) = f.lower {
1595 stack.push(lo);
1596 }
1597 if let Some(ref hi) = f.upper {
1598 stack.push(hi);
1599 }
1600 }
1601 Expression::Any(q) | Expression::All(q) => {
1602 stack.push(&q.this);
1603 stack.push(&q.subquery);
1604 }
1605 Expression::Overlaps(o) => {
1606 if let Some(ref this) = o.this {
1607 stack.push(this);
1608 }
1609 if let Some(ref expr) = o.expression {
1610 stack.push(expr);
1611 }
1612 if let Some(ref ls) = o.left_start {
1613 stack.push(ls);
1614 }
1615 if let Some(ref le) = o.left_end {
1616 stack.push(le);
1617 }
1618 if let Some(ref rs) = o.right_start {
1619 stack.push(rs);
1620 }
1621 if let Some(ref re) = o.right_end {
1622 stack.push(re);
1623 }
1624 }
1625 Expression::Interval(i) => {
1626 if let Some(ref this) = i.this {
1627 stack.push(this);
1628 }
1629 }
1630 Expression::TimeStrToTime(f) => {
1631 stack.push(&f.this);
1632 if let Some(ref zone) = f.zone {
1633 stack.push(zone);
1634 }
1635 }
1636 Expression::JSONBExtractScalar(f) => {
1637 stack.push(&f.this);
1638 stack.push(&f.expression);
1639 if let Some(ref jt) = f.json_type {
1640 stack.push(jt);
1641 }
1642 }
1643
1644 _ => {}
1649 }
1650 }
1651}
1652
1653#[cfg(test)]
1658mod tests {
1659 use super::*;
1660 use crate::dialects::{Dialect, DialectType};
1661
1662 fn parse(sql: &str) -> Expression {
1663 let dialect = Dialect::get(DialectType::Generic);
1664 let ast = dialect.parse(sql).unwrap();
1665 ast.into_iter().next().unwrap()
1666 }
1667
1668 #[test]
1669 fn test_simple_lineage() {
1670 let expr = parse("SELECT a FROM t");
1671 let node = lineage("a", &expr, None, false).unwrap();
1672
1673 assert_eq!(node.name, "a");
1674 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1675 let names = node.downstream_names();
1677 assert!(
1678 names.iter().any(|n| n == "t.a"),
1679 "Expected t.a in downstream, got: {:?}",
1680 names
1681 );
1682 }
1683
1684 #[test]
1685 fn test_lineage_walk() {
1686 let root = LineageNode {
1687 name: "col_a".to_string(),
1688 expression: Expression::Null(crate::expressions::Null),
1689 source: Expression::Null(crate::expressions::Null),
1690 downstream: vec![LineageNode::new(
1691 "t.a",
1692 Expression::Null(crate::expressions::Null),
1693 Expression::Null(crate::expressions::Null),
1694 )],
1695 source_name: String::new(),
1696 reference_node_name: String::new(),
1697 };
1698
1699 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1700 assert_eq!(names.len(), 2);
1701 assert_eq!(names[0], "col_a");
1702 assert_eq!(names[1], "t.a");
1703 }
1704
1705 #[test]
1706 fn test_aliased_column() {
1707 let expr = parse("SELECT a + 1 AS b FROM t");
1708 let node = lineage("b", &expr, None, false).unwrap();
1709
1710 assert_eq!(node.name, "b");
1711 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1713 assert!(
1714 all_names.iter().any(|n| n.contains("a")),
1715 "Expected to trace to column a, got: {:?}",
1716 all_names
1717 );
1718 }
1719
1720 #[test]
1721 fn test_qualified_column() {
1722 let expr = parse("SELECT t.a FROM t");
1723 let node = lineage("a", &expr, None, false).unwrap();
1724
1725 assert_eq!(node.name, "a");
1726 let names = node.downstream_names();
1727 assert!(
1728 names.iter().any(|n| n == "t.a"),
1729 "Expected t.a, got: {:?}",
1730 names
1731 );
1732 }
1733
1734 #[test]
1735 fn test_unqualified_column() {
1736 let expr = parse("SELECT a FROM t");
1737 let node = lineage("a", &expr, None, false).unwrap();
1738
1739 let names = node.downstream_names();
1741 assert!(
1742 names.iter().any(|n| n == "t.a"),
1743 "Expected t.a, got: {:?}",
1744 names
1745 );
1746 }
1747
1748 #[test]
1749 fn test_lineage_join() {
1750 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1751
1752 let node_a = lineage("a", &expr, None, false).unwrap();
1753 let names_a = node_a.downstream_names();
1754 assert!(
1755 names_a.iter().any(|n| n == "t.a"),
1756 "Expected t.a, got: {:?}",
1757 names_a
1758 );
1759
1760 let node_b = lineage("b", &expr, None, false).unwrap();
1761 let names_b = node_b.downstream_names();
1762 assert!(
1763 names_b.iter().any(|n| n == "s.b"),
1764 "Expected s.b, got: {:?}",
1765 names_b
1766 );
1767 }
1768
1769 #[test]
1770 fn test_lineage_alias_leaf_has_resolved_source_name() {
1771 let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
1772 let node = lineage("col1", &expr, None, false).unwrap();
1773
1774 let names = node.downstream_names();
1776 assert!(
1777 names.iter().any(|n| n == "t1.col1"),
1778 "Expected aliased column edge t1.col1, got: {:?}",
1779 names
1780 );
1781
1782 let leaf = node
1784 .downstream
1785 .iter()
1786 .find(|n| n.name == "t1.col1")
1787 .expect("Expected t1.col1 leaf");
1788 assert_eq!(leaf.source_name, "table1");
1789 match &leaf.source {
1790 Expression::Table(table) => assert_eq!(table.name.name, "table1"),
1791 _ => panic!("Expected leaf source to be a table expression"),
1792 }
1793 }
1794
1795 #[test]
1796 fn test_lineage_derived_table() {
1797 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
1798 let node = lineage("a", &expr, None, false).unwrap();
1799
1800 assert_eq!(node.name, "a");
1801 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1803 assert!(
1804 all_names.iter().any(|n| n == "t.a"),
1805 "Expected to trace through derived table to t.a, got: {:?}",
1806 all_names
1807 );
1808 }
1809
1810 #[test]
1811 fn test_lineage_cte() {
1812 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
1813 let node = lineage("a", &expr, None, false).unwrap();
1814
1815 assert_eq!(node.name, "a");
1816 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1817 assert!(
1818 all_names.iter().any(|n| n == "t.a"),
1819 "Expected to trace through CTE to t.a, got: {:?}",
1820 all_names
1821 );
1822 }
1823
1824 #[test]
1825 fn test_lineage_union() {
1826 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
1827 let node = lineage("a", &expr, None, false).unwrap();
1828
1829 assert_eq!(node.name, "a");
1830 assert_eq!(
1832 node.downstream.len(),
1833 2,
1834 "Expected 2 branches for UNION, got {}",
1835 node.downstream.len()
1836 );
1837 }
1838
1839 #[test]
1840 fn test_lineage_cte_union() {
1841 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
1842 let node = lineage("a", &expr, None, false).unwrap();
1843
1844 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1846 assert!(
1847 all_names.len() >= 3,
1848 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
1849 all_names
1850 );
1851 }
1852
1853 #[test]
1854 fn test_lineage_star() {
1855 let expr = parse("SELECT * FROM t");
1856 let node = lineage("*", &expr, None, false).unwrap();
1857
1858 assert_eq!(node.name, "*");
1859 assert!(
1861 !node.downstream.is_empty(),
1862 "Star should produce downstream nodes"
1863 );
1864 }
1865
1866 #[test]
1867 fn test_lineage_subquery_in_select() {
1868 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
1869 let node = lineage("x", &expr, None, false).unwrap();
1870
1871 assert_eq!(node.name, "x");
1872 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1874 assert!(
1875 all_names.len() >= 2,
1876 "Expected tracing into scalar subquery, got: {:?}",
1877 all_names
1878 );
1879 }
1880
1881 #[test]
1882 fn test_lineage_multiple_columns() {
1883 let expr = parse("SELECT a, b FROM t");
1884
1885 let node_a = lineage("a", &expr, None, false).unwrap();
1886 let node_b = lineage("b", &expr, None, false).unwrap();
1887
1888 assert_eq!(node_a.name, "a");
1889 assert_eq!(node_b.name, "b");
1890
1891 let names_a = node_a.downstream_names();
1893 let names_b = node_b.downstream_names();
1894 assert!(names_a.iter().any(|n| n == "t.a"));
1895 assert!(names_b.iter().any(|n| n == "t.b"));
1896 }
1897
1898 #[test]
1899 fn test_get_source_tables() {
1900 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1901 let node = lineage("a", &expr, None, false).unwrap();
1902
1903 let tables = get_source_tables(&node);
1904 assert!(
1905 tables.contains("t"),
1906 "Expected source table 't', got: {:?}",
1907 tables
1908 );
1909 }
1910
1911 #[test]
1912 fn test_lineage_column_not_found() {
1913 let expr = parse("SELECT a FROM t");
1914 let result = lineage("nonexistent", &expr, None, false);
1915 assert!(result.is_err());
1916 }
1917
1918 #[test]
1919 fn test_lineage_nested_cte() {
1920 let expr = parse(
1921 "WITH cte1 AS (SELECT a FROM t), \
1922 cte2 AS (SELECT a FROM cte1) \
1923 SELECT a FROM cte2",
1924 );
1925 let node = lineage("a", &expr, None, false).unwrap();
1926
1927 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1929 assert!(
1930 all_names.len() >= 3,
1931 "Expected to trace through nested CTEs, got: {:?}",
1932 all_names
1933 );
1934 }
1935
1936 #[test]
1937 fn test_trim_selects_true() {
1938 let expr = parse("SELECT a, b, c FROM t");
1939 let node = lineage("a", &expr, None, true).unwrap();
1940
1941 if let Expression::Select(select) = &node.source {
1943 assert_eq!(
1944 select.expressions.len(),
1945 1,
1946 "Trimmed source should have 1 expression, got {}",
1947 select.expressions.len()
1948 );
1949 } else {
1950 panic!("Expected Select source");
1951 }
1952 }
1953
1954 #[test]
1955 fn test_trim_selects_false() {
1956 let expr = parse("SELECT a, b, c FROM t");
1957 let node = lineage("a", &expr, None, false).unwrap();
1958
1959 if let Expression::Select(select) = &node.source {
1961 assert_eq!(
1962 select.expressions.len(),
1963 3,
1964 "Untrimmed source should have 3 expressions"
1965 );
1966 } else {
1967 panic!("Expected Select source");
1968 }
1969 }
1970
1971 #[test]
1972 fn test_lineage_expression_in_select() {
1973 let expr = parse("SELECT a + b AS c FROM t");
1974 let node = lineage("c", &expr, None, false).unwrap();
1975
1976 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1978 assert!(
1979 all_names.len() >= 3,
1980 "Expected to trace a + b to both columns, got: {:?}",
1981 all_names
1982 );
1983 }
1984
1985 #[test]
1986 fn test_set_operation_by_index() {
1987 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1988
1989 let node = lineage("a", &expr, None, false).unwrap();
1991
1992 assert_eq!(node.downstream.len(), 2);
1994 }
1995
1996 fn print_node(node: &LineageNode, indent: usize) {
1999 let pad = " ".repeat(indent);
2000 println!(
2001 "{pad}name={:?} source_name={:?}",
2002 node.name, node.source_name
2003 );
2004 for child in &node.downstream {
2005 print_node(child, indent + 1);
2006 }
2007 }
2008
2009 #[test]
2010 fn test_issue18_repro() {
2011 let query = "SELECT UPPER(name) as upper_name FROM users";
2013 println!("Query: {query}\n");
2014
2015 let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2016 let exprs = dialect.parse(query).unwrap();
2017 let expr = &exprs[0];
2018
2019 let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2020 println!("lineage(\"upper_name\"):");
2021 print_node(&node, 1);
2022
2023 let names = node.downstream_names();
2024 assert!(
2025 names.iter().any(|n| n == "users.name"),
2026 "Expected users.name in downstream, got: {:?}",
2027 names
2028 );
2029 }
2030
2031 #[test]
2032 fn test_lineage_upper_function() {
2033 let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2034 let node = lineage("upper_name", &expr, None, false).unwrap();
2035
2036 let names = node.downstream_names();
2037 assert!(
2038 names.iter().any(|n| n == "users.name"),
2039 "Expected users.name in downstream, got: {:?}",
2040 names
2041 );
2042 }
2043
2044 #[test]
2045 fn test_lineage_round_function() {
2046 let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2047 let node = lineage("rounded", &expr, None, false).unwrap();
2048
2049 let names = node.downstream_names();
2050 assert!(
2051 names.iter().any(|n| n == "products.price"),
2052 "Expected products.price in downstream, got: {:?}",
2053 names
2054 );
2055 }
2056
2057 #[test]
2058 fn test_lineage_coalesce_function() {
2059 let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2060 let node = lineage("val", &expr, None, false).unwrap();
2061
2062 let names = node.downstream_names();
2063 assert!(
2064 names.iter().any(|n| n == "t.a"),
2065 "Expected t.a in downstream, got: {:?}",
2066 names
2067 );
2068 assert!(
2069 names.iter().any(|n| n == "t.b"),
2070 "Expected t.b in downstream, got: {:?}",
2071 names
2072 );
2073 }
2074
2075 #[test]
2076 fn test_lineage_count_function() {
2077 let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2078 let node = lineage("cnt", &expr, None, false).unwrap();
2079
2080 let names = node.downstream_names();
2081 assert!(
2082 names.iter().any(|n| n == "t.id"),
2083 "Expected t.id in downstream, got: {:?}",
2084 names
2085 );
2086 }
2087
2088 #[test]
2089 fn test_lineage_sum_function() {
2090 let expr = parse("SELECT SUM(amount) AS total FROM t");
2091 let node = lineage("total", &expr, None, false).unwrap();
2092
2093 let names = node.downstream_names();
2094 assert!(
2095 names.iter().any(|n| n == "t.amount"),
2096 "Expected t.amount in downstream, got: {:?}",
2097 names
2098 );
2099 }
2100
2101 #[test]
2102 fn test_lineage_case_with_nested_functions() {
2103 let expr =
2104 parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2105 let node = lineage("result", &expr, None, false).unwrap();
2106
2107 let names = node.downstream_names();
2108 assert!(
2109 names.iter().any(|n| n == "t.x"),
2110 "Expected t.x in downstream, got: {:?}",
2111 names
2112 );
2113 assert!(
2114 names.iter().any(|n| n == "t.name"),
2115 "Expected t.name in downstream, got: {:?}",
2116 names
2117 );
2118 }
2119
2120 #[test]
2121 fn test_lineage_substring_function() {
2122 let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2123 let node = lineage("short", &expr, None, false).unwrap();
2124
2125 let names = node.downstream_names();
2126 assert!(
2127 names.iter().any(|n| n == "t.name"),
2128 "Expected t.name in downstream, got: {:?}",
2129 names
2130 );
2131 }
2132}