Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::schema::{normalize_name, Schema};
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Check if a table exists in the schema (not necessarily in the current scope).
109    /// Used to detect correlated references to outer scope tables.
110    pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
111        self.schema.column_names(table_name).is_ok()
112    }
113
114    /// Find the table for a column by searching all schema tables not in the current scope.
115    /// Used for correlated subquery resolution: if an unqualified column can't be resolved
116    /// in the current scope, check if it uniquely belongs to an outer-scope table.
117    /// Returns Some(table_name) if the column is found in exactly one non-local table.
118    pub fn find_column_in_outer_schema_tables(&self, column_name: &str) -> Option<String> {
119        let tables = self.schema.find_tables_for_column(column_name);
120        // Filter to tables NOT in the current scope
121        let outer_tables: Vec<String> = tables
122            .into_iter()
123            .filter(|t| !self.scope.sources.contains_key(t))
124            .collect();
125        // Only return if unambiguous (exactly one outer table has this column)
126        if outer_tables.len() == 1 {
127            Some(outer_tables.into_iter().next().unwrap())
128        } else {
129            None
130        }
131    }
132
133    /// Get all available columns across all sources in this scope
134    pub fn all_columns(&mut self) -> &HashSet<String> {
135        if self.all_columns_cache.is_none() {
136            let mut all = HashSet::new();
137            for columns in self.get_all_source_columns().values() {
138                all.extend(columns.iter().cloned());
139            }
140            self.all_columns_cache = Some(all);
141        }
142        self.all_columns_cache
143            .as_ref()
144            .expect("cache populated above")
145    }
146
147    /// Get column names for a source.
148    ///
149    /// Returns the list of column names available from the given source.
150    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
151        // Check cache first
152        if let Some(columns) = self.source_columns_cache.get(source_name) {
153            return Ok(columns.clone());
154        }
155
156        // Get the source info
157        let source_info = self
158            .scope
159            .sources
160            .get(source_name)
161            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
162
163        let columns = self.extract_columns_from_source(source_info)?;
164
165        // Cache the result
166        self.source_columns_cache
167            .insert(source_name.to_string(), columns.clone());
168
169        Ok(columns)
170    }
171
172    /// Extract column names from a source expression
173    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
174        let columns = match &source_info.expression {
175            Expression::Table(table) => {
176                // For tables, try to get columns from schema.
177                // Build the fully qualified name (catalog.schema.table) to
178                // match how MappingSchema stores hierarchical keys.
179                let table_name = qualified_table_name(table);
180                match self.schema.column_names(&table_name) {
181                    Ok(cols) => cols,
182                    Err(_) => Vec::new(), // Schema might not have this table
183                }
184            }
185            Expression::Subquery(subquery) => {
186                // For subqueries, get named_selects from the inner query
187                self.get_named_selects(&subquery.this)
188            }
189            Expression::Select(select) => {
190                // For derived tables that are SELECT expressions
191                self.get_select_column_names(select)
192            }
193            Expression::Union(union) => {
194                // For UNION, columns come from the set operation
195                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
196            }
197            Expression::Intersect(intersect) => {
198                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
199            }
200            Expression::Except(except) => {
201                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
202            }
203            Expression::Cte(cte) => {
204                if !cte.columns.is_empty() {
205                    cte.columns.iter().map(|c| c.name.clone()).collect()
206                } else {
207                    self.get_named_selects(&cte.this)
208                }
209            }
210            _ => Vec::new(),
211        };
212
213        Ok(columns)
214    }
215
216    /// Get named selects (column names) from an expression
217    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
218        match expr {
219            Expression::Select(select) => self.get_select_column_names(select),
220            Expression::Union(union) => {
221                // For unions, use the left side's columns
222                self.get_named_selects(&union.left)
223            }
224            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
225            Expression::Except(except) => self.get_named_selects(&except.left),
226            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
227            _ => Vec::new(),
228        }
229    }
230
231    /// Get column names from a SELECT expression
232    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
233        select
234            .expressions
235            .iter()
236            .filter_map(|expr| self.get_expression_alias(expr))
237            .collect()
238    }
239
240    /// Get the alias or name for a select expression
241    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
242        match expr {
243            Expression::Alias(alias) => Some(alias.alias.name.clone()),
244            Expression::Column(col) => Some(col.name.name.clone()),
245            Expression::Star(_) => Some("*".to_string()),
246            Expression::Identifier(id) => Some(id.name.clone()),
247            _ => None,
248        }
249    }
250
251    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
252    pub fn get_source_columns_from_set_op(
253        &self,
254        expression: &Expression,
255    ) -> ResolverResult<Vec<String>> {
256        match expression {
257            Expression::Select(select) => Ok(self.get_select_column_names(select)),
258            Expression::Subquery(subquery) => {
259                if matches!(
260                    &subquery.this,
261                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
262                ) {
263                    self.get_source_columns_from_set_op(&subquery.this)
264                } else {
265                    Ok(self.get_named_selects(&subquery.this))
266                }
267            }
268            Expression::Union(union) => {
269                // Standard UNION: columns come from the left side
270                self.get_source_columns_from_set_op(&union.left)
271            }
272            Expression::Intersect(intersect) => {
273                self.get_source_columns_from_set_op(&intersect.left)
274            }
275            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
276            _ => Err(ResolverError::UnknownSetOperation(format!(
277                "{:?}",
278                expression
279            ))),
280        }
281    }
282
283    /// Get all source columns for all sources in the scope
284    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
285        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
286
287        let mut result = HashMap::new();
288        for source_name in source_names {
289            if let Ok(columns) = self.get_source_columns(&source_name) {
290                result.insert(source_name, columns);
291            }
292        }
293        result
294    }
295
296    /// Get the table name for a column from the sources
297    fn get_table_name_from_sources(
298        &mut self,
299        column_name: &str,
300        source_columns: Option<&HashMap<String, Vec<String>>>,
301    ) -> Option<String> {
302        let normalized_column_name = normalize_column_name(column_name, self.dialect);
303        let unambiguous = match source_columns {
304            Some(cols) => self.compute_unambiguous_columns(cols),
305            None => {
306                if self.unambiguous_columns_cache.is_none() {
307                    let all_source_columns = self.get_all_source_columns();
308                    self.unambiguous_columns_cache =
309                        Some(self.compute_unambiguous_columns(&all_source_columns));
310                }
311                self.unambiguous_columns_cache
312                    .clone()
313                    .expect("cache populated above")
314            }
315        };
316
317        unambiguous.get(&normalized_column_name).cloned()
318    }
319
320    /// Compute unambiguous columns mapping
321    ///
322    /// A column is unambiguous if it appears in exactly one source.
323    fn compute_unambiguous_columns(
324        &self,
325        source_columns: &HashMap<String, Vec<String>>,
326    ) -> HashMap<String, String> {
327        if source_columns.is_empty() {
328            return HashMap::new();
329        }
330
331        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
332
333        for (source_name, columns) in source_columns {
334            for column in columns {
335                column_to_sources
336                    .entry(normalize_column_name(column, self.dialect))
337                    .or_default()
338                    .push(source_name.clone());
339            }
340        }
341
342        // Keep only columns that appear in exactly one source
343        column_to_sources
344            .into_iter()
345            .filter(|(_, sources)| sources.len() == 1)
346            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
347            .collect()
348    }
349
350    /// Check if a column is ambiguous (appears in multiple sources)
351    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
352        let normalized_column_name = normalize_column_name(column_name, self.dialect);
353        let all_source_columns = self.get_all_source_columns();
354        let sources_with_column: Vec<_> = all_source_columns
355            .iter()
356            .filter(|(_, columns)| {
357                columns.iter().any(|column| {
358                    normalize_column_name(column, self.dialect) == normalized_column_name
359                })
360            })
361            .map(|(name, _)| name.clone())
362            .collect();
363
364        sources_with_column.len() > 1
365    }
366
367    /// Get all sources that contain a given column
368    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
369        let normalized_column_name = normalize_column_name(column_name, self.dialect);
370        let all_source_columns = self.get_all_source_columns();
371        all_source_columns
372            .iter()
373            .filter(|(_, columns)| {
374                columns.iter().any(|column| {
375                    normalize_column_name(column, self.dialect) == normalized_column_name
376                })
377            })
378            .map(|(name, _)| name.clone())
379            .collect()
380    }
381
382    /// Try to disambiguate a column based on join context
383    ///
384    /// In join conditions, a column can sometimes be disambiguated based on
385    /// which tables have been joined up to that point.
386    pub fn disambiguate_in_join_context(
387        &mut self,
388        column_name: &str,
389        available_sources: &[String],
390    ) -> Option<String> {
391        let normalized_column_name = normalize_column_name(column_name, self.dialect);
392        let mut matching_sources = Vec::new();
393
394        for source_name in available_sources {
395            if let Ok(columns) = self.get_source_columns(source_name) {
396                if columns.iter().any(|column| {
397                    normalize_column_name(column, self.dialect) == normalized_column_name
398                }) {
399                    matching_sources.push(source_name.clone());
400                }
401            }
402        }
403
404        if matching_sources.len() == 1 {
405            Some(matching_sources.remove(0))
406        } else {
407            None
408        }
409    }
410}
411
412fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
413    normalize_name(name, dialect, false, true)
414}
415
416/// Resolve a column to its source table.
417///
418/// This is a convenience function that creates a Resolver and calls get_table.
419pub fn resolve_column(
420    scope: &Scope,
421    schema: &dyn Schema,
422    column_name: &str,
423    infer_schema: bool,
424) -> Option<String> {
425    let mut resolver = Resolver::new(scope, schema, infer_schema);
426    resolver.get_table(column_name)
427}
428
429/// Check if a column is ambiguous in the given scope.
430pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
431    let mut resolver = Resolver::new(scope, schema, true);
432    resolver.is_ambiguous(column_name)
433}
434
435/// Build the fully qualified table name (catalog.schema.table) from a TableRef.
436fn qualified_table_name(table: &TableRef) -> String {
437    let mut parts = Vec::new();
438    if let Some(catalog) = &table.catalog {
439        parts.push(catalog.name.clone());
440    }
441    if let Some(schema) = &table.schema {
442        parts.push(schema.name.clone());
443    }
444    parts.push(table.name.name.clone());
445    parts.join(".")
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use crate::dialects::Dialect;
452    use crate::expressions::DataType;
453    use crate::parser::Parser;
454    use crate::schema::MappingSchema;
455    use crate::scope::build_scope;
456
457    fn create_test_schema() -> MappingSchema {
458        let mut schema = MappingSchema::new();
459        // Add tables with columns
460        schema
461            .add_table(
462                "users",
463                &[
464                    (
465                        "id".to_string(),
466                        DataType::Int {
467                            length: None,
468                            integer_spelling: false,
469                        },
470                    ),
471                    ("name".to_string(), DataType::Text),
472                    ("email".to_string(), DataType::Text),
473                ],
474                None,
475            )
476            .unwrap();
477        schema
478            .add_table(
479                "orders",
480                &[
481                    (
482                        "id".to_string(),
483                        DataType::Int {
484                            length: None,
485                            integer_spelling: false,
486                        },
487                    ),
488                    (
489                        "user_id".to_string(),
490                        DataType::Int {
491                            length: None,
492                            integer_spelling: false,
493                        },
494                    ),
495                    (
496                        "amount".to_string(),
497                        DataType::Double {
498                            precision: None,
499                            scale: None,
500                        },
501                    ),
502                ],
503                None,
504            )
505            .unwrap();
506        schema
507    }
508
509    #[test]
510    fn test_resolver_basic() {
511        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
512        let scope = build_scope(&ast[0]);
513        let schema = create_test_schema();
514        let mut resolver = Resolver::new(&scope, &schema, true);
515
516        // 'name' should resolve to 'users' since it's the only source
517        let table = resolver.get_table("name");
518        assert_eq!(table, Some("users".to_string()));
519    }
520
521    #[test]
522    fn test_resolver_ambiguous_column() {
523        let ast =
524            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
525                .expect("Failed to parse");
526        let scope = build_scope(&ast[0]);
527        let schema = create_test_schema();
528        let mut resolver = Resolver::new(&scope, &schema, true);
529
530        // 'id' appears in both tables, so it's ambiguous
531        assert!(resolver.is_ambiguous("id"));
532
533        // 'name' only appears in users
534        assert!(!resolver.is_ambiguous("name"));
535
536        // 'amount' only appears in orders
537        assert!(!resolver.is_ambiguous("amount"));
538    }
539
540    #[test]
541    fn test_resolver_unambiguous_column() {
542        let ast = Parser::parse_sql(
543            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
544        )
545        .expect("Failed to parse");
546        let scope = build_scope(&ast[0]);
547        let schema = create_test_schema();
548        let mut resolver = Resolver::new(&scope, &schema, true);
549
550        // 'name' should resolve to 'users'
551        let table = resolver.get_table("name");
552        assert_eq!(table, Some("users".to_string()));
553
554        // 'amount' should resolve to 'orders'
555        let table = resolver.get_table("amount");
556        assert_eq!(table, Some("orders".to_string()));
557    }
558
559    #[test]
560    fn test_resolver_with_alias() {
561        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
562        let scope = build_scope(&ast[0]);
563        let schema = create_test_schema();
564        let _resolver = Resolver::new(&scope, &schema, true);
565
566        // Source should be indexed by alias 'u'
567        assert!(scope.sources.contains_key("u"));
568    }
569
570    #[test]
571    fn test_sources_for_column() {
572        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
573            .expect("Failed to parse");
574        let scope = build_scope(&ast[0]);
575        let schema = create_test_schema();
576        let mut resolver = Resolver::new(&scope, &schema, true);
577
578        // 'id' should be in both users and orders
579        let sources = resolver.sources_for_column("id");
580        assert!(sources.contains(&"users".to_string()));
581        assert!(sources.contains(&"orders".to_string()));
582
583        // 'email' should only be in users
584        let sources = resolver.sources_for_column("email");
585        assert_eq!(sources, vec!["users".to_string()]);
586    }
587
588    #[test]
589    fn test_all_columns() {
590        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
591        let scope = build_scope(&ast[0]);
592        let schema = create_test_schema();
593        let mut resolver = Resolver::new(&scope, &schema, true);
594
595        let all = resolver.all_columns();
596        assert!(all.contains("id"));
597        assert!(all.contains("name"));
598        assert!(all.contains("email"));
599    }
600
601    #[test]
602    fn test_resolver_cte_projected_alias_column() {
603        let ast = Parser::parse_sql(
604            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
605        )
606        .expect("Failed to parse");
607        let scope = build_scope(&ast[0]);
608        let schema = create_test_schema();
609        let mut resolver = Resolver::new(&scope, &schema, true);
610
611        let table = resolver.get_table("emp_id");
612        assert_eq!(table, Some("my_cte".to_string()));
613    }
614
615    #[test]
616    fn test_resolve_column_helper() {
617        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
618        let scope = build_scope(&ast[0]);
619        let schema = create_test_schema();
620
621        let table = resolve_column(&scope, &schema, "name", true);
622        assert_eq!(table, Some("users".to_string()));
623    }
624
625    #[test]
626    fn test_resolver_bigquery_mixed_case_column_names() {
627        let dialect = Dialect::get(DialectType::BigQuery);
628        let expr = dialect
629            .parse("SELECT Name AS name FROM teams")
630            .unwrap()
631            .into_iter()
632            .next()
633            .expect("expected one expression");
634        let scope = build_scope(&expr);
635
636        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
637        schema
638            .add_table(
639                "teams",
640                &[("Name".into(), DataType::String { length: None })],
641                None,
642            )
643            .expect("schema setup");
644
645        let mut resolver = Resolver::new(&scope, &schema, true);
646        let table = resolver.get_table("Name");
647        assert_eq!(table, Some("teams".to_string()));
648
649        let table = resolver.get_table("name");
650        assert_eq!(table, Some("teams".to_string()));
651    }
652}