1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::schema::{normalize_name, Schema};
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22 #[error("Unknown table: {0}")]
23 UnknownTable(String),
24
25 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26 AmbiguousColumn { column: String, sources: String },
27
28 #[error("Column not found: {0}")]
29 ColumnNotFound(String),
30
31 #[error("Unknown set operation: {0}")]
32 UnknownSetOperation(String),
33}
34
35pub type ResolverResult<T> = Result<T, ResolverError>;
37
38pub struct Resolver<'a> {
43 pub scope: &'a Scope,
45 schema: &'a dyn Schema,
47 pub dialect: Option<DialectType>,
49 infer_schema: bool,
51 source_columns_cache: HashMap<String, Vec<String>>,
53 unambiguous_columns_cache: Option<HashMap<String, String>>,
55 all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62 Self {
63 scope,
64 schema,
65 dialect: schema.dialect(),
66 infer_schema,
67 source_columns_cache: HashMap::new(),
68 unambiguous_columns_cache: None,
69 all_columns_cache: None,
70 }
71 }
72
73 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77 let table_name = self.get_table_name_from_sources(column_name, None);
79
80 if table_name.is_some() {
82 return table_name;
83 }
84
85 if self.infer_schema {
88 let sources_without_schema: Vec<_> = self
89 .get_all_source_columns()
90 .iter()
91 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92 .map(|(name, _)| name.clone())
93 .collect();
94
95 if sources_without_schema.len() == 1 {
96 return Some(sources_without_schema[0].clone());
97 }
98 }
99
100 None
101 }
102
103 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105 self.get_table(column_name).map(Identifier::new)
106 }
107
108 pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
111 self.schema.column_names(table_name).is_ok()
112 }
113
114 pub fn find_column_in_outer_schema_tables(&self, column_name: &str) -> Option<String> {
119 let tables = self.schema.find_tables_for_column(column_name);
120 let outer_tables: Vec<String> = tables
122 .into_iter()
123 .filter(|t| !self.scope.sources.contains_key(t))
124 .collect();
125 if outer_tables.len() == 1 {
127 Some(outer_tables.into_iter().next().unwrap())
128 } else {
129 None
130 }
131 }
132
133 pub fn all_columns(&mut self) -> &HashSet<String> {
135 if self.all_columns_cache.is_none() {
136 let mut all = HashSet::new();
137 for columns in self.get_all_source_columns().values() {
138 all.extend(columns.iter().cloned());
139 }
140 self.all_columns_cache = Some(all);
141 }
142 self.all_columns_cache
143 .as_ref()
144 .expect("cache populated above")
145 }
146
147 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
151 if let Some(columns) = self.source_columns_cache.get(source_name) {
153 return Ok(columns.clone());
154 }
155
156 let source_info = self
158 .scope
159 .sources
160 .get(source_name)
161 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
162
163 let columns = self.extract_columns_from_source(source_info)?;
164
165 self.source_columns_cache
167 .insert(source_name.to_string(), columns.clone());
168
169 Ok(columns)
170 }
171
172 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
174 let columns = match &source_info.expression {
175 Expression::Table(table) => {
176 let table_name = qualified_table_name(table);
180 match self.schema.column_names(&table_name) {
181 Ok(cols) => cols,
182 Err(_) => Vec::new(), }
184 }
185 Expression::Subquery(subquery) => {
186 self.get_named_selects(&subquery.this)
188 }
189 Expression::Select(select) => {
190 self.get_select_column_names(select)
192 }
193 Expression::Union(union) => {
194 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
196 }
197 Expression::Intersect(intersect) => {
198 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
199 }
200 Expression::Except(except) => {
201 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
202 }
203 Expression::Cte(cte) => {
204 if !cte.columns.is_empty() {
205 cte.columns.iter().map(|c| c.name.clone()).collect()
206 } else {
207 self.get_named_selects(&cte.this)
208 }
209 }
210 _ => Vec::new(),
211 };
212
213 Ok(columns)
214 }
215
216 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
218 match expr {
219 Expression::Select(select) => self.get_select_column_names(select),
220 Expression::Union(union) => {
221 self.get_named_selects(&union.left)
223 }
224 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
225 Expression::Except(except) => self.get_named_selects(&except.left),
226 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
227 _ => Vec::new(),
228 }
229 }
230
231 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
233 select
234 .expressions
235 .iter()
236 .filter_map(|expr| self.get_expression_alias(expr))
237 .collect()
238 }
239
240 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
242 match expr {
243 Expression::Alias(alias) => Some(alias.alias.name.clone()),
244 Expression::Column(col) => Some(col.name.name.clone()),
245 Expression::Star(_) => Some("*".to_string()),
246 Expression::Identifier(id) => Some(id.name.clone()),
247 _ => None,
248 }
249 }
250
251 pub fn get_source_columns_from_set_op(
253 &self,
254 expression: &Expression,
255 ) -> ResolverResult<Vec<String>> {
256 match expression {
257 Expression::Select(select) => Ok(self.get_select_column_names(select)),
258 Expression::Subquery(subquery) => {
259 if matches!(
260 &subquery.this,
261 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
262 ) {
263 self.get_source_columns_from_set_op(&subquery.this)
264 } else {
265 Ok(self.get_named_selects(&subquery.this))
266 }
267 }
268 Expression::Union(union) => {
269 self.get_source_columns_from_set_op(&union.left)
271 }
272 Expression::Intersect(intersect) => {
273 self.get_source_columns_from_set_op(&intersect.left)
274 }
275 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
276 _ => Err(ResolverError::UnknownSetOperation(format!(
277 "{:?}",
278 expression
279 ))),
280 }
281 }
282
283 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
285 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
286
287 let mut result = HashMap::new();
288 for source_name in source_names {
289 if let Ok(columns) = self.get_source_columns(&source_name) {
290 result.insert(source_name, columns);
291 }
292 }
293 result
294 }
295
296 fn get_table_name_from_sources(
298 &mut self,
299 column_name: &str,
300 source_columns: Option<&HashMap<String, Vec<String>>>,
301 ) -> Option<String> {
302 let normalized_column_name = normalize_column_name(column_name, self.dialect);
303 let unambiguous = match source_columns {
304 Some(cols) => self.compute_unambiguous_columns(cols),
305 None => {
306 if self.unambiguous_columns_cache.is_none() {
307 let all_source_columns = self.get_all_source_columns();
308 self.unambiguous_columns_cache =
309 Some(self.compute_unambiguous_columns(&all_source_columns));
310 }
311 self.unambiguous_columns_cache
312 .clone()
313 .expect("cache populated above")
314 }
315 };
316
317 unambiguous.get(&normalized_column_name).cloned()
318 }
319
320 fn compute_unambiguous_columns(
324 &self,
325 source_columns: &HashMap<String, Vec<String>>,
326 ) -> HashMap<String, String> {
327 if source_columns.is_empty() {
328 return HashMap::new();
329 }
330
331 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
332
333 for (source_name, columns) in source_columns {
334 for column in columns {
335 column_to_sources
336 .entry(normalize_column_name(column, self.dialect))
337 .or_default()
338 .push(source_name.clone());
339 }
340 }
341
342 column_to_sources
344 .into_iter()
345 .filter(|(_, sources)| sources.len() == 1)
346 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
347 .collect()
348 }
349
350 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
352 let normalized_column_name = normalize_column_name(column_name, self.dialect);
353 let all_source_columns = self.get_all_source_columns();
354 let sources_with_column: Vec<_> = all_source_columns
355 .iter()
356 .filter(|(_, columns)| {
357 columns.iter().any(|column| {
358 normalize_column_name(column, self.dialect) == normalized_column_name
359 })
360 })
361 .map(|(name, _)| name.clone())
362 .collect();
363
364 sources_with_column.len() > 1
365 }
366
367 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
369 let normalized_column_name = normalize_column_name(column_name, self.dialect);
370 let all_source_columns = self.get_all_source_columns();
371 all_source_columns
372 .iter()
373 .filter(|(_, columns)| {
374 columns.iter().any(|column| {
375 normalize_column_name(column, self.dialect) == normalized_column_name
376 })
377 })
378 .map(|(name, _)| name.clone())
379 .collect()
380 }
381
382 pub fn disambiguate_in_join_context(
387 &mut self,
388 column_name: &str,
389 available_sources: &[String],
390 ) -> Option<String> {
391 let normalized_column_name = normalize_column_name(column_name, self.dialect);
392 let mut matching_sources = Vec::new();
393
394 for source_name in available_sources {
395 if let Ok(columns) = self.get_source_columns(source_name) {
396 if columns.iter().any(|column| {
397 normalize_column_name(column, self.dialect) == normalized_column_name
398 }) {
399 matching_sources.push(source_name.clone());
400 }
401 }
402 }
403
404 if matching_sources.len() == 1 {
405 Some(matching_sources.remove(0))
406 } else {
407 None
408 }
409 }
410}
411
412fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
413 normalize_name(name, dialect, false, true)
414}
415
416pub fn resolve_column(
420 scope: &Scope,
421 schema: &dyn Schema,
422 column_name: &str,
423 infer_schema: bool,
424) -> Option<String> {
425 let mut resolver = Resolver::new(scope, schema, infer_schema);
426 resolver.get_table(column_name)
427}
428
429pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
431 let mut resolver = Resolver::new(scope, schema, true);
432 resolver.is_ambiguous(column_name)
433}
434
435fn qualified_table_name(table: &TableRef) -> String {
437 let mut parts = Vec::new();
438 if let Some(catalog) = &table.catalog {
439 parts.push(catalog.name.clone());
440 }
441 if let Some(schema) = &table.schema {
442 parts.push(schema.name.clone());
443 }
444 parts.push(table.name.name.clone());
445 parts.join(".")
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::dialects::Dialect;
452 use crate::expressions::DataType;
453 use crate::parser::Parser;
454 use crate::schema::MappingSchema;
455 use crate::scope::build_scope;
456
457 fn create_test_schema() -> MappingSchema {
458 let mut schema = MappingSchema::new();
459 schema
461 .add_table(
462 "users",
463 &[
464 (
465 "id".to_string(),
466 DataType::Int {
467 length: None,
468 integer_spelling: false,
469 },
470 ),
471 ("name".to_string(), DataType::Text),
472 ("email".to_string(), DataType::Text),
473 ],
474 None,
475 )
476 .unwrap();
477 schema
478 .add_table(
479 "orders",
480 &[
481 (
482 "id".to_string(),
483 DataType::Int {
484 length: None,
485 integer_spelling: false,
486 },
487 ),
488 (
489 "user_id".to_string(),
490 DataType::Int {
491 length: None,
492 integer_spelling: false,
493 },
494 ),
495 (
496 "amount".to_string(),
497 DataType::Double {
498 precision: None,
499 scale: None,
500 },
501 ),
502 ],
503 None,
504 )
505 .unwrap();
506 schema
507 }
508
509 #[test]
510 fn test_resolver_basic() {
511 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
512 let scope = build_scope(&ast[0]);
513 let schema = create_test_schema();
514 let mut resolver = Resolver::new(&scope, &schema, true);
515
516 let table = resolver.get_table("name");
518 assert_eq!(table, Some("users".to_string()));
519 }
520
521 #[test]
522 fn test_resolver_ambiguous_column() {
523 let ast =
524 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
525 .expect("Failed to parse");
526 let scope = build_scope(&ast[0]);
527 let schema = create_test_schema();
528 let mut resolver = Resolver::new(&scope, &schema, true);
529
530 assert!(resolver.is_ambiguous("id"));
532
533 assert!(!resolver.is_ambiguous("name"));
535
536 assert!(!resolver.is_ambiguous("amount"));
538 }
539
540 #[test]
541 fn test_resolver_unambiguous_column() {
542 let ast = Parser::parse_sql(
543 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
544 )
545 .expect("Failed to parse");
546 let scope = build_scope(&ast[0]);
547 let schema = create_test_schema();
548 let mut resolver = Resolver::new(&scope, &schema, true);
549
550 let table = resolver.get_table("name");
552 assert_eq!(table, Some("users".to_string()));
553
554 let table = resolver.get_table("amount");
556 assert_eq!(table, Some("orders".to_string()));
557 }
558
559 #[test]
560 fn test_resolver_with_alias() {
561 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
562 let scope = build_scope(&ast[0]);
563 let schema = create_test_schema();
564 let _resolver = Resolver::new(&scope, &schema, true);
565
566 assert!(scope.sources.contains_key("u"));
568 }
569
570 #[test]
571 fn test_sources_for_column() {
572 let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
573 .expect("Failed to parse");
574 let scope = build_scope(&ast[0]);
575 let schema = create_test_schema();
576 let mut resolver = Resolver::new(&scope, &schema, true);
577
578 let sources = resolver.sources_for_column("id");
580 assert!(sources.contains(&"users".to_string()));
581 assert!(sources.contains(&"orders".to_string()));
582
583 let sources = resolver.sources_for_column("email");
585 assert_eq!(sources, vec!["users".to_string()]);
586 }
587
588 #[test]
589 fn test_all_columns() {
590 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
591 let scope = build_scope(&ast[0]);
592 let schema = create_test_schema();
593 let mut resolver = Resolver::new(&scope, &schema, true);
594
595 let all = resolver.all_columns();
596 assert!(all.contains("id"));
597 assert!(all.contains("name"));
598 assert!(all.contains("email"));
599 }
600
601 #[test]
602 fn test_resolver_cte_projected_alias_column() {
603 let ast = Parser::parse_sql(
604 "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
605 )
606 .expect("Failed to parse");
607 let scope = build_scope(&ast[0]);
608 let schema = create_test_schema();
609 let mut resolver = Resolver::new(&scope, &schema, true);
610
611 let table = resolver.get_table("emp_id");
612 assert_eq!(table, Some("my_cte".to_string()));
613 }
614
615 #[test]
616 fn test_resolve_column_helper() {
617 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
618 let scope = build_scope(&ast[0]);
619 let schema = create_test_schema();
620
621 let table = resolve_column(&scope, &schema, "name", true);
622 assert_eq!(table, Some("users".to_string()));
623 }
624
625 #[test]
626 fn test_resolver_bigquery_mixed_case_column_names() {
627 let dialect = Dialect::get(DialectType::BigQuery);
628 let expr = dialect
629 .parse("SELECT Name AS name FROM teams")
630 .unwrap()
631 .into_iter()
632 .next()
633 .expect("expected one expression");
634 let scope = build_scope(&expr);
635
636 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
637 schema
638 .add_table(
639 "teams",
640 &[("Name".into(), DataType::String { length: None })],
641 None,
642 )
643 .expect("schema setup");
644
645 let mut resolver = Resolver::new(&scope, &schema, true);
646 let table = resolver.get_table("Name");
647 assert_eq!(table, Some("teams".to_string()));
648
649 let table = resolver.get_table("name");
650 assert_eq!(table, Some("teams".to_string()));
651 }
652}