1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::optimizer::annotate_types::annotate_types;
11use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
12use crate::schema::Schema;
13use crate::scope::{build_scope, Scope};
14use crate::traversal::ExpressionWalk;
15use crate::{Error, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct LineageNode {
22 pub name: String,
24 pub expression: Expression,
26 pub source: Expression,
28 pub downstream: Vec<LineageNode>,
30 pub source_name: String,
32 pub reference_node_name: String,
34}
35
36impl LineageNode {
37 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
39 Self {
40 name: name.into(),
41 expression,
42 source,
43 downstream: Vec::new(),
44 source_name: String::new(),
45 reference_node_name: String::new(),
46 }
47 }
48
49 pub fn walk(&self) -> LineageWalker<'_> {
51 LineageWalker { stack: vec![self] }
52 }
53
54 pub fn downstream_names(&self) -> Vec<String> {
56 self.downstream.iter().map(|n| n.name.clone()).collect()
57 }
58}
59
60pub struct LineageWalker<'a> {
62 stack: Vec<&'a LineageNode>,
63}
64
65impl<'a> Iterator for LineageWalker<'a> {
66 type Item = &'a LineageNode;
67
68 fn next(&mut self) -> Option<Self::Item> {
69 if let Some(node) = self.stack.pop() {
70 for child in node.downstream.iter().rev() {
72 self.stack.push(child);
73 }
74 Some(node)
75 } else {
76 None
77 }
78 }
79}
80
81enum ColumnRef<'a> {
87 Name(&'a str),
88 Index(usize),
89}
90
91pub fn lineage(
117 column: &str,
118 sql: &Expression,
119 dialect: Option<DialectType>,
120 trim_selects: bool,
121) -> Result<LineageNode> {
122 lineage_from_expression(column, sql, dialect, trim_selects)
123}
124
125pub fn lineage_with_schema(
141 column: &str,
142 sql: &Expression,
143 schema: Option<&dyn Schema>,
144 dialect: Option<DialectType>,
145 trim_selects: bool,
146) -> Result<LineageNode> {
147 let mut qualified_expression = if let Some(schema) = schema {
148 let options = if let Some(dialect_type) = dialect.or_else(|| schema.dialect()) {
149 QualifyColumnsOptions::new().with_dialect(dialect_type)
150 } else {
151 QualifyColumnsOptions::new()
152 };
153
154 qualify_columns(sql.clone(), schema, &options).map_err(|e| {
155 Error::internal(format!("Lineage qualification failed with schema: {}", e))
156 })?
157 } else {
158 sql.clone()
159 };
160
161 annotate_types(&mut qualified_expression, schema, dialect);
163
164 lineage_from_expression(column, &qualified_expression, dialect, trim_selects)
165}
166
167fn lineage_from_expression(
168 column: &str,
169 sql: &Expression,
170 dialect: Option<DialectType>,
171 trim_selects: bool,
172) -> Result<LineageNode> {
173 let scope = build_scope(sql);
174 to_node(
175 ColumnRef::Name(column),
176 &scope,
177 dialect,
178 "",
179 "",
180 "",
181 trim_selects,
182 )
183}
184
185pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
187 let mut tables = HashSet::new();
188 collect_source_tables(node, &mut tables);
189 tables
190}
191
192pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
194 if let Expression::Table(table) = &node.source {
195 tables.insert(table.name.name.clone());
196 }
197 for child in &node.downstream {
198 collect_source_tables(child, tables);
199 }
200}
201
202fn to_node(
208 column: ColumnRef<'_>,
209 scope: &Scope,
210 dialect: Option<DialectType>,
211 scope_name: &str,
212 source_name: &str,
213 reference_node_name: &str,
214 trim_selects: bool,
215) -> Result<LineageNode> {
216 to_node_inner(
217 column,
218 scope,
219 dialect,
220 scope_name,
221 source_name,
222 reference_node_name,
223 trim_selects,
224 &[],
225 )
226}
227
228fn to_node_inner(
229 column: ColumnRef<'_>,
230 scope: &Scope,
231 dialect: Option<DialectType>,
232 scope_name: &str,
233 source_name: &str,
234 reference_node_name: &str,
235 trim_selects: bool,
236 ancestor_cte_scopes: &[Scope],
237) -> Result<LineageNode> {
238 let scope_expr = &scope.expression;
239
240 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
242 for s in ancestor_cte_scopes {
243 all_cte_scopes.push(s);
244 }
245
246 let effective_expr = match scope_expr {
249 Expression::Cte(cte) => &cte.this,
250 other => other,
251 };
252
253 if matches!(
255 effective_expr,
256 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
257 ) {
258 if matches!(scope_expr, Expression::Cte(_)) {
260 let mut inner_scope = Scope::new(effective_expr.clone());
261 inner_scope.union_scopes = scope.union_scopes.clone();
262 inner_scope.sources = scope.sources.clone();
263 inner_scope.cte_sources = scope.cte_sources.clone();
264 inner_scope.cte_scopes = scope.cte_scopes.clone();
265 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
266 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
267 return handle_set_operation(
268 &column,
269 &inner_scope,
270 dialect,
271 scope_name,
272 source_name,
273 reference_node_name,
274 trim_selects,
275 ancestor_cte_scopes,
276 );
277 }
278 return handle_set_operation(
279 &column,
280 scope,
281 dialect,
282 scope_name,
283 source_name,
284 reference_node_name,
285 trim_selects,
286 ancestor_cte_scopes,
287 );
288 }
289
290 let select_expr = find_select_expr(effective_expr, &column)?;
292 let column_name = resolve_column_name(&column, &select_expr);
293
294 let node_source = if trim_selects {
296 trim_source(effective_expr, &select_expr)
297 } else {
298 effective_expr.clone()
299 };
300
301 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
303 node.source_name = source_name.to_string();
304 node.reference_node_name = reference_node_name.to_string();
305
306 if matches!(&select_expr, Expression::Star(_)) {
308 for (name, source_info) in &scope.sources {
309 let child = LineageNode::new(
310 format!("{}.*", name),
311 Expression::Star(crate::expressions::Star {
312 table: None,
313 except: None,
314 replace: None,
315 rename: None,
316 trailing_comments: vec![],
317 span: None,
318 }),
319 source_info.expression.clone(),
320 );
321 node.downstream.push(child);
322 }
323 return Ok(node);
324 }
325
326 let subqueries: Vec<&Expression> =
328 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
329 for sq_expr in subqueries {
330 if let Expression::Subquery(sq) = sq_expr {
331 for sq_scope in &scope.subquery_scopes {
332 if sq_scope.expression == sq.this {
333 if let Ok(child) = to_node_inner(
334 ColumnRef::Index(0),
335 sq_scope,
336 dialect,
337 &column_name,
338 "",
339 "",
340 trim_selects,
341 ancestor_cte_scopes,
342 ) {
343 node.downstream.push(child);
344 }
345 break;
346 }
347 }
348 }
349 }
350
351 let col_refs = find_column_refs_in_expr(&select_expr);
353 for col_ref in col_refs {
354 let col_name = &col_ref.column;
355 if let Some(ref table_id) = col_ref.table {
356 let tbl = &table_id.name;
357 resolve_qualified_column(
358 &mut node,
359 scope,
360 dialect,
361 tbl,
362 col_name,
363 &column_name,
364 trim_selects,
365 &all_cte_scopes,
366 );
367 } else {
368 resolve_unqualified_column(
369 &mut node,
370 scope,
371 dialect,
372 col_name,
373 &column_name,
374 trim_selects,
375 &all_cte_scopes,
376 );
377 }
378 }
379
380 Ok(node)
381}
382
383fn handle_set_operation(
388 column: &ColumnRef<'_>,
389 scope: &Scope,
390 dialect: Option<DialectType>,
391 scope_name: &str,
392 source_name: &str,
393 reference_node_name: &str,
394 trim_selects: bool,
395 ancestor_cte_scopes: &[Scope],
396) -> Result<LineageNode> {
397 let scope_expr = &scope.expression;
398
399 let col_index = match column {
401 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
402 ColumnRef::Index(i) => *i,
403 };
404
405 let col_name = match column {
406 ColumnRef::Name(name) => name.to_string(),
407 ColumnRef::Index(_) => format!("_{col_index}"),
408 };
409
410 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
411 node.source_name = source_name.to_string();
412 node.reference_node_name = reference_node_name.to_string();
413
414 for branch_scope in &scope.union_scopes {
416 if let Ok(child) = to_node_inner(
417 ColumnRef::Index(col_index),
418 branch_scope,
419 dialect,
420 scope_name,
421 "",
422 "",
423 trim_selects,
424 ancestor_cte_scopes,
425 ) {
426 node.downstream.push(child);
427 }
428 }
429
430 Ok(node)
431}
432
433fn resolve_qualified_column(
438 node: &mut LineageNode,
439 scope: &Scope,
440 dialect: Option<DialectType>,
441 table: &str,
442 col_name: &str,
443 parent_name: &str,
444 trim_selects: bool,
445 all_cte_scopes: &[&Scope],
446) {
447 if scope.cte_sources.contains_key(table) {
449 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
450 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
452 if let Ok(child) = to_node_inner(
453 ColumnRef::Name(col_name),
454 child_scope,
455 dialect,
456 parent_name,
457 table,
458 parent_name,
459 trim_selects,
460 &ancestors,
461 ) {
462 node.downstream.push(child);
463 return;
464 }
465 }
466 }
467
468 if let Some(source_info) = scope.sources.get(table) {
470 if source_info.is_scope {
471 if let Some(child_scope) = find_child_scope(scope, table) {
472 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
473 if let Ok(child) = to_node_inner(
474 ColumnRef::Name(col_name),
475 child_scope,
476 dialect,
477 parent_name,
478 table,
479 parent_name,
480 trim_selects,
481 &ancestors,
482 ) {
483 node.downstream.push(child);
484 return;
485 }
486 }
487 }
488 }
489
490 if let Some(source_info) = scope.sources.get(table) {
493 if !source_info.is_scope {
494 node.downstream.push(make_table_column_node_from_source(
495 table,
496 col_name,
497 &source_info.expression,
498 ));
499 return;
500 }
501 }
502
503 node.downstream
505 .push(make_table_column_node(table, col_name));
506}
507
508fn resolve_unqualified_column(
509 node: &mut LineageNode,
510 scope: &Scope,
511 dialect: Option<DialectType>,
512 col_name: &str,
513 parent_name: &str,
514 trim_selects: bool,
515 all_cte_scopes: &[&Scope],
516) {
517 let from_source_names = source_names_from_from_join(scope);
521
522 if from_source_names.len() == 1 {
523 let tbl = &from_source_names[0];
524 resolve_qualified_column(
525 node,
526 scope,
527 dialect,
528 tbl,
529 col_name,
530 parent_name,
531 trim_selects,
532 all_cte_scopes,
533 );
534 return;
535 }
536
537 let child = LineageNode::new(
539 col_name.to_string(),
540 Expression::Column(crate::expressions::Column {
541 name: crate::expressions::Identifier::new(col_name.to_string()),
542 table: None,
543 join_mark: false,
544 trailing_comments: vec![],
545 span: None,
546 inferred_type: None,
547 }),
548 node.source.clone(),
549 );
550 node.downstream.push(child);
551}
552
553fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
554 fn source_name(expr: &Expression) -> Option<String> {
555 match expr {
556 Expression::Table(table) => Some(
557 table
558 .alias
559 .as_ref()
560 .map(|a| a.name.clone())
561 .unwrap_or_else(|| table.name.name.clone()),
562 ),
563 Expression::Subquery(subquery) => {
564 subquery.alias.as_ref().map(|alias| alias.name.clone())
565 }
566 Expression::Paren(paren) => source_name(&paren.this),
567 _ => None,
568 }
569 }
570
571 let effective_expr = match &scope.expression {
572 Expression::Cte(cte) => &cte.this,
573 expr => expr,
574 };
575
576 let mut names = Vec::new();
577 let mut seen = std::collections::HashSet::new();
578
579 if let Expression::Select(select) = effective_expr {
580 if let Some(from) = &select.from {
581 for expr in &from.expressions {
582 if let Some(name) = source_name(expr) {
583 if !name.is_empty() && seen.insert(name.clone()) {
584 names.push(name);
585 }
586 }
587 }
588 }
589 for join in &select.joins {
590 if let Some(name) = source_name(&join.this) {
591 if !name.is_empty() && seen.insert(name.clone()) {
592 names.push(name);
593 }
594 }
595 }
596 }
597
598 names
599}
600
601fn get_alias_or_name(expr: &Expression) -> Option<String> {
607 match expr {
608 Expression::Alias(alias) => Some(alias.alias.name.clone()),
609 Expression::Column(col) => Some(col.name.name.clone()),
610 Expression::Identifier(id) => Some(id.name.clone()),
611 Expression::Star(_) => Some("*".to_string()),
612 _ => None,
613 }
614}
615
616fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
618 match column {
619 ColumnRef::Name(n) => n.to_string(),
620 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
621 }
622}
623
624fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
626 if let Expression::Select(ref select) = scope_expr {
627 match column {
628 ColumnRef::Name(name) => {
629 for expr in &select.expressions {
630 if get_alias_or_name(expr).as_deref() == Some(name) {
631 return Ok(expr.clone());
632 }
633 }
634 Err(crate::error::Error::parse(
635 format!("Cannot find column '{}' in query", name),
636 0,
637 0,
638 0,
639 0,
640 ))
641 }
642 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
643 crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0, 0, 0)
644 }),
645 }
646 } else {
647 Err(crate::error::Error::parse(
648 "Expected SELECT expression for column lookup",
649 0,
650 0,
651 0,
652 0,
653 ))
654 }
655}
656
657fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
659 let mut expr = set_op_expr;
660 loop {
661 match expr {
662 Expression::Union(u) => expr = &u.left,
663 Expression::Intersect(i) => expr = &i.left,
664 Expression::Except(e) => expr = &e.left,
665 Expression::Select(select) => {
666 for (i, e) in select.expressions.iter().enumerate() {
667 if get_alias_or_name(e).as_deref() == Some(name) {
668 return Ok(i);
669 }
670 }
671 return Err(crate::error::Error::parse(
672 format!("Cannot find column '{}' in set operation", name),
673 0,
674 0,
675 0,
676 0,
677 ));
678 }
679 _ => {
680 return Err(crate::error::Error::parse(
681 "Expected SELECT or set operation",
682 0,
683 0,
684 0,
685 0,
686 ))
687 }
688 }
689 }
690}
691
692fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
694 if let Expression::Select(select) = select_expr {
695 let mut trimmed = select.as_ref().clone();
696 trimmed.expressions = vec![target_expr.clone()];
697 Expression::Select(Box::new(trimmed))
698 } else {
699 select_expr.clone()
700 }
701}
702
703fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
705 if scope.cte_sources.contains_key(source_name) {
707 for cte_scope in &scope.cte_scopes {
708 if let Expression::Cte(cte) = &cte_scope.expression {
709 if cte.alias.name == source_name {
710 return Some(cte_scope);
711 }
712 }
713 }
714 }
715
716 if let Some(source_info) = scope.sources.get(source_name) {
718 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
719 if let Expression::Subquery(sq) = &source_info.expression {
720 for dt_scope in &scope.derived_table_scopes {
721 if dt_scope.expression == sq.this {
722 return Some(dt_scope);
723 }
724 }
725 }
726 }
727 }
728
729 None
730}
731
732fn find_child_scope_in<'a>(
736 all_cte_scopes: &[&'a Scope],
737 scope: &'a Scope,
738 source_name: &str,
739) -> Option<&'a Scope> {
740 for cte_scope in &scope.cte_scopes {
742 if let Expression::Cte(cte) = &cte_scope.expression {
743 if cte.alias.name == source_name {
744 return Some(cte_scope);
745 }
746 }
747 }
748
749 for cte_scope in all_cte_scopes {
751 if let Expression::Cte(cte) = &cte_scope.expression {
752 if cte.alias.name == source_name {
753 return Some(cte_scope);
754 }
755 }
756 }
757
758 if let Some(source_info) = scope.sources.get(source_name) {
760 if source_info.is_scope {
761 if let Expression::Subquery(sq) = &source_info.expression {
762 for dt_scope in &scope.derived_table_scopes {
763 if dt_scope.expression == sq.this {
764 return Some(dt_scope);
765 }
766 }
767 }
768 }
769 }
770
771 None
772}
773
774fn make_table_column_node(table: &str, column: &str) -> LineageNode {
776 let mut node = LineageNode::new(
777 format!("{}.{}", table, column),
778 Expression::Column(crate::expressions::Column {
779 name: crate::expressions::Identifier::new(column.to_string()),
780 table: Some(crate::expressions::Identifier::new(table.to_string())),
781 join_mark: false,
782 trailing_comments: vec![],
783 span: None,
784 inferred_type: None,
785 }),
786 Expression::Table(crate::expressions::TableRef::new(table)),
787 );
788 node.source_name = table.to_string();
789 node
790}
791
792fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
793 let mut parts: Vec<String> = Vec::new();
794 if let Some(catalog) = &table_ref.catalog {
795 parts.push(catalog.name.clone());
796 }
797 if let Some(schema) = &table_ref.schema {
798 parts.push(schema.name.clone());
799 }
800 parts.push(table_ref.name.name.clone());
801 parts.join(".")
802}
803
804fn make_table_column_node_from_source(
805 table_alias: &str,
806 column: &str,
807 source: &Expression,
808) -> LineageNode {
809 let mut node = LineageNode::new(
810 format!("{}.{}", table_alias, column),
811 Expression::Column(crate::expressions::Column {
812 name: crate::expressions::Identifier::new(column.to_string()),
813 table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
814 join_mark: false,
815 trailing_comments: vec![],
816 span: None,
817 inferred_type: None,
818 }),
819 source.clone(),
820 );
821
822 if let Expression::Table(table_ref) = source {
823 node.source_name = table_name_from_table_ref(table_ref);
824 } else {
825 node.source_name = table_alias.to_string();
826 }
827
828 node
829}
830
831#[derive(Debug, Clone)]
833struct SimpleColumnRef {
834 table: Option<crate::expressions::Identifier>,
835 column: String,
836}
837
838fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
840 let mut refs = Vec::new();
841 collect_column_refs(expr, &mut refs);
842 refs
843}
844
845fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
846 let mut stack: Vec<&Expression> = vec![expr];
847
848 while let Some(current) = stack.pop() {
849 match current {
850 Expression::Column(col) => {
852 refs.push(SimpleColumnRef {
853 table: col.table.clone(),
854 column: col.name.name.clone(),
855 });
856 }
857
858 Expression::Subquery(_) | Expression::Exists(_) => {}
860
861 Expression::And(op)
863 | Expression::Or(op)
864 | Expression::Eq(op)
865 | Expression::Neq(op)
866 | Expression::Lt(op)
867 | Expression::Lte(op)
868 | Expression::Gt(op)
869 | Expression::Gte(op)
870 | Expression::Add(op)
871 | Expression::Sub(op)
872 | Expression::Mul(op)
873 | Expression::Div(op)
874 | Expression::Mod(op)
875 | Expression::BitwiseAnd(op)
876 | Expression::BitwiseOr(op)
877 | Expression::BitwiseXor(op)
878 | Expression::BitwiseLeftShift(op)
879 | Expression::BitwiseRightShift(op)
880 | Expression::Concat(op)
881 | Expression::Adjacent(op)
882 | Expression::TsMatch(op)
883 | Expression::PropertyEQ(op)
884 | Expression::ArrayContainsAll(op)
885 | Expression::ArrayContainedBy(op)
886 | Expression::ArrayOverlaps(op)
887 | Expression::JSONBContainsAllTopKeys(op)
888 | Expression::JSONBContainsAnyTopKeys(op)
889 | Expression::JSONBDeleteAtPath(op)
890 | Expression::ExtendsLeft(op)
891 | Expression::ExtendsRight(op)
892 | Expression::Is(op)
893 | Expression::MemberOf(op)
894 | Expression::NullSafeEq(op)
895 | Expression::NullSafeNeq(op)
896 | Expression::Glob(op)
897 | Expression::Match(op) => {
898 stack.push(&op.left);
899 stack.push(&op.right);
900 }
901
902 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
904 stack.push(&u.this);
905 }
906
907 Expression::Upper(f)
909 | Expression::Lower(f)
910 | Expression::Length(f)
911 | Expression::LTrim(f)
912 | Expression::RTrim(f)
913 | Expression::Reverse(f)
914 | Expression::Abs(f)
915 | Expression::Sqrt(f)
916 | Expression::Cbrt(f)
917 | Expression::Ln(f)
918 | Expression::Exp(f)
919 | Expression::Sign(f)
920 | Expression::Date(f)
921 | Expression::Time(f)
922 | Expression::DateFromUnixDate(f)
923 | Expression::UnixDate(f)
924 | Expression::UnixSeconds(f)
925 | Expression::UnixMillis(f)
926 | Expression::UnixMicros(f)
927 | Expression::TimeStrToDate(f)
928 | Expression::DateToDi(f)
929 | Expression::DiToDate(f)
930 | Expression::TsOrDiToDi(f)
931 | Expression::TsOrDsToDatetime(f)
932 | Expression::TsOrDsToTimestamp(f)
933 | Expression::YearOfWeek(f)
934 | Expression::YearOfWeekIso(f)
935 | Expression::Initcap(f)
936 | Expression::Ascii(f)
937 | Expression::Chr(f)
938 | Expression::Soundex(f)
939 | Expression::ByteLength(f)
940 | Expression::Hex(f)
941 | Expression::LowerHex(f)
942 | Expression::Unicode(f)
943 | Expression::Radians(f)
944 | Expression::Degrees(f)
945 | Expression::Sin(f)
946 | Expression::Cos(f)
947 | Expression::Tan(f)
948 | Expression::Asin(f)
949 | Expression::Acos(f)
950 | Expression::Atan(f)
951 | Expression::IsNan(f)
952 | Expression::IsInf(f)
953 | Expression::ArrayLength(f)
954 | Expression::ArraySize(f)
955 | Expression::Cardinality(f)
956 | Expression::ArrayReverse(f)
957 | Expression::ArrayDistinct(f)
958 | Expression::ArrayFlatten(f)
959 | Expression::ArrayCompact(f)
960 | Expression::Explode(f)
961 | Expression::ExplodeOuter(f)
962 | Expression::ToArray(f)
963 | Expression::MapFromEntries(f)
964 | Expression::MapKeys(f)
965 | Expression::MapValues(f)
966 | Expression::JsonArrayLength(f)
967 | Expression::JsonKeys(f)
968 | Expression::JsonType(f)
969 | Expression::ParseJson(f)
970 | Expression::ToJson(f)
971 | Expression::Typeof(f)
972 | Expression::BitwiseCount(f)
973 | Expression::Year(f)
974 | Expression::Month(f)
975 | Expression::Day(f)
976 | Expression::Hour(f)
977 | Expression::Minute(f)
978 | Expression::Second(f)
979 | Expression::DayOfWeek(f)
980 | Expression::DayOfWeekIso(f)
981 | Expression::DayOfMonth(f)
982 | Expression::DayOfYear(f)
983 | Expression::WeekOfYear(f)
984 | Expression::Quarter(f)
985 | Expression::Epoch(f)
986 | Expression::EpochMs(f)
987 | Expression::TimeStrToUnix(f)
988 | Expression::SHA(f)
989 | Expression::SHA1Digest(f)
990 | Expression::TimeToUnix(f)
991 | Expression::JSONBool(f)
992 | Expression::Int64(f)
993 | Expression::MD5NumberLower64(f)
994 | Expression::MD5NumberUpper64(f)
995 | Expression::DateStrToDate(f)
996 | Expression::DateToDateStr(f) => {
997 stack.push(&f.this);
998 }
999
1000 Expression::Power(f)
1002 | Expression::NullIf(f)
1003 | Expression::IfNull(f)
1004 | Expression::Nvl(f)
1005 | Expression::UnixToTimeStr(f)
1006 | Expression::Contains(f)
1007 | Expression::StartsWith(f)
1008 | Expression::EndsWith(f)
1009 | Expression::Levenshtein(f)
1010 | Expression::ModFunc(f)
1011 | Expression::Atan2(f)
1012 | Expression::IntDiv(f)
1013 | Expression::AddMonths(f)
1014 | Expression::MonthsBetween(f)
1015 | Expression::NextDay(f)
1016 | Expression::ArrayContains(f)
1017 | Expression::ArrayPosition(f)
1018 | Expression::ArrayAppend(f)
1019 | Expression::ArrayPrepend(f)
1020 | Expression::ArrayUnion(f)
1021 | Expression::ArrayExcept(f)
1022 | Expression::ArrayRemove(f)
1023 | Expression::StarMap(f)
1024 | Expression::MapFromArrays(f)
1025 | Expression::MapContainsKey(f)
1026 | Expression::ElementAt(f)
1027 | Expression::JsonMergePatch(f)
1028 | Expression::JSONBContains(f)
1029 | Expression::JSONBExtract(f) => {
1030 stack.push(&f.this);
1031 stack.push(&f.expression);
1032 }
1033
1034 Expression::Greatest(f)
1036 | Expression::Least(f)
1037 | Expression::Coalesce(f)
1038 | Expression::ArrayConcat(f)
1039 | Expression::ArrayIntersect(f)
1040 | Expression::ArrayZip(f)
1041 | Expression::MapConcat(f)
1042 | Expression::JsonArray(f) => {
1043 for e in &f.expressions {
1044 stack.push(e);
1045 }
1046 }
1047
1048 Expression::Sum(f)
1050 | Expression::Avg(f)
1051 | Expression::Min(f)
1052 | Expression::Max(f)
1053 | Expression::ArrayAgg(f)
1054 | Expression::CountIf(f)
1055 | Expression::Stddev(f)
1056 | Expression::StddevPop(f)
1057 | Expression::StddevSamp(f)
1058 | Expression::Variance(f)
1059 | Expression::VarPop(f)
1060 | Expression::VarSamp(f)
1061 | Expression::Median(f)
1062 | Expression::Mode(f)
1063 | Expression::First(f)
1064 | Expression::Last(f)
1065 | Expression::AnyValue(f)
1066 | Expression::ApproxDistinct(f)
1067 | Expression::ApproxCountDistinct(f)
1068 | Expression::LogicalAnd(f)
1069 | Expression::LogicalOr(f)
1070 | Expression::Skewness(f)
1071 | Expression::ArrayConcatAgg(f)
1072 | Expression::ArrayUniqueAgg(f)
1073 | Expression::BoolXorAgg(f)
1074 | Expression::BitwiseAndAgg(f)
1075 | Expression::BitwiseOrAgg(f)
1076 | Expression::BitwiseXorAgg(f) => {
1077 stack.push(&f.this);
1078 if let Some(ref filter) = f.filter {
1079 stack.push(filter);
1080 }
1081 if let Some((ref expr, _)) = f.having_max {
1082 stack.push(expr);
1083 }
1084 if let Some(ref limit) = f.limit {
1085 stack.push(limit);
1086 }
1087 }
1088
1089 Expression::Function(func) => {
1091 for arg in &func.args {
1092 stack.push(arg);
1093 }
1094 }
1095 Expression::AggregateFunction(func) => {
1096 for arg in &func.args {
1097 stack.push(arg);
1098 }
1099 if let Some(ref filter) = func.filter {
1100 stack.push(filter);
1101 }
1102 if let Some(ref limit) = func.limit {
1103 stack.push(limit);
1104 }
1105 }
1106
1107 Expression::WindowFunction(wf) => {
1109 stack.push(&wf.this);
1110 }
1111
1112 Expression::Alias(a) => {
1114 stack.push(&a.this);
1115 }
1116 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1117 stack.push(&c.this);
1118 if let Some(ref fmt) = c.format {
1119 stack.push(fmt);
1120 }
1121 if let Some(ref def) = c.default {
1122 stack.push(def);
1123 }
1124 }
1125 Expression::Paren(p) => {
1126 stack.push(&p.this);
1127 }
1128 Expression::Annotated(a) => {
1129 stack.push(&a.this);
1130 }
1131 Expression::Case(case) => {
1132 if let Some(ref operand) = case.operand {
1133 stack.push(operand);
1134 }
1135 for (cond, result) in &case.whens {
1136 stack.push(cond);
1137 stack.push(result);
1138 }
1139 if let Some(ref else_expr) = case.else_ {
1140 stack.push(else_expr);
1141 }
1142 }
1143 Expression::Collation(c) => {
1144 stack.push(&c.this);
1145 }
1146 Expression::In(i) => {
1147 stack.push(&i.this);
1148 for e in &i.expressions {
1149 stack.push(e);
1150 }
1151 if let Some(ref q) = i.query {
1152 stack.push(q);
1153 }
1154 if let Some(ref u) = i.unnest {
1155 stack.push(u);
1156 }
1157 }
1158 Expression::Between(b) => {
1159 stack.push(&b.this);
1160 stack.push(&b.low);
1161 stack.push(&b.high);
1162 }
1163 Expression::IsNull(n) => {
1164 stack.push(&n.this);
1165 }
1166 Expression::IsTrue(t) | Expression::IsFalse(t) => {
1167 stack.push(&t.this);
1168 }
1169 Expression::IsJson(j) => {
1170 stack.push(&j.this);
1171 }
1172 Expression::Like(l) | Expression::ILike(l) => {
1173 stack.push(&l.left);
1174 stack.push(&l.right);
1175 if let Some(ref esc) = l.escape {
1176 stack.push(esc);
1177 }
1178 }
1179 Expression::SimilarTo(s) => {
1180 stack.push(&s.this);
1181 stack.push(&s.pattern);
1182 if let Some(ref esc) = s.escape {
1183 stack.push(esc);
1184 }
1185 }
1186 Expression::Ordered(o) => {
1187 stack.push(&o.this);
1188 }
1189 Expression::Array(a) => {
1190 for e in &a.expressions {
1191 stack.push(e);
1192 }
1193 }
1194 Expression::Tuple(t) => {
1195 for e in &t.expressions {
1196 stack.push(e);
1197 }
1198 }
1199 Expression::Struct(s) => {
1200 for (_, e) in &s.fields {
1201 stack.push(e);
1202 }
1203 }
1204 Expression::Subscript(s) => {
1205 stack.push(&s.this);
1206 stack.push(&s.index);
1207 }
1208 Expression::Dot(d) => {
1209 stack.push(&d.this);
1210 }
1211 Expression::MethodCall(m) => {
1212 stack.push(&m.this);
1213 for arg in &m.args {
1214 stack.push(arg);
1215 }
1216 }
1217 Expression::ArraySlice(s) => {
1218 stack.push(&s.this);
1219 if let Some(ref start) = s.start {
1220 stack.push(start);
1221 }
1222 if let Some(ref end) = s.end {
1223 stack.push(end);
1224 }
1225 }
1226 Expression::Lambda(l) => {
1227 stack.push(&l.body);
1228 }
1229 Expression::NamedArgument(n) => {
1230 stack.push(&n.value);
1231 }
1232 Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1233 stack.push(e);
1234 }
1235
1236 Expression::Substring(f) => {
1238 stack.push(&f.this);
1239 stack.push(&f.start);
1240 if let Some(ref len) = f.length {
1241 stack.push(len);
1242 }
1243 }
1244 Expression::Trim(f) => {
1245 stack.push(&f.this);
1246 if let Some(ref chars) = f.characters {
1247 stack.push(chars);
1248 }
1249 }
1250 Expression::Replace(f) => {
1251 stack.push(&f.this);
1252 stack.push(&f.old);
1253 stack.push(&f.new);
1254 }
1255 Expression::IfFunc(f) => {
1256 stack.push(&f.condition);
1257 stack.push(&f.true_value);
1258 if let Some(ref fv) = f.false_value {
1259 stack.push(fv);
1260 }
1261 }
1262 Expression::Nvl2(f) => {
1263 stack.push(&f.this);
1264 stack.push(&f.true_value);
1265 stack.push(&f.false_value);
1266 }
1267 Expression::ConcatWs(f) => {
1268 stack.push(&f.separator);
1269 for e in &f.expressions {
1270 stack.push(e);
1271 }
1272 }
1273 Expression::Count(f) => {
1274 if let Some(ref this) = f.this {
1275 stack.push(this);
1276 }
1277 if let Some(ref filter) = f.filter {
1278 stack.push(filter);
1279 }
1280 }
1281 Expression::GroupConcat(f) => {
1282 stack.push(&f.this);
1283 if let Some(ref sep) = f.separator {
1284 stack.push(sep);
1285 }
1286 if let Some(ref filter) = f.filter {
1287 stack.push(filter);
1288 }
1289 }
1290 Expression::StringAgg(f) => {
1291 stack.push(&f.this);
1292 if let Some(ref sep) = f.separator {
1293 stack.push(sep);
1294 }
1295 if let Some(ref filter) = f.filter {
1296 stack.push(filter);
1297 }
1298 if let Some(ref limit) = f.limit {
1299 stack.push(limit);
1300 }
1301 }
1302 Expression::ListAgg(f) => {
1303 stack.push(&f.this);
1304 if let Some(ref sep) = f.separator {
1305 stack.push(sep);
1306 }
1307 if let Some(ref filter) = f.filter {
1308 stack.push(filter);
1309 }
1310 }
1311 Expression::SumIf(f) => {
1312 stack.push(&f.this);
1313 stack.push(&f.condition);
1314 if let Some(ref filter) = f.filter {
1315 stack.push(filter);
1316 }
1317 }
1318 Expression::DateAdd(f) | Expression::DateSub(f) => {
1319 stack.push(&f.this);
1320 stack.push(&f.interval);
1321 }
1322 Expression::DateDiff(f) => {
1323 stack.push(&f.this);
1324 stack.push(&f.expression);
1325 }
1326 Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1327 stack.push(&f.this);
1328 }
1329 Expression::Extract(f) => {
1330 stack.push(&f.this);
1331 }
1332 Expression::Round(f) => {
1333 stack.push(&f.this);
1334 if let Some(ref d) = f.decimals {
1335 stack.push(d);
1336 }
1337 }
1338 Expression::Floor(f) => {
1339 stack.push(&f.this);
1340 if let Some(ref s) = f.scale {
1341 stack.push(s);
1342 }
1343 if let Some(ref t) = f.to {
1344 stack.push(t);
1345 }
1346 }
1347 Expression::Ceil(f) => {
1348 stack.push(&f.this);
1349 if let Some(ref d) = f.decimals {
1350 stack.push(d);
1351 }
1352 if let Some(ref t) = f.to {
1353 stack.push(t);
1354 }
1355 }
1356 Expression::Log(f) => {
1357 stack.push(&f.this);
1358 if let Some(ref b) = f.base {
1359 stack.push(b);
1360 }
1361 }
1362 Expression::AtTimeZone(f) => {
1363 stack.push(&f.this);
1364 stack.push(&f.zone);
1365 }
1366 Expression::Lead(f) | Expression::Lag(f) => {
1367 stack.push(&f.this);
1368 if let Some(ref off) = f.offset {
1369 stack.push(off);
1370 }
1371 if let Some(ref def) = f.default {
1372 stack.push(def);
1373 }
1374 }
1375 Expression::FirstValue(f) | Expression::LastValue(f) => {
1376 stack.push(&f.this);
1377 }
1378 Expression::NthValue(f) => {
1379 stack.push(&f.this);
1380 stack.push(&f.offset);
1381 }
1382 Expression::Position(f) => {
1383 stack.push(&f.substring);
1384 stack.push(&f.string);
1385 if let Some(ref start) = f.start {
1386 stack.push(start);
1387 }
1388 }
1389 Expression::Decode(f) => {
1390 stack.push(&f.this);
1391 for (search, result) in &f.search_results {
1392 stack.push(search);
1393 stack.push(result);
1394 }
1395 if let Some(ref def) = f.default {
1396 stack.push(def);
1397 }
1398 }
1399 Expression::CharFunc(f) => {
1400 for arg in &f.args {
1401 stack.push(arg);
1402 }
1403 }
1404 Expression::ArraySort(f) => {
1405 stack.push(&f.this);
1406 if let Some(ref cmp) = f.comparator {
1407 stack.push(cmp);
1408 }
1409 }
1410 Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1411 stack.push(&f.this);
1412 stack.push(&f.separator);
1413 if let Some(ref nr) = f.null_replacement {
1414 stack.push(nr);
1415 }
1416 }
1417 Expression::ArrayFilter(f) => {
1418 stack.push(&f.this);
1419 stack.push(&f.filter);
1420 }
1421 Expression::ArrayTransform(f) => {
1422 stack.push(&f.this);
1423 stack.push(&f.transform);
1424 }
1425 Expression::Sequence(f)
1426 | Expression::Generate(f)
1427 | Expression::ExplodingGenerateSeries(f) => {
1428 stack.push(&f.start);
1429 stack.push(&f.stop);
1430 if let Some(ref step) = f.step {
1431 stack.push(step);
1432 }
1433 }
1434 Expression::JsonExtract(f)
1435 | Expression::JsonExtractScalar(f)
1436 | Expression::JsonQuery(f)
1437 | Expression::JsonValue(f) => {
1438 stack.push(&f.this);
1439 stack.push(&f.path);
1440 }
1441 Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1442 stack.push(&f.this);
1443 for p in &f.paths {
1444 stack.push(p);
1445 }
1446 }
1447 Expression::JsonObject(f) => {
1448 for (k, v) in &f.pairs {
1449 stack.push(k);
1450 stack.push(v);
1451 }
1452 }
1453 Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1454 stack.push(&f.this);
1455 for (path, val) in &f.path_values {
1456 stack.push(path);
1457 stack.push(val);
1458 }
1459 }
1460 Expression::Overlay(f) => {
1461 stack.push(&f.this);
1462 stack.push(&f.replacement);
1463 stack.push(&f.from);
1464 if let Some(ref len) = f.length {
1465 stack.push(len);
1466 }
1467 }
1468 Expression::Convert(f) => {
1469 stack.push(&f.this);
1470 if let Some(ref style) = f.style {
1471 stack.push(style);
1472 }
1473 }
1474 Expression::ApproxPercentile(f) => {
1475 stack.push(&f.this);
1476 stack.push(&f.percentile);
1477 if let Some(ref acc) = f.accuracy {
1478 stack.push(acc);
1479 }
1480 if let Some(ref filter) = f.filter {
1481 stack.push(filter);
1482 }
1483 }
1484 Expression::Percentile(f)
1485 | Expression::PercentileCont(f)
1486 | Expression::PercentileDisc(f) => {
1487 stack.push(&f.this);
1488 stack.push(&f.percentile);
1489 if let Some(ref filter) = f.filter {
1490 stack.push(filter);
1491 }
1492 }
1493 Expression::WithinGroup(f) => {
1494 stack.push(&f.this);
1495 }
1496 Expression::Left(f) | Expression::Right(f) => {
1497 stack.push(&f.this);
1498 stack.push(&f.length);
1499 }
1500 Expression::Repeat(f) => {
1501 stack.push(&f.this);
1502 stack.push(&f.times);
1503 }
1504 Expression::Lpad(f) | Expression::Rpad(f) => {
1505 stack.push(&f.this);
1506 stack.push(&f.length);
1507 if let Some(ref fill) = f.fill {
1508 stack.push(fill);
1509 }
1510 }
1511 Expression::Split(f) => {
1512 stack.push(&f.this);
1513 stack.push(&f.delimiter);
1514 }
1515 Expression::RegexpLike(f) => {
1516 stack.push(&f.this);
1517 stack.push(&f.pattern);
1518 if let Some(ref flags) = f.flags {
1519 stack.push(flags);
1520 }
1521 }
1522 Expression::RegexpReplace(f) => {
1523 stack.push(&f.this);
1524 stack.push(&f.pattern);
1525 stack.push(&f.replacement);
1526 if let Some(ref flags) = f.flags {
1527 stack.push(flags);
1528 }
1529 }
1530 Expression::RegexpExtract(f) => {
1531 stack.push(&f.this);
1532 stack.push(&f.pattern);
1533 if let Some(ref group) = f.group {
1534 stack.push(group);
1535 }
1536 }
1537 Expression::ToDate(f) => {
1538 stack.push(&f.this);
1539 if let Some(ref fmt) = f.format {
1540 stack.push(fmt);
1541 }
1542 }
1543 Expression::ToTimestamp(f) => {
1544 stack.push(&f.this);
1545 if let Some(ref fmt) = f.format {
1546 stack.push(fmt);
1547 }
1548 }
1549 Expression::DateFormat(f) | Expression::FormatDate(f) => {
1550 stack.push(&f.this);
1551 stack.push(&f.format);
1552 }
1553 Expression::LastDay(f) => {
1554 stack.push(&f.this);
1555 }
1556 Expression::FromUnixtime(f) => {
1557 stack.push(&f.this);
1558 if let Some(ref fmt) = f.format {
1559 stack.push(fmt);
1560 }
1561 }
1562 Expression::UnixTimestamp(f) => {
1563 if let Some(ref this) = f.this {
1564 stack.push(this);
1565 }
1566 if let Some(ref fmt) = f.format {
1567 stack.push(fmt);
1568 }
1569 }
1570 Expression::MakeDate(f) => {
1571 stack.push(&f.year);
1572 stack.push(&f.month);
1573 stack.push(&f.day);
1574 }
1575 Expression::MakeTimestamp(f) => {
1576 stack.push(&f.year);
1577 stack.push(&f.month);
1578 stack.push(&f.day);
1579 stack.push(&f.hour);
1580 stack.push(&f.minute);
1581 stack.push(&f.second);
1582 if let Some(ref tz) = f.timezone {
1583 stack.push(tz);
1584 }
1585 }
1586 Expression::TruncFunc(f) => {
1587 stack.push(&f.this);
1588 if let Some(ref d) = f.decimals {
1589 stack.push(d);
1590 }
1591 }
1592 Expression::ArrayFunc(f) => {
1593 for e in &f.expressions {
1594 stack.push(e);
1595 }
1596 }
1597 Expression::Unnest(f) => {
1598 stack.push(&f.this);
1599 for e in &f.expressions {
1600 stack.push(e);
1601 }
1602 }
1603 Expression::StructFunc(f) => {
1604 for (_, e) in &f.fields {
1605 stack.push(e);
1606 }
1607 }
1608 Expression::StructExtract(f) => {
1609 stack.push(&f.this);
1610 }
1611 Expression::NamedStruct(f) => {
1612 for (k, v) in &f.pairs {
1613 stack.push(k);
1614 stack.push(v);
1615 }
1616 }
1617 Expression::MapFunc(f) => {
1618 for k in &f.keys {
1619 stack.push(k);
1620 }
1621 for v in &f.values {
1622 stack.push(v);
1623 }
1624 }
1625 Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1626 stack.push(&f.this);
1627 stack.push(&f.transform);
1628 }
1629 Expression::JsonArrayAgg(f) => {
1630 stack.push(&f.this);
1631 if let Some(ref filter) = f.filter {
1632 stack.push(filter);
1633 }
1634 }
1635 Expression::JsonObjectAgg(f) => {
1636 stack.push(&f.key);
1637 stack.push(&f.value);
1638 if let Some(ref filter) = f.filter {
1639 stack.push(filter);
1640 }
1641 }
1642 Expression::NTile(f) => {
1643 if let Some(ref n) = f.num_buckets {
1644 stack.push(n);
1645 }
1646 }
1647 Expression::Rand(f) => {
1648 if let Some(ref s) = f.seed {
1649 stack.push(s);
1650 }
1651 if let Some(ref lo) = f.lower {
1652 stack.push(lo);
1653 }
1654 if let Some(ref hi) = f.upper {
1655 stack.push(hi);
1656 }
1657 }
1658 Expression::Any(q) | Expression::All(q) => {
1659 stack.push(&q.this);
1660 stack.push(&q.subquery);
1661 }
1662 Expression::Overlaps(o) => {
1663 if let Some(ref this) = o.this {
1664 stack.push(this);
1665 }
1666 if let Some(ref expr) = o.expression {
1667 stack.push(expr);
1668 }
1669 if let Some(ref ls) = o.left_start {
1670 stack.push(ls);
1671 }
1672 if let Some(ref le) = o.left_end {
1673 stack.push(le);
1674 }
1675 if let Some(ref rs) = o.right_start {
1676 stack.push(rs);
1677 }
1678 if let Some(ref re) = o.right_end {
1679 stack.push(re);
1680 }
1681 }
1682 Expression::Interval(i) => {
1683 if let Some(ref this) = i.this {
1684 stack.push(this);
1685 }
1686 }
1687 Expression::TimeStrToTime(f) => {
1688 stack.push(&f.this);
1689 if let Some(ref zone) = f.zone {
1690 stack.push(zone);
1691 }
1692 }
1693 Expression::JSONBExtractScalar(f) => {
1694 stack.push(&f.this);
1695 stack.push(&f.expression);
1696 if let Some(ref jt) = f.json_type {
1697 stack.push(jt);
1698 }
1699 }
1700
1701 _ => {}
1706 }
1707 }
1708}
1709
1710#[cfg(test)]
1715mod tests {
1716 use super::*;
1717 use crate::dialects::{Dialect, DialectType};
1718 use crate::expressions::DataType;
1719 use crate::optimizer::annotate_types::annotate_types;
1720 use crate::schema::{MappingSchema, Schema};
1721
1722 fn parse(sql: &str) -> Expression {
1723 let dialect = Dialect::get(DialectType::Generic);
1724 let ast = dialect.parse(sql).unwrap();
1725 ast.into_iter().next().unwrap()
1726 }
1727
1728 #[test]
1729 fn test_simple_lineage() {
1730 let expr = parse("SELECT a FROM t");
1731 let node = lineage("a", &expr, None, false).unwrap();
1732
1733 assert_eq!(node.name, "a");
1734 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1735 let names = node.downstream_names();
1737 assert!(
1738 names.iter().any(|n| n == "t.a"),
1739 "Expected t.a in downstream, got: {:?}",
1740 names
1741 );
1742 }
1743
1744 #[test]
1745 fn test_lineage_walk() {
1746 let root = LineageNode {
1747 name: "col_a".to_string(),
1748 expression: Expression::Null(crate::expressions::Null),
1749 source: Expression::Null(crate::expressions::Null),
1750 downstream: vec![LineageNode::new(
1751 "t.a",
1752 Expression::Null(crate::expressions::Null),
1753 Expression::Null(crate::expressions::Null),
1754 )],
1755 source_name: String::new(),
1756 reference_node_name: String::new(),
1757 };
1758
1759 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1760 assert_eq!(names.len(), 2);
1761 assert_eq!(names[0], "col_a");
1762 assert_eq!(names[1], "t.a");
1763 }
1764
1765 #[test]
1766 fn test_aliased_column() {
1767 let expr = parse("SELECT a + 1 AS b FROM t");
1768 let node = lineage("b", &expr, None, false).unwrap();
1769
1770 assert_eq!(node.name, "b");
1771 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1773 assert!(
1774 all_names.iter().any(|n| n.contains("a")),
1775 "Expected to trace to column a, got: {:?}",
1776 all_names
1777 );
1778 }
1779
1780 #[test]
1781 fn test_qualified_column() {
1782 let expr = parse("SELECT t.a FROM t");
1783 let node = lineage("a", &expr, None, false).unwrap();
1784
1785 assert_eq!(node.name, "a");
1786 let names = node.downstream_names();
1787 assert!(
1788 names.iter().any(|n| n == "t.a"),
1789 "Expected t.a, got: {:?}",
1790 names
1791 );
1792 }
1793
1794 #[test]
1795 fn test_unqualified_column() {
1796 let expr = parse("SELECT a FROM t");
1797 let node = lineage("a", &expr, None, false).unwrap();
1798
1799 let names = node.downstream_names();
1801 assert!(
1802 names.iter().any(|n| n == "t.a"),
1803 "Expected t.a, got: {:?}",
1804 names
1805 );
1806 }
1807
1808 #[test]
1809 fn test_lineage_with_schema_qualifies_root_expression_issue_40() {
1810 let query = "SELECT name FROM users";
1811 let dialect = Dialect::get(DialectType::BigQuery);
1812 let expr = dialect
1813 .parse(query)
1814 .unwrap()
1815 .into_iter()
1816 .next()
1817 .expect("expected one expression");
1818
1819 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1820 schema
1821 .add_table("users", &[("name".into(), DataType::Text)], None)
1822 .expect("schema setup");
1823
1824 let node_without_schema = lineage("name", &expr, Some(DialectType::BigQuery), false)
1825 .expect("lineage without schema");
1826 let mut expr_without = node_without_schema.expression.clone();
1827 annotate_types(
1828 &mut expr_without,
1829 Some(&schema),
1830 Some(DialectType::BigQuery),
1831 );
1832 assert_eq!(
1833 expr_without.inferred_type(),
1834 None,
1835 "Expected unresolved root type without schema-aware lineage qualification"
1836 );
1837
1838 let node_with_schema = lineage_with_schema(
1839 "name",
1840 &expr,
1841 Some(&schema),
1842 Some(DialectType::BigQuery),
1843 false,
1844 )
1845 .expect("lineage with schema");
1846 let mut expr_with = node_with_schema.expression.clone();
1847 annotate_types(&mut expr_with, Some(&schema), Some(DialectType::BigQuery));
1848
1849 assert_eq!(expr_with.inferred_type(), Some(&DataType::Text));
1850 }
1851
1852 #[test]
1853 fn test_lineage_with_schema_correlated_scalar_subquery() {
1854 let query = "SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1";
1855 let dialect = Dialect::get(DialectType::BigQuery);
1856 let expr = dialect
1857 .parse(query)
1858 .unwrap()
1859 .into_iter()
1860 .next()
1861 .expect("expected one expression");
1862
1863 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1864 schema
1865 .add_table(
1866 "t1",
1867 &[("id".into(), DataType::BigInt { length: None })],
1868 None,
1869 )
1870 .expect("schema setup");
1871 schema
1872 .add_table(
1873 "t2",
1874 &[
1875 ("id".into(), DataType::BigInt { length: None }),
1876 ("val".into(), DataType::BigInt { length: None }),
1877 ],
1878 None,
1879 )
1880 .expect("schema setup");
1881
1882 let node = lineage_with_schema(
1883 "id",
1884 &expr,
1885 Some(&schema),
1886 Some(DialectType::BigQuery),
1887 false,
1888 )
1889 .expect("lineage_with_schema should handle correlated scalar subqueries");
1890
1891 assert_eq!(node.name, "id");
1892 }
1893
1894 #[test]
1895 fn test_lineage_with_schema_join_using() {
1896 let query = "SELECT a FROM t1 JOIN t2 USING(a)";
1897 let dialect = Dialect::get(DialectType::BigQuery);
1898 let expr = dialect
1899 .parse(query)
1900 .unwrap()
1901 .into_iter()
1902 .next()
1903 .expect("expected one expression");
1904
1905 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1906 schema
1907 .add_table(
1908 "t1",
1909 &[("a".into(), DataType::BigInt { length: None })],
1910 None,
1911 )
1912 .expect("schema setup");
1913 schema
1914 .add_table(
1915 "t2",
1916 &[("a".into(), DataType::BigInt { length: None })],
1917 None,
1918 )
1919 .expect("schema setup");
1920
1921 let node = lineage_with_schema(
1922 "a",
1923 &expr,
1924 Some(&schema),
1925 Some(DialectType::BigQuery),
1926 false,
1927 )
1928 .expect("lineage_with_schema should handle JOIN USING");
1929
1930 assert_eq!(node.name, "a");
1931 }
1932
1933 #[test]
1934 fn test_lineage_with_schema_qualified_table_name() {
1935 let query = "SELECT a FROM raw.t1";
1936 let dialect = Dialect::get(DialectType::BigQuery);
1937 let expr = dialect
1938 .parse(query)
1939 .unwrap()
1940 .into_iter()
1941 .next()
1942 .expect("expected one expression");
1943
1944 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1945 schema
1946 .add_table(
1947 "raw.t1",
1948 &[("a".into(), DataType::BigInt { length: None })],
1949 None,
1950 )
1951 .expect("schema setup");
1952
1953 let node = lineage_with_schema(
1954 "a",
1955 &expr,
1956 Some(&schema),
1957 Some(DialectType::BigQuery),
1958 false,
1959 )
1960 .expect("lineage_with_schema should handle dotted schema.table names");
1961
1962 assert_eq!(node.name, "a");
1963 }
1964
1965 #[test]
1966 fn test_lineage_with_schema_none_matches_lineage() {
1967 let expr = parse("SELECT a FROM t");
1968 let baseline = lineage("a", &expr, None, false).expect("lineage baseline");
1969 let with_none =
1970 lineage_with_schema("a", &expr, None, None, false).expect("lineage_with_schema");
1971
1972 assert_eq!(with_none.name, baseline.name);
1973 assert_eq!(with_none.downstream_names(), baseline.downstream_names());
1974 }
1975
1976 #[test]
1977 fn test_lineage_join() {
1978 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1979
1980 let node_a = lineage("a", &expr, None, false).unwrap();
1981 let names_a = node_a.downstream_names();
1982 assert!(
1983 names_a.iter().any(|n| n == "t.a"),
1984 "Expected t.a, got: {:?}",
1985 names_a
1986 );
1987
1988 let node_b = lineage("b", &expr, None, false).unwrap();
1989 let names_b = node_b.downstream_names();
1990 assert!(
1991 names_b.iter().any(|n| n == "s.b"),
1992 "Expected s.b, got: {:?}",
1993 names_b
1994 );
1995 }
1996
1997 #[test]
1998 fn test_lineage_alias_leaf_has_resolved_source_name() {
1999 let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
2000 let node = lineage("col1", &expr, None, false).unwrap();
2001
2002 let names = node.downstream_names();
2004 assert!(
2005 names.iter().any(|n| n == "t1.col1"),
2006 "Expected aliased column edge t1.col1, got: {:?}",
2007 names
2008 );
2009
2010 let leaf = node
2012 .downstream
2013 .iter()
2014 .find(|n| n.name == "t1.col1")
2015 .expect("Expected t1.col1 leaf");
2016 assert_eq!(leaf.source_name, "table1");
2017 match &leaf.source {
2018 Expression::Table(table) => assert_eq!(table.name.name, "table1"),
2019 _ => panic!("Expected leaf source to be a table expression"),
2020 }
2021 }
2022
2023 #[test]
2024 fn test_lineage_derived_table() {
2025 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
2026 let node = lineage("a", &expr, None, false).unwrap();
2027
2028 assert_eq!(node.name, "a");
2029 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2031 assert!(
2032 all_names.iter().any(|n| n == "t.a"),
2033 "Expected to trace through derived table to t.a, got: {:?}",
2034 all_names
2035 );
2036 }
2037
2038 #[test]
2039 fn test_lineage_cte() {
2040 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
2041 let node = lineage("a", &expr, None, false).unwrap();
2042
2043 assert_eq!(node.name, "a");
2044 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2045 assert!(
2046 all_names.iter().any(|n| n == "t.a"),
2047 "Expected to trace through CTE to t.a, got: {:?}",
2048 all_names
2049 );
2050 }
2051
2052 #[test]
2053 fn test_lineage_union() {
2054 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
2055 let node = lineage("a", &expr, None, false).unwrap();
2056
2057 assert_eq!(node.name, "a");
2058 assert_eq!(
2060 node.downstream.len(),
2061 2,
2062 "Expected 2 branches for UNION, got {}",
2063 node.downstream.len()
2064 );
2065 }
2066
2067 #[test]
2068 fn test_lineage_cte_union() {
2069 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
2070 let node = lineage("a", &expr, None, false).unwrap();
2071
2072 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2074 assert!(
2075 all_names.len() >= 3,
2076 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
2077 all_names
2078 );
2079 }
2080
2081 #[test]
2082 fn test_lineage_star() {
2083 let expr = parse("SELECT * FROM t");
2084 let node = lineage("*", &expr, None, false).unwrap();
2085
2086 assert_eq!(node.name, "*");
2087 assert!(
2089 !node.downstream.is_empty(),
2090 "Star should produce downstream nodes"
2091 );
2092 }
2093
2094 #[test]
2095 fn test_lineage_subquery_in_select() {
2096 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
2097 let node = lineage("x", &expr, None, false).unwrap();
2098
2099 assert_eq!(node.name, "x");
2100 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2102 assert!(
2103 all_names.len() >= 2,
2104 "Expected tracing into scalar subquery, got: {:?}",
2105 all_names
2106 );
2107 }
2108
2109 #[test]
2110 fn test_lineage_multiple_columns() {
2111 let expr = parse("SELECT a, b FROM t");
2112
2113 let node_a = lineage("a", &expr, None, false).unwrap();
2114 let node_b = lineage("b", &expr, None, false).unwrap();
2115
2116 assert_eq!(node_a.name, "a");
2117 assert_eq!(node_b.name, "b");
2118
2119 let names_a = node_a.downstream_names();
2121 let names_b = node_b.downstream_names();
2122 assert!(names_a.iter().any(|n| n == "t.a"));
2123 assert!(names_b.iter().any(|n| n == "t.b"));
2124 }
2125
2126 #[test]
2127 fn test_get_source_tables() {
2128 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
2129 let node = lineage("a", &expr, None, false).unwrap();
2130
2131 let tables = get_source_tables(&node);
2132 assert!(
2133 tables.contains("t"),
2134 "Expected source table 't', got: {:?}",
2135 tables
2136 );
2137 }
2138
2139 #[test]
2140 fn test_lineage_column_not_found() {
2141 let expr = parse("SELECT a FROM t");
2142 let result = lineage("nonexistent", &expr, None, false);
2143 assert!(result.is_err());
2144 }
2145
2146 #[test]
2147 fn test_lineage_nested_cte() {
2148 let expr = parse(
2149 "WITH cte1 AS (SELECT a FROM t), \
2150 cte2 AS (SELECT a FROM cte1) \
2151 SELECT a FROM cte2",
2152 );
2153 let node = lineage("a", &expr, None, false).unwrap();
2154
2155 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2157 assert!(
2158 all_names.len() >= 3,
2159 "Expected to trace through nested CTEs, got: {:?}",
2160 all_names
2161 );
2162 }
2163
2164 #[test]
2165 fn test_trim_selects_true() {
2166 let expr = parse("SELECT a, b, c FROM t");
2167 let node = lineage("a", &expr, None, true).unwrap();
2168
2169 if let Expression::Select(select) = &node.source {
2171 assert_eq!(
2172 select.expressions.len(),
2173 1,
2174 "Trimmed source should have 1 expression, got {}",
2175 select.expressions.len()
2176 );
2177 } else {
2178 panic!("Expected Select source");
2179 }
2180 }
2181
2182 #[test]
2183 fn test_trim_selects_false() {
2184 let expr = parse("SELECT a, b, c FROM t");
2185 let node = lineage("a", &expr, None, false).unwrap();
2186
2187 if let Expression::Select(select) = &node.source {
2189 assert_eq!(
2190 select.expressions.len(),
2191 3,
2192 "Untrimmed source should have 3 expressions"
2193 );
2194 } else {
2195 panic!("Expected Select source");
2196 }
2197 }
2198
2199 #[test]
2200 fn test_lineage_expression_in_select() {
2201 let expr = parse("SELECT a + b AS c FROM t");
2202 let node = lineage("c", &expr, None, false).unwrap();
2203
2204 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2206 assert!(
2207 all_names.len() >= 3,
2208 "Expected to trace a + b to both columns, got: {:?}",
2209 all_names
2210 );
2211 }
2212
2213 #[test]
2214 fn test_set_operation_by_index() {
2215 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
2216
2217 let node = lineage("a", &expr, None, false).unwrap();
2219
2220 assert_eq!(node.downstream.len(), 2);
2222 }
2223
2224 fn print_node(node: &LineageNode, indent: usize) {
2227 let pad = " ".repeat(indent);
2228 println!(
2229 "{pad}name={:?} source_name={:?}",
2230 node.name, node.source_name
2231 );
2232 for child in &node.downstream {
2233 print_node(child, indent + 1);
2234 }
2235 }
2236
2237 #[test]
2238 fn test_issue18_repro() {
2239 let query = "SELECT UPPER(name) as upper_name FROM users";
2241 println!("Query: {query}\n");
2242
2243 let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2244 let exprs = dialect.parse(query).unwrap();
2245 let expr = &exprs[0];
2246
2247 let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2248 println!("lineage(\"upper_name\"):");
2249 print_node(&node, 1);
2250
2251 let names = node.downstream_names();
2252 assert!(
2253 names.iter().any(|n| n == "users.name"),
2254 "Expected users.name in downstream, got: {:?}",
2255 names
2256 );
2257 }
2258
2259 #[test]
2260 fn test_lineage_upper_function() {
2261 let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2262 let node = lineage("upper_name", &expr, None, false).unwrap();
2263
2264 let names = node.downstream_names();
2265 assert!(
2266 names.iter().any(|n| n == "users.name"),
2267 "Expected users.name in downstream, got: {:?}",
2268 names
2269 );
2270 }
2271
2272 #[test]
2273 fn test_lineage_round_function() {
2274 let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2275 let node = lineage("rounded", &expr, None, false).unwrap();
2276
2277 let names = node.downstream_names();
2278 assert!(
2279 names.iter().any(|n| n == "products.price"),
2280 "Expected products.price in downstream, got: {:?}",
2281 names
2282 );
2283 }
2284
2285 #[test]
2286 fn test_lineage_coalesce_function() {
2287 let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2288 let node = lineage("val", &expr, None, false).unwrap();
2289
2290 let names = node.downstream_names();
2291 assert!(
2292 names.iter().any(|n| n == "t.a"),
2293 "Expected t.a in downstream, got: {:?}",
2294 names
2295 );
2296 assert!(
2297 names.iter().any(|n| n == "t.b"),
2298 "Expected t.b in downstream, got: {:?}",
2299 names
2300 );
2301 }
2302
2303 #[test]
2304 fn test_lineage_count_function() {
2305 let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2306 let node = lineage("cnt", &expr, None, false).unwrap();
2307
2308 let names = node.downstream_names();
2309 assert!(
2310 names.iter().any(|n| n == "t.id"),
2311 "Expected t.id in downstream, got: {:?}",
2312 names
2313 );
2314 }
2315
2316 #[test]
2317 fn test_lineage_sum_function() {
2318 let expr = parse("SELECT SUM(amount) AS total FROM t");
2319 let node = lineage("total", &expr, None, false).unwrap();
2320
2321 let names = node.downstream_names();
2322 assert!(
2323 names.iter().any(|n| n == "t.amount"),
2324 "Expected t.amount in downstream, got: {:?}",
2325 names
2326 );
2327 }
2328
2329 #[test]
2330 fn test_lineage_case_with_nested_functions() {
2331 let expr =
2332 parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2333 let node = lineage("result", &expr, None, false).unwrap();
2334
2335 let names = node.downstream_names();
2336 assert!(
2337 names.iter().any(|n| n == "t.x"),
2338 "Expected t.x in downstream, got: {:?}",
2339 names
2340 );
2341 assert!(
2342 names.iter().any(|n| n == "t.name"),
2343 "Expected t.name in downstream, got: {:?}",
2344 names
2345 );
2346 }
2347
2348 #[test]
2349 fn test_lineage_substring_function() {
2350 let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2351 let node = lineage("short", &expr, None, false).unwrap();
2352
2353 let names = node.downstream_names();
2354 assert!(
2355 names.iter().any(|n| n == "t.name"),
2356 "Expected t.name in downstream, got: {:?}",
2357 names
2358 );
2359 }
2360}