Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Get all available columns across all sources in this scope
109    pub fn all_columns(&mut self) -> &HashSet<String> {
110        if self.all_columns_cache.is_none() {
111            let mut all = HashSet::new();
112            for columns in self.get_all_source_columns().values() {
113                all.extend(columns.iter().cloned());
114            }
115            self.all_columns_cache = Some(all);
116        }
117        self.all_columns_cache
118            .as_ref()
119            .expect("cache populated above")
120    }
121
122    /// Get column names for a source.
123    ///
124    /// Returns the list of column names available from the given source.
125    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
126        // Check cache first
127        if let Some(columns) = self.source_columns_cache.get(source_name) {
128            return Ok(columns.clone());
129        }
130
131        // Get the source info
132        let source_info = self
133            .scope
134            .sources
135            .get(source_name)
136            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
137
138        let columns = self.extract_columns_from_source(source_info)?;
139
140        // Cache the result
141        self.source_columns_cache
142            .insert(source_name.to_string(), columns.clone());
143
144        Ok(columns)
145    }
146
147    /// Extract column names from a source expression
148    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
149        let columns = match &source_info.expression {
150            Expression::Table(table) => {
151                // For tables, try to get columns from schema
152                let table_name = table.name.name.clone();
153                match self.schema.column_names(&table_name) {
154                    Ok(cols) => cols,
155                    Err(_) => Vec::new(), // Schema might not have this table
156                }
157            }
158            Expression::Subquery(subquery) => {
159                // For subqueries, get named_selects from the inner query
160                self.get_named_selects(&subquery.this)
161            }
162            Expression::Select(select) => {
163                // For derived tables that are SELECT expressions
164                self.get_select_column_names(select)
165            }
166            Expression::Union(union) => {
167                // For UNION, columns come from the set operation
168                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
169            }
170            Expression::Intersect(intersect) => {
171                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
172            }
173            Expression::Except(except) => {
174                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
175            }
176            _ => Vec::new(),
177        };
178
179        Ok(columns)
180    }
181
182    /// Get named selects (column names) from an expression
183    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
184        match expr {
185            Expression::Select(select) => self.get_select_column_names(select),
186            Expression::Union(union) => {
187                // For unions, use the left side's columns
188                self.get_named_selects(&union.left)
189            }
190            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
191            Expression::Except(except) => self.get_named_selects(&except.left),
192            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
193            _ => Vec::new(),
194        }
195    }
196
197    /// Get column names from a SELECT expression
198    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
199        select
200            .expressions
201            .iter()
202            .filter_map(|expr| self.get_expression_alias(expr))
203            .collect()
204    }
205
206    /// Get the alias or name for a select expression
207    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
208        match expr {
209            Expression::Alias(alias) => Some(alias.alias.name.clone()),
210            Expression::Column(col) => Some(col.name.name.clone()),
211            Expression::Star(_) => Some("*".to_string()),
212            Expression::Identifier(id) => Some(id.name.clone()),
213            _ => None,
214        }
215    }
216
217    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
218    pub fn get_source_columns_from_set_op(
219        &self,
220        expression: &Expression,
221    ) -> ResolverResult<Vec<String>> {
222        match expression {
223            Expression::Select(select) => Ok(self.get_select_column_names(select)),
224            Expression::Subquery(subquery) => {
225                if matches!(
226                    &subquery.this,
227                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
228                ) {
229                    self.get_source_columns_from_set_op(&subquery.this)
230                } else {
231                    Ok(self.get_named_selects(&subquery.this))
232                }
233            }
234            Expression::Union(union) => {
235                // Standard UNION: columns come from the left side
236                self.get_source_columns_from_set_op(&union.left)
237            }
238            Expression::Intersect(intersect) => {
239                self.get_source_columns_from_set_op(&intersect.left)
240            }
241            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
242            _ => Err(ResolverError::UnknownSetOperation(format!(
243                "{:?}",
244                expression
245            ))),
246        }
247    }
248
249    /// Get all source columns for all sources in the scope
250    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
251        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
252
253        let mut result = HashMap::new();
254        for source_name in source_names {
255            if let Ok(columns) = self.get_source_columns(&source_name) {
256                result.insert(source_name, columns);
257            }
258        }
259        result
260    }
261
262    /// Get the table name for a column from the sources
263    fn get_table_name_from_sources(
264        &mut self,
265        column_name: &str,
266        source_columns: Option<&HashMap<String, Vec<String>>>,
267    ) -> Option<String> {
268        let unambiguous = match source_columns {
269            Some(cols) => self.compute_unambiguous_columns(cols),
270            None => {
271                if self.unambiguous_columns_cache.is_none() {
272                    let all_source_columns = self.get_all_source_columns();
273                    self.unambiguous_columns_cache =
274                        Some(self.compute_unambiguous_columns(&all_source_columns));
275                }
276                self.unambiguous_columns_cache
277                    .clone()
278                    .expect("cache populated above")
279            }
280        };
281
282        unambiguous.get(column_name).cloned()
283    }
284
285    /// Compute unambiguous columns mapping
286    ///
287    /// A column is unambiguous if it appears in exactly one source.
288    fn compute_unambiguous_columns(
289        &self,
290        source_columns: &HashMap<String, Vec<String>>,
291    ) -> HashMap<String, String> {
292        if source_columns.is_empty() {
293            return HashMap::new();
294        }
295
296        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
297
298        for (source_name, columns) in source_columns {
299            for column in columns {
300                column_to_sources
301                    .entry(column.clone())
302                    .or_default()
303                    .push(source_name.clone());
304            }
305        }
306
307        // Keep only columns that appear in exactly one source
308        column_to_sources
309            .into_iter()
310            .filter(|(_, sources)| sources.len() == 1)
311            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
312            .collect()
313    }
314
315    /// Check if a column is ambiguous (appears in multiple sources)
316    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
317        let all_source_columns = self.get_all_source_columns();
318        let sources_with_column: Vec<_> = all_source_columns
319            .iter()
320            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
321            .map(|(name, _)| name.clone())
322            .collect();
323
324        sources_with_column.len() > 1
325    }
326
327    /// Get all sources that contain a given column
328    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
329        let all_source_columns = self.get_all_source_columns();
330        all_source_columns
331            .iter()
332            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
333            .map(|(name, _)| name.clone())
334            .collect()
335    }
336
337    /// Try to disambiguate a column based on join context
338    ///
339    /// In join conditions, a column can sometimes be disambiguated based on
340    /// which tables have been joined up to that point.
341    pub fn disambiguate_in_join_context(
342        &mut self,
343        column_name: &str,
344        available_sources: &[String],
345    ) -> Option<String> {
346        let mut matching_sources = Vec::new();
347
348        for source_name in available_sources {
349            if let Ok(columns) = self.get_source_columns(source_name) {
350                if columns.contains(&column_name.to_string()) {
351                    matching_sources.push(source_name.clone());
352                }
353            }
354        }
355
356        if matching_sources.len() == 1 {
357            Some(matching_sources.remove(0))
358        } else {
359            None
360        }
361    }
362}
363
364/// Resolve a column to its source table.
365///
366/// This is a convenience function that creates a Resolver and calls get_table.
367pub fn resolve_column(
368    scope: &Scope,
369    schema: &dyn Schema,
370    column_name: &str,
371    infer_schema: bool,
372) -> Option<String> {
373    let mut resolver = Resolver::new(scope, schema, infer_schema);
374    resolver.get_table(column_name)
375}
376
377/// Check if a column is ambiguous in the given scope.
378pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
379    let mut resolver = Resolver::new(scope, schema, true);
380    resolver.is_ambiguous(column_name)
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::expressions::DataType;
387    use crate::parser::Parser;
388    use crate::schema::MappingSchema;
389    use crate::scope::build_scope;
390
391    fn create_test_schema() -> MappingSchema {
392        let mut schema = MappingSchema::new();
393        // Add tables with columns
394        schema
395            .add_table(
396                "users",
397                &[
398                    (
399                        "id".to_string(),
400                        DataType::Int {
401                            length: None,
402                            integer_spelling: false,
403                        },
404                    ),
405                    ("name".to_string(), DataType::Text),
406                    ("email".to_string(), DataType::Text),
407                ],
408                None,
409            )
410            .unwrap();
411        schema
412            .add_table(
413                "orders",
414                &[
415                    (
416                        "id".to_string(),
417                        DataType::Int {
418                            length: None,
419                            integer_spelling: false,
420                        },
421                    ),
422                    (
423                        "user_id".to_string(),
424                        DataType::Int {
425                            length: None,
426                            integer_spelling: false,
427                        },
428                    ),
429                    (
430                        "amount".to_string(),
431                        DataType::Double {
432                            precision: None,
433                            scale: None,
434                        },
435                    ),
436                ],
437                None,
438            )
439            .unwrap();
440        schema
441    }
442
443    #[test]
444    fn test_resolver_basic() {
445        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
446        let scope = build_scope(&ast[0]);
447        let schema = create_test_schema();
448        let mut resolver = Resolver::new(&scope, &schema, true);
449
450        // 'name' should resolve to 'users' since it's the only source
451        let table = resolver.get_table("name");
452        assert_eq!(table, Some("users".to_string()));
453    }
454
455    #[test]
456    fn test_resolver_ambiguous_column() {
457        let ast =
458            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
459                .expect("Failed to parse");
460        let scope = build_scope(&ast[0]);
461        let schema = create_test_schema();
462        let mut resolver = Resolver::new(&scope, &schema, true);
463
464        // 'id' appears in both tables, so it's ambiguous
465        assert!(resolver.is_ambiguous("id"));
466
467        // 'name' only appears in users
468        assert!(!resolver.is_ambiguous("name"));
469
470        // 'amount' only appears in orders
471        assert!(!resolver.is_ambiguous("amount"));
472    }
473
474    #[test]
475    fn test_resolver_unambiguous_column() {
476        let ast = Parser::parse_sql(
477            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
478        )
479        .expect("Failed to parse");
480        let scope = build_scope(&ast[0]);
481        let schema = create_test_schema();
482        let mut resolver = Resolver::new(&scope, &schema, true);
483
484        // 'name' should resolve to 'users'
485        let table = resolver.get_table("name");
486        assert_eq!(table, Some("users".to_string()));
487
488        // 'amount' should resolve to 'orders'
489        let table = resolver.get_table("amount");
490        assert_eq!(table, Some("orders".to_string()));
491    }
492
493    #[test]
494    fn test_resolver_with_alias() {
495        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
496        let scope = build_scope(&ast[0]);
497        let schema = create_test_schema();
498        let _resolver = Resolver::new(&scope, &schema, true);
499
500        // Source should be indexed by alias 'u'
501        assert!(scope.sources.contains_key("u"));
502    }
503
504    #[test]
505    fn test_sources_for_column() {
506        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
507            .expect("Failed to parse");
508        let scope = build_scope(&ast[0]);
509        let schema = create_test_schema();
510        let mut resolver = Resolver::new(&scope, &schema, true);
511
512        // 'id' should be in both users and orders
513        let sources = resolver.sources_for_column("id");
514        assert!(sources.contains(&"users".to_string()));
515        assert!(sources.contains(&"orders".to_string()));
516
517        // 'email' should only be in users
518        let sources = resolver.sources_for_column("email");
519        assert_eq!(sources, vec!["users".to_string()]);
520    }
521
522    #[test]
523    fn test_all_columns() {
524        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
525        let scope = build_scope(&ast[0]);
526        let schema = create_test_schema();
527        let mut resolver = Resolver::new(&scope, &schema, true);
528
529        let all = resolver.all_columns();
530        assert!(all.contains("id"));
531        assert!(all.contains("name"));
532        assert!(all.contains("email"));
533    }
534
535    #[test]
536    fn test_resolve_column_helper() {
537        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
538        let scope = build_scope(&ast[0]);
539        let schema = create_test_schema();
540
541        let table = resolve_column(&scope, &schema, "name", true);
542        assert_eq!(table, Some("users".to_string()));
543    }
544}