1use crate::dialects::transform_recursive;
9use crate::dialects::DialectType;
10use crate::expressions::{
11 Alias, BinaryOp, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren,
12 Select, TableRef, VarArgFunc, With,
13};
14use crate::resolver::{Resolver, ResolverError};
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{build_scope, traverse_scope, Scope};
17use std::cell::RefCell;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum QualifyColumnsError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Unknown column: {0}")]
28 UnknownColumn(String),
29
30 #[error("Ambiguous column: {0}")]
31 AmbiguousColumn(String),
32
33 #[error("Cannot automatically join: {0}")]
34 CannotAutoJoin(String),
35
36 #[error("Unknown output column: {0}")]
37 UnknownOutputColumn(String),
38
39 #[error("Column could not be resolved: {column}{for_table}")]
40 ColumnNotResolved { column: String, for_table: String },
41
42 #[error("Resolver error: {0}")]
43 ResolverError(#[from] ResolverError),
44}
45
46pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
48
49#[derive(Debug, Clone, Default)]
51pub struct QualifyColumnsOptions {
52 pub expand_alias_refs: bool,
54 pub expand_stars: bool,
56 pub infer_schema: Option<bool>,
58 pub allow_partial_qualification: bool,
60 pub dialect: Option<DialectType>,
62}
63
64impl QualifyColumnsOptions {
65 pub fn new() -> Self {
67 Self {
68 expand_alias_refs: true,
69 expand_stars: true,
70 infer_schema: None,
71 allow_partial_qualification: false,
72 dialect: None,
73 }
74 }
75
76 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
78 self.expand_alias_refs = expand;
79 self
80 }
81
82 pub fn with_expand_stars(mut self, expand: bool) -> Self {
84 self.expand_stars = expand;
85 self
86 }
87
88 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
90 self.dialect = Some(dialect);
91 self
92 }
93
94 pub fn with_allow_partial(mut self, allow: bool) -> Self {
96 self.allow_partial_qualification = allow;
97 self
98 }
99}
100
101pub fn qualify_columns(
116 expression: Expression,
117 schema: &dyn Schema,
118 options: &QualifyColumnsOptions,
119) -> QualifyColumnsResult<Expression> {
120 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
121 let dialect = options.dialect.or_else(|| schema.dialect());
122 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
123
124 let transformed = transform_recursive(expression, &|node| {
125 if first_error.borrow().is_some() {
126 return Ok(node);
127 }
128
129 match node {
130 Expression::Select(mut select) => {
131 if let Some(with) = &mut select.with {
132 pushdown_cte_alias_columns_with(with);
133 }
134
135 let scope_expr = Expression::Select(select.clone());
136 let scope = build_scope(&scope_expr);
137 let mut resolver = Resolver::new(&scope, schema, infer_schema);
138
139 let column_tables = if first_error.borrow().is_none() {
141 match expand_using(&mut select, &scope, &mut resolver) {
142 Ok(ct) => ct,
143 Err(err) => {
144 *first_error.borrow_mut() = Some(err);
145 HashMap::new()
146 }
147 }
148 } else {
149 HashMap::new()
150 };
151
152 if first_error.borrow().is_none() {
154 if let Err(err) = qualify_columns_in_scope(
155 &mut select,
156 &scope,
157 &mut resolver,
158 options.allow_partial_qualification,
159 ) {
160 *first_error.borrow_mut() = Some(err);
161 }
162 }
163
164 if first_error.borrow().is_none() && options.expand_alias_refs {
166 if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
167 *first_error.borrow_mut() = Some(err);
168 }
169 }
170
171 if first_error.borrow().is_none() && options.expand_stars {
173 if let Err(err) =
174 expand_stars(&mut select, &scope, &mut resolver, &column_tables)
175 {
176 *first_error.borrow_mut() = Some(err);
177 }
178 }
179
180 if first_error.borrow().is_none() {
182 if let Err(err) = qualify_outputs_select(&mut select) {
183 *first_error.borrow_mut() = Some(err);
184 }
185 }
186
187 if first_error.borrow().is_none() {
189 if let Err(err) = expand_group_by(&mut select, dialect) {
190 *first_error.borrow_mut() = Some(err);
191 }
192 }
193
194 Ok(Expression::Select(select))
195 }
196 _ => Ok(node),
197 }
198 })
199 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
200
201 if let Some(err) = first_error.into_inner() {
202 return Err(err);
203 }
204
205 Ok(transformed)
206}
207
208pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
213 let mut all_unqualified = Vec::new();
214
215 for scope in traverse_scope(expression) {
216 if let Expression::Select(_) = &scope.expression {
217 let unqualified = get_unqualified_columns(&scope);
219
220 let external = get_external_columns(&scope);
222 if !external.is_empty() && !is_correlated_subquery(&scope) {
223 let first = &external[0];
224 let for_table = if first.table.is_some() {
225 format!(" for table: '{}'", first.table.as_ref().unwrap())
226 } else {
227 String::new()
228 };
229 return Err(QualifyColumnsError::ColumnNotResolved {
230 column: first.name.clone(),
231 for_table,
232 });
233 }
234
235 all_unqualified.extend(unqualified);
236 }
237 }
238
239 if !all_unqualified.is_empty() {
240 let first = &all_unqualified[0];
241 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
242 }
243
244 Ok(())
245}
246
247fn get_source_name(expr: &Expression) -> Option<String> {
249 match expr {
250 Expression::Table(t) => Some(
251 t.alias
252 .as_ref()
253 .map(|a| a.name.clone())
254 .unwrap_or_else(|| t.name.name.clone()),
255 ),
256 Expression::Subquery(sq) => sq.alias.as_ref().map(|a| a.name.clone()),
257 _ => None,
258 }
259}
260
261fn get_ordered_source_names(select: &Select) -> Vec<String> {
264 let mut ordered = Vec::new();
265 if let Some(from) = &select.from {
266 for expr in &from.expressions {
267 if let Some(name) = get_source_name(expr) {
268 ordered.push(name);
269 }
270 }
271 }
272 for join in &select.joins {
273 if let Some(name) = get_source_name(&join.this) {
274 ordered.push(name);
275 }
276 }
277 ordered
278}
279
280fn make_coalesce(column_name: &str, tables: &[String]) -> Expression {
282 let args: Vec<Expression> = tables
283 .iter()
284 .map(|t| Expression::qualified_column(t.as_str(), column_name))
285 .collect();
286 Expression::Coalesce(Box::new(VarArgFunc {
287 expressions: args,
288 original_name: None,
289 inferred_type: None,
290 }))
291}
292
293fn expand_using(
299 select: &mut Select,
300 _scope: &Scope,
301 resolver: &mut Resolver,
302) -> QualifyColumnsResult<HashMap<String, Vec<String>>> {
303 let mut columns: HashMap<String, String> = HashMap::new();
305
306 let mut column_tables: HashMap<String, Vec<String>> = HashMap::new();
308
309 let join_names: HashSet<String> = select
311 .joins
312 .iter()
313 .filter_map(|j| get_source_name(&j.this))
314 .collect();
315
316 let all_ordered = get_ordered_source_names(select);
317 let mut ordered: Vec<String> = all_ordered
318 .iter()
319 .filter(|name| !join_names.contains(name.as_str()))
320 .cloned()
321 .collect();
322
323 if join_names.is_empty() {
324 return Ok(column_tables);
325 }
326
327 fn update_source_columns(
329 source_name: &str,
330 columns: &mut HashMap<String, String>,
331 resolver: &mut Resolver,
332 ) {
333 if let Ok(source_cols) = resolver.get_source_columns(source_name) {
334 for col_name in source_cols {
335 columns
336 .entry(col_name)
337 .or_insert_with(|| source_name.to_string());
338 }
339 }
340 }
341
342 for source_name in &ordered {
344 update_source_columns(source_name, &mut columns, resolver);
345 }
346
347 for i in 0..select.joins.len() {
348 let source_table = ordered.last().cloned().unwrap_or_default();
350 if !source_table.is_empty() {
351 update_source_columns(&source_table, &mut columns, resolver);
352 }
353
354 let join_table = get_source_name(&select.joins[i].this).unwrap_or_default();
356 ordered.push(join_table.clone());
357
358 if select.joins[i].using.is_empty() {
360 continue;
361 }
362
363 let _join_columns: Vec<String> =
364 resolver.get_source_columns(&join_table).unwrap_or_default();
365
366 let using_identifiers: Vec<String> = select.joins[i]
367 .using
368 .iter()
369 .map(|id| id.name.clone())
370 .collect();
371
372 let using_count = using_identifiers.len();
373 let is_semi_or_anti = matches!(
374 select.joins[i].kind,
375 crate::expressions::JoinKind::Semi
376 | crate::expressions::JoinKind::Anti
377 | crate::expressions::JoinKind::LeftSemi
378 | crate::expressions::JoinKind::LeftAnti
379 | crate::expressions::JoinKind::RightSemi
380 | crate::expressions::JoinKind::RightAnti
381 );
382
383 let mut conditions: Vec<Expression> = Vec::new();
384
385 for identifier in &using_identifiers {
386 let table = columns
387 .get(identifier)
388 .cloned()
389 .unwrap_or_else(|| source_table.clone());
390
391 let lhs = if i == 0 || using_count == 1 {
393 Expression::qualified_column(table.as_str(), identifier.as_str())
395 } else {
396 let coalesce_cols: Vec<String> = ordered[..ordered.len() - 1]
399 .iter()
400 .filter(|t| {
401 resolver
402 .get_source_columns(t)
403 .unwrap_or_default()
404 .contains(identifier)
405 })
406 .cloned()
407 .collect();
408
409 if coalesce_cols.len() > 1 {
410 make_coalesce(identifier, &coalesce_cols)
411 } else {
412 Expression::qualified_column(table.as_str(), identifier.as_str())
413 }
414 };
415
416 let rhs = Expression::qualified_column(join_table.as_str(), identifier.as_str());
418
419 conditions.push(Expression::Eq(Box::new(BinaryOp::new(lhs, rhs))));
420
421 if !is_semi_or_anti {
423 let tables = column_tables
424 .entry(identifier.clone())
425 .or_insert_with(Vec::new);
426 if !tables.contains(&table) {
427 tables.push(table.clone());
428 }
429 if !tables.contains(&join_table) {
430 tables.push(join_table.clone());
431 }
432 }
433 }
434
435 let on_condition = conditions
437 .into_iter()
438 .reduce(|acc, cond| Expression::And(Box::new(BinaryOp::new(acc, cond))))
439 .expect("at least one USING column");
440
441 select.joins[i].on = Some(on_condition);
443 select.joins[i].using = vec![];
444 }
445
446 if !column_tables.is_empty() {
448 let mut new_expressions = Vec::with_capacity(select.expressions.len());
450 for expr in &select.expressions {
451 match expr {
452 Expression::Column(col)
453 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
454 {
455 let tables = &column_tables[&col.name.name];
456 let coalesce = make_coalesce(&col.name.name, tables);
457 new_expressions.push(Expression::Alias(Box::new(Alias {
459 this: coalesce,
460 alias: Identifier::new(&col.name.name),
461 column_aliases: vec![],
462 pre_alias_comments: vec![],
463 trailing_comments: vec![],
464 inferred_type: None,
465 })));
466 }
467 _ => {
468 let mut rewritten = expr.clone();
469 rewrite_using_columns_in_expression(&mut rewritten, &column_tables);
470 new_expressions.push(rewritten);
471 }
472 }
473 }
474 select.expressions = new_expressions;
475
476 if let Some(where_clause) = &mut select.where_clause {
478 rewrite_using_columns_in_expression(&mut where_clause.this, &column_tables);
479 }
480
481 if let Some(group_by) = &mut select.group_by {
483 for expr in &mut group_by.expressions {
484 rewrite_using_columns_in_expression(expr, &column_tables);
485 }
486 }
487
488 if let Some(having) = &mut select.having {
490 rewrite_using_columns_in_expression(&mut having.this, &column_tables);
491 }
492
493 if let Some(qualify) = &mut select.qualify {
495 rewrite_using_columns_in_expression(&mut qualify.this, &column_tables);
496 }
497
498 if let Some(order_by) = &mut select.order_by {
500 for ordered in &mut order_by.expressions {
501 rewrite_using_columns_in_expression(&mut ordered.this, &column_tables);
502 }
503 }
504 }
505
506 Ok(column_tables)
507}
508
509fn rewrite_using_columns_in_expression(
511 expr: &mut Expression,
512 column_tables: &HashMap<String, Vec<String>>,
513) {
514 let transformed = transform_recursive(expr.clone(), &|node| match node {
515 Expression::Column(col)
516 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
517 {
518 let tables = &column_tables[&col.name.name];
519 Ok(make_coalesce(&col.name.name, tables))
520 }
521 other => Ok(other),
522 });
523
524 if let Ok(next) = transformed {
525 *expr = next;
526 }
527}
528
529fn qualify_columns_in_scope(
531 select: &mut Select,
532 scope: &Scope,
533 resolver: &mut Resolver,
534 allow_partial: bool,
535) -> QualifyColumnsResult<()> {
536 for expr in &mut select.expressions {
537 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
538 }
539 if let Some(where_clause) = &mut select.where_clause {
540 qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
541 }
542 if let Some(group_by) = &mut select.group_by {
543 for expr in &mut group_by.expressions {
544 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
545 }
546 }
547 if let Some(having) = &mut select.having {
548 qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
549 }
550 if let Some(qualify) = &mut select.qualify {
551 qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
552 }
553 if let Some(order_by) = &mut select.order_by {
554 for ordered in &mut order_by.expressions {
555 qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
556 }
557 }
558 for join in &mut select.joins {
559 qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
560 if let Some(on) = &mut join.on {
561 qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
562 }
563 }
564 Ok(())
565}
566
567fn expand_alias_refs(
574 select: &mut Select,
575 _resolver: &mut Resolver,
576 _dialect: Option<DialectType>,
577) -> QualifyColumnsResult<()> {
578 let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
579
580 for (i, expr) in select.expressions.iter_mut().enumerate() {
581 replace_alias_refs_in_expression(expr, &alias_to_expression, false);
582 if let Expression::Alias(alias) = expr {
583 alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
584 }
585 }
586
587 if let Some(where_clause) = &mut select.where_clause {
588 replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
589 }
590 if let Some(group_by) = &mut select.group_by {
591 for expr in &mut group_by.expressions {
592 replace_alias_refs_in_expression(expr, &alias_to_expression, true);
593 }
594 }
595 if let Some(having) = &mut select.having {
596 replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
597 }
598 if let Some(qualify) = &mut select.qualify {
599 replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
600 }
601 if let Some(order_by) = &mut select.order_by {
602 for ordered in &mut order_by.expressions {
603 replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
604 }
605 }
606
607 Ok(())
608}
609
610fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
617 let projections = select.expressions.clone();
618
619 if let Some(group_by) = &mut select.group_by {
620 for group_expr in &mut group_by.expressions {
621 if let Some(index) = positional_reference(group_expr) {
622 let replacement = select_expression_at_position(&projections, index)?;
623 *group_expr = replacement;
624 }
625 }
626 }
627 Ok(())
628}
629
630fn expand_stars(
640 select: &mut Select,
641 _scope: &Scope,
642 resolver: &mut Resolver,
643 column_tables: &HashMap<String, Vec<String>>,
644) -> QualifyColumnsResult<()> {
645 let mut new_selections: Vec<Expression> = Vec::new();
646 let mut has_star = false;
647 let mut coalesced_columns: HashSet<String> = HashSet::new();
648
649 let ordered_sources = get_ordered_source_names(select);
651
652 for expr in &select.expressions {
653 match expr {
654 Expression::Star(_) => {
655 has_star = true;
656 for source_name in &ordered_sources {
657 if let Ok(columns) = resolver.get_source_columns(source_name) {
658 if columns.contains(&"*".to_string()) || columns.is_empty() {
659 return Ok(());
660 }
661 for col_name in &columns {
662 if coalesced_columns.contains(col_name) {
663 continue;
665 }
666 if let Some(tables) = column_tables.get(col_name) {
667 if tables.contains(source_name) {
668 coalesced_columns.insert(col_name.clone());
670 let coalesce = make_coalesce(col_name, tables);
671 new_selections.push(Expression::Alias(Box::new(Alias {
672 this: coalesce,
673 alias: Identifier::new(col_name),
674 column_aliases: vec![],
675 pre_alias_comments: vec![],
676 trailing_comments: vec![],
677 inferred_type: None,
678 })));
679 continue;
680 }
681 }
682 new_selections
683 .push(create_qualified_column(col_name, Some(source_name)));
684 }
685 }
686 }
687 }
688 Expression::Column(col) if is_star_column(col) => {
689 has_star = true;
690 if let Some(table) = &col.table {
691 let table_name = &table.name;
692 if !ordered_sources.contains(table_name) {
693 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
694 }
695 if let Ok(columns) = resolver.get_source_columns(table_name) {
696 if columns.contains(&"*".to_string()) || columns.is_empty() {
697 return Ok(());
698 }
699 for col_name in &columns {
700 if coalesced_columns.contains(col_name) {
701 continue;
702 }
703 if let Some(tables) = column_tables.get(col_name) {
704 if tables.contains(table_name) {
705 coalesced_columns.insert(col_name.clone());
706 let coalesce = make_coalesce(col_name, tables);
707 new_selections.push(Expression::Alias(Box::new(Alias {
708 this: coalesce,
709 alias: Identifier::new(col_name),
710 column_aliases: vec![],
711 pre_alias_comments: vec![],
712 trailing_comments: vec![],
713 inferred_type: None,
714 })));
715 continue;
716 }
717 }
718 new_selections
719 .push(create_qualified_column(col_name, Some(table_name)));
720 }
721 }
722 }
723 }
724 _ => new_selections.push(expr.clone()),
725 }
726 }
727
728 if has_star {
729 select.expressions = new_selections;
730 }
731
732 Ok(())
733}
734
735pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
742 if let Expression::Select(mut select) = scope.expression.clone() {
743 qualify_outputs_select(&mut select)?;
744 }
745 Ok(())
746}
747
748fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
749 let mut new_selections: Vec<Expression> = Vec::new();
750
751 for (i, expr) in select.expressions.iter().enumerate() {
752 match expr {
753 Expression::Alias(_) => new_selections.push(expr.clone()),
754 Expression::Column(col) => {
755 new_selections.push(create_alias(expr.clone(), &col.name.name));
756 }
757 Expression::Star(_) => new_selections.push(expr.clone()),
758 _ => {
759 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
760 new_selections.push(create_alias(expr.clone(), &alias_name));
761 }
762 }
763 }
764
765 select.expressions = new_selections;
766 Ok(())
767}
768
769fn qualify_columns_in_expression(
770 expr: &mut Expression,
771 scope: &Scope,
772 resolver: &mut Resolver,
773 allow_partial: bool,
774) -> QualifyColumnsResult<()> {
775 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
776 let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
777
778 let transformed = transform_recursive(expr.clone(), &|node| {
779 if first_error.borrow().is_some() {
780 return Ok(node);
781 }
782
783 match node {
784 Expression::Column(mut col) => {
785 if let Err(err) = qualify_single_column(
786 &mut col,
787 scope,
788 &mut resolver_cell.borrow_mut(),
789 allow_partial,
790 ) {
791 *first_error.borrow_mut() = Some(err);
792 }
793 Ok(Expression::Column(col))
794 }
795 _ => Ok(node),
796 }
797 })
798 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
799
800 if let Some(err) = first_error.into_inner() {
801 return Err(err);
802 }
803
804 *expr = transformed;
805 Ok(())
806}
807
808fn qualify_single_column(
809 col: &mut Column,
810 scope: &Scope,
811 resolver: &mut Resolver,
812 allow_partial: bool,
813) -> QualifyColumnsResult<()> {
814 if is_star_column(col) {
815 return Ok(());
816 }
817
818 if let Some(table) = &col.table {
819 let table_name = &table.name;
820 if !scope.sources.contains_key(table_name) {
821 if resolver.table_exists_in_schema(table_name) {
825 return Ok(());
826 }
827 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
828 }
829
830 if let Ok(source_columns) = resolver.get_source_columns(table_name) {
831 let normalized_column_name = normalize_column_name(&col.name.name, resolver.dialect);
832 if !allow_partial
833 && !source_columns.is_empty()
834 && !source_columns.iter().any(|column| {
835 normalize_column_name(column, resolver.dialect) == normalized_column_name
836 })
837 && !source_columns.contains(&"*".to_string())
838 {
839 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
840 }
841 }
842 return Ok(());
843 }
844
845 if let Some(table_name) = resolver.get_table(&col.name.name) {
846 col.table = Some(Identifier::new(table_name));
847 return Ok(());
848 }
849
850 if let Some(outer_table) = resolver.find_column_in_outer_schema_tables(&col.name.name) {
853 col.table = Some(Identifier::new(outer_table));
854 return Ok(());
855 }
856
857 if !allow_partial {
858 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
859 }
860
861 Ok(())
862}
863
864fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
865 normalize_name(name, dialect, false, true)
866}
867
868fn replace_alias_refs_in_expression(
869 expr: &mut Expression,
870 alias_to_expression: &HashMap<String, (Expression, usize)>,
871 literal_index: bool,
872) {
873 let transformed = transform_recursive(expr.clone(), &|node| match node {
874 Expression::Column(col) if col.table.is_none() => {
875 if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
876 if literal_index && matches!(alias_expr, Expression::Literal(_)) {
877 return Ok(Expression::number(*index as i64));
878 }
879 return Ok(Expression::Paren(Box::new(Paren {
880 this: alias_expr.clone(),
881 trailing_comments: vec![],
882 })));
883 }
884 Ok(Expression::Column(col))
885 }
886 other => Ok(other),
887 });
888
889 if let Ok(next) = transformed {
890 *expr = next;
891 }
892}
893
894fn positional_reference(expr: &Expression) -> Option<usize> {
895 match expr {
896 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
897 let Literal::Number(value) = lit.as_ref() else {
898 unreachable!()
899 };
900 value.parse::<usize>().ok()
901 }
902 _ => None,
903 }
904}
905
906fn select_expression_at_position(
907 projections: &[Expression],
908 index: usize,
909) -> QualifyColumnsResult<Expression> {
910 if index == 0 || index > projections.len() {
911 return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
912 }
913
914 let projection = projections[index - 1].clone();
915 Ok(match projection {
916 Expression::Alias(alias) => alias.this.clone(),
917 other => other,
918 })
919}
920
921fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
924 let mut words: HashSet<&'static str> = [
926 "ADD",
928 "ALL",
929 "ALTER",
930 "AND",
931 "ANY",
932 "AS",
933 "ASC",
934 "BETWEEN",
935 "BY",
936 "CASE",
937 "CAST",
938 "CHECK",
939 "COLUMN",
940 "CONSTRAINT",
941 "CREATE",
942 "CROSS",
943 "CURRENT",
944 "CURRENT_DATE",
945 "CURRENT_TIME",
946 "CURRENT_TIMESTAMP",
947 "CURRENT_USER",
948 "DATABASE",
949 "DEFAULT",
950 "DELETE",
951 "DESC",
952 "DISTINCT",
953 "DROP",
954 "ELSE",
955 "END",
956 "ESCAPE",
957 "EXCEPT",
958 "EXISTS",
959 "FALSE",
960 "FETCH",
961 "FOR",
962 "FOREIGN",
963 "FROM",
964 "FULL",
965 "GRANT",
966 "GROUP",
967 "HAVING",
968 "IF",
969 "IN",
970 "INDEX",
971 "INNER",
972 "INSERT",
973 "INTERSECT",
974 "INTO",
975 "IS",
976 "JOIN",
977 "KEY",
978 "LEFT",
979 "LIKE",
980 "LIMIT",
981 "NATURAL",
982 "NOT",
983 "NULL",
984 "OFFSET",
985 "ON",
986 "OR",
987 "ORDER",
988 "OUTER",
989 "PRIMARY",
990 "REFERENCES",
991 "REPLACE",
992 "RETURNING",
993 "RIGHT",
994 "ROLLBACK",
995 "ROW",
996 "ROWS",
997 "SELECT",
998 "SESSION_USER",
999 "SET",
1000 "SOME",
1001 "TABLE",
1002 "THEN",
1003 "TO",
1004 "TRUE",
1005 "TRUNCATE",
1006 "UNION",
1007 "UNIQUE",
1008 "UPDATE",
1009 "USING",
1010 "VALUES",
1011 "VIEW",
1012 "WHEN",
1013 "WHERE",
1014 "WINDOW",
1015 "WITH",
1016 ]
1017 .iter()
1018 .copied()
1019 .collect();
1020
1021 match dialect {
1023 Some(DialectType::MySQL) => {
1024 words.extend(
1025 [
1026 "ANALYZE",
1027 "BOTH",
1028 "CHANGE",
1029 "CONDITION",
1030 "DATABASES",
1031 "DAY_HOUR",
1032 "DAY_MICROSECOND",
1033 "DAY_MINUTE",
1034 "DAY_SECOND",
1035 "DELAYED",
1036 "DETERMINISTIC",
1037 "DIV",
1038 "DUAL",
1039 "EACH",
1040 "ELSEIF",
1041 "ENCLOSED",
1042 "EXPLAIN",
1043 "FLOAT4",
1044 "FLOAT8",
1045 "FORCE",
1046 "HOUR_MICROSECOND",
1047 "HOUR_MINUTE",
1048 "HOUR_SECOND",
1049 "IGNORE",
1050 "INFILE",
1051 "INT1",
1052 "INT2",
1053 "INT3",
1054 "INT4",
1055 "INT8",
1056 "ITERATE",
1057 "KEYS",
1058 "KILL",
1059 "LEADING",
1060 "LEAVE",
1061 "LINES",
1062 "LOAD",
1063 "LOCK",
1064 "LONG",
1065 "LONGBLOB",
1066 "LONGTEXT",
1067 "LOOP",
1068 "LOW_PRIORITY",
1069 "MATCH",
1070 "MEDIUMBLOB",
1071 "MEDIUMINT",
1072 "MEDIUMTEXT",
1073 "MINUTE_MICROSECOND",
1074 "MINUTE_SECOND",
1075 "MOD",
1076 "MODIFIES",
1077 "NO_WRITE_TO_BINLOG",
1078 "OPTIMIZE",
1079 "OPTIONALLY",
1080 "OUT",
1081 "OUTFILE",
1082 "PURGE",
1083 "READS",
1084 "REGEXP",
1085 "RELEASE",
1086 "RENAME",
1087 "REPEAT",
1088 "REQUIRE",
1089 "RESIGNAL",
1090 "RETURN",
1091 "REVOKE",
1092 "RLIKE",
1093 "SCHEMA",
1094 "SCHEMAS",
1095 "SECOND_MICROSECOND",
1096 "SENSITIVE",
1097 "SEPARATOR",
1098 "SHOW",
1099 "SIGNAL",
1100 "SPATIAL",
1101 "SQL",
1102 "SQLEXCEPTION",
1103 "SQLSTATE",
1104 "SQLWARNING",
1105 "SQL_BIG_RESULT",
1106 "SQL_CALC_FOUND_ROWS",
1107 "SQL_SMALL_RESULT",
1108 "SSL",
1109 "STARTING",
1110 "STRAIGHT_JOIN",
1111 "TERMINATED",
1112 "TINYBLOB",
1113 "TINYINT",
1114 "TINYTEXT",
1115 "TRAILING",
1116 "TRIGGER",
1117 "UNDO",
1118 "UNLOCK",
1119 "UNSIGNED",
1120 "USAGE",
1121 "UTC_DATE",
1122 "UTC_TIME",
1123 "UTC_TIMESTAMP",
1124 "VARBINARY",
1125 "VARCHARACTER",
1126 "WHILE",
1127 "WRITE",
1128 "XOR",
1129 "YEAR_MONTH",
1130 "ZEROFILL",
1131 ]
1132 .iter()
1133 .copied(),
1134 );
1135 }
1136 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
1137 words.extend(
1138 [
1139 "ANALYSE",
1140 "ANALYZE",
1141 "ARRAY",
1142 "AUTHORIZATION",
1143 "BINARY",
1144 "BOTH",
1145 "COLLATE",
1146 "CONCURRENTLY",
1147 "DO",
1148 "FREEZE",
1149 "ILIKE",
1150 "INITIALLY",
1151 "ISNULL",
1152 "LATERAL",
1153 "LEADING",
1154 "LOCALTIME",
1155 "LOCALTIMESTAMP",
1156 "NOTNULL",
1157 "ONLY",
1158 "OVERLAPS",
1159 "PLACING",
1160 "SIMILAR",
1161 "SYMMETRIC",
1162 "TABLESAMPLE",
1163 "TRAILING",
1164 "VARIADIC",
1165 "VERBOSE",
1166 ]
1167 .iter()
1168 .copied(),
1169 );
1170 }
1171 Some(DialectType::BigQuery) => {
1172 words.extend(
1173 [
1174 "ASSERT_ROWS_MODIFIED",
1175 "COLLATE",
1176 "CONTAINS",
1177 "CUBE",
1178 "DEFINE",
1179 "ENUM",
1180 "EXTRACT",
1181 "FOLLOWING",
1182 "GROUPING",
1183 "GROUPS",
1184 "HASH",
1185 "IGNORE",
1186 "LATERAL",
1187 "LOOKUP",
1188 "MERGE",
1189 "NEW",
1190 "NO",
1191 "NULLS",
1192 "OF",
1193 "OVER",
1194 "PARTITION",
1195 "PRECEDING",
1196 "PROTO",
1197 "RANGE",
1198 "RECURSIVE",
1199 "RESPECT",
1200 "ROLLUP",
1201 "STRUCT",
1202 "TABLESAMPLE",
1203 "TREAT",
1204 "UNBOUNDED",
1205 "WITHIN",
1206 ]
1207 .iter()
1208 .copied(),
1209 );
1210 }
1211 Some(DialectType::Snowflake) => {
1212 words.extend(
1213 [
1214 "ACCOUNT",
1215 "BOTH",
1216 "CONNECT",
1217 "FOLLOWING",
1218 "ILIKE",
1219 "INCREMENT",
1220 "ISSUE",
1221 "LATERAL",
1222 "LEADING",
1223 "LOCALTIME",
1224 "LOCALTIMESTAMP",
1225 "MINUS",
1226 "QUALIFY",
1227 "REGEXP",
1228 "RLIKE",
1229 "SOME",
1230 "START",
1231 "TABLESAMPLE",
1232 "TOP",
1233 "TRAILING",
1234 "TRY_CAST",
1235 ]
1236 .iter()
1237 .copied(),
1238 );
1239 }
1240 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
1241 words.extend(
1242 [
1243 "BACKUP",
1244 "BREAK",
1245 "BROWSE",
1246 "BULK",
1247 "CASCADE",
1248 "CHECKPOINT",
1249 "CLOSE",
1250 "CLUSTERED",
1251 "COALESCE",
1252 "COMPUTE",
1253 "CONTAINS",
1254 "CONTAINSTABLE",
1255 "CONTINUE",
1256 "CONVERT",
1257 "DBCC",
1258 "DEALLOCATE",
1259 "DENY",
1260 "DISK",
1261 "DISTRIBUTED",
1262 "DUMP",
1263 "ERRLVL",
1264 "EXEC",
1265 "EXECUTE",
1266 "EXIT",
1267 "EXTERNAL",
1268 "FILE",
1269 "FILLFACTOR",
1270 "FREETEXT",
1271 "FREETEXTTABLE",
1272 "FUNCTION",
1273 "GOTO",
1274 "HOLDLOCK",
1275 "IDENTITY",
1276 "IDENTITYCOL",
1277 "IDENTITY_INSERT",
1278 "KILL",
1279 "LINENO",
1280 "MERGE",
1281 "NONCLUSTERED",
1282 "NULLIF",
1283 "OF",
1284 "OFF",
1285 "OFFSETS",
1286 "OPEN",
1287 "OPENDATASOURCE",
1288 "OPENQUERY",
1289 "OPENROWSET",
1290 "OPENXML",
1291 "OVER",
1292 "PERCENT",
1293 "PIVOT",
1294 "PLAN",
1295 "PRINT",
1296 "PROC",
1297 "PROCEDURE",
1298 "PUBLIC",
1299 "RAISERROR",
1300 "READ",
1301 "READTEXT",
1302 "RECONFIGURE",
1303 "REPLICATION",
1304 "RESTORE",
1305 "RESTRICT",
1306 "REVERT",
1307 "ROWCOUNT",
1308 "ROWGUIDCOL",
1309 "RULE",
1310 "SAVE",
1311 "SECURITYAUDIT",
1312 "SEMANTICKEYPHRASETABLE",
1313 "SEMANTICSIMILARITYDETAILSTABLE",
1314 "SEMANTICSIMILARITYTABLE",
1315 "SETUSER",
1316 "SHUTDOWN",
1317 "STATISTICS",
1318 "SYSTEM_USER",
1319 "TEXTSIZE",
1320 "TOP",
1321 "TRAN",
1322 "TRANSACTION",
1323 "TRIGGER",
1324 "TSEQUAL",
1325 "UNPIVOT",
1326 "UPDATETEXT",
1327 "WAITFOR",
1328 "WRITETEXT",
1329 ]
1330 .iter()
1331 .copied(),
1332 );
1333 }
1334 Some(DialectType::ClickHouse) => {
1335 words.extend(
1336 [
1337 "ANTI",
1338 "ARRAY",
1339 "ASOF",
1340 "FINAL",
1341 "FORMAT",
1342 "GLOBAL",
1343 "INF",
1344 "KILL",
1345 "MATERIALIZED",
1346 "NAN",
1347 "PREWHERE",
1348 "SAMPLE",
1349 "SEMI",
1350 "SETTINGS",
1351 "TOP",
1352 ]
1353 .iter()
1354 .copied(),
1355 );
1356 }
1357 Some(DialectType::DuckDB) => {
1358 words.extend(
1359 [
1360 "ANALYSE",
1361 "ANALYZE",
1362 "ARRAY",
1363 "BOTH",
1364 "LATERAL",
1365 "LEADING",
1366 "LOCALTIME",
1367 "LOCALTIMESTAMP",
1368 "PLACING",
1369 "QUALIFY",
1370 "SIMILAR",
1371 "TABLESAMPLE",
1372 "TRAILING",
1373 ]
1374 .iter()
1375 .copied(),
1376 );
1377 }
1378 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
1379 words.extend(
1380 [
1381 "BOTH",
1382 "CLUSTER",
1383 "DISTRIBUTE",
1384 "EXCHANGE",
1385 "EXTENDED",
1386 "FUNCTION",
1387 "LATERAL",
1388 "LEADING",
1389 "MACRO",
1390 "OVER",
1391 "PARTITION",
1392 "PERCENT",
1393 "RANGE",
1394 "READS",
1395 "REDUCE",
1396 "REGEXP",
1397 "REVOKE",
1398 "RLIKE",
1399 "ROLLUP",
1400 "SEMI",
1401 "SORT",
1402 "TABLESAMPLE",
1403 "TRAILING",
1404 "TRANSFORM",
1405 "UNBOUNDED",
1406 "UNIQUEJOIN",
1407 ]
1408 .iter()
1409 .copied(),
1410 );
1411 }
1412 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
1413 words.extend(
1414 [
1415 "CUBE",
1416 "DEALLOCATE",
1417 "DESCRIBE",
1418 "EXECUTE",
1419 "EXTRACT",
1420 "GROUPING",
1421 "LATERAL",
1422 "LOCALTIME",
1423 "LOCALTIMESTAMP",
1424 "NORMALIZE",
1425 "PREPARE",
1426 "ROLLUP",
1427 "SOME",
1428 "TABLESAMPLE",
1429 "UESCAPE",
1430 "UNNEST",
1431 ]
1432 .iter()
1433 .copied(),
1434 );
1435 }
1436 Some(DialectType::Oracle) => {
1437 words.extend(
1438 [
1439 "ACCESS",
1440 "AUDIT",
1441 "CLUSTER",
1442 "COMMENT",
1443 "COMPRESS",
1444 "CONNECT",
1445 "EXCLUSIVE",
1446 "FILE",
1447 "IDENTIFIED",
1448 "IMMEDIATE",
1449 "INCREMENT",
1450 "INITIAL",
1451 "LEVEL",
1452 "LOCK",
1453 "LONG",
1454 "MAXEXTENTS",
1455 "MINUS",
1456 "MODE",
1457 "NOAUDIT",
1458 "NOCOMPRESS",
1459 "NOWAIT",
1460 "NUMBER",
1461 "OF",
1462 "OFFLINE",
1463 "ONLINE",
1464 "PCTFREE",
1465 "PRIOR",
1466 "RAW",
1467 "RENAME",
1468 "RESOURCE",
1469 "REVOKE",
1470 "SHARE",
1471 "SIZE",
1472 "START",
1473 "SUCCESSFUL",
1474 "SYNONYM",
1475 "SYSDATE",
1476 "TRIGGER",
1477 "UID",
1478 "VALIDATE",
1479 "VARCHAR2",
1480 "WHENEVER",
1481 ]
1482 .iter()
1483 .copied(),
1484 );
1485 }
1486 Some(DialectType::Redshift) => {
1487 words.extend(
1488 [
1489 "AZ64",
1490 "BZIP2",
1491 "DELTA",
1492 "DELTA32K",
1493 "DISTSTYLE",
1494 "ENCODE",
1495 "GZIP",
1496 "ILIKE",
1497 "LIMIT",
1498 "LUNS",
1499 "LZO",
1500 "LZOP",
1501 "MOSTLY13",
1502 "MOSTLY32",
1503 "MOSTLY8",
1504 "RAW",
1505 "SIMILAR",
1506 "SNAPSHOT",
1507 "SORTKEY",
1508 "SYSDATE",
1509 "TOP",
1510 "ZSTD",
1511 ]
1512 .iter()
1513 .copied(),
1514 );
1515 }
1516 _ => {
1517 words.extend(
1519 [
1520 "ANALYZE",
1521 "ARRAY",
1522 "BOTH",
1523 "CUBE",
1524 "GROUPING",
1525 "LATERAL",
1526 "LEADING",
1527 "LOCALTIME",
1528 "LOCALTIMESTAMP",
1529 "OVER",
1530 "PARTITION",
1531 "QUALIFY",
1532 "RANGE",
1533 "ROLLUP",
1534 "SIMILAR",
1535 "SOME",
1536 "TABLESAMPLE",
1537 "TRAILING",
1538 ]
1539 .iter()
1540 .copied(),
1541 );
1542 }
1543 }
1544
1545 words
1546}
1547
1548fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1556 if name.is_empty() {
1557 return false;
1558 }
1559
1560 if name.as_bytes()[0].is_ascii_digit() {
1562 return true;
1563 }
1564
1565 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1567 return true;
1568 }
1569
1570 let upper = name.to_uppercase();
1572 reserved_words.contains(upper.as_str())
1573}
1574
1575fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1577 if id.quoted || id.name.is_empty() || id.name == "*" {
1580 return;
1581 }
1582 if needs_quoting(&id.name, reserved_words) {
1583 id.quoted = true;
1584 }
1585}
1586
1587fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1589 match expr {
1590 Expression::Identifier(id) => {
1592 maybe_quote(id, reserved_words);
1593 }
1594
1595 Expression::Column(col) => {
1596 maybe_quote(&mut col.name, reserved_words);
1597 if let Some(ref mut table) = col.table {
1598 maybe_quote(table, reserved_words);
1599 }
1600 }
1601
1602 Expression::Table(table_ref) => {
1603 maybe_quote(&mut table_ref.name, reserved_words);
1604 if let Some(ref mut schema) = table_ref.schema {
1605 maybe_quote(schema, reserved_words);
1606 }
1607 if let Some(ref mut catalog) = table_ref.catalog {
1608 maybe_quote(catalog, reserved_words);
1609 }
1610 if let Some(ref mut alias) = table_ref.alias {
1611 maybe_quote(alias, reserved_words);
1612 }
1613 for ca in &mut table_ref.column_aliases {
1614 maybe_quote(ca, reserved_words);
1615 }
1616 for p in &mut table_ref.partitions {
1617 maybe_quote(p, reserved_words);
1618 }
1619 for h in &mut table_ref.hints {
1621 quote_identifiers_recursive(h, reserved_words);
1622 }
1623 if let Some(ref mut ver) = table_ref.version {
1624 quote_identifiers_recursive(&mut ver.this, reserved_words);
1625 if let Some(ref mut e) = ver.expression {
1626 quote_identifiers_recursive(e, reserved_words);
1627 }
1628 }
1629 }
1630
1631 Expression::Star(star) => {
1632 if let Some(ref mut table) = star.table {
1633 maybe_quote(table, reserved_words);
1634 }
1635 if let Some(ref mut except_ids) = star.except {
1636 for id in except_ids {
1637 maybe_quote(id, reserved_words);
1638 }
1639 }
1640 if let Some(ref mut replace_aliases) = star.replace {
1641 for alias in replace_aliases {
1642 maybe_quote(&mut alias.alias, reserved_words);
1643 quote_identifiers_recursive(&mut alias.this, reserved_words);
1644 }
1645 }
1646 if let Some(ref mut rename_pairs) = star.rename {
1647 for (from, to) in rename_pairs {
1648 maybe_quote(from, reserved_words);
1649 maybe_quote(to, reserved_words);
1650 }
1651 }
1652 }
1653
1654 Expression::Alias(alias) => {
1656 maybe_quote(&mut alias.alias, reserved_words);
1657 for ca in &mut alias.column_aliases {
1658 maybe_quote(ca, reserved_words);
1659 }
1660 quote_identifiers_recursive(&mut alias.this, reserved_words);
1661 }
1662
1663 Expression::Select(select) => {
1665 for e in &mut select.expressions {
1666 quote_identifiers_recursive(e, reserved_words);
1667 }
1668 if let Some(ref mut from) = select.from {
1669 for e in &mut from.expressions {
1670 quote_identifiers_recursive(e, reserved_words);
1671 }
1672 }
1673 for join in &mut select.joins {
1674 quote_join(join, reserved_words);
1675 }
1676 for lv in &mut select.lateral_views {
1677 quote_lateral_view(lv, reserved_words);
1678 }
1679 if let Some(ref mut prewhere) = select.prewhere {
1680 quote_identifiers_recursive(prewhere, reserved_words);
1681 }
1682 if let Some(ref mut wh) = select.where_clause {
1683 quote_identifiers_recursive(&mut wh.this, reserved_words);
1684 }
1685 if let Some(ref mut gb) = select.group_by {
1686 for e in &mut gb.expressions {
1687 quote_identifiers_recursive(e, reserved_words);
1688 }
1689 }
1690 if let Some(ref mut hv) = select.having {
1691 quote_identifiers_recursive(&mut hv.this, reserved_words);
1692 }
1693 if let Some(ref mut q) = select.qualify {
1694 quote_identifiers_recursive(&mut q.this, reserved_words);
1695 }
1696 if let Some(ref mut ob) = select.order_by {
1697 for o in &mut ob.expressions {
1698 quote_identifiers_recursive(&mut o.this, reserved_words);
1699 }
1700 }
1701 if let Some(ref mut lim) = select.limit {
1702 quote_identifiers_recursive(&mut lim.this, reserved_words);
1703 }
1704 if let Some(ref mut off) = select.offset {
1705 quote_identifiers_recursive(&mut off.this, reserved_words);
1706 }
1707 if let Some(ref mut with) = select.with {
1708 quote_with(with, reserved_words);
1709 }
1710 if let Some(ref mut windows) = select.windows {
1711 for nw in windows {
1712 maybe_quote(&mut nw.name, reserved_words);
1713 quote_over(&mut nw.spec, reserved_words);
1714 }
1715 }
1716 if let Some(ref mut distinct_on) = select.distinct_on {
1717 for e in distinct_on {
1718 quote_identifiers_recursive(e, reserved_words);
1719 }
1720 }
1721 if let Some(ref mut limit_by) = select.limit_by {
1722 for e in limit_by {
1723 quote_identifiers_recursive(e, reserved_words);
1724 }
1725 }
1726 if let Some(ref mut settings) = select.settings {
1727 for e in settings {
1728 quote_identifiers_recursive(e, reserved_words);
1729 }
1730 }
1731 if let Some(ref mut format) = select.format {
1732 quote_identifiers_recursive(format, reserved_words);
1733 }
1734 }
1735
1736 Expression::Union(u) => {
1738 quote_identifiers_recursive(&mut u.left, reserved_words);
1739 quote_identifiers_recursive(&mut u.right, reserved_words);
1740 if let Some(ref mut with) = u.with {
1741 quote_with(with, reserved_words);
1742 }
1743 if let Some(ref mut ob) = u.order_by {
1744 for o in &mut ob.expressions {
1745 quote_identifiers_recursive(&mut o.this, reserved_words);
1746 }
1747 }
1748 if let Some(ref mut lim) = u.limit {
1749 quote_identifiers_recursive(lim, reserved_words);
1750 }
1751 if let Some(ref mut off) = u.offset {
1752 quote_identifiers_recursive(off, reserved_words);
1753 }
1754 }
1755 Expression::Intersect(i) => {
1756 quote_identifiers_recursive(&mut i.left, reserved_words);
1757 quote_identifiers_recursive(&mut i.right, reserved_words);
1758 if let Some(ref mut with) = i.with {
1759 quote_with(with, reserved_words);
1760 }
1761 if let Some(ref mut ob) = i.order_by {
1762 for o in &mut ob.expressions {
1763 quote_identifiers_recursive(&mut o.this, reserved_words);
1764 }
1765 }
1766 }
1767 Expression::Except(e) => {
1768 quote_identifiers_recursive(&mut e.left, reserved_words);
1769 quote_identifiers_recursive(&mut e.right, reserved_words);
1770 if let Some(ref mut with) = e.with {
1771 quote_with(with, reserved_words);
1772 }
1773 if let Some(ref mut ob) = e.order_by {
1774 for o in &mut ob.expressions {
1775 quote_identifiers_recursive(&mut o.this, reserved_words);
1776 }
1777 }
1778 }
1779
1780 Expression::Subquery(sq) => {
1782 quote_identifiers_recursive(&mut sq.this, reserved_words);
1783 if let Some(ref mut alias) = sq.alias {
1784 maybe_quote(alias, reserved_words);
1785 }
1786 for ca in &mut sq.column_aliases {
1787 maybe_quote(ca, reserved_words);
1788 }
1789 if let Some(ref mut ob) = sq.order_by {
1790 for o in &mut ob.expressions {
1791 quote_identifiers_recursive(&mut o.this, reserved_words);
1792 }
1793 }
1794 }
1795
1796 Expression::Insert(ins) => {
1798 quote_table_ref(&mut ins.table, reserved_words);
1799 for c in &mut ins.columns {
1800 maybe_quote(c, reserved_words);
1801 }
1802 for row in &mut ins.values {
1803 for e in row {
1804 quote_identifiers_recursive(e, reserved_words);
1805 }
1806 }
1807 if let Some(ref mut q) = ins.query {
1808 quote_identifiers_recursive(q, reserved_words);
1809 }
1810 for (id, val) in &mut ins.partition {
1811 maybe_quote(id, reserved_words);
1812 if let Some(ref mut v) = val {
1813 quote_identifiers_recursive(v, reserved_words);
1814 }
1815 }
1816 for e in &mut ins.returning {
1817 quote_identifiers_recursive(e, reserved_words);
1818 }
1819 if let Some(ref mut on_conflict) = ins.on_conflict {
1820 quote_identifiers_recursive(on_conflict, reserved_words);
1821 }
1822 if let Some(ref mut with) = ins.with {
1823 quote_with(with, reserved_words);
1824 }
1825 if let Some(ref mut alias) = ins.alias {
1826 maybe_quote(alias, reserved_words);
1827 }
1828 if let Some(ref mut src_alias) = ins.source_alias {
1829 maybe_quote(src_alias, reserved_words);
1830 }
1831 }
1832
1833 Expression::Update(upd) => {
1834 quote_table_ref(&mut upd.table, reserved_words);
1835 for tr in &mut upd.extra_tables {
1836 quote_table_ref(tr, reserved_words);
1837 }
1838 for join in &mut upd.table_joins {
1839 quote_join(join, reserved_words);
1840 }
1841 for (id, val) in &mut upd.set {
1842 maybe_quote(id, reserved_words);
1843 quote_identifiers_recursive(val, reserved_words);
1844 }
1845 if let Some(ref mut from) = upd.from_clause {
1846 for e in &mut from.expressions {
1847 quote_identifiers_recursive(e, reserved_words);
1848 }
1849 }
1850 for join in &mut upd.from_joins {
1851 quote_join(join, reserved_words);
1852 }
1853 if let Some(ref mut wh) = upd.where_clause {
1854 quote_identifiers_recursive(&mut wh.this, reserved_words);
1855 }
1856 for e in &mut upd.returning {
1857 quote_identifiers_recursive(e, reserved_words);
1858 }
1859 if let Some(ref mut with) = upd.with {
1860 quote_with(with, reserved_words);
1861 }
1862 }
1863
1864 Expression::Delete(del) => {
1865 quote_table_ref(&mut del.table, reserved_words);
1866 if let Some(ref mut alias) = del.alias {
1867 maybe_quote(alias, reserved_words);
1868 }
1869 for tr in &mut del.using {
1870 quote_table_ref(tr, reserved_words);
1871 }
1872 if let Some(ref mut wh) = del.where_clause {
1873 quote_identifiers_recursive(&mut wh.this, reserved_words);
1874 }
1875 if let Some(ref mut with) = del.with {
1876 quote_with(with, reserved_words);
1877 }
1878 }
1879
1880 Expression::And(bin)
1882 | Expression::Or(bin)
1883 | Expression::Eq(bin)
1884 | Expression::Neq(bin)
1885 | Expression::Lt(bin)
1886 | Expression::Lte(bin)
1887 | Expression::Gt(bin)
1888 | Expression::Gte(bin)
1889 | Expression::Add(bin)
1890 | Expression::Sub(bin)
1891 | Expression::Mul(bin)
1892 | Expression::Div(bin)
1893 | Expression::Mod(bin)
1894 | Expression::BitwiseAnd(bin)
1895 | Expression::BitwiseOr(bin)
1896 | Expression::BitwiseXor(bin)
1897 | Expression::Concat(bin)
1898 | Expression::Adjacent(bin)
1899 | Expression::TsMatch(bin)
1900 | Expression::PropertyEQ(bin)
1901 | Expression::ArrayContainsAll(bin)
1902 | Expression::ArrayContainedBy(bin)
1903 | Expression::ArrayOverlaps(bin)
1904 | Expression::JSONBContainsAllTopKeys(bin)
1905 | Expression::JSONBContainsAnyTopKeys(bin)
1906 | Expression::JSONBDeleteAtPath(bin)
1907 | Expression::ExtendsLeft(bin)
1908 | Expression::ExtendsRight(bin)
1909 | Expression::Is(bin)
1910 | Expression::NullSafeEq(bin)
1911 | Expression::NullSafeNeq(bin)
1912 | Expression::Glob(bin)
1913 | Expression::Match(bin)
1914 | Expression::MemberOf(bin)
1915 | Expression::BitwiseLeftShift(bin)
1916 | Expression::BitwiseRightShift(bin) => {
1917 quote_identifiers_recursive(&mut bin.left, reserved_words);
1918 quote_identifiers_recursive(&mut bin.right, reserved_words);
1919 }
1920
1921 Expression::Like(like) | Expression::ILike(like) => {
1923 quote_identifiers_recursive(&mut like.left, reserved_words);
1924 quote_identifiers_recursive(&mut like.right, reserved_words);
1925 if let Some(ref mut esc) = like.escape {
1926 quote_identifiers_recursive(esc, reserved_words);
1927 }
1928 }
1929
1930 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1932 quote_identifiers_recursive(&mut un.this, reserved_words);
1933 }
1934
1935 Expression::In(in_expr) => {
1937 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1938 for e in &mut in_expr.expressions {
1939 quote_identifiers_recursive(e, reserved_words);
1940 }
1941 if let Some(ref mut q) = in_expr.query {
1942 quote_identifiers_recursive(q, reserved_words);
1943 }
1944 if let Some(ref mut un) = in_expr.unnest {
1945 quote_identifiers_recursive(un, reserved_words);
1946 }
1947 }
1948
1949 Expression::Between(bw) => {
1950 quote_identifiers_recursive(&mut bw.this, reserved_words);
1951 quote_identifiers_recursive(&mut bw.low, reserved_words);
1952 quote_identifiers_recursive(&mut bw.high, reserved_words);
1953 }
1954
1955 Expression::IsNull(is_null) => {
1956 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1957 }
1958
1959 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
1960 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
1961 }
1962
1963 Expression::Exists(ex) => {
1964 quote_identifiers_recursive(&mut ex.this, reserved_words);
1965 }
1966
1967 Expression::Function(func) => {
1969 for arg in &mut func.args {
1970 quote_identifiers_recursive(arg, reserved_words);
1971 }
1972 }
1973
1974 Expression::AggregateFunction(agg) => {
1975 for arg in &mut agg.args {
1976 quote_identifiers_recursive(arg, reserved_words);
1977 }
1978 if let Some(ref mut filter) = agg.filter {
1979 quote_identifiers_recursive(filter, reserved_words);
1980 }
1981 for o in &mut agg.order_by {
1982 quote_identifiers_recursive(&mut o.this, reserved_words);
1983 }
1984 }
1985
1986 Expression::WindowFunction(wf) => {
1987 quote_identifiers_recursive(&mut wf.this, reserved_words);
1988 quote_over(&mut wf.over, reserved_words);
1989 }
1990
1991 Expression::Case(case) => {
1993 if let Some(ref mut operand) = case.operand {
1994 quote_identifiers_recursive(operand, reserved_words);
1995 }
1996 for (when, then) in &mut case.whens {
1997 quote_identifiers_recursive(when, reserved_words);
1998 quote_identifiers_recursive(then, reserved_words);
1999 }
2000 if let Some(ref mut else_) = case.else_ {
2001 quote_identifiers_recursive(else_, reserved_words);
2002 }
2003 }
2004
2005 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
2007 quote_identifiers_recursive(&mut cast.this, reserved_words);
2008 if let Some(ref mut fmt) = cast.format {
2009 quote_identifiers_recursive(fmt, reserved_words);
2010 }
2011 }
2012
2013 Expression::Paren(paren) => {
2015 quote_identifiers_recursive(&mut paren.this, reserved_words);
2016 }
2017
2018 Expression::Annotated(ann) => {
2019 quote_identifiers_recursive(&mut ann.this, reserved_words);
2020 }
2021
2022 Expression::With(with) => {
2024 quote_with(with, reserved_words);
2025 }
2026
2027 Expression::Cte(cte) => {
2028 maybe_quote(&mut cte.alias, reserved_words);
2029 for c in &mut cte.columns {
2030 maybe_quote(c, reserved_words);
2031 }
2032 quote_identifiers_recursive(&mut cte.this, reserved_words);
2033 }
2034
2035 Expression::From(from) => {
2037 for e in &mut from.expressions {
2038 quote_identifiers_recursive(e, reserved_words);
2039 }
2040 }
2041
2042 Expression::Join(join) => {
2043 quote_join(join, reserved_words);
2044 }
2045
2046 Expression::JoinedTable(jt) => {
2047 quote_identifiers_recursive(&mut jt.left, reserved_words);
2048 for join in &mut jt.joins {
2049 quote_join(join, reserved_words);
2050 }
2051 if let Some(ref mut alias) = jt.alias {
2052 maybe_quote(alias, reserved_words);
2053 }
2054 }
2055
2056 Expression::Where(wh) => {
2057 quote_identifiers_recursive(&mut wh.this, reserved_words);
2058 }
2059
2060 Expression::GroupBy(gb) => {
2061 for e in &mut gb.expressions {
2062 quote_identifiers_recursive(e, reserved_words);
2063 }
2064 }
2065
2066 Expression::Having(hv) => {
2067 quote_identifiers_recursive(&mut hv.this, reserved_words);
2068 }
2069
2070 Expression::OrderBy(ob) => {
2071 for o in &mut ob.expressions {
2072 quote_identifiers_recursive(&mut o.this, reserved_words);
2073 }
2074 }
2075
2076 Expression::Ordered(ord) => {
2077 quote_identifiers_recursive(&mut ord.this, reserved_words);
2078 }
2079
2080 Expression::Limit(lim) => {
2081 quote_identifiers_recursive(&mut lim.this, reserved_words);
2082 }
2083
2084 Expression::Offset(off) => {
2085 quote_identifiers_recursive(&mut off.this, reserved_words);
2086 }
2087
2088 Expression::Qualify(q) => {
2089 quote_identifiers_recursive(&mut q.this, reserved_words);
2090 }
2091
2092 Expression::Window(ws) => {
2093 for e in &mut ws.partition_by {
2094 quote_identifiers_recursive(e, reserved_words);
2095 }
2096 for o in &mut ws.order_by {
2097 quote_identifiers_recursive(&mut o.this, reserved_words);
2098 }
2099 }
2100
2101 Expression::Over(over) => {
2102 quote_over(over, reserved_words);
2103 }
2104
2105 Expression::WithinGroup(wg) => {
2106 quote_identifiers_recursive(&mut wg.this, reserved_words);
2107 for o in &mut wg.order_by {
2108 quote_identifiers_recursive(&mut o.this, reserved_words);
2109 }
2110 }
2111
2112 Expression::Pivot(piv) => {
2114 quote_identifiers_recursive(&mut piv.this, reserved_words);
2115 for e in &mut piv.expressions {
2116 quote_identifiers_recursive(e, reserved_words);
2117 }
2118 for f in &mut piv.fields {
2119 quote_identifiers_recursive(f, reserved_words);
2120 }
2121 if let Some(ref mut alias) = piv.alias {
2122 maybe_quote(alias, reserved_words);
2123 }
2124 }
2125
2126 Expression::Unpivot(unpiv) => {
2127 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
2128 maybe_quote(&mut unpiv.value_column, reserved_words);
2129 maybe_quote(&mut unpiv.name_column, reserved_words);
2130 for e in &mut unpiv.columns {
2131 quote_identifiers_recursive(e, reserved_words);
2132 }
2133 if let Some(ref mut alias) = unpiv.alias {
2134 maybe_quote(alias, reserved_words);
2135 }
2136 }
2137
2138 Expression::Values(vals) => {
2140 for tuple in &mut vals.expressions {
2141 for e in &mut tuple.expressions {
2142 quote_identifiers_recursive(e, reserved_words);
2143 }
2144 }
2145 if let Some(ref mut alias) = vals.alias {
2146 maybe_quote(alias, reserved_words);
2147 }
2148 for ca in &mut vals.column_aliases {
2149 maybe_quote(ca, reserved_words);
2150 }
2151 }
2152
2153 Expression::Array(arr) => {
2155 for e in &mut arr.expressions {
2156 quote_identifiers_recursive(e, reserved_words);
2157 }
2158 }
2159
2160 Expression::Struct(st) => {
2161 for (_name, e) in &mut st.fields {
2162 quote_identifiers_recursive(e, reserved_words);
2163 }
2164 }
2165
2166 Expression::Tuple(tup) => {
2167 for e in &mut tup.expressions {
2168 quote_identifiers_recursive(e, reserved_words);
2169 }
2170 }
2171
2172 Expression::Subscript(sub) => {
2174 quote_identifiers_recursive(&mut sub.this, reserved_words);
2175 quote_identifiers_recursive(&mut sub.index, reserved_words);
2176 }
2177
2178 Expression::Dot(dot) => {
2179 quote_identifiers_recursive(&mut dot.this, reserved_words);
2180 maybe_quote(&mut dot.field, reserved_words);
2181 }
2182
2183 Expression::ScopeResolution(sr) => {
2184 if let Some(ref mut this) = sr.this {
2185 quote_identifiers_recursive(this, reserved_words);
2186 }
2187 quote_identifiers_recursive(&mut sr.expression, reserved_words);
2188 }
2189
2190 Expression::Lateral(lat) => {
2192 quote_identifiers_recursive(&mut lat.this, reserved_words);
2193 }
2195
2196 Expression::DPipe(dpipe) => {
2198 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
2199 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
2200 }
2201
2202 Expression::Merge(merge) => {
2204 quote_identifiers_recursive(&mut merge.this, reserved_words);
2205 quote_identifiers_recursive(&mut merge.using, reserved_words);
2206 if let Some(ref mut on) = merge.on {
2207 quote_identifiers_recursive(on, reserved_words);
2208 }
2209 if let Some(ref mut whens) = merge.whens {
2210 quote_identifiers_recursive(whens, reserved_words);
2211 }
2212 if let Some(ref mut with) = merge.with_ {
2213 quote_identifiers_recursive(with, reserved_words);
2214 }
2215 if let Some(ref mut ret) = merge.returning {
2216 quote_identifiers_recursive(ret, reserved_words);
2217 }
2218 }
2219
2220 Expression::LateralView(lv) => {
2222 quote_lateral_view(lv, reserved_words);
2223 }
2224
2225 Expression::Anonymous(anon) => {
2227 quote_identifiers_recursive(&mut anon.this, reserved_words);
2228 for e in &mut anon.expressions {
2229 quote_identifiers_recursive(e, reserved_words);
2230 }
2231 }
2232
2233 Expression::Filter(filter) => {
2235 quote_identifiers_recursive(&mut filter.this, reserved_words);
2236 quote_identifiers_recursive(&mut filter.expression, reserved_words);
2237 }
2238
2239 Expression::Returning(ret) => {
2241 for e in &mut ret.expressions {
2242 quote_identifiers_recursive(e, reserved_words);
2243 }
2244 }
2245
2246 Expression::BracedWildcard(inner) => {
2248 quote_identifiers_recursive(inner, reserved_words);
2249 }
2250
2251 Expression::ReturnStmt(inner) => {
2253 quote_identifiers_recursive(inner, reserved_words);
2254 }
2255
2256 Expression::Literal(_)
2258 | Expression::Boolean(_)
2259 | Expression::Null(_)
2260 | Expression::DataType(_)
2261 | Expression::Raw(_)
2262 | Expression::Placeholder(_)
2263 | Expression::CurrentDate(_)
2264 | Expression::CurrentTime(_)
2265 | Expression::CurrentTimestamp(_)
2266 | Expression::CurrentTimestampLTZ(_)
2267 | Expression::SessionUser(_)
2268 | Expression::RowNumber(_)
2269 | Expression::Rank(_)
2270 | Expression::DenseRank(_)
2271 | Expression::PercentRank(_)
2272 | Expression::CumeDist(_)
2273 | Expression::Random(_)
2274 | Expression::Pi(_)
2275 | Expression::JSONPathRoot(_) => {
2276 }
2278
2279 _ => {}
2283 }
2284}
2285
2286fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
2288 quote_identifiers_recursive(&mut join.this, reserved_words);
2289 if let Some(ref mut on) = join.on {
2290 quote_identifiers_recursive(on, reserved_words);
2291 }
2292 for id in &mut join.using {
2293 maybe_quote(id, reserved_words);
2294 }
2295 if let Some(ref mut mc) = join.match_condition {
2296 quote_identifiers_recursive(mc, reserved_words);
2297 }
2298 for piv in &mut join.pivots {
2299 quote_identifiers_recursive(piv, reserved_words);
2300 }
2301}
2302
2303fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
2305 for cte in &mut with.ctes {
2306 maybe_quote(&mut cte.alias, reserved_words);
2307 for c in &mut cte.columns {
2308 maybe_quote(c, reserved_words);
2309 }
2310 for k in &mut cte.key_expressions {
2311 maybe_quote(k, reserved_words);
2312 }
2313 quote_identifiers_recursive(&mut cte.this, reserved_words);
2314 }
2315}
2316
2317fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
2319 if let Some(ref mut wn) = over.window_name {
2320 maybe_quote(wn, reserved_words);
2321 }
2322 for e in &mut over.partition_by {
2323 quote_identifiers_recursive(e, reserved_words);
2324 }
2325 for o in &mut over.order_by {
2326 quote_identifiers_recursive(&mut o.this, reserved_words);
2327 }
2328 if let Some(ref mut alias) = over.alias {
2329 maybe_quote(alias, reserved_words);
2330 }
2331}
2332
2333fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
2335 maybe_quote(&mut table_ref.name, reserved_words);
2336 if let Some(ref mut schema) = table_ref.schema {
2337 maybe_quote(schema, reserved_words);
2338 }
2339 if let Some(ref mut catalog) = table_ref.catalog {
2340 maybe_quote(catalog, reserved_words);
2341 }
2342 if let Some(ref mut alias) = table_ref.alias {
2343 maybe_quote(alias, reserved_words);
2344 }
2345 for ca in &mut table_ref.column_aliases {
2346 maybe_quote(ca, reserved_words);
2347 }
2348 for p in &mut table_ref.partitions {
2349 maybe_quote(p, reserved_words);
2350 }
2351 for h in &mut table_ref.hints {
2352 quote_identifiers_recursive(h, reserved_words);
2353 }
2354}
2355
2356fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
2358 quote_identifiers_recursive(&mut lv.this, reserved_words);
2359 if let Some(ref mut ta) = lv.table_alias {
2360 maybe_quote(ta, reserved_words);
2361 }
2362 for ca in &mut lv.column_aliases {
2363 maybe_quote(ca, reserved_words);
2364 }
2365}
2366
2367pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
2378 let reserved_words = get_reserved_words(dialect);
2379 let mut result = expression;
2380 quote_identifiers_recursive(&mut result, &reserved_words);
2381 result
2382}
2383
2384pub fn pushdown_cte_alias_columns(_scope: &Scope) {
2389 }
2392
2393fn pushdown_cte_alias_columns_with(with: &mut With) {
2394 for cte in &mut with.ctes {
2395 if cte.columns.is_empty() {
2396 continue;
2397 }
2398
2399 if let Expression::Select(select) = &mut cte.this {
2400 let mut next_expressions = Vec::with_capacity(select.expressions.len());
2401
2402 for (i, projection) in select.expressions.iter().enumerate() {
2403 let Some(alias_name) = cte.columns.get(i) else {
2404 next_expressions.push(projection.clone());
2405 continue;
2406 };
2407
2408 match projection {
2409 Expression::Alias(existing) => {
2410 let mut aliased = existing.clone();
2411 aliased.alias = alias_name.clone();
2412 next_expressions.push(Expression::Alias(aliased));
2413 }
2414 _ => {
2415 next_expressions.push(create_alias(projection.clone(), &alias_name.name));
2416 }
2417 }
2418 }
2419
2420 select.expressions = next_expressions;
2421 }
2422 }
2423}
2424
2425fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
2431 let mut columns = Vec::new();
2432 collect_columns(&scope.expression, &mut columns);
2433 columns
2434}
2435
2436#[derive(Debug, Clone)]
2438struct ColumnRef {
2439 table: Option<String>,
2440 name: String,
2441}
2442
2443fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
2445 match expr {
2446 Expression::Column(col) => {
2447 columns.push(ColumnRef {
2448 table: col.table.as_ref().map(|t| t.name.clone()),
2449 name: col.name.name.clone(),
2450 });
2451 }
2452 Expression::Select(select) => {
2453 for e in &select.expressions {
2454 collect_columns(e, columns);
2455 }
2456 if let Some(from) = &select.from {
2457 for e in &from.expressions {
2458 collect_columns(e, columns);
2459 }
2460 }
2461 if let Some(where_clause) = &select.where_clause {
2462 collect_columns(&where_clause.this, columns);
2463 }
2464 if let Some(group_by) = &select.group_by {
2465 for e in &group_by.expressions {
2466 collect_columns(e, columns);
2467 }
2468 }
2469 if let Some(having) = &select.having {
2470 collect_columns(&having.this, columns);
2471 }
2472 if let Some(order_by) = &select.order_by {
2473 for o in &order_by.expressions {
2474 collect_columns(&o.this, columns);
2475 }
2476 }
2477 for join in &select.joins {
2478 collect_columns(&join.this, columns);
2479 if let Some(on) = &join.on {
2480 collect_columns(on, columns);
2481 }
2482 }
2483 }
2484 Expression::Alias(alias) => {
2485 collect_columns(&alias.this, columns);
2486 }
2487 Expression::Function(func) => {
2488 for arg in &func.args {
2489 collect_columns(arg, columns);
2490 }
2491 }
2492 Expression::AggregateFunction(agg) => {
2493 for arg in &agg.args {
2494 collect_columns(arg, columns);
2495 }
2496 }
2497 Expression::And(bin)
2498 | Expression::Or(bin)
2499 | Expression::Eq(bin)
2500 | Expression::Neq(bin)
2501 | Expression::Lt(bin)
2502 | Expression::Lte(bin)
2503 | Expression::Gt(bin)
2504 | Expression::Gte(bin)
2505 | Expression::Add(bin)
2506 | Expression::Sub(bin)
2507 | Expression::Mul(bin)
2508 | Expression::Div(bin) => {
2509 collect_columns(&bin.left, columns);
2510 collect_columns(&bin.right, columns);
2511 }
2512 Expression::Not(unary) | Expression::Neg(unary) => {
2513 collect_columns(&unary.this, columns);
2514 }
2515 Expression::Paren(paren) => {
2516 collect_columns(&paren.this, columns);
2517 }
2518 Expression::Case(case) => {
2519 if let Some(operand) = &case.operand {
2520 collect_columns(operand, columns);
2521 }
2522 for (when, then) in &case.whens {
2523 collect_columns(when, columns);
2524 collect_columns(then, columns);
2525 }
2526 if let Some(else_) = &case.else_ {
2527 collect_columns(else_, columns);
2528 }
2529 }
2530 Expression::Cast(cast) => {
2531 collect_columns(&cast.this, columns);
2532 }
2533 Expression::In(in_expr) => {
2534 collect_columns(&in_expr.this, columns);
2535 for e in &in_expr.expressions {
2536 collect_columns(e, columns);
2537 }
2538 if let Some(query) = &in_expr.query {
2539 collect_columns(query, columns);
2540 }
2541 }
2542 Expression::Between(between) => {
2543 collect_columns(&between.this, columns);
2544 collect_columns(&between.low, columns);
2545 collect_columns(&between.high, columns);
2546 }
2547 Expression::Subquery(subquery) => {
2548 collect_columns(&subquery.this, columns);
2549 }
2550 _ => {}
2551 }
2552}
2553
2554fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
2556 get_scope_columns(scope)
2557 .into_iter()
2558 .filter(|c| c.table.is_none())
2559 .collect()
2560}
2561
2562fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
2564 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
2565
2566 get_scope_columns(scope)
2567 .into_iter()
2568 .filter(|c| {
2569 if let Some(table) = &c.table {
2570 !source_names.contains(table)
2571 } else {
2572 false
2573 }
2574 })
2575 .collect()
2576}
2577
2578fn is_correlated_subquery(scope: &Scope) -> bool {
2580 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2581}
2582
2583fn is_star_column(col: &Column) -> bool {
2585 col.name.name == "*"
2586}
2587
2588fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2590 Expression::boxed_column(Column {
2591 name: Identifier::new(name),
2592 table: table.map(Identifier::new),
2593 join_mark: false,
2594 trailing_comments: vec![],
2595 span: None,
2596 inferred_type: None,
2597 })
2598}
2599
2600fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2602 Expression::Alias(Box::new(Alias {
2603 this: expr,
2604 alias: Identifier::new(alias_name),
2605 column_aliases: vec![],
2606 pre_alias_comments: vec![],
2607 trailing_comments: vec![],
2608 inferred_type: None,
2609 }))
2610}
2611
2612fn get_output_name(expr: &Expression) -> Option<String> {
2614 match expr {
2615 Expression::Column(col) => Some(col.name.name.clone()),
2616 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2617 Expression::Identifier(id) => Some(id.name.clone()),
2618 _ => None,
2619 }
2620}
2621
2622#[cfg(test)]
2623mod tests {
2624 use super::*;
2625 use crate::expressions::DataType;
2626 use crate::generator::Generator;
2627 use crate::parser::Parser;
2628 use crate::scope::build_scope;
2629 use crate::{MappingSchema, Schema};
2630
2631 fn gen(expr: &Expression) -> String {
2632 Generator::new().generate(expr).unwrap()
2633 }
2634
2635 fn parse(sql: &str) -> Expression {
2636 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2637 }
2638
2639 #[test]
2640 fn test_qualify_columns_options() {
2641 let options = QualifyColumnsOptions::new()
2642 .with_expand_alias_refs(true)
2643 .with_expand_stars(false)
2644 .with_dialect(DialectType::PostgreSQL)
2645 .with_allow_partial(true);
2646
2647 assert!(options.expand_alias_refs);
2648 assert!(!options.expand_stars);
2649 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2650 assert!(options.allow_partial_qualification);
2651 }
2652
2653 #[test]
2654 fn test_get_scope_columns() {
2655 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2656 let scope = build_scope(&expr);
2657 let columns = get_scope_columns(&scope);
2658
2659 assert!(columns.iter().any(|c| c.name == "a"));
2660 assert!(columns.iter().any(|c| c.name == "b"));
2661 assert!(columns.iter().any(|c| c.name == "c"));
2662 }
2663
2664 #[test]
2665 fn test_get_unqualified_columns() {
2666 let expr = parse("SELECT t.a, b FROM t");
2667 let scope = build_scope(&expr);
2668 let unqualified = get_unqualified_columns(&scope);
2669
2670 assert!(unqualified.iter().any(|c| c.name == "b"));
2672 assert!(!unqualified.iter().any(|c| c.name == "a"));
2673 }
2674
2675 #[test]
2676 fn test_is_star_column() {
2677 let col = Column {
2678 name: Identifier::new("*"),
2679 table: Some(Identifier::new("t")),
2680 join_mark: false,
2681 trailing_comments: vec![],
2682 span: None,
2683 inferred_type: None,
2684 };
2685 assert!(is_star_column(&col));
2686
2687 let col2 = Column {
2688 name: Identifier::new("id"),
2689 table: None,
2690 join_mark: false,
2691 trailing_comments: vec![],
2692 span: None,
2693 inferred_type: None,
2694 };
2695 assert!(!is_star_column(&col2));
2696 }
2697
2698 #[test]
2699 fn test_create_qualified_column() {
2700 let expr = create_qualified_column("id", Some("users"));
2701 let sql = gen(&expr);
2702 assert!(sql.contains("users"));
2703 assert!(sql.contains("id"));
2704 }
2705
2706 #[test]
2707 fn test_create_alias() {
2708 let col = Expression::boxed_column(Column {
2709 name: Identifier::new("value"),
2710 table: None,
2711 join_mark: false,
2712 trailing_comments: vec![],
2713 span: None,
2714 inferred_type: None,
2715 });
2716 let aliased = create_alias(col, "total");
2717 let sql = gen(&aliased);
2718 assert!(sql.contains("AS") || sql.contains("total"));
2719 }
2720
2721 #[test]
2722 fn test_validate_qualify_columns_success() {
2723 let expr = parse("SELECT t.a, t.b FROM t");
2725 let result = validate_qualify_columns(&expr);
2726 let _ = result;
2729 }
2730
2731 #[test]
2732 fn test_collect_columns_nested() {
2733 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2734 let mut columns = Vec::new();
2735 collect_columns(&expr, &mut columns);
2736
2737 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2738 assert!(names.contains(&"a"));
2739 assert!(names.contains(&"b"));
2740 assert!(names.contains(&"c"));
2741 assert!(names.contains(&"d"));
2742 assert!(names.contains(&"e"));
2743 assert!(names.contains(&"f"));
2744 }
2745
2746 #[test]
2747 fn test_collect_columns_in_case() {
2748 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2749 let mut columns = Vec::new();
2750 collect_columns(&expr, &mut columns);
2751
2752 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2753 assert!(names.contains(&"a"));
2754 assert!(names.contains(&"b"));
2755 assert!(names.contains(&"c"));
2756 }
2757
2758 #[test]
2759 fn test_collect_columns_in_subquery() {
2760 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2761 let mut columns = Vec::new();
2762 collect_columns(&expr, &mut columns);
2763
2764 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2765 assert!(names.contains(&"a"));
2766 assert!(names.contains(&"b"));
2767 assert!(names.contains(&"c"));
2768 }
2769
2770 #[test]
2771 fn test_qualify_outputs_basic() {
2772 let expr = parse("SELECT a, b + c FROM t");
2773 let scope = build_scope(&expr);
2774 let result = qualify_outputs(&scope);
2775 assert!(result.is_ok());
2776 }
2777
2778 #[test]
2779 fn test_qualify_columns_expands_star_with_schema() {
2780 let expr = parse("SELECT * FROM users");
2781
2782 let mut schema = MappingSchema::new();
2783 schema
2784 .add_table(
2785 "users",
2786 &[
2787 (
2788 "id".to_string(),
2789 DataType::Int {
2790 length: None,
2791 integer_spelling: false,
2792 },
2793 ),
2794 ("name".to_string(), DataType::Text),
2795 ("email".to_string(), DataType::Text),
2796 ],
2797 None,
2798 )
2799 .expect("schema setup");
2800
2801 let result =
2802 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2803 let sql = gen(&result);
2804
2805 assert!(!sql.contains("SELECT *"));
2806 assert!(sql.contains("users.id"));
2807 assert!(sql.contains("users.name"));
2808 assert!(sql.contains("users.email"));
2809 }
2810
2811 #[test]
2812 fn test_qualify_columns_expands_group_by_positions() {
2813 let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
2814
2815 let mut schema = MappingSchema::new();
2816 schema
2817 .add_table(
2818 "t",
2819 &[
2820 (
2821 "a".to_string(),
2822 DataType::Int {
2823 length: None,
2824 integer_spelling: false,
2825 },
2826 ),
2827 (
2828 "b".to_string(),
2829 DataType::Int {
2830 length: None,
2831 integer_spelling: false,
2832 },
2833 ),
2834 ],
2835 None,
2836 )
2837 .expect("schema setup");
2838
2839 let result =
2840 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2841 let sql = gen(&result);
2842
2843 assert!(!sql.contains("GROUP BY 1"));
2844 assert!(!sql.contains("GROUP BY 2"));
2845 assert!(sql.contains("GROUP BY"));
2846 assert!(sql.contains("t.a"));
2847 assert!(sql.contains("t.b"));
2848 }
2849
2850 #[test]
2855 fn test_expand_using_simple() {
2856 let expr = parse("SELECT x.b FROM x JOIN y USING (b)");
2858
2859 let mut schema = MappingSchema::new();
2860 schema
2861 .add_table(
2862 "x",
2863 &[
2864 ("a".to_string(), DataType::BigInt { length: None }),
2865 ("b".to_string(), DataType::BigInt { length: None }),
2866 ],
2867 None,
2868 )
2869 .expect("schema setup");
2870 schema
2871 .add_table(
2872 "y",
2873 &[
2874 ("b".to_string(), DataType::BigInt { length: None }),
2875 ("c".to_string(), DataType::BigInt { length: None }),
2876 ],
2877 None,
2878 )
2879 .expect("schema setup");
2880
2881 let result =
2882 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2883 let sql = gen(&result);
2884
2885 assert!(
2887 !sql.contains("USING"),
2888 "USING should be replaced with ON: {sql}"
2889 );
2890 assert!(
2891 sql.contains("ON x.b = y.b"),
2892 "ON condition should be x.b = y.b: {sql}"
2893 );
2894 assert!(sql.contains("SELECT x.b"), "SELECT should keep x.b: {sql}");
2896 }
2897
2898 #[test]
2899 fn test_expand_using_unqualified_coalesce() {
2900 let expr = parse("SELECT b FROM x JOIN y USING(b)");
2902
2903 let mut schema = MappingSchema::new();
2904 schema
2905 .add_table(
2906 "x",
2907 &[
2908 ("a".to_string(), DataType::BigInt { length: None }),
2909 ("b".to_string(), DataType::BigInt { length: None }),
2910 ],
2911 None,
2912 )
2913 .expect("schema setup");
2914 schema
2915 .add_table(
2916 "y",
2917 &[
2918 ("b".to_string(), DataType::BigInt { length: None }),
2919 ("c".to_string(), DataType::BigInt { length: None }),
2920 ],
2921 None,
2922 )
2923 .expect("schema setup");
2924
2925 let result =
2926 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2927 let sql = gen(&result);
2928
2929 assert!(
2930 sql.contains("COALESCE(x.b, y.b)"),
2931 "Unqualified USING column should become COALESCE: {sql}"
2932 );
2933 assert!(
2934 sql.contains("AS b"),
2935 "COALESCE should be aliased as 'b': {sql}"
2936 );
2937 assert!(
2938 sql.contains("ON x.b = y.b"),
2939 "ON condition should be generated: {sql}"
2940 );
2941 }
2942
2943 #[test]
2944 fn test_expand_using_with_where() {
2945 let expr = parse("SELECT b FROM x JOIN y USING(b) WHERE b = 1");
2947
2948 let mut schema = MappingSchema::new();
2949 schema
2950 .add_table(
2951 "x",
2952 &[("b".to_string(), DataType::BigInt { length: None })],
2953 None,
2954 )
2955 .expect("schema setup");
2956 schema
2957 .add_table(
2958 "y",
2959 &[("b".to_string(), DataType::BigInt { length: None })],
2960 None,
2961 )
2962 .expect("schema setup");
2963
2964 let result =
2965 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2966 let sql = gen(&result);
2967
2968 assert!(
2969 sql.contains("WHERE COALESCE(x.b, y.b)"),
2970 "WHERE should use COALESCE for USING column: {sql}"
2971 );
2972 }
2973
2974 #[test]
2975 fn test_expand_using_multi_join() {
2976 let expr = parse("SELECT b FROM x JOIN y USING(b) JOIN z USING(b)");
2978
2979 let mut schema = MappingSchema::new();
2980 for table in &["x", "y", "z"] {
2981 schema
2982 .add_table(
2983 table,
2984 &[("b".to_string(), DataType::BigInt { length: None })],
2985 None,
2986 )
2987 .expect("schema setup");
2988 }
2989
2990 let result =
2991 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2992 let sql = gen(&result);
2993
2994 assert!(
2996 sql.contains("COALESCE(x.b, y.b, z.b)"),
2997 "Should have 3-table COALESCE: {sql}"
2998 );
2999 assert!(
3001 sql.contains("ON x.b = y.b"),
3002 "First join ON condition: {sql}"
3003 );
3004 }
3005
3006 #[test]
3007 fn test_expand_using_multi_column() {
3008 let expr = parse("SELECT b, c FROM y JOIN z USING(b, c)");
3010
3011 let mut schema = MappingSchema::new();
3012 schema
3013 .add_table(
3014 "y",
3015 &[
3016 ("b".to_string(), DataType::BigInt { length: None }),
3017 ("c".to_string(), DataType::BigInt { length: None }),
3018 ],
3019 None,
3020 )
3021 .expect("schema setup");
3022 schema
3023 .add_table(
3024 "z",
3025 &[
3026 ("b".to_string(), DataType::BigInt { length: None }),
3027 ("c".to_string(), DataType::BigInt { length: None }),
3028 ],
3029 None,
3030 )
3031 .expect("schema setup");
3032
3033 let result =
3034 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3035 let sql = gen(&result);
3036
3037 assert!(
3038 sql.contains("COALESCE(y.b, z.b)"),
3039 "column 'b' should get COALESCE: {sql}"
3040 );
3041 assert!(
3042 sql.contains("COALESCE(y.c, z.c)"),
3043 "column 'c' should get COALESCE: {sql}"
3044 );
3045 assert!(
3047 sql.contains("y.b = z.b") && sql.contains("y.c = z.c"),
3048 "ON should have both equality conditions: {sql}"
3049 );
3050 }
3051
3052 #[test]
3053 fn test_expand_using_star() {
3054 let expr = parse("SELECT * FROM x JOIN y USING(b)");
3056
3057 let mut schema = MappingSchema::new();
3058 schema
3059 .add_table(
3060 "x",
3061 &[
3062 ("a".to_string(), DataType::BigInt { length: None }),
3063 ("b".to_string(), DataType::BigInt { length: None }),
3064 ],
3065 None,
3066 )
3067 .expect("schema setup");
3068 schema
3069 .add_table(
3070 "y",
3071 &[
3072 ("b".to_string(), DataType::BigInt { length: None }),
3073 ("c".to_string(), DataType::BigInt { length: None }),
3074 ],
3075 None,
3076 )
3077 .expect("schema setup");
3078
3079 let result =
3080 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3081 let sql = gen(&result);
3082
3083 assert!(
3085 sql.contains("COALESCE(x.b, y.b) AS b"),
3086 "USING column should be COALESCE in star expansion: {sql}"
3087 );
3088 assert!(sql.contains("x.a"), "non-USING column a from x: {sql}");
3090 assert!(sql.contains("y.c"), "non-USING column c from y: {sql}");
3091 let coalesce_count = sql.matches("COALESCE").count();
3093 assert_eq!(
3094 coalesce_count, 1,
3095 "b should appear only once as COALESCE: {sql}"
3096 );
3097 }
3098
3099 #[test]
3100 fn test_expand_using_table_star() {
3101 let expr = parse("SELECT x.* FROM x JOIN y USING(b)");
3103
3104 let mut schema = MappingSchema::new();
3105 schema
3106 .add_table(
3107 "x",
3108 &[
3109 ("a".to_string(), DataType::BigInt { length: None }),
3110 ("b".to_string(), DataType::BigInt { length: None }),
3111 ],
3112 None,
3113 )
3114 .expect("schema setup");
3115 schema
3116 .add_table(
3117 "y",
3118 &[
3119 ("b".to_string(), DataType::BigInt { length: None }),
3120 ("c".to_string(), DataType::BigInt { length: None }),
3121 ],
3122 None,
3123 )
3124 .expect("schema setup");
3125
3126 let result =
3127 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3128 let sql = gen(&result);
3129
3130 assert!(
3132 sql.contains("COALESCE(x.b, y.b)"),
3133 "USING column from x.* should become COALESCE: {sql}"
3134 );
3135 assert!(sql.contains("x.a"), "non-USING column a: {sql}");
3136 }
3137
3138 #[test]
3139 fn test_qualify_columns_qualified_table_name() {
3140 let expr = parse("SELECT a FROM raw.t1");
3141
3142 let mut schema = MappingSchema::new();
3143 schema
3144 .add_table(
3145 "raw.t1",
3146 &[("a".to_string(), DataType::BigInt { length: None })],
3147 None,
3148 )
3149 .expect("schema setup");
3150
3151 let result =
3152 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3153 let sql = gen(&result);
3154
3155 assert!(
3156 sql.contains("t1.a"),
3157 "column should be qualified with table name: {sql}"
3158 );
3159
3160 let expr = parse("SELECT MAX(a) FROM raw.t1");
3162 let result =
3163 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3164 let sql = gen(&result);
3165 assert!(
3166 sql.contains("t1.a"),
3167 "column in function should be qualified with table name: {sql}"
3168 );
3169
3170 let expr = parse("SELECT ABS(a) FROM raw.t1");
3172 let result =
3173 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3174 let sql = gen(&result);
3175 assert!(
3176 sql.contains("t1.a"),
3177 "column in function should be qualified with table name: {sql}"
3178 );
3179 }
3180
3181 #[test]
3182 fn test_qualify_columns_count_star() {
3183 let expr = parse("SELECT COUNT(*) FROM t1");
3185
3186 let mut schema = MappingSchema::new();
3187 schema
3188 .add_table(
3189 "t1",
3190 &[("id".to_string(), DataType::BigInt { length: None })],
3191 None,
3192 )
3193 .expect("schema setup");
3194
3195 let result =
3196 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3197 let sql = gen(&result);
3198
3199 assert!(
3200 sql.contains("COUNT(*)"),
3201 "COUNT(*) should be preserved: {sql}"
3202 );
3203 }
3204
3205 #[test]
3206 fn test_qualify_columns_correlated_scalar_subquery() {
3207 let expr =
3208 parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1");
3209
3210 let mut schema = MappingSchema::new();
3211 schema
3212 .add_table(
3213 "t1",
3214 &[("id".to_string(), DataType::BigInt { length: None })],
3215 None,
3216 )
3217 .expect("schema setup");
3218 schema
3219 .add_table(
3220 "t2",
3221 &[
3222 ("id".to_string(), DataType::BigInt { length: None }),
3223 ("val".to_string(), DataType::BigInt { length: None }),
3224 ],
3225 None,
3226 )
3227 .expect("schema setup");
3228
3229 let result =
3230 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3231 let sql = gen(&result);
3232
3233 assert!(
3234 sql.contains("t1.id"),
3235 "outer column should be qualified: {sql}"
3236 );
3237 assert!(
3238 sql.contains("t2.id"),
3239 "inner column should be qualified: {sql}"
3240 );
3241 }
3242
3243 #[test]
3244 fn test_qualify_columns_correlated_scalar_subquery_unqualified() {
3245 let expr =
3246 parse("SELECT t1_id, (SELECT AVG(val) FROM t2 WHERE t2_id = t1_id) AS avg_val FROM t1");
3247
3248 let mut schema = MappingSchema::new();
3249 schema
3250 .add_table(
3251 "t1",
3252 &[("t1_id".to_string(), DataType::BigInt { length: None })],
3253 None,
3254 )
3255 .expect("schema setup");
3256 schema
3257 .add_table(
3258 "t2",
3259 &[
3260 ("t2_id".to_string(), DataType::BigInt { length: None }),
3261 ("val".to_string(), DataType::BigInt { length: None }),
3262 ],
3263 None,
3264 )
3265 .expect("schema setup");
3266
3267 let result =
3268 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3269 let sql = gen(&result);
3270
3271 assert!(
3272 sql.contains("t1.t1_id"),
3273 "outer column should be qualified: {sql}"
3274 );
3275 assert!(
3276 sql.contains("t2.t2_id"),
3277 "inner column should be qualified: {sql}"
3278 );
3279 assert!(
3281 sql.contains("= t1.t1_id"),
3282 "correlated column should be qualified: {sql}"
3283 );
3284 }
3285
3286 #[test]
3287 fn test_qualify_columns_correlated_exists_subquery() {
3288 let expr = parse(
3289 "SELECT o_orderpriority FROM orders \
3290 WHERE EXISTS (SELECT * FROM lineitem WHERE l_orderkey = o_orderkey)",
3291 );
3292
3293 let mut schema = MappingSchema::new();
3294 schema
3295 .add_table(
3296 "orders",
3297 &[
3298 ("o_orderpriority".to_string(), DataType::Text),
3299 ("o_orderkey".to_string(), DataType::BigInt { length: None }),
3300 ],
3301 None,
3302 )
3303 .expect("schema setup");
3304 schema
3305 .add_table(
3306 "lineitem",
3307 &[("l_orderkey".to_string(), DataType::BigInt { length: None })],
3308 None,
3309 )
3310 .expect("schema setup");
3311
3312 let result =
3313 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3314 let sql = gen(&result);
3315
3316 assert!(
3317 sql.contains("orders.o_orderpriority"),
3318 "outer column should be qualified: {sql}"
3319 );
3320 assert!(
3321 sql.contains("lineitem.l_orderkey"),
3322 "inner column should be qualified: {sql}"
3323 );
3324 assert!(
3325 sql.contains("orders.o_orderkey"),
3326 "correlated outer column should be qualified: {sql}"
3327 );
3328 }
3329
3330 #[test]
3331 fn test_qualify_columns_rejects_unknown_table() {
3332 let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1");
3333
3334 let mut schema = MappingSchema::new();
3335 schema
3336 .add_table(
3337 "t1",
3338 &[("id".to_string(), DataType::BigInt { length: None })],
3339 None,
3340 )
3341 .expect("schema setup");
3342
3343 let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new());
3344 assert!(
3345 result.is_err(),
3346 "should reject reference to table not in scope or schema"
3347 );
3348 }
3349
3350 #[test]
3355 fn test_needs_quoting_reserved_word() {
3356 let reserved = get_reserved_words(None);
3357 assert!(needs_quoting("select", &reserved));
3358 assert!(needs_quoting("SELECT", &reserved));
3359 assert!(needs_quoting("from", &reserved));
3360 assert!(needs_quoting("WHERE", &reserved));
3361 assert!(needs_quoting("join", &reserved));
3362 assert!(needs_quoting("table", &reserved));
3363 }
3364
3365 #[test]
3366 fn test_needs_quoting_normal_identifiers() {
3367 let reserved = get_reserved_words(None);
3368 assert!(!needs_quoting("foo", &reserved));
3369 assert!(!needs_quoting("my_column", &reserved));
3370 assert!(!needs_quoting("col1", &reserved));
3371 assert!(!needs_quoting("A", &reserved));
3372 assert!(!needs_quoting("_hidden", &reserved));
3373 }
3374
3375 #[test]
3376 fn test_needs_quoting_special_characters() {
3377 let reserved = get_reserved_words(None);
3378 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
3384
3385 #[test]
3386 fn test_needs_quoting_starts_with_digit() {
3387 let reserved = get_reserved_words(None);
3388 assert!(needs_quoting("1col", &reserved));
3389 assert!(needs_quoting("123", &reserved));
3390 assert!(needs_quoting("0_start", &reserved));
3391 }
3392
3393 #[test]
3394 fn test_needs_quoting_empty() {
3395 let reserved = get_reserved_words(None);
3396 assert!(!needs_quoting("", &reserved));
3397 }
3398
3399 #[test]
3400 fn test_maybe_quote_sets_quoted_flag() {
3401 let reserved = get_reserved_words(None);
3402 let mut id = Identifier::new("select");
3403 assert!(!id.quoted);
3404 maybe_quote(&mut id, &reserved);
3405 assert!(id.quoted);
3406 }
3407
3408 #[test]
3409 fn test_maybe_quote_skips_already_quoted() {
3410 let reserved = get_reserved_words(None);
3411 let mut id = Identifier::quoted("myname");
3412 assert!(id.quoted);
3413 maybe_quote(&mut id, &reserved);
3414 assert!(id.quoted); assert_eq!(id.name, "myname"); }
3417
3418 #[test]
3419 fn test_maybe_quote_skips_star() {
3420 let reserved = get_reserved_words(None);
3421 let mut id = Identifier::new("*");
3422 maybe_quote(&mut id, &reserved);
3423 assert!(!id.quoted); }
3425
3426 #[test]
3427 fn test_maybe_quote_skips_normal() {
3428 let reserved = get_reserved_words(None);
3429 let mut id = Identifier::new("normal_col");
3430 maybe_quote(&mut id, &reserved);
3431 assert!(!id.quoted);
3432 }
3433
3434 #[test]
3435 fn test_quote_identifiers_column_with_reserved_name() {
3436 let expr = Expression::boxed_column(Column {
3438 name: Identifier::new("select"),
3439 table: None,
3440 join_mark: false,
3441 trailing_comments: vec![],
3442 span: None,
3443 inferred_type: None,
3444 });
3445 let result = quote_identifiers(expr, None);
3446 if let Expression::Column(col) = &result {
3447 assert!(col.name.quoted, "Column named 'select' should be quoted");
3448 } else {
3449 panic!("Expected Column expression");
3450 }
3451 }
3452
3453 #[test]
3454 fn test_quote_identifiers_column_with_special_chars() {
3455 let expr = Expression::boxed_column(Column {
3456 name: Identifier::new("my column"),
3457 table: None,
3458 join_mark: false,
3459 trailing_comments: vec![],
3460 span: None,
3461 inferred_type: None,
3462 });
3463 let result = quote_identifiers(expr, None);
3464 if let Expression::Column(col) = &result {
3465 assert!(col.name.quoted, "Column with space should be quoted");
3466 } else {
3467 panic!("Expected Column expression");
3468 }
3469 }
3470
3471 #[test]
3472 fn test_quote_identifiers_preserves_normal_column() {
3473 let expr = Expression::boxed_column(Column {
3474 name: Identifier::new("normal_col"),
3475 table: Some(Identifier::new("my_table")),
3476 join_mark: false,
3477 trailing_comments: vec![],
3478 span: None,
3479 inferred_type: None,
3480 });
3481 let result = quote_identifiers(expr, None);
3482 if let Expression::Column(col) = &result {
3483 assert!(!col.name.quoted, "Normal column should not be quoted");
3484 assert!(
3485 !col.table.as_ref().unwrap().quoted,
3486 "Normal table should not be quoted"
3487 );
3488 } else {
3489 panic!("Expected Column expression");
3490 }
3491 }
3492
3493 #[test]
3494 fn test_quote_identifiers_table_ref_reserved() {
3495 let expr = Expression::Table(Box::new(TableRef::new("select")));
3496 let result = quote_identifiers(expr, None);
3497 if let Expression::Table(tr) = &result {
3498 assert!(tr.name.quoted, "Table named 'select' should be quoted");
3499 } else {
3500 panic!("Expected Table expression");
3501 }
3502 }
3503
3504 #[test]
3505 fn test_quote_identifiers_table_ref_schema_and_alias() {
3506 let mut tr = TableRef::new("my_table");
3507 tr.schema = Some(Identifier::new("from"));
3508 tr.alias = Some(Identifier::new("t"));
3509 let expr = Expression::Table(Box::new(tr));
3510 let result = quote_identifiers(expr, None);
3511 if let Expression::Table(tr) = &result {
3512 assert!(!tr.name.quoted, "Normal table name should not be quoted");
3513 assert!(
3514 tr.schema.as_ref().unwrap().quoted,
3515 "Schema named 'from' should be quoted"
3516 );
3517 assert!(
3518 !tr.alias.as_ref().unwrap().quoted,
3519 "Normal alias should not be quoted"
3520 );
3521 } else {
3522 panic!("Expected Table expression");
3523 }
3524 }
3525
3526 #[test]
3527 fn test_quote_identifiers_identifier_node() {
3528 let expr = Expression::Identifier(Identifier::new("order"));
3529 let result = quote_identifiers(expr, None);
3530 if let Expression::Identifier(id) = &result {
3531 assert!(id.quoted, "Identifier named 'order' should be quoted");
3532 } else {
3533 panic!("Expected Identifier expression");
3534 }
3535 }
3536
3537 #[test]
3538 fn test_quote_identifiers_alias() {
3539 let inner = Expression::boxed_column(Column {
3540 name: Identifier::new("val"),
3541 table: None,
3542 join_mark: false,
3543 trailing_comments: vec![],
3544 span: None,
3545 inferred_type: None,
3546 });
3547 let expr = Expression::Alias(Box::new(Alias {
3548 this: inner,
3549 alias: Identifier::new("select"),
3550 column_aliases: vec![Identifier::new("from")],
3551 pre_alias_comments: vec![],
3552 trailing_comments: vec![],
3553 inferred_type: None,
3554 }));
3555 let result = quote_identifiers(expr, None);
3556 if let Expression::Alias(alias) = &result {
3557 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
3558 assert!(
3559 alias.column_aliases[0].quoted,
3560 "Column alias named 'from' should be quoted"
3561 );
3562 if let Expression::Column(col) = &alias.this {
3564 assert!(!col.name.quoted);
3565 }
3566 } else {
3567 panic!("Expected Alias expression");
3568 }
3569 }
3570
3571 #[test]
3572 fn test_quote_identifiers_select_recursive() {
3573 let expr = parse("SELECT a, b FROM t WHERE c = 1");
3575 let result = quote_identifiers(expr, None);
3576 let sql = gen(&result);
3578 assert!(sql.contains("a"));
3580 assert!(sql.contains("b"));
3581 assert!(sql.contains("t"));
3582 }
3583
3584 #[test]
3585 fn test_quote_identifiers_digit_start() {
3586 let expr = Expression::boxed_column(Column {
3587 name: Identifier::new("1col"),
3588 table: None,
3589 join_mark: false,
3590 trailing_comments: vec![],
3591 span: None,
3592 inferred_type: None,
3593 });
3594 let result = quote_identifiers(expr, None);
3595 if let Expression::Column(col) = &result {
3596 assert!(
3597 col.name.quoted,
3598 "Column starting with digit should be quoted"
3599 );
3600 } else {
3601 panic!("Expected Column expression");
3602 }
3603 }
3604
3605 #[test]
3606 fn test_quote_identifiers_with_mysql_dialect() {
3607 let reserved = get_reserved_words(Some(DialectType::MySQL));
3608 assert!(needs_quoting("KILL", &reserved));
3610 assert!(needs_quoting("FORCE", &reserved));
3612 }
3613
3614 #[test]
3615 fn test_quote_identifiers_with_postgresql_dialect() {
3616 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
3617 assert!(needs_quoting("ILIKE", &reserved));
3619 assert!(needs_quoting("VERBOSE", &reserved));
3621 }
3622
3623 #[test]
3624 fn test_quote_identifiers_with_bigquery_dialect() {
3625 let reserved = get_reserved_words(Some(DialectType::BigQuery));
3626 assert!(needs_quoting("STRUCT", &reserved));
3628 assert!(needs_quoting("PROTO", &reserved));
3630 }
3631
3632 #[test]
3633 fn test_quote_identifiers_case_insensitive_reserved() {
3634 let reserved = get_reserved_words(None);
3635 assert!(needs_quoting("Select", &reserved));
3636 assert!(needs_quoting("sElEcT", &reserved));
3637 assert!(needs_quoting("FROM", &reserved));
3638 assert!(needs_quoting("from", &reserved));
3639 }
3640
3641 #[test]
3642 fn test_quote_identifiers_join_using() {
3643 let mut join = crate::expressions::Join {
3645 this: Expression::Table(Box::new(TableRef::new("other"))),
3646 on: None,
3647 using: vec![Identifier::new("key"), Identifier::new("value")],
3648 kind: crate::expressions::JoinKind::Inner,
3649 use_inner_keyword: false,
3650 use_outer_keyword: false,
3651 deferred_condition: false,
3652 join_hint: None,
3653 match_condition: None,
3654 pivots: vec![],
3655 comments: vec![],
3656 nesting_group: 0,
3657 directed: false,
3658 };
3659 let reserved = get_reserved_words(None);
3660 quote_join(&mut join, &reserved);
3661 assert!(
3663 join.using[0].quoted,
3664 "USING identifier 'key' should be quoted"
3665 );
3666 assert!(
3667 !join.using[1].quoted,
3668 "USING identifier 'value' should not be quoted"
3669 );
3670 }
3671
3672 #[test]
3673 fn test_quote_identifiers_cte() {
3674 let mut cte = crate::expressions::Cte {
3676 alias: Identifier::new("select"),
3677 this: Expression::boxed_column(Column {
3678 name: Identifier::new("x"),
3679 table: None,
3680 join_mark: false,
3681 trailing_comments: vec![],
3682 span: None,
3683 inferred_type: None,
3684 }),
3685 columns: vec![Identifier::new("from"), Identifier::new("normal")],
3686 materialized: None,
3687 key_expressions: vec![],
3688 alias_first: false,
3689 comments: Vec::new(),
3690 };
3691 let reserved = get_reserved_words(None);
3692 maybe_quote(&mut cte.alias, &reserved);
3693 for c in &mut cte.columns {
3694 maybe_quote(c, &reserved);
3695 }
3696 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
3697 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
3698 assert!(
3699 !cte.columns[1].quoted,
3700 "CTE column 'normal' should not be quoted"
3701 );
3702 }
3703
3704 #[test]
3705 fn test_quote_identifiers_binary_ops_recurse() {
3706 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
3709 Expression::boxed_column(Column {
3710 name: Identifier::new("select"),
3711 table: None,
3712 join_mark: false,
3713 trailing_comments: vec![],
3714 span: None,
3715 inferred_type: None,
3716 }),
3717 Expression::boxed_column(Column {
3718 name: Identifier::new("normal"),
3719 table: None,
3720 join_mark: false,
3721 trailing_comments: vec![],
3722 span: None,
3723 inferred_type: None,
3724 }),
3725 )));
3726 let result = quote_identifiers(expr, None);
3727 if let Expression::Add(bin) = &result {
3728 if let Expression::Column(left) = &bin.left {
3729 assert!(
3730 left.name.quoted,
3731 "'select' column should be quoted in binary op"
3732 );
3733 }
3734 if let Expression::Column(right) = &bin.right {
3735 assert!(!right.name.quoted, "'normal' column should not be quoted");
3736 }
3737 } else {
3738 panic!("Expected Add expression");
3739 }
3740 }
3741
3742 #[test]
3743 fn test_quote_identifiers_already_quoted_preserved() {
3744 let expr = Expression::boxed_column(Column {
3746 name: Identifier::quoted("normal_name"),
3747 table: None,
3748 join_mark: false,
3749 trailing_comments: vec![],
3750 span: None,
3751 inferred_type: None,
3752 });
3753 let result = quote_identifiers(expr, None);
3754 if let Expression::Column(col) = &result {
3755 assert!(
3756 col.name.quoted,
3757 "Already-quoted identifier should remain quoted"
3758 );
3759 } else {
3760 panic!("Expected Column expression");
3761 }
3762 }
3763
3764 #[test]
3765 fn test_quote_identifiers_full_parsed_query() {
3766 let mut select = crate::expressions::Select::new();
3769 select.expressions.push(Expression::boxed_column(Column {
3770 name: Identifier::new("order"),
3771 table: Some(Identifier::new("t")),
3772 join_mark: false,
3773 trailing_comments: vec![],
3774 span: None,
3775 inferred_type: None,
3776 }));
3777 select.from = Some(crate::expressions::From {
3778 expressions: vec![Expression::Table(Box::new(TableRef::new("t")))],
3779 });
3780 let expr = Expression::Select(Box::new(select));
3781
3782 let result = quote_identifiers(expr, None);
3783 if let Expression::Select(sel) = &result {
3784 if let Expression::Column(col) = &sel.expressions[0] {
3785 assert!(col.name.quoted, "Column named 'order' should be quoted");
3786 assert!(
3787 !col.table.as_ref().unwrap().quoted,
3788 "Table 't' should not be quoted"
3789 );
3790 } else {
3791 panic!("Expected Column in SELECT list");
3792 }
3793 } else {
3794 panic!("Expected Select expression");
3795 }
3796 }
3797
3798 #[test]
3799 fn test_get_reserved_words_all_dialects() {
3800 let dialects = [
3802 None,
3803 Some(DialectType::Generic),
3804 Some(DialectType::MySQL),
3805 Some(DialectType::PostgreSQL),
3806 Some(DialectType::BigQuery),
3807 Some(DialectType::Snowflake),
3808 Some(DialectType::TSQL),
3809 Some(DialectType::ClickHouse),
3810 Some(DialectType::DuckDB),
3811 Some(DialectType::Hive),
3812 Some(DialectType::Spark),
3813 Some(DialectType::Trino),
3814 Some(DialectType::Oracle),
3815 Some(DialectType::Redshift),
3816 ];
3817 for dialect in &dialects {
3818 let words = get_reserved_words(*dialect);
3819 assert!(
3821 words.contains("SELECT"),
3822 "All dialects should have SELECT as reserved"
3823 );
3824 assert!(
3825 words.contains("FROM"),
3826 "All dialects should have FROM as reserved"
3827 );
3828 }
3829 }
3830}