Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Check if a table exists in the schema (not necessarily in the current scope).
109    /// Used to detect correlated references to outer scope tables.
110    pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
111        self.schema.column_names(table_name).is_ok()
112    }
113
114    /// Get all available columns across all sources in this scope
115    pub fn all_columns(&mut self) -> &HashSet<String> {
116        if self.all_columns_cache.is_none() {
117            let mut all = HashSet::new();
118            for columns in self.get_all_source_columns().values() {
119                all.extend(columns.iter().cloned());
120            }
121            self.all_columns_cache = Some(all);
122        }
123        self.all_columns_cache
124            .as_ref()
125            .expect("cache populated above")
126    }
127
128    /// Get column names for a source.
129    ///
130    /// Returns the list of column names available from the given source.
131    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
132        // Check cache first
133        if let Some(columns) = self.source_columns_cache.get(source_name) {
134            return Ok(columns.clone());
135        }
136
137        // Get the source info
138        let source_info = self
139            .scope
140            .sources
141            .get(source_name)
142            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
143
144        let columns = self.extract_columns_from_source(source_info)?;
145
146        // Cache the result
147        self.source_columns_cache
148            .insert(source_name.to_string(), columns.clone());
149
150        Ok(columns)
151    }
152
153    /// Extract column names from a source expression
154    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
155        let columns = match &source_info.expression {
156            Expression::Table(table) => {
157                // For tables, try to get columns from schema.
158                // Build the fully qualified name (catalog.schema.table) to
159                // match how MappingSchema stores hierarchical keys.
160                let table_name = qualified_table_name(table);
161                match self.schema.column_names(&table_name) {
162                    Ok(cols) => cols,
163                    Err(_) => Vec::new(), // Schema might not have this table
164                }
165            }
166            Expression::Subquery(subquery) => {
167                // For subqueries, get named_selects from the inner query
168                self.get_named_selects(&subquery.this)
169            }
170            Expression::Select(select) => {
171                // For derived tables that are SELECT expressions
172                self.get_select_column_names(select)
173            }
174            Expression::Union(union) => {
175                // For UNION, columns come from the set operation
176                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
177            }
178            Expression::Intersect(intersect) => {
179                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
180            }
181            Expression::Except(except) => {
182                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
183            }
184            Expression::Cte(cte) => {
185                if !cte.columns.is_empty() {
186                    cte.columns.iter().map(|c| c.name.clone()).collect()
187                } else {
188                    self.get_named_selects(&cte.this)
189                }
190            }
191            _ => Vec::new(),
192        };
193
194        Ok(columns)
195    }
196
197    /// Get named selects (column names) from an expression
198    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
199        match expr {
200            Expression::Select(select) => self.get_select_column_names(select),
201            Expression::Union(union) => {
202                // For unions, use the left side's columns
203                self.get_named_selects(&union.left)
204            }
205            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
206            Expression::Except(except) => self.get_named_selects(&except.left),
207            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
208            _ => Vec::new(),
209        }
210    }
211
212    /// Get column names from a SELECT expression
213    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
214        select
215            .expressions
216            .iter()
217            .filter_map(|expr| self.get_expression_alias(expr))
218            .collect()
219    }
220
221    /// Get the alias or name for a select expression
222    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
223        match expr {
224            Expression::Alias(alias) => Some(alias.alias.name.clone()),
225            Expression::Column(col) => Some(col.name.name.clone()),
226            Expression::Star(_) => Some("*".to_string()),
227            Expression::Identifier(id) => Some(id.name.clone()),
228            _ => None,
229        }
230    }
231
232    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
233    pub fn get_source_columns_from_set_op(
234        &self,
235        expression: &Expression,
236    ) -> ResolverResult<Vec<String>> {
237        match expression {
238            Expression::Select(select) => Ok(self.get_select_column_names(select)),
239            Expression::Subquery(subquery) => {
240                if matches!(
241                    &subquery.this,
242                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
243                ) {
244                    self.get_source_columns_from_set_op(&subquery.this)
245                } else {
246                    Ok(self.get_named_selects(&subquery.this))
247                }
248            }
249            Expression::Union(union) => {
250                // Standard UNION: columns come from the left side
251                self.get_source_columns_from_set_op(&union.left)
252            }
253            Expression::Intersect(intersect) => {
254                self.get_source_columns_from_set_op(&intersect.left)
255            }
256            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
257            _ => Err(ResolverError::UnknownSetOperation(format!(
258                "{:?}",
259                expression
260            ))),
261        }
262    }
263
264    /// Get all source columns for all sources in the scope
265    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
266        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
267
268        let mut result = HashMap::new();
269        for source_name in source_names {
270            if let Ok(columns) = self.get_source_columns(&source_name) {
271                result.insert(source_name, columns);
272            }
273        }
274        result
275    }
276
277    /// Get the table name for a column from the sources
278    fn get_table_name_from_sources(
279        &mut self,
280        column_name: &str,
281        source_columns: Option<&HashMap<String, Vec<String>>>,
282    ) -> Option<String> {
283        let unambiguous = match source_columns {
284            Some(cols) => self.compute_unambiguous_columns(cols),
285            None => {
286                if self.unambiguous_columns_cache.is_none() {
287                    let all_source_columns = self.get_all_source_columns();
288                    self.unambiguous_columns_cache =
289                        Some(self.compute_unambiguous_columns(&all_source_columns));
290                }
291                self.unambiguous_columns_cache
292                    .clone()
293                    .expect("cache populated above")
294            }
295        };
296
297        unambiguous.get(column_name).cloned()
298    }
299
300    /// Compute unambiguous columns mapping
301    ///
302    /// A column is unambiguous if it appears in exactly one source.
303    fn compute_unambiguous_columns(
304        &self,
305        source_columns: &HashMap<String, Vec<String>>,
306    ) -> HashMap<String, String> {
307        if source_columns.is_empty() {
308            return HashMap::new();
309        }
310
311        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
312
313        for (source_name, columns) in source_columns {
314            for column in columns {
315                column_to_sources
316                    .entry(column.clone())
317                    .or_default()
318                    .push(source_name.clone());
319            }
320        }
321
322        // Keep only columns that appear in exactly one source
323        column_to_sources
324            .into_iter()
325            .filter(|(_, sources)| sources.len() == 1)
326            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
327            .collect()
328    }
329
330    /// Check if a column is ambiguous (appears in multiple sources)
331    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
332        let all_source_columns = self.get_all_source_columns();
333        let sources_with_column: Vec<_> = all_source_columns
334            .iter()
335            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
336            .map(|(name, _)| name.clone())
337            .collect();
338
339        sources_with_column.len() > 1
340    }
341
342    /// Get all sources that contain a given column
343    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
344        let all_source_columns = self.get_all_source_columns();
345        all_source_columns
346            .iter()
347            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
348            .map(|(name, _)| name.clone())
349            .collect()
350    }
351
352    /// Try to disambiguate a column based on join context
353    ///
354    /// In join conditions, a column can sometimes be disambiguated based on
355    /// which tables have been joined up to that point.
356    pub fn disambiguate_in_join_context(
357        &mut self,
358        column_name: &str,
359        available_sources: &[String],
360    ) -> Option<String> {
361        let mut matching_sources = Vec::new();
362
363        for source_name in available_sources {
364            if let Ok(columns) = self.get_source_columns(source_name) {
365                if columns.contains(&column_name.to_string()) {
366                    matching_sources.push(source_name.clone());
367                }
368            }
369        }
370
371        if matching_sources.len() == 1 {
372            Some(matching_sources.remove(0))
373        } else {
374            None
375        }
376    }
377}
378
379/// Resolve a column to its source table.
380///
381/// This is a convenience function that creates a Resolver and calls get_table.
382pub fn resolve_column(
383    scope: &Scope,
384    schema: &dyn Schema,
385    column_name: &str,
386    infer_schema: bool,
387) -> Option<String> {
388    let mut resolver = Resolver::new(scope, schema, infer_schema);
389    resolver.get_table(column_name)
390}
391
392/// Check if a column is ambiguous in the given scope.
393pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
394    let mut resolver = Resolver::new(scope, schema, true);
395    resolver.is_ambiguous(column_name)
396}
397
398/// Build the fully qualified table name (catalog.schema.table) from a TableRef.
399fn qualified_table_name(table: &TableRef) -> String {
400    let mut parts = Vec::new();
401    if let Some(catalog) = &table.catalog {
402        parts.push(catalog.name.clone());
403    }
404    if let Some(schema) = &table.schema {
405        parts.push(schema.name.clone());
406    }
407    parts.push(table.name.name.clone());
408    parts.join(".")
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::expressions::DataType;
415    use crate::parser::Parser;
416    use crate::schema::MappingSchema;
417    use crate::scope::build_scope;
418
419    fn create_test_schema() -> MappingSchema {
420        let mut schema = MappingSchema::new();
421        // Add tables with columns
422        schema
423            .add_table(
424                "users",
425                &[
426                    (
427                        "id".to_string(),
428                        DataType::Int {
429                            length: None,
430                            integer_spelling: false,
431                        },
432                    ),
433                    ("name".to_string(), DataType::Text),
434                    ("email".to_string(), DataType::Text),
435                ],
436                None,
437            )
438            .unwrap();
439        schema
440            .add_table(
441                "orders",
442                &[
443                    (
444                        "id".to_string(),
445                        DataType::Int {
446                            length: None,
447                            integer_spelling: false,
448                        },
449                    ),
450                    (
451                        "user_id".to_string(),
452                        DataType::Int {
453                            length: None,
454                            integer_spelling: false,
455                        },
456                    ),
457                    (
458                        "amount".to_string(),
459                        DataType::Double {
460                            precision: None,
461                            scale: None,
462                        },
463                    ),
464                ],
465                None,
466            )
467            .unwrap();
468        schema
469    }
470
471    #[test]
472    fn test_resolver_basic() {
473        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
474        let scope = build_scope(&ast[0]);
475        let schema = create_test_schema();
476        let mut resolver = Resolver::new(&scope, &schema, true);
477
478        // 'name' should resolve to 'users' since it's the only source
479        let table = resolver.get_table("name");
480        assert_eq!(table, Some("users".to_string()));
481    }
482
483    #[test]
484    fn test_resolver_ambiguous_column() {
485        let ast =
486            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
487                .expect("Failed to parse");
488        let scope = build_scope(&ast[0]);
489        let schema = create_test_schema();
490        let mut resolver = Resolver::new(&scope, &schema, true);
491
492        // 'id' appears in both tables, so it's ambiguous
493        assert!(resolver.is_ambiguous("id"));
494
495        // 'name' only appears in users
496        assert!(!resolver.is_ambiguous("name"));
497
498        // 'amount' only appears in orders
499        assert!(!resolver.is_ambiguous("amount"));
500    }
501
502    #[test]
503    fn test_resolver_unambiguous_column() {
504        let ast = Parser::parse_sql(
505            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
506        )
507        .expect("Failed to parse");
508        let scope = build_scope(&ast[0]);
509        let schema = create_test_schema();
510        let mut resolver = Resolver::new(&scope, &schema, true);
511
512        // 'name' should resolve to 'users'
513        let table = resolver.get_table("name");
514        assert_eq!(table, Some("users".to_string()));
515
516        // 'amount' should resolve to 'orders'
517        let table = resolver.get_table("amount");
518        assert_eq!(table, Some("orders".to_string()));
519    }
520
521    #[test]
522    fn test_resolver_with_alias() {
523        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
524        let scope = build_scope(&ast[0]);
525        let schema = create_test_schema();
526        let _resolver = Resolver::new(&scope, &schema, true);
527
528        // Source should be indexed by alias 'u'
529        assert!(scope.sources.contains_key("u"));
530    }
531
532    #[test]
533    fn test_sources_for_column() {
534        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
535            .expect("Failed to parse");
536        let scope = build_scope(&ast[0]);
537        let schema = create_test_schema();
538        let mut resolver = Resolver::new(&scope, &schema, true);
539
540        // 'id' should be in both users and orders
541        let sources = resolver.sources_for_column("id");
542        assert!(sources.contains(&"users".to_string()));
543        assert!(sources.contains(&"orders".to_string()));
544
545        // 'email' should only be in users
546        let sources = resolver.sources_for_column("email");
547        assert_eq!(sources, vec!["users".to_string()]);
548    }
549
550    #[test]
551    fn test_all_columns() {
552        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
553        let scope = build_scope(&ast[0]);
554        let schema = create_test_schema();
555        let mut resolver = Resolver::new(&scope, &schema, true);
556
557        let all = resolver.all_columns();
558        assert!(all.contains("id"));
559        assert!(all.contains("name"));
560        assert!(all.contains("email"));
561    }
562
563    #[test]
564    fn test_resolver_cte_projected_alias_column() {
565        let ast = Parser::parse_sql(
566            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
567        )
568        .expect("Failed to parse");
569        let scope = build_scope(&ast[0]);
570        let schema = create_test_schema();
571        let mut resolver = Resolver::new(&scope, &schema, true);
572
573        let table = resolver.get_table("emp_id");
574        assert_eq!(table, Some("my_cte".to_string()));
575    }
576
577    #[test]
578    fn test_resolve_column_helper() {
579        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
580        let scope = build_scope(&ast[0]);
581        let schema = create_test_schema();
582
583        let table = resolve_column(&scope, &schema, "name", true);
584        assert_eq!(table, Some("users".to_string()));
585    }
586}