Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::generator::Generator;
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{Scope, SourceInfo};
17use crate::traversal::ExpressionWalk;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21/// Errors that can occur during column resolution
22#[derive(Debug, Error, Clone)]
23pub enum ResolverError {
24    #[error("Unknown table: {0}")]
25    UnknownTable(String),
26
27    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
28    AmbiguousColumn { column: String, sources: String },
29
30    #[error("Column not found: {0}")]
31    ColumnNotFound(String),
32
33    #[error("Unknown set operation: {0}")]
34    UnknownSetOperation(String),
35}
36
37/// Result type for resolver operations
38pub type ResolverResult<T> = Result<T, ResolverError>;
39
40/// Helper for resolving columns to their source tables.
41///
42/// This is a struct so we can lazily load some things and easily share
43/// them across functions.
44pub struct Resolver<'a> {
45    /// The scope being analyzed
46    pub scope: &'a Scope,
47    /// The schema for table/column information
48    schema: &'a dyn Schema,
49    /// The dialect being used
50    pub dialect: Option<DialectType>,
51    /// Whether to infer schema from context
52    infer_schema: bool,
53    /// Cached source columns: source_name -> column names
54    source_columns_cache: HashMap<String, Vec<String>>,
55    /// Cached unambiguous columns: column_name -> source_name
56    unambiguous_columns_cache: Option<HashMap<String, String>>,
57    /// Cached set of all available columns
58    all_columns_cache: Option<HashSet<String>>,
59}
60
61impl<'a> Resolver<'a> {
62    /// Create a new resolver for a scope
63    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
64        Self {
65            scope,
66            schema,
67            dialect: schema.dialect(),
68            infer_schema,
69            source_columns_cache: HashMap::new(),
70            unambiguous_columns_cache: None,
71            all_columns_cache: None,
72        }
73    }
74
75    /// Get the table for a column name.
76    ///
77    /// Returns the table name if it can be found/inferred.
78    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
79        // Try to find table from all sources (unambiguous lookup)
80        let table_name = self.get_table_name_from_sources(column_name, None);
81
82        // If we found a table, return it
83        if table_name.is_some() {
84            return table_name;
85        }
86
87        // If schema inference is enabled and exactly one source has no schema,
88        // assume the column belongs to that source
89        if self.infer_schema {
90            let sources_without_schema: Vec<_> = self
91                .get_all_source_columns()
92                .iter()
93                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
94                .map(|(name, _)| name.clone())
95                .collect();
96
97            if sources_without_schema.len() == 1 {
98                return Some(sources_without_schema[0].clone());
99            }
100        }
101
102        None
103    }
104
105    /// Get the table for a column, returning an Identifier
106    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
107        self.get_table(column_name).map(Identifier::new)
108    }
109
110    /// Check if a table exists in the schema (not necessarily in the current scope).
111    /// Used to detect correlated references to outer scope tables.
112    pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
113        self.schema.column_names(table_name).is_ok()
114    }
115
116    /// Find the table for a column by searching all schema tables not in the current scope.
117    /// Used for correlated subquery resolution: if an unqualified column can't be resolved
118    /// in the current scope, check if it uniquely belongs to an outer-scope table.
119    /// Returns Some(table_name) if the column is found in exactly one non-local table.
120    pub fn find_column_in_outer_schema_tables(&self, column_name: &str) -> Option<String> {
121        let tables = self.schema.find_tables_for_column(column_name);
122        // Filter to tables NOT in the current scope
123        let outer_tables: Vec<String> = tables
124            .into_iter()
125            .filter(|t| !self.scope.sources.contains_key(t))
126            .collect();
127        // Only return if unambiguous (exactly one outer table has this column)
128        if outer_tables.len() == 1 {
129            Some(outer_tables.into_iter().next().unwrap())
130        } else {
131            None
132        }
133    }
134
135    /// Get all available columns across all sources in this scope
136    pub fn all_columns(&mut self) -> &HashSet<String> {
137        if self.all_columns_cache.is_none() {
138            let mut all = HashSet::new();
139            for columns in self.get_all_source_columns().values() {
140                all.extend(columns.iter().cloned());
141            }
142            self.all_columns_cache = Some(all);
143        }
144        self.all_columns_cache
145            .as_ref()
146            .expect("cache populated above")
147    }
148
149    /// Get column names for a source.
150    ///
151    /// Returns the list of column names available from the given source.
152    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
153        // Check cache first
154        if let Some(columns) = self.source_columns_cache.get(source_name) {
155            return Ok(columns.clone());
156        }
157
158        // Get the source info
159        let source_info = self
160            .scope
161            .sources
162            .get(source_name)
163            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
164
165        let columns = self.extract_columns_from_source(source_info)?;
166
167        // Cache the result
168        self.source_columns_cache
169            .insert(source_name.to_string(), columns.clone());
170
171        Ok(columns)
172    }
173
174    /// Extract column names from a source expression
175    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
176        self.get_source_columns_for_expression(&source_info.expression)
177    }
178
179    fn get_source_columns_for_expression(
180        &self,
181        expression: &Expression,
182    ) -> ResolverResult<Vec<String>> {
183        let columns = match expression {
184            Expression::Table(table) => {
185                // For tables, try to get columns from schema.
186                // Build the fully qualified name (catalog.schema.table) to
187                // match how MappingSchema stores hierarchical keys.
188                let table_name = qualified_table_name(table);
189                match self.schema.column_names(&table_name) {
190                    Ok(cols) => cols,
191                    Err(_) => Vec::new(), // Schema might not have this table
192                }
193            }
194            Expression::Subquery(subquery) => {
195                // For subqueries, get named_selects from the inner query
196                self.get_named_selects(&subquery.this)
197            }
198            Expression::Select(select) => {
199                // For derived tables that are SELECT expressions
200                self.get_select_column_names(select)
201            }
202            Expression::Union(union) => {
203                // For UNION, columns come from the set operation
204                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
205            }
206            Expression::Intersect(intersect) => {
207                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
208            }
209            Expression::Except(except) => {
210                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
211            }
212            Expression::Cte(cte) => {
213                if !cte.columns.is_empty() {
214                    cte.columns.iter().map(|c| c.name.clone()).collect()
215                } else {
216                    self.get_named_selects(&cte.this)
217                }
218            }
219            Expression::Pivot(pivot) => self.get_pivot_output_columns(pivot),
220            Expression::Unpivot(unpivot) => self.get_unpivot_output_columns(unpivot),
221            Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
222                alias_output_columns(alias)
223            }
224            Expression::Alias(alias) => {
225                let columns = self.get_source_columns_for_expression(&alias.this)?;
226                apply_alias_columns(columns, &alias.column_aliases)
227            }
228            Expression::Unnest(unnest) => unnest_output_columns(unnest),
229            Expression::Lateral(lateral) => lateral_output_columns(lateral),
230            Expression::LateralView(lateral_view) => lateral_view_output_columns(lateral_view),
231            Expression::Paren(paren) => self.get_source_columns_for_expression(&paren.this)?,
232            _ => Vec::new(),
233        };
234
235        Ok(columns)
236    }
237
238    /// Get named selects (column names) from an expression
239    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
240        match expr {
241            Expression::Select(select) => self.get_select_column_names(select),
242            Expression::Union(union) => {
243                // For unions, use the left side's columns
244                self.get_named_selects(&union.left)
245            }
246            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
247            Expression::Except(except) => self.get_named_selects(&except.left),
248            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
249            Expression::Alias(alias) => {
250                let columns = self.get_named_selects(&alias.this);
251                apply_alias_columns(columns, &alias.column_aliases)
252            }
253            Expression::Paren(paren) => self.get_named_selects(&paren.this),
254            _ => Vec::new(),
255        }
256    }
257
258    /// Get column names from a SELECT expression
259    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
260        select
261            .expressions
262            .iter()
263            .filter_map(|expr| self.get_expression_alias(expr))
264            .collect()
265    }
266
267    /// Get the alias or name for a select expression
268    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
269        match expr {
270            Expression::Alias(alias) => Some(alias.alias.name.clone()),
271            Expression::Column(col) => Some(col.name.name.clone()),
272            Expression::Star(_) => Some("*".to_string()),
273            Expression::Identifier(id) => Some(id.name.clone()),
274            _ => None,
275        }
276    }
277
278    fn get_pivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
279        if pivot.unpivot {
280            return self.get_pivot_unpivot_output_columns(pivot);
281        }
282
283        let pre_columns = self.get_source_output_columns(&pivot.this);
284        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
285            return Vec::new();
286        }
287
288        let excluded = pivot_excluded_source_columns(pivot, self.dialect);
289        let generated = pivot_generated_output_columns(pivot, self.dialect);
290        if excluded.is_empty() || generated.is_empty() {
291            return Vec::new();
292        }
293
294        let mut columns: Vec<String> = pre_columns
295            .into_iter()
296            .filter(|column| !excluded.contains(&normalize_column_name(column, self.dialect)))
297            .collect();
298        columns.extend(generated);
299        apply_alias_columns(columns, &pivot.alias_columns)
300    }
301
302    fn get_pivot_unpivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
303        let pre_columns = self.get_source_output_columns(&pivot.this);
304        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
305            return Vec::new();
306        }
307
308        let input_columns: HashSet<String> = pivot
309            .expressions
310            .iter()
311            .flat_map(expression_column_names)
312            .map(|column| normalize_column_name(&column, self.dialect))
313            .collect();
314        let mut columns: Vec<String> = pre_columns
315            .into_iter()
316            .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
317            .collect();
318
319        if let Some(Expression::UnpivotColumns(unpivot_columns)) = pivot.into.as_deref() {
320            if let Some(name) = expression_name(&unpivot_columns.this) {
321                columns.push(name);
322            }
323            for value_column in &unpivot_columns.expressions {
324                if let Some(name) = expression_name(value_column) {
325                    columns.push(name);
326                }
327            }
328        }
329
330        apply_alias_columns(columns, &pivot.alias_columns)
331    }
332
333    fn get_unpivot_output_columns(&self, unpivot: &crate::expressions::Unpivot) -> Vec<String> {
334        let pre_columns = self.get_source_output_columns(&unpivot.this);
335        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
336            return Vec::new();
337        }
338
339        let input_columns: HashSet<String> = unpivot
340            .columns
341            .iter()
342            .flat_map(expression_column_names)
343            .map(|column| normalize_column_name(&column, self.dialect))
344            .collect();
345        let mut columns: Vec<String> = pre_columns
346            .into_iter()
347            .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
348            .collect();
349        columns.push(unpivot.name_column.name.clone());
350        columns.push(unpivot.value_column.name.clone());
351        columns.extend(
352            unpivot
353                .extra_value_columns
354                .iter()
355                .map(|column| column.name.clone()),
356        );
357        apply_alias_columns(columns, &unpivot.alias_columns)
358    }
359
360    fn get_source_output_columns(&self, source: &Expression) -> Vec<String> {
361        match source {
362            Expression::Table(table) => {
363                if table.schema.is_none() && table.catalog.is_none() {
364                    if let Some(source) = self.scope.cte_sources.get(&table.name.name) {
365                        return self.extract_columns_from_source(source).unwrap_or_default();
366                    }
367                }
368
369                let table_name = qualified_table_name(table);
370                self.schema.column_names(&table_name).unwrap_or_default()
371            }
372            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
373            Expression::Select(select) => self.get_select_column_names(select),
374            Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => self
375                .get_source_columns_from_set_op(source)
376                .unwrap_or_default(),
377            Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
378                alias_output_columns(alias)
379            }
380            Expression::Alias(alias) => {
381                let columns = self.get_source_output_columns(&alias.this);
382                apply_alias_columns(columns, &alias.column_aliases)
383            }
384            Expression::Unnest(unnest) => unnest_output_columns(unnest),
385            Expression::Lateral(lateral) => lateral_output_columns(lateral),
386            Expression::LateralView(lateral_view) => lateral_view_output_columns(lateral_view),
387            Expression::Cte(cte) => {
388                if cte.columns.is_empty() {
389                    self.get_named_selects(&cte.this)
390                } else {
391                    cte.columns
392                        .iter()
393                        .map(|column| column.name.clone())
394                        .collect()
395                }
396            }
397            Expression::Paren(paren) => self.get_source_output_columns(&paren.this),
398            _ => Vec::new(),
399        }
400    }
401
402    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
403    pub fn get_source_columns_from_set_op(
404        &self,
405        expression: &Expression,
406    ) -> ResolverResult<Vec<String>> {
407        match expression {
408            Expression::Select(select) => Ok(self.get_select_column_names(select)),
409            Expression::Subquery(subquery) => {
410                if matches!(
411                    &subquery.this,
412                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
413                ) {
414                    self.get_source_columns_from_set_op(&subquery.this)
415                } else {
416                    Ok(self.get_named_selects(&subquery.this))
417                }
418            }
419            Expression::Alias(alias) => {
420                let columns = self.get_source_columns_from_set_op(&alias.this)?;
421                Ok(apply_alias_columns(columns, &alias.column_aliases))
422            }
423            Expression::Paren(paren) => self.get_source_columns_from_set_op(&paren.this),
424            Expression::Union(union) => {
425                // Standard UNION: columns come from the left side
426                self.get_source_columns_from_set_op(&union.left)
427            }
428            Expression::Intersect(intersect) => {
429                self.get_source_columns_from_set_op(&intersect.left)
430            }
431            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
432            _ => Err(ResolverError::UnknownSetOperation(format!(
433                "{:?}",
434                expression
435            ))),
436        }
437    }
438
439    /// Get all source columns for all sources in the scope
440    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
441        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
442
443        let mut result = HashMap::new();
444        for source_name in source_names {
445            if let Ok(columns) = self.get_source_columns(&source_name) {
446                result.insert(source_name, columns);
447            }
448        }
449        result
450    }
451
452    /// Get the table name for a column from the sources
453    fn get_table_name_from_sources(
454        &mut self,
455        column_name: &str,
456        source_columns: Option<&HashMap<String, Vec<String>>>,
457    ) -> Option<String> {
458        let normalized_column_name = normalize_column_name(column_name, self.dialect);
459        let unambiguous = match source_columns {
460            Some(cols) => self.compute_unambiguous_columns(cols),
461            None => {
462                if self.unambiguous_columns_cache.is_none() {
463                    let all_source_columns = self.get_all_source_columns();
464                    self.unambiguous_columns_cache =
465                        Some(self.compute_unambiguous_columns(&all_source_columns));
466                }
467                self.unambiguous_columns_cache
468                    .clone()
469                    .expect("cache populated above")
470            }
471        };
472
473        unambiguous.get(&normalized_column_name).cloned()
474    }
475
476    /// Compute unambiguous columns mapping
477    ///
478    /// A column is unambiguous if it appears in exactly one source.
479    fn compute_unambiguous_columns(
480        &self,
481        source_columns: &HashMap<String, Vec<String>>,
482    ) -> HashMap<String, String> {
483        if source_columns.is_empty() {
484            return HashMap::new();
485        }
486
487        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
488
489        for (source_name, columns) in source_columns {
490            for column in columns {
491                column_to_sources
492                    .entry(normalize_column_name(column, self.dialect))
493                    .or_default()
494                    .push(source_name.clone());
495            }
496        }
497
498        // Keep only columns that appear in exactly one source
499        column_to_sources
500            .into_iter()
501            .filter(|(_, sources)| sources.len() == 1)
502            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
503            .collect()
504    }
505
506    /// Check if a column is ambiguous (appears in multiple sources)
507    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
508        let normalized_column_name = normalize_column_name(column_name, self.dialect);
509        let all_source_columns = self.get_all_source_columns();
510        let sources_with_column: Vec<_> = all_source_columns
511            .iter()
512            .filter(|(_, columns)| {
513                columns.iter().any(|column| {
514                    normalize_column_name(column, self.dialect) == normalized_column_name
515                })
516            })
517            .map(|(name, _)| name.clone())
518            .collect();
519
520        sources_with_column.len() > 1
521    }
522
523    /// Get all sources that contain a given column
524    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
525        let normalized_column_name = normalize_column_name(column_name, self.dialect);
526        let all_source_columns = self.get_all_source_columns();
527        all_source_columns
528            .iter()
529            .filter(|(_, columns)| {
530                columns.iter().any(|column| {
531                    normalize_column_name(column, self.dialect) == normalized_column_name
532                })
533            })
534            .map(|(name, _)| name.clone())
535            .collect()
536    }
537
538    /// Try to disambiguate a column based on join context
539    ///
540    /// In join conditions, a column can sometimes be disambiguated based on
541    /// which tables have been joined up to that point.
542    pub fn disambiguate_in_join_context(
543        &mut self,
544        column_name: &str,
545        available_sources: &[String],
546    ) -> Option<String> {
547        let normalized_column_name = normalize_column_name(column_name, self.dialect);
548        let mut matching_sources = Vec::new();
549
550        for source_name in available_sources {
551            if let Ok(columns) = self.get_source_columns(source_name) {
552                if columns.iter().any(|column| {
553                    normalize_column_name(column, self.dialect) == normalized_column_name
554                }) {
555                    matching_sources.push(source_name.clone());
556                }
557            }
558        }
559
560        if matching_sources.len() == 1 {
561            Some(matching_sources.remove(0))
562        } else {
563            None
564        }
565    }
566}
567
568fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
569    normalize_name(name, dialect, false, true)
570}
571
572fn apply_alias_columns(mut columns: Vec<String>, alias_columns: &[Identifier]) -> Vec<String> {
573    for (idx, alias) in alias_columns.iter().enumerate() {
574        if let Some(column) = columns.get_mut(idx) {
575            *column = alias.name.clone();
576        }
577    }
578    columns
579}
580
581fn unnest_output_columns(unnest: &crate::expressions::UnnestFunc) -> Vec<String> {
582    unnest
583        .alias
584        .iter()
585        .map(|alias| alias.name.clone())
586        .chain(unnest.offset_alias.iter().map(|alias| alias.name.clone()))
587        .collect()
588}
589
590fn alias_output_columns(alias: &crate::expressions::Alias) -> Vec<String> {
591    if alias.column_aliases.is_empty() {
592        vec![alias.alias.name.clone()]
593    } else {
594        alias
595            .column_aliases
596            .iter()
597            .map(|column| column.name.clone())
598            .collect()
599    }
600}
601
602fn lateral_output_columns(lateral: &crate::expressions::Lateral) -> Vec<String> {
603    if lateral.column_aliases.is_empty() {
604        default_virtual_output_columns(&lateral.this)
605    } else {
606        lateral.column_aliases.clone()
607    }
608}
609
610fn lateral_view_output_columns(lateral_view: &crate::expressions::LateralView) -> Vec<String> {
611    lateral_view
612        .column_aliases
613        .iter()
614        .map(|column| column.name.clone())
615        .collect()
616}
617
618fn default_virtual_output_columns(expression: &Expression) -> Vec<String> {
619    match expression {
620        Expression::Unnest(unnest) => unnest_output_columns(unnest),
621        Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
622            alias_output_columns(alias)
623        }
624        Expression::Function(function) if function.name.eq_ignore_ascii_case("FLATTEN") => {
625            ["seq", "key", "path", "index", "value", "this"]
626                .into_iter()
627                .map(String::from)
628                .collect()
629        }
630        _ => Vec::new(),
631    }
632}
633
634fn pivot_excluded_source_columns(
635    pivot: &crate::expressions::Pivot,
636    dialect: Option<DialectType>,
637) -> HashSet<String> {
638    pivot
639        .fields
640        .iter()
641        .chain(pivot.expressions.iter())
642        .chain(pivot.using.iter())
643        .flat_map(expression_column_names)
644        .map(|column| normalize_column_name(&column, dialect))
645        .collect()
646}
647
648fn pivot_generated_output_columns(
649    pivot: &crate::expressions::Pivot,
650    _dialect: Option<DialectType>,
651) -> Vec<String> {
652    let fields = pivot_field_output_names(pivot);
653    let aggregations = if pivot.using.is_empty() {
654        &pivot.expressions
655    } else {
656        &pivot.using
657    };
658
659    if fields.is_empty() || aggregations.is_empty() {
660        return Vec::new();
661    }
662
663    let needs_suffix = aggregations.len() > 1;
664    let mut outputs = Vec::new();
665    for field in fields {
666        for aggregation in aggregations {
667            if let Some(suffix) = pivot_aggregation_output_suffix(aggregation, needs_suffix) {
668                outputs.push(format!("{field}_{suffix}"));
669            } else {
670                outputs.push(field.clone());
671            }
672        }
673    }
674    outputs
675}
676
677fn pivot_field_output_names(pivot: &crate::expressions::Pivot) -> Vec<String> {
678    pivot
679        .fields
680        .iter()
681        .filter_map(|field| match field {
682            Expression::In(in_expr) => Some(
683                in_expr
684                    .expressions
685                    .iter()
686                    .filter_map(expression_name)
687                    .collect::<Vec<_>>(),
688            ),
689            _ => None,
690        })
691        .flatten()
692        .collect()
693}
694
695fn pivot_aggregation_output_suffix(expr: &Expression, needs_suffix: bool) -> Option<String> {
696    match expr {
697        Expression::Alias(alias) => Some(alias.alias.name.clone()),
698        _ if needs_suffix => Generator::sql(expr).ok().map(|sql| sql.to_lowercase()),
699        _ => None,
700    }
701}
702
703fn expression_name(expr: &Expression) -> Option<String> {
704    match expr {
705        Expression::PivotAlias(alias) => expression_name(&alias.alias),
706        Expression::Alias(alias) => Some(alias.alias.name.clone()),
707        Expression::Identifier(identifier) => Some(identifier.name.clone()),
708        Expression::Column(column) => Some(column.name.name.clone()),
709        Expression::Literal(literal) => Some(literal.value_str().to_string()),
710        Expression::Var(var) => Some(var.this.clone()),
711        Expression::Tuple(tuple) => tuple.expressions.first().and_then(expression_name),
712        _ => None,
713    }
714}
715
716fn expression_column_names(expr: &Expression) -> Vec<String> {
717    expr.find_all(|node| matches!(node, Expression::Column(_)))
718        .into_iter()
719        .filter_map(|node| match node {
720            Expression::Column(column) => Some(column.name.name.clone()),
721            _ => None,
722        })
723        .collect()
724}
725
726/// Resolve a column to its source table.
727///
728/// This is a convenience function that creates a Resolver and calls get_table.
729pub fn resolve_column(
730    scope: &Scope,
731    schema: &dyn Schema,
732    column_name: &str,
733    infer_schema: bool,
734) -> Option<String> {
735    let mut resolver = Resolver::new(scope, schema, infer_schema);
736    resolver.get_table(column_name)
737}
738
739/// Check if a column is ambiguous in the given scope.
740pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
741    let mut resolver = Resolver::new(scope, schema, true);
742    resolver.is_ambiguous(column_name)
743}
744
745/// Build the fully qualified table name (catalog.schema.table) from a TableRef.
746fn qualified_table_name(table: &TableRef) -> String {
747    let mut parts = Vec::new();
748    if let Some(catalog) = &table.catalog {
749        parts.push(catalog.name.clone());
750    }
751    if let Some(schema) = &table.schema {
752        parts.push(schema.name.clone());
753    }
754    parts.push(table.name.name.clone());
755    parts.join(".")
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761    use crate::dialects::Dialect;
762    use crate::expressions::DataType;
763    use crate::parser::Parser;
764    use crate::schema::MappingSchema;
765    use crate::scope::build_scope;
766
767    fn create_test_schema() -> MappingSchema {
768        let mut schema = MappingSchema::new();
769        // Add tables with columns
770        schema
771            .add_table(
772                "users",
773                &[
774                    (
775                        "id".to_string(),
776                        DataType::Int {
777                            length: None,
778                            integer_spelling: false,
779                        },
780                    ),
781                    ("name".to_string(), DataType::Text),
782                    ("email".to_string(), DataType::Text),
783                ],
784                None,
785            )
786            .unwrap();
787        schema
788            .add_table(
789                "orders",
790                &[
791                    (
792                        "id".to_string(),
793                        DataType::Int {
794                            length: None,
795                            integer_spelling: false,
796                        },
797                    ),
798                    (
799                        "user_id".to_string(),
800                        DataType::Int {
801                            length: None,
802                            integer_spelling: false,
803                        },
804                    ),
805                    (
806                        "amount".to_string(),
807                        DataType::Double {
808                            precision: None,
809                            scale: None,
810                        },
811                    ),
812                ],
813                None,
814            )
815            .unwrap();
816        schema
817    }
818
819    #[test]
820    fn test_resolver_basic() {
821        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
822        let scope = build_scope(&ast[0]);
823        let schema = create_test_schema();
824        let mut resolver = Resolver::new(&scope, &schema, true);
825
826        // 'name' should resolve to 'users' since it's the only source
827        let table = resolver.get_table("name");
828        assert_eq!(table, Some("users".to_string()));
829    }
830
831    #[test]
832    fn test_resolver_ambiguous_column() {
833        let ast =
834            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
835                .expect("Failed to parse");
836        let scope = build_scope(&ast[0]);
837        let schema = create_test_schema();
838        let mut resolver = Resolver::new(&scope, &schema, true);
839
840        // 'id' appears in both tables, so it's ambiguous
841        assert!(resolver.is_ambiguous("id"));
842
843        // 'name' only appears in users
844        assert!(!resolver.is_ambiguous("name"));
845
846        // 'amount' only appears in orders
847        assert!(!resolver.is_ambiguous("amount"));
848    }
849
850    #[test]
851    fn test_resolver_unambiguous_column() {
852        let ast = Parser::parse_sql(
853            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
854        )
855        .expect("Failed to parse");
856        let scope = build_scope(&ast[0]);
857        let schema = create_test_schema();
858        let mut resolver = Resolver::new(&scope, &schema, true);
859
860        // 'name' should resolve to 'users'
861        let table = resolver.get_table("name");
862        assert_eq!(table, Some("users".to_string()));
863
864        // 'amount' should resolve to 'orders'
865        let table = resolver.get_table("amount");
866        assert_eq!(table, Some("orders".to_string()));
867    }
868
869    #[test]
870    fn test_resolver_with_alias() {
871        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
872        let scope = build_scope(&ast[0]);
873        let schema = create_test_schema();
874        let _resolver = Resolver::new(&scope, &schema, true);
875
876        // Source should be indexed by alias 'u'
877        assert!(scope.sources.contains_key("u"));
878    }
879
880    #[test]
881    fn test_sources_for_column() {
882        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
883            .expect("Failed to parse");
884        let scope = build_scope(&ast[0]);
885        let schema = create_test_schema();
886        let mut resolver = Resolver::new(&scope, &schema, true);
887
888        // 'id' should be in both users and orders
889        let sources = resolver.sources_for_column("id");
890        assert!(sources.contains(&"users".to_string()));
891        assert!(sources.contains(&"orders".to_string()));
892
893        // 'email' should only be in users
894        let sources = resolver.sources_for_column("email");
895        assert_eq!(sources, vec!["users".to_string()]);
896    }
897
898    #[test]
899    fn test_all_columns() {
900        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
901        let scope = build_scope(&ast[0]);
902        let schema = create_test_schema();
903        let mut resolver = Resolver::new(&scope, &schema, true);
904
905        let all = resolver.all_columns();
906        assert!(all.contains("id"));
907        assert!(all.contains("name"));
908        assert!(all.contains("email"));
909    }
910
911    #[test]
912    fn test_resolver_cte_projected_alias_column() {
913        let ast = Parser::parse_sql(
914            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
915        )
916        .expect("Failed to parse");
917        let scope = build_scope(&ast[0]);
918        let schema = create_test_schema();
919        let mut resolver = Resolver::new(&scope, &schema, true);
920
921        let table = resolver.get_table("emp_id");
922        assert_eq!(table, Some("my_cte".to_string()));
923    }
924
925    #[test]
926    fn test_resolve_column_helper() {
927        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
928        let scope = build_scope(&ast[0]);
929        let schema = create_test_schema();
930
931        let table = resolve_column(&scope, &schema, "name", true);
932        assert_eq!(table, Some("users".to_string()));
933    }
934
935    #[test]
936    fn test_resolver_bigquery_mixed_case_column_names() {
937        let dialect = Dialect::get(DialectType::BigQuery);
938        let expr = dialect
939            .parse("SELECT Name AS name FROM teams")
940            .unwrap()
941            .into_iter()
942            .next()
943            .expect("expected one expression");
944        let scope = build_scope(&expr);
945
946        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
947        schema
948            .add_table(
949                "teams",
950                &[("Name".into(), DataType::String { length: None })],
951                None,
952            )
953            .expect("schema setup");
954
955        let mut resolver = Resolver::new(&scope, &schema, true);
956        let table = resolver.get_table("Name");
957        assert_eq!(table, Some("teams".to_string()));
958
959        let table = resolver.get_table("name");
960        assert_eq!(table, Some("teams".to_string()));
961    }
962}