Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::generator::Generator;
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{Scope, SourceInfo};
17use crate::traversal::ExpressionWalk;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21/// Errors that can occur during column resolution
22#[derive(Debug, Error, Clone)]
23pub enum ResolverError {
24    #[error("Unknown table: {0}")]
25    UnknownTable(String),
26
27    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
28    AmbiguousColumn { column: String, sources: String },
29
30    #[error("Column not found: {0}")]
31    ColumnNotFound(String),
32
33    #[error("Unknown set operation: {0}")]
34    UnknownSetOperation(String),
35}
36
37/// Result type for resolver operations
38pub type ResolverResult<T> = Result<T, ResolverError>;
39
40/// Helper for resolving columns to their source tables.
41///
42/// This is a struct so we can lazily load some things and easily share
43/// them across functions.
44pub struct Resolver<'a> {
45    /// The scope being analyzed
46    pub scope: &'a Scope,
47    /// The schema for table/column information
48    schema: &'a dyn Schema,
49    /// The dialect being used
50    pub dialect: Option<DialectType>,
51    /// Whether to infer schema from context
52    infer_schema: bool,
53    /// Cached source columns: source_name -> column names
54    source_columns_cache: HashMap<String, Vec<String>>,
55    /// Cached unambiguous columns: column_name -> source_name
56    unambiguous_columns_cache: Option<HashMap<String, String>>,
57    /// Cached set of all available columns
58    all_columns_cache: Option<HashSet<String>>,
59}
60
61impl<'a> Resolver<'a> {
62    /// Create a new resolver for a scope
63    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
64        Self {
65            scope,
66            schema,
67            dialect: schema.dialect(),
68            infer_schema,
69            source_columns_cache: HashMap::new(),
70            unambiguous_columns_cache: None,
71            all_columns_cache: None,
72        }
73    }
74
75    /// Get the table for a column name.
76    ///
77    /// Returns the table name if it can be found/inferred.
78    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
79        // Try to find table from all sources (unambiguous lookup)
80        let table_name = self.get_table_name_from_sources(column_name, None);
81
82        // If we found a table, return it
83        if table_name.is_some() {
84            return table_name;
85        }
86
87        // If schema inference is enabled and exactly one source has no schema,
88        // assume the column belongs to that source
89        if self.infer_schema {
90            let sources_without_schema: Vec<_> = self
91                .get_all_source_columns()
92                .iter()
93                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
94                .map(|(name, _)| name.clone())
95                .collect();
96
97            if sources_without_schema.len() == 1 {
98                return Some(sources_without_schema[0].clone());
99            }
100        }
101
102        None
103    }
104
105    /// Get the table for a column, returning an Identifier
106    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
107        self.get_table(column_name).map(Identifier::new)
108    }
109
110    /// Check if a table exists in the schema (not necessarily in the current scope).
111    /// Used to detect correlated references to outer scope tables.
112    pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
113        self.schema.column_names(table_name).is_ok()
114    }
115
116    /// Find the table for a column by searching all schema tables not in the current scope.
117    /// Used for correlated subquery resolution: if an unqualified column can't be resolved
118    /// in the current scope, check if it uniquely belongs to an outer-scope table.
119    /// Returns Some(table_name) if the column is found in exactly one non-local table.
120    pub fn find_column_in_outer_schema_tables(&self, column_name: &str) -> Option<String> {
121        let tables = self.schema.find_tables_for_column(column_name);
122        // Filter to tables NOT in the current scope
123        let outer_tables: Vec<String> = tables
124            .into_iter()
125            .filter(|t| !self.scope.sources.contains_key(t))
126            .collect();
127        // Only return if unambiguous (exactly one outer table has this column)
128        if outer_tables.len() == 1 {
129            Some(outer_tables.into_iter().next().unwrap())
130        } else {
131            None
132        }
133    }
134
135    /// Get all available columns across all sources in this scope
136    pub fn all_columns(&mut self) -> &HashSet<String> {
137        if self.all_columns_cache.is_none() {
138            let mut all = HashSet::new();
139            for columns in self.get_all_source_columns().values() {
140                all.extend(columns.iter().cloned());
141            }
142            self.all_columns_cache = Some(all);
143        }
144        self.all_columns_cache
145            .as_ref()
146            .expect("cache populated above")
147    }
148
149    /// Get column names for a source.
150    ///
151    /// Returns the list of column names available from the given source.
152    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
153        // Check cache first
154        if let Some(columns) = self.source_columns_cache.get(source_name) {
155            return Ok(columns.clone());
156        }
157
158        // Get the source info
159        let source_info = self
160            .scope
161            .sources
162            .get(source_name)
163            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
164
165        let columns = self.extract_columns_from_source(source_info)?;
166
167        // Cache the result
168        self.source_columns_cache
169            .insert(source_name.to_string(), columns.clone());
170
171        Ok(columns)
172    }
173
174    /// Extract column names from a source expression
175    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
176        let columns = match &source_info.expression {
177            Expression::Table(table) => {
178                // For tables, try to get columns from schema.
179                // Build the fully qualified name (catalog.schema.table) to
180                // match how MappingSchema stores hierarchical keys.
181                let table_name = qualified_table_name(table);
182                match self.schema.column_names(&table_name) {
183                    Ok(cols) => cols,
184                    Err(_) => Vec::new(), // Schema might not have this table
185                }
186            }
187            Expression::Subquery(subquery) => {
188                // For subqueries, get named_selects from the inner query
189                self.get_named_selects(&subquery.this)
190            }
191            Expression::Select(select) => {
192                // For derived tables that are SELECT expressions
193                self.get_select_column_names(select)
194            }
195            Expression::Union(union) => {
196                // For UNION, columns come from the set operation
197                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
198            }
199            Expression::Intersect(intersect) => {
200                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
201            }
202            Expression::Except(except) => {
203                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
204            }
205            Expression::Cte(cte) => {
206                if !cte.columns.is_empty() {
207                    cte.columns.iter().map(|c| c.name.clone()).collect()
208                } else {
209                    self.get_named_selects(&cte.this)
210                }
211            }
212            Expression::Pivot(pivot) => self.get_pivot_output_columns(pivot),
213            Expression::Unpivot(unpivot) => self.get_unpivot_output_columns(unpivot),
214            _ => Vec::new(),
215        };
216
217        Ok(columns)
218    }
219
220    /// Get named selects (column names) from an expression
221    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
222        match expr {
223            Expression::Select(select) => self.get_select_column_names(select),
224            Expression::Union(union) => {
225                // For unions, use the left side's columns
226                self.get_named_selects(&union.left)
227            }
228            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
229            Expression::Except(except) => self.get_named_selects(&except.left),
230            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
231            _ => Vec::new(),
232        }
233    }
234
235    /// Get column names from a SELECT expression
236    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
237        select
238            .expressions
239            .iter()
240            .filter_map(|expr| self.get_expression_alias(expr))
241            .collect()
242    }
243
244    /// Get the alias or name for a select expression
245    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
246        match expr {
247            Expression::Alias(alias) => Some(alias.alias.name.clone()),
248            Expression::Column(col) => Some(col.name.name.clone()),
249            Expression::Star(_) => Some("*".to_string()),
250            Expression::Identifier(id) => Some(id.name.clone()),
251            _ => None,
252        }
253    }
254
255    fn get_pivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
256        if pivot.unpivot {
257            return self.get_pivot_unpivot_output_columns(pivot);
258        }
259
260        let pre_columns = self.get_source_output_columns(&pivot.this);
261        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
262            return Vec::new();
263        }
264
265        let excluded = pivot_excluded_source_columns(pivot, self.dialect);
266        let generated = pivot_generated_output_columns(pivot, self.dialect);
267        if excluded.is_empty() || generated.is_empty() {
268            return Vec::new();
269        }
270
271        let mut columns: Vec<String> = pre_columns
272            .into_iter()
273            .filter(|column| !excluded.contains(&normalize_column_name(column, self.dialect)))
274            .collect();
275        columns.extend(generated);
276        apply_alias_columns(columns, &pivot.alias_columns)
277    }
278
279    fn get_pivot_unpivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
280        let pre_columns = self.get_source_output_columns(&pivot.this);
281        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
282            return Vec::new();
283        }
284
285        let input_columns: HashSet<String> = pivot
286            .expressions
287            .iter()
288            .flat_map(expression_column_names)
289            .map(|column| normalize_column_name(&column, self.dialect))
290            .collect();
291        let mut columns: Vec<String> = pre_columns
292            .into_iter()
293            .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
294            .collect();
295
296        if let Some(Expression::UnpivotColumns(unpivot_columns)) = pivot.into.as_deref() {
297            if let Some(name) = expression_name(&unpivot_columns.this) {
298                columns.push(name);
299            }
300            for value_column in &unpivot_columns.expressions {
301                if let Some(name) = expression_name(value_column) {
302                    columns.push(name);
303                }
304            }
305        }
306
307        apply_alias_columns(columns, &pivot.alias_columns)
308    }
309
310    fn get_unpivot_output_columns(&self, unpivot: &crate::expressions::Unpivot) -> Vec<String> {
311        let pre_columns = self.get_source_output_columns(&unpivot.this);
312        if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
313            return Vec::new();
314        }
315
316        let input_columns: HashSet<String> = unpivot
317            .columns
318            .iter()
319            .flat_map(expression_column_names)
320            .map(|column| normalize_column_name(&column, self.dialect))
321            .collect();
322        let mut columns: Vec<String> = pre_columns
323            .into_iter()
324            .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
325            .collect();
326        columns.push(unpivot.name_column.name.clone());
327        columns.push(unpivot.value_column.name.clone());
328        columns.extend(
329            unpivot
330                .extra_value_columns
331                .iter()
332                .map(|column| column.name.clone()),
333        );
334        apply_alias_columns(columns, &unpivot.alias_columns)
335    }
336
337    fn get_source_output_columns(&self, source: &Expression) -> Vec<String> {
338        match source {
339            Expression::Table(table) => {
340                if table.schema.is_none() && table.catalog.is_none() {
341                    if let Some(source) = self.scope.cte_sources.get(&table.name.name) {
342                        return self.extract_columns_from_source(source).unwrap_or_default();
343                    }
344                }
345
346                let table_name = qualified_table_name(table);
347                self.schema.column_names(&table_name).unwrap_or_default()
348            }
349            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
350            Expression::Select(select) => self.get_select_column_names(select),
351            Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => self
352                .get_source_columns_from_set_op(source)
353                .unwrap_or_default(),
354            Expression::Cte(cte) => {
355                if cte.columns.is_empty() {
356                    self.get_named_selects(&cte.this)
357                } else {
358                    cte.columns
359                        .iter()
360                        .map(|column| column.name.clone())
361                        .collect()
362                }
363            }
364            Expression::Paren(paren) => self.get_source_output_columns(&paren.this),
365            _ => Vec::new(),
366        }
367    }
368
369    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
370    pub fn get_source_columns_from_set_op(
371        &self,
372        expression: &Expression,
373    ) -> ResolverResult<Vec<String>> {
374        match expression {
375            Expression::Select(select) => Ok(self.get_select_column_names(select)),
376            Expression::Subquery(subquery) => {
377                if matches!(
378                    &subquery.this,
379                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
380                ) {
381                    self.get_source_columns_from_set_op(&subquery.this)
382                } else {
383                    Ok(self.get_named_selects(&subquery.this))
384                }
385            }
386            Expression::Union(union) => {
387                // Standard UNION: columns come from the left side
388                self.get_source_columns_from_set_op(&union.left)
389            }
390            Expression::Intersect(intersect) => {
391                self.get_source_columns_from_set_op(&intersect.left)
392            }
393            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
394            _ => Err(ResolverError::UnknownSetOperation(format!(
395                "{:?}",
396                expression
397            ))),
398        }
399    }
400
401    /// Get all source columns for all sources in the scope
402    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
403        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
404
405        let mut result = HashMap::new();
406        for source_name in source_names {
407            if let Ok(columns) = self.get_source_columns(&source_name) {
408                result.insert(source_name, columns);
409            }
410        }
411        result
412    }
413
414    /// Get the table name for a column from the sources
415    fn get_table_name_from_sources(
416        &mut self,
417        column_name: &str,
418        source_columns: Option<&HashMap<String, Vec<String>>>,
419    ) -> Option<String> {
420        let normalized_column_name = normalize_column_name(column_name, self.dialect);
421        let unambiguous = match source_columns {
422            Some(cols) => self.compute_unambiguous_columns(cols),
423            None => {
424                if self.unambiguous_columns_cache.is_none() {
425                    let all_source_columns = self.get_all_source_columns();
426                    self.unambiguous_columns_cache =
427                        Some(self.compute_unambiguous_columns(&all_source_columns));
428                }
429                self.unambiguous_columns_cache
430                    .clone()
431                    .expect("cache populated above")
432            }
433        };
434
435        unambiguous.get(&normalized_column_name).cloned()
436    }
437
438    /// Compute unambiguous columns mapping
439    ///
440    /// A column is unambiguous if it appears in exactly one source.
441    fn compute_unambiguous_columns(
442        &self,
443        source_columns: &HashMap<String, Vec<String>>,
444    ) -> HashMap<String, String> {
445        if source_columns.is_empty() {
446            return HashMap::new();
447        }
448
449        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
450
451        for (source_name, columns) in source_columns {
452            for column in columns {
453                column_to_sources
454                    .entry(normalize_column_name(column, self.dialect))
455                    .or_default()
456                    .push(source_name.clone());
457            }
458        }
459
460        // Keep only columns that appear in exactly one source
461        column_to_sources
462            .into_iter()
463            .filter(|(_, sources)| sources.len() == 1)
464            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
465            .collect()
466    }
467
468    /// Check if a column is ambiguous (appears in multiple sources)
469    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
470        let normalized_column_name = normalize_column_name(column_name, self.dialect);
471        let all_source_columns = self.get_all_source_columns();
472        let sources_with_column: Vec<_> = all_source_columns
473            .iter()
474            .filter(|(_, columns)| {
475                columns.iter().any(|column| {
476                    normalize_column_name(column, self.dialect) == normalized_column_name
477                })
478            })
479            .map(|(name, _)| name.clone())
480            .collect();
481
482        sources_with_column.len() > 1
483    }
484
485    /// Get all sources that contain a given column
486    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
487        let normalized_column_name = normalize_column_name(column_name, self.dialect);
488        let all_source_columns = self.get_all_source_columns();
489        all_source_columns
490            .iter()
491            .filter(|(_, columns)| {
492                columns.iter().any(|column| {
493                    normalize_column_name(column, self.dialect) == normalized_column_name
494                })
495            })
496            .map(|(name, _)| name.clone())
497            .collect()
498    }
499
500    /// Try to disambiguate a column based on join context
501    ///
502    /// In join conditions, a column can sometimes be disambiguated based on
503    /// which tables have been joined up to that point.
504    pub fn disambiguate_in_join_context(
505        &mut self,
506        column_name: &str,
507        available_sources: &[String],
508    ) -> Option<String> {
509        let normalized_column_name = normalize_column_name(column_name, self.dialect);
510        let mut matching_sources = Vec::new();
511
512        for source_name in available_sources {
513            if let Ok(columns) = self.get_source_columns(source_name) {
514                if columns.iter().any(|column| {
515                    normalize_column_name(column, self.dialect) == normalized_column_name
516                }) {
517                    matching_sources.push(source_name.clone());
518                }
519            }
520        }
521
522        if matching_sources.len() == 1 {
523            Some(matching_sources.remove(0))
524        } else {
525            None
526        }
527    }
528}
529
530fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
531    normalize_name(name, dialect, false, true)
532}
533
534fn apply_alias_columns(mut columns: Vec<String>, alias_columns: &[Identifier]) -> Vec<String> {
535    for (idx, alias) in alias_columns.iter().enumerate() {
536        if let Some(column) = columns.get_mut(idx) {
537            *column = alias.name.clone();
538        }
539    }
540    columns
541}
542
543fn pivot_excluded_source_columns(
544    pivot: &crate::expressions::Pivot,
545    dialect: Option<DialectType>,
546) -> HashSet<String> {
547    pivot
548        .fields
549        .iter()
550        .chain(pivot.expressions.iter())
551        .chain(pivot.using.iter())
552        .flat_map(expression_column_names)
553        .map(|column| normalize_column_name(&column, dialect))
554        .collect()
555}
556
557fn pivot_generated_output_columns(
558    pivot: &crate::expressions::Pivot,
559    _dialect: Option<DialectType>,
560) -> Vec<String> {
561    let fields = pivot_field_output_names(pivot);
562    let aggregations = if pivot.using.is_empty() {
563        &pivot.expressions
564    } else {
565        &pivot.using
566    };
567
568    if fields.is_empty() || aggregations.is_empty() {
569        return Vec::new();
570    }
571
572    let needs_suffix = aggregations.len() > 1;
573    let mut outputs = Vec::new();
574    for field in fields {
575        for aggregation in aggregations {
576            if let Some(suffix) = pivot_aggregation_output_suffix(aggregation, needs_suffix) {
577                outputs.push(format!("{field}_{suffix}"));
578            } else {
579                outputs.push(field.clone());
580            }
581        }
582    }
583    outputs
584}
585
586fn pivot_field_output_names(pivot: &crate::expressions::Pivot) -> Vec<String> {
587    pivot
588        .fields
589        .iter()
590        .filter_map(|field| match field {
591            Expression::In(in_expr) => Some(
592                in_expr
593                    .expressions
594                    .iter()
595                    .filter_map(expression_name)
596                    .collect::<Vec<_>>(),
597            ),
598            _ => None,
599        })
600        .flatten()
601        .collect()
602}
603
604fn pivot_aggregation_output_suffix(expr: &Expression, needs_suffix: bool) -> Option<String> {
605    match expr {
606        Expression::Alias(alias) => Some(alias.alias.name.clone()),
607        _ if needs_suffix => Generator::sql(expr).ok().map(|sql| sql.to_lowercase()),
608        _ => None,
609    }
610}
611
612fn expression_name(expr: &Expression) -> Option<String> {
613    match expr {
614        Expression::PivotAlias(alias) => expression_name(&alias.alias),
615        Expression::Alias(alias) => Some(alias.alias.name.clone()),
616        Expression::Identifier(identifier) => Some(identifier.name.clone()),
617        Expression::Column(column) => Some(column.name.name.clone()),
618        Expression::Literal(literal) => Some(literal.value_str().to_string()),
619        Expression::Var(var) => Some(var.this.clone()),
620        Expression::Tuple(tuple) => tuple.expressions.first().and_then(expression_name),
621        _ => None,
622    }
623}
624
625fn expression_column_names(expr: &Expression) -> Vec<String> {
626    expr.find_all(|node| matches!(node, Expression::Column(_)))
627        .into_iter()
628        .filter_map(|node| match node {
629            Expression::Column(column) => Some(column.name.name.clone()),
630            _ => None,
631        })
632        .collect()
633}
634
635/// Resolve a column to its source table.
636///
637/// This is a convenience function that creates a Resolver and calls get_table.
638pub fn resolve_column(
639    scope: &Scope,
640    schema: &dyn Schema,
641    column_name: &str,
642    infer_schema: bool,
643) -> Option<String> {
644    let mut resolver = Resolver::new(scope, schema, infer_schema);
645    resolver.get_table(column_name)
646}
647
648/// Check if a column is ambiguous in the given scope.
649pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
650    let mut resolver = Resolver::new(scope, schema, true);
651    resolver.is_ambiguous(column_name)
652}
653
654/// Build the fully qualified table name (catalog.schema.table) from a TableRef.
655fn qualified_table_name(table: &TableRef) -> String {
656    let mut parts = Vec::new();
657    if let Some(catalog) = &table.catalog {
658        parts.push(catalog.name.clone());
659    }
660    if let Some(schema) = &table.schema {
661        parts.push(schema.name.clone());
662    }
663    parts.push(table.name.name.clone());
664    parts.join(".")
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use crate::dialects::Dialect;
671    use crate::expressions::DataType;
672    use crate::parser::Parser;
673    use crate::schema::MappingSchema;
674    use crate::scope::build_scope;
675
676    fn create_test_schema() -> MappingSchema {
677        let mut schema = MappingSchema::new();
678        // Add tables with columns
679        schema
680            .add_table(
681                "users",
682                &[
683                    (
684                        "id".to_string(),
685                        DataType::Int {
686                            length: None,
687                            integer_spelling: false,
688                        },
689                    ),
690                    ("name".to_string(), DataType::Text),
691                    ("email".to_string(), DataType::Text),
692                ],
693                None,
694            )
695            .unwrap();
696        schema
697            .add_table(
698                "orders",
699                &[
700                    (
701                        "id".to_string(),
702                        DataType::Int {
703                            length: None,
704                            integer_spelling: false,
705                        },
706                    ),
707                    (
708                        "user_id".to_string(),
709                        DataType::Int {
710                            length: None,
711                            integer_spelling: false,
712                        },
713                    ),
714                    (
715                        "amount".to_string(),
716                        DataType::Double {
717                            precision: None,
718                            scale: None,
719                        },
720                    ),
721                ],
722                None,
723            )
724            .unwrap();
725        schema
726    }
727
728    #[test]
729    fn test_resolver_basic() {
730        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
731        let scope = build_scope(&ast[0]);
732        let schema = create_test_schema();
733        let mut resolver = Resolver::new(&scope, &schema, true);
734
735        // 'name' should resolve to 'users' since it's the only source
736        let table = resolver.get_table("name");
737        assert_eq!(table, Some("users".to_string()));
738    }
739
740    #[test]
741    fn test_resolver_ambiguous_column() {
742        let ast =
743            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
744                .expect("Failed to parse");
745        let scope = build_scope(&ast[0]);
746        let schema = create_test_schema();
747        let mut resolver = Resolver::new(&scope, &schema, true);
748
749        // 'id' appears in both tables, so it's ambiguous
750        assert!(resolver.is_ambiguous("id"));
751
752        // 'name' only appears in users
753        assert!(!resolver.is_ambiguous("name"));
754
755        // 'amount' only appears in orders
756        assert!(!resolver.is_ambiguous("amount"));
757    }
758
759    #[test]
760    fn test_resolver_unambiguous_column() {
761        let ast = Parser::parse_sql(
762            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
763        )
764        .expect("Failed to parse");
765        let scope = build_scope(&ast[0]);
766        let schema = create_test_schema();
767        let mut resolver = Resolver::new(&scope, &schema, true);
768
769        // 'name' should resolve to 'users'
770        let table = resolver.get_table("name");
771        assert_eq!(table, Some("users".to_string()));
772
773        // 'amount' should resolve to 'orders'
774        let table = resolver.get_table("amount");
775        assert_eq!(table, Some("orders".to_string()));
776    }
777
778    #[test]
779    fn test_resolver_with_alias() {
780        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
781        let scope = build_scope(&ast[0]);
782        let schema = create_test_schema();
783        let _resolver = Resolver::new(&scope, &schema, true);
784
785        // Source should be indexed by alias 'u'
786        assert!(scope.sources.contains_key("u"));
787    }
788
789    #[test]
790    fn test_sources_for_column() {
791        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
792            .expect("Failed to parse");
793        let scope = build_scope(&ast[0]);
794        let schema = create_test_schema();
795        let mut resolver = Resolver::new(&scope, &schema, true);
796
797        // 'id' should be in both users and orders
798        let sources = resolver.sources_for_column("id");
799        assert!(sources.contains(&"users".to_string()));
800        assert!(sources.contains(&"orders".to_string()));
801
802        // 'email' should only be in users
803        let sources = resolver.sources_for_column("email");
804        assert_eq!(sources, vec!["users".to_string()]);
805    }
806
807    #[test]
808    fn test_all_columns() {
809        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
810        let scope = build_scope(&ast[0]);
811        let schema = create_test_schema();
812        let mut resolver = Resolver::new(&scope, &schema, true);
813
814        let all = resolver.all_columns();
815        assert!(all.contains("id"));
816        assert!(all.contains("name"));
817        assert!(all.contains("email"));
818    }
819
820    #[test]
821    fn test_resolver_cte_projected_alias_column() {
822        let ast = Parser::parse_sql(
823            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
824        )
825        .expect("Failed to parse");
826        let scope = build_scope(&ast[0]);
827        let schema = create_test_schema();
828        let mut resolver = Resolver::new(&scope, &schema, true);
829
830        let table = resolver.get_table("emp_id");
831        assert_eq!(table, Some("my_cte".to_string()));
832    }
833
834    #[test]
835    fn test_resolve_column_helper() {
836        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
837        let scope = build_scope(&ast[0]);
838        let schema = create_test_schema();
839
840        let table = resolve_column(&scope, &schema, "name", true);
841        assert_eq!(table, Some("users".to_string()));
842    }
843
844    #[test]
845    fn test_resolver_bigquery_mixed_case_column_names() {
846        let dialect = Dialect::get(DialectType::BigQuery);
847        let expr = dialect
848            .parse("SELECT Name AS name FROM teams")
849            .unwrap()
850            .into_iter()
851            .next()
852            .expect("expected one expression");
853        let scope = build_scope(&expr);
854
855        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
856        schema
857            .add_table(
858                "teams",
859                &[("Name".into(), DataType::String { length: None })],
860                None,
861            )
862            .expect("schema setup");
863
864        let mut resolver = Resolver::new(&scope, &schema, true);
865        let table = resolver.get_table("Name");
866        assert_eq!(table, Some("teams".to_string()));
867
868        let table = resolver.get_table("name");
869        assert_eq!(table, Some("teams".to_string()));
870    }
871}