1use crate::dialects::DialectType;
13use crate::expressions::{Expression, Identifier, TableRef};
14use crate::generator::Generator;
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{Scope, SourceInfo};
17use crate::traversal::ExpressionWalk;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum ResolverError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Ambiguous column: {column} appears in multiple sources: {sources}")]
28 AmbiguousColumn { column: String, sources: String },
29
30 #[error("Column not found: {0}")]
31 ColumnNotFound(String),
32
33 #[error("Unknown set operation: {0}")]
34 UnknownSetOperation(String),
35}
36
37pub type ResolverResult<T> = Result<T, ResolverError>;
39
40pub struct Resolver<'a> {
45 pub scope: &'a Scope,
47 schema: &'a dyn Schema,
49 pub dialect: Option<DialectType>,
51 infer_schema: bool,
53 source_columns_cache: HashMap<String, Vec<String>>,
55 unambiguous_columns_cache: Option<HashMap<String, String>>,
57 all_columns_cache: Option<HashSet<String>>,
59}
60
61impl<'a> Resolver<'a> {
62 pub fn new(scope: &'a Scope, schema: &'a dyn Schema, infer_schema: bool) -> Self {
64 Self {
65 scope,
66 schema,
67 dialect: schema.dialect(),
68 infer_schema,
69 source_columns_cache: HashMap::new(),
70 unambiguous_columns_cache: None,
71 all_columns_cache: None,
72 }
73 }
74
75 pub fn get_table(&mut self, column_name: &str) -> Option<String> {
79 let table_name = self.get_table_name_from_sources(column_name, None);
81
82 if table_name.is_some() {
84 return table_name;
85 }
86
87 if self.infer_schema {
90 let sources_without_schema: Vec<_> = self
91 .get_all_source_columns()
92 .iter()
93 .filter(|(_, columns)| columns.is_empty() || columns.contains(&"*".to_string()))
94 .map(|(name, _)| name.clone())
95 .collect();
96
97 if sources_without_schema.len() == 1 {
98 return Some(sources_without_schema[0].clone());
99 }
100 }
101
102 None
103 }
104
105 pub fn get_table_identifier(&mut self, column_name: &str) -> Option<Identifier> {
107 self.get_table(column_name).map(Identifier::new)
108 }
109
110 pub fn table_exists_in_schema(&self, table_name: &str) -> bool {
113 self.schema.column_names(table_name).is_ok()
114 }
115
116 pub fn find_column_in_outer_schema_tables(&self, column_name: &str) -> Option<String> {
121 let tables = self.schema.find_tables_for_column(column_name);
122 let outer_tables: Vec<String> = tables
124 .into_iter()
125 .filter(|t| !self.scope.sources.contains_key(t))
126 .collect();
127 if outer_tables.len() == 1 {
129 Some(outer_tables.into_iter().next().unwrap())
130 } else {
131 None
132 }
133 }
134
135 pub fn all_columns(&mut self) -> &HashSet<String> {
137 if self.all_columns_cache.is_none() {
138 let mut all = HashSet::new();
139 for columns in self.get_all_source_columns().values() {
140 all.extend(columns.iter().cloned());
141 }
142 self.all_columns_cache = Some(all);
143 }
144 self.all_columns_cache
145 .as_ref()
146 .expect("cache populated above")
147 }
148
149 pub fn get_source_columns(&mut self, source_name: &str) -> ResolverResult<Vec<String>> {
153 if let Some(columns) = self.source_columns_cache.get(source_name) {
155 return Ok(columns.clone());
156 }
157
158 let source_info = self
160 .scope
161 .sources
162 .get(source_name)
163 .ok_or_else(|| ResolverError::UnknownTable(source_name.to_string()))?;
164
165 let columns = self.extract_columns_from_source(source_info)?;
166
167 self.source_columns_cache
169 .insert(source_name.to_string(), columns.clone());
170
171 Ok(columns)
172 }
173
174 fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult<Vec<String>> {
176 self.get_source_columns_for_expression(&source_info.expression)
177 }
178
179 fn get_source_columns_for_expression(
180 &self,
181 expression: &Expression,
182 ) -> ResolverResult<Vec<String>> {
183 let columns = match expression {
184 Expression::Table(table) => {
185 let table_name = qualified_table_name(table);
189 match self.schema.column_names(&table_name) {
190 Ok(cols) => cols,
191 Err(_) => Vec::new(), }
193 }
194 Expression::Subquery(subquery) => {
195 self.get_named_selects(&subquery.this)
197 }
198 Expression::Select(select) => {
199 self.get_select_column_names(select)
201 }
202 Expression::Union(union) => {
203 self.get_source_columns_from_set_op(&Expression::Union(union.clone()))?
205 }
206 Expression::Intersect(intersect) => {
207 self.get_source_columns_from_set_op(&Expression::Intersect(intersect.clone()))?
208 }
209 Expression::Except(except) => {
210 self.get_source_columns_from_set_op(&Expression::Except(except.clone()))?
211 }
212 Expression::Cte(cte) => {
213 if !cte.columns.is_empty() {
214 cte.columns.iter().map(|c| c.name.clone()).collect()
215 } else {
216 self.get_named_selects(&cte.this)
217 }
218 }
219 Expression::Pivot(pivot) => self.get_pivot_output_columns(pivot),
220 Expression::Unpivot(unpivot) => self.get_unpivot_output_columns(unpivot),
221 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
222 alias_output_columns(alias)
223 }
224 Expression::Alias(alias) => {
225 let columns = self.get_source_columns_for_expression(&alias.this)?;
226 apply_alias_columns(columns, &alias.column_aliases)
227 }
228 Expression::Unnest(unnest) => unnest_output_columns(unnest),
229 Expression::Lateral(lateral) => lateral_output_columns(lateral),
230 Expression::LateralView(lateral_view) => lateral_view_output_columns(lateral_view),
231 Expression::Paren(paren) => self.get_source_columns_for_expression(&paren.this)?,
232 _ => Vec::new(),
233 };
234
235 Ok(columns)
236 }
237
238 fn get_named_selects(&self, expr: &Expression) -> Vec<String> {
240 match expr {
241 Expression::Select(select) => self.get_select_column_names(select),
242 Expression::Union(union) => {
243 self.get_named_selects(&union.left)
245 }
246 Expression::Intersect(intersect) => self.get_named_selects(&intersect.left),
247 Expression::Except(except) => self.get_named_selects(&except.left),
248 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
249 Expression::Alias(alias) => {
250 let columns = self.get_named_selects(&alias.this);
251 apply_alias_columns(columns, &alias.column_aliases)
252 }
253 Expression::Paren(paren) => self.get_named_selects(&paren.this),
254 _ => Vec::new(),
255 }
256 }
257
258 fn get_select_column_names(&self, select: &crate::expressions::Select) -> Vec<String> {
260 select
261 .expressions
262 .iter()
263 .filter_map(|expr| self.get_expression_alias(expr))
264 .collect()
265 }
266
267 fn get_expression_alias(&self, expr: &Expression) -> Option<String> {
269 match expr {
270 Expression::Alias(alias) => Some(alias.alias.name.clone()),
271 Expression::Column(col) => Some(col.name.name.clone()),
272 Expression::Star(_) => Some("*".to_string()),
273 Expression::Identifier(id) => Some(id.name.clone()),
274 _ => None,
275 }
276 }
277
278 fn get_pivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
279 if pivot.unpivot {
280 return self.get_pivot_unpivot_output_columns(pivot);
281 }
282
283 let pre_columns = self.get_source_output_columns(&pivot.this);
284 if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
285 return Vec::new();
286 }
287
288 let excluded = pivot_excluded_source_columns(pivot, self.dialect);
289 let generated = pivot_generated_output_columns(pivot, self.dialect);
290 if excluded.is_empty() || generated.is_empty() {
291 return Vec::new();
292 }
293
294 let mut columns: Vec<String> = pre_columns
295 .into_iter()
296 .filter(|column| !excluded.contains(&normalize_column_name(column, self.dialect)))
297 .collect();
298 columns.extend(generated);
299 apply_alias_columns(columns, &pivot.alias_columns)
300 }
301
302 fn get_pivot_unpivot_output_columns(&self, pivot: &crate::expressions::Pivot) -> Vec<String> {
303 let pre_columns = self.get_source_output_columns(&pivot.this);
304 if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
305 return Vec::new();
306 }
307
308 let input_columns: HashSet<String> = pivot
309 .expressions
310 .iter()
311 .flat_map(expression_column_names)
312 .map(|column| normalize_column_name(&column, self.dialect))
313 .collect();
314 let mut columns: Vec<String> = pre_columns
315 .into_iter()
316 .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
317 .collect();
318
319 if let Some(Expression::UnpivotColumns(unpivot_columns)) = pivot.into.as_deref() {
320 if let Some(name) = expression_name(&unpivot_columns.this) {
321 columns.push(name);
322 }
323 for value_column in &unpivot_columns.expressions {
324 if let Some(name) = expression_name(value_column) {
325 columns.push(name);
326 }
327 }
328 }
329
330 apply_alias_columns(columns, &pivot.alias_columns)
331 }
332
333 fn get_unpivot_output_columns(&self, unpivot: &crate::expressions::Unpivot) -> Vec<String> {
334 let pre_columns = self.get_source_output_columns(&unpivot.this);
335 if pre_columns.is_empty() || pre_columns.iter().any(|column| column == "*") {
336 return Vec::new();
337 }
338
339 let input_columns: HashSet<String> = unpivot
340 .columns
341 .iter()
342 .flat_map(expression_column_names)
343 .map(|column| normalize_column_name(&column, self.dialect))
344 .collect();
345 let mut columns: Vec<String> = pre_columns
346 .into_iter()
347 .filter(|column| !input_columns.contains(&normalize_column_name(column, self.dialect)))
348 .collect();
349 columns.push(unpivot.name_column.name.clone());
350 columns.push(unpivot.value_column.name.clone());
351 columns.extend(
352 unpivot
353 .extra_value_columns
354 .iter()
355 .map(|column| column.name.clone()),
356 );
357 apply_alias_columns(columns, &unpivot.alias_columns)
358 }
359
360 fn get_source_output_columns(&self, source: &Expression) -> Vec<String> {
361 match source {
362 Expression::Table(table) => {
363 if table.schema.is_none() && table.catalog.is_none() {
364 if let Some(source) = self.scope.cte_sources.get(&table.name.name) {
365 return self.extract_columns_from_source(source).unwrap_or_default();
366 }
367 }
368
369 let table_name = qualified_table_name(table);
370 self.schema.column_names(&table_name).unwrap_or_default()
371 }
372 Expression::Subquery(subquery) => self.get_named_selects(&subquery.this),
373 Expression::Select(select) => self.get_select_column_names(select),
374 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_) => self
375 .get_source_columns_from_set_op(source)
376 .unwrap_or_default(),
377 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
378 alias_output_columns(alias)
379 }
380 Expression::Alias(alias) => {
381 let columns = self.get_source_output_columns(&alias.this);
382 apply_alias_columns(columns, &alias.column_aliases)
383 }
384 Expression::Unnest(unnest) => unnest_output_columns(unnest),
385 Expression::Lateral(lateral) => lateral_output_columns(lateral),
386 Expression::LateralView(lateral_view) => lateral_view_output_columns(lateral_view),
387 Expression::Cte(cte) => {
388 if cte.columns.is_empty() {
389 self.get_named_selects(&cte.this)
390 } else {
391 cte.columns
392 .iter()
393 .map(|column| column.name.clone())
394 .collect()
395 }
396 }
397 Expression::Paren(paren) => self.get_source_output_columns(&paren.this),
398 _ => Vec::new(),
399 }
400 }
401
402 pub fn get_source_columns_from_set_op(
404 &self,
405 expression: &Expression,
406 ) -> ResolverResult<Vec<String>> {
407 match expression {
408 Expression::Select(select) => Ok(self.get_select_column_names(select)),
409 Expression::Subquery(subquery) => {
410 if matches!(
411 &subquery.this,
412 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
413 ) {
414 self.get_source_columns_from_set_op(&subquery.this)
415 } else {
416 Ok(self.get_named_selects(&subquery.this))
417 }
418 }
419 Expression::Alias(alias) => {
420 let columns = self.get_source_columns_from_set_op(&alias.this)?;
421 Ok(apply_alias_columns(columns, &alias.column_aliases))
422 }
423 Expression::Paren(paren) => self.get_source_columns_from_set_op(&paren.this),
424 Expression::Union(union) => {
425 self.get_source_columns_from_set_op(&union.left)
427 }
428 Expression::Intersect(intersect) => {
429 self.get_source_columns_from_set_op(&intersect.left)
430 }
431 Expression::Except(except) => self.get_source_columns_from_set_op(&except.left),
432 _ => Err(ResolverError::UnknownSetOperation(format!(
433 "{:?}",
434 expression
435 ))),
436 }
437 }
438
439 fn get_all_source_columns(&mut self) -> HashMap<String, Vec<String>> {
441 let source_names: Vec<_> = self.scope.sources.keys().cloned().collect();
442
443 let mut result = HashMap::new();
444 for source_name in source_names {
445 if let Ok(columns) = self.get_source_columns(&source_name) {
446 result.insert(source_name, columns);
447 }
448 }
449 result
450 }
451
452 fn get_table_name_from_sources(
454 &mut self,
455 column_name: &str,
456 source_columns: Option<&HashMap<String, Vec<String>>>,
457 ) -> Option<String> {
458 let normalized_column_name = normalize_column_name(column_name, self.dialect);
459 let unambiguous = match source_columns {
460 Some(cols) => self.compute_unambiguous_columns(cols),
461 None => {
462 if self.unambiguous_columns_cache.is_none() {
463 let all_source_columns = self.get_all_source_columns();
464 self.unambiguous_columns_cache =
465 Some(self.compute_unambiguous_columns(&all_source_columns));
466 }
467 self.unambiguous_columns_cache
468 .clone()
469 .expect("cache populated above")
470 }
471 };
472
473 unambiguous.get(&normalized_column_name).cloned()
474 }
475
476 fn compute_unambiguous_columns(
480 &self,
481 source_columns: &HashMap<String, Vec<String>>,
482 ) -> HashMap<String, String> {
483 if source_columns.is_empty() {
484 return HashMap::new();
485 }
486
487 let mut column_to_sources: HashMap<String, Vec<String>> = HashMap::new();
488
489 for (source_name, columns) in source_columns {
490 for column in columns {
491 column_to_sources
492 .entry(normalize_column_name(column, self.dialect))
493 .or_default()
494 .push(source_name.clone());
495 }
496 }
497
498 column_to_sources
500 .into_iter()
501 .filter(|(_, sources)| sources.len() == 1)
502 .map(|(column, sources)| (column, sources.into_iter().next().unwrap()))
503 .collect()
504 }
505
506 pub fn is_ambiguous(&mut self, column_name: &str) -> bool {
508 let normalized_column_name = normalize_column_name(column_name, self.dialect);
509 let all_source_columns = self.get_all_source_columns();
510 let sources_with_column: Vec<_> = all_source_columns
511 .iter()
512 .filter(|(_, columns)| {
513 columns.iter().any(|column| {
514 normalize_column_name(column, self.dialect) == normalized_column_name
515 })
516 })
517 .map(|(name, _)| name.clone())
518 .collect();
519
520 sources_with_column.len() > 1
521 }
522
523 pub fn sources_for_column(&mut self, column_name: &str) -> Vec<String> {
525 let normalized_column_name = normalize_column_name(column_name, self.dialect);
526 let all_source_columns = self.get_all_source_columns();
527 all_source_columns
528 .iter()
529 .filter(|(_, columns)| {
530 columns.iter().any(|column| {
531 normalize_column_name(column, self.dialect) == normalized_column_name
532 })
533 })
534 .map(|(name, _)| name.clone())
535 .collect()
536 }
537
538 pub fn disambiguate_in_join_context(
543 &mut self,
544 column_name: &str,
545 available_sources: &[String],
546 ) -> Option<String> {
547 let normalized_column_name = normalize_column_name(column_name, self.dialect);
548 let mut matching_sources = Vec::new();
549
550 for source_name in available_sources {
551 if let Ok(columns) = self.get_source_columns(source_name) {
552 if columns.iter().any(|column| {
553 normalize_column_name(column, self.dialect) == normalized_column_name
554 }) {
555 matching_sources.push(source_name.clone());
556 }
557 }
558 }
559
560 if matching_sources.len() == 1 {
561 Some(matching_sources.remove(0))
562 } else {
563 None
564 }
565 }
566}
567
568fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
569 normalize_name(name, dialect, false, true)
570}
571
572fn apply_alias_columns(mut columns: Vec<String>, alias_columns: &[Identifier]) -> Vec<String> {
573 for (idx, alias) in alias_columns.iter().enumerate() {
574 if let Some(column) = columns.get_mut(idx) {
575 *column = alias.name.clone();
576 }
577 }
578 columns
579}
580
581fn unnest_output_columns(unnest: &crate::expressions::UnnestFunc) -> Vec<String> {
582 unnest
583 .alias
584 .iter()
585 .map(|alias| alias.name.clone())
586 .chain(unnest.offset_alias.iter().map(|alias| alias.name.clone()))
587 .collect()
588}
589
590fn alias_output_columns(alias: &crate::expressions::Alias) -> Vec<String> {
591 if alias.column_aliases.is_empty() {
592 vec![alias.alias.name.clone()]
593 } else {
594 alias
595 .column_aliases
596 .iter()
597 .map(|column| column.name.clone())
598 .collect()
599 }
600}
601
602fn lateral_output_columns(lateral: &crate::expressions::Lateral) -> Vec<String> {
603 if lateral.column_aliases.is_empty() {
604 default_virtual_output_columns(&lateral.this)
605 } else {
606 lateral.column_aliases.clone()
607 }
608}
609
610fn lateral_view_output_columns(lateral_view: &crate::expressions::LateralView) -> Vec<String> {
611 lateral_view
612 .column_aliases
613 .iter()
614 .map(|column| column.name.clone())
615 .collect()
616}
617
618fn default_virtual_output_columns(expression: &Expression) -> Vec<String> {
619 match expression {
620 Expression::Unnest(unnest) => unnest_output_columns(unnest),
621 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
622 alias_output_columns(alias)
623 }
624 Expression::Function(function) if function.name.eq_ignore_ascii_case("FLATTEN") => {
625 ["seq", "key", "path", "index", "value", "this"]
626 .into_iter()
627 .map(String::from)
628 .collect()
629 }
630 _ => Vec::new(),
631 }
632}
633
634fn pivot_excluded_source_columns(
635 pivot: &crate::expressions::Pivot,
636 dialect: Option<DialectType>,
637) -> HashSet<String> {
638 pivot
639 .fields
640 .iter()
641 .chain(pivot.expressions.iter())
642 .chain(pivot.using.iter())
643 .flat_map(expression_column_names)
644 .map(|column| normalize_column_name(&column, dialect))
645 .collect()
646}
647
648fn pivot_generated_output_columns(
649 pivot: &crate::expressions::Pivot,
650 _dialect: Option<DialectType>,
651) -> Vec<String> {
652 let fields = pivot_field_output_names(pivot);
653 let aggregations = if pivot.using.is_empty() {
654 &pivot.expressions
655 } else {
656 &pivot.using
657 };
658
659 if fields.is_empty() || aggregations.is_empty() {
660 return Vec::new();
661 }
662
663 let needs_suffix = aggregations.len() > 1;
664 let mut outputs = Vec::new();
665 for field in fields {
666 for aggregation in aggregations {
667 if let Some(suffix) = pivot_aggregation_output_suffix(aggregation, needs_suffix) {
668 outputs.push(format!("{field}_{suffix}"));
669 } else {
670 outputs.push(field.clone());
671 }
672 }
673 }
674 outputs
675}
676
677fn pivot_field_output_names(pivot: &crate::expressions::Pivot) -> Vec<String> {
678 pivot
679 .fields
680 .iter()
681 .filter_map(|field| match field {
682 Expression::In(in_expr) => Some(
683 in_expr
684 .expressions
685 .iter()
686 .filter_map(expression_name)
687 .collect::<Vec<_>>(),
688 ),
689 _ => None,
690 })
691 .flatten()
692 .collect()
693}
694
695fn pivot_aggregation_output_suffix(expr: &Expression, needs_suffix: bool) -> Option<String> {
696 match expr {
697 Expression::Alias(alias) => Some(alias.alias.name.clone()),
698 _ if needs_suffix => Generator::sql(expr).ok().map(|sql| sql.to_lowercase()),
699 _ => None,
700 }
701}
702
703fn expression_name(expr: &Expression) -> Option<String> {
704 match expr {
705 Expression::PivotAlias(alias) => expression_name(&alias.alias),
706 Expression::Alias(alias) => Some(alias.alias.name.clone()),
707 Expression::Identifier(identifier) => Some(identifier.name.clone()),
708 Expression::Column(column) => Some(column.name.name.clone()),
709 Expression::Literal(literal) => Some(literal.value_str().to_string()),
710 Expression::Var(var) => Some(var.this.clone()),
711 Expression::Tuple(tuple) => tuple.expressions.first().and_then(expression_name),
712 _ => None,
713 }
714}
715
716fn expression_column_names(expr: &Expression) -> Vec<String> {
717 expr.find_all(|node| matches!(node, Expression::Column(_)))
718 .into_iter()
719 .filter_map(|node| match node {
720 Expression::Column(column) => Some(column.name.name.clone()),
721 _ => None,
722 })
723 .collect()
724}
725
726pub fn resolve_column(
730 scope: &Scope,
731 schema: &dyn Schema,
732 column_name: &str,
733 infer_schema: bool,
734) -> Option<String> {
735 let mut resolver = Resolver::new(scope, schema, infer_schema);
736 resolver.get_table(column_name)
737}
738
739pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str) -> bool {
741 let mut resolver = Resolver::new(scope, schema, true);
742 resolver.is_ambiguous(column_name)
743}
744
745fn qualified_table_name(table: &TableRef) -> String {
747 let mut parts = Vec::new();
748 if let Some(catalog) = &table.catalog {
749 parts.push(catalog.name.clone());
750 }
751 if let Some(schema) = &table.schema {
752 parts.push(schema.name.clone());
753 }
754 parts.push(table.name.name.clone());
755 parts.join(".")
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use crate::dialects::Dialect;
762 use crate::expressions::DataType;
763 use crate::parser::Parser;
764 use crate::schema::MappingSchema;
765 use crate::scope::build_scope;
766
767 fn create_test_schema() -> MappingSchema {
768 let mut schema = MappingSchema::new();
769 schema
771 .add_table(
772 "users",
773 &[
774 (
775 "id".to_string(),
776 DataType::Int {
777 length: None,
778 integer_spelling: false,
779 },
780 ),
781 ("name".to_string(), DataType::Text),
782 ("email".to_string(), DataType::Text),
783 ],
784 None,
785 )
786 .unwrap();
787 schema
788 .add_table(
789 "orders",
790 &[
791 (
792 "id".to_string(),
793 DataType::Int {
794 length: None,
795 integer_spelling: false,
796 },
797 ),
798 (
799 "user_id".to_string(),
800 DataType::Int {
801 length: None,
802 integer_spelling: false,
803 },
804 ),
805 (
806 "amount".to_string(),
807 DataType::Double {
808 precision: None,
809 scale: None,
810 },
811 ),
812 ],
813 None,
814 )
815 .unwrap();
816 schema
817 }
818
819 #[test]
820 fn test_resolver_basic() {
821 let ast = Parser::parse_sql("SELECT id, name FROM users").expect("Failed to parse");
822 let scope = build_scope(&ast[0]);
823 let schema = create_test_schema();
824 let mut resolver = Resolver::new(&scope, &schema, true);
825
826 let table = resolver.get_table("name");
828 assert_eq!(table, Some("users".to_string()));
829 }
830
831 #[test]
832 fn test_resolver_ambiguous_column() {
833 let ast =
834 Parser::parse_sql("SELECT id FROM users JOIN orders ON users.id = orders.user_id")
835 .expect("Failed to parse");
836 let scope = build_scope(&ast[0]);
837 let schema = create_test_schema();
838 let mut resolver = Resolver::new(&scope, &schema, true);
839
840 assert!(resolver.is_ambiguous("id"));
842
843 assert!(!resolver.is_ambiguous("name"));
845
846 assert!(!resolver.is_ambiguous("amount"));
848 }
849
850 #[test]
851 fn test_resolver_unambiguous_column() {
852 let ast = Parser::parse_sql(
853 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
854 )
855 .expect("Failed to parse");
856 let scope = build_scope(&ast[0]);
857 let schema = create_test_schema();
858 let mut resolver = Resolver::new(&scope, &schema, true);
859
860 let table = resolver.get_table("name");
862 assert_eq!(table, Some("users".to_string()));
863
864 let table = resolver.get_table("amount");
866 assert_eq!(table, Some("orders".to_string()));
867 }
868
869 #[test]
870 fn test_resolver_with_alias() {
871 let ast = Parser::parse_sql("SELECT u.id FROM users AS u").expect("Failed to parse");
872 let scope = build_scope(&ast[0]);
873 let schema = create_test_schema();
874 let _resolver = Resolver::new(&scope, &schema, true);
875
876 assert!(scope.sources.contains_key("u"));
878 }
879
880 #[test]
881 fn test_sources_for_column() {
882 let ast = Parser::parse_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id")
883 .expect("Failed to parse");
884 let scope = build_scope(&ast[0]);
885 let schema = create_test_schema();
886 let mut resolver = Resolver::new(&scope, &schema, true);
887
888 let sources = resolver.sources_for_column("id");
890 assert!(sources.contains(&"users".to_string()));
891 assert!(sources.contains(&"orders".to_string()));
892
893 let sources = resolver.sources_for_column("email");
895 assert_eq!(sources, vec!["users".to_string()]);
896 }
897
898 #[test]
899 fn test_all_columns() {
900 let ast = Parser::parse_sql("SELECT * FROM users").expect("Failed to parse");
901 let scope = build_scope(&ast[0]);
902 let schema = create_test_schema();
903 let mut resolver = Resolver::new(&scope, &schema, true);
904
905 let all = resolver.all_columns();
906 assert!(all.contains("id"));
907 assert!(all.contains("name"));
908 assert!(all.contains("email"));
909 }
910
911 #[test]
912 fn test_resolver_cte_projected_alias_column() {
913 let ast = Parser::parse_sql(
914 "WITH my_cte AS (SELECT id AS emp_id FROM users) SELECT emp_id FROM my_cte",
915 )
916 .expect("Failed to parse");
917 let scope = build_scope(&ast[0]);
918 let schema = create_test_schema();
919 let mut resolver = Resolver::new(&scope, &schema, true);
920
921 let table = resolver.get_table("emp_id");
922 assert_eq!(table, Some("my_cte".to_string()));
923 }
924
925 #[test]
926 fn test_resolve_column_helper() {
927 let ast = Parser::parse_sql("SELECT name FROM users").expect("Failed to parse");
928 let scope = build_scope(&ast[0]);
929 let schema = create_test_schema();
930
931 let table = resolve_column(&scope, &schema, "name", true);
932 assert_eq!(table, Some("users".to_string()));
933 }
934
935 #[test]
936 fn test_resolver_bigquery_mixed_case_column_names() {
937 let dialect = Dialect::get(DialectType::BigQuery);
938 let expr = dialect
939 .parse("SELECT Name AS name FROM teams")
940 .unwrap()
941 .into_iter()
942 .next()
943 .expect("expected one expression");
944 let scope = build_scope(&expr);
945
946 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
947 schema
948 .add_table(
949 "teams",
950 &[("Name".into(), DataType::String { length: None })],
951 None,
952 )
953 .expect("schema setup");
954
955 let mut resolver = Resolver::new(&scope, &schema, true);
956 let table = resolver.get_table("Name");
957 assert_eq!(table, Some("teams".to_string()));
958
959 let table = resolver.get_table("name");
960 assert_eq!(table, Some("teams".to_string()));
961 }
962}