Skip to main content

polyglot_sql/
resolver.rs

1//! Column Resolver Module
2//!
3//! This module provides functionality for resolving column references to their
4//! source tables. It handles:
5//! - Finding which table a column belongs to
6//! - Resolving ambiguous column references
7//! - Handling join context for disambiguation
8//! - Supporting set operations (UNION, INTERSECT, EXCEPT)
9//!
10//! Based on the Python implementation in `sqlglot/optimizer/resolver.py`.
11
12use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during column resolution
20#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22    #[error("Unknown table: {0}")]
23    UnknownTable(String),
24
25    #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26    AmbiguousColumn { column: String, sources: String },
27
28    #[error("Column not found: {0}")]
29    ColumnNotFound(String),
30
31    #[error("Unknown set operation: {0}")]
32    UnknownSetOperation(String),
33}
34
35/// Result type for resolver operations
36pub type ResolverResult<T> = Result<T, ResolverError>;
37
38/// Helper for resolving columns to their source tables.
39///
40/// This is a struct so we can lazily load some things and easily share
41/// them across functions.
42pub struct Resolver<'a> {
43    /// The scope being analyzed
44    pub scope: &'a Scope,
45    /// The schema for table/column information
46    schema: &'a dyn Schema,
47    /// The dialect being used
48    pub dialect: Option<DialectType>,
49    /// Whether to infer schema from context
50    infer_schema: bool,
51    /// Cached source columns: source_name -> column names
52    source_columns_cache: HashMap<String, Vec<String>>,
53    /// Cached unambiguous columns: column_name -> source_name
54    unambiguous_columns_cache: Option<HashMap<String, String>>,
55    /// Cached set of all available columns
56    all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60    /// Create a new resolver for a scope
61    pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62        Self {
63            scope,
64            schema,
65            dialect: schema.dialect(),
66            infer_schema,
67            source_columns_cache: HashMap::new(),
68            unambiguous_columns_cache: None,
69            all_columns_cache: None,
70        }
71    }
72
73    /// Get the table for a column name.
74    ///
75    /// Returns the table name if it can be found/inferred.
76    pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77        // Try to find table from all sources (unambiguous lookup)
78        let table_name = self.get_table_name_from_sources(column_name, None);
79
80        // If we found a table, return it
81        if table_name.is_some() {
82            return table_name;
83        }
84
85        // If schema inference is enabled and exactly one source has no schema,
86        // assume the column belongs to that source
87        if self.infer_schema {
88            let sources_without_schema: Vec<_> = self
89                .get_all_source_columns()
90                .iter()
91                .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92                .map(|(name, _)| name.clone())
93                .collect();
94
95            if sources_without_schema.len() == 1 {
96                return Some(sources_without_schema[0].clone());
97            }
98        }
99
100        None
101    }
102
103    /// Get the table for a column, returning an Identifier
104    pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105        self.get_table(column_name).map(Identifier::new)
106    }
107
108    /// Get all available columns across all sources in this scope
109    pub fn all_columns(&mut self) -> &HashSet<String> {
110        if self.all_columns_cache.is_none() {
111            let mut all = HashSet::new();
112            for columns in self.get_all_source_columns().values() {
113                all.extend(columns.iter().cloned());
114            }
115            self.all_columns_cache = Some(all);
116        }
117        self.all_columns_cache
118            .as_ref()
119            .expect("cache populated above")
120    }
121
122    /// Get column names for a source.
123    ///
124    /// Returns the list of column names available from the given source.
125    pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
126        // Check cache first
127        if let Some(columns) = self.source_columns_cache.get(source_name) {
128            return Ok(columns.clone());
129        }
130
131        // Get the source info
132        let source_info = self
133            .scope
134            .sources
135            .get(source_name)
136            .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
137
138        let columns = self.extract_columns_from_source(source_info)?;
139
140        // Cache the result
141        self.source_columns_cache
142            .insert(source_name.to_string(), columns.clone());
143
144        Ok(columns)
145    }
146
147    /// Extract column names from a source expression
148    fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
149        let columns = match &source_info.expression {
150            Expression::Table(table) => {
151                // For tables, try to get columns from schema
152                let table_name = table.name.name.clone();
153                match self.schema.column_names(&table_name) {
154                    Ok(cols) => cols,
155                    Err(_) => Vec::new(), // Schema might not have this table
156                }
157            }
158            Expression::Subquery(subquery) => {
159                // For subqueries, get named_selects from the inner query
160                self.get_named_selects(&subquery.this)
161            }
162            Expression::Select(select) => {
163                // For derived tables that are SELECT expressions
164                self.get_select_column_names(select)
165            }
166            Expression::Union(union) => {
167                // For UNION, columns come from the set operation
168                self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
169            }
170            Expression::Intersect(intersect) => {
171                self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
172            }
173            Expression::Except(except) => {
174                self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
175            }
176            Expression::Cte(cte) => {
177                if !cte.columns.is_empty() {
178                    cte.columns.iter().map(|c| c.name.clone()).collect()
179                } else {
180                    self.get_named_selects(&cte.this)
181                }
182            }
183            _ => Vec::new(),
184        };
185
186        Ok(columns)
187    }
188
189    /// Get named selects (column names) from an expression
190    fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
191        match expr {
192            Expression::Select(select) => self.get_select_column_names(select),
193            Expression::Union(union) => {
194                // For unions, use the left side's columns
195                self.get_named_selects(&union.left)
196            }
197            Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
198            Expression::Except(except) => self.get_named_selects(&except.left),
199            Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
200            _ => Vec::new(),
201        }
202    }
203
204    /// Get column names from a SELECT expression
205    fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
206        select
207            .expressions
208            .iter()
209            .filter_map(|expr| self.get_expression_alias(expr))
210            .collect()
211    }
212
213    /// Get the alias or name for a select expression
214    fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
215        match expr {
216            Expression::Alias(alias) => Some(alias.alias.name.clone()),
217            Expression::Column(col) => Some(col.name.name.clone()),
218            Expression::Star(_) => Some("*".to_string()),
219            Expression::Identifier(id) => Some(id.name.clone()),
220            _ => None,
221        }
222    }
223
224    /// Get columns from a set operation (UNION, INTERSECT, EXCEPT)
225    pub fn get_source_columns_from_set_op(
226        &self,
227        expression: &Expression,
228    ) -> ResolverResult<Vec<String>> {
229        match expression {
230            Expression::Select(select) => Ok(self.get_select_column_names(select)),
231            Expression::Subquery(subquery) => {
232                if matches!(
233                    &subquery.this,
234                    Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
235                ) {
236                    self.get_source_columns_from_set_op(&subquery.this)
237                } else {
238                    Ok(self.get_named_selects(&subquery.this))
239                }
240            }
241            Expression::Union(union) => {
242                // Standard UNION: columns come from the left side
243                self.get_source_columns_from_set_op(&union.left)
244            }
245            Expression::Intersect(intersect) => {
246                self.get_source_columns_from_set_op(&intersect.left)
247            }
248            Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
249            _ => Err(ResolverError::UnknownSetOperation(format!(
250                "{:?}",
251                expression
252            ))),
253        }
254    }
255
256    /// Get all source columns for all sources in the scope
257    fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
258        let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
259
260        let mut result = HashMap::new();
261        for source_name in source_names {
262            if let Ok(columns) = self.get_source_columns(&source_name) {
263                result.insert(source_name, columns);
264            }
265        }
266        result
267    }
268
269    /// Get the table name for a column from the sources
270    fn get_table_name_from_sources(
271        &mut self,
272        column_name: &str,
273        source_columns: Option<&HashMap<String, Vec<String>>>,
274    ) -> Option<String> {
275        let unambiguous = match source_columns {
276            Some(cols) => self.compute_unambiguous_columns(cols),
277            None => {
278                if self.unambiguous_columns_cache.is_none() {
279                    let all_source_columns = self.get_all_source_columns();
280                    self.unambiguous_columns_cache =
281                        Some(self.compute_unambiguous_columns(&all_source_columns));
282                }
283                self.unambiguous_columns_cache
284                    .clone()
285                    .expect("cache populated above")
286            }
287        };
288
289        unambiguous.get(column_name).cloned()
290    }
291
292    /// Compute unambiguous columns mapping
293    ///
294    /// A column is unambiguous if it appears in exactly one source.
295    fn compute_unambiguous_columns(
296        &self,
297        source_columns: &HashMap<String, Vec<String>>,
298    ) -> HashMap<String, String> {
299        if source_columns.is_empty() {
300            return HashMap::new();
301        }
302
303        let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
304
305        for (source_name, columns) in source_columns {
306            for column in columns {
307                column_to_sources
308                    .entry(column.clone())
309                    .or_default()
310                    .push(source_name.clone());
311            }
312        }
313
314        // Keep only columns that appear in exactly one source
315        column_to_sources
316            .into_iter()
317            .filter(|(_, sources)| sources.len() == 1)
318            .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
319            .collect()
320    }
321
322    /// Check if a column is ambiguous (appears in multiple sources)
323    pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
324        let all_source_columns = self.get_all_source_columns();
325        let sources_with_column: Vec<_> = all_source_columns
326            .iter()
327            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
328            .map(|(name, _)| name.clone())
329            .collect();
330
331        sources_with_column.len() > 1
332    }
333
334    /// Get all sources that contain a given column
335    pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
336        let all_source_columns = self.get_all_source_columns();
337        all_source_columns
338            .iter()
339            .filter(|(_, columns)| columns.contains(&column_name.to_string()))
340            .map(|(name, _)| name.clone())
341            .collect()
342    }
343
344    /// Try to disambiguate a column based on join context
345    ///
346    /// In join conditions, a column can sometimes be disambiguated based on
347    /// which tables have been joined up to that point.
348    pub fn disambiguate_in_join_context(
349        &mut self,
350        column_name: &str,
351        available_sources: &[String],
352    ) -> Option<String> {
353        let mut matching_sources = Vec::new();
354
355        for source_name in available_sources {
356            if let Ok(columns) = self.get_source_columns(source_name) {
357                if columns.contains(&column_name.to_string()) {
358                    matching_sources.push(source_name.clone());
359                }
360            }
361        }
362
363        if matching_sources.len() == 1 {
364            Some(matching_sources.remove(0))
365        } else {
366            None
367        }
368    }
369}
370
371/// Resolve a column to its source table.
372///
373/// This is a convenience function that creates a Resolver and calls get_table.
374pub fn resolve_column(
375    scope: &Scope,
376    schema: &dyn Schema,
377    column_name: &str,
378    infer_schema: bool,
379) -> Option<String> {
380    let mut resolver = Resolver::new(scope, schema, infer_schema);
381    resolver.get_table(column_name)
382}
383
384/// Check if a column is ambiguous in the given scope.
385pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
386    let mut resolver = Resolver::new(scope, schema, true);
387    resolver.is_ambiguous(column_name)
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::expressions::DataType;
394    use crate::parser::Parser;
395    use crate::schema::MappingSchema;
396    use crate::scope::build_scope;
397
398    fn create_test_schema() -> MappingSchema {
399        let mut schema = MappingSchema::new();
400        // Add tables with columns
401        schema
402            .add_table(
403                "users",
404                &[
405                    (
406                        "id".to_string(),
407                        DataType::Int {
408                            length: None,
409                            integer_spelling: false,
410                        },
411                    ),
412                    ("name".to_string(), DataType::Text),
413                    ("email".to_string(), DataType::Text),
414                ],
415                None,
416            )
417            .unwrap();
418        schema
419            .add_table(
420                "orders",
421                &[
422                    (
423                        "id".to_string(),
424                        DataType::Int {
425                            length: None,
426                            integer_spelling: false,
427                        },
428                    ),
429                    (
430                        "user_id".to_string(),
431                        DataType::Int {
432                            length: None,
433                            integer_spelling: false,
434                        },
435                    ),
436                    (
437                        "amount".to_string(),
438                        DataType::Double {
439                            precision: None,
440                            scale: None,
441                        },
442                    ),
443                ],
444                None,
445            )
446            .unwrap();
447        schema
448    }
449
450    #[test]
451    fn test_resolver_basic() {
452        let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
453        let scope = build_scope(&ast[0]);
454        let schema = create_test_schema();
455        let mut resolver = Resolver::new(&scope, &schema, true);
456
457        // 'name' should resolve to 'users' since it's the only source
458        let table = resolver.get_table("name");
459        assert_eq!(table, Some("users".to_string()));
460    }
461
462    #[test]
463    fn test_resolver_ambiguous_column() {
464        let ast =
465            Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
466                .expect("Failed to parse");
467        let scope = build_scope(&ast[0]);
468        let schema = create_test_schema();
469        let mut resolver = Resolver::new(&scope, &schema, true);
470
471        // 'id' appears in both tables, so it's ambiguous
472        assert!(resolver.is_ambiguous("id"));
473
474        // 'name' only appears in users
475        assert!(!resolver.is_ambiguous("name"));
476
477        // 'amount' only appears in orders
478        assert!(!resolver.is_ambiguous("amount"));
479    }
480
481    #[test]
482    fn test_resolver_unambiguous_column() {
483        let ast = Parser::parse_sql(
484            "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
485        )
486        .expect("Failed to parse");
487        let scope = build_scope(&ast[0]);
488        let schema = create_test_schema();
489        let mut resolver = Resolver::new(&scope, &schema, true);
490
491        // 'name' should resolve to 'users'
492        let table = resolver.get_table("name");
493        assert_eq!(table, Some("users".to_string()));
494
495        // 'amount' should resolve to 'orders'
496        let table = resolver.get_table("amount");
497        assert_eq!(table, Some("orders".to_string()));
498    }
499
500    #[test]
501    fn test_resolver_with_alias() {
502        let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
503        let scope = build_scope(&ast[0]);
504        let schema = create_test_schema();
505        let _resolver = Resolver::new(&scope, &schema, true);
506
507        // Source should be indexed by alias 'u'
508        assert!(scope.sources.contains_key("u"));
509    }
510
511    #[test]
512    fn test_sources_for_column() {
513        let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
514            .expect("Failed to parse");
515        let scope = build_scope(&ast[0]);
516        let schema = create_test_schema();
517        let mut resolver = Resolver::new(&scope, &schema, true);
518
519        // 'id' should be in both users and orders
520        let sources = resolver.sources_for_column("id");
521        assert!(sources.contains(&"users".to_string()));
522        assert!(sources.contains(&"orders".to_string()));
523
524        // 'email' should only be in users
525        let sources = resolver.sources_for_column("email");
526        assert_eq!(sources, vec!["users".to_string()]);
527    }
528
529    #[test]
530    fn test_all_columns() {
531        let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
532        let scope = build_scope(&ast[0]);
533        let schema = create_test_schema();
534        let mut resolver = Resolver::new(&scope, &schema, true);
535
536        let all = resolver.all_columns();
537        assert!(all.contains("id"));
538        assert!(all.contains("name"));
539        assert!(all.contains("email"));
540    }
541
542    #[test]
543    fn test_resolver_cte_projected_alias_column() {
544        let ast = Parser::parse_sql(
545            "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
546        )
547        .expect("Failed to parse");
548        let scope = build_scope(&ast[0]);
549        let schema = create_test_schema();
550        let mut resolver = Resolver::new(&scope, &schema, true);
551
552        let table = resolver.get_table("emp_id");
553        assert_eq!(table, Some("my_cte".to_string()));
554    }
555
556    #[test]
557    fn test_resolve_column_helper() {
558        let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
559        let scope = build_scope(&ast[0]);
560        let schema = create_test_schema();
561
562        let table = resolve_column(&scope, &schema, "name", true);
563        assert_eq!(table, Some("users".to_string()));
564    }
565}