1use crate::ast_transforms::get_output_column_names;
9use crate::dialects::{Dialect, DialectType};
10use crate::expressions::{DataType, Expression, JoinKind, TableRef, With};
11use crate::lineage::{lineage_by_index_from_expression, LineageNode};
12use crate::optimizer::annotate_types::annotate_types;
13use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
14use crate::schema::{MappingSchema, Schema};
15use crate::scope::{build_scope, Scope, SourceInfo, SourceKind};
16use crate::traversal::{contains_aggregate, ExpressionWalk};
17use crate::validation::{mapping_schema_from_validation_schema, ValidationSchema};
18use crate::{parse_data_type, parse_one, Error, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24#[serde(rename_all = "camelCase", default)]
25pub struct AnalyzeQueryOptions {
26 pub dialect: DialectType,
28 pub schema: Option<ValidationSchema>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct QueryAnalysis {
36 pub shape: QueryShape,
37 pub ctes: Vec<String>,
38 pub cte_facts: Vec<CteFact>,
39 pub projections: Vec<ProjectionFact>,
40 pub relations: Vec<RelationFact>,
41 pub base_tables: Vec<RelationFact>,
42 pub star_projections: Vec<StarProjectionFact>,
43 pub set_operations: Vec<SetOperationFact>,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum QueryShape {
50 Select,
51 SetOperation,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct ProjectionFact {
58 pub index: usize,
59 pub name: Option<String>,
60 pub is_star: bool,
61 pub star_table: Option<String>,
62 pub transform_kind: TransformKind,
63 pub cast_type: Option<String>,
64 pub type_hint: Option<String>,
65 pub nullability: ProjectionNullability,
66 pub upstream: Vec<ColumnReferenceFact>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub struct CteFact {
73 pub name: String,
74 pub columns: Vec<String>,
75 pub body_sql: String,
76 pub output_columns: Vec<String>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct StarProjectionFact {
83 pub index: usize,
84 pub table: Option<String>,
85 pub expanded_columns: Vec<String>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct ColumnReferenceFact {
92 pub source_name: Option<String>,
93 pub source_alias: Option<String>,
94 pub source_kind: SourceKind,
95 pub table: Option<String>,
96 pub column: String,
97 pub unqualified: bool,
98 pub confidence: ReferenceConfidence,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct RelationFact {
105 pub name: String,
106 pub alias: Option<String>,
107 pub kind: SourceKind,
108 pub columns: Vec<String>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(rename_all = "camelCase")]
114pub struct SetOperationFact {
115 pub kind: String,
116 pub all: bool,
117 pub distinct: bool,
118 pub output_columns: Vec<String>,
119 pub branches: Vec<SetOperationBranchFact>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124#[serde(rename_all = "camelCase")]
125pub struct SetOperationBranchFact {
126 pub index: usize,
127 pub projections: Vec<ProjectionFact>,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum TransformKind {
134 Direct,
135 Cast,
136 Aggregation,
137 Constant,
138 Expression,
139 Star,
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum ReferenceConfidence {
146 Resolved,
147 Ambiguous,
148 Unknown,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153#[serde(rename_all = "snake_case")]
154pub enum ProjectionNullability {
155 NonNull,
156 Nullable,
157 Unknown,
158}
159
160pub fn analyze_query(sql: &str, options: AnalyzeQueryOptions) -> Result<QueryAnalysis> {
162 let mut expression = parse_one(sql, options.dialect)?;
163 expression = effective_query(expression);
164 ensure_query(&expression)?;
165 let original_expression = expression.clone();
166
167 let mapping_schema = options
168 .schema
169 .as_ref()
170 .map(|schema| analysis_mapping_schema(schema, options.dialect));
171 let schema_info = options.schema.as_ref().map(AnalysisSchemaInfo::from_schema);
172 let cte_facts = top_level_cte_facts(&original_expression, options.dialect)?;
173 let star_projections = star_projection_facts(&original_expression, mapping_schema.as_ref());
174
175 if let Some(schema) = mapping_schema.as_ref() {
176 let qualify_options = QualifyColumnsOptions::new().with_dialect(options.dialect);
177 expression = qualify_columns(expression, schema, &qualify_options)
178 .map_err(|e| Error::internal(format!("query analysis qualification failed: {e}")))?;
179 }
180
181 let annotation_schema = mapping_schema.as_ref().map(|schema| {
182 let mut alias_schema = schema.clone();
183 add_scope_aliases_to_schema(
184 &build_scope(&expression),
185 schema,
186 &mut alias_schema,
187 options.dialect,
188 );
189 alias_schema
190 });
191
192 annotate_types(
193 &mut expression,
194 annotation_schema
195 .as_ref()
196 .map(|schema| schema as &dyn Schema),
197 Some(options.dialect),
198 );
199 crate::lineage::expand_cte_stars(
200 &mut expression,
201 annotation_schema
202 .as_ref()
203 .or(mapping_schema.as_ref())
204 .map(|schema| schema as &dyn Schema),
205 );
206
207 let scope = build_scope(&expression);
208 let nullability_context = NullabilityContext {
209 schema: schema_info.as_ref(),
210 nullable_sources: nullable_source_names(&expression),
211 };
212 let shape = if is_set_operation(&expression) {
213 QueryShape::SetOperation
214 } else {
215 QueryShape::Select
216 };
217
218 Ok(QueryAnalysis {
219 shape,
220 ctes: collect_cte_names(&expression),
221 cte_facts,
222 projections: projection_facts_for_query(
223 &expression,
224 &scope,
225 options.dialect,
226 &nullability_context,
227 ),
228 relations: relation_facts(&scope, mapping_schema.as_ref()),
229 base_tables: base_table_facts(&scope, mapping_schema.as_ref()),
230 star_projections,
231 set_operations: set_operation_facts(&expression, &scope, options.dialect),
232 })
233}
234
235fn analysis_mapping_schema(schema: &ValidationSchema, dialect: DialectType) -> MappingSchema {
236 let broad_schema = mapping_schema_from_validation_schema(schema);
237 let mut mapping_schema = MappingSchema::with_dialect(dialect);
238
239 for table in &schema.tables {
240 let table_names = validation_table_names(table);
241 if table_names.is_empty() {
242 continue;
243 }
244
245 let fallback_table = table_names[0].as_str();
246 let columns: Vec<(String, DataType)> = table
247 .columns
248 .iter()
249 .map(|column| {
250 let data_type = parse_analysis_data_type(&column.data_type, dialect)
251 .unwrap_or_else(|| {
252 broad_schema
253 .get_column_type(fallback_table, &column.name)
254 .unwrap_or(DataType::Unknown)
255 });
256 (column.name.to_ascii_lowercase(), data_type)
257 })
258 .collect();
259
260 for table_name in table_names {
261 let _ = mapping_schema.add_table(&table_name, &columns, Some(dialect));
262 }
263 }
264
265 mapping_schema
266}
267
268fn validation_table_names(table: &crate::validation::SchemaTable) -> Vec<String> {
269 let mut names = Vec::new();
270
271 names.push(table.name.to_ascii_lowercase());
272 if let Some(schema_name) = &table.schema {
273 names.push(format!(
274 "{}.{}",
275 schema_name.to_ascii_lowercase(),
276 table.name.to_ascii_lowercase()
277 ));
278 }
279 for alias in &table.aliases {
280 names.push(alias.to_ascii_lowercase());
281 }
282
283 names.sort();
284 names.dedup();
285 names
286}
287
288fn parse_analysis_data_type(data_type: &str, dialect: DialectType) -> Option<DataType> {
289 let trimmed = data_type.trim();
290 if trimmed.is_empty() {
291 return None;
292 }
293 parse_data_type(trimmed, dialect).ok()
294}
295
296fn add_scope_aliases_to_schema(
297 scope: &Scope,
298 source_schema: &MappingSchema,
299 target_schema: &mut MappingSchema,
300 dialect: DialectType,
301) {
302 for child_scope in scope.traverse() {
303 for (source_name, source) in &child_scope.sources {
304 if source.kind != SourceKind::Table {
305 continue;
306 }
307 if let Some(table_name) = source_table_name(source) {
308 if source_name == &table_name {
309 continue;
310 }
311 if let Ok(column_names) = source_schema.column_names(&table_name) {
312 let columns: Vec<(String, DataType)> = column_names
313 .iter()
314 .map(|column| {
315 (
316 column.clone(),
317 source_schema
318 .get_column_type(&table_name, column)
319 .unwrap_or(DataType::Unknown),
320 )
321 })
322 .collect();
323 let _ = target_schema.add_table(source_name, &columns, Some(dialect));
324 }
325 }
326 }
327 }
328}
329
330#[derive(Debug, Clone)]
331struct AnalysisColumnInfo {
332 nullable: Option<bool>,
333 primary_key: bool,
334}
335
336#[derive(Debug, Clone)]
337struct AnalysisSchemaInfo {
338 columns: HashMap<(String, String), AnalysisColumnInfo>,
339}
340
341impl AnalysisSchemaInfo {
342 fn from_schema(schema: &ValidationSchema) -> Self {
343 let mut columns = HashMap::new();
344
345 for table in &schema.tables {
346 let table_names = validation_table_names(table);
347 let primary_keys: HashSet<String> = table
348 .primary_key
349 .iter()
350 .map(|column| column.to_ascii_lowercase())
351 .collect();
352
353 for column in &table.columns {
354 let info = AnalysisColumnInfo {
355 nullable: column.nullable,
356 primary_key: column.primary_key
357 || primary_keys.contains(&column.name.to_ascii_lowercase()),
358 };
359
360 for table_name in &table_names {
361 columns.insert(
362 (
363 normalize_lookup_name(table_name),
364 normalize_lookup_name(&column.name),
365 ),
366 info.clone(),
367 );
368 }
369 }
370 }
371
372 Self { columns }
373 }
374
375 fn column(&self, table: &str, column: &str) -> Option<&AnalysisColumnInfo> {
376 self.columns
377 .get(&(normalize_lookup_name(table), normalize_lookup_name(column)))
378 }
379}
380
381struct NullabilityContext<'a> {
382 schema: Option<&'a AnalysisSchemaInfo>,
383 nullable_sources: HashSet<String>,
384}
385
386fn top_level_cte_facts(expression: &Expression, dialect: DialectType) -> Result<Vec<CteFact>> {
387 let Some(with_clause) = with_clause(expression) else {
388 return Ok(Vec::new());
389 };
390
391 with_clause
392 .ctes
393 .iter()
394 .map(|cte| {
395 Ok(CteFact {
396 name: cte.alias.name.clone(),
397 columns: cte
398 .columns
399 .iter()
400 .map(|column| column.name.clone())
401 .collect(),
402 body_sql: Dialect::get(dialect).generate(&cte.this)?,
403 output_columns: get_output_column_names(&cte.this),
404 })
405 })
406 .collect()
407}
408
409fn star_projection_facts(
410 expression: &Expression,
411 mapping_schema: Option<&MappingSchema>,
412) -> Vec<StarProjectionFact> {
413 let scope = build_scope(expression);
414 let ordered_sources = ordered_source_names_for_query(expression);
415
416 select_expressions_for_query(expression)
417 .iter()
418 .enumerate()
419 .filter_map(|(index, projection)| {
420 let inner = unwrap_projection_alias(projection);
421 if !projection_is_star(inner) {
422 return None;
423 }
424
425 let table = projection_star_table(inner);
426 let expanded_columns =
427 expanded_star_columns(table.as_deref(), &scope, &ordered_sources, mapping_schema);
428
429 Some(StarProjectionFact {
430 index,
431 table,
432 expanded_columns,
433 })
434 })
435 .collect()
436}
437
438fn expanded_star_columns(
439 star_table: Option<&str>,
440 scope: &Scope,
441 ordered_sources: &[String],
442 mapping_schema: Option<&MappingSchema>,
443) -> Vec<String> {
444 let mut columns = Vec::new();
445 let mut source_names: Vec<String> = if ordered_sources.is_empty() {
446 let mut names: Vec<_> = scope.sources.keys().cloned().collect();
447 names.sort();
448 names
449 } else {
450 ordered_sources.to_vec()
451 };
452
453 source_names.dedup();
454
455 for source_name in source_names {
456 let Some(source) = scope.sources.get(&source_name) else {
457 continue;
458 };
459
460 if let Some(star_table) = star_table {
461 let matches = source_name.eq_ignore_ascii_case(star_table)
462 || source
463 .alias
464 .as_deref()
465 .is_some_and(|alias| alias.eq_ignore_ascii_case(star_table))
466 || source_table_name(source)
467 .is_some_and(|table| table.eq_ignore_ascii_case(star_table));
468
469 if !matches {
470 continue;
471 }
472 }
473
474 columns.extend(source_columns(source, mapping_schema));
475 }
476
477 columns
478}
479
480fn ordered_source_names_for_query(expression: &Expression) -> Vec<String> {
481 match expression {
482 Expression::Select(select) => ordered_source_names_for_select(select),
483 Expression::Union(union) => ordered_source_names_for_query(&union.left),
484 Expression::Intersect(intersect) => ordered_source_names_for_query(&intersect.left),
485 Expression::Except(except) => ordered_source_names_for_query(&except.left),
486 Expression::Subquery(subquery) => ordered_source_names_for_query(&subquery.this),
487 _ => Vec::new(),
488 }
489}
490
491fn ordered_source_names_for_select(select: &crate::expressions::Select) -> Vec<String> {
492 let mut sources = Vec::new();
493
494 if let Some(from) = &select.from {
495 for expression in &from.expressions {
496 if let Some(source_name) = expression_source_name(expression) {
497 sources.push(source_name);
498 }
499 }
500 }
501
502 for join in &select.joins {
503 if let Some(source_name) = expression_source_name(&join.this) {
504 sources.push(source_name);
505 }
506 }
507
508 sources
509}
510
511fn nullable_source_names(expression: &Expression) -> HashSet<String> {
512 match expression {
513 Expression::Select(select) => nullable_source_names_for_select(select),
514 Expression::Union(union) => nullable_source_names(&union.left),
515 Expression::Intersect(intersect) => nullable_source_names(&intersect.left),
516 Expression::Except(except) => nullable_source_names(&except.left),
517 Expression::Subquery(subquery) => nullable_source_names(&subquery.this),
518 _ => HashSet::new(),
519 }
520}
521
522fn nullable_source_names_for_select(select: &crate::expressions::Select) -> HashSet<String> {
523 let mut nullable = HashSet::new();
524 let mut left_sources = Vec::new();
525
526 if let Some(from) = &select.from {
527 for expression in &from.expressions {
528 if let Some(source_name) = expression_source_name(expression) {
529 left_sources.push(source_name);
530 }
531 }
532 }
533
534 for join in &select.joins {
535 let right_source = expression_source_name(&join.this);
536
537 if join_nullable_left(join.kind) {
538 for source_name in &left_sources {
539 nullable.insert(normalize_lookup_name(source_name));
540 }
541 }
542
543 if join_nullable_right(join.kind) {
544 if let Some(source_name) = &right_source {
545 nullable.insert(normalize_lookup_name(source_name));
546 }
547 }
548
549 if let Some(source_name) = right_source {
550 left_sources.push(source_name);
551 }
552 }
553
554 nullable
555}
556
557fn join_nullable_left(kind: JoinKind) -> bool {
558 matches!(
559 kind,
560 JoinKind::Right
561 | JoinKind::NaturalRight
562 | JoinKind::AsOfRight
563 | JoinKind::Full
564 | JoinKind::NaturalFull
565 | JoinKind::Outer
566 )
567}
568
569fn join_nullable_right(kind: JoinKind) -> bool {
570 matches!(
571 kind,
572 JoinKind::Left
573 | JoinKind::NaturalLeft
574 | JoinKind::AsOfLeft
575 | JoinKind::LeftLateral
576 | JoinKind::OuterApply
577 | JoinKind::LeftArray
578 | JoinKind::Full
579 | JoinKind::NaturalFull
580 | JoinKind::Outer
581 )
582}
583
584fn expression_source_name(expression: &Expression) -> Option<String> {
585 match expression {
586 Expression::Table(table) => table
587 .alias
588 .as_ref()
589 .map(|alias| alias.name.clone())
590 .or_else(|| Some(table.name.name.clone())),
591 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|alias| alias.name.clone()),
592 Expression::Alias(alias) => Some(alias.alias.name.clone()),
593 Expression::Cte(cte) => Some(cte.alias.name.clone()),
594 _ => None,
595 }
596}
597
598fn normalize_lookup_name(name: &str) -> String {
599 name.to_ascii_lowercase()
600}
601
602fn effective_query(expression: Expression) -> Expression {
603 match expression {
604 Expression::Prepare(prepare) => prepare.statement,
605 Expression::Subquery(subquery) if subquery.alias.is_none() => subquery.this,
606 other => other,
607 }
608}
609
610fn ensure_query(expression: &Expression) -> Result<()> {
611 if matches!(
612 expression,
613 Expression::Select(_)
614 | Expression::Union(_)
615 | Expression::Intersect(_)
616 | Expression::Except(_)
617 ) {
618 Ok(())
619 } else {
620 Err(Error::internal(
621 "analyze_query requires a SELECT or set operation query",
622 ))
623 }
624}
625
626fn is_set_operation(expression: &Expression) -> bool {
627 matches!(
628 expression,
629 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
630 )
631}
632
633fn collect_cte_names(expression: &Expression) -> Vec<String> {
634 let mut names = Vec::new();
635 let mut seen = HashSet::new();
636 collect_cte_names_inner(expression, &mut names, &mut seen);
637 names
638}
639
640fn collect_cte_names_inner(
641 expression: &Expression,
642 names: &mut Vec<String>,
643 seen: &mut HashSet<String>,
644) {
645 if let Some(with_clause) = with_clause(expression) {
646 collect_with_names(with_clause, names, seen);
647 }
648
649 match expression {
650 Expression::Union(union) => {
651 collect_cte_names_inner(&union.left, names, seen);
652 collect_cte_names_inner(&union.right, names, seen);
653 }
654 Expression::Intersect(intersect) => {
655 collect_cte_names_inner(&intersect.left, names, seen);
656 collect_cte_names_inner(&intersect.right, names, seen);
657 }
658 Expression::Except(except) => {
659 collect_cte_names_inner(&except.left, names, seen);
660 collect_cte_names_inner(&except.right, names, seen);
661 }
662 Expression::Subquery(subquery) => collect_cte_names_inner(&subquery.this, names, seen),
663 _ => {}
664 }
665}
666
667fn collect_with_names(with_clause: &With, names: &mut Vec<String>, seen: &mut HashSet<String>) {
668 for cte in &with_clause.ctes {
669 if seen.insert(cte.alias.name.clone()) {
670 names.push(cte.alias.name.clone());
671 }
672 collect_cte_names_inner(&cte.this, names, seen);
673 }
674}
675
676fn with_clause(expression: &Expression) -> Option<&With> {
677 match expression {
678 Expression::Select(select) => select.with.as_ref(),
679 Expression::Union(union) => union.with.as_ref(),
680 Expression::Intersect(intersect) => intersect.with.as_ref(),
681 Expression::Except(except) => except.with.as_ref(),
682 _ => None,
683 }
684}
685
686fn projection_facts_for_query(
687 expression: &Expression,
688 scope: &Scope,
689 dialect: DialectType,
690 nullability_context: &NullabilityContext<'_>,
691) -> Vec<ProjectionFact> {
692 let expressions = select_expressions_for_query(expression);
693 let names = get_output_column_names(expression);
694
695 expressions
696 .iter()
697 .enumerate()
698 .map(|(index, projection)| {
699 projection_fact(
700 index,
701 names
702 .get(index)
703 .cloned()
704 .or_else(|| projection_name(projection)),
705 projection,
706 expression,
707 scope,
708 dialect,
709 nullability_context,
710 )
711 })
712 .collect()
713}
714
715fn select_expressions_for_query(expression: &Expression) -> Vec<&Expression> {
716 match expression {
717 Expression::Select(select) => select.expressions.iter().collect(),
718 Expression::Union(union) => select_expressions_for_query(&union.left),
719 Expression::Intersect(intersect) => select_expressions_for_query(&intersect.left),
720 Expression::Except(except) => select_expressions_for_query(&except.left),
721 Expression::Subquery(subquery) => select_expressions_for_query(&subquery.this),
722 _ => Vec::new(),
723 }
724}
725
726fn projection_fact(
727 index: usize,
728 name: Option<String>,
729 projection: &Expression,
730 query: &Expression,
731 scope: &Scope,
732 dialect: DialectType,
733 nullability_context: &NullabilityContext<'_>,
734) -> ProjectionFact {
735 let inner = unwrap_projection_alias(projection);
736 let is_star = projection_is_star(inner);
737 let upstream = lineage_by_index_from_expression(index, query, Some(dialect), false)
738 .map(|node| terminal_references_from_lineage(&node))
739 .ok()
740 .filter(|refs| !refs.is_empty())
741 .unwrap_or_else(|| fallback_column_references(inner, scope));
742
743 ProjectionFact {
744 index,
745 name,
746 is_star,
747 star_table: projection_star_table(inner),
748 transform_kind: transform_kind(inner),
749 cast_type: cast_type(inner, dialect),
750 type_hint: projection
751 .inferred_type()
752 .or_else(|| inner.inferred_type())
753 .and_then(|data_type| render_data_type(data_type, dialect)),
754 nullability: projection_nullability(inner, scope, nullability_context),
755 upstream,
756 }
757}
758
759fn unwrap_projection_alias(expression: &Expression) -> &Expression {
760 match expression {
761 Expression::Alias(alias) => unwrap_projection_alias(&alias.this),
762 Expression::Annotated(annotated) => unwrap_projection_alias(&annotated.this),
763 Expression::Paren(paren) => unwrap_projection_alias(&paren.this),
764 _ => expression,
765 }
766}
767
768fn projection_name(expression: &Expression) -> Option<String> {
769 match expression {
770 Expression::Alias(alias) => Some(alias.alias.name.clone()),
771 Expression::Column(column) => Some(column.name.name.clone()),
772 Expression::Identifier(identifier) => Some(identifier.name.clone()),
773 Expression::Star(_) => Some("*".to_string()),
774 Expression::Annotated(annotated) => projection_name(&annotated.this),
775 _ => None,
776 }
777}
778
779fn projection_is_star(expression: &Expression) -> bool {
780 matches!(expression, Expression::Star(_))
781 || matches!(expression, Expression::Column(column) if column.name.name == "*")
782}
783
784fn projection_star_table(expression: &Expression) -> Option<String> {
785 match expression {
786 Expression::Star(star) => star
787 .table
788 .as_ref()
789 .map(|identifier| identifier.name.clone()),
790 Expression::Column(column) if column.name.name == "*" => column
791 .table
792 .as_ref()
793 .map(|identifier| identifier.name.clone()),
794 _ => None,
795 }
796}
797
798fn transform_kind(expression: &Expression) -> TransformKind {
799 if projection_is_star(expression) {
800 TransformKind::Star
801 } else if is_cast_expression(expression) {
802 TransformKind::Cast
803 } else if contains_aggregate(expression) {
804 TransformKind::Aggregation
805 } else if matches!(
806 expression,
807 Expression::Column(_) | Expression::Identifier(_)
808 ) {
809 TransformKind::Direct
810 } else if is_simple_constant(expression) {
811 TransformKind::Constant
812 } else {
813 TransformKind::Expression
814 }
815}
816
817fn is_cast_expression(expression: &Expression) -> bool {
818 matches!(
819 expression,
820 Expression::Cast(_) | Expression::TryCast(_) | Expression::SafeCast(_)
821 )
822}
823
824fn cast_type(expression: &Expression, dialect: DialectType) -> Option<String> {
825 match expression {
826 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
827 render_data_type(&cast.to, dialect)
828 }
829 _ => None,
830 }
831}
832
833fn render_data_type(data_type: &DataType, dialect: DialectType) -> Option<String> {
834 Dialect::get(dialect)
835 .generate(&Expression::DataType(data_type.clone()))
836 .ok()
837}
838
839fn is_simple_constant(expression: &Expression) -> bool {
840 match expression {
841 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_) => true,
842 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
843 is_simple_constant(&cast.this)
844 }
845 Expression::Neg(unary) | Expression::BitwiseNot(unary) => is_simple_constant(&unary.this),
846 _ => false,
847 }
848}
849
850fn projection_nullability(
851 expression: &Expression,
852 scope: &Scope,
853 context: &NullabilityContext<'_>,
854) -> ProjectionNullability {
855 match expression {
856 Expression::Alias(alias) => projection_nullability(&alias.this, scope, context),
857 Expression::Annotated(annotated) => projection_nullability(&annotated.this, scope, context),
858 Expression::Paren(paren) => projection_nullability(&paren.this, scope, context),
859 Expression::Literal(_) | Expression::Boolean(_) => ProjectionNullability::NonNull,
860 Expression::Null(_) => ProjectionNullability::Nullable,
861 Expression::Count(_) | Expression::CountIf(_) => ProjectionNullability::NonNull,
862 Expression::Cast(cast) => projection_nullability(&cast.this, scope, context),
863 Expression::TryCast(_) | Expression::SafeCast(_) => ProjectionNullability::Unknown,
864 Expression::Column(column) => column_nullability(
865 &column.name.name,
866 column.table.as_ref().map(|table| table.name.as_str()),
867 scope,
868 context,
869 ),
870 Expression::Identifier(identifier) => {
871 column_nullability(&identifier.name, None, scope, context)
872 }
873 Expression::Coalesce(func) => coalesce_nullability(&func.expressions, scope, context),
874 _ => ProjectionNullability::Unknown,
875 }
876}
877
878fn column_nullability(
879 column_name: &str,
880 source_name: Option<&str>,
881 scope: &Scope,
882 context: &NullabilityContext<'_>,
883) -> ProjectionNullability {
884 let resolved_source_name = source_name
885 .map(str::to_string)
886 .or_else(|| single_scope_source_name(scope));
887
888 if let Some(source_name) = &resolved_source_name {
889 if context
890 .nullable_sources
891 .contains(&normalize_lookup_name(source_name))
892 {
893 return ProjectionNullability::Nullable;
894 }
895 }
896
897 let Some(schema) = context.schema else {
898 return ProjectionNullability::Unknown;
899 };
900
901 let table_name = resolved_source_name
902 .as_ref()
903 .and_then(|name| scope.sources.get(name).and_then(source_table_name))
904 .or(resolved_source_name);
905
906 let Some(table_name) = table_name else {
907 return ProjectionNullability::Unknown;
908 };
909
910 match schema.column(&table_name, column_name) {
911 Some(info) if info.primary_key || info.nullable == Some(false) => {
912 ProjectionNullability::NonNull
913 }
914 Some(info) if info.nullable == Some(true) => ProjectionNullability::Nullable,
915 Some(_) | None => ProjectionNullability::Unknown,
916 }
917}
918
919fn single_scope_source_name(scope: &Scope) -> Option<String> {
920 if scope.sources.len() == 1 {
921 scope.sources.keys().next().cloned()
922 } else {
923 None
924 }
925}
926
927fn coalesce_nullability(
928 expressions: &[Expression],
929 scope: &Scope,
930 context: &NullabilityContext<'_>,
931) -> ProjectionNullability {
932 if expressions.is_empty() {
933 return ProjectionNullability::Unknown;
934 }
935
936 let mut all_nullable = true;
937
938 for expression in expressions {
939 match projection_nullability(unwrap_projection_alias(expression), scope, context) {
940 ProjectionNullability::NonNull => return ProjectionNullability::NonNull,
941 ProjectionNullability::Nullable => {}
942 ProjectionNullability::Unknown => all_nullable = false,
943 }
944 }
945
946 if all_nullable {
947 ProjectionNullability::Nullable
948 } else {
949 ProjectionNullability::Unknown
950 }
951}
952
953fn terminal_references_from_lineage(node: &LineageNode) -> Vec<ColumnReferenceFact> {
954 let mut refs = Vec::new();
955 collect_terminal_references(node, &mut refs);
956 dedupe_column_refs(refs)
957}
958
959fn collect_terminal_references(node: &LineageNode, refs: &mut Vec<ColumnReferenceFact>) {
960 if node.downstream.is_empty() {
961 if let Some(reference) = column_reference_from_lineage_node(node) {
962 refs.push(reference);
963 }
964 return;
965 }
966
967 for child in &node.downstream {
968 collect_terminal_references(child, refs);
969 }
970}
971
972fn column_reference_from_lineage_node(node: &LineageNode) -> Option<ColumnReferenceFact> {
973 match &node.expression {
974 Expression::Column(column) => {
975 let source_name = non_empty_string(node.source_name.clone());
976 let table =
977 lineage_node_table(node).or_else(|| column.table.as_ref().map(|t| t.name.clone()));
978 let confidence = if node.source_kind == SourceKind::Unknown && source_name.is_none() {
979 ReferenceConfidence::Unknown
980 } else {
981 ReferenceConfidence::Resolved
982 };
983 Some(ColumnReferenceFact {
984 source_name,
985 source_alias: node.source_alias.clone(),
986 source_kind: node.source_kind,
987 table,
988 column: column.name.name.clone(),
989 unqualified: column.table.is_none(),
990 confidence,
991 })
992 }
993 Expression::Star(_) => Some(ColumnReferenceFact {
994 source_name: non_empty_string(node.source_name.clone()),
995 source_alias: node.source_alias.clone(),
996 source_kind: node.source_kind,
997 table: lineage_node_table(node),
998 column: "*".to_string(),
999 unqualified: true,
1000 confidence: if node.source_kind == SourceKind::Unknown {
1001 ReferenceConfidence::Unknown
1002 } else {
1003 ReferenceConfidence::Resolved
1004 },
1005 }),
1006 _ => None,
1007 }
1008}
1009
1010fn lineage_node_table(node: &LineageNode) -> Option<String> {
1011 match &node.source {
1012 Expression::Table(table) => Some(table_name(table)),
1013 _ => None,
1014 }
1015}
1016
1017fn fallback_column_references(expression: &Expression, scope: &Scope) -> Vec<ColumnReferenceFact> {
1018 let mut refs = Vec::new();
1019 let source_count = scope.sources.len();
1020 let single_source = if source_count == 1 {
1021 scope.sources.iter().next()
1022 } else {
1023 None
1024 };
1025
1026 for column_expr in expression.find_all(|candidate| matches!(candidate, Expression::Column(_))) {
1027 if let Expression::Column(column) = column_expr {
1028 if column.name.name == "*" {
1029 continue;
1030 }
1031 let source = column
1032 .table
1033 .as_ref()
1034 .and_then(|table| scope.sources.get(&table.name));
1035 let (source_name, source_alias, source_kind, table, confidence) =
1036 if let Some(table_identifier) = &column.table {
1037 if let Some(source) = source {
1038 (
1039 Some(table_identifier.name.clone()),
1040 source.alias.clone(),
1041 source.kind,
1042 source_table_name(source)
1043 .or_else(|| Some(table_identifier.name.clone())),
1044 ReferenceConfidence::Resolved,
1045 )
1046 } else {
1047 (
1048 Some(table_identifier.name.clone()),
1049 None,
1050 SourceKind::Unknown,
1051 Some(table_identifier.name.clone()),
1052 ReferenceConfidence::Unknown,
1053 )
1054 }
1055 } else if let Some((name, source)) = single_source {
1056 (
1057 Some(name.clone()),
1058 source.alias.clone(),
1059 source.kind,
1060 source_table_name(source).or_else(|| Some(name.clone())),
1061 ReferenceConfidence::Resolved,
1062 )
1063 } else if source_count > 1 {
1064 (
1065 None,
1066 None,
1067 SourceKind::Unknown,
1068 None,
1069 ReferenceConfidence::Ambiguous,
1070 )
1071 } else {
1072 (
1073 None,
1074 None,
1075 SourceKind::Unknown,
1076 None,
1077 ReferenceConfidence::Unknown,
1078 )
1079 };
1080
1081 refs.push(ColumnReferenceFact {
1082 source_name,
1083 source_alias,
1084 source_kind,
1085 table,
1086 column: column.name.name.clone(),
1087 unqualified: column.table.is_none(),
1088 confidence,
1089 });
1090 }
1091 }
1092
1093 dedupe_column_refs(refs)
1094}
1095
1096fn dedupe_column_refs(refs: Vec<ColumnReferenceFact>) -> Vec<ColumnReferenceFact> {
1097 let mut seen = HashSet::new();
1098 let mut deduped = Vec::new();
1099
1100 for reference in refs {
1101 let key = (
1102 reference.source_name.clone(),
1103 reference.source_alias.clone(),
1104 reference.table.clone(),
1105 reference.column.clone(),
1106 format!("{:?}", reference.source_kind),
1107 reference.unqualified,
1108 format!("{:?}", reference.confidence),
1109 );
1110 if seen.insert(key) {
1111 deduped.push(reference);
1112 }
1113 }
1114
1115 deduped
1116}
1117
1118fn relation_facts(
1119 scope: &Scope,
1120 mapping_schema: Option<&crate::schema::MappingSchema>,
1121) -> Vec<RelationFact> {
1122 let mut relations = Vec::new();
1123 let mut seen = HashSet::new();
1124 collect_relation_facts(scope, mapping_schema, &mut seen, &mut relations);
1125
1126 relations.sort_by(|left, right| {
1127 left.name
1128 .cmp(&right.name)
1129 .then_with(|| left.alias.cmp(&right.alias))
1130 });
1131 relations
1132}
1133
1134fn collect_relation_facts(
1135 scope: &Scope,
1136 mapping_schema: Option<&crate::schema::MappingSchema>,
1137 seen: &mut HashSet<String>,
1138 relations: &mut Vec<RelationFact>,
1139) {
1140 for relation in scope
1141 .sources
1142 .iter()
1143 .map(|(source_name, source)| RelationFact {
1144 name: source
1145 .lineage_name
1146 .clone()
1147 .or_else(|| source_table_name(source))
1148 .unwrap_or_else(|| source_name.clone()),
1149 alias: source.alias.clone().or_else(|| source_alias(source)),
1150 kind: source.kind,
1151 columns: source_columns(source, mapping_schema),
1152 })
1153 {
1154 let key = format!("{:?}|{}|{:?}", relation.kind, relation.name, relation.alias);
1155 if seen.insert(key) {
1156 relations.push(relation);
1157 }
1158 }
1159
1160 for branch_scope in &scope.union_scopes {
1161 collect_relation_facts(branch_scope, mapping_schema, seen, relations);
1162 }
1163}
1164
1165fn base_table_facts(
1166 scope: &Scope,
1167 mapping_schema: Option<&crate::schema::MappingSchema>,
1168) -> Vec<RelationFact> {
1169 let mut relations = Vec::new();
1170 let mut seen = HashSet::new();
1171
1172 collect_base_table_facts(scope, mapping_schema, &mut seen, &mut relations);
1173
1174 relations.sort_by(|left, right| left.name.cmp(&right.name));
1175 relations
1176}
1177
1178fn collect_base_table_facts(
1179 scope: &Scope,
1180 mapping_schema: Option<&crate::schema::MappingSchema>,
1181 seen: &mut HashSet<String>,
1182 relations: &mut Vec<RelationFact>,
1183) {
1184 for source in scope.sources.values() {
1185 if source.kind != SourceKind::Table {
1186 continue;
1187 }
1188
1189 let Some(table_name) = source_table_name(source) else {
1190 continue;
1191 };
1192
1193 if seen.insert(table_name.clone()) {
1194 relations.push(RelationFact {
1195 name: table_name,
1196 alias: source.alias.clone().or_else(|| source_alias(source)),
1197 kind: SourceKind::Table,
1198 columns: source_columns(source, mapping_schema),
1199 });
1200 }
1201 }
1202
1203 for child_scope in scope
1204 .cte_scopes
1205 .iter()
1206 .chain(scope.union_scopes.iter())
1207 .chain(scope.table_scopes.iter())
1208 .chain(scope.derived_table_scopes.iter())
1209 .chain(scope.subquery_scopes.iter())
1210 {
1211 collect_base_table_facts(child_scope, mapping_schema, seen, relations);
1212 }
1213}
1214
1215fn source_columns(
1216 source: &SourceInfo,
1217 mapping_schema: Option<&crate::schema::MappingSchema>,
1218) -> Vec<String> {
1219 match &source.expression {
1220 Expression::Table(table) => mapping_schema
1221 .and_then(|schema| schema.column_names(&table_name(table)).ok())
1222 .unwrap_or_default(),
1223 Expression::Select(_)
1224 | Expression::Union(_)
1225 | Expression::Intersect(_)
1226 | Expression::Except(_) => get_output_column_names(&source.expression),
1227 Expression::Subquery(subquery) => get_output_column_names(&subquery.this),
1228 Expression::Cte(cte) if !cte.columns.is_empty() => cte
1229 .columns
1230 .iter()
1231 .map(|column| column.name.clone())
1232 .collect(),
1233 Expression::Cte(cte) => get_output_column_names(&cte.this),
1234 _ => Vec::new(),
1235 }
1236}
1237
1238fn source_table_name(source: &SourceInfo) -> Option<String> {
1239 match &source.expression {
1240 Expression::Table(table) => Some(table_name(table)),
1241 _ => None,
1242 }
1243}
1244
1245fn source_alias(source: &SourceInfo) -> Option<String> {
1246 match &source.expression {
1247 Expression::Table(table) => table.alias.as_ref().map(|alias| alias.name.clone()),
1248 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|alias| alias.name.clone()),
1249 _ => None,
1250 }
1251}
1252
1253fn table_name(table: &TableRef) -> String {
1254 let mut parts = Vec::new();
1255 if let Some(catalog) = &table.catalog {
1256 parts.push(catalog.name.clone());
1257 }
1258 if let Some(schema) = &table.schema {
1259 parts.push(schema.name.clone());
1260 }
1261 parts.push(table.name.name.clone());
1262 parts.join(".")
1263}
1264
1265fn set_operation_facts(
1266 expression: &Expression,
1267 scope: &Scope,
1268 dialect: DialectType,
1269) -> Vec<SetOperationFact> {
1270 let mut facts = Vec::new();
1271 collect_set_operation_facts(expression, scope, dialect, &mut facts);
1272 facts
1273}
1274
1275fn collect_set_operation_facts(
1276 expression: &Expression,
1277 scope: &Scope,
1278 dialect: DialectType,
1279 facts: &mut Vec<SetOperationFact>,
1280) {
1281 match expression {
1282 Expression::Union(union) => {
1283 facts.push(SetOperationFact {
1284 kind: "union".to_string(),
1285 all: union.all,
1286 distinct: union.distinct,
1287 output_columns: get_output_column_names(expression),
1288 branches: set_operation_branches(&union.left, &union.right, scope, dialect),
1289 });
1290 collect_set_operation_facts(&union.left, scope, dialect, facts);
1291 collect_set_operation_facts(&union.right, scope, dialect, facts);
1292 }
1293 Expression::Intersect(intersect) => {
1294 facts.push(SetOperationFact {
1295 kind: "intersect".to_string(),
1296 all: intersect.all,
1297 distinct: intersect.distinct,
1298 output_columns: get_output_column_names(expression),
1299 branches: set_operation_branches(&intersect.left, &intersect.right, scope, dialect),
1300 });
1301 collect_set_operation_facts(&intersect.left, scope, dialect, facts);
1302 collect_set_operation_facts(&intersect.right, scope, dialect, facts);
1303 }
1304 Expression::Except(except) => {
1305 facts.push(SetOperationFact {
1306 kind: "except".to_string(),
1307 all: except.all,
1308 distinct: except.distinct,
1309 output_columns: get_output_column_names(expression),
1310 branches: set_operation_branches(&except.left, &except.right, scope, dialect),
1311 });
1312 collect_set_operation_facts(&except.left, scope, dialect, facts);
1313 collect_set_operation_facts(&except.right, scope, dialect, facts);
1314 }
1315 Expression::Subquery(subquery) => {
1316 collect_set_operation_facts(&subquery.this, scope, dialect, facts);
1317 }
1318 _ => {}
1319 }
1320}
1321
1322fn set_operation_branches(
1323 left: &Expression,
1324 right: &Expression,
1325 scope: &Scope,
1326 dialect: DialectType,
1327) -> Vec<SetOperationBranchFact> {
1328 vec![
1329 SetOperationBranchFact {
1330 index: 0,
1331 projections: projection_facts_for_branch(left, scope, dialect),
1332 },
1333 SetOperationBranchFact {
1334 index: 1,
1335 projections: projection_facts_for_branch(right, scope, dialect),
1336 },
1337 ]
1338}
1339
1340fn projection_facts_for_branch(
1341 expression: &Expression,
1342 root_scope: &Scope,
1343 dialect: DialectType,
1344) -> Vec<ProjectionFact> {
1345 let branch_scope = build_scope(expression);
1346 let scope = if branch_scope.sources.is_empty() {
1347 root_scope
1348 } else {
1349 &branch_scope
1350 };
1351 let nullability_context = NullabilityContext {
1352 schema: None,
1353 nullable_sources: nullable_source_names(expression),
1354 };
1355 projection_facts_for_query(expression, scope, dialect, &nullability_context)
1356}
1357
1358fn non_empty_string(value: String) -> Option<String> {
1359 if value.is_empty() {
1360 None
1361 } else {
1362 Some(value)
1363 }
1364}