Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Get all available columns across all sources in this scope
109    pub fn all_columns(&mut self) -> &HashSet<String> {
110        if self.all_columns_cache.is_none() {
111            let mut all = HashSet::new();
112            for columns in self.get_all_source_columns().values() {
113                all.extend(columns.iter().cloned());
114            }
115            self.all_columns_cache = Some(all);
116        }
117        self.all_columns_cache.as_ref().expect("cache populated above")
118    }
119
120    /// Get column names for a source.
121    ///
122    /// Returns the list of column names available from the given source.
123    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
124        // Check cache first
125        if let Some(columns) = self.source_columns_cache.get(source_name) {
126            return Ok(columns.clone());
127        }
128
129        // Get the source info
130        let source_info = self
131            .scope
132            .sources
133            .get(source_name)
134            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
135
136        let columns = self.extract_columns_from_source(source_info)?;
137
138        // Cache the result
139        self.source_columns_cache
140            .insert(source_name.to_string(), columns.clone());
141
142        Ok(columns)
143    }
144
145    /// Extract column names from a source expression
146    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
147        let columns = match &source_info.expression {
148            Expression::Table(table) => {
149                // For tables, try to get columns from schema
150                let table_name = table.name.name.clone();
151                match self.schema.column_names(&table_name) {
152                    Ok(cols) => cols,
153                    Err(_) => Vec::new(), // Schema might not have this table
154                }
155            }
156            Expression::Subquery(subquery) => {
157                // For subqueries, get named_selects from the inner query
158                self.get_named_selects(&subquery.this)
159            }
160            Expression::Select(select) => {
161                // For derived tables that are SELECT expressions
162                self.get_select_column_names(select)
163            }
164            Expression::Union(union) => {
165                // For UNION, columns come from the set operation
166                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
167            }
168            Expression::Intersect(intersect) => {
169                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
170            }
171            Expression::Except(except) => {
172                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
173            }
174            _ => Vec::new(),
175        };
176
177        Ok(columns)
178    }
179
180    /// Get named selects (column names) from an expression
181    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
182        match expr {
183            Expression::Select(select) => self.get_select_column_names(select),
184            Expression::Union(union) => {
185                // For unions, use the left side's columns
186                self.get_named_selects(&union.left)
187            }
188            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
189            Expression::Except(except) => self.get_named_selects(&except.left),
190            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
191            _ => Vec::new(),
192        }
193    }
194
195    /// Get column names from a SELECT expression
196    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
197        select
198            .expressions
199            .iter()
200            .filter_map(|expr| self.get_expression_alias(expr))
201            .collect()
202    }
203
204    /// Get the alias or name for a select expression
205    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
206        match expr {
207            Expression::Alias(alias) => Some(alias.alias.name.clone()),
208            Expression::Column(col) => Some(col.name.name.clone()),
209            Expression::Star(_) => Some("*".to_string()),
210            Expression::Identifier(id) => Some(id.name.clone()),
211            _ => None,
212        }
213    }
214
215    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
216    pub fn get_source_columns_from_set_op(
217        &self,
218        expression: &Expression,
219    ) -> ResolverResult<Vec<String>> {
220        match expression {
221            Expression::Select(select) => Ok(self.get_select_column_names(select)),
222            Expression::Subquery(subquery) => {
223                if matches!(
224                    &subquery.this,
225                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
226                ) {
227                    self.get_source_columns_from_set_op(&subquery.this)
228                } else {
229                    Ok(self.get_named_selects(&subquery.this))
230                }
231            }
232            Expression::Union(union) => {
233                // Standard UNION: columns come from the left side
234                self.get_source_columns_from_set_op(&union.left)
235            }
236            Expression::Intersect(intersect) => {
237                self.get_source_columns_from_set_op(&intersect.left)
238            }
239            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
240            _ => Err(ResolverError::UnknownSetOperation(format!(
241                "{:?}",
242                expression
243            ))),
244        }
245    }
246
247    /// Get all source columns for all sources in the scope
248    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
249        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
250
251        let mut result = HashMap::new();
252        for source_name in source_names {
253            if let Ok(columns) = self.get_source_columns(&source_name) {
254                result.insert(source_name, columns);
255            }
256        }
257        result
258    }
259
260    /// Get the table name for a column from the sources
261    fn get_table_name_from_sources(
262        &mut self,
263        column_name: &str,
264        source_columns: Option<&HashMap<String, Vec<String>>>,
265    ) -> Option<String> {
266        let unambiguous = match source_columns {
267            Some(cols) => self.compute_unambiguous_columns(cols),
268            None => {
269                if self.unambiguous_columns_cache.is_none() {
270                    let all_source_columns = self.get_all_source_columns();
271                    self.unambiguous_columns_cache =
272                        Some(self.compute_unambiguous_columns(&all_source_columns));
273                }
274                self.unambiguous_columns_cache.clone().expect("cache populated above")
275            }
276        };
277
278        unambiguous.get(column_name).cloned()
279    }
280
281    /// Compute unambiguous columns mapping
282    ///
283    /// A column is unambiguous if it appears in exactly one source.
284    fn compute_unambiguous_columns(
285        &self,
286        source_columns: &HashMap<String, Vec<String>>,
287    ) -> HashMap<String, String> {
288        if source_columns.is_empty() {
289            return HashMap::new();
290        }
291
292        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
293
294        for (source_name, columns) in source_columns {
295            for column in columns {
296                column_to_sources
297                    .entry(column.clone())
298                    .or_default()
299                    .push(source_name.clone());
300            }
301        }
302
303        // Keep only columns that appear in exactly one source
304        column_to_sources
305            .into_iter()
306            .filter(|(_, sources)| sources.len() == 1)
307            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
308            .collect()
309    }
310
311    /// Check if a column is ambiguous (appears in multiple sources)
312    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
313        let all_source_columns = self.get_all_source_columns();
314        let sources_with_column: Vec<_> = all_source_columns
315            .iter()
316            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
317            .map(|(name, _)| name.clone())
318            .collect();
319
320        sources_with_column.len() > 1
321    }
322
323    /// Get all sources that contain a given column
324    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
325        let all_source_columns = self.get_all_source_columns();
326        all_source_columns
327            .iter()
328            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
329            .map(|(name, _)| name.clone())
330            .collect()
331    }
332
333    /// Try to disambiguate a column based on join context
334    ///
335    /// In join conditions, a column can sometimes be disambiguated based on
336    /// which tables have been joined up to that point.
337    pub fn disambiguate_in_join_context(
338        &mut self,
339        column_name: &str,
340        available_sources: &[String],
341    ) -> Option<String> {
342        let mut matching_sources = Vec::new();
343
344        for source_name in available_sources {
345            if let Ok(columns) = self.get_source_columns(source_name) {
346                if columns.contains(&column_name.to_string()) {
347                    matching_sources.push(source_name.clone());
348                }
349            }
350        }
351
352        if matching_sources.len() == 1 {
353            Some(matching_sources.remove(0))
354        } else {
355            None
356        }
357    }
358}
359
360/// Resolve a column to its source table.
361///
362/// This is a convenience function that creates a Resolver and calls get_table.
363pub fn resolve_column(
364    scope: &Scope,
365    schema: &dyn Schema,
366    column_name: &str,
367    infer_schema: bool,
368) -> Option<String> {
369    let mut resolver = Resolver::new(scope, schema, infer_schema);
370    resolver.get_table(column_name)
371}
372
373/// Check if a column is ambiguous in the given scope.
374pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
375    let mut resolver = Resolver::new(scope, schema, true);
376    resolver.is_ambiguous(column_name)
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::expressions::DataType;
383    use crate::parser::Parser;
384    use crate::schema::MappingSchema;
385    use crate::scope::build_scope;
386
387    fn create_test_schema() -> MappingSchema {
388        let mut schema = MappingSchema::new();
389        // Add tables with columns
390        schema
391            .add_table(
392                "users",
393                &[
394                    ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
395                    ("name".to_string(), DataType::Text),
396                    ("email".to_string(), DataType::Text),
397                ],
398                None,
399            )
400            .unwrap();
401        schema
402            .add_table(
403                "orders",
404                &[
405                    ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
406                    ("user_id".to_string(), DataType::Int { length: None, integer_spelling: false }),
407                    ("amount".to_string(), DataType::Double { precision: None, scale: None }),
408                ],
409                None,
410            )
411            .unwrap();
412        schema
413    }
414
415    #[test]
416    fn test_resolver_basic() {
417        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
418        let scope = build_scope(&ast[0]);
419        let schema = create_test_schema();
420        let mut resolver = Resolver::new(&scope, &schema, true);
421
422        // 'name' should resolve to 'users' since it's the only source
423        let table = resolver.get_table("name");
424        assert_eq!(table, Some("users".to_string()));
425    }
426
427    #[test]
428    fn test_resolver_ambiguous_column() {
429        let ast =
430            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
431                .expect("Failed to parse");
432        let scope = build_scope(&ast[0]);
433        let schema = create_test_schema();
434        let mut resolver = Resolver::new(&scope, &schema, true);
435
436        // 'id' appears in both tables, so it's ambiguous
437        assert!(resolver.is_ambiguous("id"));
438
439        // 'name' only appears in users
440        assert!(!resolver.is_ambiguous("name"));
441
442        // 'amount' only appears in orders
443        assert!(!resolver.is_ambiguous("amount"));
444    }
445
446    #[test]
447    fn test_resolver_unambiguous_column() {
448        let ast = Parser::parse_sql(
449            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
450        )
451        .expect("Failed to parse");
452        let scope = build_scope(&ast[0]);
453        let schema = create_test_schema();
454        let mut resolver = Resolver::new(&scope, &schema, true);
455
456        // 'name' should resolve to 'users'
457        let table = resolver.get_table("name");
458        assert_eq!(table, Some("users".to_string()));
459
460        // 'amount' should resolve to 'orders'
461        let table = resolver.get_table("amount");
462        assert_eq!(table, Some("orders".to_string()));
463    }
464
465    #[test]
466    fn test_resolver_with_alias() {
467        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
468        let scope = build_scope(&ast[0]);
469        let schema = create_test_schema();
470        let _resolver = Resolver::new(&scope, &schema, true);
471
472        // Source should be indexed by alias 'u'
473        assert!(scope.sources.contains_key("u"));
474    }
475
476    #[test]
477    fn test_sources_for_column() {
478        let ast =
479            Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
480                .expect("Failed to parse");
481        let scope = build_scope(&ast[0]);
482        let schema = create_test_schema();
483        let mut resolver = Resolver::new(&scope, &schema, true);
484
485        // 'id' should be in both users and orders
486        let sources = resolver.sources_for_column("id");
487        assert!(sources.contains(&"users".to_string()));
488        assert!(sources.contains(&"orders".to_string()));
489
490        // 'email' should only be in users
491        let sources = resolver.sources_for_column("email");
492        assert_eq!(sources, vec!["users".to_string()]);
493    }
494
495    #[test]
496    fn test_all_columns() {
497        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
498        let scope = build_scope(&ast[0]);
499        let schema = create_test_schema();
500        let mut resolver = Resolver::new(&scope, &schema, true);
501
502        let all = resolver.all_columns();
503        assert!(all.contains("id"));
504        assert!(all.contains("name"));
505        assert!(all.contains("email"));
506    }
507
508    #[test]
509    fn test_resolve_column_helper() {
510        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
511        let scope = build_scope(&ast[0]);
512        let schema = create_test_schema();
513
514        let table = resolve_column(&scope, &schema, "name", true);
515        assert_eq!(table, Some("users".to_string()));
516    }
517}