1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier};
14use crate::schema::Schema;
15use crate::scope::{Scope, SourceInfo};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone)]
21pub enum ResolverError {
22 #[error("Unknown table: {0}")]
23 UnknownTable(String),
24
25 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
26 AmbiguousColumn { column: String, sources: String },
27
28 #[error("Column not found: {0}")]
29 ColumnNotFound(String),
30
31 #[error("Unknown set operation: {0}")]
32 UnknownSetOperation(String),
33}
34
35pub type ResolverResult<T> = Result<T, ResolverError>;
37
38pub struct Resolver<'a> {
43 pub scope: &'a Scope,
45 schema: &'a dyn Schema,
47 pub dialect: Option<DialectType>,
49 infer_schema: bool,
51 source_columns_cache: HashMap<String, Vec<String>>,
53 unambiguous_columns_cache: Option<HashMap<String, String>>,
55 all_columns_cache: Option<HashSet<String>>,
57}
58
59impl<'a> Resolver<'a> {
60 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
62 Self {
63 scope,
64 schema,
65 dialect: schema.dialect(),
66 infer_schema,
67 source_columns_cache: HashMap::new(),
68 unambiguous_columns_cache: None,
69 all_columns_cache: None,
70 }
71 }
72
73 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
77 let table_name = self.get_table_name_from_sources(column_name, None);
79
80 if table_name.is_some() {
82 return table_name;
83 }
84
85 if self.infer_schema {
88 let sources_without_schema: Vec<_> = self
89 .get_all_source_columns()
90 .iter()
91 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
92 .map(|(name, _)| name.clone())
93 .collect();
94
95 if sources_without_schema.len() == 1 {
96 return Some(sources_without_schema[0].clone());
97 }
98 }
99
100 None
101 }
102
103 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
105 self.get_table(column_name).map(Identifier::new)
106 }
107
108 pub fn all_columns(&mut self) -> &HashSet<String> {
110 if self.all_columns_cache.is_none() {
111 let mut all = HashSet::new();
112 for columns in self.get_all_source_columns().values() {
113 all.extend(columns.iter().cloned());
114 }
115 self.all_columns_cache = Some(all);
116 }
117 self.all_columns_cache.as_ref().expect("cache populated above")
118 }
119
120 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
124 if let Some(columns) = self.source_columns_cache.get(source_name) {
126 return Ok(columns.clone());
127 }
128
129 let source_info = self
131 .scope
132 .sources
133 .get(source_name)
134 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
135
136 let columns = self.extract_columns_from_source(source_info)?;
137
138 self.source_columns_cache
140 .insert(source_name.to_string(), columns.clone());
141
142 Ok(columns)
143 }
144
145 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
147 let columns = match &source_info.expression {
148 Expression::Table(table) => {
149 let table_name = table.name.name.clone();
151 match self.schema.column_names(&table_name) {
152 Ok(cols) => cols,
153 Err(_) => Vec::new(), }
155 }
156 Expression::Subquery(subquery) => {
157 self.get_named_selects(&subquery.this)
159 }
160 Expression::Select(select) => {
161 self.get_select_column_names(select)
163 }
164 Expression::Union(union) => {
165 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
167 }
168 Expression::Intersect(intersect) => {
169 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
170 }
171 Expression::Except(except) => {
172 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
173 }
174 _ => Vec::new(),
175 };
176
177 Ok(columns)
178 }
179
180 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
182 match expr {
183 Expression::Select(select) => self.get_select_column_names(select),
184 Expression::Union(union) => {
185 self.get_named_selects(&union.left)
187 }
188 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
189 Expression::Except(except) => self.get_named_selects(&except.left),
190 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
191 _ => Vec::new(),
192 }
193 }
194
195 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
197 select
198 .expressions
199 .iter()
200 .filter_map(|expr| self.get_expression_alias(expr))
201 .collect()
202 }
203
204 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
206 match expr {
207 Expression::Alias(alias) => Some(alias.alias.name.clone()),
208 Expression::Column(col) => Some(col.name.name.clone()),
209 Expression::Star(_) => Some("*".to_string()),
210 Expression::Identifier(id) => Some(id.name.clone()),
211 _ => None,
212 }
213 }
214
215 pub fn get_source_columns_from_set_op(
217 &self,
218 expression: &Expression,
219 ) -> ResolverResult<Vec<String>> {
220 match expression {
221 Expression::Select(select) => Ok(self.get_select_column_names(select)),
222 Expression::Subquery(subquery) => {
223 if matches!(
224 &subquery.this,
225 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
226 ) {
227 self.get_source_columns_from_set_op(&subquery.this)
228 } else {
229 Ok(self.get_named_selects(&subquery.this))
230 }
231 }
232 Expression::Union(union) => {
233 self.get_source_columns_from_set_op(&union.left)
235 }
236 Expression::Intersect(intersect) => {
237 self.get_source_columns_from_set_op(&intersect.left)
238 }
239 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
240 _ => Err(ResolverError::UnknownSetOperation(format!(
241 "{:?}",
242 expression
243 ))),
244 }
245 }
246
247 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
249 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
250
251 let mut result = HashMap::new();
252 for source_name in source_names {
253 if let Ok(columns) = self.get_source_columns(&source_name) {
254 result.insert(source_name, columns);
255 }
256 }
257 result
258 }
259
260 fn get_table_name_from_sources(
262 &mut self,
263 column_name: &str,
264 source_columns: Option<&HashMap<String, Vec<String>>>,
265 ) -> Option<String> {
266 let unambiguous = match source_columns {
267 Some(cols) => self.compute_unambiguous_columns(cols),
268 None => {
269 if self.unambiguous_columns_cache.is_none() {
270 let all_source_columns = self.get_all_source_columns();
271 self.unambiguous_columns_cache =
272 Some(self.compute_unambiguous_columns(&all_source_columns));
273 }
274 self.unambiguous_columns_cache.clone().expect("cache populated above")
275 }
276 };
277
278 unambiguous.get(column_name).cloned()
279 }
280
281 fn compute_unambiguous_columns(
285 &self,
286 source_columns: &HashMap<String, Vec<String>>,
287 ) -> HashMap<String, String> {
288 if source_columns.is_empty() {
289 return HashMap::new();
290 }
291
292 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
293
294 for (source_name, columns) in source_columns {
295 for column in columns {
296 column_to_sources
297 .entry(column.clone())
298 .or_default()
299 .push(source_name.clone());
300 }
301 }
302
303 column_to_sources
305 .into_iter()
306 .filter(|(_, sources)| sources.len() == 1)
307 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
308 .collect()
309 }
310
311 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
313 let all_source_columns = self.get_all_source_columns();
314 let sources_with_column: Vec<_> = all_source_columns
315 .iter()
316 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
317 .map(|(name, _)| name.clone())
318 .collect();
319
320 sources_with_column.len() > 1
321 }
322
323 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
325 let all_source_columns = self.get_all_source_columns();
326 all_source_columns
327 .iter()
328 .filter(|(_, columns)| columns.contains(&column_name.to_string()))
329 .map(|(name, _)| name.clone())
330 .collect()
331 }
332
333 pub fn disambiguate_in_join_context(
338 &mut self,
339 column_name: &str,
340 available_sources: &[String],
341 ) -> Option<String> {
342 let mut matching_sources = Vec::new();
343
344 for source_name in available_sources {
345 if let Ok(columns) = self.get_source_columns(source_name) {
346 if columns.contains(&column_name.to_string()) {
347 matching_sources.push(source_name.clone());
348 }
349 }
350 }
351
352 if matching_sources.len() == 1 {
353 Some(matching_sources.remove(0))
354 } else {
355 None
356 }
357 }
358}
359
360pub fn resolve_column(
364 scope: &Scope,
365 schema: &dyn Schema,
366 column_name: &str,
367 infer_schema: bool,
368) -> Option<String> {
369 let mut resolver = Resolver::new(scope, schema, infer_schema);
370 resolver.get_table(column_name)
371}
372
373pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
375 let mut resolver = Resolver::new(scope, schema, true);
376 resolver.is_ambiguous(column_name)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::expressions::DataType;
383 use crate::parser::Parser;
384 use crate::schema::MappingSchema;
385 use crate::scope::build_scope;
386
387 fn create_test_schema() -> MappingSchema {
388 let mut schema = MappingSchema::new();
389 schema
391 .add_table(
392 "users",
393 &[
394 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
395 ("name".to_string(), DataType::Text),
396 ("email".to_string(), DataType::Text),
397 ],
398 None,
399 )
400 .unwrap();
401 schema
402 .add_table(
403 "orders",
404 &[
405 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
406 ("user_id".to_string(), DataType::Int { length: None, integer_spelling: false }),
407 ("amount".to_string(), DataType::Double { precision: None, scale: None }),
408 ],
409 None,
410 )
411 .unwrap();
412 schema
413 }
414
415 #[test]
416 fn test_resolver_basic() {
417 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
418 let scope = build_scope(&ast[0]);
419 let schema = create_test_schema();
420 let mut resolver = Resolver::new(&scope, &schema, true);
421
422 let table = resolver.get_table("name");
424 assert_eq!(table, Some("users".to_string()));
425 }
426
427 #[test]
428 fn test_resolver_ambiguous_column() {
429 let ast =
430 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
431 .expect("Failed to parse");
432 let scope = build_scope(&ast[0]);
433 let schema = create_test_schema();
434 let mut resolver = Resolver::new(&scope, &schema, true);
435
436 assert!(resolver.is_ambiguous("id"));
438
439 assert!(!resolver.is_ambiguous("name"));
441
442 assert!(!resolver.is_ambiguous("amount"));
444 }
445
446 #[test]
447 fn test_resolver_unambiguous_column() {
448 let ast = Parser::parse_sql(
449 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
450 )
451 .expect("Failed to parse");
452 let scope = build_scope(&ast[0]);
453 let schema = create_test_schema();
454 let mut resolver = Resolver::new(&scope, &schema, true);
455
456 let table = resolver.get_table("name");
458 assert_eq!(table, Some("users".to_string()));
459
460 let table = resolver.get_table("amount");
462 assert_eq!(table, Some("orders".to_string()));
463 }
464
465 #[test]
466 fn test_resolver_with_alias() {
467 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
468 let scope = build_scope(&ast[0]);
469 let schema = create_test_schema();
470 let _resolver = Resolver::new(&scope, &schema, true);
471
472 assert!(scope.sources.contains_key("u"));
474 }
475
476 #[test]
477 fn test_sources_for_column() {
478 let ast =
479 Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
480 .expect("Failed to parse");
481 let scope = build_scope(&ast[0]);
482 let schema = create_test_schema();
483 let mut resolver = Resolver::new(&scope, &schema, true);
484
485 let sources = resolver.sources_for_column("id");
487 assert!(sources.contains(&"users".to_string()));
488 assert!(sources.contains(&"orders".to_string()));
489
490 let sources = resolver.sources_for_column("email");
492 assert_eq!(sources, vec!["users".to_string()]);
493 }
494
495 #[test]
496 fn test_all_columns() {
497 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
498 let scope = build_scope(&ast[0]);
499 let schema = create_test_schema();
500 let mut resolver = Resolver::new(&scope, &schema, true);
501
502 let all = resolver.all_columns();
503 assert!(all.contains("id"));
504 assert!(all.contains("name"));
505 assert!(all.contains("email"));
506 }
507
508 #[test]
509 fn test_resolve_column_helper() {
510 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
511 let scope = build_scope(&ast[0]);
512 let schema = create_test_schema();
513
514 let table = resolve_column(&scope, &schema, "name", true);
515 assert_eq!(table, Some("users".to_string()));
516 }
517}