1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
11use crate::schema::Schema;
12use crate::scope::{build_scope, Scope};
13use crate::traversal::ExpressionWalk;
14use crate::{Error, Result};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LineageNode {
21 pub name: String,
23 pub expression: Expression,
25 pub source: Expression,
27 pub downstream: Vec<LineageNode>,
29 pub source_name: String,
31 pub reference_node_name: String,
33}
34
35impl LineageNode {
36 pub fn new(name: impl Into<String>, expression: Expression, source: Expression) -> Self {
38 Self {
39 name: name.into(),
40 expression,
41 source,
42 downstream: Vec::new(),
43 source_name: String::new(),
44 reference_node_name: String::new(),
45 }
46 }
47
48 pub fn walk(&self) -> LineageWalker<'_> {
50 LineageWalker { stack: vec![self] }
51 }
52
53 pub fn downstream_names(&self) -> Vec<String> {
55 self.downstream.iter().map(|n| n.name.clone()).collect()
56 }
57}
58
59pub struct LineageWalker<'a> {
61 stack: Vec<&'a LineageNode>,
62}
63
64impl<'a> Iterator for LineageWalker<'a> {
65 type Item = &'a LineageNode;
66
67 fn next(&mut self) -> Option<Self::Item> {
68 if let Some(node) = self.stack.pop() {
69 for child in node.downstream.iter().rev() {
71 self.stack.push(child);
72 }
73 Some(node)
74 } else {
75 None
76 }
77 }
78}
79
80enum ColumnRef<'a> {
86 Name(&'a str),
87 Index(usize),
88}
89
90pub fn lineage(
116 column: &str,
117 sql: &Expression,
118 dialect: Option<DialectType>,
119 trim_selects: bool,
120) -> Result<LineageNode> {
121 lineage_from_expression(column, sql, dialect, trim_selects)
122}
123
124pub fn lineage_with_schema(
140 column: &str,
141 sql: &Expression,
142 schema: Option<&dyn Schema>,
143 dialect: Option<DialectType>,
144 trim_selects: bool,
145) -> Result<LineageNode> {
146 let qualified_expression = if let Some(schema) = schema {
147 let options = if let Some(dialect_type) = dialect.or_else(|| schema.dialect()) {
148 QualifyColumnsOptions::new().with_dialect(dialect_type)
149 } else {
150 QualifyColumnsOptions::new()
151 };
152
153 qualify_columns(sql.clone(), schema, &options).map_err(|e| {
154 Error::internal(format!("Lineage qualification failed with schema: {}", e))
155 })?
156 } else {
157 sql.clone()
158 };
159
160 lineage_from_expression(column, &qualified_expression, dialect, trim_selects)
161}
162
163fn lineage_from_expression(
164 column: &str,
165 sql: &Expression,
166 dialect: Option<DialectType>,
167 trim_selects: bool,
168) -> Result<LineageNode> {
169 let scope = build_scope(sql);
170 to_node(
171 ColumnRef::Name(column),
172 &scope,
173 dialect,
174 "",
175 "",
176 "",
177 trim_selects,
178 )
179}
180
181pub fn get_source_tables(node: &LineageNode) -> HashSet<String> {
183 let mut tables = HashSet::new();
184 collect_source_tables(node, &mut tables);
185 tables
186}
187
188pub fn collect_source_tables(node: &LineageNode, tables: &mut HashSet<String>) {
190 if let Expression::Table(table) = &node.source {
191 tables.insert(table.name.name.clone());
192 }
193 for child in &node.downstream {
194 collect_source_tables(child, tables);
195 }
196}
197
198fn to_node(
204 column: ColumnRef<'_>,
205 scope: &Scope,
206 dialect: Option<DialectType>,
207 scope_name: &str,
208 source_name: &str,
209 reference_node_name: &str,
210 trim_selects: bool,
211) -> Result<LineageNode> {
212 to_node_inner(
213 column,
214 scope,
215 dialect,
216 scope_name,
217 source_name,
218 reference_node_name,
219 trim_selects,
220 &[],
221 )
222}
223
224fn to_node_inner(
225 column: ColumnRef<'_>,
226 scope: &Scope,
227 dialect: Option<DialectType>,
228 scope_name: &str,
229 source_name: &str,
230 reference_node_name: &str,
231 trim_selects: bool,
232 ancestor_cte_scopes: &[Scope],
233) -> Result<LineageNode> {
234 let scope_expr = &scope.expression;
235
236 let mut all_cte_scopes: Vec<&Scope> = scope.cte_scopes.iter().collect();
238 for s in ancestor_cte_scopes {
239 all_cte_scopes.push(s);
240 }
241
242 let effective_expr = match scope_expr {
245 Expression::Cte(cte) => &cte.this,
246 other => other,
247 };
248
249 if matches!(
251 effective_expr,
252 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
253 ) {
254 if matches!(scope_expr, Expression::Cte(_)) {
256 let mut inner_scope = Scope::new(effective_expr.clone());
257 inner_scope.union_scopes = scope.union_scopes.clone();
258 inner_scope.sources = scope.sources.clone();
259 inner_scope.cte_sources = scope.cte_sources.clone();
260 inner_scope.cte_scopes = scope.cte_scopes.clone();
261 inner_scope.derived_table_scopes = scope.derived_table_scopes.clone();
262 inner_scope.subquery_scopes = scope.subquery_scopes.clone();
263 return handle_set_operation(
264 &column,
265 &inner_scope,
266 dialect,
267 scope_name,
268 source_name,
269 reference_node_name,
270 trim_selects,
271 ancestor_cte_scopes,
272 );
273 }
274 return handle_set_operation(
275 &column,
276 scope,
277 dialect,
278 scope_name,
279 source_name,
280 reference_node_name,
281 trim_selects,
282 ancestor_cte_scopes,
283 );
284 }
285
286 let select_expr = find_select_expr(effective_expr, &column)?;
288 let column_name = resolve_column_name(&column, &select_expr);
289
290 let node_source = if trim_selects {
292 trim_source(effective_expr, &select_expr)
293 } else {
294 effective_expr.clone()
295 };
296
297 let mut node = LineageNode::new(&column_name, select_expr.clone(), node_source);
299 node.source_name = source_name.to_string();
300 node.reference_node_name = reference_node_name.to_string();
301
302 if matches!(&select_expr, Expression::Star(_)) {
304 for (name, source_info) in &scope.sources {
305 let child = LineageNode::new(
306 format!("{}.*", name),
307 Expression::Star(crate::expressions::Star {
308 table: None,
309 except: None,
310 replace: None,
311 rename: None,
312 trailing_comments: vec![],
313 span: None,
314 }),
315 source_info.expression.clone(),
316 );
317 node.downstream.push(child);
318 }
319 return Ok(node);
320 }
321
322 let subqueries: Vec<&Expression> =
324 select_expr.find_all(|e| matches!(e, Expression::Subquery(sq) if sq.alias.is_none()));
325 for sq_expr in subqueries {
326 if let Expression::Subquery(sq) = sq_expr {
327 for sq_scope in &scope.subquery_scopes {
328 if sq_scope.expression == sq.this {
329 if let Ok(child) = to_node_inner(
330 ColumnRef::Index(0),
331 sq_scope,
332 dialect,
333 &column_name,
334 "",
335 "",
336 trim_selects,
337 ancestor_cte_scopes,
338 ) {
339 node.downstream.push(child);
340 }
341 break;
342 }
343 }
344 }
345 }
346
347 let col_refs = find_column_refs_in_expr(&select_expr);
349 for col_ref in col_refs {
350 let col_name = &col_ref.column;
351 if let Some(ref table_id) = col_ref.table {
352 let tbl = &table_id.name;
353 resolve_qualified_column(
354 &mut node,
355 scope,
356 dialect,
357 tbl,
358 col_name,
359 &column_name,
360 trim_selects,
361 &all_cte_scopes,
362 );
363 } else {
364 resolve_unqualified_column(
365 &mut node,
366 scope,
367 dialect,
368 col_name,
369 &column_name,
370 trim_selects,
371 &all_cte_scopes,
372 );
373 }
374 }
375
376 Ok(node)
377}
378
379fn handle_set_operation(
384 column: &ColumnRef<'_>,
385 scope: &Scope,
386 dialect: Option<DialectType>,
387 scope_name: &str,
388 source_name: &str,
389 reference_node_name: &str,
390 trim_selects: bool,
391 ancestor_cte_scopes: &[Scope],
392) -> Result<LineageNode> {
393 let scope_expr = &scope.expression;
394
395 let col_index = match column {
397 ColumnRef::Name(name) => column_to_index(scope_expr, name)?,
398 ColumnRef::Index(i) => *i,
399 };
400
401 let col_name = match column {
402 ColumnRef::Name(name) => name.to_string(),
403 ColumnRef::Index(_) => format!("_{col_index}"),
404 };
405
406 let mut node = LineageNode::new(&col_name, scope_expr.clone(), scope_expr.clone());
407 node.source_name = source_name.to_string();
408 node.reference_node_name = reference_node_name.to_string();
409
410 for branch_scope in &scope.union_scopes {
412 if let Ok(child) = to_node_inner(
413 ColumnRef::Index(col_index),
414 branch_scope,
415 dialect,
416 scope_name,
417 "",
418 "",
419 trim_selects,
420 ancestor_cte_scopes,
421 ) {
422 node.downstream.push(child);
423 }
424 }
425
426 Ok(node)
427}
428
429fn resolve_qualified_column(
434 node: &mut LineageNode,
435 scope: &Scope,
436 dialect: Option<DialectType>,
437 table: &str,
438 col_name: &str,
439 parent_name: &str,
440 trim_selects: bool,
441 all_cte_scopes: &[&Scope],
442) {
443 if scope.cte_sources.contains_key(table) {
445 if let Some(child_scope) = find_child_scope_in(all_cte_scopes, scope, table) {
446 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
448 if let Ok(child) = to_node_inner(
449 ColumnRef::Name(col_name),
450 child_scope,
451 dialect,
452 parent_name,
453 table,
454 parent_name,
455 trim_selects,
456 &ancestors,
457 ) {
458 node.downstream.push(child);
459 return;
460 }
461 }
462 }
463
464 if let Some(source_info) = scope.sources.get(table) {
466 if source_info.is_scope {
467 if let Some(child_scope) = find_child_scope(scope, table) {
468 let ancestors: Vec<Scope> = all_cte_scopes.iter().map(|s| (*s).clone()).collect();
469 if let Ok(child) = to_node_inner(
470 ColumnRef::Name(col_name),
471 child_scope,
472 dialect,
473 parent_name,
474 table,
475 parent_name,
476 trim_selects,
477 &ancestors,
478 ) {
479 node.downstream.push(child);
480 return;
481 }
482 }
483 }
484 }
485
486 if let Some(source_info) = scope.sources.get(table) {
489 if !source_info.is_scope {
490 node.downstream.push(make_table_column_node_from_source(
491 table,
492 col_name,
493 &source_info.expression,
494 ));
495 return;
496 }
497 }
498
499 node.downstream
501 .push(make_table_column_node(table, col_name));
502}
503
504fn resolve_unqualified_column(
505 node: &mut LineageNode,
506 scope: &Scope,
507 dialect: Option<DialectType>,
508 col_name: &str,
509 parent_name: &str,
510 trim_selects: bool,
511 all_cte_scopes: &[&Scope],
512) {
513 let from_source_names = source_names_from_from_join(scope);
517
518 if from_source_names.len() == 1 {
519 let tbl = &from_source_names[0];
520 resolve_qualified_column(
521 node,
522 scope,
523 dialect,
524 tbl,
525 col_name,
526 parent_name,
527 trim_selects,
528 all_cte_scopes,
529 );
530 return;
531 }
532
533 let child = LineageNode::new(
535 col_name.to_string(),
536 Expression::Column(crate::expressions::Column {
537 name: crate::expressions::Identifier::new(col_name.to_string()),
538 table: None,
539 join_mark: false,
540 trailing_comments: vec![],
541 span: None,
542 }),
543 node.source.clone(),
544 );
545 node.downstream.push(child);
546}
547
548fn source_names_from_from_join(scope: &Scope) -> Vec<String> {
549 fn source_name(expr: &Expression) -> Option<String> {
550 match expr {
551 Expression::Table(table) => Some(
552 table
553 .alias
554 .as_ref()
555 .map(|a| a.name.clone())
556 .unwrap_or_else(|| table.name.name.clone()),
557 ),
558 Expression::Subquery(subquery) => {
559 subquery.alias.as_ref().map(|alias| alias.name.clone())
560 }
561 Expression::Paren(paren) => source_name(&paren.this),
562 _ => None,
563 }
564 }
565
566 let effective_expr = match &scope.expression {
567 Expression::Cte(cte) => &cte.this,
568 expr => expr,
569 };
570
571 let mut names = Vec::new();
572 let mut seen = std::collections::HashSet::new();
573
574 if let Expression::Select(select) = effective_expr {
575 if let Some(from) = &select.from {
576 for expr in &from.expressions {
577 if let Some(name) = source_name(expr) {
578 if !name.is_empty() && seen.insert(name.clone()) {
579 names.push(name);
580 }
581 }
582 }
583 }
584 for join in &select.joins {
585 if let Some(name) = source_name(&join.this) {
586 if !name.is_empty() && seen.insert(name.clone()) {
587 names.push(name);
588 }
589 }
590 }
591 }
592
593 names
594}
595
596fn get_alias_or_name(expr: &Expression) -> Option<String> {
602 match expr {
603 Expression::Alias(alias) => Some(alias.alias.name.clone()),
604 Expression::Column(col) => Some(col.name.name.clone()),
605 Expression::Identifier(id) => Some(id.name.clone()),
606 Expression::Star(_) => Some("*".to_string()),
607 _ => None,
608 }
609}
610
611fn resolve_column_name(column: &ColumnRef<'_>, select_expr: &Expression) -> String {
613 match column {
614 ColumnRef::Name(n) => n.to_string(),
615 ColumnRef::Index(_) => get_alias_or_name(select_expr).unwrap_or_else(|| "?".to_string()),
616 }
617}
618
619fn find_select_expr(scope_expr: &Expression, column: &ColumnRef<'_>) -> Result<Expression> {
621 if let Expression::Select(ref select) = scope_expr {
622 match column {
623 ColumnRef::Name(name) => {
624 for expr in &select.expressions {
625 if get_alias_or_name(expr).as_deref() == Some(name) {
626 return Ok(expr.clone());
627 }
628 }
629 Err(crate::error::Error::parse(
630 format!("Cannot find column '{}' in query", name),
631 0,
632 0,
633 0,
634 0,
635 ))
636 }
637 ColumnRef::Index(idx) => select.expressions.get(*idx).cloned().ok_or_else(|| {
638 crate::error::Error::parse(format!("Column index {} out of range", idx), 0, 0, 0, 0)
639 }),
640 }
641 } else {
642 Err(crate::error::Error::parse(
643 "Expected SELECT expression for column lookup",
644 0,
645 0,
646 0,
647 0,
648 ))
649 }
650}
651
652fn column_to_index(set_op_expr: &Expression, name: &str) -> Result<usize> {
654 let mut expr = set_op_expr;
655 loop {
656 match expr {
657 Expression::Union(u) => expr = &u.left,
658 Expression::Intersect(i) => expr = &i.left,
659 Expression::Except(e) => expr = &e.left,
660 Expression::Select(select) => {
661 for (i, e) in select.expressions.iter().enumerate() {
662 if get_alias_or_name(e).as_deref() == Some(name) {
663 return Ok(i);
664 }
665 }
666 return Err(crate::error::Error::parse(
667 format!("Cannot find column '{}' in set operation", name),
668 0,
669 0,
670 0,
671 0,
672 ));
673 }
674 _ => {
675 return Err(crate::error::Error::parse(
676 "Expected SELECT or set operation",
677 0,
678 0,
679 0,
680 0,
681 ))
682 }
683 }
684 }
685}
686
687fn trim_source(select_expr: &Expression, target_expr: &Expression) -> Expression {
689 if let Expression::Select(select) = select_expr {
690 let mut trimmed = select.as_ref().clone();
691 trimmed.expressions = vec![target_expr.clone()];
692 Expression::Select(Box::new(trimmed))
693 } else {
694 select_expr.clone()
695 }
696}
697
698fn find_child_scope<'a>(scope: &'a Scope, source_name: &str) -> Option<&'a Scope> {
700 if scope.cte_sources.contains_key(source_name) {
702 for cte_scope in &scope.cte_scopes {
703 if let Expression::Cte(cte) = &cte_scope.expression {
704 if cte.alias.name == source_name {
705 return Some(cte_scope);
706 }
707 }
708 }
709 }
710
711 if let Some(source_info) = scope.sources.get(source_name) {
713 if source_info.is_scope && !scope.cte_sources.contains_key(source_name) {
714 if let Expression::Subquery(sq) = &source_info.expression {
715 for dt_scope in &scope.derived_table_scopes {
716 if dt_scope.expression == sq.this {
717 return Some(dt_scope);
718 }
719 }
720 }
721 }
722 }
723
724 None
725}
726
727fn find_child_scope_in<'a>(
731 all_cte_scopes: &[&'a Scope],
732 scope: &'a Scope,
733 source_name: &str,
734) -> Option<&'a Scope> {
735 for cte_scope in &scope.cte_scopes {
737 if let Expression::Cte(cte) = &cte_scope.expression {
738 if cte.alias.name == source_name {
739 return Some(cte_scope);
740 }
741 }
742 }
743
744 for cte_scope in all_cte_scopes {
746 if let Expression::Cte(cte) = &cte_scope.expression {
747 if cte.alias.name == source_name {
748 return Some(cte_scope);
749 }
750 }
751 }
752
753 if let Some(source_info) = scope.sources.get(source_name) {
755 if source_info.is_scope {
756 if let Expression::Subquery(sq) = &source_info.expression {
757 for dt_scope in &scope.derived_table_scopes {
758 if dt_scope.expression == sq.this {
759 return Some(dt_scope);
760 }
761 }
762 }
763 }
764 }
765
766 None
767}
768
769fn make_table_column_node(table: &str, column: &str) -> LineageNode {
771 let mut node = LineageNode::new(
772 format!("{}.{}", table, column),
773 Expression::Column(crate::expressions::Column {
774 name: crate::expressions::Identifier::new(column.to_string()),
775 table: Some(crate::expressions::Identifier::new(table.to_string())),
776 join_mark: false,
777 trailing_comments: vec![],
778 span: None,
779 }),
780 Expression::Table(crate::expressions::TableRef::new(table)),
781 );
782 node.source_name = table.to_string();
783 node
784}
785
786fn table_name_from_table_ref(table_ref: &crate::expressions::TableRef) -> String {
787 let mut parts: Vec<String> = Vec::new();
788 if let Some(catalog) = &table_ref.catalog {
789 parts.push(catalog.name.clone());
790 }
791 if let Some(schema) = &table_ref.schema {
792 parts.push(schema.name.clone());
793 }
794 parts.push(table_ref.name.name.clone());
795 parts.join(".")
796}
797
798fn make_table_column_node_from_source(
799 table_alias: &str,
800 column: &str,
801 source: &Expression,
802) -> LineageNode {
803 let mut node = LineageNode::new(
804 format!("{}.{}", table_alias, column),
805 Expression::Column(crate::expressions::Column {
806 name: crate::expressions::Identifier::new(column.to_string()),
807 table: Some(crate::expressions::Identifier::new(table_alias.to_string())),
808 join_mark: false,
809 trailing_comments: vec![],
810 span: None,
811 }),
812 source.clone(),
813 );
814
815 if let Expression::Table(table_ref) = source {
816 node.source_name = table_name_from_table_ref(table_ref);
817 } else {
818 node.source_name = table_alias.to_string();
819 }
820
821 node
822}
823
824#[derive(Debug, Clone)]
826struct SimpleColumnRef {
827 table: Option<crate::expressions::Identifier>,
828 column: String,
829}
830
831fn find_column_refs_in_expr(expr: &Expression) -> Vec<SimpleColumnRef> {
833 let mut refs = Vec::new();
834 collect_column_refs(expr, &mut refs);
835 refs
836}
837
838fn collect_column_refs(expr: &Expression, refs: &mut Vec<SimpleColumnRef>) {
839 let mut stack: Vec<&Expression> = vec![expr];
840
841 while let Some(current) = stack.pop() {
842 match current {
843 Expression::Column(col) => {
845 refs.push(SimpleColumnRef {
846 table: col.table.clone(),
847 column: col.name.name.clone(),
848 });
849 }
850
851 Expression::Subquery(_) | Expression::Exists(_) => {}
853
854 Expression::And(op)
856 | Expression::Or(op)
857 | Expression::Eq(op)
858 | Expression::Neq(op)
859 | Expression::Lt(op)
860 | Expression::Lte(op)
861 | Expression::Gt(op)
862 | Expression::Gte(op)
863 | Expression::Add(op)
864 | Expression::Sub(op)
865 | Expression::Mul(op)
866 | Expression::Div(op)
867 | Expression::Mod(op)
868 | Expression::BitwiseAnd(op)
869 | Expression::BitwiseOr(op)
870 | Expression::BitwiseXor(op)
871 | Expression::BitwiseLeftShift(op)
872 | Expression::BitwiseRightShift(op)
873 | Expression::Concat(op)
874 | Expression::Adjacent(op)
875 | Expression::TsMatch(op)
876 | Expression::PropertyEQ(op)
877 | Expression::ArrayContainsAll(op)
878 | Expression::ArrayContainedBy(op)
879 | Expression::ArrayOverlaps(op)
880 | Expression::JSONBContainsAllTopKeys(op)
881 | Expression::JSONBContainsAnyTopKeys(op)
882 | Expression::JSONBDeleteAtPath(op)
883 | Expression::ExtendsLeft(op)
884 | Expression::ExtendsRight(op)
885 | Expression::Is(op)
886 | Expression::MemberOf(op)
887 | Expression::NullSafeEq(op)
888 | Expression::NullSafeNeq(op)
889 | Expression::Glob(op)
890 | Expression::Match(op) => {
891 stack.push(&op.left);
892 stack.push(&op.right);
893 }
894
895 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
897 stack.push(&u.this);
898 }
899
900 Expression::Upper(f)
902 | Expression::Lower(f)
903 | Expression::Length(f)
904 | Expression::LTrim(f)
905 | Expression::RTrim(f)
906 | Expression::Reverse(f)
907 | Expression::Abs(f)
908 | Expression::Sqrt(f)
909 | Expression::Cbrt(f)
910 | Expression::Ln(f)
911 | Expression::Exp(f)
912 | Expression::Sign(f)
913 | Expression::Date(f)
914 | Expression::Time(f)
915 | Expression::DateFromUnixDate(f)
916 | Expression::UnixDate(f)
917 | Expression::UnixSeconds(f)
918 | Expression::UnixMillis(f)
919 | Expression::UnixMicros(f)
920 | Expression::TimeStrToDate(f)
921 | Expression::DateToDi(f)
922 | Expression::DiToDate(f)
923 | Expression::TsOrDiToDi(f)
924 | Expression::TsOrDsToDatetime(f)
925 | Expression::TsOrDsToTimestamp(f)
926 | Expression::YearOfWeek(f)
927 | Expression::YearOfWeekIso(f)
928 | Expression::Initcap(f)
929 | Expression::Ascii(f)
930 | Expression::Chr(f)
931 | Expression::Soundex(f)
932 | Expression::ByteLength(f)
933 | Expression::Hex(f)
934 | Expression::LowerHex(f)
935 | Expression::Unicode(f)
936 | Expression::Radians(f)
937 | Expression::Degrees(f)
938 | Expression::Sin(f)
939 | Expression::Cos(f)
940 | Expression::Tan(f)
941 | Expression::Asin(f)
942 | Expression::Acos(f)
943 | Expression::Atan(f)
944 | Expression::IsNan(f)
945 | Expression::IsInf(f)
946 | Expression::ArrayLength(f)
947 | Expression::ArraySize(f)
948 | Expression::Cardinality(f)
949 | Expression::ArrayReverse(f)
950 | Expression::ArrayDistinct(f)
951 | Expression::ArrayFlatten(f)
952 | Expression::ArrayCompact(f)
953 | Expression::Explode(f)
954 | Expression::ExplodeOuter(f)
955 | Expression::ToArray(f)
956 | Expression::MapFromEntries(f)
957 | Expression::MapKeys(f)
958 | Expression::MapValues(f)
959 | Expression::JsonArrayLength(f)
960 | Expression::JsonKeys(f)
961 | Expression::JsonType(f)
962 | Expression::ParseJson(f)
963 | Expression::ToJson(f)
964 | Expression::Typeof(f)
965 | Expression::BitwiseCount(f)
966 | Expression::Year(f)
967 | Expression::Month(f)
968 | Expression::Day(f)
969 | Expression::Hour(f)
970 | Expression::Minute(f)
971 | Expression::Second(f)
972 | Expression::DayOfWeek(f)
973 | Expression::DayOfWeekIso(f)
974 | Expression::DayOfMonth(f)
975 | Expression::DayOfYear(f)
976 | Expression::WeekOfYear(f)
977 | Expression::Quarter(f)
978 | Expression::Epoch(f)
979 | Expression::EpochMs(f)
980 | Expression::TimeStrToUnix(f)
981 | Expression::SHA(f)
982 | Expression::SHA1Digest(f)
983 | Expression::TimeToUnix(f)
984 | Expression::JSONBool(f)
985 | Expression::Int64(f)
986 | Expression::MD5NumberLower64(f)
987 | Expression::MD5NumberUpper64(f)
988 | Expression::DateStrToDate(f)
989 | Expression::DateToDateStr(f) => {
990 stack.push(&f.this);
991 }
992
993 Expression::Power(f)
995 | Expression::NullIf(f)
996 | Expression::IfNull(f)
997 | Expression::Nvl(f)
998 | Expression::UnixToTimeStr(f)
999 | Expression::Contains(f)
1000 | Expression::StartsWith(f)
1001 | Expression::EndsWith(f)
1002 | Expression::Levenshtein(f)
1003 | Expression::ModFunc(f)
1004 | Expression::Atan2(f)
1005 | Expression::IntDiv(f)
1006 | Expression::AddMonths(f)
1007 | Expression::MonthsBetween(f)
1008 | Expression::NextDay(f)
1009 | Expression::ArrayContains(f)
1010 | Expression::ArrayPosition(f)
1011 | Expression::ArrayAppend(f)
1012 | Expression::ArrayPrepend(f)
1013 | Expression::ArrayUnion(f)
1014 | Expression::ArrayExcept(f)
1015 | Expression::ArrayRemove(f)
1016 | Expression::StarMap(f)
1017 | Expression::MapFromArrays(f)
1018 | Expression::MapContainsKey(f)
1019 | Expression::ElementAt(f)
1020 | Expression::JsonMergePatch(f)
1021 | Expression::JSONBContains(f)
1022 | Expression::JSONBExtract(f) => {
1023 stack.push(&f.this);
1024 stack.push(&f.expression);
1025 }
1026
1027 Expression::Greatest(f)
1029 | Expression::Least(f)
1030 | Expression::Coalesce(f)
1031 | Expression::ArrayConcat(f)
1032 | Expression::ArrayIntersect(f)
1033 | Expression::ArrayZip(f)
1034 | Expression::MapConcat(f)
1035 | Expression::JsonArray(f) => {
1036 for e in &f.expressions {
1037 stack.push(e);
1038 }
1039 }
1040
1041 Expression::Sum(f)
1043 | Expression::Avg(f)
1044 | Expression::Min(f)
1045 | Expression::Max(f)
1046 | Expression::ArrayAgg(f)
1047 | Expression::CountIf(f)
1048 | Expression::Stddev(f)
1049 | Expression::StddevPop(f)
1050 | Expression::StddevSamp(f)
1051 | Expression::Variance(f)
1052 | Expression::VarPop(f)
1053 | Expression::VarSamp(f)
1054 | Expression::Median(f)
1055 | Expression::Mode(f)
1056 | Expression::First(f)
1057 | Expression::Last(f)
1058 | Expression::AnyValue(f)
1059 | Expression::ApproxDistinct(f)
1060 | Expression::ApproxCountDistinct(f)
1061 | Expression::LogicalAnd(f)
1062 | Expression::LogicalOr(f)
1063 | Expression::Skewness(f)
1064 | Expression::ArrayConcatAgg(f)
1065 | Expression::ArrayUniqueAgg(f)
1066 | Expression::BoolXorAgg(f)
1067 | Expression::BitwiseAndAgg(f)
1068 | Expression::BitwiseOrAgg(f)
1069 | Expression::BitwiseXorAgg(f) => {
1070 stack.push(&f.this);
1071 if let Some(ref filter) = f.filter {
1072 stack.push(filter);
1073 }
1074 if let Some((ref expr, _)) = f.having_max {
1075 stack.push(expr);
1076 }
1077 if let Some(ref limit) = f.limit {
1078 stack.push(limit);
1079 }
1080 }
1081
1082 Expression::Function(func) => {
1084 for arg in &func.args {
1085 stack.push(arg);
1086 }
1087 }
1088 Expression::AggregateFunction(func) => {
1089 for arg in &func.args {
1090 stack.push(arg);
1091 }
1092 if let Some(ref filter) = func.filter {
1093 stack.push(filter);
1094 }
1095 if let Some(ref limit) = func.limit {
1096 stack.push(limit);
1097 }
1098 }
1099
1100 Expression::WindowFunction(wf) => {
1102 stack.push(&wf.this);
1103 }
1104
1105 Expression::Alias(a) => {
1107 stack.push(&a.this);
1108 }
1109 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
1110 stack.push(&c.this);
1111 if let Some(ref fmt) = c.format {
1112 stack.push(fmt);
1113 }
1114 if let Some(ref def) = c.default {
1115 stack.push(def);
1116 }
1117 }
1118 Expression::Paren(p) => {
1119 stack.push(&p.this);
1120 }
1121 Expression::Annotated(a) => {
1122 stack.push(&a.this);
1123 }
1124 Expression::Case(case) => {
1125 if let Some(ref operand) = case.operand {
1126 stack.push(operand);
1127 }
1128 for (cond, result) in &case.whens {
1129 stack.push(cond);
1130 stack.push(result);
1131 }
1132 if let Some(ref else_expr) = case.else_ {
1133 stack.push(else_expr);
1134 }
1135 }
1136 Expression::Collation(c) => {
1137 stack.push(&c.this);
1138 }
1139 Expression::In(i) => {
1140 stack.push(&i.this);
1141 for e in &i.expressions {
1142 stack.push(e);
1143 }
1144 if let Some(ref q) = i.query {
1145 stack.push(q);
1146 }
1147 if let Some(ref u) = i.unnest {
1148 stack.push(u);
1149 }
1150 }
1151 Expression::Between(b) => {
1152 stack.push(&b.this);
1153 stack.push(&b.low);
1154 stack.push(&b.high);
1155 }
1156 Expression::IsNull(n) => {
1157 stack.push(&n.this);
1158 }
1159 Expression::IsTrue(t) | Expression::IsFalse(t) => {
1160 stack.push(&t.this);
1161 }
1162 Expression::IsJson(j) => {
1163 stack.push(&j.this);
1164 }
1165 Expression::Like(l) | Expression::ILike(l) => {
1166 stack.push(&l.left);
1167 stack.push(&l.right);
1168 if let Some(ref esc) = l.escape {
1169 stack.push(esc);
1170 }
1171 }
1172 Expression::SimilarTo(s) => {
1173 stack.push(&s.this);
1174 stack.push(&s.pattern);
1175 if let Some(ref esc) = s.escape {
1176 stack.push(esc);
1177 }
1178 }
1179 Expression::Ordered(o) => {
1180 stack.push(&o.this);
1181 }
1182 Expression::Array(a) => {
1183 for e in &a.expressions {
1184 stack.push(e);
1185 }
1186 }
1187 Expression::Tuple(t) => {
1188 for e in &t.expressions {
1189 stack.push(e);
1190 }
1191 }
1192 Expression::Struct(s) => {
1193 for (_, e) in &s.fields {
1194 stack.push(e);
1195 }
1196 }
1197 Expression::Subscript(s) => {
1198 stack.push(&s.this);
1199 stack.push(&s.index);
1200 }
1201 Expression::Dot(d) => {
1202 stack.push(&d.this);
1203 }
1204 Expression::MethodCall(m) => {
1205 stack.push(&m.this);
1206 for arg in &m.args {
1207 stack.push(arg);
1208 }
1209 }
1210 Expression::ArraySlice(s) => {
1211 stack.push(&s.this);
1212 if let Some(ref start) = s.start {
1213 stack.push(start);
1214 }
1215 if let Some(ref end) = s.end {
1216 stack.push(end);
1217 }
1218 }
1219 Expression::Lambda(l) => {
1220 stack.push(&l.body);
1221 }
1222 Expression::NamedArgument(n) => {
1223 stack.push(&n.value);
1224 }
1225 Expression::BracedWildcard(e) | Expression::ReturnStmt(e) => {
1226 stack.push(e);
1227 }
1228
1229 Expression::Substring(f) => {
1231 stack.push(&f.this);
1232 stack.push(&f.start);
1233 if let Some(ref len) = f.length {
1234 stack.push(len);
1235 }
1236 }
1237 Expression::Trim(f) => {
1238 stack.push(&f.this);
1239 if let Some(ref chars) = f.characters {
1240 stack.push(chars);
1241 }
1242 }
1243 Expression::Replace(f) => {
1244 stack.push(&f.this);
1245 stack.push(&f.old);
1246 stack.push(&f.new);
1247 }
1248 Expression::IfFunc(f) => {
1249 stack.push(&f.condition);
1250 stack.push(&f.true_value);
1251 if let Some(ref fv) = f.false_value {
1252 stack.push(fv);
1253 }
1254 }
1255 Expression::Nvl2(f) => {
1256 stack.push(&f.this);
1257 stack.push(&f.true_value);
1258 stack.push(&f.false_value);
1259 }
1260 Expression::ConcatWs(f) => {
1261 stack.push(&f.separator);
1262 for e in &f.expressions {
1263 stack.push(e);
1264 }
1265 }
1266 Expression::Count(f) => {
1267 if let Some(ref this) = f.this {
1268 stack.push(this);
1269 }
1270 if let Some(ref filter) = f.filter {
1271 stack.push(filter);
1272 }
1273 }
1274 Expression::GroupConcat(f) => {
1275 stack.push(&f.this);
1276 if let Some(ref sep) = f.separator {
1277 stack.push(sep);
1278 }
1279 if let Some(ref filter) = f.filter {
1280 stack.push(filter);
1281 }
1282 }
1283 Expression::StringAgg(f) => {
1284 stack.push(&f.this);
1285 if let Some(ref sep) = f.separator {
1286 stack.push(sep);
1287 }
1288 if let Some(ref filter) = f.filter {
1289 stack.push(filter);
1290 }
1291 if let Some(ref limit) = f.limit {
1292 stack.push(limit);
1293 }
1294 }
1295 Expression::ListAgg(f) => {
1296 stack.push(&f.this);
1297 if let Some(ref sep) = f.separator {
1298 stack.push(sep);
1299 }
1300 if let Some(ref filter) = f.filter {
1301 stack.push(filter);
1302 }
1303 }
1304 Expression::SumIf(f) => {
1305 stack.push(&f.this);
1306 stack.push(&f.condition);
1307 if let Some(ref filter) = f.filter {
1308 stack.push(filter);
1309 }
1310 }
1311 Expression::DateAdd(f) | Expression::DateSub(f) => {
1312 stack.push(&f.this);
1313 stack.push(&f.interval);
1314 }
1315 Expression::DateDiff(f) => {
1316 stack.push(&f.this);
1317 stack.push(&f.expression);
1318 }
1319 Expression::DateTrunc(f) | Expression::TimestampTrunc(f) => {
1320 stack.push(&f.this);
1321 }
1322 Expression::Extract(f) => {
1323 stack.push(&f.this);
1324 }
1325 Expression::Round(f) => {
1326 stack.push(&f.this);
1327 if let Some(ref d) = f.decimals {
1328 stack.push(d);
1329 }
1330 }
1331 Expression::Floor(f) => {
1332 stack.push(&f.this);
1333 if let Some(ref s) = f.scale {
1334 stack.push(s);
1335 }
1336 if let Some(ref t) = f.to {
1337 stack.push(t);
1338 }
1339 }
1340 Expression::Ceil(f) => {
1341 stack.push(&f.this);
1342 if let Some(ref d) = f.decimals {
1343 stack.push(d);
1344 }
1345 if let Some(ref t) = f.to {
1346 stack.push(t);
1347 }
1348 }
1349 Expression::Log(f) => {
1350 stack.push(&f.this);
1351 if let Some(ref b) = f.base {
1352 stack.push(b);
1353 }
1354 }
1355 Expression::AtTimeZone(f) => {
1356 stack.push(&f.this);
1357 stack.push(&f.zone);
1358 }
1359 Expression::Lead(f) | Expression::Lag(f) => {
1360 stack.push(&f.this);
1361 if let Some(ref off) = f.offset {
1362 stack.push(off);
1363 }
1364 if let Some(ref def) = f.default {
1365 stack.push(def);
1366 }
1367 }
1368 Expression::FirstValue(f) | Expression::LastValue(f) => {
1369 stack.push(&f.this);
1370 }
1371 Expression::NthValue(f) => {
1372 stack.push(&f.this);
1373 stack.push(&f.offset);
1374 }
1375 Expression::Position(f) => {
1376 stack.push(&f.substring);
1377 stack.push(&f.string);
1378 if let Some(ref start) = f.start {
1379 stack.push(start);
1380 }
1381 }
1382 Expression::Decode(f) => {
1383 stack.push(&f.this);
1384 for (search, result) in &f.search_results {
1385 stack.push(search);
1386 stack.push(result);
1387 }
1388 if let Some(ref def) = f.default {
1389 stack.push(def);
1390 }
1391 }
1392 Expression::CharFunc(f) => {
1393 for arg in &f.args {
1394 stack.push(arg);
1395 }
1396 }
1397 Expression::ArraySort(f) => {
1398 stack.push(&f.this);
1399 if let Some(ref cmp) = f.comparator {
1400 stack.push(cmp);
1401 }
1402 }
1403 Expression::ArrayJoin(f) | Expression::ArrayToString(f) => {
1404 stack.push(&f.this);
1405 stack.push(&f.separator);
1406 if let Some(ref nr) = f.null_replacement {
1407 stack.push(nr);
1408 }
1409 }
1410 Expression::ArrayFilter(f) => {
1411 stack.push(&f.this);
1412 stack.push(&f.filter);
1413 }
1414 Expression::ArrayTransform(f) => {
1415 stack.push(&f.this);
1416 stack.push(&f.transform);
1417 }
1418 Expression::Sequence(f)
1419 | Expression::Generate(f)
1420 | Expression::ExplodingGenerateSeries(f) => {
1421 stack.push(&f.start);
1422 stack.push(&f.stop);
1423 if let Some(ref step) = f.step {
1424 stack.push(step);
1425 }
1426 }
1427 Expression::JsonExtract(f)
1428 | Expression::JsonExtractScalar(f)
1429 | Expression::JsonQuery(f)
1430 | Expression::JsonValue(f) => {
1431 stack.push(&f.this);
1432 stack.push(&f.path);
1433 }
1434 Expression::JsonExtractPath(f) | Expression::JsonRemove(f) => {
1435 stack.push(&f.this);
1436 for p in &f.paths {
1437 stack.push(p);
1438 }
1439 }
1440 Expression::JsonObject(f) => {
1441 for (k, v) in &f.pairs {
1442 stack.push(k);
1443 stack.push(v);
1444 }
1445 }
1446 Expression::JsonSet(f) | Expression::JsonInsert(f) => {
1447 stack.push(&f.this);
1448 for (path, val) in &f.path_values {
1449 stack.push(path);
1450 stack.push(val);
1451 }
1452 }
1453 Expression::Overlay(f) => {
1454 stack.push(&f.this);
1455 stack.push(&f.replacement);
1456 stack.push(&f.from);
1457 if let Some(ref len) = f.length {
1458 stack.push(len);
1459 }
1460 }
1461 Expression::Convert(f) => {
1462 stack.push(&f.this);
1463 if let Some(ref style) = f.style {
1464 stack.push(style);
1465 }
1466 }
1467 Expression::ApproxPercentile(f) => {
1468 stack.push(&f.this);
1469 stack.push(&f.percentile);
1470 if let Some(ref acc) = f.accuracy {
1471 stack.push(acc);
1472 }
1473 if let Some(ref filter) = f.filter {
1474 stack.push(filter);
1475 }
1476 }
1477 Expression::Percentile(f)
1478 | Expression::PercentileCont(f)
1479 | Expression::PercentileDisc(f) => {
1480 stack.push(&f.this);
1481 stack.push(&f.percentile);
1482 if let Some(ref filter) = f.filter {
1483 stack.push(filter);
1484 }
1485 }
1486 Expression::WithinGroup(f) => {
1487 stack.push(&f.this);
1488 }
1489 Expression::Left(f) | Expression::Right(f) => {
1490 stack.push(&f.this);
1491 stack.push(&f.length);
1492 }
1493 Expression::Repeat(f) => {
1494 stack.push(&f.this);
1495 stack.push(&f.times);
1496 }
1497 Expression::Lpad(f) | Expression::Rpad(f) => {
1498 stack.push(&f.this);
1499 stack.push(&f.length);
1500 if let Some(ref fill) = f.fill {
1501 stack.push(fill);
1502 }
1503 }
1504 Expression::Split(f) => {
1505 stack.push(&f.this);
1506 stack.push(&f.delimiter);
1507 }
1508 Expression::RegexpLike(f) => {
1509 stack.push(&f.this);
1510 stack.push(&f.pattern);
1511 if let Some(ref flags) = f.flags {
1512 stack.push(flags);
1513 }
1514 }
1515 Expression::RegexpReplace(f) => {
1516 stack.push(&f.this);
1517 stack.push(&f.pattern);
1518 stack.push(&f.replacement);
1519 if let Some(ref flags) = f.flags {
1520 stack.push(flags);
1521 }
1522 }
1523 Expression::RegexpExtract(f) => {
1524 stack.push(&f.this);
1525 stack.push(&f.pattern);
1526 if let Some(ref group) = f.group {
1527 stack.push(group);
1528 }
1529 }
1530 Expression::ToDate(f) => {
1531 stack.push(&f.this);
1532 if let Some(ref fmt) = f.format {
1533 stack.push(fmt);
1534 }
1535 }
1536 Expression::ToTimestamp(f) => {
1537 stack.push(&f.this);
1538 if let Some(ref fmt) = f.format {
1539 stack.push(fmt);
1540 }
1541 }
1542 Expression::DateFormat(f) | Expression::FormatDate(f) => {
1543 stack.push(&f.this);
1544 stack.push(&f.format);
1545 }
1546 Expression::LastDay(f) => {
1547 stack.push(&f.this);
1548 }
1549 Expression::FromUnixtime(f) => {
1550 stack.push(&f.this);
1551 if let Some(ref fmt) = f.format {
1552 stack.push(fmt);
1553 }
1554 }
1555 Expression::UnixTimestamp(f) => {
1556 if let Some(ref this) = f.this {
1557 stack.push(this);
1558 }
1559 if let Some(ref fmt) = f.format {
1560 stack.push(fmt);
1561 }
1562 }
1563 Expression::MakeDate(f) => {
1564 stack.push(&f.year);
1565 stack.push(&f.month);
1566 stack.push(&f.day);
1567 }
1568 Expression::MakeTimestamp(f) => {
1569 stack.push(&f.year);
1570 stack.push(&f.month);
1571 stack.push(&f.day);
1572 stack.push(&f.hour);
1573 stack.push(&f.minute);
1574 stack.push(&f.second);
1575 if let Some(ref tz) = f.timezone {
1576 stack.push(tz);
1577 }
1578 }
1579 Expression::TruncFunc(f) => {
1580 stack.push(&f.this);
1581 if let Some(ref d) = f.decimals {
1582 stack.push(d);
1583 }
1584 }
1585 Expression::ArrayFunc(f) => {
1586 for e in &f.expressions {
1587 stack.push(e);
1588 }
1589 }
1590 Expression::Unnest(f) => {
1591 stack.push(&f.this);
1592 for e in &f.expressions {
1593 stack.push(e);
1594 }
1595 }
1596 Expression::StructFunc(f) => {
1597 for (_, e) in &f.fields {
1598 stack.push(e);
1599 }
1600 }
1601 Expression::StructExtract(f) => {
1602 stack.push(&f.this);
1603 }
1604 Expression::NamedStruct(f) => {
1605 for (k, v) in &f.pairs {
1606 stack.push(k);
1607 stack.push(v);
1608 }
1609 }
1610 Expression::MapFunc(f) => {
1611 for k in &f.keys {
1612 stack.push(k);
1613 }
1614 for v in &f.values {
1615 stack.push(v);
1616 }
1617 }
1618 Expression::TransformKeys(f) | Expression::TransformValues(f) => {
1619 stack.push(&f.this);
1620 stack.push(&f.transform);
1621 }
1622 Expression::JsonArrayAgg(f) => {
1623 stack.push(&f.this);
1624 if let Some(ref filter) = f.filter {
1625 stack.push(filter);
1626 }
1627 }
1628 Expression::JsonObjectAgg(f) => {
1629 stack.push(&f.key);
1630 stack.push(&f.value);
1631 if let Some(ref filter) = f.filter {
1632 stack.push(filter);
1633 }
1634 }
1635 Expression::NTile(f) => {
1636 if let Some(ref n) = f.num_buckets {
1637 stack.push(n);
1638 }
1639 }
1640 Expression::Rand(f) => {
1641 if let Some(ref s) = f.seed {
1642 stack.push(s);
1643 }
1644 if let Some(ref lo) = f.lower {
1645 stack.push(lo);
1646 }
1647 if let Some(ref hi) = f.upper {
1648 stack.push(hi);
1649 }
1650 }
1651 Expression::Any(q) | Expression::All(q) => {
1652 stack.push(&q.this);
1653 stack.push(&q.subquery);
1654 }
1655 Expression::Overlaps(o) => {
1656 if let Some(ref this) = o.this {
1657 stack.push(this);
1658 }
1659 if let Some(ref expr) = o.expression {
1660 stack.push(expr);
1661 }
1662 if let Some(ref ls) = o.left_start {
1663 stack.push(ls);
1664 }
1665 if let Some(ref le) = o.left_end {
1666 stack.push(le);
1667 }
1668 if let Some(ref rs) = o.right_start {
1669 stack.push(rs);
1670 }
1671 if let Some(ref re) = o.right_end {
1672 stack.push(re);
1673 }
1674 }
1675 Expression::Interval(i) => {
1676 if let Some(ref this) = i.this {
1677 stack.push(this);
1678 }
1679 }
1680 Expression::TimeStrToTime(f) => {
1681 stack.push(&f.this);
1682 if let Some(ref zone) = f.zone {
1683 stack.push(zone);
1684 }
1685 }
1686 Expression::JSONBExtractScalar(f) => {
1687 stack.push(&f.this);
1688 stack.push(&f.expression);
1689 if let Some(ref jt) = f.json_type {
1690 stack.push(jt);
1691 }
1692 }
1693
1694 _ => {}
1699 }
1700 }
1701}
1702
1703#[cfg(test)]
1708mod tests {
1709 use super::*;
1710 use crate::dialects::{Dialect, DialectType};
1711 use crate::expressions::DataType;
1712 use crate::optimizer::annotate_types::annotate_types;
1713 use crate::schema::{MappingSchema, Schema};
1714
1715 fn parse(sql: &str) -> Expression {
1716 let dialect = Dialect::get(DialectType::Generic);
1717 let ast = dialect.parse(sql).unwrap();
1718 ast.into_iter().next().unwrap()
1719 }
1720
1721 #[test]
1722 fn test_simple_lineage() {
1723 let expr = parse("SELECT a FROM t");
1724 let node = lineage("a", &expr, None, false).unwrap();
1725
1726 assert_eq!(node.name, "a");
1727 assert!(!node.downstream.is_empty(), "Should have downstream nodes");
1728 let names = node.downstream_names();
1730 assert!(
1731 names.iter().any(|n| n == "t.a"),
1732 "Expected t.a in downstream, got: {:?}",
1733 names
1734 );
1735 }
1736
1737 #[test]
1738 fn test_lineage_walk() {
1739 let root = LineageNode {
1740 name: "col_a".to_string(),
1741 expression: Expression::Null(crate::expressions::Null),
1742 source: Expression::Null(crate::expressions::Null),
1743 downstream: vec![LineageNode::new(
1744 "t.a",
1745 Expression::Null(crate::expressions::Null),
1746 Expression::Null(crate::expressions::Null),
1747 )],
1748 source_name: String::new(),
1749 reference_node_name: String::new(),
1750 };
1751
1752 let names: Vec<_> = root.walk().map(|n| n.name.clone()).collect();
1753 assert_eq!(names.len(), 2);
1754 assert_eq!(names[0], "col_a");
1755 assert_eq!(names[1], "t.a");
1756 }
1757
1758 #[test]
1759 fn test_aliased_column() {
1760 let expr = parse("SELECT a + 1 AS b FROM t");
1761 let node = lineage("b", &expr, None, false).unwrap();
1762
1763 assert_eq!(node.name, "b");
1764 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1766 assert!(
1767 all_names.iter().any(|n| n.contains("a")),
1768 "Expected to trace to column a, got: {:?}",
1769 all_names
1770 );
1771 }
1772
1773 #[test]
1774 fn test_qualified_column() {
1775 let expr = parse("SELECT t.a FROM t");
1776 let node = lineage("a", &expr, None, false).unwrap();
1777
1778 assert_eq!(node.name, "a");
1779 let names = node.downstream_names();
1780 assert!(
1781 names.iter().any(|n| n == "t.a"),
1782 "Expected t.a, got: {:?}",
1783 names
1784 );
1785 }
1786
1787 #[test]
1788 fn test_unqualified_column() {
1789 let expr = parse("SELECT a FROM t");
1790 let node = lineage("a", &expr, None, false).unwrap();
1791
1792 let names = node.downstream_names();
1794 assert!(
1795 names.iter().any(|n| n == "t.a"),
1796 "Expected t.a, got: {:?}",
1797 names
1798 );
1799 }
1800
1801 #[test]
1802 fn test_lineage_with_schema_qualifies_root_expression_issue_40() {
1803 let query = "SELECT name FROM users";
1804 let dialect = Dialect::get(DialectType::BigQuery);
1805 let expr = dialect
1806 .parse(query)
1807 .unwrap()
1808 .into_iter()
1809 .next()
1810 .expect("expected one expression");
1811
1812 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
1813 schema
1814 .add_table("users", &[("name".into(), DataType::Text)], None)
1815 .expect("schema setup");
1816
1817 let node_without_schema = lineage("name", &expr, Some(DialectType::BigQuery), false)
1818 .expect("lineage without schema");
1819 let root_without_schema = annotate_types(
1820 &node_without_schema.expression,
1821 Some(&schema),
1822 Some(DialectType::BigQuery),
1823 );
1824 assert_eq!(
1825 root_without_schema, None,
1826 "Expected unresolved root type without schema-aware lineage qualification"
1827 );
1828
1829 let node_with_schema = lineage_with_schema(
1830 "name",
1831 &expr,
1832 Some(&schema),
1833 Some(DialectType::BigQuery),
1834 false,
1835 )
1836 .expect("lineage with schema");
1837 let root_with_schema = annotate_types(
1838 &node_with_schema.expression,
1839 Some(&schema),
1840 Some(DialectType::BigQuery),
1841 );
1842
1843 assert_eq!(root_with_schema, Some(DataType::Text));
1844 }
1845
1846 #[test]
1847 fn test_lineage_with_schema_none_matches_lineage() {
1848 let expr = parse("SELECT a FROM t");
1849 let baseline = lineage("a", &expr, None, false).expect("lineage baseline");
1850 let with_none =
1851 lineage_with_schema("a", &expr, None, None, false).expect("lineage_with_schema");
1852
1853 assert_eq!(with_none.name, baseline.name);
1854 assert_eq!(with_none.downstream_names(), baseline.downstream_names());
1855 }
1856
1857 #[test]
1858 fn test_lineage_join() {
1859 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1860
1861 let node_a = lineage("a", &expr, None, false).unwrap();
1862 let names_a = node_a.downstream_names();
1863 assert!(
1864 names_a.iter().any(|n| n == "t.a"),
1865 "Expected t.a, got: {:?}",
1866 names_a
1867 );
1868
1869 let node_b = lineage("b", &expr, None, false).unwrap();
1870 let names_b = node_b.downstream_names();
1871 assert!(
1872 names_b.iter().any(|n| n == "s.b"),
1873 "Expected s.b, got: {:?}",
1874 names_b
1875 );
1876 }
1877
1878 #[test]
1879 fn test_lineage_alias_leaf_has_resolved_source_name() {
1880 let expr = parse("SELECT t1.col1 FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id");
1881 let node = lineage("col1", &expr, None, false).unwrap();
1882
1883 let names = node.downstream_names();
1885 assert!(
1886 names.iter().any(|n| n == "t1.col1"),
1887 "Expected aliased column edge t1.col1, got: {:?}",
1888 names
1889 );
1890
1891 let leaf = node
1893 .downstream
1894 .iter()
1895 .find(|n| n.name == "t1.col1")
1896 .expect("Expected t1.col1 leaf");
1897 assert_eq!(leaf.source_name, "table1");
1898 match &leaf.source {
1899 Expression::Table(table) => assert_eq!(table.name.name, "table1"),
1900 _ => panic!("Expected leaf source to be a table expression"),
1901 }
1902 }
1903
1904 #[test]
1905 fn test_lineage_derived_table() {
1906 let expr = parse("SELECT x.a FROM (SELECT a FROM t) AS x");
1907 let node = lineage("a", &expr, None, false).unwrap();
1908
1909 assert_eq!(node.name, "a");
1910 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1912 assert!(
1913 all_names.iter().any(|n| n == "t.a"),
1914 "Expected to trace through derived table to t.a, got: {:?}",
1915 all_names
1916 );
1917 }
1918
1919 #[test]
1920 fn test_lineage_cte() {
1921 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT a FROM cte");
1922 let node = lineage("a", &expr, None, false).unwrap();
1923
1924 assert_eq!(node.name, "a");
1925 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1926 assert!(
1927 all_names.iter().any(|n| n == "t.a"),
1928 "Expected to trace through CTE to t.a, got: {:?}",
1929 all_names
1930 );
1931 }
1932
1933 #[test]
1934 fn test_lineage_union() {
1935 let expr = parse("SELECT a FROM t1 UNION SELECT a FROM t2");
1936 let node = lineage("a", &expr, None, false).unwrap();
1937
1938 assert_eq!(node.name, "a");
1939 assert_eq!(
1941 node.downstream.len(),
1942 2,
1943 "Expected 2 branches for UNION, got {}",
1944 node.downstream.len()
1945 );
1946 }
1947
1948 #[test]
1949 fn test_lineage_cte_union() {
1950 let expr = parse("WITH cte AS (SELECT a FROM t1 UNION SELECT a FROM t2) SELECT a FROM cte");
1951 let node = lineage("a", &expr, None, false).unwrap();
1952
1953 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1955 assert!(
1956 all_names.len() >= 3,
1957 "Expected at least 3 nodes for CTE with UNION, got: {:?}",
1958 all_names
1959 );
1960 }
1961
1962 #[test]
1963 fn test_lineage_star() {
1964 let expr = parse("SELECT * FROM t");
1965 let node = lineage("*", &expr, None, false).unwrap();
1966
1967 assert_eq!(node.name, "*");
1968 assert!(
1970 !node.downstream.is_empty(),
1971 "Star should produce downstream nodes"
1972 );
1973 }
1974
1975 #[test]
1976 fn test_lineage_subquery_in_select() {
1977 let expr = parse("SELECT (SELECT MAX(b) FROM s) AS x FROM t");
1978 let node = lineage("x", &expr, None, false).unwrap();
1979
1980 assert_eq!(node.name, "x");
1981 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
1983 assert!(
1984 all_names.len() >= 2,
1985 "Expected tracing into scalar subquery, got: {:?}",
1986 all_names
1987 );
1988 }
1989
1990 #[test]
1991 fn test_lineage_multiple_columns() {
1992 let expr = parse("SELECT a, b FROM t");
1993
1994 let node_a = lineage("a", &expr, None, false).unwrap();
1995 let node_b = lineage("b", &expr, None, false).unwrap();
1996
1997 assert_eq!(node_a.name, "a");
1998 assert_eq!(node_b.name, "b");
1999
2000 let names_a = node_a.downstream_names();
2002 let names_b = node_b.downstream_names();
2003 assert!(names_a.iter().any(|n| n == "t.a"));
2004 assert!(names_b.iter().any(|n| n == "t.b"));
2005 }
2006
2007 #[test]
2008 fn test_get_source_tables() {
2009 let expr = parse("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
2010 let node = lineage("a", &expr, None, false).unwrap();
2011
2012 let tables = get_source_tables(&node);
2013 assert!(
2014 tables.contains("t"),
2015 "Expected source table 't', got: {:?}",
2016 tables
2017 );
2018 }
2019
2020 #[test]
2021 fn test_lineage_column_not_found() {
2022 let expr = parse("SELECT a FROM t");
2023 let result = lineage("nonexistent", &expr, None, false);
2024 assert!(result.is_err());
2025 }
2026
2027 #[test]
2028 fn test_lineage_nested_cte() {
2029 let expr = parse(
2030 "WITH cte1 AS (SELECT a FROM t), \
2031 cte2 AS (SELECT a FROM cte1) \
2032 SELECT a FROM cte2",
2033 );
2034 let node = lineage("a", &expr, None, false).unwrap();
2035
2036 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2038 assert!(
2039 all_names.len() >= 3,
2040 "Expected to trace through nested CTEs, got: {:?}",
2041 all_names
2042 );
2043 }
2044
2045 #[test]
2046 fn test_trim_selects_true() {
2047 let expr = parse("SELECT a, b, c FROM t");
2048 let node = lineage("a", &expr, None, true).unwrap();
2049
2050 if let Expression::Select(select) = &node.source {
2052 assert_eq!(
2053 select.expressions.len(),
2054 1,
2055 "Trimmed source should have 1 expression, got {}",
2056 select.expressions.len()
2057 );
2058 } else {
2059 panic!("Expected Select source");
2060 }
2061 }
2062
2063 #[test]
2064 fn test_trim_selects_false() {
2065 let expr = parse("SELECT a, b, c FROM t");
2066 let node = lineage("a", &expr, None, false).unwrap();
2067
2068 if let Expression::Select(select) = &node.source {
2070 assert_eq!(
2071 select.expressions.len(),
2072 3,
2073 "Untrimmed source should have 3 expressions"
2074 );
2075 } else {
2076 panic!("Expected Select source");
2077 }
2078 }
2079
2080 #[test]
2081 fn test_lineage_expression_in_select() {
2082 let expr = parse("SELECT a + b AS c FROM t");
2083 let node = lineage("c", &expr, None, false).unwrap();
2084
2085 let all_names: Vec<_> = node.walk().map(|n| n.name.clone()).collect();
2087 assert!(
2088 all_names.len() >= 3,
2089 "Expected to trace a + b to both columns, got: {:?}",
2090 all_names
2091 );
2092 }
2093
2094 #[test]
2095 fn test_set_operation_by_index() {
2096 let expr = parse("SELECT a FROM t1 UNION SELECT b FROM t2");
2097
2098 let node = lineage("a", &expr, None, false).unwrap();
2100
2101 assert_eq!(node.downstream.len(), 2);
2103 }
2104
2105 fn print_node(node: &LineageNode, indent: usize) {
2108 let pad = " ".repeat(indent);
2109 println!(
2110 "{pad}name={:?} source_name={:?}",
2111 node.name, node.source_name
2112 );
2113 for child in &node.downstream {
2114 print_node(child, indent + 1);
2115 }
2116 }
2117
2118 #[test]
2119 fn test_issue18_repro() {
2120 let query = "SELECT UPPER(name) as upper_name FROM users";
2122 println!("Query: {query}\n");
2123
2124 let dialect = crate::dialects::Dialect::get(DialectType::BigQuery);
2125 let exprs = dialect.parse(query).unwrap();
2126 let expr = &exprs[0];
2127
2128 let node = lineage("upper_name", expr, Some(DialectType::BigQuery), false).unwrap();
2129 println!("lineage(\"upper_name\"):");
2130 print_node(&node, 1);
2131
2132 let names = node.downstream_names();
2133 assert!(
2134 names.iter().any(|n| n == "users.name"),
2135 "Expected users.name in downstream, got: {:?}",
2136 names
2137 );
2138 }
2139
2140 #[test]
2141 fn test_lineage_upper_function() {
2142 let expr = parse("SELECT UPPER(name) AS upper_name FROM users");
2143 let node = lineage("upper_name", &expr, None, false).unwrap();
2144
2145 let names = node.downstream_names();
2146 assert!(
2147 names.iter().any(|n| n == "users.name"),
2148 "Expected users.name in downstream, got: {:?}",
2149 names
2150 );
2151 }
2152
2153 #[test]
2154 fn test_lineage_round_function() {
2155 let expr = parse("SELECT ROUND(price, 2) AS rounded FROM products");
2156 let node = lineage("rounded", &expr, None, false).unwrap();
2157
2158 let names = node.downstream_names();
2159 assert!(
2160 names.iter().any(|n| n == "products.price"),
2161 "Expected products.price in downstream, got: {:?}",
2162 names
2163 );
2164 }
2165
2166 #[test]
2167 fn test_lineage_coalesce_function() {
2168 let expr = parse("SELECT COALESCE(a, b) AS val FROM t");
2169 let node = lineage("val", &expr, None, false).unwrap();
2170
2171 let names = node.downstream_names();
2172 assert!(
2173 names.iter().any(|n| n == "t.a"),
2174 "Expected t.a in downstream, got: {:?}",
2175 names
2176 );
2177 assert!(
2178 names.iter().any(|n| n == "t.b"),
2179 "Expected t.b in downstream, got: {:?}",
2180 names
2181 );
2182 }
2183
2184 #[test]
2185 fn test_lineage_count_function() {
2186 let expr = parse("SELECT COUNT(id) AS cnt FROM t");
2187 let node = lineage("cnt", &expr, None, false).unwrap();
2188
2189 let names = node.downstream_names();
2190 assert!(
2191 names.iter().any(|n| n == "t.id"),
2192 "Expected t.id in downstream, got: {:?}",
2193 names
2194 );
2195 }
2196
2197 #[test]
2198 fn test_lineage_sum_function() {
2199 let expr = parse("SELECT SUM(amount) AS total FROM t");
2200 let node = lineage("total", &expr, None, false).unwrap();
2201
2202 let names = node.downstream_names();
2203 assert!(
2204 names.iter().any(|n| n == "t.amount"),
2205 "Expected t.amount in downstream, got: {:?}",
2206 names
2207 );
2208 }
2209
2210 #[test]
2211 fn test_lineage_case_with_nested_functions() {
2212 let expr =
2213 parse("SELECT CASE WHEN x > 0 THEN UPPER(name) ELSE LOWER(name) END AS result FROM t");
2214 let node = lineage("result", &expr, None, false).unwrap();
2215
2216 let names = node.downstream_names();
2217 assert!(
2218 names.iter().any(|n| n == "t.x"),
2219 "Expected t.x in downstream, got: {:?}",
2220 names
2221 );
2222 assert!(
2223 names.iter().any(|n| n == "t.name"),
2224 "Expected t.name in downstream, got: {:?}",
2225 names
2226 );
2227 }
2228
2229 #[test]
2230 fn test_lineage_substring_function() {
2231 let expr = parse("SELECT SUBSTRING(name, 1, 3) AS short FROM t");
2232 let node = lineage("short", &expr, None, false).unwrap();
2233
2234 let names = node.downstream_names();
2235 assert!(
2236 names.iter().any(|n| n == "t.name"),
2237 "Expected t.name in downstream, got: {:?}",
2238 names
2239 );
2240 }
2241}