Skip to main content

polyglot_sql/
query_analysis.rs

1//! Compact query analysis facts.
2//!
3//! This module intentionally builds on the existing parser, scope builder, type
4//! annotator, and lineage implementation. It is a convenience API: callers that
5//! need the full AST or full lineage graph should continue using those lower
6//! level APIs directly.
7
8use crate::ast_transforms::get_output_column_names;
9use crate::dialects::{Dialect, DialectType};
10use crate::expressions::{DataType, Expression, JoinKind, TableRef, With};
11use crate::lineage::{lineage_by_index_from_expression, LineageNode};
12use crate::optimizer::annotate_types::annotate_types;
13use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
14use crate::schema::{MappingSchema, Schema};
15use crate::scope::{build_scope, Scope, SourceInfo, SourceKind};
16use crate::traversal::{contains_aggregate, ExpressionWalk};
17use crate::validation::{mapping_schema_from_validation_schema, ValidationSchema};
18use crate::{parse_data_type, parse_one, Error, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21
22/// Options for [`analyze_query`].
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24#[serde(rename_all = "camelCase", default)]
25pub struct AnalyzeQueryOptions {
26    /// SQL dialect used for parsing and dialect-aware rendering.
27    pub dialect: DialectType,
28    /// Optional validation schema used for qualification and type annotation.
29    pub schema: Option<ValidationSchema>,
30}
31
32/// Compact facts about a query's output shape and data dependencies.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct QueryAnalysis {
36    pub shape: QueryShape,
37    pub ctes: Vec<String>,
38    pub cte_facts: Vec<CteFact>,
39    pub projections: Vec<ProjectionFact>,
40    pub relations: Vec<RelationFact>,
41    pub base_tables: Vec<RelationFact>,
42    pub star_projections: Vec<StarProjectionFact>,
43    pub set_operations: Vec<SetOperationFact>,
44}
45
46/// Top-level query shape.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum QueryShape {
50    Select,
51    SetOperation,
52}
53
54/// Compact fact about one output projection.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct ProjectionFact {
58    pub index: usize,
59    pub name: Option<String>,
60    pub is_star: bool,
61    pub star_table: Option<String>,
62    pub transform_kind: TransformKind,
63    pub cast_type: Option<String>,
64    pub type_hint: Option<String>,
65    pub nullability: ProjectionNullability,
66    pub upstream: Vec<ColumnReferenceFact>,
67}
68
69/// Compact fact about one top-level CTE definition.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub struct CteFact {
73    pub name: String,
74    pub columns: Vec<String>,
75    pub body_sql: String,
76    pub output_columns: Vec<String>,
77}
78
79/// Compact fact about one original star projection.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct StarProjectionFact {
83    pub index: usize,
84    pub table: Option<String>,
85    pub expanded_columns: Vec<String>,
86}
87
88/// Compact fact about an upstream column reference.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct ColumnReferenceFact {
92    pub source_name: Option<String>,
93    pub source_alias: Option<String>,
94    pub source_kind: SourceKind,
95    pub table: Option<String>,
96    pub column: String,
97    pub unqualified: bool,
98    pub confidence: ReferenceConfidence,
99}
100
101/// Compact fact about a relation visible in the root scope.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct RelationFact {
105    pub name: String,
106    pub alias: Option<String>,
107    pub kind: SourceKind,
108    pub columns: Vec<String>,
109}
110
111/// Compact fact about a set operation.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(rename_all = "camelCase")]
114pub struct SetOperationFact {
115    pub kind: String,
116    pub all: bool,
117    pub distinct: bool,
118    pub output_columns: Vec<String>,
119    pub branches: Vec<SetOperationBranchFact>,
120}
121
122/// Compact facts for one immediate set-operation branch.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124#[serde(rename_all = "camelCase")]
125pub struct SetOperationBranchFact {
126    pub index: usize,
127    pub projections: Vec<ProjectionFact>,
128}
129
130/// High-level kind of transformation performed by a projection.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum TransformKind {
134    Direct,
135    Cast,
136    Aggregation,
137    Constant,
138    Expression,
139    Star,
140}
141
142/// Confidence level for a compact upstream column reference.
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum ReferenceConfidence {
146    Resolved,
147    Ambiguous,
148    Unknown,
149}
150
151/// Conservative nullability classification for one output projection.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153#[serde(rename_all = "snake_case")]
154pub enum ProjectionNullability {
155    NonNull,
156    Nullable,
157    Unknown,
158}
159
160/// Analyze a single SELECT or set-operation query.
161pub fn analyze_query(sql: &str, options: AnalyzeQueryOptions) -> Result<QueryAnalysis> {
162    let mut expression = parse_one(sql, options.dialect)?;
163    expression = effective_query(expression);
164    ensure_query(&expression)?;
165    let original_expression = expression.clone();
166
167    let mapping_schema = options
168        .schema
169        .as_ref()
170        .map(|schema| analysis_mapping_schema(schema, options.dialect));
171    let schema_info = options.schema.as_ref().map(AnalysisSchemaInfo::from_schema);
172    let cte_facts = top_level_cte_facts(&original_expression, options.dialect)?;
173    let star_projections = star_projection_facts(&original_expression, mapping_schema.as_ref());
174
175    if let Some(schema) = mapping_schema.as_ref() {
176        let qualify_options = QualifyColumnsOptions::new().with_dialect(options.dialect);
177        expression = qualify_columns(expression, schema, &qualify_options)
178            .map_err(|e| Error::internal(format!("query analysis qualification failed: {e}")))?;
179    }
180
181    let annotation_schema = mapping_schema.as_ref().map(|schema| {
182        let mut alias_schema = schema.clone();
183        add_scope_aliases_to_schema(
184            &build_scope(&expression),
185            schema,
186            &mut alias_schema,
187            options.dialect,
188        );
189        alias_schema
190    });
191
192    annotate_types(
193        &mut expression,
194        annotation_schema
195            .as_ref()
196            .map(|schema| schema as &dyn Schema),
197        Some(options.dialect),
198    );
199    crate::lineage::expand_cte_stars(
200        &mut expression,
201        annotation_schema
202            .as_ref()
203            .or(mapping_schema.as_ref())
204            .map(|schema| schema as &dyn Schema),
205    );
206
207    let scope = build_scope(&expression);
208    let nullability_context = NullabilityContext {
209        schema: schema_info.as_ref(),
210        nullable_sources: nullable_source_names(&expression),
211    };
212    let shape = if is_set_operation(&expression) {
213        QueryShape::SetOperation
214    } else {
215        QueryShape::Select
216    };
217
218    Ok(QueryAnalysis {
219        shape,
220        ctes: collect_cte_names(&expression),
221        cte_facts,
222        projections: projection_facts_for_query(
223            &expression,
224            &scope,
225            options.dialect,
226            &nullability_context,
227        ),
228        relations: relation_facts(&scope, mapping_schema.as_ref()),
229        base_tables: base_table_facts(&scope, mapping_schema.as_ref()),
230        star_projections,
231        set_operations: set_operation_facts(&expression, &scope, options.dialect),
232    })
233}
234
235fn analysis_mapping_schema(schema: &ValidationSchema, dialect: DialectType) -> MappingSchema {
236    let broad_schema = mapping_schema_from_validation_schema(schema);
237    let mut mapping_schema = MappingSchema::with_dialect(dialect);
238
239    for table in &schema.tables {
240        let table_names = validation_table_names(table);
241        if table_names.is_empty() {
242            continue;
243        }
244
245        let fallback_table = table_names[0].as_str();
246        let columns: Vec<(String, DataType)> = table
247            .columns
248            .iter()
249            .map(|column| {
250                let data_type = parse_analysis_data_type(&column.data_type, dialect)
251                    .unwrap_or_else(|| {
252                        broad_schema
253                            .get_column_type(fallback_table, &column.name)
254                            .unwrap_or(DataType::Unknown)
255                    });
256                (column.name.to_ascii_lowercase(), data_type)
257            })
258            .collect();
259
260        for table_name in table_names {
261            let _ = mapping_schema.add_table(&table_name, &columns, Some(dialect));
262        }
263    }
264
265    mapping_schema
266}
267
268fn validation_table_names(table: &crate::validation::SchemaTable) -> Vec<String> {
269    let mut names = Vec::new();
270
271    names.push(table.name.to_ascii_lowercase());
272    if let Some(schema_name) = &table.schema {
273        names.push(format!(
274            "{}.{}",
275            schema_name.to_ascii_lowercase(),
276            table.name.to_ascii_lowercase()
277        ));
278    }
279    for alias in &table.aliases {
280        names.push(alias.to_ascii_lowercase());
281    }
282
283    names.sort();
284    names.dedup();
285    names
286}
287
288fn parse_analysis_data_type(data_type: &str, dialect: DialectType) -> Option<DataType> {
289    let trimmed = data_type.trim();
290    if trimmed.is_empty() {
291        return None;
292    }
293    parse_data_type(trimmed, dialect).ok()
294}
295
296fn add_scope_aliases_to_schema(
297    scope: &Scope,
298    source_schema: &MappingSchema,
299    target_schema: &mut MappingSchema,
300    dialect: DialectType,
301) {
302    for child_scope in scope.traverse() {
303        for (source_name, source) in &child_scope.sources {
304            if source.kind != SourceKind::Table {
305                continue;
306            }
307            if let Some(table_name) = source_table_name(source) {
308                if source_name == &table_name {
309                    continue;
310                }
311                if let Ok(column_names) = source_schema.column_names(&table_name) {
312                    let columns: Vec<(String, DataType)> = column_names
313                        .iter()
314                        .map(|column| {
315                            (
316                                column.clone(),
317                                source_schema
318                                    .get_column_type(&table_name, column)
319                                    .unwrap_or(DataType::Unknown),
320                            )
321                        })
322                        .collect();
323                    let _ = target_schema.add_table(source_name, &columns, Some(dialect));
324                }
325            }
326        }
327    }
328}
329
330#[derive(Debug, Clone)]
331struct AnalysisColumnInfo {
332    nullable: Option<bool>,
333    primary_key: bool,
334}
335
336#[derive(Debug, Clone)]
337struct AnalysisSchemaInfo {
338    columns: HashMap<(String, String), AnalysisColumnInfo>,
339}
340
341impl AnalysisSchemaInfo {
342    fn from_schema(schema: &ValidationSchema) -> Self {
343        let mut columns = HashMap::new();
344
345        for table in &schema.tables {
346            let table_names = validation_table_names(table);
347            let primary_keys: HashSet<String> = table
348                .primary_key
349                .iter()
350                .map(|column| column.to_ascii_lowercase())
351                .collect();
352
353            for column in &table.columns {
354                let info = AnalysisColumnInfo {
355                    nullable: column.nullable,
356                    primary_key: column.primary_key
357                        || primary_keys.contains(&column.name.to_ascii_lowercase()),
358                };
359
360                for table_name in &table_names {
361                    columns.insert(
362                        (
363                            normalize_lookup_name(table_name),
364                            normalize_lookup_name(&column.name),
365                        ),
366                        info.clone(),
367                    );
368                }
369            }
370        }
371
372        Self { columns }
373    }
374
375    fn column(&self, table: &str, column: &str) -> Option<&AnalysisColumnInfo> {
376        self.columns
377            .get(&(normalize_lookup_name(table), normalize_lookup_name(column)))
378    }
379}
380
381struct NullabilityContext<'a> {
382    schema: Option<&'a AnalysisSchemaInfo>,
383    nullable_sources: HashSet<String>,
384}
385
386fn top_level_cte_facts(expression: &Expression, dialect: DialectType) -> Result<Vec<CteFact>> {
387    let Some(with_clause) = with_clause(expression) else {
388        return Ok(Vec::new());
389    };
390
391    with_clause
392        .ctes
393        .iter()
394        .map(|cte| {
395            Ok(CteFact {
396                name: cte.alias.name.clone(),
397                columns: cte
398                    .columns
399                    .iter()
400                    .map(|column| column.name.clone())
401                    .collect(),
402                body_sql: Dialect::get(dialect).generate(&cte.this)?,
403                output_columns: get_output_column_names(&cte.this),
404            })
405        })
406        .collect()
407}
408
409fn star_projection_facts(
410    expression: &Expression,
411    mapping_schema: Option<&MappingSchema>,
412) -> Vec<StarProjectionFact> {
413    let scope = build_scope(expression);
414    let ordered_sources = ordered_source_names_for_query(expression);
415
416    select_expressions_for_query(expression)
417        .iter()
418        .enumerate()
419        .filter_map(|(index, projection)| {
420            let inner = unwrap_projection_alias(projection);
421            if !projection_is_star(inner) {
422                return None;
423            }
424
425            let table = projection_star_table(inner);
426            let expanded_columns =
427                expanded_star_columns(table.as_deref(), &scope, &ordered_sources, mapping_schema);
428
429            Some(StarProjectionFact {
430                index,
431                table,
432                expanded_columns,
433            })
434        })
435        .collect()
436}
437
438fn expanded_star_columns(
439    star_table: Option<&str>,
440    scope: &Scope,
441    ordered_sources: &[String],
442    mapping_schema: Option<&MappingSchema>,
443) -> Vec<String> {
444    let mut columns = Vec::new();
445    let mut source_names: Vec<String> = if ordered_sources.is_empty() {
446        let mut names: Vec<_> = scope.sources.keys().cloned().collect();
447        names.sort();
448        names
449    } else {
450        ordered_sources.to_vec()
451    };
452
453    source_names.dedup();
454
455    for source_name in source_names {
456        let Some(source) = scope.sources.get(&source_name) else {
457            continue;
458        };
459
460        if let Some(star_table) = star_table {
461            let matches = source_name.eq_ignore_ascii_case(star_table)
462                || source
463                    .alias
464                    .as_deref()
465                    .is_some_and(|alias| alias.eq_ignore_ascii_case(star_table))
466                || source_table_name(source)
467                    .is_some_and(|table| table.eq_ignore_ascii_case(star_table));
468
469            if !matches {
470                continue;
471            }
472        }
473
474        columns.extend(source_columns(source, mapping_schema));
475    }
476
477    columns
478}
479
480fn ordered_source_names_for_query(expression: &Expression) -> Vec<String> {
481    match expression {
482        Expression::Select(select) => ordered_source_names_for_select(select),
483        Expression::Union(union) => ordered_source_names_for_query(&union.left),
484        Expression::Intersect(intersect) => ordered_source_names_for_query(&intersect.left),
485        Expression::Except(except) => ordered_source_names_for_query(&except.left),
486        Expression::Subquery(subquery) => ordered_source_names_for_query(&subquery.this),
487        _ => Vec::new(),
488    }
489}
490
491fn ordered_source_names_for_select(select: &crate::expressions::Select) -> Vec<String> {
492    let mut sources = Vec::new();
493
494    if let Some(from) = &select.from {
495        for expression in &from.expressions {
496            if let Some(source_name) = expression_source_name(expression) {
497                sources.push(source_name);
498            }
499        }
500    }
501
502    for join in &select.joins {
503        if let Some(source_name) = expression_source_name(&join.this) {
504            sources.push(source_name);
505        }
506    }
507
508    sources
509}
510
511fn nullable_source_names(expression: &Expression) -> HashSet<String> {
512    match expression {
513        Expression::Select(select) => nullable_source_names_for_select(select),
514        Expression::Union(union) => nullable_source_names(&union.left),
515        Expression::Intersect(intersect) => nullable_source_names(&intersect.left),
516        Expression::Except(except) => nullable_source_names(&except.left),
517        Expression::Subquery(subquery) => nullable_source_names(&subquery.this),
518        _ => HashSet::new(),
519    }
520}
521
522fn nullable_source_names_for_select(select: &crate::expressions::Select) -> HashSet<String> {
523    let mut nullable = HashSet::new();
524    let mut left_sources = Vec::new();
525
526    if let Some(from) = &select.from {
527        for expression in &from.expressions {
528            if let Some(source_name) = expression_source_name(expression) {
529                left_sources.push(source_name);
530            }
531        }
532    }
533
534    for join in &select.joins {
535        let right_source = expression_source_name(&join.this);
536
537        if join_nullable_left(join.kind) {
538            for source_name in &left_sources {
539                nullable.insert(normalize_lookup_name(source_name));
540            }
541        }
542
543        if join_nullable_right(join.kind) {
544            if let Some(source_name) = &right_source {
545                nullable.insert(normalize_lookup_name(source_name));
546            }
547        }
548
549        if let Some(source_name) = right_source {
550            left_sources.push(source_name);
551        }
552    }
553
554    nullable
555}
556
557fn join_nullable_left(kind: JoinKind) -> bool {
558    matches!(
559        kind,
560        JoinKind::Right
561            | JoinKind::NaturalRight
562            | JoinKind::AsOfRight
563            | JoinKind::Full
564            | JoinKind::NaturalFull
565            | JoinKind::Outer
566    )
567}
568
569fn join_nullable_right(kind: JoinKind) -> bool {
570    matches!(
571        kind,
572        JoinKind::Left
573            | JoinKind::NaturalLeft
574            | JoinKind::AsOfLeft
575            | JoinKind::LeftLateral
576            | JoinKind::OuterApply
577            | JoinKind::LeftArray
578            | JoinKind::Full
579            | JoinKind::NaturalFull
580            | JoinKind::Outer
581    )
582}
583
584fn expression_source_name(expression: &Expression) -> Option<String> {
585    match expression {
586        Expression::Table(table) => table
587            .alias
588            .as_ref()
589            .map(|alias| alias.name.clone())
590            .or_else(|| Some(table.name.name.clone())),
591        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|alias| alias.name.clone()),
592        Expression::Alias(alias) => Some(alias.alias.name.clone()),
593        Expression::Cte(cte) => Some(cte.alias.name.clone()),
594        _ => None,
595    }
596}
597
598fn normalize_lookup_name(name: &str) -> String {
599    name.to_ascii_lowercase()
600}
601
602fn effective_query(expression: Expression) -> Expression {
603    match expression {
604        Expression::Prepare(prepare) => prepare.statement,
605        Expression::Subquery(subquery) if subquery.alias.is_none() => subquery.this,
606        other => other,
607    }
608}
609
610fn ensure_query(expression: &Expression) -> Result<()> {
611    if matches!(
612        expression,
613        Expression::Select(_)
614            | Expression::Union(_)
615            | Expression::Intersect(_)
616            | Expression::Except(_)
617    ) {
618        Ok(())
619    } else {
620        Err(Error::internal(
621            "analyze_query requires a SELECT or set operation query",
622        ))
623    }
624}
625
626fn is_set_operation(expression: &Expression) -> bool {
627    matches!(
628        expression,
629        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
630    )
631}
632
633fn collect_cte_names(expression: &Expression) -> Vec<String> {
634    let mut names = Vec::new();
635    let mut seen = HashSet::new();
636    collect_cte_names_inner(expression, &mut names, &mut seen);
637    names
638}
639
640fn collect_cte_names_inner(
641    expression: &Expression,
642    names: &mut Vec<String>,
643    seen: &mut HashSet<String>,
644) {
645    if let Some(with_clause) = with_clause(expression) {
646        collect_with_names(with_clause, names, seen);
647    }
648
649    match expression {
650        Expression::Union(union) => {
651            collect_cte_names_inner(&union.left, names, seen);
652            collect_cte_names_inner(&union.right, names, seen);
653        }
654        Expression::Intersect(intersect) => {
655            collect_cte_names_inner(&intersect.left, names, seen);
656            collect_cte_names_inner(&intersect.right, names, seen);
657        }
658        Expression::Except(except) => {
659            collect_cte_names_inner(&except.left, names, seen);
660            collect_cte_names_inner(&except.right, names, seen);
661        }
662        Expression::Subquery(subquery) => collect_cte_names_inner(&subquery.this, names, seen),
663        _ => {}
664    }
665}
666
667fn collect_with_names(with_clause: &With, names: &mut Vec<String>, seen: &mut HashSet<String>) {
668    for cte in &with_clause.ctes {
669        if seen.insert(cte.alias.name.clone()) {
670            names.push(cte.alias.name.clone());
671        }
672        collect_cte_names_inner(&cte.this, names, seen);
673    }
674}
675
676fn with_clause(expression: &Expression) -> Option<&With> {
677    match expression {
678        Expression::Select(select) => select.with.as_ref(),
679        Expression::Union(union) => union.with.as_ref(),
680        Expression::Intersect(intersect) => intersect.with.as_ref(),
681        Expression::Except(except) => except.with.as_ref(),
682        _ => None,
683    }
684}
685
686fn projection_facts_for_query(
687    expression: &Expression,
688    scope: &Scope,
689    dialect: DialectType,
690    nullability_context: &NullabilityContext<'_>,
691) -> Vec<ProjectionFact> {
692    let expressions = select_expressions_for_query(expression);
693    let names = get_output_column_names(expression);
694
695    expressions
696        .iter()
697        .enumerate()
698        .map(|(index, projection)| {
699            projection_fact(
700                index,
701                names
702                    .get(index)
703                    .cloned()
704                    .or_else(|| projection_name(projection)),
705                projection,
706                expression,
707                scope,
708                dialect,
709                nullability_context,
710            )
711        })
712        .collect()
713}
714
715fn select_expressions_for_query(expression: &Expression) -> Vec<&Expression> {
716    match expression {
717        Expression::Select(select) => select.expressions.iter().collect(),
718        Expression::Union(union) => select_expressions_for_query(&union.left),
719        Expression::Intersect(intersect) => select_expressions_for_query(&intersect.left),
720        Expression::Except(except) => select_expressions_for_query(&except.left),
721        Expression::Subquery(subquery) => select_expressions_for_query(&subquery.this),
722        _ => Vec::new(),
723    }
724}
725
726fn projection_fact(
727    index: usize,
728    name: Option<String>,
729    projection: &Expression,
730    query: &Expression,
731    scope: &Scope,
732    dialect: DialectType,
733    nullability_context: &NullabilityContext<'_>,
734) -> ProjectionFact {
735    let inner = unwrap_projection_alias(projection);
736    let is_star = projection_is_star(inner);
737    let upstream = lineage_by_index_from_expression(index, query, Some(dialect), false)
738        .map(|node| terminal_references_from_lineage(&node))
739        .ok()
740        .filter(|refs| !refs.is_empty())
741        .unwrap_or_else(|| fallback_column_references(inner, scope));
742
743    ProjectionFact {
744        index,
745        name,
746        is_star,
747        star_table: projection_star_table(inner),
748        transform_kind: transform_kind(inner),
749        cast_type: cast_type(inner, dialect),
750        type_hint: projection
751            .inferred_type()
752            .or_else(|| inner.inferred_type())
753            .and_then(|data_type| render_data_type(data_type, dialect)),
754        nullability: projection_nullability(inner, scope, nullability_context),
755        upstream,
756    }
757}
758
759fn unwrap_projection_alias(expression: &Expression) -> &Expression {
760    match expression {
761        Expression::Alias(alias) => unwrap_projection_alias(&alias.this),
762        Expression::Annotated(annotated) => unwrap_projection_alias(&annotated.this),
763        Expression::Paren(paren) => unwrap_projection_alias(&paren.this),
764        _ => expression,
765    }
766}
767
768fn projection_name(expression: &Expression) -> Option<String> {
769    match expression {
770        Expression::Alias(alias) => Some(alias.alias.name.clone()),
771        Expression::Column(column) => Some(column.name.name.clone()),
772        Expression::Identifier(identifier) => Some(identifier.name.clone()),
773        Expression::Star(_) => Some("*".to_string()),
774        Expression::Annotated(annotated) => projection_name(&annotated.this),
775        _ => None,
776    }
777}
778
779fn projection_is_star(expression: &Expression) -> bool {
780    matches!(expression, Expression::Star(_))
781        || matches!(expression, Expression::Column(column) if column.name.name == "*")
782}
783
784fn projection_star_table(expression: &Expression) -> Option<String> {
785    match expression {
786        Expression::Star(star) => star
787            .table
788            .as_ref()
789            .map(|identifier| identifier.name.clone()),
790        Expression::Column(column) if column.name.name == "*" => column
791            .table
792            .as_ref()
793            .map(|identifier| identifier.name.clone()),
794        _ => None,
795    }
796}
797
798fn transform_kind(expression: &Expression) -> TransformKind {
799    if projection_is_star(expression) {
800        TransformKind::Star
801    } else if is_cast_expression(expression) {
802        TransformKind::Cast
803    } else if contains_aggregate(expression) {
804        TransformKind::Aggregation
805    } else if matches!(
806        expression,
807        Expression::Column(_) | Expression::Identifier(_)
808    ) {
809        TransformKind::Direct
810    } else if is_simple_constant(expression) {
811        TransformKind::Constant
812    } else {
813        TransformKind::Expression
814    }
815}
816
817fn is_cast_expression(expression: &Expression) -> bool {
818    matches!(
819        expression,
820        Expression::Cast(_) | Expression::TryCast(_) | Expression::SafeCast(_)
821    )
822}
823
824fn cast_type(expression: &Expression, dialect: DialectType) -> Option<String> {
825    match expression {
826        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
827            render_data_type(&cast.to, dialect)
828        }
829        _ => None,
830    }
831}
832
833fn render_data_type(data_type: &DataType, dialect: DialectType) -> Option<String> {
834    Dialect::get(dialect)
835        .generate(&Expression::DataType(data_type.clone()))
836        .ok()
837}
838
839fn is_simple_constant(expression: &Expression) -> bool {
840    match expression {
841        Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_) => true,
842        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
843            is_simple_constant(&cast.this)
844        }
845        Expression::Neg(unary) | Expression::BitwiseNot(unary) => is_simple_constant(&unary.this),
846        _ => false,
847    }
848}
849
850fn projection_nullability(
851    expression: &Expression,
852    scope: &Scope,
853    context: &NullabilityContext<'_>,
854) -> ProjectionNullability {
855    match expression {
856        Expression::Alias(alias) => projection_nullability(&alias.this, scope, context),
857        Expression::Annotated(annotated) => projection_nullability(&annotated.this, scope, context),
858        Expression::Paren(paren) => projection_nullability(&paren.this, scope, context),
859        Expression::Literal(_) | Expression::Boolean(_) => ProjectionNullability::NonNull,
860        Expression::Null(_) => ProjectionNullability::Nullable,
861        Expression::Count(_) | Expression::CountIf(_) => ProjectionNullability::NonNull,
862        Expression::Cast(cast) => projection_nullability(&cast.this, scope, context),
863        Expression::TryCast(_) | Expression::SafeCast(_) => ProjectionNullability::Unknown,
864        Expression::Column(column) => column_nullability(
865            &column.name.name,
866            column.table.as_ref().map(|table| table.name.as_str()),
867            scope,
868            context,
869        ),
870        Expression::Identifier(identifier) => {
871            column_nullability(&identifier.name, None, scope, context)
872        }
873        Expression::Coalesce(func) => coalesce_nullability(&func.expressions, scope, context),
874        _ => ProjectionNullability::Unknown,
875    }
876}
877
878fn column_nullability(
879    column_name: &str,
880    source_name: Option<&str>,
881    scope: &Scope,
882    context: &NullabilityContext<'_>,
883) -> ProjectionNullability {
884    let resolved_source_name = source_name
885        .map(str::to_string)
886        .or_else(|| single_scope_source_name(scope));
887
888    if let Some(source_name) = &resolved_source_name {
889        if context
890            .nullable_sources
891            .contains(&normalize_lookup_name(source_name))
892        {
893            return ProjectionNullability::Nullable;
894        }
895    }
896
897    let Some(schema) = context.schema else {
898        return ProjectionNullability::Unknown;
899    };
900
901    let table_name = resolved_source_name
902        .as_ref()
903        .and_then(|name| scope.sources.get(name).and_then(source_table_name))
904        .or(resolved_source_name);
905
906    let Some(table_name) = table_name else {
907        return ProjectionNullability::Unknown;
908    };
909
910    match schema.column(&table_name, column_name) {
911        Some(info) if info.primary_key || info.nullable == Some(false) => {
912            ProjectionNullability::NonNull
913        }
914        Some(info) if info.nullable == Some(true) => ProjectionNullability::Nullable,
915        Some(_) | None => ProjectionNullability::Unknown,
916    }
917}
918
919fn single_scope_source_name(scope: &Scope) -> Option<String> {
920    if scope.sources.len() == 1 {
921        scope.sources.keys().next().cloned()
922    } else {
923        None
924    }
925}
926
927fn coalesce_nullability(
928    expressions: &[Expression],
929    scope: &Scope,
930    context: &NullabilityContext<'_>,
931) -> ProjectionNullability {
932    if expressions.is_empty() {
933        return ProjectionNullability::Unknown;
934    }
935
936    let mut all_nullable = true;
937
938    for expression in expressions {
939        match projection_nullability(unwrap_projection_alias(expression), scope, context) {
940            ProjectionNullability::NonNull => return ProjectionNullability::NonNull,
941            ProjectionNullability::Nullable => {}
942            ProjectionNullability::Unknown => all_nullable = false,
943        }
944    }
945
946    if all_nullable {
947        ProjectionNullability::Nullable
948    } else {
949        ProjectionNullability::Unknown
950    }
951}
952
953fn terminal_references_from_lineage(node: &LineageNode) -> Vec<ColumnReferenceFact> {
954    let mut refs = Vec::new();
955    collect_terminal_references(node, &mut refs);
956    dedupe_column_refs(refs)
957}
958
959fn collect_terminal_references(node: &LineageNode, refs: &mut Vec<ColumnReferenceFact>) {
960    if node.downstream.is_empty() {
961        if let Some(reference) = column_reference_from_lineage_node(node) {
962            refs.push(reference);
963        }
964        return;
965    }
966
967    for child in &node.downstream {
968        collect_terminal_references(child, refs);
969    }
970}
971
972fn column_reference_from_lineage_node(node: &LineageNode) -> Option<ColumnReferenceFact> {
973    match &node.expression {
974        Expression::Column(column) => {
975            let source_name = non_empty_string(node.source_name.clone());
976            let table =
977                lineage_node_table(node).or_else(|| column.table.as_ref().map(|t| t.name.clone()));
978            let confidence = if node.source_kind == SourceKind::Unknown && source_name.is_none() {
979                ReferenceConfidence::Unknown
980            } else {
981                ReferenceConfidence::Resolved
982            };
983            Some(ColumnReferenceFact {
984                source_name,
985                source_alias: node.source_alias.clone(),
986                source_kind: node.source_kind,
987                table,
988                column: column.name.name.clone(),
989                unqualified: column.table.is_none(),
990                confidence,
991            })
992        }
993        Expression::Star(_) => Some(ColumnReferenceFact {
994            source_name: non_empty_string(node.source_name.clone()),
995            source_alias: node.source_alias.clone(),
996            source_kind: node.source_kind,
997            table: lineage_node_table(node),
998            column: "*".to_string(),
999            unqualified: true,
1000            confidence: if node.source_kind == SourceKind::Unknown {
1001                ReferenceConfidence::Unknown
1002            } else {
1003                ReferenceConfidence::Resolved
1004            },
1005        }),
1006        _ => None,
1007    }
1008}
1009
1010fn lineage_node_table(node: &LineageNode) -> Option<String> {
1011    match &node.source {
1012        Expression::Table(table) => Some(table_name(table)),
1013        _ => None,
1014    }
1015}
1016
1017fn fallback_column_references(expression: &Expression, scope: &Scope) -> Vec<ColumnReferenceFact> {
1018    let mut refs = Vec::new();
1019    let source_count = scope.sources.len();
1020    let single_source = if source_count == 1 {
1021        scope.sources.iter().next()
1022    } else {
1023        None
1024    };
1025
1026    for column_expr in expression.find_all(|candidate| matches!(candidate, Expression::Column(_))) {
1027        if let Expression::Column(column) = column_expr {
1028            if column.name.name == "*" {
1029                continue;
1030            }
1031            let source = column
1032                .table
1033                .as_ref()
1034                .and_then(|table| scope.sources.get(&table.name));
1035            let (source_name, source_alias, source_kind, table, confidence) =
1036                if let Some(table_identifier) = &column.table {
1037                    if let Some(source) = source {
1038                        (
1039                            Some(table_identifier.name.clone()),
1040                            source.alias.clone(),
1041                            source.kind,
1042                            source_table_name(source)
1043                                .or_else(|| Some(table_identifier.name.clone())),
1044                            ReferenceConfidence::Resolved,
1045                        )
1046                    } else {
1047                        (
1048                            Some(table_identifier.name.clone()),
1049                            None,
1050                            SourceKind::Unknown,
1051                            Some(table_identifier.name.clone()),
1052                            ReferenceConfidence::Unknown,
1053                        )
1054                    }
1055                } else if let Some((name, source)) = single_source {
1056                    (
1057                        Some(name.clone()),
1058                        source.alias.clone(),
1059                        source.kind,
1060                        source_table_name(source).or_else(|| Some(name.clone())),
1061                        ReferenceConfidence::Resolved,
1062                    )
1063                } else if source_count > 1 {
1064                    (
1065                        None,
1066                        None,
1067                        SourceKind::Unknown,
1068                        None,
1069                        ReferenceConfidence::Ambiguous,
1070                    )
1071                } else {
1072                    (
1073                        None,
1074                        None,
1075                        SourceKind::Unknown,
1076                        None,
1077                        ReferenceConfidence::Unknown,
1078                    )
1079                };
1080
1081            refs.push(ColumnReferenceFact {
1082                source_name,
1083                source_alias,
1084                source_kind,
1085                table,
1086                column: column.name.name.clone(),
1087                unqualified: column.table.is_none(),
1088                confidence,
1089            });
1090        }
1091    }
1092
1093    dedupe_column_refs(refs)
1094}
1095
1096fn dedupe_column_refs(refs: Vec<ColumnReferenceFact>) -> Vec<ColumnReferenceFact> {
1097    let mut seen = HashSet::new();
1098    let mut deduped = Vec::new();
1099
1100    for reference in refs {
1101        let key = (
1102            reference.source_name.clone(),
1103            reference.source_alias.clone(),
1104            reference.table.clone(),
1105            reference.column.clone(),
1106            format!("{:?}", reference.source_kind),
1107            reference.unqualified,
1108            format!("{:?}", reference.confidence),
1109        );
1110        if seen.insert(key) {
1111            deduped.push(reference);
1112        }
1113    }
1114
1115    deduped
1116}
1117
1118fn relation_facts(
1119    scope: &Scope,
1120    mapping_schema: Option<&crate::schema::MappingSchema>,
1121) -> Vec<RelationFact> {
1122    let mut relations = Vec::new();
1123    let mut seen = HashSet::new();
1124    collect_relation_facts(scope, mapping_schema, &mut seen, &mut relations);
1125
1126    relations.sort_by(|left, right| {
1127        left.name
1128            .cmp(&right.name)
1129            .then_with(|| left.alias.cmp(&right.alias))
1130    });
1131    relations
1132}
1133
1134fn collect_relation_facts(
1135    scope: &Scope,
1136    mapping_schema: Option<&crate::schema::MappingSchema>,
1137    seen: &mut HashSet<String>,
1138    relations: &mut Vec<RelationFact>,
1139) {
1140    for relation in scope
1141        .sources
1142        .iter()
1143        .map(|(source_name, source)| RelationFact {
1144            name: source
1145                .lineage_name
1146                .clone()
1147                .or_else(|| source_table_name(source))
1148                .unwrap_or_else(|| source_name.clone()),
1149            alias: source.alias.clone().or_else(|| source_alias(source)),
1150            kind: source.kind,
1151            columns: source_columns(source, mapping_schema),
1152        })
1153    {
1154        let key = format!("{:?}|{}|{:?}", relation.kind, relation.name, relation.alias);
1155        if seen.insert(key) {
1156            relations.push(relation);
1157        }
1158    }
1159
1160    for branch_scope in &scope.union_scopes {
1161        collect_relation_facts(branch_scope, mapping_schema, seen, relations);
1162    }
1163}
1164
1165fn base_table_facts(
1166    scope: &Scope,
1167    mapping_schema: Option<&crate::schema::MappingSchema>,
1168) -> Vec<RelationFact> {
1169    let mut relations = Vec::new();
1170    let mut seen = HashSet::new();
1171
1172    collect_base_table_facts(scope, mapping_schema, &mut seen, &mut relations);
1173
1174    relations.sort_by(|left, right| left.name.cmp(&right.name));
1175    relations
1176}
1177
1178fn collect_base_table_facts(
1179    scope: &Scope,
1180    mapping_schema: Option<&crate::schema::MappingSchema>,
1181    seen: &mut HashSet<String>,
1182    relations: &mut Vec<RelationFact>,
1183) {
1184    for source in scope.sources.values() {
1185        if source.kind != SourceKind::Table {
1186            continue;
1187        }
1188
1189        let Some(table_name) = source_table_name(source) else {
1190            continue;
1191        };
1192
1193        if seen.insert(table_name.clone()) {
1194            relations.push(RelationFact {
1195                name: table_name,
1196                alias: source.alias.clone().or_else(|| source_alias(source)),
1197                kind: SourceKind::Table,
1198                columns: source_columns(source, mapping_schema),
1199            });
1200        }
1201    }
1202
1203    for child_scope in scope
1204        .cte_scopes
1205        .iter()
1206        .chain(scope.union_scopes.iter())
1207        .chain(scope.table_scopes.iter())
1208        .chain(scope.derived_table_scopes.iter())
1209        .chain(scope.subquery_scopes.iter())
1210    {
1211        collect_base_table_facts(child_scope, mapping_schema, seen, relations);
1212    }
1213}
1214
1215fn source_columns(
1216    source: &SourceInfo,
1217    mapping_schema: Option<&crate::schema::MappingSchema>,
1218) -> Vec<String> {
1219    match &source.expression {
1220        Expression::Table(table) => mapping_schema
1221            .and_then(|schema| schema.column_names(&table_name(table)).ok())
1222            .unwrap_or_default(),
1223        Expression::Select(_)
1224        | Expression::Union(_)
1225        | Expression::Intersect(_)
1226        | Expression::Except(_) => get_output_column_names(&source.expression),
1227        Expression::Subquery(subquery) => get_output_column_names(&subquery.this),
1228        Expression::Cte(cte) if !cte.columns.is_empty() => cte
1229            .columns
1230            .iter()
1231            .map(|column| column.name.clone())
1232            .collect(),
1233        Expression::Cte(cte) => get_output_column_names(&cte.this),
1234        _ => Vec::new(),
1235    }
1236}
1237
1238fn source_table_name(source: &SourceInfo) -> Option<String> {
1239    match &source.expression {
1240        Expression::Table(table) => Some(table_name(table)),
1241        _ => None,
1242    }
1243}
1244
1245fn source_alias(source: &SourceInfo) -> Option<String> {
1246    match &source.expression {
1247        Expression::Table(table) => table.alias.as_ref().map(|alias| alias.name.clone()),
1248        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|alias| alias.name.clone()),
1249        _ => None,
1250    }
1251}
1252
1253fn table_name(table: &TableRef) -> String {
1254    let mut parts = Vec::new();
1255    if let Some(catalog) = &table.catalog {
1256        parts.push(catalog.name.clone());
1257    }
1258    if let Some(schema) = &table.schema {
1259        parts.push(schema.name.clone());
1260    }
1261    parts.push(table.name.name.clone());
1262    parts.join(".")
1263}
1264
1265fn set_operation_facts(
1266    expression: &Expression,
1267    scope: &Scope,
1268    dialect: DialectType,
1269) -> Vec<SetOperationFact> {
1270    let mut facts = Vec::new();
1271    collect_set_operation_facts(expression, scope, dialect, &mut facts);
1272    facts
1273}
1274
1275fn collect_set_operation_facts(
1276    expression: &Expression,
1277    scope: &Scope,
1278    dialect: DialectType,
1279    facts: &mut Vec<SetOperationFact>,
1280) {
1281    match expression {
1282        Expression::Union(union) => {
1283            facts.push(SetOperationFact {
1284                kind: "union".to_string(),
1285                all: union.all,
1286                distinct: union.distinct,
1287                output_columns: get_output_column_names(expression),
1288                branches: set_operation_branches(&union.left, &union.right, scope, dialect),
1289            });
1290            collect_set_operation_facts(&union.left, scope, dialect, facts);
1291            collect_set_operation_facts(&union.right, scope, dialect, facts);
1292        }
1293        Expression::Intersect(intersect) => {
1294            facts.push(SetOperationFact {
1295                kind: "intersect".to_string(),
1296                all: intersect.all,
1297                distinct: intersect.distinct,
1298                output_columns: get_output_column_names(expression),
1299                branches: set_operation_branches(&intersect.left, &intersect.right, scope, dialect),
1300            });
1301            collect_set_operation_facts(&intersect.left, scope, dialect, facts);
1302            collect_set_operation_facts(&intersect.right, scope, dialect, facts);
1303        }
1304        Expression::Except(except) => {
1305            facts.push(SetOperationFact {
1306                kind: "except".to_string(),
1307                all: except.all,
1308                distinct: except.distinct,
1309                output_columns: get_output_column_names(expression),
1310                branches: set_operation_branches(&except.left, &except.right, scope, dialect),
1311            });
1312            collect_set_operation_facts(&except.left, scope, dialect, facts);
1313            collect_set_operation_facts(&except.right, scope, dialect, facts);
1314        }
1315        Expression::Subquery(subquery) => {
1316            collect_set_operation_facts(&subquery.this, scope, dialect, facts);
1317        }
1318        _ => {}
1319    }
1320}
1321
1322fn set_operation_branches(
1323    left: &Expression,
1324    right: &Expression,
1325    scope: &Scope,
1326    dialect: DialectType,
1327) -> Vec<SetOperationBranchFact> {
1328    vec![
1329        SetOperationBranchFact {
1330            index: 0,
1331            projections: projection_facts_for_branch(left, scope, dialect),
1332        },
1333        SetOperationBranchFact {
1334            index: 1,
1335            projections: projection_facts_for_branch(right, scope, dialect),
1336        },
1337    ]
1338}
1339
1340fn projection_facts_for_branch(
1341    expression: &Expression,
1342    root_scope: &Scope,
1343    dialect: DialectType,
1344) -> Vec<ProjectionFact> {
1345    let branch_scope = build_scope(expression);
1346    let scope = if branch_scope.sources.is_empty() {
1347        root_scope
1348    } else {
1349        &branch_scope
1350    };
1351    let nullability_context = NullabilityContext {
1352        schema: None,
1353        nullable_sources: nullable_source_names(expression),
1354    };
1355    projection_facts_for_query(expression, scope, dialect, &nullability_context)
1356}
1357
1358fn non_empty_string(value: String) -> Option<String> {
1359    if value.is_empty() {
1360        None
1361    } else {
1362        Some(value)
1363    }
1364}