1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22 #[error("Unknown table: {0}")]
23 UnknownTable(String),
24
25 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26 AmbiguousColumn { column: String, sources: String },
27
28 #[error("Column not found: {0}")]
29 ColumnNotFound(String),
30
31 #[error("Unknown set operation: {0}")]
32 UnknownSetOperation(String),
33}
34
35pub type ResolverResult<T> = Result<T, ResolverError>;
37
38pub struct Resolver<'a> {
43 pub scope: &'a Scope,
45 schema: &'a dyn Schema,
47 pub dialect: Option<DialectType>,
49 infer_schema: bool,
51 source_columns_cache: HashMap<String, Vec<String>>,
53 unambiguous_columns_cache: Option<HashMap<String, String>>,
55 all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62 Self {
63 scope,
64 schema,
65 dialect: schema.dialect(),
66 infer_schema,
67 source_columns_cache: HashMap::new(),
68 unambiguous_columns_cache: None,
69 all_columns_cache: None,
70 }
71 }
72
73 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77 let table_name = self.get_table_name_from_sources(column_name, None);
79
80 if table_name.is_some() {
82 return table_name;
83 }
84
85 if self.infer_schema {
88 let sources_without_schema: Vec<_> = self
89 .get_all_source_columns()
90 .iter()
91 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92 .map(|(name, _)| name.clone())
93 .collect();
94
95 if sources_without_schema.len() == 1 {
96 return Some(sources_without_schema[0].clone());
97 }
98 }
99
100 None
101 }
102
103 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105 self.get_table(column_name).map(Identifier::new)
106 }
107
108 pub fn all_columns(&mut self) -> &HashSet<String> {
110 if self.all_columns_cache.is_none() {
111 let mut all = HashSet::new();
112 for columns in self.get_all_source_columns().values() {
113 all.extend(columns.iter().cloned());
114 }
115 self.all_columns_cache = Some(all);
116 }
117 self.all_columns_cache
118 .as_ref()
119 .expect("cache populated above")
120 }
121
122 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
126 if let Some(columns) = self.source_columns_cache.get(source_name) {
128 return Ok(columns.clone());
129 }
130
131 let source_info = self
133 .scope
134 .sources
135 .get(source_name)
136 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
137
138 let columns = self.extract_columns_from_source(source_info)?;
139
140 self.source_columns_cache
142 .insert(source_name.to_string(), columns.clone());
143
144 Ok(columns)
145 }
146
147 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
149 let columns = match &source_info.expression {
150 Expression::Table(table) => {
151 let table_name = table.name.name.clone();
153 match self.schema.column_names(&table_name) {
154 Ok(cols) => cols,
155 Err(_) => Vec::new(), }
157 }
158 Expression::Subquery(subquery) => {
159 self.get_named_selects(&subquery.this)
161 }
162 Expression::Select(select) => {
163 self.get_select_column_names(select)
165 }
166 Expression::Union(union) => {
167 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
169 }
170 Expression::Intersect(intersect) => {
171 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
172 }
173 Expression::Except(except) => {
174 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
175 }
176 _ => Vec::new(),
177 };
178
179 Ok(columns)
180 }
181
182 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
184 match expr {
185 Expression::Select(select) => self.get_select_column_names(select),
186 Expression::Union(union) => {
187 self.get_named_selects(&union.left)
189 }
190 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
191 Expression::Except(except) => self.get_named_selects(&except.left),
192 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
193 _ => Vec::new(),
194 }
195 }
196
197 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
199 select
200 .expressions
201 .iter()
202 .filter_map(|expr| self.get_expression_alias(expr))
203 .collect()
204 }
205
206 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
208 match expr {
209 Expression::Alias(alias) => Some(alias.alias.name.clone()),
210 Expression::Column(col) => Some(col.name.name.clone()),
211 Expression::Star(_) => Some("*".to_string()),
212 Expression::Identifier(id) => Some(id.name.clone()),
213 _ => None,
214 }
215 }
216
217 pub fn get_source_columns_from_set_op(
219 &self,
220 expression: &Expression,
221 ) -> ResolverResult<Vec<String>> {
222 match expression {
223 Expression::Select(select) => Ok(self.get_select_column_names(select)),
224 Expression::Subquery(subquery) => {
225 if matches!(
226 &subquery.this,
227 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
228 ) {
229 self.get_source_columns_from_set_op(&subquery.this)
230 } else {
231 Ok(self.get_named_selects(&subquery.this))
232 }
233 }
234 Expression::Union(union) => {
235 self.get_source_columns_from_set_op(&union.left)
237 }
238 Expression::Intersect(intersect) => {
239 self.get_source_columns_from_set_op(&intersect.left)
240 }
241 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
242 _ => Err(ResolverError::UnknownSetOperation(format!(
243 "{:?}",
244 expression
245 ))),
246 }
247 }
248
249 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
251 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
252
253 let mut result = HashMap::new();
254 for source_name in source_names {
255 if let Ok(columns) = self.get_source_columns(&source_name) {
256 result.insert(source_name, columns);
257 }
258 }
259 result
260 }
261
262 fn get_table_name_from_sources(
264 &mut self,
265 column_name: &str,
266 source_columns: Option<&HashMap<String, Vec<String>>>,
267 ) -> Option<String> {
268 let unambiguous = match source_columns {
269 Some(cols) => self.compute_unambiguous_columns(cols),
270 None => {
271 if self.unambiguous_columns_cache.is_none() {
272 let all_source_columns = self.get_all_source_columns();
273 self.unambiguous_columns_cache =
274 Some(self.compute_unambiguous_columns(&all_source_columns));
275 }
276 self.unambiguous_columns_cache
277 .clone()
278 .expect("cache populated above")
279 }
280 };
281
282 unambiguous.get(column_name).cloned()
283 }
284
285 fn compute_unambiguous_columns(
289 &self,
290 source_columns: &HashMap<String, Vec<String>>,
291 ) -> HashMap<String, String> {
292 if source_columns.is_empty() {
293 return HashMap::new();
294 }
295
296 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
297
298 for (source_name, columns) in source_columns {
299 for column in columns {
300 column_to_sources
301 .entry(column.clone())
302 .or_default()
303 .push(source_name.clone());
304 }
305 }
306
307 column_to_sources
309 .into_iter()
310 .filter(|(_, sources)| sources.len() == 1)
311 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
312 .collect()
313 }
314
315 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
317 let all_source_columns = self.get_all_source_columns();
318 let sources_with_column: Vec<_> = all_source_columns
319 .iter()
320 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
321 .map(|(name, _)| name.clone())
322 .collect();
323
324 sources_with_column.len() > 1
325 }
326
327 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
329 let all_source_columns = self.get_all_source_columns();
330 all_source_columns
331 .iter()
332 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
333 .map(|(name, _)| name.clone())
334 .collect()
335 }
336
337 pub fn disambiguate_in_join_context(
342 &mut self,
343 column_name: &str,
344 available_sources: &[String],
345 ) -> Option<String> {
346 let mut matching_sources = Vec::new();
347
348 for source_name in available_sources {
349 if let Ok(columns) = self.get_source_columns(source_name) {
350 if columns.contains(&column_name.to_string()) {
351 matching_sources.push(source_name.clone());
352 }
353 }
354 }
355
356 if matching_sources.len() == 1 {
357 Some(matching_sources.remove(0))
358 } else {
359 None
360 }
361 }
362}
363
364pub fn resolve_column(
368 scope: &Scope,
369 schema: &dyn Schema,
370 column_name: &str,
371 infer_schema: bool,
372) -> Option<String> {
373 let mut resolver = Resolver::new(scope, schema, infer_schema);
374 resolver.get_table(column_name)
375}
376
377pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
379 let mut resolver = Resolver::new(scope, schema, true);
380 resolver.is_ambiguous(column_name)
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::expressions::DataType;
387 use crate::parser::Parser;
388 use crate::schema::MappingSchema;
389 use crate::scope::build_scope;
390
391 fn create_test_schema() -> MappingSchema {
392 let mut schema = MappingSchema::new();
393 schema
395 .add_table(
396 "users",
397 &[
398 (
399 "id".to_string(),
400 DataType::Int {
401 length: None,
402 integer_spelling: false,
403 },
404 ),
405 ("name".to_string(), DataType::Text),
406 ("email".to_string(), DataType::Text),
407 ],
408 None,
409 )
410 .unwrap();
411 schema
412 .add_table(
413 "orders",
414 &[
415 (
416 "id".to_string(),
417 DataType::Int {
418 length: None,
419 integer_spelling: false,
420 },
421 ),
422 (
423 "user_id".to_string(),
424 DataType::Int {
425 length: None,
426 integer_spelling: false,
427 },
428 ),
429 (
430 "amount".to_string(),
431 DataType::Double {
432 precision: None,
433 scale: None,
434 },
435 ),
436 ],
437 None,
438 )
439 .unwrap();
440 schema
441 }
442
443 #[test]
444 fn test_resolver_basic() {
445 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
446 let scope = build_scope(&ast[0]);
447 let schema = create_test_schema();
448 let mut resolver = Resolver::new(&scope, &schema, true);
449
450 let table = resolver.get_table("name");
452 assert_eq!(table, Some("users".to_string()));
453 }
454
455 #[test]
456 fn test_resolver_ambiguous_column() {
457 let ast =
458 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
459 .expect("Failed to parse");
460 let scope = build_scope(&ast[0]);
461 let schema = create_test_schema();
462 let mut resolver = Resolver::new(&scope, &schema, true);
463
464 assert!(resolver.is_ambiguous("id"));
466
467 assert!(!resolver.is_ambiguous("name"));
469
470 assert!(!resolver.is_ambiguous("amount"));
472 }
473
474 #[test]
475 fn test_resolver_unambiguous_column() {
476 let ast = Parser::parse_sql(
477 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
478 )
479 .expect("Failed to parse");
480 let scope = build_scope(&ast[0]);
481 let schema = create_test_schema();
482 let mut resolver = Resolver::new(&scope, &schema, true);
483
484 let table = resolver.get_table("name");
486 assert_eq!(table, Some("users".to_string()));
487
488 let table = resolver.get_table("amount");
490 assert_eq!(table, Some("orders".to_string()));
491 }
492
493 #[test]
494 fn test_resolver_with_alias() {
495 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
496 let scope = build_scope(&ast[0]);
497 let schema = create_test_schema();
498 let _resolver = Resolver::new(&scope, &schema, true);
499
500 assert!(scope.sources.contains_key("u"));
502 }
503
504 #[test]
505 fn test_sources_for_column() {
506 let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
507 .expect("Failed to parse");
508 let scope = build_scope(&ast[0]);
509 let schema = create_test_schema();
510 let mut resolver = Resolver::new(&scope, &schema, true);
511
512 let sources = resolver.sources_for_column("id");
514 assert!(sources.contains(&"users".to_string()));
515 assert!(sources.contains(&"orders".to_string()));
516
517 let sources = resolver.sources_for_column("email");
519 assert_eq!(sources, vec!["users".to_string()]);
520 }
521
522 #[test]
523 fn test_all_columns() {
524 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
525 let scope = build_scope(&ast[0]);
526 let schema = create_test_schema();
527 let mut resolver = Resolver::new(&scope, &schema, true);
528
529 let all = resolver.all_columns();
530 assert!(all.contains("id"));
531 assert!(all.contains("name"));
532 assert!(all.contains("email"));
533 }
534
535 #[test]
536 fn test_resolve_column_helper() {
537 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
538 let scope = build_scope(&ast[0]);
539 let schema = create_test_schema();
540
541 let table = resolve_column(&scope, &schema, "name", true);
542 assert_eq!(table, Some("users".to_string()));
543 }
544}