1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22 #[error("Unknown table: {0}")]
23 UnknownTable(String),
24
25 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26 AmbiguousColumn { column: String, sources: String },
27
28 #[error("Column not found: {0}")]
29 ColumnNotFound(String),
30
31 #[error("Unknown set operation: {0}")]
32 UnknownSetOperation(String),
33}
34
35pub type ResolverResult<T> = Result<T, ResolverError>;
37
38pub struct Resolver<'a> {
43 pub scope: &'a Scope,
45 schema: &'a dyn Schema,
47 pub dialect: Option<DialectType>,
49 infer_schema: bool,
51 source_columns_cache: HashMap<String, Vec<String>>,
53 unambiguous_columns_cache: Option<HashMap<String, String>>,
55 all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62 Self {
63 scope,
64 schema,
65 dialect: schema.dialect(),
66 infer_schema,
67 source_columns_cache: HashMap::new(),
68 unambiguous_columns_cache: None,
69 all_columns_cache: None,
70 }
71 }
72
73 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77 let table_name = self.get_table_name_from_sources(column_name, None);
79
80 if table_name.is_some() {
82 return table_name;
83 }
84
85 if self.infer_schema {
88 let sources_without_schema: Vec<_> = self
89 .get_all_source_columns()
90 .iter()
91 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92 .map(|(name, _)| name.clone())
93 .collect();
94
95 if sources_without_schema.len() == 1 {
96 return Some(sources_without_schema[0].clone());
97 }
98 }
99
100 None
101 }
102
103 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105 self.get_table(column_name).map(Identifier::new)
106 }
107
108 pub fn all_columns(&mut self) -> &HashSet<String> {
110 if self.all_columns_cache.is_none() {
111 let mut all = HashSet::new();
112 for columns in self.get_all_source_columns().values() {
113 all.extend(columns.iter().cloned());
114 }
115 self.all_columns_cache = Some(all);
116 }
117 self.all_columns_cache
118 .as_ref()
119 .expect("cache populated above")
120 }
121
122 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
126 if let Some(columns) = self.source_columns_cache.get(source_name) {
128 return Ok(columns.clone());
129 }
130
131 let source_info = self
133 .scope
134 .sources
135 .get(source_name)
136 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
137
138 let columns = self.extract_columns_from_source(source_info)?;
139
140 self.source_columns_cache
142 .insert(source_name.to_string(), columns.clone());
143
144 Ok(columns)
145 }
146
147 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
149 let columns = match &source_info.expression {
150 Expression::Table(table) => {
151 let table_name = table.name.name.clone();
153 match self.schema.column_names(&table_name) {
154 Ok(cols) => cols,
155 Err(_) => Vec::new(), }
157 }
158 Expression::Subquery(subquery) => {
159 self.get_named_selects(&subquery.this)
161 }
162 Expression::Select(select) => {
163 self.get_select_column_names(select)
165 }
166 Expression::Union(union) => {
167 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
169 }
170 Expression::Intersect(intersect) => {
171 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
172 }
173 Expression::Except(except) => {
174 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
175 }
176 Expression::Cte(cte) => {
177 if !cte.columns.is_empty() {
178 cte.columns.iter().map(|c| c.name.clone()).collect()
179 } else {
180 self.get_named_selects(&cte.this)
181 }
182 }
183 _ => Vec::new(),
184 };
185
186 Ok(columns)
187 }
188
189 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
191 match expr {
192 Expression::Select(select) => self.get_select_column_names(select),
193 Expression::Union(union) => {
194 self.get_named_selects(&union.left)
196 }
197 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
198 Expression::Except(except) => self.get_named_selects(&except.left),
199 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
200 _ => Vec::new(),
201 }
202 }
203
204 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
206 select
207 .expressions
208 .iter()
209 .filter_map(|expr| self.get_expression_alias(expr))
210 .collect()
211 }
212
213 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
215 match expr {
216 Expression::Alias(alias) => Some(alias.alias.name.clone()),
217 Expression::Column(col) => Some(col.name.name.clone()),
218 Expression::Star(_) => Some("*".to_string()),
219 Expression::Identifier(id) => Some(id.name.clone()),
220 _ => None,
221 }
222 }
223
224 pub fn get_source_columns_from_set_op(
226 &self,
227 expression: &Expression,
228 ) -> ResolverResult<Vec<String>> {
229 match expression {
230 Expression::Select(select) => Ok(self.get_select_column_names(select)),
231 Expression::Subquery(subquery) => {
232 if matches!(
233 &subquery.this,
234 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
235 ) {
236 self.get_source_columns_from_set_op(&subquery.this)
237 } else {
238 Ok(self.get_named_selects(&subquery.this))
239 }
240 }
241 Expression::Union(union) => {
242 self.get_source_columns_from_set_op(&union.left)
244 }
245 Expression::Intersect(intersect) => {
246 self.get_source_columns_from_set_op(&intersect.left)
247 }
248 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
249 _ => Err(ResolverError::UnknownSetOperation(format!(
250 "{:?}",
251 expression
252 ))),
253 }
254 }
255
256 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
258 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
259
260 let mut result = HashMap::new();
261 for source_name in source_names {
262 if let Ok(columns) = self.get_source_columns(&source_name) {
263 result.insert(source_name, columns);
264 }
265 }
266 result
267 }
268
269 fn get_table_name_from_sources(
271 &mut self,
272 column_name: &str,
273 source_columns: Option<&HashMap<String, Vec<String>>>,
274 ) -> Option<String> {
275 let unambiguous = match source_columns {
276 Some(cols) => self.compute_unambiguous_columns(cols),
277 None => {
278 if self.unambiguous_columns_cache.is_none() {
279 let all_source_columns = self.get_all_source_columns();
280 self.unambiguous_columns_cache =
281 Some(self.compute_unambiguous_columns(&all_source_columns));
282 }
283 self.unambiguous_columns_cache
284 .clone()
285 .expect("cache populated above")
286 }
287 };
288
289 unambiguous.get(column_name).cloned()
290 }
291
292 fn compute_unambiguous_columns(
296 &self,
297 source_columns: &HashMap<String, Vec<String>>,
298 ) -> HashMap<String, String> {
299 if source_columns.is_empty() {
300 return HashMap::new();
301 }
302
303 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
304
305 for (source_name, columns) in source_columns {
306 for column in columns {
307 column_to_sources
308 .entry(column.clone())
309 .or_default()
310 .push(source_name.clone());
311 }
312 }
313
314 column_to_sources
316 .into_iter()
317 .filter(|(_, sources)| sources.len() == 1)
318 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
319 .collect()
320 }
321
322 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
324 let all_source_columns = self.get_all_source_columns();
325 let sources_with_column: Vec<_> = all_source_columns
326 .iter()
327 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
328 .map(|(name, _)| name.clone())
329 .collect();
330
331 sources_with_column.len() > 1
332 }
333
334 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
336 let all_source_columns = self.get_all_source_columns();
337 all_source_columns
338 .iter()
339 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
340 .map(|(name, _)| name.clone())
341 .collect()
342 }
343
344 pub fn disambiguate_in_join_context(
349 &mut self,
350 column_name: &str,
351 available_sources: &[String],
352 ) -> Option<String> {
353 let mut matching_sources = Vec::new();
354
355 for source_name in available_sources {
356 if let Ok(columns) = self.get_source_columns(source_name) {
357 if columns.contains(&column_name.to_string()) {
358 matching_sources.push(source_name.clone());
359 }
360 }
361 }
362
363 if matching_sources.len() == 1 {
364 Some(matching_sources.remove(0))
365 } else {
366 None
367 }
368 }
369}
370
371pub fn resolve_column(
375 scope: &Scope,
376 schema: &dyn Schema,
377 column_name: &str,
378 infer_schema: bool,
379) -> Option<String> {
380 let mut resolver = Resolver::new(scope, schema, infer_schema);
381 resolver.get_table(column_name)
382}
383
384pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
386 let mut resolver = Resolver::new(scope, schema, true);
387 resolver.is_ambiguous(column_name)
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::expressions::DataType;
394 use crate::parser::Parser;
395 use crate::schema::MappingSchema;
396 use crate::scope::build_scope;
397
398 fn create_test_schema() -> MappingSchema {
399 let mut schema = MappingSchema::new();
400 schema
402 .add_table(
403 "users",
404 &[
405 (
406 "id".to_string(),
407 DataType::Int {
408 length: None,
409 integer_spelling: false,
410 },
411 ),
412 ("name".to_string(), DataType::Text),
413 ("email".to_string(), DataType::Text),
414 ],
415 None,
416 )
417 .unwrap();
418 schema
419 .add_table(
420 "orders",
421 &[
422 (
423 "id".to_string(),
424 DataType::Int {
425 length: None,
426 integer_spelling: false,
427 },
428 ),
429 (
430 "user_id".to_string(),
431 DataType::Int {
432 length: None,
433 integer_spelling: false,
434 },
435 ),
436 (
437 "amount".to_string(),
438 DataType::Double {
439 precision: None,
440 scale: None,
441 },
442 ),
443 ],
444 None,
445 )
446 .unwrap();
447 schema
448 }
449
450 #[test]
451 fn test_resolver_basic() {
452 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
453 let scope = build_scope(&ast[0]);
454 let schema = create_test_schema();
455 let mut resolver = Resolver::new(&scope, &schema, true);
456
457 let table = resolver.get_table("name");
459 assert_eq!(table, Some("users".to_string()));
460 }
461
462 #[test]
463 fn test_resolver_ambiguous_column() {
464 let ast =
465 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
466 .expect("Failed to parse");
467 let scope = build_scope(&ast[0]);
468 let schema = create_test_schema();
469 let mut resolver = Resolver::new(&scope, &schema, true);
470
471 assert!(resolver.is_ambiguous("id"));
473
474 assert!(!resolver.is_ambiguous("name"));
476
477 assert!(!resolver.is_ambiguous("amount"));
479 }
480
481 #[test]
482 fn test_resolver_unambiguous_column() {
483 let ast = Parser::parse_sql(
484 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
485 )
486 .expect("Failed to parse");
487 let scope = build_scope(&ast[0]);
488 let schema = create_test_schema();
489 let mut resolver = Resolver::new(&scope, &schema, true);
490
491 let table = resolver.get_table("name");
493 assert_eq!(table, Some("users".to_string()));
494
495 let table = resolver.get_table("amount");
497 assert_eq!(table, Some("orders".to_string()));
498 }
499
500 #[test]
501 fn test_resolver_with_alias() {
502 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
503 let scope = build_scope(&ast[0]);
504 let schema = create_test_schema();
505 let _resolver = Resolver::new(&scope, &schema, true);
506
507 assert!(scope.sources.contains_key("u"));
509 }
510
511 #[test]
512 fn test_sources_for_column() {
513 let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
514 .expect("Failed to parse");
515 let scope = build_scope(&ast[0]);
516 let schema = create_test_schema();
517 let mut resolver = Resolver::new(&scope, &schema, true);
518
519 let sources = resolver.sources_for_column("id");
521 assert!(sources.contains(&"users".to_string()));
522 assert!(sources.contains(&"orders".to_string()));
523
524 let sources = resolver.sources_for_column("email");
526 assert_eq!(sources, vec!["users".to_string()]);
527 }
528
529 #[test]
530 fn test_all_columns() {
531 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
532 let scope = build_scope(&ast[0]);
533 let schema = create_test_schema();
534 let mut resolver = Resolver::new(&scope, &schema, true);
535
536 let all = resolver.all_columns();
537 assert!(all.contains("id"));
538 assert!(all.contains("name"));
539 assert!(all.contains("email"));
540 }
541
542 #[test]
543 fn test_resolver_cte_projected_alias_column() {
544 let ast = Parser::parse_sql(
545 "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
546 )
547 .expect("Failed to parse");
548 let scope = build_scope(&ast[0]);
549 let schema = create_test_schema();
550 let mut resolver = Resolver::new(&scope, &schema, true);
551
552 let table = resolver.get_table("emp_id");
553 assert_eq!(table, Some("my_cte".to_string()));
554 }
555
556 #[test]
557 fn test_resolve_column_helper() {
558 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
559 let scope = build_scope(&ast[0]);
560 let schema = create_test_schema();
561
562 let table = resolve_column(&scope, &schema, "name", true);
563 assert_eq!(table, Some("users".to_string()));
564 }
565}