1use crate::dialects::transform_recursive;
9use crate::dialects::DialectType;
10use crate::expressions::{
11 Alias, BinaryOp, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren,
12 Select, TableRef, VarArgFunc, With,
13};
14use crate::resolver::{Resolver, ResolverError};
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{build_scope, traverse_scope, Scope};
17use std::cell::RefCell;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum QualifyColumnsError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Unknown column: {0}")]
28 UnknownColumn(String),
29
30 #[error("Ambiguous column: {0}")]
31 AmbiguousColumn(String),
32
33 #[error("Cannot automatically join: {0}")]
34 CannotAutoJoin(String),
35
36 #[error("Unknown output column: {0}")]
37 UnknownOutputColumn(String),
38
39 #[error("Column could not be resolved: {column}{for_table}")]
40 ColumnNotResolved { column: String, for_table: String },
41
42 #[error("Resolver error: {0}")]
43 ResolverError(#[from] ResolverError),
44}
45
46pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
48
49#[derive(Debug, Clone, Default)]
51pub struct QualifyColumnsOptions {
52 pub expand_alias_refs: bool,
54 pub expand_stars: bool,
56 pub infer_schema: Option<bool>,
58 pub allow_partial_qualification: bool,
60 pub dialect: Option<DialectType>,
62}
63
64impl QualifyColumnsOptions {
65 pub fn new() -> Self {
67 Self {
68 expand_alias_refs: true,
69 expand_stars: true,
70 infer_schema: None,
71 allow_partial_qualification: false,
72 dialect: None,
73 }
74 }
75
76 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
78 self.expand_alias_refs = expand;
79 self
80 }
81
82 pub fn with_expand_stars(mut self, expand: bool) -> Self {
84 self.expand_stars = expand;
85 self
86 }
87
88 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
90 self.dialect = Some(dialect);
91 self
92 }
93
94 pub fn with_allow_partial(mut self, allow: bool) -> Self {
96 self.allow_partial_qualification = allow;
97 self
98 }
99}
100
101pub fn qualify_columns(
116 expression: Expression,
117 schema: &dyn Schema,
118 options: &QualifyColumnsOptions,
119) -> QualifyColumnsResult<Expression> {
120 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
121 let dialect = options.dialect.or_else(|| schema.dialect());
122 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
123
124 let transformed = transform_recursive(expression, &|node| {
125 if first_error.borrow().is_some() {
126 return Ok(node);
127 }
128
129 match node {
130 Expression::Select(mut select) => {
131 if let Some(with) = &mut select.with {
132 pushdown_cte_alias_columns_with(with);
133 }
134
135 let scope_expr = Expression::Select(select.clone());
136 let scope = build_scope(&scope_expr);
137 let mut resolver = Resolver::new(&scope, schema, infer_schema);
138
139 let column_tables = if first_error.borrow().is_none() {
141 match expand_using(&mut select, &scope, &mut resolver) {
142 Ok(ct) => ct,
143 Err(err) => {
144 *first_error.borrow_mut() = Some(err);
145 HashMap::new()
146 }
147 }
148 } else {
149 HashMap::new()
150 };
151
152 if first_error.borrow().is_none() {
154 if let Err(err) = qualify_columns_in_scope(
155 &mut select,
156 &scope,
157 &mut resolver,
158 options.allow_partial_qualification,
159 ) {
160 *first_error.borrow_mut() = Some(err);
161 }
162 }
163
164 if first_error.borrow().is_none() && options.expand_alias_refs {
166 if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
167 *first_error.borrow_mut() = Some(err);
168 }
169 }
170
171 if first_error.borrow().is_none() && options.expand_stars {
173 if let Err(err) =
174 expand_stars(&mut select, &scope, &mut resolver, &column_tables)
175 {
176 *first_error.borrow_mut() = Some(err);
177 }
178 }
179
180 if first_error.borrow().is_none() {
182 if let Err(err) = qualify_outputs_select(&mut select) {
183 *first_error.borrow_mut() = Some(err);
184 }
185 }
186
187 if first_error.borrow().is_none() {
189 if let Err(err) = expand_group_by(&mut select, dialect) {
190 *first_error.borrow_mut() = Some(err);
191 }
192 }
193
194 Ok(Expression::Select(select))
195 }
196 _ => Ok(node),
197 }
198 })
199 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
200
201 if let Some(err) = first_error.into_inner() {
202 return Err(err);
203 }
204
205 Ok(transformed)
206}
207
208pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
213 let mut all_unqualified = Vec::new();
214
215 for scope in traverse_scope(expression) {
216 if let Expression::Select(_) = &scope.expression {
217 let unqualified = get_unqualified_columns(&scope);
219
220 let external = get_external_columns(&scope);
222 if !external.is_empty() && !is_correlated_subquery(&scope) {
223 let first = &external[0];
224 let for_table = if first.table.is_some() {
225 format!(" for table: '{}'", first.table.as_ref().unwrap())
226 } else {
227 String::new()
228 };
229 return Err(QualifyColumnsError::ColumnNotResolved {
230 column: first.name.clone(),
231 for_table,
232 });
233 }
234
235 all_unqualified.extend(unqualified);
236 }
237 }
238
239 if !all_unqualified.is_empty() {
240 let first = &all_unqualified[0];
241 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
242 }
243
244 Ok(())
245}
246
247fn get_source_name(expr: &Expression) -> Option<String> {
249 match expr {
250 Expression::Table(t) => Some(
251 t.alias
252 .as_ref()
253 .map(|a| a.name.clone())
254 .unwrap_or_else(|| t.name.name.clone()),
255 ),
256 Expression::Subquery(sq) => sq.alias.as_ref().map(|a| a.name.clone()),
257 _ => None,
258 }
259}
260
261fn get_ordered_source_names(select: &Select) -> Vec<String> {
264 let mut ordered = Vec::new();
265 if let Some(from) = &select.from {
266 for expr in &from.expressions {
267 if let Some(name) = get_source_name(expr) {
268 ordered.push(name);
269 }
270 }
271 }
272 for join in &select.joins {
273 if let Some(name) = get_source_name(&join.this) {
274 ordered.push(name);
275 }
276 }
277 ordered
278}
279
280fn make_coalesce(column_name: &str, tables: &[String]) -> Expression {
282 let args: Vec<Expression> = tables
283 .iter()
284 .map(|t| Expression::qualified_column(t.as_str(), column_name))
285 .collect();
286 Expression::Coalesce(Box::new(VarArgFunc {
287 expressions: args,
288 original_name: None,
289 inferred_type: None,
290 }))
291}
292
293fn expand_using(
299 select: &mut Select,
300 _scope: &Scope,
301 resolver: &mut Resolver,
302) -> QualifyColumnsResult<HashMap<String, Vec<String>>> {
303 let mut columns: HashMap<String, String> = HashMap::new();
305
306 let mut column_tables: HashMap<String, Vec<String>> = HashMap::new();
308
309 let join_names: HashSet<String> = select
311 .joins
312 .iter()
313 .filter_map(|j| get_source_name(&j.this))
314 .collect();
315
316 let all_ordered = get_ordered_source_names(select);
317 let mut ordered: Vec<String> = all_ordered
318 .iter()
319 .filter(|name| !join_names.contains(name.as_str()))
320 .cloned()
321 .collect();
322
323 if join_names.is_empty() {
324 return Ok(column_tables);
325 }
326
327 fn update_source_columns(
329 source_name: &str,
330 columns: &mut HashMap<String, String>,
331 resolver: &mut Resolver,
332 ) {
333 if let Ok(source_cols) = resolver.get_source_columns(source_name) {
334 for col_name in source_cols {
335 columns
336 .entry(col_name)
337 .or_insert_with(|| source_name.to_string());
338 }
339 }
340 }
341
342 for source_name in &ordered {
344 update_source_columns(source_name, &mut columns, resolver);
345 }
346
347 for i in 0..select.joins.len() {
348 let source_table = ordered.last().cloned().unwrap_or_default();
350 if !source_table.is_empty() {
351 update_source_columns(&source_table, &mut columns, resolver);
352 }
353
354 let join_table = get_source_name(&select.joins[i].this).unwrap_or_default();
356 ordered.push(join_table.clone());
357
358 if select.joins[i].using.is_empty() {
360 continue;
361 }
362
363 let _join_columns: Vec<String> =
364 resolver.get_source_columns(&join_table).unwrap_or_default();
365
366 let using_identifiers: Vec<String> = select.joins[i]
367 .using
368 .iter()
369 .map(|id| id.name.clone())
370 .collect();
371
372 let using_count = using_identifiers.len();
373 let is_semi_or_anti = matches!(
374 select.joins[i].kind,
375 crate::expressions::JoinKind::Semi
376 | crate::expressions::JoinKind::Anti
377 | crate::expressions::JoinKind::LeftSemi
378 | crate::expressions::JoinKind::LeftAnti
379 | crate::expressions::JoinKind::RightSemi
380 | crate::expressions::JoinKind::RightAnti
381 );
382
383 let mut conditions: Vec<Expression> = Vec::new();
384
385 for identifier in &using_identifiers {
386 let table = columns
387 .get(identifier)
388 .cloned()
389 .unwrap_or_else(|| source_table.clone());
390
391 let lhs = if i == 0 || using_count == 1 {
393 Expression::qualified_column(table.as_str(), identifier.as_str())
395 } else {
396 let coalesce_cols: Vec<String> = ordered[..ordered.len() - 1]
399 .iter()
400 .filter(|t| {
401 resolver
402 .get_source_columns(t)
403 .unwrap_or_default()
404 .contains(identifier)
405 })
406 .cloned()
407 .collect();
408
409 if coalesce_cols.len() > 1 {
410 make_coalesce(identifier, &coalesce_cols)
411 } else {
412 Expression::qualified_column(table.as_str(), identifier.as_str())
413 }
414 };
415
416 let rhs = Expression::qualified_column(join_table.as_str(), identifier.as_str());
418
419 conditions.push(Expression::Eq(Box::new(BinaryOp::new(lhs, rhs))));
420
421 if !is_semi_or_anti {
423 let tables = column_tables
424 .entry(identifier.clone())
425 .or_insert_with(Vec::new);
426 if !tables.contains(&table) {
427 tables.push(table.clone());
428 }
429 if !tables.contains(&join_table) {
430 tables.push(join_table.clone());
431 }
432 }
433 }
434
435 let on_condition = conditions
437 .into_iter()
438 .reduce(|acc, cond| Expression::And(Box::new(BinaryOp::new(acc, cond))))
439 .expect("at least one USING column");
440
441 select.joins[i].on = Some(on_condition);
443 select.joins[i].using = vec![];
444 }
445
446 if !column_tables.is_empty() {
448 let mut new_expressions = Vec::with_capacity(select.expressions.len());
450 for expr in &select.expressions {
451 match expr {
452 Expression::Column(col)
453 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
454 {
455 let tables = &column_tables[&col.name.name];
456 let coalesce = make_coalesce(&col.name.name, tables);
457 new_expressions.push(Expression::Alias(Box::new(Alias {
459 this: coalesce,
460 alias: Identifier::new(&col.name.name),
461 column_aliases: vec![],
462 alias_explicit_as: false,
463 alias_keyword: None,
464 pre_alias_comments: vec![],
465 trailing_comments: vec![],
466 inferred_type: None,
467 })));
468 }
469 _ => {
470 let mut rewritten = expr.clone();
471 rewrite_using_columns_in_expression(&mut rewritten, &column_tables);
472 new_expressions.push(rewritten);
473 }
474 }
475 }
476 select.expressions = new_expressions;
477
478 if let Some(where_clause) = &mut select.where_clause {
480 rewrite_using_columns_in_expression(&mut where_clause.this, &column_tables);
481 }
482
483 if let Some(group_by) = &mut select.group_by {
485 for expr in &mut group_by.expressions {
486 rewrite_using_columns_in_expression(expr, &column_tables);
487 }
488 }
489
490 if let Some(having) = &mut select.having {
492 rewrite_using_columns_in_expression(&mut having.this, &column_tables);
493 }
494
495 if let Some(qualify) = &mut select.qualify {
497 rewrite_using_columns_in_expression(&mut qualify.this, &column_tables);
498 }
499
500 if let Some(order_by) = &mut select.order_by {
502 for ordered in &mut order_by.expressions {
503 rewrite_using_columns_in_expression(&mut ordered.this, &column_tables);
504 }
505 }
506 }
507
508 Ok(column_tables)
509}
510
511fn rewrite_using_columns_in_expression(
513 expr: &mut Expression,
514 column_tables: &HashMap<String, Vec<String>>,
515) {
516 let transformed = transform_recursive(expr.clone(), &|node| match node {
517 Expression::Column(col)
518 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
519 {
520 let tables = &column_tables[&col.name.name];
521 Ok(make_coalesce(&col.name.name, tables))
522 }
523 other => Ok(other),
524 });
525
526 if let Ok(next) = transformed {
527 *expr = next;
528 }
529}
530
531fn qualify_columns_in_scope(
533 select: &mut Select,
534 scope: &Scope,
535 resolver: &mut Resolver,
536 allow_partial: bool,
537) -> QualifyColumnsResult<()> {
538 for expr in &mut select.expressions {
539 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
540 }
541 if let Some(where_clause) = &mut select.where_clause {
542 qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
543 }
544 if let Some(group_by) = &mut select.group_by {
545 for expr in &mut group_by.expressions {
546 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
547 }
548 }
549 if let Some(having) = &mut select.having {
550 qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
551 }
552 if let Some(qualify) = &mut select.qualify {
553 qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
554 }
555 if let Some(order_by) = &mut select.order_by {
556 for ordered in &mut order_by.expressions {
557 qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
558 }
559 }
560 for join in &mut select.joins {
561 qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
562 if let Some(on) = &mut join.on {
563 qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
564 }
565 }
566 Ok(())
567}
568
569fn expand_alias_refs(
576 select: &mut Select,
577 _resolver: &mut Resolver,
578 _dialect: Option<DialectType>,
579) -> QualifyColumnsResult<()> {
580 let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
581
582 for (i, expr) in select.expressions.iter_mut().enumerate() {
583 replace_alias_refs_in_expression(expr, &alias_to_expression, false);
584 if let Expression::Alias(alias) = expr {
585 alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
586 }
587 }
588
589 if let Some(where_clause) = &mut select.where_clause {
590 replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
591 }
592 if let Some(group_by) = &mut select.group_by {
593 for expr in &mut group_by.expressions {
594 replace_alias_refs_in_expression(expr, &alias_to_expression, true);
595 }
596 }
597 if let Some(having) = &mut select.having {
598 replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
599 }
600 if let Some(qualify) = &mut select.qualify {
601 replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
602 }
603 if let Some(order_by) = &mut select.order_by {
604 for ordered in &mut order_by.expressions {
605 replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
606 }
607 }
608
609 Ok(())
610}
611
612fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
619 let projections = select.expressions.clone();
620
621 if let Some(group_by) = &mut select.group_by {
622 for group_expr in &mut group_by.expressions {
623 if let Some(index) = positional_reference(group_expr) {
624 let replacement = select_expression_at_position(&projections, index)?;
625 *group_expr = replacement;
626 }
627 }
628 }
629 Ok(())
630}
631
632fn expand_stars(
642 select: &mut Select,
643 _scope: &Scope,
644 resolver: &mut Resolver,
645 column_tables: &HashMap<String, Vec<String>>,
646) -> QualifyColumnsResult<()> {
647 let mut new_selections: Vec<Expression> = Vec::new();
648 let mut has_star = false;
649 let mut coalesced_columns: HashSet<String> = HashSet::new();
650
651 let ordered_sources = get_ordered_source_names(select);
653
654 for expr in &select.expressions {
655 match expr {
656 Expression::Star(star) => {
657 has_star = true;
658 if let Some(table) = &star.table {
659 let table_name = &table.name;
660 if !ordered_sources.contains(table_name) {
661 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
662 }
663 if let Ok(columns) = resolver.get_source_columns(table_name) {
664 if columns.contains(&"*".to_string()) || columns.is_empty() {
665 return Ok(());
666 }
667 for col_name in &columns {
668 if coalesced_columns.contains(col_name) {
669 continue;
670 }
671 if let Some(tables) = column_tables.get(col_name) {
672 if tables.contains(table_name) {
673 coalesced_columns.insert(col_name.clone());
674 let coalesce = make_coalesce(col_name, tables);
675 new_selections.push(Expression::Alias(Box::new(Alias {
676 this: coalesce,
677 alias: Identifier::new(col_name),
678 column_aliases: vec![],
679 alias_explicit_as: false,
680 alias_keyword: None,
681 pre_alias_comments: vec![],
682 trailing_comments: vec![],
683 inferred_type: None,
684 })));
685 continue;
686 }
687 }
688 new_selections
689 .push(create_qualified_column(col_name, Some(table_name)));
690 }
691 }
692 } else {
693 for source_name in &ordered_sources {
694 if let Ok(columns) = resolver.get_source_columns(source_name) {
695 if columns.contains(&"*".to_string()) || columns.is_empty() {
696 return Ok(());
697 }
698 for col_name in &columns {
699 if coalesced_columns.contains(col_name) {
700 continue;
702 }
703 if let Some(tables) = column_tables.get(col_name) {
704 if tables.contains(source_name) {
705 coalesced_columns.insert(col_name.clone());
707 let coalesce = make_coalesce(col_name, tables);
708 new_selections.push(Expression::Alias(Box::new(Alias {
709 this: coalesce,
710 alias: Identifier::new(col_name),
711 column_aliases: vec![],
712 alias_explicit_as: false,
713 alias_keyword: None,
714 pre_alias_comments: vec![],
715 trailing_comments: vec![],
716 inferred_type: None,
717 })));
718 continue;
719 }
720 }
721 new_selections
722 .push(create_qualified_column(col_name, Some(source_name)));
723 }
724 }
725 }
726 }
727 }
728 Expression::Column(col) if is_star_column(col) => {
729 has_star = true;
730 if let Some(table) = &col.table {
731 let table_name = &table.name;
732 if !ordered_sources.contains(table_name) {
733 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
734 }
735 if let Ok(columns) = resolver.get_source_columns(table_name) {
736 if columns.contains(&"*".to_string()) || columns.is_empty() {
737 return Ok(());
738 }
739 for col_name in &columns {
740 if coalesced_columns.contains(col_name) {
741 continue;
742 }
743 if let Some(tables) = column_tables.get(col_name) {
744 if tables.contains(table_name) {
745 coalesced_columns.insert(col_name.clone());
746 let coalesce = make_coalesce(col_name, tables);
747 new_selections.push(Expression::Alias(Box::new(Alias {
748 this: coalesce,
749 alias: Identifier::new(col_name),
750 column_aliases: vec![],
751 alias_explicit_as: false,
752 alias_keyword: None,
753 pre_alias_comments: vec![],
754 trailing_comments: vec![],
755 inferred_type: None,
756 })));
757 continue;
758 }
759 }
760 new_selections
761 .push(create_qualified_column(col_name, Some(table_name)));
762 }
763 }
764 }
765 }
766 _ => new_selections.push(expr.clone()),
767 }
768 }
769
770 if has_star {
771 select.expressions = new_selections;
772 }
773
774 Ok(())
775}
776
777pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
784 if let Expression::Select(mut select) = scope.expression.clone() {
785 qualify_outputs_select(&mut select)?;
786 }
787 Ok(())
788}
789
790fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
791 let mut new_selections: Vec<Expression> = Vec::new();
792
793 for (i, expr) in select.expressions.iter().enumerate() {
794 match expr {
795 Expression::Alias(_) => new_selections.push(expr.clone()),
796 Expression::Column(col) => {
797 new_selections.push(create_alias(expr.clone(), &col.name.name));
798 }
799 Expression::Star(_) => new_selections.push(expr.clone()),
800 _ => {
801 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
802 new_selections.push(create_alias(expr.clone(), &alias_name));
803 }
804 }
805 }
806
807 select.expressions = new_selections;
808 Ok(())
809}
810
811fn qualify_columns_in_expression(
812 expr: &mut Expression,
813 scope: &Scope,
814 resolver: &mut Resolver,
815 allow_partial: bool,
816) -> QualifyColumnsResult<()> {
817 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
818 let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
819
820 let transformed = transform_recursive(expr.clone(), &|node| {
821 if first_error.borrow().is_some() {
822 return Ok(node);
823 }
824
825 match node {
826 Expression::Column(mut col) => {
827 if let Err(err) = qualify_single_column(
828 &mut col,
829 scope,
830 &mut resolver_cell.borrow_mut(),
831 allow_partial,
832 ) {
833 *first_error.borrow_mut() = Some(err);
834 }
835 Ok(Expression::Column(col))
836 }
837 _ => Ok(node),
838 }
839 })
840 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
841
842 if let Some(err) = first_error.into_inner() {
843 return Err(err);
844 }
845
846 *expr = transformed;
847 Ok(())
848}
849
850fn qualify_single_column(
851 col: &mut Column,
852 scope: &Scope,
853 resolver: &mut Resolver,
854 allow_partial: bool,
855) -> QualifyColumnsResult<()> {
856 if is_star_column(col) {
857 return Ok(());
858 }
859
860 if let Some(table) = &col.table {
861 let table_name = &table.name;
862 if !scope.sources.contains_key(table_name) {
863 if resolver.table_exists_in_schema(table_name) {
867 return Ok(());
868 }
869 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
870 }
871
872 if let Ok(source_columns) = resolver.get_source_columns(table_name) {
873 let normalized_column_name = normalize_column_name(&col.name.name, resolver.dialect);
874 if !allow_partial
875 && !source_columns.is_empty()
876 && !source_columns.iter().any(|column| {
877 normalize_column_name(column, resolver.dialect) == normalized_column_name
878 })
879 && !source_columns.contains(&"*".to_string())
880 {
881 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
882 }
883 }
884 return Ok(());
885 }
886
887 if let Some(table_name) = resolver.get_table(&col.name.name) {
888 col.table = Some(Identifier::new(table_name));
889 return Ok(());
890 }
891
892 if let Some(outer_table) = resolver.find_column_in_outer_schema_tables(&col.name.name) {
895 col.table = Some(Identifier::new(outer_table));
896 return Ok(());
897 }
898
899 if !allow_partial {
900 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
901 }
902
903 Ok(())
904}
905
906fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
907 normalize_name(name, dialect, false, true)
908}
909
910fn replace_alias_refs_in_expression(
911 expr: &mut Expression,
912 alias_to_expression: &HashMap<String, (Expression, usize)>,
913 literal_index: bool,
914) {
915 let transformed = transform_recursive(expr.clone(), &|node| match node {
916 Expression::Column(col) if col.table.is_none() => {
917 if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
918 if literal_index && matches!(alias_expr, Expression::Literal(_)) {
919 return Ok(Expression::number(*index as i64));
920 }
921 return Ok(Expression::Paren(Box::new(Paren {
922 this: alias_expr.clone(),
923 trailing_comments: vec![],
924 })));
925 }
926 Ok(Expression::Column(col))
927 }
928 other => Ok(other),
929 });
930
931 if let Ok(next) = transformed {
932 *expr = next;
933 }
934}
935
936fn positional_reference(expr: &Expression) -> Option<usize> {
937 match expr {
938 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
939 let Literal::Number(value) = lit.as_ref() else {
940 unreachable!()
941 };
942 value.parse::<usize>().ok()
943 }
944 _ => None,
945 }
946}
947
948fn select_expression_at_position(
949 projections: &[Expression],
950 index: usize,
951) -> QualifyColumnsResult<Expression> {
952 if index == 0 || index > projections.len() {
953 return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
954 }
955
956 let projection = projections[index - 1].clone();
957 Ok(match projection {
958 Expression::Alias(alias) => alias.this.clone(),
959 other => other,
960 })
961}
962
963fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
966 let mut words: HashSet<&'static str> = [
968 "ADD",
970 "ALL",
971 "ALTER",
972 "AND",
973 "ANY",
974 "AS",
975 "ASC",
976 "BETWEEN",
977 "BY",
978 "CASE",
979 "CAST",
980 "CHECK",
981 "COLUMN",
982 "CONSTRAINT",
983 "CREATE",
984 "CROSS",
985 "CURRENT",
986 "CURRENT_DATE",
987 "CURRENT_TIME",
988 "CURRENT_TIMESTAMP",
989 "CURRENT_USER",
990 "DATABASE",
991 "DEFAULT",
992 "DELETE",
993 "DESC",
994 "DISTINCT",
995 "DROP",
996 "ELSE",
997 "END",
998 "ESCAPE",
999 "EXCEPT",
1000 "EXISTS",
1001 "FALSE",
1002 "FETCH",
1003 "FOR",
1004 "FOREIGN",
1005 "FROM",
1006 "FULL",
1007 "GRANT",
1008 "GROUP",
1009 "HAVING",
1010 "IF",
1011 "IN",
1012 "INDEX",
1013 "INNER",
1014 "INSERT",
1015 "INTERSECT",
1016 "INTO",
1017 "IS",
1018 "JOIN",
1019 "KEY",
1020 "LEFT",
1021 "LIKE",
1022 "LIMIT",
1023 "NATURAL",
1024 "NOT",
1025 "NULL",
1026 "OFFSET",
1027 "ON",
1028 "OR",
1029 "ORDER",
1030 "OUTER",
1031 "PRIMARY",
1032 "REFERENCES",
1033 "REPLACE",
1034 "RETURNING",
1035 "RIGHT",
1036 "ROLLBACK",
1037 "ROW",
1038 "ROWS",
1039 "SELECT",
1040 "SESSION_USER",
1041 "SET",
1042 "SOME",
1043 "TABLE",
1044 "THEN",
1045 "TO",
1046 "TRUE",
1047 "TRUNCATE",
1048 "UNION",
1049 "UNIQUE",
1050 "UPDATE",
1051 "USING",
1052 "VALUES",
1053 "VIEW",
1054 "WHEN",
1055 "WHERE",
1056 "WINDOW",
1057 "WITH",
1058 ]
1059 .iter()
1060 .copied()
1061 .collect();
1062
1063 match dialect {
1065 Some(DialectType::MySQL) => {
1066 words.extend(
1067 [
1068 "ANALYZE",
1069 "BOTH",
1070 "CHANGE",
1071 "CONDITION",
1072 "DATABASES",
1073 "DAY_HOUR",
1074 "DAY_MICROSECOND",
1075 "DAY_MINUTE",
1076 "DAY_SECOND",
1077 "DELAYED",
1078 "DETERMINISTIC",
1079 "DIV",
1080 "DUAL",
1081 "EACH",
1082 "ELSEIF",
1083 "ENCLOSED",
1084 "EXPLAIN",
1085 "FLOAT4",
1086 "FLOAT8",
1087 "FORCE",
1088 "HOUR_MICROSECOND",
1089 "HOUR_MINUTE",
1090 "HOUR_SECOND",
1091 "IGNORE",
1092 "INFILE",
1093 "INT1",
1094 "INT2",
1095 "INT3",
1096 "INT4",
1097 "INT8",
1098 "ITERATE",
1099 "KEYS",
1100 "KILL",
1101 "LEADING",
1102 "LEAVE",
1103 "LINES",
1104 "LOAD",
1105 "LOCK",
1106 "LONG",
1107 "LONGBLOB",
1108 "LONGTEXT",
1109 "LOOP",
1110 "LOW_PRIORITY",
1111 "MATCH",
1112 "MEDIUMBLOB",
1113 "MEDIUMINT",
1114 "MEDIUMTEXT",
1115 "MINUTE_MICROSECOND",
1116 "MINUTE_SECOND",
1117 "MOD",
1118 "MODIFIES",
1119 "NO_WRITE_TO_BINLOG",
1120 "OPTIMIZE",
1121 "OPTIONALLY",
1122 "OUT",
1123 "OUTFILE",
1124 "PURGE",
1125 "READS",
1126 "REGEXP",
1127 "RELEASE",
1128 "RENAME",
1129 "REPEAT",
1130 "REQUIRE",
1131 "RESIGNAL",
1132 "RETURN",
1133 "REVOKE",
1134 "RLIKE",
1135 "SCHEMA",
1136 "SCHEMAS",
1137 "SECOND_MICROSECOND",
1138 "SENSITIVE",
1139 "SEPARATOR",
1140 "SHOW",
1141 "SIGNAL",
1142 "SPATIAL",
1143 "SQL",
1144 "SQLEXCEPTION",
1145 "SQLSTATE",
1146 "SQLWARNING",
1147 "SQL_BIG_RESULT",
1148 "SQL_CALC_FOUND_ROWS",
1149 "SQL_SMALL_RESULT",
1150 "SSL",
1151 "STARTING",
1152 "STRAIGHT_JOIN",
1153 "TERMINATED",
1154 "TINYBLOB",
1155 "TINYINT",
1156 "TINYTEXT",
1157 "TRAILING",
1158 "TRIGGER",
1159 "UNDO",
1160 "UNLOCK",
1161 "UNSIGNED",
1162 "USAGE",
1163 "UTC_DATE",
1164 "UTC_TIME",
1165 "UTC_TIMESTAMP",
1166 "VARBINARY",
1167 "VARCHARACTER",
1168 "WHILE",
1169 "WRITE",
1170 "XOR",
1171 "YEAR_MONTH",
1172 "ZEROFILL",
1173 ]
1174 .iter()
1175 .copied(),
1176 );
1177 }
1178 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
1179 words.extend(
1180 [
1181 "ANALYSE",
1182 "ANALYZE",
1183 "ARRAY",
1184 "AUTHORIZATION",
1185 "BINARY",
1186 "BOTH",
1187 "COLLATE",
1188 "CONCURRENTLY",
1189 "DO",
1190 "FREEZE",
1191 "ILIKE",
1192 "INITIALLY",
1193 "ISNULL",
1194 "LATERAL",
1195 "LEADING",
1196 "LOCALTIME",
1197 "LOCALTIMESTAMP",
1198 "NOTNULL",
1199 "ONLY",
1200 "OVERLAPS",
1201 "PLACING",
1202 "SIMILAR",
1203 "SYMMETRIC",
1204 "TABLESAMPLE",
1205 "TRAILING",
1206 "VARIADIC",
1207 "VERBOSE",
1208 ]
1209 .iter()
1210 .copied(),
1211 );
1212 }
1213 Some(DialectType::BigQuery) => {
1214 words.extend(
1215 [
1216 "ASSERT_ROWS_MODIFIED",
1217 "COLLATE",
1218 "CONTAINS",
1219 "CUBE",
1220 "DEFINE",
1221 "ENUM",
1222 "EXTRACT",
1223 "FOLLOWING",
1224 "GROUPING",
1225 "GROUPS",
1226 "HASH",
1227 "IGNORE",
1228 "LATERAL",
1229 "LOOKUP",
1230 "MERGE",
1231 "NEW",
1232 "NO",
1233 "NULLS",
1234 "OF",
1235 "OVER",
1236 "PARTITION",
1237 "PRECEDING",
1238 "PROTO",
1239 "RANGE",
1240 "RECURSIVE",
1241 "RESPECT",
1242 "ROLLUP",
1243 "STRUCT",
1244 "TABLESAMPLE",
1245 "TREAT",
1246 "UNBOUNDED",
1247 "WITHIN",
1248 ]
1249 .iter()
1250 .copied(),
1251 );
1252 }
1253 Some(DialectType::Snowflake) => {
1254 words.extend(
1255 [
1256 "ACCOUNT",
1257 "BOTH",
1258 "CONNECT",
1259 "FOLLOWING",
1260 "ILIKE",
1261 "INCREMENT",
1262 "ISSUE",
1263 "LATERAL",
1264 "LEADING",
1265 "LOCALTIME",
1266 "LOCALTIMESTAMP",
1267 "MINUS",
1268 "QUALIFY",
1269 "REGEXP",
1270 "RLIKE",
1271 "SOME",
1272 "START",
1273 "TABLESAMPLE",
1274 "TOP",
1275 "TRAILING",
1276 "TRY_CAST",
1277 ]
1278 .iter()
1279 .copied(),
1280 );
1281 }
1282 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
1283 words.extend(
1284 [
1285 "BACKUP",
1286 "BREAK",
1287 "BROWSE",
1288 "BULK",
1289 "CASCADE",
1290 "CHECKPOINT",
1291 "CLOSE",
1292 "CLUSTERED",
1293 "COALESCE",
1294 "COMPUTE",
1295 "CONTAINS",
1296 "CONTAINSTABLE",
1297 "CONTINUE",
1298 "CONVERT",
1299 "DBCC",
1300 "DEALLOCATE",
1301 "DENY",
1302 "DISK",
1303 "DISTRIBUTED",
1304 "DUMP",
1305 "ERRLVL",
1306 "EXEC",
1307 "EXECUTE",
1308 "EXIT",
1309 "EXTERNAL",
1310 "FILE",
1311 "FILLFACTOR",
1312 "FREETEXT",
1313 "FREETEXTTABLE",
1314 "FUNCTION",
1315 "GOTO",
1316 "HOLDLOCK",
1317 "IDENTITY",
1318 "IDENTITYCOL",
1319 "IDENTITY_INSERT",
1320 "KILL",
1321 "LINENO",
1322 "MERGE",
1323 "NONCLUSTERED",
1324 "NULLIF",
1325 "OF",
1326 "OFF",
1327 "OFFSETS",
1328 "OPEN",
1329 "OPENDATASOURCE",
1330 "OPENQUERY",
1331 "OPENROWSET",
1332 "OPENXML",
1333 "OVER",
1334 "PERCENT",
1335 "PIVOT",
1336 "PLAN",
1337 "PRINT",
1338 "PROC",
1339 "PROCEDURE",
1340 "PUBLIC",
1341 "RAISERROR",
1342 "READ",
1343 "READTEXT",
1344 "RECONFIGURE",
1345 "REPLICATION",
1346 "RESTORE",
1347 "RESTRICT",
1348 "REVERT",
1349 "ROWCOUNT",
1350 "ROWGUIDCOL",
1351 "RULE",
1352 "SAVE",
1353 "SECURITYAUDIT",
1354 "SEMANTICKEYPHRASETABLE",
1355 "SEMANTICSIMILARITYDETAILSTABLE",
1356 "SEMANTICSIMILARITYTABLE",
1357 "SETUSER",
1358 "SHUTDOWN",
1359 "STATISTICS",
1360 "SYSTEM_USER",
1361 "TEXTSIZE",
1362 "TOP",
1363 "TRAN",
1364 "TRANSACTION",
1365 "TRIGGER",
1366 "TSEQUAL",
1367 "UNPIVOT",
1368 "UPDATETEXT",
1369 "WAITFOR",
1370 "WRITETEXT",
1371 ]
1372 .iter()
1373 .copied(),
1374 );
1375 }
1376 Some(DialectType::ClickHouse) => {
1377 words.extend(
1378 [
1379 "ANTI",
1380 "ARRAY",
1381 "ASOF",
1382 "FINAL",
1383 "FORMAT",
1384 "GLOBAL",
1385 "INF",
1386 "KILL",
1387 "MATERIALIZED",
1388 "NAN",
1389 "PREWHERE",
1390 "SAMPLE",
1391 "SEMI",
1392 "SETTINGS",
1393 "TOP",
1394 ]
1395 .iter()
1396 .copied(),
1397 );
1398 }
1399 Some(DialectType::DuckDB) => {
1400 words.extend(
1401 [
1402 "ANALYSE",
1403 "ANALYZE",
1404 "ARRAY",
1405 "BOTH",
1406 "LATERAL",
1407 "LEADING",
1408 "LOCALTIME",
1409 "LOCALTIMESTAMP",
1410 "PLACING",
1411 "QUALIFY",
1412 "SIMILAR",
1413 "TABLESAMPLE",
1414 "TRAILING",
1415 ]
1416 .iter()
1417 .copied(),
1418 );
1419 }
1420 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
1421 words.extend(
1422 [
1423 "BOTH",
1424 "CLUSTER",
1425 "DISTRIBUTE",
1426 "EXCHANGE",
1427 "EXTENDED",
1428 "FUNCTION",
1429 "LATERAL",
1430 "LEADING",
1431 "MACRO",
1432 "OVER",
1433 "PARTITION",
1434 "PERCENT",
1435 "RANGE",
1436 "READS",
1437 "REDUCE",
1438 "REGEXP",
1439 "REVOKE",
1440 "RLIKE",
1441 "ROLLUP",
1442 "SEMI",
1443 "SORT",
1444 "TABLESAMPLE",
1445 "TRAILING",
1446 "TRANSFORM",
1447 "UNBOUNDED",
1448 "UNIQUEJOIN",
1449 ]
1450 .iter()
1451 .copied(),
1452 );
1453 }
1454 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
1455 words.extend(
1456 [
1457 "CUBE",
1458 "DEALLOCATE",
1459 "DESCRIBE",
1460 "EXECUTE",
1461 "EXTRACT",
1462 "GROUPING",
1463 "LATERAL",
1464 "LOCALTIME",
1465 "LOCALTIMESTAMP",
1466 "NORMALIZE",
1467 "PREPARE",
1468 "ROLLUP",
1469 "SOME",
1470 "TABLESAMPLE",
1471 "UESCAPE",
1472 "UNNEST",
1473 ]
1474 .iter()
1475 .copied(),
1476 );
1477 }
1478 Some(DialectType::Oracle) => {
1479 words.extend(
1480 [
1481 "ACCESS",
1482 "AUDIT",
1483 "CLUSTER",
1484 "COMMENT",
1485 "COMPRESS",
1486 "CONNECT",
1487 "EXCLUSIVE",
1488 "FILE",
1489 "IDENTIFIED",
1490 "IMMEDIATE",
1491 "INCREMENT",
1492 "INITIAL",
1493 "LEVEL",
1494 "LOCK",
1495 "LONG",
1496 "MAXEXTENTS",
1497 "MINUS",
1498 "MODE",
1499 "NOAUDIT",
1500 "NOCOMPRESS",
1501 "NOWAIT",
1502 "NUMBER",
1503 "OF",
1504 "OFFLINE",
1505 "ONLINE",
1506 "PCTFREE",
1507 "PRIOR",
1508 "RAW",
1509 "RENAME",
1510 "RESOURCE",
1511 "REVOKE",
1512 "SHARE",
1513 "SIZE",
1514 "START",
1515 "SUCCESSFUL",
1516 "SYNONYM",
1517 "SYSDATE",
1518 "TRIGGER",
1519 "UID",
1520 "VALIDATE",
1521 "VARCHAR2",
1522 "WHENEVER",
1523 ]
1524 .iter()
1525 .copied(),
1526 );
1527 }
1528 Some(DialectType::Redshift) => {
1529 words.extend(
1530 [
1531 "AZ64",
1532 "BZIP2",
1533 "DELTA",
1534 "DELTA32K",
1535 "DISTSTYLE",
1536 "ENCODE",
1537 "GZIP",
1538 "ILIKE",
1539 "LIMIT",
1540 "LUNS",
1541 "LZO",
1542 "LZOP",
1543 "MOSTLY13",
1544 "MOSTLY32",
1545 "MOSTLY8",
1546 "RAW",
1547 "SIMILAR",
1548 "SNAPSHOT",
1549 "SORTKEY",
1550 "SYSDATE",
1551 "TOP",
1552 "ZSTD",
1553 ]
1554 .iter()
1555 .copied(),
1556 );
1557 }
1558 _ => {
1559 words.extend(
1561 [
1562 "ANALYZE",
1563 "ARRAY",
1564 "BOTH",
1565 "CUBE",
1566 "GROUPING",
1567 "LATERAL",
1568 "LEADING",
1569 "LOCALTIME",
1570 "LOCALTIMESTAMP",
1571 "OVER",
1572 "PARTITION",
1573 "QUALIFY",
1574 "RANGE",
1575 "ROLLUP",
1576 "SIMILAR",
1577 "SOME",
1578 "TABLESAMPLE",
1579 "TRAILING",
1580 ]
1581 .iter()
1582 .copied(),
1583 );
1584 }
1585 }
1586
1587 words
1588}
1589
1590fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1598 if name.is_empty() {
1599 return false;
1600 }
1601
1602 if name.as_bytes()[0].is_ascii_digit() {
1604 return true;
1605 }
1606
1607 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1609 return true;
1610 }
1611
1612 let upper = name.to_uppercase();
1614 reserved_words.contains(upper.as_str())
1615}
1616
1617fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1619 if id.quoted || id.name.is_empty() || id.name == "*" {
1622 return;
1623 }
1624 if needs_quoting(&id.name, reserved_words) {
1625 id.quoted = true;
1626 }
1627}
1628
1629fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1631 match expr {
1632 Expression::Identifier(id) => {
1634 maybe_quote(id, reserved_words);
1635 }
1636
1637 Expression::Column(col) => {
1638 maybe_quote(&mut col.name, reserved_words);
1639 if let Some(ref mut table) = col.table {
1640 maybe_quote(table, reserved_words);
1641 }
1642 }
1643
1644 Expression::Table(table_ref) => {
1645 maybe_quote(&mut table_ref.name, reserved_words);
1646 if let Some(ref mut schema) = table_ref.schema {
1647 maybe_quote(schema, reserved_words);
1648 }
1649 if let Some(ref mut catalog) = table_ref.catalog {
1650 maybe_quote(catalog, reserved_words);
1651 }
1652 if let Some(ref mut alias) = table_ref.alias {
1653 maybe_quote(alias, reserved_words);
1654 }
1655 for ca in &mut table_ref.column_aliases {
1656 maybe_quote(ca, reserved_words);
1657 }
1658 for p in &mut table_ref.partitions {
1659 maybe_quote(p, reserved_words);
1660 }
1661 for h in &mut table_ref.hints {
1663 quote_identifiers_recursive(h, reserved_words);
1664 }
1665 if let Some(ref mut ver) = table_ref.version {
1666 quote_identifiers_recursive(&mut ver.this, reserved_words);
1667 if let Some(ref mut e) = ver.expression {
1668 quote_identifiers_recursive(e, reserved_words);
1669 }
1670 }
1671 }
1672
1673 Expression::Star(star) => {
1674 if let Some(ref mut table) = star.table {
1675 maybe_quote(table, reserved_words);
1676 }
1677 if let Some(ref mut except_ids) = star.except {
1678 for id in except_ids {
1679 maybe_quote(id, reserved_words);
1680 }
1681 }
1682 if let Some(ref mut replace_aliases) = star.replace {
1683 for alias in replace_aliases {
1684 maybe_quote(&mut alias.alias, reserved_words);
1685 quote_identifiers_recursive(&mut alias.this, reserved_words);
1686 }
1687 }
1688 if let Some(ref mut rename_pairs) = star.rename {
1689 for (from, to) in rename_pairs {
1690 maybe_quote(from, reserved_words);
1691 maybe_quote(to, reserved_words);
1692 }
1693 }
1694 }
1695
1696 Expression::Alias(alias) => {
1698 maybe_quote(&mut alias.alias, reserved_words);
1699 for ca in &mut alias.column_aliases {
1700 maybe_quote(ca, reserved_words);
1701 }
1702 quote_identifiers_recursive(&mut alias.this, reserved_words);
1703 }
1704
1705 Expression::Select(select) => {
1707 for e in &mut select.expressions {
1708 quote_identifiers_recursive(e, reserved_words);
1709 }
1710 if let Some(ref mut from) = select.from {
1711 for e in &mut from.expressions {
1712 quote_identifiers_recursive(e, reserved_words);
1713 }
1714 }
1715 for join in &mut select.joins {
1716 quote_join(join, reserved_words);
1717 }
1718 for lv in &mut select.lateral_views {
1719 quote_lateral_view(lv, reserved_words);
1720 }
1721 if let Some(ref mut prewhere) = select.prewhere {
1722 quote_identifiers_recursive(prewhere, reserved_words);
1723 }
1724 if let Some(ref mut wh) = select.where_clause {
1725 quote_identifiers_recursive(&mut wh.this, reserved_words);
1726 }
1727 if let Some(ref mut gb) = select.group_by {
1728 for e in &mut gb.expressions {
1729 quote_identifiers_recursive(e, reserved_words);
1730 }
1731 }
1732 if let Some(ref mut hv) = select.having {
1733 quote_identifiers_recursive(&mut hv.this, reserved_words);
1734 }
1735 if let Some(ref mut q) = select.qualify {
1736 quote_identifiers_recursive(&mut q.this, reserved_words);
1737 }
1738 if let Some(ref mut ob) = select.order_by {
1739 for o in &mut ob.expressions {
1740 quote_identifiers_recursive(&mut o.this, reserved_words);
1741 }
1742 }
1743 if let Some(ref mut lim) = select.limit {
1744 quote_identifiers_recursive(&mut lim.this, reserved_words);
1745 }
1746 if let Some(ref mut off) = select.offset {
1747 quote_identifiers_recursive(&mut off.this, reserved_words);
1748 }
1749 if let Some(ref mut with) = select.with {
1750 quote_with(with, reserved_words);
1751 }
1752 if let Some(ref mut windows) = select.windows {
1753 for nw in windows {
1754 maybe_quote(&mut nw.name, reserved_words);
1755 quote_over(&mut nw.spec, reserved_words);
1756 }
1757 }
1758 if let Some(ref mut distinct_on) = select.distinct_on {
1759 for e in distinct_on {
1760 quote_identifiers_recursive(e, reserved_words);
1761 }
1762 }
1763 if let Some(ref mut limit_by) = select.limit_by {
1764 for e in limit_by {
1765 quote_identifiers_recursive(e, reserved_words);
1766 }
1767 }
1768 if let Some(ref mut settings) = select.settings {
1769 for e in settings {
1770 quote_identifiers_recursive(e, reserved_words);
1771 }
1772 }
1773 if let Some(ref mut format) = select.format {
1774 quote_identifiers_recursive(format, reserved_words);
1775 }
1776 }
1777
1778 Expression::Union(u) => {
1780 quote_identifiers_recursive(&mut u.left, reserved_words);
1781 quote_identifiers_recursive(&mut u.right, reserved_words);
1782 if let Some(ref mut with) = u.with {
1783 quote_with(with, reserved_words);
1784 }
1785 if let Some(ref mut ob) = u.order_by {
1786 for o in &mut ob.expressions {
1787 quote_identifiers_recursive(&mut o.this, reserved_words);
1788 }
1789 }
1790 if let Some(ref mut lim) = u.limit {
1791 quote_identifiers_recursive(lim, reserved_words);
1792 }
1793 if let Some(ref mut off) = u.offset {
1794 quote_identifiers_recursive(off, reserved_words);
1795 }
1796 }
1797 Expression::Intersect(i) => {
1798 quote_identifiers_recursive(&mut i.left, reserved_words);
1799 quote_identifiers_recursive(&mut i.right, reserved_words);
1800 if let Some(ref mut with) = i.with {
1801 quote_with(with, reserved_words);
1802 }
1803 if let Some(ref mut ob) = i.order_by {
1804 for o in &mut ob.expressions {
1805 quote_identifiers_recursive(&mut o.this, reserved_words);
1806 }
1807 }
1808 }
1809 Expression::Except(e) => {
1810 quote_identifiers_recursive(&mut e.left, reserved_words);
1811 quote_identifiers_recursive(&mut e.right, reserved_words);
1812 if let Some(ref mut with) = e.with {
1813 quote_with(with, reserved_words);
1814 }
1815 if let Some(ref mut ob) = e.order_by {
1816 for o in &mut ob.expressions {
1817 quote_identifiers_recursive(&mut o.this, reserved_words);
1818 }
1819 }
1820 }
1821
1822 Expression::Subquery(sq) => {
1824 quote_identifiers_recursive(&mut sq.this, reserved_words);
1825 if let Some(ref mut alias) = sq.alias {
1826 maybe_quote(alias, reserved_words);
1827 }
1828 for ca in &mut sq.column_aliases {
1829 maybe_quote(ca, reserved_words);
1830 }
1831 if let Some(ref mut ob) = sq.order_by {
1832 for o in &mut ob.expressions {
1833 quote_identifiers_recursive(&mut o.this, reserved_words);
1834 }
1835 }
1836 }
1837
1838 Expression::Insert(ins) => {
1840 quote_table_ref(&mut ins.table, reserved_words);
1841 for c in &mut ins.columns {
1842 maybe_quote(c, reserved_words);
1843 }
1844 for row in &mut ins.values {
1845 for e in row {
1846 quote_identifiers_recursive(e, reserved_words);
1847 }
1848 }
1849 if let Some(ref mut q) = ins.query {
1850 quote_identifiers_recursive(q, reserved_words);
1851 }
1852 for (id, val) in &mut ins.partition {
1853 maybe_quote(id, reserved_words);
1854 if let Some(ref mut v) = val {
1855 quote_identifiers_recursive(v, reserved_words);
1856 }
1857 }
1858 for e in &mut ins.returning {
1859 quote_identifiers_recursive(e, reserved_words);
1860 }
1861 if let Some(ref mut on_conflict) = ins.on_conflict {
1862 quote_identifiers_recursive(on_conflict, reserved_words);
1863 }
1864 if let Some(ref mut with) = ins.with {
1865 quote_with(with, reserved_words);
1866 }
1867 if let Some(ref mut alias) = ins.alias {
1868 maybe_quote(alias, reserved_words);
1869 }
1870 if let Some(ref mut src_alias) = ins.source_alias {
1871 maybe_quote(src_alias, reserved_words);
1872 }
1873 }
1874
1875 Expression::Update(upd) => {
1876 quote_table_ref(&mut upd.table, reserved_words);
1877 for tr in &mut upd.extra_tables {
1878 quote_table_ref(tr, reserved_words);
1879 }
1880 for join in &mut upd.table_joins {
1881 quote_join(join, reserved_words);
1882 }
1883 for (id, val) in &mut upd.set {
1884 maybe_quote(id, reserved_words);
1885 quote_identifiers_recursive(val, reserved_words);
1886 }
1887 if let Some(ref mut from) = upd.from_clause {
1888 for e in &mut from.expressions {
1889 quote_identifiers_recursive(e, reserved_words);
1890 }
1891 }
1892 for join in &mut upd.from_joins {
1893 quote_join(join, reserved_words);
1894 }
1895 if let Some(ref mut wh) = upd.where_clause {
1896 quote_identifiers_recursive(&mut wh.this, reserved_words);
1897 }
1898 for e in &mut upd.returning {
1899 quote_identifiers_recursive(e, reserved_words);
1900 }
1901 if let Some(ref mut with) = upd.with {
1902 quote_with(with, reserved_words);
1903 }
1904 }
1905
1906 Expression::Delete(del) => {
1907 quote_table_ref(&mut del.table, reserved_words);
1908 if let Some(ref mut alias) = del.alias {
1909 maybe_quote(alias, reserved_words);
1910 }
1911 for tr in &mut del.using {
1912 quote_table_ref(tr, reserved_words);
1913 }
1914 if let Some(ref mut wh) = del.where_clause {
1915 quote_identifiers_recursive(&mut wh.this, reserved_words);
1916 }
1917 if let Some(ref mut with) = del.with {
1918 quote_with(with, reserved_words);
1919 }
1920 }
1921
1922 Expression::And(bin)
1924 | Expression::Or(bin)
1925 | Expression::Eq(bin)
1926 | Expression::Neq(bin)
1927 | Expression::Lt(bin)
1928 | Expression::Lte(bin)
1929 | Expression::Gt(bin)
1930 | Expression::Gte(bin)
1931 | Expression::Add(bin)
1932 | Expression::Sub(bin)
1933 | Expression::Mul(bin)
1934 | Expression::Div(bin)
1935 | Expression::Mod(bin)
1936 | Expression::BitwiseAnd(bin)
1937 | Expression::BitwiseOr(bin)
1938 | Expression::BitwiseXor(bin)
1939 | Expression::Concat(bin)
1940 | Expression::Adjacent(bin)
1941 | Expression::TsMatch(bin)
1942 | Expression::PropertyEQ(bin)
1943 | Expression::ArrayContainsAll(bin)
1944 | Expression::ArrayContainedBy(bin)
1945 | Expression::ArrayOverlaps(bin)
1946 | Expression::JSONBContainsAllTopKeys(bin)
1947 | Expression::JSONBContainsAnyTopKeys(bin)
1948 | Expression::JSONBDeleteAtPath(bin)
1949 | Expression::ExtendsLeft(bin)
1950 | Expression::ExtendsRight(bin)
1951 | Expression::Is(bin)
1952 | Expression::NullSafeEq(bin)
1953 | Expression::NullSafeNeq(bin)
1954 | Expression::Glob(bin)
1955 | Expression::Match(bin)
1956 | Expression::MemberOf(bin)
1957 | Expression::BitwiseLeftShift(bin)
1958 | Expression::BitwiseRightShift(bin) => {
1959 quote_identifiers_recursive(&mut bin.left, reserved_words);
1960 quote_identifiers_recursive(&mut bin.right, reserved_words);
1961 }
1962
1963 Expression::Like(like) | Expression::ILike(like) => {
1965 quote_identifiers_recursive(&mut like.left, reserved_words);
1966 quote_identifiers_recursive(&mut like.right, reserved_words);
1967 if let Some(ref mut esc) = like.escape {
1968 quote_identifiers_recursive(esc, reserved_words);
1969 }
1970 }
1971
1972 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1974 quote_identifiers_recursive(&mut un.this, reserved_words);
1975 }
1976
1977 Expression::In(in_expr) => {
1979 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1980 for e in &mut in_expr.expressions {
1981 quote_identifiers_recursive(e, reserved_words);
1982 }
1983 if let Some(ref mut q) = in_expr.query {
1984 quote_identifiers_recursive(q, reserved_words);
1985 }
1986 if let Some(ref mut un) = in_expr.unnest {
1987 quote_identifiers_recursive(un, reserved_words);
1988 }
1989 }
1990
1991 Expression::Between(bw) => {
1992 quote_identifiers_recursive(&mut bw.this, reserved_words);
1993 quote_identifiers_recursive(&mut bw.low, reserved_words);
1994 quote_identifiers_recursive(&mut bw.high, reserved_words);
1995 }
1996
1997 Expression::IsNull(is_null) => {
1998 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1999 }
2000
2001 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
2002 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
2003 }
2004
2005 Expression::Exists(ex) => {
2006 quote_identifiers_recursive(&mut ex.this, reserved_words);
2007 }
2008
2009 Expression::Function(func) => {
2011 for arg in &mut func.args {
2012 quote_identifiers_recursive(arg, reserved_words);
2013 }
2014 }
2015
2016 Expression::AggregateFunction(agg) => {
2017 for arg in &mut agg.args {
2018 quote_identifiers_recursive(arg, reserved_words);
2019 }
2020 if let Some(ref mut filter) = agg.filter {
2021 quote_identifiers_recursive(filter, reserved_words);
2022 }
2023 for o in &mut agg.order_by {
2024 quote_identifiers_recursive(&mut o.this, reserved_words);
2025 }
2026 }
2027
2028 Expression::WindowFunction(wf) => {
2029 quote_identifiers_recursive(&mut wf.this, reserved_words);
2030 quote_over(&mut wf.over, reserved_words);
2031 }
2032
2033 Expression::Case(case) => {
2035 if let Some(ref mut operand) = case.operand {
2036 quote_identifiers_recursive(operand, reserved_words);
2037 }
2038 for (when, then) in &mut case.whens {
2039 quote_identifiers_recursive(when, reserved_words);
2040 quote_identifiers_recursive(then, reserved_words);
2041 }
2042 if let Some(ref mut else_) = case.else_ {
2043 quote_identifiers_recursive(else_, reserved_words);
2044 }
2045 }
2046
2047 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
2049 quote_identifiers_recursive(&mut cast.this, reserved_words);
2050 if let Some(ref mut fmt) = cast.format {
2051 quote_identifiers_recursive(fmt, reserved_words);
2052 }
2053 }
2054
2055 Expression::Paren(paren) => {
2057 quote_identifiers_recursive(&mut paren.this, reserved_words);
2058 }
2059
2060 Expression::Annotated(ann) => {
2061 quote_identifiers_recursive(&mut ann.this, reserved_words);
2062 }
2063
2064 Expression::With(with) => {
2066 quote_with(with, reserved_words);
2067 }
2068
2069 Expression::Cte(cte) => {
2070 maybe_quote(&mut cte.alias, reserved_words);
2071 for c in &mut cte.columns {
2072 maybe_quote(c, reserved_words);
2073 }
2074 quote_identifiers_recursive(&mut cte.this, reserved_words);
2075 }
2076
2077 Expression::From(from) => {
2079 for e in &mut from.expressions {
2080 quote_identifiers_recursive(e, reserved_words);
2081 }
2082 }
2083
2084 Expression::Join(join) => {
2085 quote_join(join, reserved_words);
2086 }
2087
2088 Expression::JoinedTable(jt) => {
2089 quote_identifiers_recursive(&mut jt.left, reserved_words);
2090 for join in &mut jt.joins {
2091 quote_join(join, reserved_words);
2092 }
2093 if let Some(ref mut alias) = jt.alias {
2094 maybe_quote(alias, reserved_words);
2095 }
2096 }
2097
2098 Expression::Where(wh) => {
2099 quote_identifiers_recursive(&mut wh.this, reserved_words);
2100 }
2101
2102 Expression::GroupBy(gb) => {
2103 for e in &mut gb.expressions {
2104 quote_identifiers_recursive(e, reserved_words);
2105 }
2106 }
2107
2108 Expression::Having(hv) => {
2109 quote_identifiers_recursive(&mut hv.this, reserved_words);
2110 }
2111
2112 Expression::OrderBy(ob) => {
2113 for o in &mut ob.expressions {
2114 quote_identifiers_recursive(&mut o.this, reserved_words);
2115 }
2116 }
2117
2118 Expression::Ordered(ord) => {
2119 quote_identifiers_recursive(&mut ord.this, reserved_words);
2120 }
2121
2122 Expression::Limit(lim) => {
2123 quote_identifiers_recursive(&mut lim.this, reserved_words);
2124 }
2125
2126 Expression::Offset(off) => {
2127 quote_identifiers_recursive(&mut off.this, reserved_words);
2128 }
2129
2130 Expression::Qualify(q) => {
2131 quote_identifiers_recursive(&mut q.this, reserved_words);
2132 }
2133
2134 Expression::Window(ws) => {
2135 for e in &mut ws.partition_by {
2136 quote_identifiers_recursive(e, reserved_words);
2137 }
2138 for o in &mut ws.order_by {
2139 quote_identifiers_recursive(&mut o.this, reserved_words);
2140 }
2141 }
2142
2143 Expression::Over(over) => {
2144 quote_over(over, reserved_words);
2145 }
2146
2147 Expression::WithinGroup(wg) => {
2148 quote_identifiers_recursive(&mut wg.this, reserved_words);
2149 for o in &mut wg.order_by {
2150 quote_identifiers_recursive(&mut o.this, reserved_words);
2151 }
2152 }
2153
2154 Expression::Pivot(piv) => {
2156 quote_identifiers_recursive(&mut piv.this, reserved_words);
2157 for e in &mut piv.expressions {
2158 quote_identifiers_recursive(e, reserved_words);
2159 }
2160 for f in &mut piv.fields {
2161 quote_identifiers_recursive(f, reserved_words);
2162 }
2163 if let Some(ref mut alias) = piv.alias {
2164 maybe_quote(alias, reserved_words);
2165 }
2166 }
2167
2168 Expression::Unpivot(unpiv) => {
2169 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
2170 maybe_quote(&mut unpiv.value_column, reserved_words);
2171 maybe_quote(&mut unpiv.name_column, reserved_words);
2172 for e in &mut unpiv.columns {
2173 quote_identifiers_recursive(e, reserved_words);
2174 }
2175 if let Some(ref mut alias) = unpiv.alias {
2176 maybe_quote(alias, reserved_words);
2177 }
2178 }
2179
2180 Expression::Values(vals) => {
2182 for tuple in &mut vals.expressions {
2183 for e in &mut tuple.expressions {
2184 quote_identifiers_recursive(e, reserved_words);
2185 }
2186 }
2187 if let Some(ref mut alias) = vals.alias {
2188 maybe_quote(alias, reserved_words);
2189 }
2190 for ca in &mut vals.column_aliases {
2191 maybe_quote(ca, reserved_words);
2192 }
2193 }
2194
2195 Expression::Array(arr) => {
2197 for e in &mut arr.expressions {
2198 quote_identifiers_recursive(e, reserved_words);
2199 }
2200 }
2201
2202 Expression::Struct(st) => {
2203 for (_name, e) in &mut st.fields {
2204 quote_identifiers_recursive(e, reserved_words);
2205 }
2206 }
2207
2208 Expression::Tuple(tup) => {
2209 for e in &mut tup.expressions {
2210 quote_identifiers_recursive(e, reserved_words);
2211 }
2212 }
2213
2214 Expression::Subscript(sub) => {
2216 quote_identifiers_recursive(&mut sub.this, reserved_words);
2217 quote_identifiers_recursive(&mut sub.index, reserved_words);
2218 }
2219
2220 Expression::Dot(dot) => {
2221 quote_identifiers_recursive(&mut dot.this, reserved_words);
2222 maybe_quote(&mut dot.field, reserved_words);
2223 }
2224
2225 Expression::ScopeResolution(sr) => {
2226 if let Some(ref mut this) = sr.this {
2227 quote_identifiers_recursive(this, reserved_words);
2228 }
2229 quote_identifiers_recursive(&mut sr.expression, reserved_words);
2230 }
2231
2232 Expression::Lateral(lat) => {
2234 quote_identifiers_recursive(&mut lat.this, reserved_words);
2235 }
2237
2238 Expression::DPipe(dpipe) => {
2240 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
2241 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
2242 }
2243
2244 Expression::Merge(merge) => {
2246 quote_identifiers_recursive(&mut merge.this, reserved_words);
2247 quote_identifiers_recursive(&mut merge.using, reserved_words);
2248 if let Some(ref mut on) = merge.on {
2249 quote_identifiers_recursive(on, reserved_words);
2250 }
2251 if let Some(ref mut whens) = merge.whens {
2252 quote_identifiers_recursive(whens, reserved_words);
2253 }
2254 if let Some(ref mut with) = merge.with_ {
2255 quote_identifiers_recursive(with, reserved_words);
2256 }
2257 if let Some(ref mut ret) = merge.returning {
2258 quote_identifiers_recursive(ret, reserved_words);
2259 }
2260 }
2261
2262 Expression::LateralView(lv) => {
2264 quote_lateral_view(lv, reserved_words);
2265 }
2266
2267 Expression::Anonymous(anon) => {
2269 quote_identifiers_recursive(&mut anon.this, reserved_words);
2270 for e in &mut anon.expressions {
2271 quote_identifiers_recursive(e, reserved_words);
2272 }
2273 }
2274
2275 Expression::Filter(filter) => {
2277 quote_identifiers_recursive(&mut filter.this, reserved_words);
2278 quote_identifiers_recursive(&mut filter.expression, reserved_words);
2279 }
2280
2281 Expression::Returning(ret) => {
2283 for e in &mut ret.expressions {
2284 quote_identifiers_recursive(e, reserved_words);
2285 }
2286 }
2287
2288 Expression::BracedWildcard(inner) => {
2290 quote_identifiers_recursive(inner, reserved_words);
2291 }
2292
2293 Expression::ReturnStmt(inner) => {
2295 quote_identifiers_recursive(inner, reserved_words);
2296 }
2297
2298 Expression::Literal(_)
2300 | Expression::Boolean(_)
2301 | Expression::Null(_)
2302 | Expression::DataType(_)
2303 | Expression::Raw(_)
2304 | Expression::Placeholder(_)
2305 | Expression::CurrentDate(_)
2306 | Expression::CurrentTime(_)
2307 | Expression::CurrentTimestamp(_)
2308 | Expression::CurrentTimestampLTZ(_)
2309 | Expression::SessionUser(_)
2310 | Expression::RowNumber(_)
2311 | Expression::Rank(_)
2312 | Expression::DenseRank(_)
2313 | Expression::PercentRank(_)
2314 | Expression::CumeDist(_)
2315 | Expression::Random(_)
2316 | Expression::Pi(_)
2317 | Expression::JSONPathRoot(_) => {
2318 }
2320
2321 _ => {}
2325 }
2326}
2327
2328fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
2330 quote_identifiers_recursive(&mut join.this, reserved_words);
2331 if let Some(ref mut on) = join.on {
2332 quote_identifiers_recursive(on, reserved_words);
2333 }
2334 for id in &mut join.using {
2335 maybe_quote(id, reserved_words);
2336 }
2337 if let Some(ref mut mc) = join.match_condition {
2338 quote_identifiers_recursive(mc, reserved_words);
2339 }
2340 for piv in &mut join.pivots {
2341 quote_identifiers_recursive(piv, reserved_words);
2342 }
2343}
2344
2345fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
2347 for cte in &mut with.ctes {
2348 maybe_quote(&mut cte.alias, reserved_words);
2349 for c in &mut cte.columns {
2350 maybe_quote(c, reserved_words);
2351 }
2352 for k in &mut cte.key_expressions {
2353 maybe_quote(k, reserved_words);
2354 }
2355 quote_identifiers_recursive(&mut cte.this, reserved_words);
2356 }
2357}
2358
2359fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
2361 if let Some(ref mut wn) = over.window_name {
2362 maybe_quote(wn, reserved_words);
2363 }
2364 for e in &mut over.partition_by {
2365 quote_identifiers_recursive(e, reserved_words);
2366 }
2367 for o in &mut over.order_by {
2368 quote_identifiers_recursive(&mut o.this, reserved_words);
2369 }
2370 if let Some(ref mut alias) = over.alias {
2371 maybe_quote(alias, reserved_words);
2372 }
2373}
2374
2375fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
2377 maybe_quote(&mut table_ref.name, reserved_words);
2378 if let Some(ref mut schema) = table_ref.schema {
2379 maybe_quote(schema, reserved_words);
2380 }
2381 if let Some(ref mut catalog) = table_ref.catalog {
2382 maybe_quote(catalog, reserved_words);
2383 }
2384 if let Some(ref mut alias) = table_ref.alias {
2385 maybe_quote(alias, reserved_words);
2386 }
2387 for ca in &mut table_ref.column_aliases {
2388 maybe_quote(ca, reserved_words);
2389 }
2390 for p in &mut table_ref.partitions {
2391 maybe_quote(p, reserved_words);
2392 }
2393 for h in &mut table_ref.hints {
2394 quote_identifiers_recursive(h, reserved_words);
2395 }
2396}
2397
2398fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
2400 quote_identifiers_recursive(&mut lv.this, reserved_words);
2401 if let Some(ref mut ta) = lv.table_alias {
2402 maybe_quote(ta, reserved_words);
2403 }
2404 for ca in &mut lv.column_aliases {
2405 maybe_quote(ca, reserved_words);
2406 }
2407}
2408
2409pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
2420 let reserved_words = get_reserved_words(dialect);
2421 let mut result = expression;
2422 quote_identifiers_recursive(&mut result, &reserved_words);
2423 result
2424}
2425
2426pub fn pushdown_cte_alias_columns(_scope: &Scope) {
2431 }
2434
2435fn pushdown_cte_alias_columns_with(with: &mut With) {
2436 for cte in &mut with.ctes {
2437 if cte.columns.is_empty() {
2438 continue;
2439 }
2440
2441 if let Expression::Select(select) = &mut cte.this {
2442 let mut next_expressions = Vec::with_capacity(select.expressions.len());
2443
2444 for (i, projection) in select.expressions.iter().enumerate() {
2445 let Some(alias_name) = cte.columns.get(i) else {
2446 next_expressions.push(projection.clone());
2447 continue;
2448 };
2449
2450 match projection {
2451 Expression::Alias(existing) => {
2452 let mut aliased = existing.clone();
2453 aliased.alias = alias_name.clone();
2454 next_expressions.push(Expression::Alias(aliased));
2455 }
2456 _ => {
2457 next_expressions.push(create_alias(projection.clone(), &alias_name.name));
2458 }
2459 }
2460 }
2461
2462 select.expressions = next_expressions;
2463 }
2464 }
2465}
2466
2467fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
2473 let mut columns = Vec::new();
2474 collect_columns(&scope.expression, &mut columns);
2475 columns
2476}
2477
2478#[derive(Debug, Clone)]
2480struct ColumnRef {
2481 table: Option<String>,
2482 name: String,
2483}
2484
2485fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
2487 match expr {
2488 Expression::Column(col) => {
2489 columns.push(ColumnRef {
2490 table: col.table.as_ref().map(|t| t.name.clone()),
2491 name: col.name.name.clone(),
2492 });
2493 }
2494 Expression::Select(select) => {
2495 for e in &select.expressions {
2496 collect_columns(e, columns);
2497 }
2498 if let Some(from) = &select.from {
2499 for e in &from.expressions {
2500 collect_columns(e, columns);
2501 }
2502 }
2503 if let Some(where_clause) = &select.where_clause {
2504 collect_columns(&where_clause.this, columns);
2505 }
2506 if let Some(group_by) = &select.group_by {
2507 for e in &group_by.expressions {
2508 collect_columns(e, columns);
2509 }
2510 }
2511 if let Some(having) = &select.having {
2512 collect_columns(&having.this, columns);
2513 }
2514 if let Some(order_by) = &select.order_by {
2515 for o in &order_by.expressions {
2516 collect_columns(&o.this, columns);
2517 }
2518 }
2519 for join in &select.joins {
2520 collect_columns(&join.this, columns);
2521 if let Some(on) = &join.on {
2522 collect_columns(on, columns);
2523 }
2524 }
2525 }
2526 Expression::Alias(alias) => {
2527 collect_columns(&alias.this, columns);
2528 }
2529 Expression::Function(func) => {
2530 for arg in &func.args {
2531 collect_columns(arg, columns);
2532 }
2533 }
2534 Expression::AggregateFunction(agg) => {
2535 for arg in &agg.args {
2536 collect_columns(arg, columns);
2537 }
2538 }
2539 Expression::And(bin)
2540 | Expression::Or(bin)
2541 | Expression::Eq(bin)
2542 | Expression::Neq(bin)
2543 | Expression::Lt(bin)
2544 | Expression::Lte(bin)
2545 | Expression::Gt(bin)
2546 | Expression::Gte(bin)
2547 | Expression::Add(bin)
2548 | Expression::Sub(bin)
2549 | Expression::Mul(bin)
2550 | Expression::Div(bin) => {
2551 collect_columns(&bin.left, columns);
2552 collect_columns(&bin.right, columns);
2553 }
2554 Expression::Not(unary) | Expression::Neg(unary) => {
2555 collect_columns(&unary.this, columns);
2556 }
2557 Expression::Paren(paren) => {
2558 collect_columns(&paren.this, columns);
2559 }
2560 Expression::Case(case) => {
2561 if let Some(operand) = &case.operand {
2562 collect_columns(operand, columns);
2563 }
2564 for (when, then) in &case.whens {
2565 collect_columns(when, columns);
2566 collect_columns(then, columns);
2567 }
2568 if let Some(else_) = &case.else_ {
2569 collect_columns(else_, columns);
2570 }
2571 }
2572 Expression::Cast(cast) => {
2573 collect_columns(&cast.this, columns);
2574 }
2575 Expression::In(in_expr) => {
2576 collect_columns(&in_expr.this, columns);
2577 for e in &in_expr.expressions {
2578 collect_columns(e, columns);
2579 }
2580 if let Some(query) = &in_expr.query {
2581 collect_columns(query, columns);
2582 }
2583 }
2584 Expression::Between(between) => {
2585 collect_columns(&between.this, columns);
2586 collect_columns(&between.low, columns);
2587 collect_columns(&between.high, columns);
2588 }
2589 Expression::Subquery(subquery) => {
2590 collect_columns(&subquery.this, columns);
2591 }
2592 _ => {}
2593 }
2594}
2595
2596fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
2598 get_scope_columns(scope)
2599 .into_iter()
2600 .filter(|c| c.table.is_none())
2601 .collect()
2602}
2603
2604fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
2606 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
2607
2608 get_scope_columns(scope)
2609 .into_iter()
2610 .filter(|c| {
2611 if let Some(table) = &c.table {
2612 !source_names.contains(table)
2613 } else {
2614 false
2615 }
2616 })
2617 .collect()
2618}
2619
2620fn is_correlated_subquery(scope: &Scope) -> bool {
2622 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2623}
2624
2625fn is_star_column(col: &Column) -> bool {
2627 col.name.name == "*"
2628}
2629
2630fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2632 Expression::boxed_column(Column {
2633 name: Identifier::new(name),
2634 table: table.map(Identifier::new),
2635 join_mark: false,
2636 trailing_comments: vec![],
2637 span: None,
2638 inferred_type: None,
2639 })
2640}
2641
2642fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2644 Expression::Alias(Box::new(Alias {
2645 this: expr,
2646 alias: Identifier::new(alias_name),
2647 column_aliases: vec![],
2648 alias_explicit_as: false,
2649 alias_keyword: None,
2650 pre_alias_comments: vec![],
2651 trailing_comments: vec![],
2652 inferred_type: None,
2653 }))
2654}
2655
2656fn get_output_name(expr: &Expression) -> Option<String> {
2658 match expr {
2659 Expression::Column(col) => Some(col.name.name.clone()),
2660 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2661 Expression::Identifier(id) => Some(id.name.clone()),
2662 _ => None,
2663 }
2664}
2665
2666#[cfg(test)]
2667mod tests {
2668 use super::*;
2669 use crate::expressions::DataType;
2670 use crate::generator::Generator;
2671 use crate::parser::Parser;
2672 use crate::scope::build_scope;
2673 use crate::{MappingSchema, Schema};
2674
2675 fn gen(expr: &Expression) -> String {
2676 Generator::new().generate(expr).unwrap()
2677 }
2678
2679 fn parse(sql: &str) -> Expression {
2680 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2681 }
2682
2683 #[test]
2684 fn test_qualify_columns_options() {
2685 let options = QualifyColumnsOptions::new()
2686 .with_expand_alias_refs(true)
2687 .with_expand_stars(false)
2688 .with_dialect(DialectType::PostgreSQL)
2689 .with_allow_partial(true);
2690
2691 assert!(options.expand_alias_refs);
2692 assert!(!options.expand_stars);
2693 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2694 assert!(options.allow_partial_qualification);
2695 }
2696
2697 #[test]
2698 fn test_get_scope_columns() {
2699 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2700 let scope = build_scope(&expr);
2701 let columns = get_scope_columns(&scope);
2702
2703 assert!(columns.iter().any(|c| c.name == "a"));
2704 assert!(columns.iter().any(|c| c.name == "b"));
2705 assert!(columns.iter().any(|c| c.name == "c"));
2706 }
2707
2708 #[test]
2709 fn test_get_unqualified_columns() {
2710 let expr = parse("SELECT t.a, b FROM t");
2711 let scope = build_scope(&expr);
2712 let unqualified = get_unqualified_columns(&scope);
2713
2714 assert!(unqualified.iter().any(|c| c.name == "b"));
2716 assert!(!unqualified.iter().any(|c| c.name == "a"));
2717 }
2718
2719 #[test]
2720 fn test_is_star_column() {
2721 let col = Column {
2722 name: Identifier::new("*"),
2723 table: Some(Identifier::new("t")),
2724 join_mark: false,
2725 trailing_comments: vec![],
2726 span: None,
2727 inferred_type: None,
2728 };
2729 assert!(is_star_column(&col));
2730
2731 let col2 = Column {
2732 name: Identifier::new("id"),
2733 table: None,
2734 join_mark: false,
2735 trailing_comments: vec![],
2736 span: None,
2737 inferred_type: None,
2738 };
2739 assert!(!is_star_column(&col2));
2740 }
2741
2742 #[test]
2743 fn test_create_qualified_column() {
2744 let expr = create_qualified_column("id", Some("users"));
2745 let sql = gen(&expr);
2746 assert!(sql.contains("users"));
2747 assert!(sql.contains("id"));
2748 }
2749
2750 #[test]
2751 fn test_create_alias() {
2752 let col = Expression::boxed_column(Column {
2753 name: Identifier::new("value"),
2754 table: None,
2755 join_mark: false,
2756 trailing_comments: vec![],
2757 span: None,
2758 inferred_type: None,
2759 });
2760 let aliased = create_alias(col, "total");
2761 let sql = gen(&aliased);
2762 assert!(sql.contains("AS") || sql.contains("total"));
2763 }
2764
2765 #[test]
2766 fn test_validate_qualify_columns_success() {
2767 let expr = parse("SELECT t.a, t.b FROM t");
2769 let result = validate_qualify_columns(&expr);
2770 let _ = result;
2773 }
2774
2775 #[test]
2776 fn test_collect_columns_nested() {
2777 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2778 let mut columns = Vec::new();
2779 collect_columns(&expr, &mut columns);
2780
2781 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2782 assert!(names.contains(&"a"));
2783 assert!(names.contains(&"b"));
2784 assert!(names.contains(&"c"));
2785 assert!(names.contains(&"d"));
2786 assert!(names.contains(&"e"));
2787 assert!(names.contains(&"f"));
2788 }
2789
2790 #[test]
2791 fn test_collect_columns_in_case() {
2792 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2793 let mut columns = Vec::new();
2794 collect_columns(&expr, &mut columns);
2795
2796 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2797 assert!(names.contains(&"a"));
2798 assert!(names.contains(&"b"));
2799 assert!(names.contains(&"c"));
2800 }
2801
2802 #[test]
2803 fn test_collect_columns_in_subquery() {
2804 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2805 let mut columns = Vec::new();
2806 collect_columns(&expr, &mut columns);
2807
2808 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2809 assert!(names.contains(&"a"));
2810 assert!(names.contains(&"b"));
2811 assert!(names.contains(&"c"));
2812 }
2813
2814 #[test]
2815 fn test_qualify_outputs_basic() {
2816 let expr = parse("SELECT a, b + c FROM t");
2817 let scope = build_scope(&expr);
2818 let result = qualify_outputs(&scope);
2819 assert!(result.is_ok());
2820 }
2821
2822 #[test]
2823 fn test_qualify_columns_expands_star_with_schema() {
2824 let expr = parse("SELECT * FROM users");
2825
2826 let mut schema = MappingSchema::new();
2827 schema
2828 .add_table(
2829 "users",
2830 &[
2831 (
2832 "id".to_string(),
2833 DataType::Int {
2834 length: None,
2835 integer_spelling: false,
2836 },
2837 ),
2838 ("name".to_string(), DataType::Text),
2839 ("email".to_string(), DataType::Text),
2840 ],
2841 None,
2842 )
2843 .expect("schema setup");
2844
2845 let result =
2846 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2847 let sql = gen(&result);
2848
2849 assert!(!sql.contains("SELECT *"));
2850 assert!(sql.contains("users.id"));
2851 assert!(sql.contains("users.name"));
2852 assert!(sql.contains("users.email"));
2853 }
2854
2855 #[test]
2856 fn test_qualify_columns_expands_group_by_positions() {
2857 let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
2858
2859 let mut schema = MappingSchema::new();
2860 schema
2861 .add_table(
2862 "t",
2863 &[
2864 (
2865 "a".to_string(),
2866 DataType::Int {
2867 length: None,
2868 integer_spelling: false,
2869 },
2870 ),
2871 (
2872 "b".to_string(),
2873 DataType::Int {
2874 length: None,
2875 integer_spelling: false,
2876 },
2877 ),
2878 ],
2879 None,
2880 )
2881 .expect("schema setup");
2882
2883 let result =
2884 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2885 let sql = gen(&result);
2886
2887 assert!(!sql.contains("GROUP BY 1"));
2888 assert!(!sql.contains("GROUP BY 2"));
2889 assert!(sql.contains("GROUP BY"));
2890 assert!(sql.contains("t.a"));
2891 assert!(sql.contains("t.b"));
2892 }
2893
2894 #[test]
2899 fn test_expand_using_simple() {
2900 let expr = parse("SELECT x.b FROM x JOIN y USING (b)");
2902
2903 let mut schema = MappingSchema::new();
2904 schema
2905 .add_table(
2906 "x",
2907 &[
2908 ("a".to_string(), DataType::BigInt { length: None }),
2909 ("b".to_string(), DataType::BigInt { length: None }),
2910 ],
2911 None,
2912 )
2913 .expect("schema setup");
2914 schema
2915 .add_table(
2916 "y",
2917 &[
2918 ("b".to_string(), DataType::BigInt { length: None }),
2919 ("c".to_string(), DataType::BigInt { length: None }),
2920 ],
2921 None,
2922 )
2923 .expect("schema setup");
2924
2925 let result =
2926 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2927 let sql = gen(&result);
2928
2929 assert!(
2931 !sql.contains("USING"),
2932 "USING should be replaced with ON: {sql}"
2933 );
2934 assert!(
2935 sql.contains("ON x.b = y.b"),
2936 "ON condition should be x.b = y.b: {sql}"
2937 );
2938 assert!(sql.contains("SELECT x.b"), "SELECT should keep x.b: {sql}");
2940 }
2941
2942 #[test]
2943 fn test_expand_using_unqualified_coalesce() {
2944 let expr = parse("SELECT b FROM x JOIN y USING(b)");
2946
2947 let mut schema = MappingSchema::new();
2948 schema
2949 .add_table(
2950 "x",
2951 &[
2952 ("a".to_string(), DataType::BigInt { length: None }),
2953 ("b".to_string(), DataType::BigInt { length: None }),
2954 ],
2955 None,
2956 )
2957 .expect("schema setup");
2958 schema
2959 .add_table(
2960 "y",
2961 &[
2962 ("b".to_string(), DataType::BigInt { length: None }),
2963 ("c".to_string(), DataType::BigInt { length: None }),
2964 ],
2965 None,
2966 )
2967 .expect("schema setup");
2968
2969 let result =
2970 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2971 let sql = gen(&result);
2972
2973 assert!(
2974 sql.contains("COALESCE(x.b, y.b)"),
2975 "Unqualified USING column should become COALESCE: {sql}"
2976 );
2977 assert!(
2978 sql.contains("AS b"),
2979 "COALESCE should be aliased as 'b': {sql}"
2980 );
2981 assert!(
2982 sql.contains("ON x.b = y.b"),
2983 "ON condition should be generated: {sql}"
2984 );
2985 }
2986
2987 #[test]
2988 fn test_expand_using_with_where() {
2989 let expr = parse("SELECT b FROM x JOIN y USING(b) WHERE b = 1");
2991
2992 let mut schema = MappingSchema::new();
2993 schema
2994 .add_table(
2995 "x",
2996 &[("b".to_string(), DataType::BigInt { length: None })],
2997 None,
2998 )
2999 .expect("schema setup");
3000 schema
3001 .add_table(
3002 "y",
3003 &[("b".to_string(), DataType::BigInt { length: None })],
3004 None,
3005 )
3006 .expect("schema setup");
3007
3008 let result =
3009 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3010 let sql = gen(&result);
3011
3012 assert!(
3013 sql.contains("WHERE COALESCE(x.b, y.b)"),
3014 "WHERE should use COALESCE for USING column: {sql}"
3015 );
3016 }
3017
3018 #[test]
3019 fn test_expand_using_multi_join() {
3020 let expr = parse("SELECT b FROM x JOIN y USING(b) JOIN z USING(b)");
3022
3023 let mut schema = MappingSchema::new();
3024 for table in &["x", "y", "z"] {
3025 schema
3026 .add_table(
3027 table,
3028 &[("b".to_string(), DataType::BigInt { length: None })],
3029 None,
3030 )
3031 .expect("schema setup");
3032 }
3033
3034 let result =
3035 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3036 let sql = gen(&result);
3037
3038 assert!(
3040 sql.contains("COALESCE(x.b, y.b, z.b)"),
3041 "Should have 3-table COALESCE: {sql}"
3042 );
3043 assert!(
3045 sql.contains("ON x.b = y.b"),
3046 "First join ON condition: {sql}"
3047 );
3048 }
3049
3050 #[test]
3051 fn test_expand_using_multi_column() {
3052 let expr = parse("SELECT b, c FROM y JOIN z USING(b, c)");
3054
3055 let mut schema = MappingSchema::new();
3056 schema
3057 .add_table(
3058 "y",
3059 &[
3060 ("b".to_string(), DataType::BigInt { length: None }),
3061 ("c".to_string(), DataType::BigInt { length: None }),
3062 ],
3063 None,
3064 )
3065 .expect("schema setup");
3066 schema
3067 .add_table(
3068 "z",
3069 &[
3070 ("b".to_string(), DataType::BigInt { length: None }),
3071 ("c".to_string(), DataType::BigInt { length: None }),
3072 ],
3073 None,
3074 )
3075 .expect("schema setup");
3076
3077 let result =
3078 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3079 let sql = gen(&result);
3080
3081 assert!(
3082 sql.contains("COALESCE(y.b, z.b)"),
3083 "column 'b' should get COALESCE: {sql}"
3084 );
3085 assert!(
3086 sql.contains("COALESCE(y.c, z.c)"),
3087 "column 'c' should get COALESCE: {sql}"
3088 );
3089 assert!(
3091 sql.contains("y.b = z.b") && sql.contains("y.c = z.c"),
3092 "ON should have both equality conditions: {sql}"
3093 );
3094 }
3095
3096 #[test]
3097 fn test_expand_using_star() {
3098 let expr = parse("SELECT * FROM x JOIN y USING(b)");
3100
3101 let mut schema = MappingSchema::new();
3102 schema
3103 .add_table(
3104 "x",
3105 &[
3106 ("a".to_string(), DataType::BigInt { length: None }),
3107 ("b".to_string(), DataType::BigInt { length: None }),
3108 ],
3109 None,
3110 )
3111 .expect("schema setup");
3112 schema
3113 .add_table(
3114 "y",
3115 &[
3116 ("b".to_string(), DataType::BigInt { length: None }),
3117 ("c".to_string(), DataType::BigInt { length: None }),
3118 ],
3119 None,
3120 )
3121 .expect("schema setup");
3122
3123 let result =
3124 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3125 let sql = gen(&result);
3126
3127 assert!(
3129 sql.contains("COALESCE(x.b, y.b) AS b"),
3130 "USING column should be COALESCE in star expansion: {sql}"
3131 );
3132 assert!(sql.contains("x.a"), "non-USING column a from x: {sql}");
3134 assert!(sql.contains("y.c"), "non-USING column c from y: {sql}");
3135 let coalesce_count = sql.matches("COALESCE").count();
3137 assert_eq!(
3138 coalesce_count, 1,
3139 "b should appear only once as COALESCE: {sql}"
3140 );
3141 }
3142
3143 #[test]
3144 fn test_expand_using_table_star() {
3145 let expr = parse("SELECT x.* FROM x JOIN y USING(b)");
3147
3148 let mut schema = MappingSchema::new();
3149 schema
3150 .add_table(
3151 "x",
3152 &[
3153 ("a".to_string(), DataType::BigInt { length: None }),
3154 ("b".to_string(), DataType::BigInt { length: None }),
3155 ],
3156 None,
3157 )
3158 .expect("schema setup");
3159 schema
3160 .add_table(
3161 "y",
3162 &[
3163 ("b".to_string(), DataType::BigInt { length: None }),
3164 ("c".to_string(), DataType::BigInt { length: None }),
3165 ],
3166 None,
3167 )
3168 .expect("schema setup");
3169
3170 let result =
3171 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3172 let sql = gen(&result);
3173
3174 assert!(
3176 sql.contains("COALESCE(x.b, y.b)"),
3177 "USING column from x.* should become COALESCE: {sql}"
3178 );
3179 assert!(sql.contains("x.a"), "non-USING column a: {sql}");
3180 }
3181
3182 #[test]
3183 fn test_qualify_columns_qualified_table_name() {
3184 let expr = parse("SELECT a FROM raw.t1");
3185
3186 let mut schema = MappingSchema::new();
3187 schema
3188 .add_table(
3189 "raw.t1",
3190 &[("a".to_string(), DataType::BigInt { length: None })],
3191 None,
3192 )
3193 .expect("schema setup");
3194
3195 let result =
3196 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3197 let sql = gen(&result);
3198
3199 assert!(
3200 sql.contains("t1.a"),
3201 "column should be qualified with table name: {sql}"
3202 );
3203
3204 let expr = parse("SELECT MAX(a) FROM raw.t1");
3206 let result =
3207 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3208 let sql = gen(&result);
3209 assert!(
3210 sql.contains("t1.a"),
3211 "column in function should be qualified with table name: {sql}"
3212 );
3213
3214 let expr = parse("SELECT ABS(a) FROM raw.t1");
3216 let result =
3217 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3218 let sql = gen(&result);
3219 assert!(
3220 sql.contains("t1.a"),
3221 "column in function should be qualified with table name: {sql}"
3222 );
3223 }
3224
3225 #[test]
3226 fn test_qualify_columns_count_star() {
3227 let expr = parse("SELECT COUNT(*) FROM t1");
3229
3230 let mut schema = MappingSchema::new();
3231 schema
3232 .add_table(
3233 "t1",
3234 &[("id".to_string(), DataType::BigInt { length: None })],
3235 None,
3236 )
3237 .expect("schema setup");
3238
3239 let result =
3240 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3241 let sql = gen(&result);
3242
3243 assert!(
3244 sql.contains("COUNT(*)"),
3245 "COUNT(*) should be preserved: {sql}"
3246 );
3247 }
3248
3249 #[test]
3250 fn test_qualify_columns_correlated_scalar_subquery() {
3251 let expr =
3252 parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1");
3253
3254 let mut schema = MappingSchema::new();
3255 schema
3256 .add_table(
3257 "t1",
3258 &[("id".to_string(), DataType::BigInt { length: None })],
3259 None,
3260 )
3261 .expect("schema setup");
3262 schema
3263 .add_table(
3264 "t2",
3265 &[
3266 ("id".to_string(), DataType::BigInt { length: None }),
3267 ("val".to_string(), DataType::BigInt { length: None }),
3268 ],
3269 None,
3270 )
3271 .expect("schema setup");
3272
3273 let result =
3274 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3275 let sql = gen(&result);
3276
3277 assert!(
3278 sql.contains("t1.id"),
3279 "outer column should be qualified: {sql}"
3280 );
3281 assert!(
3282 sql.contains("t2.id"),
3283 "inner column should be qualified: {sql}"
3284 );
3285 }
3286
3287 #[test]
3288 fn test_qualify_columns_correlated_scalar_subquery_unqualified() {
3289 let expr =
3290 parse("SELECT t1_id, (SELECT AVG(val) FROM t2 WHERE t2_id = t1_id) AS avg_val FROM t1");
3291
3292 let mut schema = MappingSchema::new();
3293 schema
3294 .add_table(
3295 "t1",
3296 &[("t1_id".to_string(), DataType::BigInt { length: None })],
3297 None,
3298 )
3299 .expect("schema setup");
3300 schema
3301 .add_table(
3302 "t2",
3303 &[
3304 ("t2_id".to_string(), DataType::BigInt { length: None }),
3305 ("val".to_string(), DataType::BigInt { length: None }),
3306 ],
3307 None,
3308 )
3309 .expect("schema setup");
3310
3311 let result =
3312 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3313 let sql = gen(&result);
3314
3315 assert!(
3316 sql.contains("t1.t1_id"),
3317 "outer column should be qualified: {sql}"
3318 );
3319 assert!(
3320 sql.contains("t2.t2_id"),
3321 "inner column should be qualified: {sql}"
3322 );
3323 assert!(
3325 sql.contains("= t1.t1_id"),
3326 "correlated column should be qualified: {sql}"
3327 );
3328 }
3329
3330 #[test]
3331 fn test_qualify_columns_correlated_exists_subquery() {
3332 let expr = parse(
3333 "SELECT o_orderpriority FROM orders \
3334 WHERE EXISTS (SELECT * FROM lineitem WHERE l_orderkey = o_orderkey)",
3335 );
3336
3337 let mut schema = MappingSchema::new();
3338 schema
3339 .add_table(
3340 "orders",
3341 &[
3342 ("o_orderpriority".to_string(), DataType::Text),
3343 ("o_orderkey".to_string(), DataType::BigInt { length: None }),
3344 ],
3345 None,
3346 )
3347 .expect("schema setup");
3348 schema
3349 .add_table(
3350 "lineitem",
3351 &[("l_orderkey".to_string(), DataType::BigInt { length: None })],
3352 None,
3353 )
3354 .expect("schema setup");
3355
3356 let result =
3357 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3358 let sql = gen(&result);
3359
3360 assert!(
3361 sql.contains("orders.o_orderpriority"),
3362 "outer column should be qualified: {sql}"
3363 );
3364 assert!(
3365 sql.contains("lineitem.l_orderkey"),
3366 "inner column should be qualified: {sql}"
3367 );
3368 assert!(
3369 sql.contains("orders.o_orderkey"),
3370 "correlated outer column should be qualified: {sql}"
3371 );
3372 }
3373
3374 #[test]
3375 fn test_qualify_columns_rejects_unknown_table() {
3376 let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1");
3377
3378 let mut schema = MappingSchema::new();
3379 schema
3380 .add_table(
3381 "t1",
3382 &[("id".to_string(), DataType::BigInt { length: None })],
3383 None,
3384 )
3385 .expect("schema setup");
3386
3387 let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new());
3388 assert!(
3389 result.is_err(),
3390 "should reject reference to table not in scope or schema"
3391 );
3392 }
3393
3394 #[test]
3399 fn test_needs_quoting_reserved_word() {
3400 let reserved = get_reserved_words(None);
3401 assert!(needs_quoting("select", &reserved));
3402 assert!(needs_quoting("SELECT", &reserved));
3403 assert!(needs_quoting("from", &reserved));
3404 assert!(needs_quoting("WHERE", &reserved));
3405 assert!(needs_quoting("join", &reserved));
3406 assert!(needs_quoting("table", &reserved));
3407 }
3408
3409 #[test]
3410 fn test_needs_quoting_normal_identifiers() {
3411 let reserved = get_reserved_words(None);
3412 assert!(!needs_quoting("foo", &reserved));
3413 assert!(!needs_quoting("my_column", &reserved));
3414 assert!(!needs_quoting("col1", &reserved));
3415 assert!(!needs_quoting("A", &reserved));
3416 assert!(!needs_quoting("_hidden", &reserved));
3417 }
3418
3419 #[test]
3420 fn test_needs_quoting_special_characters() {
3421 let reserved = get_reserved_words(None);
3422 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
3428
3429 #[test]
3430 fn test_needs_quoting_starts_with_digit() {
3431 let reserved = get_reserved_words(None);
3432 assert!(needs_quoting("1col", &reserved));
3433 assert!(needs_quoting("123", &reserved));
3434 assert!(needs_quoting("0_start", &reserved));
3435 }
3436
3437 #[test]
3438 fn test_needs_quoting_empty() {
3439 let reserved = get_reserved_words(None);
3440 assert!(!needs_quoting("", &reserved));
3441 }
3442
3443 #[test]
3444 fn test_maybe_quote_sets_quoted_flag() {
3445 let reserved = get_reserved_words(None);
3446 let mut id = Identifier::new("select");
3447 assert!(!id.quoted);
3448 maybe_quote(&mut id, &reserved);
3449 assert!(id.quoted);
3450 }
3451
3452 #[test]
3453 fn test_maybe_quote_skips_already_quoted() {
3454 let reserved = get_reserved_words(None);
3455 let mut id = Identifier::quoted("myname");
3456 assert!(id.quoted);
3457 maybe_quote(&mut id, &reserved);
3458 assert!(id.quoted); assert_eq!(id.name, "myname"); }
3461
3462 #[test]
3463 fn test_maybe_quote_skips_star() {
3464 let reserved = get_reserved_words(None);
3465 let mut id = Identifier::new("*");
3466 maybe_quote(&mut id, &reserved);
3467 assert!(!id.quoted); }
3469
3470 #[test]
3471 fn test_maybe_quote_skips_normal() {
3472 let reserved = get_reserved_words(None);
3473 let mut id = Identifier::new("normal_col");
3474 maybe_quote(&mut id, &reserved);
3475 assert!(!id.quoted);
3476 }
3477
3478 #[test]
3479 fn test_quote_identifiers_column_with_reserved_name() {
3480 let expr = Expression::boxed_column(Column {
3482 name: Identifier::new("select"),
3483 table: None,
3484 join_mark: false,
3485 trailing_comments: vec![],
3486 span: None,
3487 inferred_type: None,
3488 });
3489 let result = quote_identifiers(expr, None);
3490 if let Expression::Column(col) = &result {
3491 assert!(col.name.quoted, "Column named 'select' should be quoted");
3492 } else {
3493 panic!("Expected Column expression");
3494 }
3495 }
3496
3497 #[test]
3498 fn test_quote_identifiers_column_with_special_chars() {
3499 let expr = Expression::boxed_column(Column {
3500 name: Identifier::new("my column"),
3501 table: None,
3502 join_mark: false,
3503 trailing_comments: vec![],
3504 span: None,
3505 inferred_type: None,
3506 });
3507 let result = quote_identifiers(expr, None);
3508 if let Expression::Column(col) = &result {
3509 assert!(col.name.quoted, "Column with space should be quoted");
3510 } else {
3511 panic!("Expected Column expression");
3512 }
3513 }
3514
3515 #[test]
3516 fn test_quote_identifiers_preserves_normal_column() {
3517 let expr = Expression::boxed_column(Column {
3518 name: Identifier::new("normal_col"),
3519 table: Some(Identifier::new("my_table")),
3520 join_mark: false,
3521 trailing_comments: vec![],
3522 span: None,
3523 inferred_type: None,
3524 });
3525 let result = quote_identifiers(expr, None);
3526 if let Expression::Column(col) = &result {
3527 assert!(!col.name.quoted, "Normal column should not be quoted");
3528 assert!(
3529 !col.table.as_ref().unwrap().quoted,
3530 "Normal table should not be quoted"
3531 );
3532 } else {
3533 panic!("Expected Column expression");
3534 }
3535 }
3536
3537 #[test]
3538 fn test_quote_identifiers_table_ref_reserved() {
3539 let expr = Expression::Table(Box::new(TableRef::new("select")));
3540 let result = quote_identifiers(expr, None);
3541 if let Expression::Table(tr) = &result {
3542 assert!(tr.name.quoted, "Table named 'select' should be quoted");
3543 } else {
3544 panic!("Expected Table expression");
3545 }
3546 }
3547
3548 #[test]
3549 fn test_quote_identifiers_table_ref_schema_and_alias() {
3550 let mut tr = TableRef::new("my_table");
3551 tr.schema = Some(Identifier::new("from"));
3552 tr.alias = Some(Identifier::new("t"));
3553 let expr = Expression::Table(Box::new(tr));
3554 let result = quote_identifiers(expr, None);
3555 if let Expression::Table(tr) = &result {
3556 assert!(!tr.name.quoted, "Normal table name should not be quoted");
3557 assert!(
3558 tr.schema.as_ref().unwrap().quoted,
3559 "Schema named 'from' should be quoted"
3560 );
3561 assert!(
3562 !tr.alias.as_ref().unwrap().quoted,
3563 "Normal alias should not be quoted"
3564 );
3565 } else {
3566 panic!("Expected Table expression");
3567 }
3568 }
3569
3570 #[test]
3571 fn test_quote_identifiers_identifier_node() {
3572 let expr = Expression::Identifier(Identifier::new("order"));
3573 let result = quote_identifiers(expr, None);
3574 if let Expression::Identifier(id) = &result {
3575 assert!(id.quoted, "Identifier named 'order' should be quoted");
3576 } else {
3577 panic!("Expected Identifier expression");
3578 }
3579 }
3580
3581 #[test]
3582 fn test_quote_identifiers_alias() {
3583 let inner = Expression::boxed_column(Column {
3584 name: Identifier::new("val"),
3585 table: None,
3586 join_mark: false,
3587 trailing_comments: vec![],
3588 span: None,
3589 inferred_type: None,
3590 });
3591 let expr = Expression::Alias(Box::new(Alias {
3592 this: inner,
3593 alias: Identifier::new("select"),
3594 column_aliases: vec![Identifier::new("from")],
3595 alias_explicit_as: false,
3596 alias_keyword: None,
3597 pre_alias_comments: vec![],
3598 trailing_comments: vec![],
3599 inferred_type: None,
3600 }));
3601 let result = quote_identifiers(expr, None);
3602 if let Expression::Alias(alias) = &result {
3603 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
3604 assert!(
3605 alias.column_aliases[0].quoted,
3606 "Column alias named 'from' should be quoted"
3607 );
3608 if let Expression::Column(col) = &alias.this {
3610 assert!(!col.name.quoted);
3611 }
3612 } else {
3613 panic!("Expected Alias expression");
3614 }
3615 }
3616
3617 #[test]
3618 fn test_quote_identifiers_select_recursive() {
3619 let expr = parse("SELECT a, b FROM t WHERE c = 1");
3621 let result = quote_identifiers(expr, None);
3622 let sql = gen(&result);
3624 assert!(sql.contains("a"));
3626 assert!(sql.contains("b"));
3627 assert!(sql.contains("t"));
3628 }
3629
3630 #[test]
3631 fn test_quote_identifiers_digit_start() {
3632 let expr = Expression::boxed_column(Column {
3633 name: Identifier::new("1col"),
3634 table: None,
3635 join_mark: false,
3636 trailing_comments: vec![],
3637 span: None,
3638 inferred_type: None,
3639 });
3640 let result = quote_identifiers(expr, None);
3641 if let Expression::Column(col) = &result {
3642 assert!(
3643 col.name.quoted,
3644 "Column starting with digit should be quoted"
3645 );
3646 } else {
3647 panic!("Expected Column expression");
3648 }
3649 }
3650
3651 #[test]
3652 fn test_quote_identifiers_with_mysql_dialect() {
3653 let reserved = get_reserved_words(Some(DialectType::MySQL));
3654 assert!(needs_quoting("KILL", &reserved));
3656 assert!(needs_quoting("FORCE", &reserved));
3658 }
3659
3660 #[test]
3661 fn test_quote_identifiers_with_postgresql_dialect() {
3662 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
3663 assert!(needs_quoting("ILIKE", &reserved));
3665 assert!(needs_quoting("VERBOSE", &reserved));
3667 }
3668
3669 #[test]
3670 fn test_quote_identifiers_with_bigquery_dialect() {
3671 let reserved = get_reserved_words(Some(DialectType::BigQuery));
3672 assert!(needs_quoting("STRUCT", &reserved));
3674 assert!(needs_quoting("PROTO", &reserved));
3676 }
3677
3678 #[test]
3679 fn test_quote_identifiers_case_insensitive_reserved() {
3680 let reserved = get_reserved_words(None);
3681 assert!(needs_quoting("Select", &reserved));
3682 assert!(needs_quoting("sElEcT", &reserved));
3683 assert!(needs_quoting("FROM", &reserved));
3684 assert!(needs_quoting("from", &reserved));
3685 }
3686
3687 #[test]
3688 fn test_quote_identifiers_join_using() {
3689 let mut join = crate::expressions::Join {
3691 this: Expression::Table(Box::new(TableRef::new("other"))),
3692 on: None,
3693 using: vec![Identifier::new("key"), Identifier::new("value")],
3694 kind: crate::expressions::JoinKind::Inner,
3695 use_inner_keyword: false,
3696 use_outer_keyword: false,
3697 deferred_condition: false,
3698 join_hint: None,
3699 match_condition: None,
3700 pivots: vec![],
3701 comments: vec![],
3702 nesting_group: 0,
3703 directed: false,
3704 };
3705 let reserved = get_reserved_words(None);
3706 quote_join(&mut join, &reserved);
3707 assert!(
3709 join.using[0].quoted,
3710 "USING identifier 'key' should be quoted"
3711 );
3712 assert!(
3713 !join.using[1].quoted,
3714 "USING identifier 'value' should not be quoted"
3715 );
3716 }
3717
3718 #[test]
3719 fn test_quote_identifiers_cte() {
3720 let mut cte = crate::expressions::Cte {
3722 alias: Identifier::new("select"),
3723 this: Expression::boxed_column(Column {
3724 name: Identifier::new("x"),
3725 table: None,
3726 join_mark: false,
3727 trailing_comments: vec![],
3728 span: None,
3729 inferred_type: None,
3730 }),
3731 columns: vec![Identifier::new("from"), Identifier::new("normal")],
3732 materialized: None,
3733 key_expressions: vec![],
3734 alias_first: false,
3735 comments: Vec::new(),
3736 };
3737 let reserved = get_reserved_words(None);
3738 maybe_quote(&mut cte.alias, &reserved);
3739 for c in &mut cte.columns {
3740 maybe_quote(c, &reserved);
3741 }
3742 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
3743 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
3744 assert!(
3745 !cte.columns[1].quoted,
3746 "CTE column 'normal' should not be quoted"
3747 );
3748 }
3749
3750 #[test]
3751 fn test_quote_identifiers_binary_ops_recurse() {
3752 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
3755 Expression::boxed_column(Column {
3756 name: Identifier::new("select"),
3757 table: None,
3758 join_mark: false,
3759 trailing_comments: vec![],
3760 span: None,
3761 inferred_type: None,
3762 }),
3763 Expression::boxed_column(Column {
3764 name: Identifier::new("normal"),
3765 table: None,
3766 join_mark: false,
3767 trailing_comments: vec![],
3768 span: None,
3769 inferred_type: None,
3770 }),
3771 )));
3772 let result = quote_identifiers(expr, None);
3773 if let Expression::Add(bin) = &result {
3774 if let Expression::Column(left) = &bin.left {
3775 assert!(
3776 left.name.quoted,
3777 "'select' column should be quoted in binary op"
3778 );
3779 }
3780 if let Expression::Column(right) = &bin.right {
3781 assert!(!right.name.quoted, "'normal' column should not be quoted");
3782 }
3783 } else {
3784 panic!("Expected Add expression");
3785 }
3786 }
3787
3788 #[test]
3789 fn test_quote_identifiers_already_quoted_preserved() {
3790 let expr = Expression::boxed_column(Column {
3792 name: Identifier::quoted("normal_name"),
3793 table: None,
3794 join_mark: false,
3795 trailing_comments: vec![],
3796 span: None,
3797 inferred_type: None,
3798 });
3799 let result = quote_identifiers(expr, None);
3800 if let Expression::Column(col) = &result {
3801 assert!(
3802 col.name.quoted,
3803 "Already-quoted identifier should remain quoted"
3804 );
3805 } else {
3806 panic!("Expected Column expression");
3807 }
3808 }
3809
3810 #[test]
3811 fn test_quote_identifiers_full_parsed_query() {
3812 let mut select = crate::expressions::Select::new();
3815 select.expressions.push(Expression::boxed_column(Column {
3816 name: Identifier::new("order"),
3817 table: Some(Identifier::new("t")),
3818 join_mark: false,
3819 trailing_comments: vec![],
3820 span: None,
3821 inferred_type: None,
3822 }));
3823 select.from = Some(crate::expressions::From {
3824 expressions: vec![Expression::Table(Box::new(TableRef::new("t")))],
3825 });
3826 let expr = Expression::Select(Box::new(select));
3827
3828 let result = quote_identifiers(expr, None);
3829 if let Expression::Select(sel) = &result {
3830 if let Expression::Column(col) = &sel.expressions[0] {
3831 assert!(col.name.quoted, "Column named 'order' should be quoted");
3832 assert!(
3833 !col.table.as_ref().unwrap().quoted,
3834 "Table 't' should not be quoted"
3835 );
3836 } else {
3837 panic!("Expected Column in SELECT list");
3838 }
3839 } else {
3840 panic!("Expected Select expression");
3841 }
3842 }
3843
3844 #[test]
3845 fn test_get_reserved_words_all_dialects() {
3846 let dialects = [
3848 None,
3849 Some(DialectType::Generic),
3850 Some(DialectType::MySQL),
3851 Some(DialectType::PostgreSQL),
3852 Some(DialectType::BigQuery),
3853 Some(DialectType::Snowflake),
3854 Some(DialectType::TSQL),
3855 Some(DialectType::ClickHouse),
3856 Some(DialectType::DuckDB),
3857 Some(DialectType::Hive),
3858 Some(DialectType::Spark),
3859 Some(DialectType::Trino),
3860 Some(DialectType::Oracle),
3861 Some(DialectType::Redshift),
3862 ];
3863 for dialect in &dialects {
3864 let words = get_reserved_words(*dialect);
3865 assert!(
3867 words.contains("SELECT"),
3868 "All dialects should have SELECT as reserved"
3869 );
3870 assert!(
3871 words.contains("FROM"),
3872 "All dialects should have FROM as reserved"
3873 );
3874 }
3875 }
3876}