1use crate::ast_transforms::get_output_column_names;
9use crate::dialects::{Dialect, DialectType};
10use crate::expressions::{DataType, Expression, TableRef, With};
11use crate::lineage::{lineage_by_index_from_expression, LineageNode};
12use crate::optimizer::annotate_types::annotate_types;
13use crate::optimizer::qualify_columns::{qualify_columns, QualifyColumnsOptions};
14use crate::schema::Schema;
15use crate::scope::{build_scope, Scope, SourceInfo, SourceKind};
16use crate::traversal::{contains_aggregate, ExpressionWalk};
17use crate::validation::{mapping_schema_from_validation_schema, ValidationSchema};
18use crate::{parse_one, Error, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24#[serde(rename_all = "camelCase", default)]
25pub struct AnalyzeQueryOptions {
26 pub dialect: DialectType,
28 pub schema: Option<ValidationSchema>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct QueryAnalysis {
36 pub shape: QueryShape,
37 pub ctes: Vec<String>,
38 pub projections: Vec<ProjectionFact>,
39 pub relations: Vec<RelationFact>,
40 pub set_operations: Vec<SetOperationFact>,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum QueryShape {
47 Select,
48 SetOperation,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct ProjectionFact {
55 pub index: usize,
56 pub name: Option<String>,
57 pub is_star: bool,
58 pub star_table: Option<String>,
59 pub transform_kind: TransformKind,
60 pub cast_type: Option<String>,
61 pub type_hint: Option<String>,
62 pub upstream: Vec<ColumnReferenceFact>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct ColumnReferenceFact {
69 pub source_name: Option<String>,
70 pub source_alias: Option<String>,
71 pub source_kind: SourceKind,
72 pub table: Option<String>,
73 pub column: String,
74 pub unqualified: bool,
75 pub confidence: ReferenceConfidence,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81pub struct RelationFact {
82 pub name: String,
83 pub alias: Option<String>,
84 pub kind: SourceKind,
85 pub columns: Vec<String>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct SetOperationFact {
92 pub kind: String,
93 pub all: bool,
94 pub distinct: bool,
95 pub output_columns: Vec<String>,
96 pub branches: Vec<SetOperationBranchFact>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101#[serde(rename_all = "camelCase")]
102pub struct SetOperationBranchFact {
103 pub index: usize,
104 pub projections: Vec<ProjectionFact>,
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
109#[serde(rename_all = "snake_case")]
110pub enum TransformKind {
111 Direct,
112 Cast,
113 Aggregation,
114 Constant,
115 Expression,
116 Star,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
121#[serde(rename_all = "snake_case")]
122pub enum ReferenceConfidence {
123 Resolved,
124 Ambiguous,
125 Unknown,
126}
127
128pub fn analyze_query(sql: &str, options: AnalyzeQueryOptions) -> Result<QueryAnalysis> {
130 let mut expression = parse_one(sql, options.dialect)?;
131 expression = effective_query(expression);
132 ensure_query(&expression)?;
133
134 let mapping_schema = options
135 .schema
136 .as_ref()
137 .map(mapping_schema_from_validation_schema);
138
139 if let Some(schema) = mapping_schema.as_ref() {
140 let qualify_options = QualifyColumnsOptions::new().with_dialect(options.dialect);
141 expression = qualify_columns(expression, schema, &qualify_options)
142 .map_err(|e| Error::internal(format!("query analysis qualification failed: {e}")))?;
143 }
144
145 annotate_types(
146 &mut expression,
147 mapping_schema.as_ref().map(|s| s as _),
148 Some(options.dialect),
149 );
150 crate::lineage::expand_cte_stars(&mut expression, mapping_schema.as_ref().map(|s| s as _));
151
152 let scope = build_scope(&expression);
153 let shape = if is_set_operation(&expression) {
154 QueryShape::SetOperation
155 } else {
156 QueryShape::Select
157 };
158
159 Ok(QueryAnalysis {
160 shape,
161 ctes: collect_cte_names(&expression),
162 projections: projection_facts_for_query(&expression, &scope, options.dialect),
163 relations: relation_facts(&scope, mapping_schema.as_ref()),
164 set_operations: set_operation_facts(&expression, &scope, options.dialect),
165 })
166}
167
168fn effective_query(expression: Expression) -> Expression {
169 match expression {
170 Expression::Prepare(prepare) => prepare.statement,
171 Expression::Subquery(subquery) if subquery.alias.is_none() => subquery.this,
172 other => other,
173 }
174}
175
176fn ensure_query(expression: &Expression) -> Result<()> {
177 if matches!(
178 expression,
179 Expression::Select(_)
180 | Expression::Union(_)
181 | Expression::Intersect(_)
182 | Expression::Except(_)
183 ) {
184 Ok(())
185 } else {
186 Err(Error::internal(
187 "analyze_query requires a SELECT or set operation query",
188 ))
189 }
190}
191
192fn is_set_operation(expression: &Expression) -> bool {
193 matches!(
194 expression,
195 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
196 )
197}
198
199fn collect_cte_names(expression: &Expression) -> Vec<String> {
200 let mut names = Vec::new();
201 let mut seen = HashSet::new();
202 collect_cte_names_inner(expression, &mut names, &mut seen);
203 names
204}
205
206fn collect_cte_names_inner(
207 expression: &Expression,
208 names: &mut Vec<String>,
209 seen: &mut HashSet<String>,
210) {
211 if let Some(with_clause) = with_clause(expression) {
212 collect_with_names(with_clause, names, seen);
213 }
214
215 match expression {
216 Expression::Union(union) => {
217 collect_cte_names_inner(&union.left, names, seen);
218 collect_cte_names_inner(&union.right, names, seen);
219 }
220 Expression::Intersect(intersect) => {
221 collect_cte_names_inner(&intersect.left, names, seen);
222 collect_cte_names_inner(&intersect.right, names, seen);
223 }
224 Expression::Except(except) => {
225 collect_cte_names_inner(&except.left, names, seen);
226 collect_cte_names_inner(&except.right, names, seen);
227 }
228 Expression::Subquery(subquery) => collect_cte_names_inner(&subquery.this, names, seen),
229 _ => {}
230 }
231}
232
233fn collect_with_names(with_clause: &With, names: &mut Vec<String>, seen: &mut HashSet<String>) {
234 for cte in &with_clause.ctes {
235 if seen.insert(cte.alias.name.clone()) {
236 names.push(cte.alias.name.clone());
237 }
238 collect_cte_names_inner(&cte.this, names, seen);
239 }
240}
241
242fn with_clause(expression: &Expression) -> Option<&With> {
243 match expression {
244 Expression::Select(select) => select.with.as_ref(),
245 Expression::Union(union) => union.with.as_ref(),
246 Expression::Intersect(intersect) => intersect.with.as_ref(),
247 Expression::Except(except) => except.with.as_ref(),
248 _ => None,
249 }
250}
251
252fn projection_facts_for_query(
253 expression: &Expression,
254 scope: &Scope,
255 dialect: DialectType,
256) -> Vec<ProjectionFact> {
257 let expressions = select_expressions_for_query(expression);
258 let names = get_output_column_names(expression);
259
260 expressions
261 .iter()
262 .enumerate()
263 .map(|(index, projection)| {
264 projection_fact(
265 index,
266 names
267 .get(index)
268 .cloned()
269 .or_else(|| projection_name(projection)),
270 projection,
271 expression,
272 scope,
273 dialect,
274 )
275 })
276 .collect()
277}
278
279fn select_expressions_for_query(expression: &Expression) -> Vec<&Expression> {
280 match expression {
281 Expression::Select(select) => select.expressions.iter().collect(),
282 Expression::Union(union) => select_expressions_for_query(&union.left),
283 Expression::Intersect(intersect) => select_expressions_for_query(&intersect.left),
284 Expression::Except(except) => select_expressions_for_query(&except.left),
285 Expression::Subquery(subquery) => select_expressions_for_query(&subquery.this),
286 _ => Vec::new(),
287 }
288}
289
290fn projection_fact(
291 index: usize,
292 name: Option<String>,
293 projection: &Expression,
294 query: &Expression,
295 scope: &Scope,
296 dialect: DialectType,
297) -> ProjectionFact {
298 let inner = unwrap_projection_alias(projection);
299 let is_star = projection_is_star(inner);
300 let upstream = lineage_by_index_from_expression(index, query, Some(dialect), false)
301 .map(|node| terminal_references_from_lineage(&node))
302 .ok()
303 .filter(|refs| !refs.is_empty())
304 .unwrap_or_else(|| fallback_column_references(inner, scope));
305
306 ProjectionFact {
307 index,
308 name,
309 is_star,
310 star_table: projection_star_table(inner),
311 transform_kind: transform_kind(inner),
312 cast_type: cast_type(inner, dialect),
313 type_hint: projection
314 .inferred_type()
315 .or_else(|| inner.inferred_type())
316 .and_then(|data_type| render_data_type(data_type, dialect)),
317 upstream,
318 }
319}
320
321fn unwrap_projection_alias(expression: &Expression) -> &Expression {
322 match expression {
323 Expression::Alias(alias) => unwrap_projection_alias(&alias.this),
324 Expression::Annotated(annotated) => unwrap_projection_alias(&annotated.this),
325 Expression::Paren(paren) => unwrap_projection_alias(&paren.this),
326 _ => expression,
327 }
328}
329
330fn projection_name(expression: &Expression) -> Option<String> {
331 match expression {
332 Expression::Alias(alias) => Some(alias.alias.name.clone()),
333 Expression::Column(column) => Some(column.name.name.clone()),
334 Expression::Identifier(identifier) => Some(identifier.name.clone()),
335 Expression::Star(_) => Some("*".to_string()),
336 Expression::Annotated(annotated) => projection_name(&annotated.this),
337 _ => None,
338 }
339}
340
341fn projection_is_star(expression: &Expression) -> bool {
342 matches!(expression, Expression::Star(_))
343 || matches!(expression, Expression::Column(column) if column.name.name == "*")
344}
345
346fn projection_star_table(expression: &Expression) -> Option<String> {
347 match expression {
348 Expression::Star(star) => star
349 .table
350 .as_ref()
351 .map(|identifier| identifier.name.clone()),
352 Expression::Column(column) if column.name.name == "*" => column
353 .table
354 .as_ref()
355 .map(|identifier| identifier.name.clone()),
356 _ => None,
357 }
358}
359
360fn transform_kind(expression: &Expression) -> TransformKind {
361 if projection_is_star(expression) {
362 TransformKind::Star
363 } else if is_cast_expression(expression) {
364 TransformKind::Cast
365 } else if contains_aggregate(expression) {
366 TransformKind::Aggregation
367 } else if matches!(
368 expression,
369 Expression::Column(_) | Expression::Identifier(_)
370 ) {
371 TransformKind::Direct
372 } else if is_simple_constant(expression) {
373 TransformKind::Constant
374 } else {
375 TransformKind::Expression
376 }
377}
378
379fn is_cast_expression(expression: &Expression) -> bool {
380 matches!(
381 expression,
382 Expression::Cast(_) | Expression::TryCast(_) | Expression::SafeCast(_)
383 )
384}
385
386fn cast_type(expression: &Expression, dialect: DialectType) -> Option<String> {
387 match expression {
388 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
389 render_data_type(&cast.to, dialect)
390 }
391 _ => None,
392 }
393}
394
395fn render_data_type(data_type: &DataType, dialect: DialectType) -> Option<String> {
396 Dialect::get(dialect)
397 .generate(&Expression::DataType(data_type.clone()))
398 .ok()
399}
400
401fn is_simple_constant(expression: &Expression) -> bool {
402 match expression {
403 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_) => true,
404 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
405 is_simple_constant(&cast.this)
406 }
407 Expression::Neg(unary) | Expression::BitwiseNot(unary) => is_simple_constant(&unary.this),
408 _ => false,
409 }
410}
411
412fn terminal_references_from_lineage(node: &LineageNode) -> Vec<ColumnReferenceFact> {
413 let mut refs = Vec::new();
414 collect_terminal_references(node, &mut refs);
415 dedupe_column_refs(refs)
416}
417
418fn collect_terminal_references(node: &LineageNode, refs: &mut Vec<ColumnReferenceFact>) {
419 if node.downstream.is_empty() {
420 if let Some(reference) = column_reference_from_lineage_node(node) {
421 refs.push(reference);
422 }
423 return;
424 }
425
426 for child in &node.downstream {
427 collect_terminal_references(child, refs);
428 }
429}
430
431fn column_reference_from_lineage_node(node: &LineageNode) -> Option<ColumnReferenceFact> {
432 match &node.expression {
433 Expression::Column(column) => {
434 let source_name = non_empty_string(node.source_name.clone());
435 let table =
436 lineage_node_table(node).or_else(|| column.table.as_ref().map(|t| t.name.clone()));
437 let confidence = if node.source_kind == SourceKind::Unknown && source_name.is_none() {
438 ReferenceConfidence::Unknown
439 } else {
440 ReferenceConfidence::Resolved
441 };
442 Some(ColumnReferenceFact {
443 source_name,
444 source_alias: node.source_alias.clone(),
445 source_kind: node.source_kind,
446 table,
447 column: column.name.name.clone(),
448 unqualified: column.table.is_none(),
449 confidence,
450 })
451 }
452 Expression::Star(_) => Some(ColumnReferenceFact {
453 source_name: non_empty_string(node.source_name.clone()),
454 source_alias: node.source_alias.clone(),
455 source_kind: node.source_kind,
456 table: lineage_node_table(node),
457 column: "*".to_string(),
458 unqualified: true,
459 confidence: if node.source_kind == SourceKind::Unknown {
460 ReferenceConfidence::Unknown
461 } else {
462 ReferenceConfidence::Resolved
463 },
464 }),
465 _ => None,
466 }
467}
468
469fn lineage_node_table(node: &LineageNode) -> Option<String> {
470 match &node.source {
471 Expression::Table(table) => Some(table_name(table)),
472 _ => None,
473 }
474}
475
476fn fallback_column_references(expression: &Expression, scope: &Scope) -> Vec<ColumnReferenceFact> {
477 let mut refs = Vec::new();
478 let source_count = scope.sources.len();
479 let single_source = if source_count == 1 {
480 scope.sources.iter().next()
481 } else {
482 None
483 };
484
485 for column_expr in expression.find_all(|candidate| matches!(candidate, Expression::Column(_))) {
486 if let Expression::Column(column) = column_expr {
487 if column.name.name == "*" {
488 continue;
489 }
490 let source = column
491 .table
492 .as_ref()
493 .and_then(|table| scope.sources.get(&table.name));
494 let (source_name, source_alias, source_kind, table, confidence) =
495 if let Some(table_identifier) = &column.table {
496 if let Some(source) = source {
497 (
498 Some(table_identifier.name.clone()),
499 source.alias.clone(),
500 source.kind,
501 source_table_name(source)
502 .or_else(|| Some(table_identifier.name.clone())),
503 ReferenceConfidence::Resolved,
504 )
505 } else {
506 (
507 Some(table_identifier.name.clone()),
508 None,
509 SourceKind::Unknown,
510 Some(table_identifier.name.clone()),
511 ReferenceConfidence::Unknown,
512 )
513 }
514 } else if let Some((name, source)) = single_source {
515 (
516 Some(name.clone()),
517 source.alias.clone(),
518 source.kind,
519 source_table_name(source).or_else(|| Some(name.clone())),
520 ReferenceConfidence::Resolved,
521 )
522 } else if source_count > 1 {
523 (
524 None,
525 None,
526 SourceKind::Unknown,
527 None,
528 ReferenceConfidence::Ambiguous,
529 )
530 } else {
531 (
532 None,
533 None,
534 SourceKind::Unknown,
535 None,
536 ReferenceConfidence::Unknown,
537 )
538 };
539
540 refs.push(ColumnReferenceFact {
541 source_name,
542 source_alias,
543 source_kind,
544 table,
545 column: column.name.name.clone(),
546 unqualified: column.table.is_none(),
547 confidence,
548 });
549 }
550 }
551
552 dedupe_column_refs(refs)
553}
554
555fn dedupe_column_refs(refs: Vec<ColumnReferenceFact>) -> Vec<ColumnReferenceFact> {
556 let mut seen = HashSet::new();
557 let mut deduped = Vec::new();
558
559 for reference in refs {
560 let key = (
561 reference.source_name.clone(),
562 reference.source_alias.clone(),
563 reference.table.clone(),
564 reference.column.clone(),
565 format!("{:?}", reference.source_kind),
566 reference.unqualified,
567 format!("{:?}", reference.confidence),
568 );
569 if seen.insert(key) {
570 deduped.push(reference);
571 }
572 }
573
574 deduped
575}
576
577fn relation_facts(
578 scope: &Scope,
579 mapping_schema: Option<&crate::schema::MappingSchema>,
580) -> Vec<RelationFact> {
581 let mut relations = Vec::new();
582 let mut seen = HashSet::new();
583 collect_relation_facts(scope, mapping_schema, &mut seen, &mut relations);
584
585 relations.sort_by(|left, right| {
586 left.name
587 .cmp(&right.name)
588 .then_with(|| left.alias.cmp(&right.alias))
589 });
590 relations
591}
592
593fn collect_relation_facts(
594 scope: &Scope,
595 mapping_schema: Option<&crate::schema::MappingSchema>,
596 seen: &mut HashSet<String>,
597 relations: &mut Vec<RelationFact>,
598) {
599 for relation in scope
600 .sources
601 .iter()
602 .map(|(source_name, source)| RelationFact {
603 name: source
604 .lineage_name
605 .clone()
606 .or_else(|| source_table_name(source))
607 .unwrap_or_else(|| source_name.clone()),
608 alias: source.alias.clone().or_else(|| source_alias(source)),
609 kind: source.kind,
610 columns: source_columns(source, mapping_schema),
611 })
612 {
613 let key = format!("{:?}|{}|{:?}", relation.kind, relation.name, relation.alias);
614 if seen.insert(key) {
615 relations.push(relation);
616 }
617 }
618
619 for branch_scope in &scope.union_scopes {
620 collect_relation_facts(branch_scope, mapping_schema, seen, relations);
621 }
622}
623
624fn source_columns(
625 source: &SourceInfo,
626 mapping_schema: Option<&crate::schema::MappingSchema>,
627) -> Vec<String> {
628 match &source.expression {
629 Expression::Table(table) => mapping_schema
630 .and_then(|schema| schema.column_names(&table_name(table)).ok())
631 .unwrap_or_default(),
632 Expression::Select(_)
633 | Expression::Union(_)
634 | Expression::Intersect(_)
635 | Expression::Except(_) => get_output_column_names(&source.expression),
636 Expression::Subquery(subquery) => get_output_column_names(&subquery.this),
637 Expression::Cte(cte) if !cte.columns.is_empty() => cte
638 .columns
639 .iter()
640 .map(|column| column.name.clone())
641 .collect(),
642 Expression::Cte(cte) => get_output_column_names(&cte.this),
643 _ => Vec::new(),
644 }
645}
646
647fn source_table_name(source: &SourceInfo) -> Option<String> {
648 match &source.expression {
649 Expression::Table(table) => Some(table_name(table)),
650 _ => None,
651 }
652}
653
654fn source_alias(source: &SourceInfo) -> Option<String> {
655 match &source.expression {
656 Expression::Table(table) => table.alias.as_ref().map(|alias| alias.name.clone()),
657 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|alias| alias.name.clone()),
658 _ => None,
659 }
660}
661
662fn table_name(table: &TableRef) -> String {
663 let mut parts = Vec::new();
664 if let Some(catalog) = &table.catalog {
665 parts.push(catalog.name.clone());
666 }
667 if let Some(schema) = &table.schema {
668 parts.push(schema.name.clone());
669 }
670 parts.push(table.name.name.clone());
671 parts.join(".")
672}
673
674fn set_operation_facts(
675 expression: &Expression,
676 scope: &Scope,
677 dialect: DialectType,
678) -> Vec<SetOperationFact> {
679 let mut facts = Vec::new();
680 collect_set_operation_facts(expression, scope, dialect, &mut facts);
681 facts
682}
683
684fn collect_set_operation_facts(
685 expression: &Expression,
686 scope: &Scope,
687 dialect: DialectType,
688 facts: &mut Vec<SetOperationFact>,
689) {
690 match expression {
691 Expression::Union(union) => {
692 facts.push(SetOperationFact {
693 kind: "union".to_string(),
694 all: union.all,
695 distinct: union.distinct,
696 output_columns: get_output_column_names(expression),
697 branches: set_operation_branches(&union.left, &union.right, scope, dialect),
698 });
699 collect_set_operation_facts(&union.left, scope, dialect, facts);
700 collect_set_operation_facts(&union.right, scope, dialect, facts);
701 }
702 Expression::Intersect(intersect) => {
703 facts.push(SetOperationFact {
704 kind: "intersect".to_string(),
705 all: intersect.all,
706 distinct: intersect.distinct,
707 output_columns: get_output_column_names(expression),
708 branches: set_operation_branches(&intersect.left, &intersect.right, scope, dialect),
709 });
710 collect_set_operation_facts(&intersect.left, scope, dialect, facts);
711 collect_set_operation_facts(&intersect.right, scope, dialect, facts);
712 }
713 Expression::Except(except) => {
714 facts.push(SetOperationFact {
715 kind: "except".to_string(),
716 all: except.all,
717 distinct: except.distinct,
718 output_columns: get_output_column_names(expression),
719 branches: set_operation_branches(&except.left, &except.right, scope, dialect),
720 });
721 collect_set_operation_facts(&except.left, scope, dialect, facts);
722 collect_set_operation_facts(&except.right, scope, dialect, facts);
723 }
724 Expression::Subquery(subquery) => {
725 collect_set_operation_facts(&subquery.this, scope, dialect, facts);
726 }
727 _ => {}
728 }
729}
730
731fn set_operation_branches(
732 left: &Expression,
733 right: &Expression,
734 scope: &Scope,
735 dialect: DialectType,
736) -> Vec<SetOperationBranchFact> {
737 vec![
738 SetOperationBranchFact {
739 index: 0,
740 projections: projection_facts_for_branch(left, scope, dialect),
741 },
742 SetOperationBranchFact {
743 index: 1,
744 projections: projection_facts_for_branch(right, scope, dialect),
745 },
746 ]
747}
748
749fn projection_facts_for_branch(
750 expression: &Expression,
751 root_scope: &Scope,
752 dialect: DialectType,
753) -> Vec<ProjectionFact> {
754 let branch_scope = build_scope(expression);
755 let scope = if branch_scope.sources.is_empty() {
756 root_scope
757 } else {
758 &branch_scope
759 };
760 projection_facts_for_query(expression, scope, dialect)
761}
762
763fn non_empty_string(value: String) -> Option<String> {
764 if value.is_empty() {
765 None
766 } else {
767 Some(value)
768 }
769}