Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::schema::{normalize_name, Schema};
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Check if a table exists in the schema (not necessarily in the current scope).
109    /// Used to detect correlated references to outer scope tables.
110    pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
111        self.schema.column_names(table_name).is_ok()
112    }
113
114    /// Get all available columns across all sources in this scope
115    pub fn all_columns(&mut self) -> &HashSet<String> {
116        if self.all_columns_cache.is_none() {
117            let mut all = HashSet::new();
118            for columns in self.get_all_source_columns().values() {
119                all.extend(columns.iter().cloned());
120            }
121            self.all_columns_cache = Some(all);
122        }
123        self.all_columns_cache
124            .as_ref()
125            .expect("cache populated above")
126    }
127
128    /// Get column names for a source.
129    ///
130    /// Returns the list of column names available from the given source.
131    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
132        // Check cache first
133        if let Some(columns) = self.source_columns_cache.get(source_name) {
134            return Ok(columns.clone());
135        }
136
137        // Get the source info
138        let source_info = self
139            .scope
140            .sources
141            .get(source_name)
142            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
143
144        let columns = self.extract_columns_from_source(source_info)?;
145
146        // Cache the result
147        self.source_columns_cache
148            .insert(source_name.to_string(), columns.clone());
149
150        Ok(columns)
151    }
152
153    /// Extract column names from a source expression
154    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
155        let columns = match &source_info.expression {
156            Expression::Table(table) => {
157                // For tables, try to get columns from schema.
158                // Build the fully qualified name (catalog.schema.table) to
159                // match how MappingSchema stores hierarchical keys.
160                let table_name = qualified_table_name(table);
161                match self.schema.column_names(&table_name) {
162                    Ok(cols) => cols,
163                    Err(_) => Vec::new(), // Schema might not have this table
164                }
165            }
166            Expression::Subquery(subquery) => {
167                // For subqueries, get named_selects from the inner query
168                self.get_named_selects(&subquery.this)
169            }
170            Expression::Select(select) => {
171                // For derived tables that are SELECT expressions
172                self.get_select_column_names(select)
173            }
174            Expression::Union(union) => {
175                // For UNION, columns come from the set operation
176                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
177            }
178            Expression::Intersect(intersect) => {
179                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
180            }
181            Expression::Except(except) => {
182                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
183            }
184            Expression::Cte(cte) => {
185                if !cte.columns.is_empty() {
186                    cte.columns.iter().map(|c| c.name.clone()).collect()
187                } else {
188                    self.get_named_selects(&cte.this)
189                }
190            }
191            _ => Vec::new(),
192        };
193
194        Ok(columns)
195    }
196
197    /// Get named selects (column names) from an expression
198    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
199        match expr {
200            Expression::Select(select) => self.get_select_column_names(select),
201            Expression::Union(union) => {
202                // For unions, use the left side's columns
203                self.get_named_selects(&union.left)
204            }
205            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
206            Expression::Except(except) => self.get_named_selects(&except.left),
207            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
208            _ => Vec::new(),
209        }
210    }
211
212    /// Get column names from a SELECT expression
213    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
214        select
215            .expressions
216            .iter()
217            .filter_map(|expr| self.get_expression_alias(expr))
218            .collect()
219    }
220
221    /// Get the alias or name for a select expression
222    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
223        match expr {
224            Expression::Alias(alias) => Some(alias.alias.name.clone()),
225            Expression::Column(col) => Some(col.name.name.clone()),
226            Expression::Star(_) => Some("*".to_string()),
227            Expression::Identifier(id) => Some(id.name.clone()),
228            _ => None,
229        }
230    }
231
232    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
233    pub fn get_source_columns_from_set_op(
234        &self,
235        expression: &Expression,
236    ) -> ResolverResult<Vec<String>> {
237        match expression {
238            Expression::Select(select) => Ok(self.get_select_column_names(select)),
239            Expression::Subquery(subquery) => {
240                if matches!(
241                    &subquery.this,
242                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
243                ) {
244                    self.get_source_columns_from_set_op(&subquery.this)
245                } else {
246                    Ok(self.get_named_selects(&subquery.this))
247                }
248            }
249            Expression::Union(union) => {
250                // Standard UNION: columns come from the left side
251                self.get_source_columns_from_set_op(&union.left)
252            }
253            Expression::Intersect(intersect) => {
254                self.get_source_columns_from_set_op(&intersect.left)
255            }
256            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
257            _ => Err(ResolverError::UnknownSetOperation(format!(
258                "{:?}",
259                expression
260            ))),
261        }
262    }
263
264    /// Get all source columns for all sources in the scope
265    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
266        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
267
268        let mut result = HashMap::new();
269        for source_name in source_names {
270            if let Ok(columns) = self.get_source_columns(&source_name) {
271                result.insert(source_name, columns);
272            }
273        }
274        result
275    }
276
277    /// Get the table name for a column from the sources
278    fn get_table_name_from_sources(
279        &mut self,
280        column_name: &str,
281        source_columns: Option<&HashMap<String, Vec<String>>>,
282    ) -> Option<String> {
283        let normalized_column_name = normalize_column_name(column_name, self.dialect);
284        let unambiguous = match source_columns {
285            Some(cols) => self.compute_unambiguous_columns(cols),
286            None => {
287                if self.unambiguous_columns_cache.is_none() {
288                    let all_source_columns = self.get_all_source_columns();
289                    self.unambiguous_columns_cache =
290                        Some(self.compute_unambiguous_columns(&all_source_columns));
291                }
292                self.unambiguous_columns_cache
293                    .clone()
294                    .expect("cache populated above")
295            }
296        };
297
298        unambiguous.get(&normalized_column_name).cloned()
299    }
300
301    /// Compute unambiguous columns mapping
302    ///
303    /// A column is unambiguous if it appears in exactly one source.
304    fn compute_unambiguous_columns(
305        &self,
306        source_columns: &HashMap<String, Vec<String>>,
307    ) -> HashMap<String, String> {
308        if source_columns.is_empty() {
309            return HashMap::new();
310        }
311
312        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
313
314        for (source_name, columns) in source_columns {
315            for column in columns {
316                column_to_sources
317                    .entry(normalize_column_name(column, self.dialect))
318                    .or_default()
319                    .push(source_name.clone());
320            }
321        }
322
323        // Keep only columns that appear in exactly one source
324        column_to_sources
325            .into_iter()
326            .filter(|(_, sources)| sources.len() == 1)
327            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
328            .collect()
329    }
330
331    /// Check if a column is ambiguous (appears in multiple sources)
332    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
333        let normalized_column_name = normalize_column_name(column_name, self.dialect);
334        let all_source_columns = self.get_all_source_columns();
335        let sources_with_column: Vec<_> = all_source_columns
336            .iter()
337            .filter(|(_, columns)| {
338                columns.iter().any(|column| {
339                    normalize_column_name(column, self.dialect) == normalized_column_name
340                })
341            })
342            .map(|(name, _)| name.clone())
343            .collect();
344
345        sources_with_column.len() > 1
346    }
347
348    /// Get all sources that contain a given column
349    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
350        let normalized_column_name = normalize_column_name(column_name, self.dialect);
351        let all_source_columns = self.get_all_source_columns();
352        all_source_columns
353            .iter()
354            .filter(|(_, columns)| {
355                columns.iter().any(|column| {
356                    normalize_column_name(column, self.dialect) == normalized_column_name
357                })
358            })
359            .map(|(name, _)| name.clone())
360            .collect()
361    }
362
363    /// Try to disambiguate a column based on join context
364    ///
365    /// In join conditions, a column can sometimes be disambiguated based on
366    /// which tables have been joined up to that point.
367    pub fn disambiguate_in_join_context(
368        &mut self,
369        column_name: &str,
370        available_sources: &[String],
371    ) -> Option<String> {
372        let normalized_column_name = normalize_column_name(column_name, self.dialect);
373        let mut matching_sources = Vec::new();
374
375        for source_name in available_sources {
376            if let Ok(columns) = self.get_source_columns(source_name) {
377                if columns.iter().any(|column| {
378                    normalize_column_name(column, self.dialect) == normalized_column_name
379                }) {
380                    matching_sources.push(source_name.clone());
381                }
382            }
383        }
384
385        if matching_sources.len() == 1 {
386            Some(matching_sources.remove(0))
387        } else {
388            None
389        }
390    }
391}
392
393fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
394    normalize_name(name, dialect, false, true)
395}
396
397/// Resolve a column to its source table.
398///
399/// This is a convenience function that creates a Resolver and calls get_table.
400pub fn resolve_column(
401    scope: &Scope,
402    schema: &dyn Schema,
403    column_name: &str,
404    infer_schema: bool,
405) -> Option<String> {
406    let mut resolver = Resolver::new(scope, schema, infer_schema);
407    resolver.get_table(column_name)
408}
409
410/// Check if a column is ambiguous in the given scope.
411pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
412    let mut resolver = Resolver::new(scope, schema, true);
413    resolver.is_ambiguous(column_name)
414}
415
416/// Build the fully qualified table name (catalog.schema.table) from a TableRef.
417fn qualified_table_name(table: &TableRef) -> String {
418    let mut parts = Vec::new();
419    if let Some(catalog) = &table.catalog {
420        parts.push(catalog.name.clone());
421    }
422    if let Some(schema) = &table.schema {
423        parts.push(schema.name.clone());
424    }
425    parts.push(table.name.name.clone());
426    parts.join(".")
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::dialects::Dialect;
433    use crate::expressions::DataType;
434    use crate::parser::Parser;
435    use crate::schema::MappingSchema;
436    use crate::scope::build_scope;
437
438    fn create_test_schema() -> MappingSchema {
439        let mut schema = MappingSchema::new();
440        // Add tables with columns
441        schema
442            .add_table(
443                "users",
444                &[
445                    (
446                        "id".to_string(),
447                        DataType::Int {
448                            length: None,
449                            integer_spelling: false,
450                        },
451                    ),
452                    ("name".to_string(), DataType::Text),
453                    ("email".to_string(), DataType::Text),
454                ],
455                None,
456            )
457            .unwrap();
458        schema
459            .add_table(
460                "orders",
461                &[
462                    (
463                        "id".to_string(),
464                        DataType::Int {
465                            length: None,
466                            integer_spelling: false,
467                        },
468                    ),
469                    (
470                        "user_id".to_string(),
471                        DataType::Int {
472                            length: None,
473                            integer_spelling: false,
474                        },
475                    ),
476                    (
477                        "amount".to_string(),
478                        DataType::Double {
479                            precision: None,
480                            scale: None,
481                        },
482                    ),
483                ],
484                None,
485            )
486            .unwrap();
487        schema
488    }
489
490    #[test]
491    fn test_resolver_basic() {
492        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
493        let scope = build_scope(&ast[0]);
494        let schema = create_test_schema();
495        let mut resolver = Resolver::new(&scope, &schema, true);
496
497        // 'name' should resolve to 'users' since it's the only source
498        let table = resolver.get_table("name");
499        assert_eq!(table, Some("users".to_string()));
500    }
501
502    #[test]
503    fn test_resolver_ambiguous_column() {
504        let ast =
505            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
506                .expect("Failed to parse");
507        let scope = build_scope(&ast[0]);
508        let schema = create_test_schema();
509        let mut resolver = Resolver::new(&scope, &schema, true);
510
511        // 'id' appears in both tables, so it's ambiguous
512        assert!(resolver.is_ambiguous("id"));
513
514        // 'name' only appears in users
515        assert!(!resolver.is_ambiguous("name"));
516
517        // 'amount' only appears in orders
518        assert!(!resolver.is_ambiguous("amount"));
519    }
520
521    #[test]
522    fn test_resolver_unambiguous_column() {
523        let ast = Parser::parse_sql(
524            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
525        )
526        .expect("Failed to parse");
527        let scope = build_scope(&ast[0]);
528        let schema = create_test_schema();
529        let mut resolver = Resolver::new(&scope, &schema, true);
530
531        // 'name' should resolve to 'users'
532        let table = resolver.get_table("name");
533        assert_eq!(table, Some("users".to_string()));
534
535        // 'amount' should resolve to 'orders'
536        let table = resolver.get_table("amount");
537        assert_eq!(table, Some("orders".to_string()));
538    }
539
540    #[test]
541    fn test_resolver_with_alias() {
542        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
543        let scope = build_scope(&ast[0]);
544        let schema = create_test_schema();
545        let _resolver = Resolver::new(&scope, &schema, true);
546
547        // Source should be indexed by alias 'u'
548        assert!(scope.sources.contains_key("u"));
549    }
550
551    #[test]
552    fn test_sources_for_column() {
553        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
554            .expect("Failed to parse");
555        let scope = build_scope(&ast[0]);
556        let schema = create_test_schema();
557        let mut resolver = Resolver::new(&scope, &schema, true);
558
559        // 'id' should be in both users and orders
560        let sources = resolver.sources_for_column("id");
561        assert!(sources.contains(&"users".to_string()));
562        assert!(sources.contains(&"orders".to_string()));
563
564        // 'email' should only be in users
565        let sources = resolver.sources_for_column("email");
566        assert_eq!(sources, vec!["users".to_string()]);
567    }
568
569    #[test]
570    fn test_all_columns() {
571        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
572        let scope = build_scope(&ast[0]);
573        let schema = create_test_schema();
574        let mut resolver = Resolver::new(&scope, &schema, true);
575
576        let all = resolver.all_columns();
577        assert!(all.contains("id"));
578        assert!(all.contains("name"));
579        assert!(all.contains("email"));
580    }
581
582    #[test]
583    fn test_resolver_cte_projected_alias_column() {
584        let ast = Parser::parse_sql(
585            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
586        )
587        .expect("Failed to parse");
588        let scope = build_scope(&ast[0]);
589        let schema = create_test_schema();
590        let mut resolver = Resolver::new(&scope, &schema, true);
591
592        let table = resolver.get_table("emp_id");
593        assert_eq!(table, Some("my_cte".to_string()));
594    }
595
596    #[test]
597    fn test_resolve_column_helper() {
598        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
599        let scope = build_scope(&ast[0]);
600        let schema = create_test_schema();
601
602        let table = resolve_column(&scope, &schema, "name", true);
603        assert_eq!(table, Some("users".to_string()));
604    }
605
606    #[test]
607    fn test_resolver_bigquery_mixed_case_column_names() {
608        let dialect = Dialect::get(DialectType::BigQuery);
609        let expr = dialect
610            .parse("SELECT Name AS name FROM teams")
611            .unwrap()
612            .into_iter()
613            .next()
614            .expect("expected one expression");
615        let scope = build_scope(&expr);
616
617        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
618        schema
619            .add_table(
620                "teams",
621                &[("Name".into(), DataType::String { length: None })],
622                None,
623            )
624            .expect("schema setup");
625
626        let mut resolver = Resolver::new(&scope, &schema, true);
627        let table = resolver.get_table("Name");
628        assert_eq!(table, Some("teams".to_string()));
629
630        let table = resolver.get_table("name");
631        assert_eq!(table, Some("teams".to_string()));
632    }
633}