1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::schema::{normalize_name, Schema};
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22 #[error("Unknown table: {0}")]
23 UnknownTable(String),
24
25 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26 AmbiguousColumn { column: String, sources: String },
27
28 #[error("Column not found: {0}")]
29 ColumnNotFound(String),
30
31 #[error("Unknown set operation: {0}")]
32 UnknownSetOperation(String),
33}
34
35pub type ResolverResult<T> = Result<T, ResolverError>;
37
38pub struct Resolver<'a> {
43 pub scope: &'a Scope,
45 schema: &'a dyn Schema,
47 pub dialect: Option<DialectType>,
49 infer_schema: bool,
51 source_columns_cache: HashMap<String, Vec<String>>,
53 unambiguous_columns_cache: Option<HashMap<String, String>>,
55 all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62 Self {
63 scope,
64 schema,
65 dialect: schema.dialect(),
66 infer_schema,
67 source_columns_cache: HashMap::new(),
68 unambiguous_columns_cache: None,
69 all_columns_cache: None,
70 }
71 }
72
73 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77 let table_name = self.get_table_name_from_sources(column_name, None);
79
80 if table_name.is_some() {
82 return table_name;
83 }
84
85 if self.infer_schema {
88 let sources_without_schema: Vec<_> = self
89 .get_all_source_columns()
90 .iter()
91 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92 .map(|(name, _)| name.clone())
93 .collect();
94
95 if sources_without_schema.len() == 1 {
96 return Some(sources_without_schema[0].clone());
97 }
98 }
99
100 None
101 }
102
103 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105 self.get_table(column_name).map(Identifier::new)
106 }
107
108 pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
111 self.schema.column_names(table_name).is_ok()
112 }
113
114 pub fn all_columns(&mut self) -> &HashSet<String> {
116 if self.all_columns_cache.is_none() {
117 let mut all = HashSet::new();
118 for columns in self.get_all_source_columns().values() {
119 all.extend(columns.iter().cloned());
120 }
121 self.all_columns_cache = Some(all);
122 }
123 self.all_columns_cache
124 .as_ref()
125 .expect("cache populated above")
126 }
127
128 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
132 if let Some(columns) = self.source_columns_cache.get(source_name) {
134 return Ok(columns.clone());
135 }
136
137 let source_info = self
139 .scope
140 .sources
141 .get(source_name)
142 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
143
144 let columns = self.extract_columns_from_source(source_info)?;
145
146 self.source_columns_cache
148 .insert(source_name.to_string(), columns.clone());
149
150 Ok(columns)
151 }
152
153 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
155 let columns = match &source_info.expression {
156 Expression::Table(table) => {
157 let table_name = qualified_table_name(table);
161 match self.schema.column_names(&table_name) {
162 Ok(cols) => cols,
163 Err(_) => Vec::new(), }
165 }
166 Expression::Subquery(subquery) => {
167 self.get_named_selects(&subquery.this)
169 }
170 Expression::Select(select) => {
171 self.get_select_column_names(select)
173 }
174 Expression::Union(union) => {
175 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
177 }
178 Expression::Intersect(intersect) => {
179 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
180 }
181 Expression::Except(except) => {
182 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
183 }
184 Expression::Cte(cte) => {
185 if !cte.columns.is_empty() {
186 cte.columns.iter().map(|c| c.name.clone()).collect()
187 } else {
188 self.get_named_selects(&cte.this)
189 }
190 }
191 _ => Vec::new(),
192 };
193
194 Ok(columns)
195 }
196
197 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
199 match expr {
200 Expression::Select(select) => self.get_select_column_names(select),
201 Expression::Union(union) => {
202 self.get_named_selects(&union.left)
204 }
205 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
206 Expression::Except(except) => self.get_named_selects(&except.left),
207 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
208 _ => Vec::new(),
209 }
210 }
211
212 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
214 select
215 .expressions
216 .iter()
217 .filter_map(|expr| self.get_expression_alias(expr))
218 .collect()
219 }
220
221 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
223 match expr {
224 Expression::Alias(alias) => Some(alias.alias.name.clone()),
225 Expression::Column(col) => Some(col.name.name.clone()),
226 Expression::Star(_) => Some("*".to_string()),
227 Expression::Identifier(id) => Some(id.name.clone()),
228 _ => None,
229 }
230 }
231
232 pub fn get_source_columns_from_set_op(
234 &self,
235 expression: &Expression,
236 ) -> ResolverResult<Vec<String>> {
237 match expression {
238 Expression::Select(select) => Ok(self.get_select_column_names(select)),
239 Expression::Subquery(subquery) => {
240 if matches!(
241 &subquery.this,
242 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
243 ) {
244 self.get_source_columns_from_set_op(&subquery.this)
245 } else {
246 Ok(self.get_named_selects(&subquery.this))
247 }
248 }
249 Expression::Union(union) => {
250 self.get_source_columns_from_set_op(&union.left)
252 }
253 Expression::Intersect(intersect) => {
254 self.get_source_columns_from_set_op(&intersect.left)
255 }
256 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
257 _ => Err(ResolverError::UnknownSetOperation(format!(
258 "{:?}",
259 expression
260 ))),
261 }
262 }
263
264 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
266 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
267
268 let mut result = HashMap::new();
269 for source_name in source_names {
270 if let Ok(columns) = self.get_source_columns(&source_name) {
271 result.insert(source_name, columns);
272 }
273 }
274 result
275 }
276
277 fn get_table_name_from_sources(
279 &mut self,
280 column_name: &str,
281 source_columns: Option<&HashMap<String, Vec<String>>>,
282 ) -> Option<String> {
283 let normalized_column_name = normalize_column_name(column_name, self.dialect);
284 let unambiguous = match source_columns {
285 Some(cols) => self.compute_unambiguous_columns(cols),
286 None => {
287 if self.unambiguous_columns_cache.is_none() {
288 let all_source_columns = self.get_all_source_columns();
289 self.unambiguous_columns_cache =
290 Some(self.compute_unambiguous_columns(&all_source_columns));
291 }
292 self.unambiguous_columns_cache
293 .clone()
294 .expect("cache populated above")
295 }
296 };
297
298 unambiguous.get(&normalized_column_name).cloned()
299 }
300
301 fn compute_unambiguous_columns(
305 &self,
306 source_columns: &HashMap<String, Vec<String>>,
307 ) -> HashMap<String, String> {
308 if source_columns.is_empty() {
309 return HashMap::new();
310 }
311
312 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
313
314 for (source_name, columns) in source_columns {
315 for column in columns {
316 column_to_sources
317 .entry(normalize_column_name(column, self.dialect))
318 .or_default()
319 .push(source_name.clone());
320 }
321 }
322
323 column_to_sources
325 .into_iter()
326 .filter(|(_, sources)| sources.len() == 1)
327 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
328 .collect()
329 }
330
331 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
333 let normalized_column_name = normalize_column_name(column_name, self.dialect);
334 let all_source_columns = self.get_all_source_columns();
335 let sources_with_column: Vec<_> = all_source_columns
336 .iter()
337 .filter(|(_, columns)| {
338 columns.iter().any(|column| {
339 normalize_column_name(column, self.dialect) == normalized_column_name
340 })
341 })
342 .map(|(name, _)| name.clone())
343 .collect();
344
345 sources_with_column.len() > 1
346 }
347
348 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
350 let normalized_column_name = normalize_column_name(column_name, self.dialect);
351 let all_source_columns = self.get_all_source_columns();
352 all_source_columns
353 .iter()
354 .filter(|(_, columns)| {
355 columns.iter().any(|column| {
356 normalize_column_name(column, self.dialect) == normalized_column_name
357 })
358 })
359 .map(|(name, _)| name.clone())
360 .collect()
361 }
362
363 pub fn disambiguate_in_join_context(
368 &mut self,
369 column_name: &str,
370 available_sources: &[String],
371 ) -> Option<String> {
372 let normalized_column_name = normalize_column_name(column_name, self.dialect);
373 let mut matching_sources = Vec::new();
374
375 for source_name in available_sources {
376 if let Ok(columns) = self.get_source_columns(source_name) {
377 if columns.iter().any(|column| {
378 normalize_column_name(column, self.dialect) == normalized_column_name
379 }) {
380 matching_sources.push(source_name.clone());
381 }
382 }
383 }
384
385 if matching_sources.len() == 1 {
386 Some(matching_sources.remove(0))
387 } else {
388 None
389 }
390 }
391}
392
393fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
394 normalize_name(name, dialect, false, true)
395}
396
397pub fn resolve_column(
401 scope: &Scope,
402 schema: &dyn Schema,
403 column_name: &str,
404 infer_schema: bool,
405) -> Option<String> {
406 let mut resolver = Resolver::new(scope, schema, infer_schema);
407 resolver.get_table(column_name)
408}
409
410pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
412 let mut resolver = Resolver::new(scope, schema, true);
413 resolver.is_ambiguous(column_name)
414}
415
416fn qualified_table_name(table: &TableRef) -> String {
418 let mut parts = Vec::new();
419 if let Some(catalog) = &table.catalog {
420 parts.push(catalog.name.clone());
421 }
422 if let Some(schema) = &table.schema {
423 parts.push(schema.name.clone());
424 }
425 parts.push(table.name.name.clone());
426 parts.join(".")
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::dialects::Dialect;
433 use crate::expressions::DataType;
434 use crate::parser::Parser;
435 use crate::schema::MappingSchema;
436 use crate::scope::build_scope;
437
438 fn create_test_schema() -> MappingSchema {
439 let mut schema = MappingSchema::new();
440 schema
442 .add_table(
443 "users",
444 &[
445 (
446 "id".to_string(),
447 DataType::Int {
448 length: None,
449 integer_spelling: false,
450 },
451 ),
452 ("name".to_string(), DataType::Text),
453 ("email".to_string(), DataType::Text),
454 ],
455 None,
456 )
457 .unwrap();
458 schema
459 .add_table(
460 "orders",
461 &[
462 (
463 "id".to_string(),
464 DataType::Int {
465 length: None,
466 integer_spelling: false,
467 },
468 ),
469 (
470 "user_id".to_string(),
471 DataType::Int {
472 length: None,
473 integer_spelling: false,
474 },
475 ),
476 (
477 "amount".to_string(),
478 DataType::Double {
479 precision: None,
480 scale: None,
481 },
482 ),
483 ],
484 None,
485 )
486 .unwrap();
487 schema
488 }
489
490 #[test]
491 fn test_resolver_basic() {
492 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
493 let scope = build_scope(&ast[0]);
494 let schema = create_test_schema();
495 let mut resolver = Resolver::new(&scope, &schema, true);
496
497 let table = resolver.get_table("name");
499 assert_eq!(table, Some("users".to_string()));
500 }
501
502 #[test]
503 fn test_resolver_ambiguous_column() {
504 let ast =
505 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
506 .expect("Failed to parse");
507 let scope = build_scope(&ast[0]);
508 let schema = create_test_schema();
509 let mut resolver = Resolver::new(&scope, &schema, true);
510
511 assert!(resolver.is_ambiguous("id"));
513
514 assert!(!resolver.is_ambiguous("name"));
516
517 assert!(!resolver.is_ambiguous("amount"));
519 }
520
521 #[test]
522 fn test_resolver_unambiguous_column() {
523 let ast = Parser::parse_sql(
524 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
525 )
526 .expect("Failed to parse");
527 let scope = build_scope(&ast[0]);
528 let schema = create_test_schema();
529 let mut resolver = Resolver::new(&scope, &schema, true);
530
531 let table = resolver.get_table("name");
533 assert_eq!(table, Some("users".to_string()));
534
535 let table = resolver.get_table("amount");
537 assert_eq!(table, Some("orders".to_string()));
538 }
539
540 #[test]
541 fn test_resolver_with_alias() {
542 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
543 let scope = build_scope(&ast[0]);
544 let schema = create_test_schema();
545 let _resolver = Resolver::new(&scope, &schema, true);
546
547 assert!(scope.sources.contains_key("u"));
549 }
550
551 #[test]
552 fn test_sources_for_column() {
553 let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
554 .expect("Failed to parse");
555 let scope = build_scope(&ast[0]);
556 let schema = create_test_schema();
557 let mut resolver = Resolver::new(&scope, &schema, true);
558
559 let sources = resolver.sources_for_column("id");
561 assert!(sources.contains(&"users".to_string()));
562 assert!(sources.contains(&"orders".to_string()));
563
564 let sources = resolver.sources_for_column("email");
566 assert_eq!(sources, vec!["users".to_string()]);
567 }
568
569 #[test]
570 fn test_all_columns() {
571 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
572 let scope = build_scope(&ast[0]);
573 let schema = create_test_schema();
574 let mut resolver = Resolver::new(&scope, &schema, true);
575
576 let all = resolver.all_columns();
577 assert!(all.contains("id"));
578 assert!(all.contains("name"));
579 assert!(all.contains("email"));
580 }
581
582 #[test]
583 fn test_resolver_cte_projected_alias_column() {
584 let ast = Parser::parse_sql(
585 "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
586 )
587 .expect("Failed to parse");
588 let scope = build_scope(&ast[0]);
589 let schema = create_test_schema();
590 let mut resolver = Resolver::new(&scope, &schema, true);
591
592 let table = resolver.get_table("emp_id");
593 assert_eq!(table, Some("my_cte".to_string()));
594 }
595
596 #[test]
597 fn test_resolve_column_helper() {
598 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
599 let scope = build_scope(&ast[0]);
600 let schema = create_test_schema();
601
602 let table = resolve_column(&scope, &schema, "name", true);
603 assert_eq!(table, Some("users".to_string()));
604 }
605
606 #[test]
607 fn test_resolver_bigquery_mixed_case_column_names() {
608 let dialect = Dialect::get(DialectType::BigQuery);
609 let expr = dialect
610 .parse("SELECT Name AS name FROM teams")
611 .unwrap()
612 .into_iter()
613 .next()
614 .expect("expected one expression");
615 let scope = build_scope(&expr);
616
617 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
618 schema
619 .add_table(
620 "teams",
621 &[("Name".into(), DataType::String { length: None })],
622 None,
623 )
624 .expect("schema setup");
625
626 let mut resolver = Resolver::new(&scope, &schema, true);
627 let table = resolver.get_table("Name");
628 assert_eq!(table, Some("teams".to_string()));
629
630 let table = resolver.get_table("name");
631 assert_eq!(table, Some("teams".to_string()));
632 }
633}