1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::scope::{build_scope, Scope};
11use crate::traversal::ExpressionWalk;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LineageNode {
19 pub name: String,
21 pub expression: Expression,
23 pub source: Expression,
25 pub downstream: Vec<LineageNode>,
27 pub source_name: String,
29 pub reference_node_name: String,
31}
32
33impl LineageNode {
34 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
36 Self {
37 name: name.into(),
38 expression,
39 source,
40 downstream: Vec::new(),
41 source_name: String::new(),
42 reference_node_name: String::new(),
43 }
44 }
45
46 pub fn walk(&self) -> LineageWalker<'_> {
48 LineageWalker { stack: vec![self] }
49 }
50
51 pub fn downstream_names(&self) -> Vec<String> {
53 self.downstream.iter().map(|n| n.name.clone()).collect()
54 }
55}
56
57pub struct LineageWalker<'a> {
59 stack: Vec<&'a LineageNode>,
60}
61
62impl<'a> Iterator for LineageWalker<'a> {
63 type Item = &'a LineageNode;
64
65 fn next(&mut self) -> Option<Self::Item> {
66 if let Some(node) = self.stack.pop() {
67 for child in node.downstream.iter().rev() {
69 self.stack.push(child);
70 }
71 Some(node)
72 } else {
73 None
74 }
75 }
76}
77
78enum ColumnRef<'a> {
84 Name(&'a str),
85 Index(usize),
86}
87
88pub fn lineage(
114 column: &str,
115 sql: &Expression,
116 dialect: Option<DialectType>,
117 trim_selects: bool,
118) -> Result<LineageNode> {
119 let scope = build_scope(sql);
120 to_node(
121 ColumnRef::Name(column),
122 &scope,
123 dialect,
124 "",
125 "",
126 "",
127 trim_selects,
128 )
129}
130
131pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
133 let mut tables = HashSet::new();
134 collect_source_tables(node, &mut tables);
135 tables
136}
137
138pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
140 if let Expression::Table(table) = &node.source {
141 tables.insert(table.name.name.clone());
142 }
143 for child in &node.downstream {
144 collect_source_tables(child, tables);
145 }
146}
147
148fn to_node(
154 column: ColumnRef<'_>,
155 scope: &Scope,
156 dialect: Option<DialectType>,
157 scope_name: &str,
158 source_name: &str,
159 reference_node_name: &str,
160 trim_selects: bool,
161) -> Result<LineageNode> {
162 to_node_inner(
163 column,
164 scope,
165 dialect,
166 scope_name,
167 source_name,
168 reference_node_name,
169 trim_selects,
170 &[],
171 )
172}
173
174fn to_node_inner(
175 column: ColumnRef<'_>,
176 scope: &Scope,
177 dialect: Option<DialectType>,
178 scope_name: &str,
179 source_name: &str,
180 reference_node_name: &str,
181 trim_selects: bool,
182 ancestor_cte_scopes: &[Scope],
183) -> Result<LineageNode> {
184 let scope_expr = &scope.expression;
185
186 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
188 for s in ancestor_cte_scopes {
189 all_cte_scopes.push(s);
190 }
191
192 let effective_expr = match scope_expr {
195 Expression::Cte(cte) => &cte.this,
196 other => other,
197 };
198
199 if matches!(
201 effective_expr,
202 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
203 ) {
204 if matches!(scope_expr, Expression::Cte(_)) {
206 let mut inner_scope = Scope::new(effective_expr.clone());
207 inner_scope.union_scopes = scope.union_scopes.clone();
208 inner_scope.sources = scope.sources.clone();
209 inner_scope.cte_sources = scope.cte_sources.clone();
210 inner_scope.cte_scopes = scope.cte_scopes.clone();
211 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
212 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
213 return handle_set_operation(
214 &column,
215 &inner_scope,
216 dialect,
217 scope_name,
218 source_name,
219 reference_node_name,
220 trim_selects,
221 ancestor_cte_scopes,
222 );
223 }
224 return handle_set_operation(
225 &column,
226 scope,
227 dialect,
228 scope_name,
229 source_name,
230 reference_node_name,
231 trim_selects,
232 ancestor_cte_scopes,
233 );
234 }
235
236 let select_expr = find_select_expr(effective_expr, &column)?;
238 let column_name = resolve_column_name(&column, &select_expr);
239
240 let node_source = if trim_selects {
242 trim_source(effective_expr, &select_expr)
243 } else {
244 effective_expr.clone()
245 };
246
247 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
249 node.source_name = source_name.to_string();
250 node.reference_node_name = reference_node_name.to_string();
251
252 if matches!(&select_expr, Expression::Star(_)) {
254 for (name, source_info) in &scope.sources {
255 let child = LineageNode::new(
256 format!("{}.*", name),
257 Expression::Star(crate::expressions::Star {
258 table: None,
259 except: None,
260 replace: None,
261 rename: None,
262 trailing_comments: vec![],
263 }),
264 source_info.expression.clone(),
265 );
266 node.downstream.push(child);
267 }
268 return Ok(node);
269 }
270
271 let subqueries: Vec<&Expression> =
273 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
274 for sq_expr in subqueries {
275 if let Expression::Subquery(sq) = sq_expr {
276 for sq_scope in &scope.subquery_scopes {
277 if sq_scope.expression == sq.this {
278 if let Ok(child) = to_node_inner(
279 ColumnRef::Index(0),
280 sq_scope,
281 dialect,
282 &column_name,
283 "",
284 "",
285 trim_selects,
286 ancestor_cte_scopes,
287 ) {
288 node.downstream.push(child);
289 }
290 break;
291 }
292 }
293 }
294 }
295
296 let col_refs = find_column_refs_in_expr(&select_expr);
298 for col_ref in col_refs {
299 let col_name = &col_ref.column;
300 if let Some(ref table_id) = col_ref.table {
301 let tbl = &table_id.name;
302 resolve_qualified_column(
303 &mut node,
304 scope,
305 dialect,
306 tbl,
307 col_name,
308 &column_name,
309 trim_selects,
310 &all_cte_scopes,
311 );
312 } else {
313 resolve_unqualified_column(
314 &mut node,
315 scope,
316 dialect,
317 col_name,
318 &column_name,
319 trim_selects,
320 &all_cte_scopes,
321 );
322 }
323 }
324
325 Ok(node)
326}
327
328fn handle_set_operation(
333 column: &ColumnRef<'_>,
334 scope: &Scope,
335 dialect: Option<DialectType>,
336 scope_name: &str,
337 source_name: &str,
338 reference_node_name: &str,
339 trim_selects: bool,
340 ancestor_cte_scopes: &[Scope],
341) -> Result<LineageNode> {
342 let scope_expr = &scope.expression;
343
344 let col_index = match column {
346 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
347 ColumnRef::Index(i) => *i,
348 };
349
350 let col_name = match column {
351 ColumnRef::Name(name) => name.to_string(),
352 ColumnRef::Index(_) => format!("_{col_index}"),
353 };
354
355 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
356 node.source_name = source_name.to_string();
357 node.reference_node_name = reference_node_name.to_string();
358
359 for branch_scope in &scope.union_scopes {
361 if let Ok(child) = to_node_inner(
362 ColumnRef::Index(col_index),
363 branch_scope,
364 dialect,
365 scope_name,
366 "",
367 "",
368 trim_selects,
369 ancestor_cte_scopes,
370 ) {
371 node.downstream.push(child);
372 }
373 }
374
375 Ok(node)
376}
377
378fn resolve_qualified_column(
383 node: &mut LineageNode,
384 scope: &Scope,
385 dialect: Option<DialectType>,
386 table: &str,
387 col_name: &str,
388 parent_name: &str,
389 trim_selects: bool,
390 all_cte_scopes: &[&Scope],
391) {
392 if scope.cte_sources.contains_key(table) {
394 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
395 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
397 if let Ok(child) = to_node_inner(
398 ColumnRef::Name(col_name),
399 child_scope,
400 dialect,
401 parent_name,
402 table,
403 parent_name,
404 trim_selects,
405 &ancestors,
406 ) {
407 node.downstream.push(child);
408 return;
409 }
410 }
411 }
412
413 if let Some(source_info) = scope.sources.get(table) {
415 if source_info.is_scope {
416 if let Some(child_scope) = find_child_scope(scope, table) {
417 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
418 if let Ok(child) = to_node_inner(
419 ColumnRef::Name(col_name),
420 child_scope,
421 dialect,
422 parent_name,
423 table,
424 parent_name,
425 trim_selects,
426 &ancestors,
427 ) {
428 node.downstream.push(child);
429 return;
430 }
431 }
432 }
433 }
434
435 if let Some(source_info) = scope.sources.get(table) {
438 if !source_info.is_scope {
439 node.downstream.push(make_table_column_node_from_source(
440 table,
441 col_name,
442 &source_info.expression,
443 ));
444 return;
445 }
446 }
447
448 node.downstream
450 .push(make_table_column_node(table, col_name));
451}
452
453fn resolve_unqualified_column(
454 node: &mut LineageNode,
455 scope: &Scope,
456 dialect: Option<DialectType>,
457 col_name: &str,
458 parent_name: &str,
459 trim_selects: bool,
460 all_cte_scopes: &[&Scope],
461) {
462 let from_source_names = source_names_from_from_join(scope);
466
467 if from_source_names.len() == 1 {
468 let tbl = &from_source_names[0];
469 resolve_qualified_column(
470 node,
471 scope,
472 dialect,
473 tbl,
474 col_name,
475 parent_name,
476 trim_selects,
477 all_cte_scopes,
478 );
479 return;
480 }
481
482 let child = LineageNode::new(
484 col_name.to_string(),
485 Expression::Column(crate::expressions::Column {
486 name: crate::expressions::Identifier::new(col_name.to_string()),
487 table: None,
488 join_mark: false,
489 trailing_comments: vec![],
490 }),
491 node.source.clone(),
492 );
493 node.downstream.push(child);
494}
495
496fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
497 fn source_name(expr: &Expression) -> Option<String> {
498 match expr {
499 Expression::Table(table) => Some(
500 table
501 .alias
502 .as_ref()
503 .map(|a| a.name.clone())
504 .unwrap_or_else(|| table.name.name.clone()),
505 ),
506 Expression::Subquery(subquery) => {
507 subquery.alias.as_ref().map(|alias| alias.name.clone())
508 }
509 Expression::Paren(paren) => source_name(&paren.this),
510 _ => None,
511 }
512 }
513
514 let effective_expr = match &scope.expression {
515 Expression::Cte(cte) => &cte.this,
516 expr => expr,
517 };
518
519 let mut names = Vec::new();
520 let mut seen = std::collections::HashSet::new();
521
522 if let Expression::Select(select) = effective_expr {
523 if let Some(from) = &select.from {
524 for expr in &from.expressions {
525 if let Some(name) = source_name(expr) {
526 if !name.is_empty() && seen.insert(name.clone()) {
527 names.push(name);
528 }
529 }
530 }
531 }
532 for join in &select.joins {
533 if let Some(name) = source_name(&join.this) {
534 if !name.is_empty() && seen.insert(name.clone()) {
535 names.push(name);
536 }
537 }
538 }
539 }
540
541 names
542}
543
544fn get_alias_or_name(expr: &Expression) -> Option<String> {
550 match expr {
551 Expression::Alias(alias) => Some(alias.alias.name.clone()),
552 Expression::Column(col) => Some(col.name.name.clone()),
553 Expression::Identifier(id) => Some(id.name.clone()),
554 Expression::Star(_) => Some("*".to_string()),
555 _ => None,
556 }
557}
558
559fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
561 match column {
562 ColumnRef::Name(n) => n.to_string(),
563 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
564 }
565}
566
567fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
569 if let Expression::Select(ref select) = scope_expr {
570 match column {
571 ColumnRef::Name(name) => {
572 for expr in &select.expressions {
573 if get_alias_or_name(expr).as_deref() == Some(name) {
574 return Ok(expr.clone());
575 }
576 }
577 Err(crate::error::Error::parse(
578 format!("Cannot find column '{}' in query", name),
579 0,
580 0,
581 ))
582 }
583 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
584 crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0)
585 }),
586 }
587 } else {
588 Err(crate::error::Error::parse(
589 "Expected SELECT expression for column lookup",
590 0,
591 0,
592 ))
593 }
594}
595
596fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
598 let mut expr = set_op_expr;
599 loop {
600 match expr {
601 Expression::Union(u) => expr = &u.left,
602 Expression::Intersect(i) => expr = &i.left,
603 Expression::Except(e) => expr = &e.left,
604 Expression::Select(select) => {
605 for (i, e) in select.expressions.iter().enumerate() {
606 if get_alias_or_name(e).as_deref() == Some(name) {
607 return Ok(i);
608 }
609 }
610 return Err(crate::error::Error::parse(
611 format!("Cannot find column '{}' in set operation", name),
612 0,
613 0,
614 ));
615 }
616 _ => {
617 return Err(crate::error::Error::parse(
618 "Expected SELECT or set operation",
619 0,
620 0,
621 ))
622 }
623 }
624 }
625}
626
627fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
629 if let Expression::Select(select) = select_expr {
630 let mut trimmed = select.as_ref().clone();
631 trimmed.expressions = vec![target_expr.clone()];
632 Expression::Select(Box::new(trimmed))
633 } else {
634 select_expr.clone()
635 }
636}
637
638fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
640 if scope.cte_sources.contains_key(source_name) {
642 for cte_scope in &scope.cte_scopes {
643 if let Expression::Cte(cte) = &cte_scope.expression {
644 if cte.alias.name == source_name {
645 return Some(cte_scope);
646 }
647 }
648 }
649 }
650
651 if let Some(source_info) = scope.sources.get(source_name) {
653 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
654 if let Expression::Subquery(sq) = &source_info.expression {
655 for dt_scope in &scope.derived_table_scopes {
656 if dt_scope.expression == sq.this {
657 return Some(dt_scope);
658 }
659 }
660 }
661 }
662 }
663
664 None
665}
666
667fn find_child_scope_in<'a>(
671 all_cte_scopes: &[&'a Scope],
672 scope: &'a Scope,
673 source_name: &str,
674) -> Option<&'a Scope> {
675 for cte_scope in &scope.cte_scopes {
677 if let Expression::Cte(cte) = &cte_scope.expression {
678 if cte.alias.name == source_name {
679 return Some(cte_scope);
680 }
681 }
682 }
683
684 for cte_scope in all_cte_scopes {
686 if let Expression::Cte(cte) = &cte_scope.expression {
687 if cte.alias.name == source_name {
688 return Some(cte_scope);
689 }
690 }
691 }
692
693 if let Some(source_info) = scope.sources.get(source_name) {
695 if source_info.is_scope {
696 if let Expression::Subquery(sq) = &source_info.expression {
697 for dt_scope in &scope.derived_table_scopes {
698 if dt_scope.expression == sq.this {
699 return Some(dt_scope);
700 }
701 }
702 }
703 }
704 }
705
706 None
707}
708
709fn make_table_column_node(table: &str, column: &str) -> LineageNode {
711 let mut node = LineageNode::new(
712 format!("{}.{}", table, column),
713 Expression::Column(crate::expressions::Column {
714 name: crate::expressions::Identifier::new(column.to_string()),
715 table: Some(crate::expressions::Identifier::new(table.to_string())),
716 join_mark: false,
717 trailing_comments: vec![],
718 }),
719 Expression::Table(crate::expressions::TableRef::new(table)),
720 );
721 node.source_name = table.to_string();
722 node
723}
724
725fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
726 let mut parts: Vec<String> = Vec::new();
727 if let Some(catalog) = &table_ref.catalog {
728 parts.push(catalog.name.clone());
729 }
730 if let Some(schema) = &table_ref.schema {
731 parts.push(schema.name.clone());
732 }
733 parts.push(table_ref.name.name.clone());
734 parts.join(".")
735}
736
737fn make_table_column_node_from_source(
738 table_alias: &str,
739 column: &str,
740 source: &Expression,
741) -> LineageNode {
742 let mut node = LineageNode::new(
743 format!("{}.{}", table_alias, column),
744 Expression::Column(crate::expressions::Column {
745 name: crate::expressions::Identifier::new(column.to_string()),
746 table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
747 join_mark: false,
748 trailing_comments: vec![],
749 }),
750 source.clone(),
751 );
752
753 if let Expression::Table(table_ref) = source {
754 node.source_name = table_name_from_table_ref(table_ref);
755 } else {
756 node.source_name = table_alias.to_string();
757 }
758
759 node
760}
761
762#[derive(Debug, Clone)]
764struct SimpleColumnRef {
765 table: Option<crate::expressions::Identifier>,
766 column: String,
767}
768
769fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
771 let mut refs = Vec::new();
772 collect_column_refs(expr, &mut refs);
773 refs
774}
775
776fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
777 let mut stack: Vec<&Expression> = vec![expr];
778
779 while let Some(current) = stack.pop() {
780 match current {
781 Expression::Column(col) => {
783 refs.push(SimpleColumnRef {
784 table: col.table.clone(),
785 column: col.name.name.clone(),
786 });
787 }
788
789 Expression::Subquery(_) | Expression::Exists(_) => {}
791
792 Expression::And(op)
794 | Expression::Or(op)
795 | Expression::Eq(op)
796 | Expression::Neq(op)
797 | Expression::Lt(op)
798 | Expression::Lte(op)
799 | Expression::Gt(op)
800 | Expression::Gte(op)
801 | Expression::Add(op)
802 | Expression::Sub(op)
803 | Expression::Mul(op)
804 | Expression::Div(op)
805 | Expression::Mod(op)
806 | Expression::BitwiseAnd(op)
807 | Expression::BitwiseOr(op)
808 | Expression::BitwiseXor(op)
809 | Expression::BitwiseLeftShift(op)
810 | Expression::BitwiseRightShift(op)
811 | Expression::Concat(op)
812 | Expression::Adjacent(op)
813 | Expression::TsMatch(op)
814 | Expression::PropertyEQ(op)
815 | Expression::ArrayContainsAll(op)
816 | Expression::ArrayContainedBy(op)
817 | Expression::ArrayOverlaps(op)
818 | Expression::JSONBContainsAllTopKeys(op)
819 | Expression::JSONBContainsAnyTopKeys(op)
820 | Expression::JSONBDeleteAtPath(op)
821 | Expression::ExtendsLeft(op)
822 | Expression::ExtendsRight(op)
823 | Expression::Is(op)
824 | Expression::MemberOf(op)
825 | Expression::NullSafeEq(op)
826 | Expression::NullSafeNeq(op)
827 | Expression::Glob(op)
828 | Expression::Match(op) => {
829 stack.push(&op.left);
830 stack.push(&op.right);
831 }
832
833 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
835 stack.push(&u.this);
836 }
837
838 Expression::Upper(f)
840 | Expression::Lower(f)
841 | Expression::Length(f)
842 | Expression::LTrim(f)
843 | Expression::RTrim(f)
844 | Expression::Reverse(f)
845 | Expression::Abs(f)
846 | Expression::Sqrt(f)
847 | Expression::Cbrt(f)
848 | Expression::Ln(f)
849 | Expression::Exp(f)
850 | Expression::Sign(f)
851 | Expression::Date(f)
852 | Expression::Time(f)
853 | Expression::DateFromUnixDate(f)
854 | Expression::UnixDate(f)
855 | Expression::UnixSeconds(f)
856 | Expression::UnixMillis(f)
857 | Expression::UnixMicros(f)
858 | Expression::TimeStrToDate(f)
859 | Expression::DateToDi(f)
860 | Expression::DiToDate(f)
861 | Expression::TsOrDiToDi(f)
862 | Expression::TsOrDsToDatetime(f)
863 | Expression::TsOrDsToTimestamp(f)
864 | Expression::YearOfWeek(f)
865 | Expression::YearOfWeekIso(f)
866 | Expression::Initcap(f)
867 | Expression::Ascii(f)
868 | Expression::Chr(f)
869 | Expression::Soundex(f)
870 | Expression::ByteLength(f)
871 | Expression::Hex(f)
872 | Expression::LowerHex(f)
873 | Expression::Unicode(f)
874 | Expression::Radians(f)
875 | Expression::Degrees(f)
876 | Expression::Sin(f)
877 | Expression::Cos(f)
878 | Expression::Tan(f)
879 | Expression::Asin(f)
880 | Expression::Acos(f)
881 | Expression::Atan(f)
882 | Expression::IsNan(f)
883 | Expression::IsInf(f)
884 | Expression::ArrayLength(f)
885 | Expression::ArraySize(f)
886 | Expression::Cardinality(f)
887 | Expression::ArrayReverse(f)
888 | Expression::ArrayDistinct(f)
889 | Expression::ArrayFlatten(f)
890 | Expression::ArrayCompact(f)
891 | Expression::Explode(f)
892 | Expression::ExplodeOuter(f)
893 | Expression::ToArray(f)
894 | Expression::MapFromEntries(f)
895 | Expression::MapKeys(f)
896 | Expression::MapValues(f)
897 | Expression::JsonArrayLength(f)
898 | Expression::JsonKeys(f)
899 | Expression::JsonType(f)
900 | Expression::ParseJson(f)
901 | Expression::ToJson(f)
902 | Expression::Typeof(f)
903 | Expression::BitwiseCount(f)
904 | Expression::Year(f)
905 | Expression::Month(f)
906 | Expression::Day(f)
907 | Expression::Hour(f)
908 | Expression::Minute(f)
909 | Expression::Second(f)
910 | Expression::DayOfWeek(f)
911 | Expression::DayOfWeekIso(f)
912 | Expression::DayOfMonth(f)
913 | Expression::DayOfYear(f)
914 | Expression::WeekOfYear(f)
915 | Expression::Quarter(f)
916 | Expression::Epoch(f)
917 | Expression::EpochMs(f)
918 | Expression::TimeStrToUnix(f)
919 | Expression::SHA(f)
920 | Expression::SHA1Digest(f)
921 | Expression::TimeToUnix(f)
922 | Expression::JSONBool(f)
923 | Expression::Int64(f)
924 | Expression::MD5NumberLower64(f)
925 | Expression::MD5NumberUpper64(f)
926 | Expression::DateStrToDate(f)
927 | Expression::DateToDateStr(f) => {
928 stack.push(&f.this);
929 }
930
931 Expression::Power(f)
933 | Expression::NullIf(f)
934 | Expression::IfNull(f)
935 | Expression::Nvl(f)
936 | Expression::UnixToTimeStr(f)
937 | Expression::Contains(f)
938 | Expression::StartsWith(f)
939 | Expression::EndsWith(f)
940 | Expression::Levenshtein(f)
941 | Expression::ModFunc(f)
942 | Expression::Atan2(f)
943 | Expression::IntDiv(f)
944 | Expression::AddMonths(f)
945 | Expression::MonthsBetween(f)
946 | Expression::NextDay(f)
947 | Expression::ArrayContains(f)
948 | Expression::ArrayPosition(f)
949 | Expression::ArrayAppend(f)
950 | Expression::ArrayPrepend(f)
951 | Expression::ArrayUnion(f)
952 | Expression::ArrayExcept(f)
953 | Expression::ArrayRemove(f)
954 | Expression::StarMap(f)
955 | Expression::MapFromArrays(f)
956 | Expression::MapContainsKey(f)
957 | Expression::ElementAt(f)
958 | Expression::JsonMergePatch(f)
959 | Expression::JSONBContains(f)
960 | Expression::JSONBExtract(f) => {
961 stack.push(&f.this);
962 stack.push(&f.expression);
963 }
964
965 Expression::Greatest(f)
967 | Expression::Least(f)
968 | Expression::Coalesce(f)
969 | Expression::ArrayConcat(f)
970 | Expression::ArrayIntersect(f)
971 | Expression::ArrayZip(f)
972 | Expression::MapConcat(f)
973 | Expression::JsonArray(f) => {
974 for e in &f.expressions {
975 stack.push(e);
976 }
977 }
978
979 Expression::Sum(f)
981 | Expression::Avg(f)
982 | Expression::Min(f)
983 | Expression::Max(f)
984 | Expression::ArrayAgg(f)
985 | Expression::CountIf(f)
986 | Expression::Stddev(f)
987 | Expression::StddevPop(f)
988 | Expression::StddevSamp(f)
989 | Expression::Variance(f)
990 | Expression::VarPop(f)
991 | Expression::VarSamp(f)
992 | Expression::Median(f)
993 | Expression::Mode(f)
994 | Expression::First(f)
995 | Expression::Last(f)
996 | Expression::AnyValue(f)
997 | Expression::ApproxDistinct(f)
998 | Expression::ApproxCountDistinct(f)
999 | Expression::LogicalAnd(f)
1000 | Expression::LogicalOr(f)
1001 | Expression::Skewness(f)
1002 | Expression::ArrayConcatAgg(f)
1003 | Expression::ArrayUniqueAgg(f)
1004 | Expression::BoolXorAgg(f)
1005 | Expression::BitwiseAndAgg(f)
1006 | Expression::BitwiseOrAgg(f)
1007 | Expression::BitwiseXorAgg(f) => {
1008 stack.push(&f.this);
1009 if let Some(ref filter) = f.filter {
1010 stack.push(filter);
1011 }
1012 if let Some((ref expr, _)) = f.having_max {
1013 stack.push(expr);
1014 }
1015 if let Some(ref limit) = f.limit {
1016 stack.push(limit);
1017 }
1018 }
1019
1020 Expression::Function(func) => {
1022 for arg in &func.args {
1023 stack.push(arg);
1024 }
1025 }
1026 Expression::AggregateFunction(func) => {
1027 for arg in &func.args {
1028 stack.push(arg);
1029 }
1030 if let Some(ref filter) = func.filter {
1031 stack.push(filter);
1032 }
1033 if let Some(ref limit) = func.limit {
1034 stack.push(limit);
1035 }
1036 }
1037
1038 Expression::WindowFunction(wf) => {
1040 stack.push(&wf.this);
1041 }
1042
1043 Expression::Alias(a) => {
1045 stack.push(&a.this);
1046 }
1047 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1048 stack.push(&c.this);
1049 if let Some(ref fmt) = c.format {
1050 stack.push(fmt);
1051 }
1052 if let Some(ref def) = c.default {
1053 stack.push(def);
1054 }
1055 }
1056 Expression::Paren(p) => {
1057 stack.push(&p.this);
1058 }
1059 Expression::Annotated(a) => {
1060 stack.push(&a.this);
1061 }
1062 Expression::Case(case) => {
1063 if let Some(ref operand) = case.operand {
1064 stack.push(operand);
1065 }
1066 for (cond, result) in &case.whens {
1067 stack.push(cond);
1068 stack.push(result);
1069 }
1070 if let Some(ref else_expr) = case.else_ {
1071 stack.push(else_expr);
1072 }
1073 }
1074 Expression::Collation(c) => {
1075 stack.push(&c.this);
1076 }
1077 Expression::In(i) => {
1078 stack.push(&i.this);
1079 for e in &i.expressions {
1080 stack.push(e);
1081 }
1082 if let Some(ref q) = i.query {
1083 stack.push(q);
1084 }
1085 if let Some(ref u) = i.unnest {
1086 stack.push(u);
1087 }
1088 }
1089 Expression::Between(b) => {
1090 stack.push(&b.this);
1091 stack.push(&b.low);
1092 stack.push(&b.high);
1093 }
1094 Expression::IsNull(n) => {
1095 stack.push(&n.this);
1096 }
1097 Expression::IsTrue(t) | Expression::IsFalse(t) => {
1098 stack.push(&t.this);
1099 }
1100 Expression::IsJson(j) => {
1101 stack.push(&j.this);
1102 }
1103 Expression::Like(l) | Expression::ILike(l) => {
1104 stack.push(&l.left);
1105 stack.push(&l.right);
1106 if let Some(ref esc) = l.escape {
1107 stack.push(esc);
1108 }
1109 }
1110 Expression::SimilarTo(s) => {
1111 stack.push(&s.this);
1112 stack.push(&s.pattern);
1113 if let Some(ref esc) = s.escape {
1114 stack.push(esc);
1115 }
1116 }
1117 Expression::Ordered(o) => {
1118 stack.push(&o.this);
1119 }
1120 Expression::Array(a) => {
1121 for e in &a.expressions {
1122 stack.push(e);
1123 }
1124 }
1125 Expression::Tuple(t) => {
1126 for e in &t.expressions {
1127 stack.push(e);
1128 }
1129 }
1130 Expression::Struct(s) => {
1131 for (_, e) in &s.fields {
1132 stack.push(e);
1133 }
1134 }
1135 Expression::Subscript(s) => {
1136 stack.push(&s.this);
1137 stack.push(&s.index);
1138 }
1139 Expression::Dot(d) => {
1140 stack.push(&d.this);
1141 }
1142 Expression::MethodCall(m) => {
1143 stack.push(&m.this);
1144 for arg in &m.args {
1145 stack.push(arg);
1146 }
1147 }
1148 Expression::ArraySlice(s) => {
1149 stack.push(&s.this);
1150 if let Some(ref start) = s.start {
1151 stack.push(start);
1152 }
1153 if let Some(ref end) = s.end {
1154 stack.push(end);
1155 }
1156 }
1157 Expression::Lambda(l) => {
1158 stack.push(&l.body);
1159 }
1160 Expression::NamedArgument(n) => {
1161 stack.push(&n.value);
1162 }
1163 Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1164 stack.push(e);
1165 }
1166
1167 Expression::Substring(f) => {
1169 stack.push(&f.this);
1170 stack.push(&f.start);
1171 if let Some(ref len) = f.length {
1172 stack.push(len);
1173 }
1174 }
1175 Expression::Trim(f) => {
1176 stack.push(&f.this);
1177 if let Some(ref chars) = f.characters {
1178 stack.push(chars);
1179 }
1180 }
1181 Expression::Replace(f) => {
1182 stack.push(&f.this);
1183 stack.push(&f.old);
1184 stack.push(&f.new);
1185 }
1186 Expression::IfFunc(f) => {
1187 stack.push(&f.condition);
1188 stack.push(&f.true_value);
1189 if let Some(ref fv) = f.false_value {
1190 stack.push(fv);
1191 }
1192 }
1193 Expression::Nvl2(f) => {
1194 stack.push(&f.this);
1195 stack.push(&f.true_value);
1196 stack.push(&f.false_value);
1197 }
1198 Expression::ConcatWs(f) => {
1199 stack.push(&f.separator);
1200 for e in &f.expressions {
1201 stack.push(e);
1202 }
1203 }
1204 Expression::Count(f) => {
1205 if let Some(ref this) = f.this {
1206 stack.push(this);
1207 }
1208 if let Some(ref filter) = f.filter {
1209 stack.push(filter);
1210 }
1211 }
1212 Expression::GroupConcat(f) => {
1213 stack.push(&f.this);
1214 if let Some(ref sep) = f.separator {
1215 stack.push(sep);
1216 }
1217 if let Some(ref filter) = f.filter {
1218 stack.push(filter);
1219 }
1220 }
1221 Expression::StringAgg(f) => {
1222 stack.push(&f.this);
1223 if let Some(ref sep) = f.separator {
1224 stack.push(sep);
1225 }
1226 if let Some(ref filter) = f.filter {
1227 stack.push(filter);
1228 }
1229 if let Some(ref limit) = f.limit {
1230 stack.push(limit);
1231 }
1232 }
1233 Expression::ListAgg(f) => {
1234 stack.push(&f.this);
1235 if let Some(ref sep) = f.separator {
1236 stack.push(sep);
1237 }
1238 if let Some(ref filter) = f.filter {
1239 stack.push(filter);
1240 }
1241 }
1242 Expression::SumIf(f) => {
1243 stack.push(&f.this);
1244 stack.push(&f.condition);
1245 if let Some(ref filter) = f.filter {
1246 stack.push(filter);
1247 }
1248 }
1249 Expression::DateAdd(f) | Expression::DateSub(f) => {
1250 stack.push(&f.this);
1251 stack.push(&f.interval);
1252 }
1253 Expression::DateDiff(f) => {
1254 stack.push(&f.this);
1255 stack.push(&f.expression);
1256 }
1257 Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1258 stack.push(&f.this);
1259 }
1260 Expression::Extract(f) => {
1261 stack.push(&f.this);
1262 }
1263 Expression::Round(f) => {
1264 stack.push(&f.this);
1265 if let Some(ref d) = f.decimals {
1266 stack.push(d);
1267 }
1268 }
1269 Expression::Floor(f) => {
1270 stack.push(&f.this);
1271 if let Some(ref s) = f.scale {
1272 stack.push(s);
1273 }
1274 if let Some(ref t) = f.to {
1275 stack.push(t);
1276 }
1277 }
1278 Expression::Ceil(f) => {
1279 stack.push(&f.this);
1280 if let Some(ref d) = f.decimals {
1281 stack.push(d);
1282 }
1283 if let Some(ref t) = f.to {
1284 stack.push(t);
1285 }
1286 }
1287 Expression::Log(f) => {
1288 stack.push(&f.this);
1289 if let Some(ref b) = f.base {
1290 stack.push(b);
1291 }
1292 }
1293 Expression::AtTimeZone(f) => {
1294 stack.push(&f.this);
1295 stack.push(&f.zone);
1296 }
1297 Expression::Lead(f) | Expression::Lag(f) => {
1298 stack.push(&f.this);
1299 if let Some(ref off) = f.offset {
1300 stack.push(off);
1301 }
1302 if let Some(ref def) = f.default {
1303 stack.push(def);
1304 }
1305 }
1306 Expression::FirstValue(f) | Expression::LastValue(f) => {
1307 stack.push(&f.this);
1308 }
1309 Expression::NthValue(f) => {
1310 stack.push(&f.this);
1311 stack.push(&f.offset);
1312 }
1313 Expression::Position(f) => {
1314 stack.push(&f.substring);
1315 stack.push(&f.string);
1316 if let Some(ref start) = f.start {
1317 stack.push(start);
1318 }
1319 }
1320 Expression::Decode(f) => {
1321 stack.push(&f.this);
1322 for (search, result) in &f.search_results {
1323 stack.push(search);
1324 stack.push(result);
1325 }
1326 if let Some(ref def) = f.default {
1327 stack.push(def);
1328 }
1329 }
1330 Expression::CharFunc(f) => {
1331 for arg in &f.args {
1332 stack.push(arg);
1333 }
1334 }
1335 Expression::ArraySort(f) => {
1336 stack.push(&f.this);
1337 if let Some(ref cmp) = f.comparator {
1338 stack.push(cmp);
1339 }
1340 }
1341 Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1342 stack.push(&f.this);
1343 stack.push(&f.separator);
1344 if let Some(ref nr) = f.null_replacement {
1345 stack.push(nr);
1346 }
1347 }
1348 Expression::ArrayFilter(f) => {
1349 stack.push(&f.this);
1350 stack.push(&f.filter);
1351 }
1352 Expression::ArrayTransform(f) => {
1353 stack.push(&f.this);
1354 stack.push(&f.transform);
1355 }
1356 Expression::Sequence(f)
1357 | Expression::Generate(f)
1358 | Expression::ExplodingGenerateSeries(f) => {
1359 stack.push(&f.start);
1360 stack.push(&f.stop);
1361 if let Some(ref step) = f.step {
1362 stack.push(step);
1363 }
1364 }
1365 Expression::JsonExtract(f)
1366 | Expression::JsonExtractScalar(f)
1367 | Expression::JsonQuery(f)
1368 | Expression::JsonValue(f) => {
1369 stack.push(&f.this);
1370 stack.push(&f.path);
1371 }
1372 Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1373 stack.push(&f.this);
1374 for p in &f.paths {
1375 stack.push(p);
1376 }
1377 }
1378 Expression::JsonObject(f) => {
1379 for (k, v) in &f.pairs {
1380 stack.push(k);
1381 stack.push(v);
1382 }
1383 }
1384 Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1385 stack.push(&f.this);
1386 for (path, val) in &f.path_values {
1387 stack.push(path);
1388 stack.push(val);
1389 }
1390 }
1391 Expression::Overlay(f) => {
1392 stack.push(&f.this);
1393 stack.push(&f.replacement);
1394 stack.push(&f.from);
1395 if let Some(ref len) = f.length {
1396 stack.push(len);
1397 }
1398 }
1399 Expression::Convert(f) => {
1400 stack.push(&f.this);
1401 if let Some(ref style) = f.style {
1402 stack.push(style);
1403 }
1404 }
1405 Expression::ApproxPercentile(f) => {
1406 stack.push(&f.this);
1407 stack.push(&f.percentile);
1408 if let Some(ref acc) = f.accuracy {
1409 stack.push(acc);
1410 }
1411 if let Some(ref filter) = f.filter {
1412 stack.push(filter);
1413 }
1414 }
1415 Expression::Percentile(f)
1416 | Expression::PercentileCont(f)
1417 | Expression::PercentileDisc(f) => {
1418 stack.push(&f.this);
1419 stack.push(&f.percentile);
1420 if let Some(ref filter) = f.filter {
1421 stack.push(filter);
1422 }
1423 }
1424 Expression::WithinGroup(f) => {
1425 stack.push(&f.this);
1426 }
1427 Expression::Left(f) | Expression::Right(f) => {
1428 stack.push(&f.this);
1429 stack.push(&f.length);
1430 }
1431 Expression::Repeat(f) => {
1432 stack.push(&f.this);
1433 stack.push(&f.times);
1434 }
1435 Expression::Lpad(f) | Expression::Rpad(f) => {
1436 stack.push(&f.this);
1437 stack.push(&f.length);
1438 if let Some(ref fill) = f.fill {
1439 stack.push(fill);
1440 }
1441 }
1442 Expression::Split(f) => {
1443 stack.push(&f.this);
1444 stack.push(&f.delimiter);
1445 }
1446 Expression::RegexpLike(f) => {
1447 stack.push(&f.this);
1448 stack.push(&f.pattern);
1449 if let Some(ref flags) = f.flags {
1450 stack.push(flags);
1451 }
1452 }
1453 Expression::RegexpReplace(f) => {
1454 stack.push(&f.this);
1455 stack.push(&f.pattern);
1456 stack.push(&f.replacement);
1457 if let Some(ref flags) = f.flags {
1458 stack.push(flags);
1459 }
1460 }
1461 Expression::RegexpExtract(f) => {
1462 stack.push(&f.this);
1463 stack.push(&f.pattern);
1464 if let Some(ref group) = f.group {
1465 stack.push(group);
1466 }
1467 }
1468 Expression::ToDate(f) => {
1469 stack.push(&f.this);
1470 if let Some(ref fmt) = f.format {
1471 stack.push(fmt);
1472 }
1473 }
1474 Expression::ToTimestamp(f) => {
1475 stack.push(&f.this);
1476 if let Some(ref fmt) = f.format {
1477 stack.push(fmt);
1478 }
1479 }
1480 Expression::DateFormat(f) | Expression::FormatDate(f) => {
1481 stack.push(&f.this);
1482 stack.push(&f.format);
1483 }
1484 Expression::LastDay(f) => {
1485 stack.push(&f.this);
1486 }
1487 Expression::FromUnixtime(f) => {
1488 stack.push(&f.this);
1489 if let Some(ref fmt) = f.format {
1490 stack.push(fmt);
1491 }
1492 }
1493 Expression::UnixTimestamp(f) => {
1494 if let Some(ref this) = f.this {
1495 stack.push(this);
1496 }
1497 if let Some(ref fmt) = f.format {
1498 stack.push(fmt);
1499 }
1500 }
1501 Expression::MakeDate(f) => {
1502 stack.push(&f.year);
1503 stack.push(&f.month);
1504 stack.push(&f.day);
1505 }
1506 Expression::MakeTimestamp(f) => {
1507 stack.push(&f.year);
1508 stack.push(&f.month);
1509 stack.push(&f.day);
1510 stack.push(&f.hour);
1511 stack.push(&f.minute);
1512 stack.push(&f.second);
1513 if let Some(ref tz) = f.timezone {
1514 stack.push(tz);
1515 }
1516 }
1517 Expression::TruncFunc(f) => {
1518 stack.push(&f.this);
1519 if let Some(ref d) = f.decimals {
1520 stack.push(d);
1521 }
1522 }
1523 Expression::ArrayFunc(f) => {
1524 for e in &f.expressions {
1525 stack.push(e);
1526 }
1527 }
1528 Expression::Unnest(f) => {
1529 stack.push(&f.this);
1530 for e in &f.expressions {
1531 stack.push(e);
1532 }
1533 }
1534 Expression::StructFunc(f) => {
1535 for (_, e) in &f.fields {
1536 stack.push(e);
1537 }
1538 }
1539 Expression::StructExtract(f) => {
1540 stack.push(&f.this);
1541 }
1542 Expression::NamedStruct(f) => {
1543 for (k, v) in &f.pairs {
1544 stack.push(k);
1545 stack.push(v);
1546 }
1547 }
1548 Expression::MapFunc(f) => {
1549 for k in &f.keys {
1550 stack.push(k);
1551 }
1552 for v in &f.values {
1553 stack.push(v);
1554 }
1555 }
1556 Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1557 stack.push(&f.this);
1558 stack.push(&f.transform);
1559 }
1560 Expression::JsonArrayAgg(f) => {
1561 stack.push(&f.this);
1562 if let Some(ref filter) = f.filter {
1563 stack.push(filter);
1564 }
1565 }
1566 Expression::JsonObjectAgg(f) => {
1567 stack.push(&f.key);
1568 stack.push(&f.value);
1569 if let Some(ref filter) = f.filter {
1570 stack.push(filter);
1571 }
1572 }
1573 Expression::NTile(f) => {
1574 if let Some(ref n) = f.num_buckets {
1575 stack.push(n);
1576 }
1577 }
1578 Expression::Rand(f) => {
1579 if let Some(ref s) = f.seed {
1580 stack.push(s);
1581 }
1582 if let Some(ref lo) = f.lower {
1583 stack.push(lo);
1584 }
1585 if let Some(ref hi) = f.upper {
1586 stack.push(hi);
1587 }
1588 }
1589 Expression::Any(q) | Expression::All(q) => {
1590 stack.push(&q.this);
1591 stack.push(&q.subquery);
1592 }
1593 Expression::Overlaps(o) => {
1594 if let Some(ref this) = o.this {
1595 stack.push(this);
1596 }
1597 if let Some(ref expr) = o.expression {
1598 stack.push(expr);
1599 }
1600 if let Some(ref ls) = o.left_start {
1601 stack.push(ls);
1602 }
1603 if let Some(ref le) = o.left_end {
1604 stack.push(le);
1605 }
1606 if let Some(ref rs) = o.right_start {
1607 stack.push(rs);
1608 }
1609 if let Some(ref re) = o.right_end {
1610 stack.push(re);
1611 }
1612 }
1613 Expression::Interval(i) => {
1614 if let Some(ref this) = i.this {
1615 stack.push(this);
1616 }
1617 }
1618 Expression::TimeStrToTime(f) => {
1619 stack.push(&f.this);
1620 if let Some(ref zone) = f.zone {
1621 stack.push(zone);
1622 }
1623 }
1624 Expression::JSONBExtractScalar(f) => {
1625 stack.push(&f.this);
1626 stack.push(&f.expression);
1627 if let Some(ref jt) = f.json_type {
1628 stack.push(jt);
1629 }
1630 }
1631
1632 _ => {}
1637 }
1638 }
1639}
1640
1641#[cfg(test)]
1646mod tests {
1647 use super::*;
1648 use crate::dialects::{Dialect, DialectType};
1649
1650 fn parse(sql: &str) -> Expression {
1651 let dialect = Dialect::get(DialectType::Generic);
1652 let ast = dialect.parse(sql).unwrap();
1653 ast.into_iter().next().unwrap()
1654 }
1655
1656 #[test]
1657 fn test_simple_lineage() {
1658 let expr = parse("SELECT a FROM t");
1659 let node = lineage("a", &expr, None, false).unwrap();
1660
1661 assert_eq!(node.name, "a");
1662 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1663 let names = node.downstream_names();
1665 assert!(
1666 names.iter().any(|n| n == "t.a"),
1667 "Expected t.a in downstream, got: {:?}",
1668 names
1669 );
1670 }
1671
1672 #[test]
1673 fn test_lineage_walk() {
1674 let root = LineageNode {
1675 name: "col_a".to_string(),
1676 expression: Expression::Null(crate::expressions::Null),
1677 source: Expression::Null(crate::expressions::Null),
1678 downstream: vec![LineageNode::new(
1679 "t.a",
1680 Expression::Null(crate::expressions::Null),
1681 Expression::Null(crate::expressions::Null),
1682 )],
1683 source_name: String::new(),
1684 reference_node_name: String::new(),
1685 };
1686
1687 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1688 assert_eq!(names.len(), 2);
1689 assert_eq!(names[0], "col_a");
1690 assert_eq!(names[1], "t.a");
1691 }
1692
1693 #[test]
1694 fn test_aliased_column() {
1695 let expr = parse("SELECT a + 1 AS b FROM t");
1696 let node = lineage("b", &expr, None, false).unwrap();
1697
1698 assert_eq!(node.name, "b");
1699 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1701 assert!(
1702 all_names.iter().any(|n| n.contains("a")),
1703 "Expected to trace to column a, got: {:?}",
1704 all_names
1705 );
1706 }
1707
1708 #[test]
1709 fn test_qualified_column() {
1710 let expr = parse("SELECT t.a FROM t");
1711 let node = lineage("a", &expr, None, false).unwrap();
1712
1713 assert_eq!(node.name, "a");
1714 let names = node.downstream_names();
1715 assert!(
1716 names.iter().any(|n| n == "t.a"),
1717 "Expected t.a, got: {:?}",
1718 names
1719 );
1720 }
1721
1722 #[test]
1723 fn test_unqualified_column() {
1724 let expr = parse("SELECT a FROM t");
1725 let node = lineage("a", &expr, None, false).unwrap();
1726
1727 let names = node.downstream_names();
1729 assert!(
1730 names.iter().any(|n| n == "t.a"),
1731 "Expected t.a, got: {:?}",
1732 names
1733 );
1734 }
1735
1736 #[test]
1737 fn test_lineage_join() {
1738 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1739
1740 let node_a = lineage("a", &expr, None, false).unwrap();
1741 let names_a = node_a.downstream_names();
1742 assert!(
1743 names_a.iter().any(|n| n == "t.a"),
1744 "Expected t.a, got: {:?}",
1745 names_a
1746 );
1747
1748 let node_b = lineage("b", &expr, None, false).unwrap();
1749 let names_b = node_b.downstream_names();
1750 assert!(
1751 names_b.iter().any(|n| n == "s.b"),
1752 "Expected s.b, got: {:?}",
1753 names_b
1754 );
1755 }
1756
1757 #[test]
1758 fn test_lineage_alias_leaf_has_resolved_source_name() {
1759 let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
1760 let node = lineage("col1", &expr, None, false).unwrap();
1761
1762 let names = node.downstream_names();
1764 assert!(
1765 names.iter().any(|n| n == "t1.col1"),
1766 "Expected aliased column edge t1.col1, got: {:?}",
1767 names
1768 );
1769
1770 let leaf = node
1772 .downstream
1773 .iter()
1774 .find(|n| n.name == "t1.col1")
1775 .expect("Expected t1.col1 leaf");
1776 assert_eq!(leaf.source_name, "table1");
1777 match &leaf.source {
1778 Expression::Table(table) => assert_eq!(table.name.name, "table1"),
1779 _ => panic!("Expected leaf source to be a table expression"),
1780 }
1781 }
1782
1783 #[test]
1784 fn test_lineage_derived_table() {
1785 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
1786 let node = lineage("a", &expr, None, false).unwrap();
1787
1788 assert_eq!(node.name, "a");
1789 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1791 assert!(
1792 all_names.iter().any(|n| n == "t.a"),
1793 "Expected to trace through derived table to t.a, got: {:?}",
1794 all_names
1795 );
1796 }
1797
1798 #[test]
1799 fn test_lineage_cte() {
1800 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
1801 let node = lineage("a", &expr, None, false).unwrap();
1802
1803 assert_eq!(node.name, "a");
1804 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1805 assert!(
1806 all_names.iter().any(|n| n == "t.a"),
1807 "Expected to trace through CTE to t.a, got: {:?}",
1808 all_names
1809 );
1810 }
1811
1812 #[test]
1813 fn test_lineage_union() {
1814 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
1815 let node = lineage("a", &expr, None, false).unwrap();
1816
1817 assert_eq!(node.name, "a");
1818 assert_eq!(
1820 node.downstream.len(),
1821 2,
1822 "Expected 2 branches for UNION, got {}",
1823 node.downstream.len()
1824 );
1825 }
1826
1827 #[test]
1828 fn test_lineage_cte_union() {
1829 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
1830 let node = lineage("a", &expr, None, false).unwrap();
1831
1832 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1834 assert!(
1835 all_names.len() >= 3,
1836 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
1837 all_names
1838 );
1839 }
1840
1841 #[test]
1842 fn test_lineage_star() {
1843 let expr = parse("SELECT * FROM t");
1844 let node = lineage("*", &expr, None, false).unwrap();
1845
1846 assert_eq!(node.name, "*");
1847 assert!(
1849 !node.downstream.is_empty(),
1850 "Star should produce downstream nodes"
1851 );
1852 }
1853
1854 #[test]
1855 fn test_lineage_subquery_in_select() {
1856 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
1857 let node = lineage("x", &expr, None, false).unwrap();
1858
1859 assert_eq!(node.name, "x");
1860 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1862 assert!(
1863 all_names.len() >= 2,
1864 "Expected tracing into scalar subquery, got: {:?}",
1865 all_names
1866 );
1867 }
1868
1869 #[test]
1870 fn test_lineage_multiple_columns() {
1871 let expr = parse("SELECT a, b FROM t");
1872
1873 let node_a = lineage("a", &expr, None, false).unwrap();
1874 let node_b = lineage("b", &expr, None, false).unwrap();
1875
1876 assert_eq!(node_a.name, "a");
1877 assert_eq!(node_b.name, "b");
1878
1879 let names_a = node_a.downstream_names();
1881 let names_b = node_b.downstream_names();
1882 assert!(names_a.iter().any(|n| n == "t.a"));
1883 assert!(names_b.iter().any(|n| n == "t.b"));
1884 }
1885
1886 #[test]
1887 fn test_get_source_tables() {
1888 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1889 let node = lineage("a", &expr, None, false).unwrap();
1890
1891 let tables = get_source_tables(&node);
1892 assert!(
1893 tables.contains("t"),
1894 "Expected source table 't', got: {:?}",
1895 tables
1896 );
1897 }
1898
1899 #[test]
1900 fn test_lineage_column_not_found() {
1901 let expr = parse("SELECT a FROM t");
1902 let result = lineage("nonexistent", &expr, None, false);
1903 assert!(result.is_err());
1904 }
1905
1906 #[test]
1907 fn test_lineage_nested_cte() {
1908 let expr = parse(
1909 "WITH cte1 AS (SELECT a FROM t), \
1910 cte2 AS (SELECT a FROM cte1) \
1911 SELECT a FROM cte2",
1912 );
1913 let node = lineage("a", &expr, None, false).unwrap();
1914
1915 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1917 assert!(
1918 all_names.len() >= 3,
1919 "Expected to trace through nested CTEs, got: {:?}",
1920 all_names
1921 );
1922 }
1923
1924 #[test]
1925 fn test_trim_selects_true() {
1926 let expr = parse("SELECT a, b, c FROM t");
1927 let node = lineage("a", &expr, None, true).unwrap();
1928
1929 if let Expression::Select(select) = &node.source {
1931 assert_eq!(
1932 select.expressions.len(),
1933 1,
1934 "Trimmed source should have 1 expression, got {}",
1935 select.expressions.len()
1936 );
1937 } else {
1938 panic!("Expected Select source");
1939 }
1940 }
1941
1942 #[test]
1943 fn test_trim_selects_false() {
1944 let expr = parse("SELECT a, b, c FROM t");
1945 let node = lineage("a", &expr, None, false).unwrap();
1946
1947 if let Expression::Select(select) = &node.source {
1949 assert_eq!(
1950 select.expressions.len(),
1951 3,
1952 "Untrimmed source should have 3 expressions"
1953 );
1954 } else {
1955 panic!("Expected Select source");
1956 }
1957 }
1958
1959 #[test]
1960 fn test_lineage_expression_in_select() {
1961 let expr = parse("SELECT a + b AS c FROM t");
1962 let node = lineage("c", &expr, None, false).unwrap();
1963
1964 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1966 assert!(
1967 all_names.len() >= 3,
1968 "Expected to trace a + b to both columns, got: {:?}",
1969 all_names
1970 );
1971 }
1972
1973 #[test]
1974 fn test_set_operation_by_index() {
1975 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
1976
1977 let node = lineage("a", &expr, None, false).unwrap();
1979
1980 assert_eq!(node.downstream.len(), 2);
1982 }
1983
1984 fn print_node(node: &LineageNode, indent: usize) {
1987 let pad = " ".repeat(indent);
1988 println!(
1989 "{pad}name={:?} source_name={:?}",
1990 node.name, node.source_name
1991 );
1992 for child in &node.downstream {
1993 print_node(child, indent + 1);
1994 }
1995 }
1996
1997 #[test]
1998 fn test_issue18_repro() {
1999 let query = "SELECT UPPER(name) as upper_name FROM users";
2001 println!("Query: {query}\n");
2002
2003 let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2004 let exprs = dialect.parse(query).unwrap();
2005 let expr = &exprs[0];
2006
2007 let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2008 println!("lineage(\"upper_name\"):");
2009 print_node(&node, 1);
2010
2011 let names = node.downstream_names();
2012 assert!(
2013 names.iter().any(|n| n == "users.name"),
2014 "Expected users.name in downstream, got: {:?}",
2015 names
2016 );
2017 }
2018
2019 #[test]
2020 fn test_lineage_upper_function() {
2021 let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2022 let node = lineage("upper_name", &expr, None, false).unwrap();
2023
2024 let names = node.downstream_names();
2025 assert!(
2026 names.iter().any(|n| n == "users.name"),
2027 "Expected users.name in downstream, got: {:?}",
2028 names
2029 );
2030 }
2031
2032 #[test]
2033 fn test_lineage_round_function() {
2034 let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2035 let node = lineage("rounded", &expr, None, false).unwrap();
2036
2037 let names = node.downstream_names();
2038 assert!(
2039 names.iter().any(|n| n == "products.price"),
2040 "Expected products.price in downstream, got: {:?}",
2041 names
2042 );
2043 }
2044
2045 #[test]
2046 fn test_lineage_coalesce_function() {
2047 let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2048 let node = lineage("val", &expr, None, false).unwrap();
2049
2050 let names = node.downstream_names();
2051 assert!(
2052 names.iter().any(|n| n == "t.a"),
2053 "Expected t.a in downstream, got: {:?}",
2054 names
2055 );
2056 assert!(
2057 names.iter().any(|n| n == "t.b"),
2058 "Expected t.b in downstream, got: {:?}",
2059 names
2060 );
2061 }
2062
2063 #[test]
2064 fn test_lineage_count_function() {
2065 let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2066 let node = lineage("cnt", &expr, None, false).unwrap();
2067
2068 let names = node.downstream_names();
2069 assert!(
2070 names.iter().any(|n| n == "t.id"),
2071 "Expected t.id in downstream, got: {:?}",
2072 names
2073 );
2074 }
2075
2076 #[test]
2077 fn test_lineage_sum_function() {
2078 let expr = parse("SELECT SUM(amount) AS total FROM t");
2079 let node = lineage("total", &expr, None, false).unwrap();
2080
2081 let names = node.downstream_names();
2082 assert!(
2083 names.iter().any(|n| n == "t.amount"),
2084 "Expected t.amount in downstream, got: {:?}",
2085 names
2086 );
2087 }
2088
2089 #[test]
2090 fn test_lineage_case_with_nested_functions() {
2091 let expr =
2092 parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2093 let node = lineage("result", &expr, None, false).unwrap();
2094
2095 let names = node.downstream_names();
2096 assert!(
2097 names.iter().any(|n| n == "t.x"),
2098 "Expected t.x in downstream, got: {:?}",
2099 names
2100 );
2101 assert!(
2102 names.iter().any(|n| n == "t.name"),
2103 "Expected t.name in downstream, got: {:?}",
2104 names
2105 );
2106 }
2107
2108 #[test]
2109 fn test_lineage_substring_function() {
2110 let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2111 let node = lineage("short", &expr, None, false).unwrap();
2112
2113 let names = node.downstream_names();
2114 assert!(
2115 names.iter().any(|n| n == "t.name"),
2116 "Expected t.name in downstream, got: {:?}",
2117 names
2118 );
2119 }
2120}