1use crate::dialects::transform_recursive;
9use crate::dialects::DialectType;
10use crate::expressions::{
11 Alias, BinaryOp, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren,
12 Select, TableRef, VarArgFunc, With,
13};
14use crate::resolver::{Resolver, ResolverError};
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{build_scope, traverse_scope, Scope};
17use std::cell::RefCell;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum QualifyColumnsError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Unknown column: {0}")]
28 UnknownColumn(String),
29
30 #[error("Ambiguous column: {0}")]
31 AmbiguousColumn(String),
32
33 #[error("Cannot automatically join: {0}")]
34 CannotAutoJoin(String),
35
36 #[error("Unknown output column: {0}")]
37 UnknownOutputColumn(String),
38
39 #[error("Column could not be resolved: {column}{for_table}")]
40 ColumnNotResolved { column: String, for_table: String },
41
42 #[error("Resolver error: {0}")]
43 ResolverError(#[from] ResolverError),
44}
45
46pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
48
49#[derive(Debug, Clone, Default)]
51pub struct QualifyColumnsOptions {
52 pub expand_alias_refs: bool,
54 pub expand_stars: bool,
56 pub infer_schema: Option<bool>,
58 pub allow_partial_qualification: bool,
60 pub dialect: Option<DialectType>,
62}
63
64impl QualifyColumnsOptions {
65 pub fn new() -> Self {
67 Self {
68 expand_alias_refs: true,
69 expand_stars: true,
70 infer_schema: None,
71 allow_partial_qualification: false,
72 dialect: None,
73 }
74 }
75
76 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
78 self.expand_alias_refs = expand;
79 self
80 }
81
82 pub fn with_expand_stars(mut self, expand: bool) -> Self {
84 self.expand_stars = expand;
85 self
86 }
87
88 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
90 self.dialect = Some(dialect);
91 self
92 }
93
94 pub fn with_allow_partial(mut self, allow: bool) -> Self {
96 self.allow_partial_qualification = allow;
97 self
98 }
99}
100
101pub fn qualify_columns(
116 expression: Expression,
117 schema: &dyn Schema,
118 options: &QualifyColumnsOptions,
119) -> QualifyColumnsResult<Expression> {
120 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
121 let dialect = options.dialect.or_else(|| schema.dialect());
122 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
123
124 let transformed = transform_recursive(expression, &|node| {
125 if first_error.borrow().is_some() {
126 return Ok(node);
127 }
128
129 match node {
130 Expression::Select(mut select) => {
131 if let Some(with) = &mut select.with {
132 pushdown_cte_alias_columns_with(with);
133 }
134
135 let scope_expr = Expression::Select(select.clone());
136 let scope = build_scope(&scope_expr);
137 let mut resolver = Resolver::new(&scope, schema, infer_schema);
138
139 let column_tables = if first_error.borrow().is_none() {
141 match expand_using(&mut select, &scope, &mut resolver) {
142 Ok(ct) => ct,
143 Err(err) => {
144 *first_error.borrow_mut() = Some(err);
145 HashMap::new()
146 }
147 }
148 } else {
149 HashMap::new()
150 };
151
152 if first_error.borrow().is_none() {
154 if let Err(err) = qualify_columns_in_scope(
155 &mut select,
156 &scope,
157 &mut resolver,
158 options.allow_partial_qualification,
159 ) {
160 *first_error.borrow_mut() = Some(err);
161 }
162 }
163
164 if first_error.borrow().is_none() && options.expand_alias_refs {
166 if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
167 *first_error.borrow_mut() = Some(err);
168 }
169 }
170
171 if first_error.borrow().is_none() && options.expand_stars {
173 if let Err(err) =
174 expand_stars(&mut select, &scope, &mut resolver, &column_tables)
175 {
176 *first_error.borrow_mut() = Some(err);
177 }
178 }
179
180 if first_error.borrow().is_none() {
182 if let Err(err) = qualify_outputs_select(&mut select) {
183 *first_error.borrow_mut() = Some(err);
184 }
185 }
186
187 if first_error.borrow().is_none() {
189 if let Err(err) = expand_group_by(&mut select, dialect) {
190 *first_error.borrow_mut() = Some(err);
191 }
192 }
193
194 Ok(Expression::Select(select))
195 }
196 _ => Ok(node),
197 }
198 })
199 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
200
201 if let Some(err) = first_error.into_inner() {
202 return Err(err);
203 }
204
205 Ok(transformed)
206}
207
208pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
213 let mut all_unqualified = Vec::new();
214
215 for scope in traverse_scope(expression) {
216 if let Expression::Select(_) = &scope.expression {
217 let unqualified = get_unqualified_columns(&scope);
219
220 let external = get_external_columns(&scope);
222 if !external.is_empty() && !is_correlated_subquery(&scope) {
223 let first = &external[0];
224 let for_table = if first.table.is_some() {
225 format!(" for table: '{}'", first.table.as_ref().unwrap())
226 } else {
227 String::new()
228 };
229 return Err(QualifyColumnsError::ColumnNotResolved {
230 column: first.name.clone(),
231 for_table,
232 });
233 }
234
235 all_unqualified.extend(unqualified);
236 }
237 }
238
239 if !all_unqualified.is_empty() {
240 let first = &all_unqualified[0];
241 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
242 }
243
244 Ok(())
245}
246
247fn get_source_name(expr: &Expression) -> Option<String> {
249 match expr {
250 Expression::Table(t) => Some(
251 t.alias
252 .as_ref()
253 .map(|a| a.name.clone())
254 .unwrap_or_else(|| t.name.name.clone()),
255 ),
256 Expression::Subquery(sq) => sq.alias.as_ref().map(|a| a.name.clone()),
257 _ => None,
258 }
259}
260
261fn get_ordered_source_names(select: &Select) -> Vec<String> {
264 let mut ordered = Vec::new();
265 if let Some(from) = &select.from {
266 for expr in &from.expressions {
267 if let Some(name) = get_source_name(expr) {
268 ordered.push(name);
269 }
270 }
271 }
272 for join in &select.joins {
273 if let Some(name) = get_source_name(&join.this) {
274 ordered.push(name);
275 }
276 }
277 ordered
278}
279
280fn make_coalesce(column_name: &str, tables: &[String]) -> Expression {
282 let args: Vec<Expression> = tables
283 .iter()
284 .map(|t| Expression::qualified_column(t.as_str(), column_name))
285 .collect();
286 Expression::Coalesce(Box::new(VarArgFunc {
287 expressions: args,
288 original_name: None,
289 inferred_type: None,
290 }))
291}
292
293fn expand_using(
299 select: &mut Select,
300 _scope: &Scope,
301 resolver: &mut Resolver,
302) -> QualifyColumnsResult<HashMap<String, Vec<String>>> {
303 let mut columns: HashMap<String, String> = HashMap::new();
305
306 let mut column_tables: HashMap<String, Vec<String>> = HashMap::new();
308
309 let join_names: HashSet<String> = select
311 .joins
312 .iter()
313 .filter_map(|j| get_source_name(&j.this))
314 .collect();
315
316 let all_ordered = get_ordered_source_names(select);
317 let mut ordered: Vec<String> = all_ordered
318 .iter()
319 .filter(|name| !join_names.contains(name.as_str()))
320 .cloned()
321 .collect();
322
323 if join_names.is_empty() {
324 return Ok(column_tables);
325 }
326
327 fn update_source_columns(
329 source_name: &str,
330 columns: &mut HashMap<String, String>,
331 resolver: &mut Resolver,
332 ) {
333 if let Ok(source_cols) = resolver.get_source_columns(source_name) {
334 for col_name in source_cols {
335 columns
336 .entry(col_name)
337 .or_insert_with(|| source_name.to_string());
338 }
339 }
340 }
341
342 for source_name in &ordered {
344 update_source_columns(source_name, &mut columns, resolver);
345 }
346
347 for i in 0..select.joins.len() {
348 let source_table = ordered.last().cloned().unwrap_or_default();
350 if !source_table.is_empty() {
351 update_source_columns(&source_table, &mut columns, resolver);
352 }
353
354 let join_table = get_source_name(&select.joins[i].this).unwrap_or_default();
356 ordered.push(join_table.clone());
357
358 if select.joins[i].using.is_empty() {
360 continue;
361 }
362
363 let _join_columns: Vec<String> =
364 resolver.get_source_columns(&join_table).unwrap_or_default();
365
366 let using_identifiers: Vec<String> = select.joins[i]
367 .using
368 .iter()
369 .map(|id| id.name.clone())
370 .collect();
371
372 let using_count = using_identifiers.len();
373 let is_semi_or_anti = matches!(
374 select.joins[i].kind,
375 crate::expressions::JoinKind::Semi
376 | crate::expressions::JoinKind::Anti
377 | crate::expressions::JoinKind::LeftSemi
378 | crate::expressions::JoinKind::LeftAnti
379 | crate::expressions::JoinKind::RightSemi
380 | crate::expressions::JoinKind::RightAnti
381 );
382
383 let mut conditions: Vec<Expression> = Vec::new();
384
385 for identifier in &using_identifiers {
386 let table = columns
387 .get(identifier)
388 .cloned()
389 .unwrap_or_else(|| source_table.clone());
390
391 let lhs = if i == 0 || using_count == 1 {
393 Expression::qualified_column(table.as_str(), identifier.as_str())
395 } else {
396 let coalesce_cols: Vec<String> = ordered[..ordered.len() - 1]
399 .iter()
400 .filter(|t| {
401 resolver
402 .get_source_columns(t)
403 .unwrap_or_default()
404 .contains(identifier)
405 })
406 .cloned()
407 .collect();
408
409 if coalesce_cols.len() > 1 {
410 make_coalesce(identifier, &coalesce_cols)
411 } else {
412 Expression::qualified_column(table.as_str(), identifier.as_str())
413 }
414 };
415
416 let rhs = Expression::qualified_column(join_table.as_str(), identifier.as_str());
418
419 conditions.push(Expression::Eq(Box::new(BinaryOp::new(lhs, rhs))));
420
421 if !is_semi_or_anti {
423 let tables = column_tables
424 .entry(identifier.clone())
425 .or_insert_with(Vec::new);
426 if !tables.contains(&table) {
427 tables.push(table.clone());
428 }
429 if !tables.contains(&join_table) {
430 tables.push(join_table.clone());
431 }
432 }
433 }
434
435 let on_condition = conditions
437 .into_iter()
438 .reduce(|acc, cond| Expression::And(Box::new(BinaryOp::new(acc, cond))))
439 .expect("at least one USING column");
440
441 select.joins[i].on = Some(on_condition);
443 select.joins[i].using = vec![];
444 }
445
446 if !column_tables.is_empty() {
448 let mut new_expressions = Vec::with_capacity(select.expressions.len());
450 for expr in &select.expressions {
451 match expr {
452 Expression::Column(col)
453 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
454 {
455 let tables = &column_tables[&col.name.name];
456 let coalesce = make_coalesce(&col.name.name, tables);
457 new_expressions.push(Expression::Alias(Box::new(Alias {
459 this: coalesce,
460 alias: Identifier::new(&col.name.name),
461 column_aliases: vec![],
462 pre_alias_comments: vec![],
463 trailing_comments: vec![],
464 inferred_type: None,
465 })));
466 }
467 _ => {
468 let mut rewritten = expr.clone();
469 rewrite_using_columns_in_expression(&mut rewritten, &column_tables);
470 new_expressions.push(rewritten);
471 }
472 }
473 }
474 select.expressions = new_expressions;
475
476 if let Some(where_clause) = &mut select.where_clause {
478 rewrite_using_columns_in_expression(&mut where_clause.this, &column_tables);
479 }
480
481 if let Some(group_by) = &mut select.group_by {
483 for expr in &mut group_by.expressions {
484 rewrite_using_columns_in_expression(expr, &column_tables);
485 }
486 }
487
488 if let Some(having) = &mut select.having {
490 rewrite_using_columns_in_expression(&mut having.this, &column_tables);
491 }
492
493 if let Some(qualify) = &mut select.qualify {
495 rewrite_using_columns_in_expression(&mut qualify.this, &column_tables);
496 }
497
498 if let Some(order_by) = &mut select.order_by {
500 for ordered in &mut order_by.expressions {
501 rewrite_using_columns_in_expression(&mut ordered.this, &column_tables);
502 }
503 }
504 }
505
506 Ok(column_tables)
507}
508
509fn rewrite_using_columns_in_expression(
511 expr: &mut Expression,
512 column_tables: &HashMap<String, Vec<String>>,
513) {
514 let transformed = transform_recursive(expr.clone(), &|node| match node {
515 Expression::Column(col)
516 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
517 {
518 let tables = &column_tables[&col.name.name];
519 Ok(make_coalesce(&col.name.name, tables))
520 }
521 other => Ok(other),
522 });
523
524 if let Ok(next) = transformed {
525 *expr = next;
526 }
527}
528
529fn qualify_columns_in_scope(
531 select: &mut Select,
532 scope: &Scope,
533 resolver: &mut Resolver,
534 allow_partial: bool,
535) -> QualifyColumnsResult<()> {
536 for expr in &mut select.expressions {
537 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
538 }
539 if let Some(where_clause) = &mut select.where_clause {
540 qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
541 }
542 if let Some(group_by) = &mut select.group_by {
543 for expr in &mut group_by.expressions {
544 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
545 }
546 }
547 if let Some(having) = &mut select.having {
548 qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
549 }
550 if let Some(qualify) = &mut select.qualify {
551 qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
552 }
553 if let Some(order_by) = &mut select.order_by {
554 for ordered in &mut order_by.expressions {
555 qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
556 }
557 }
558 for join in &mut select.joins {
559 qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
560 if let Some(on) = &mut join.on {
561 qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
562 }
563 }
564 Ok(())
565}
566
567fn expand_alias_refs(
574 select: &mut Select,
575 _resolver: &mut Resolver,
576 _dialect: Option<DialectType>,
577) -> QualifyColumnsResult<()> {
578 let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
579
580 for (i, expr) in select.expressions.iter_mut().enumerate() {
581 replace_alias_refs_in_expression(expr, &alias_to_expression, false);
582 if let Expression::Alias(alias) = expr {
583 alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
584 }
585 }
586
587 if let Some(where_clause) = &mut select.where_clause {
588 replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
589 }
590 if let Some(group_by) = &mut select.group_by {
591 for expr in &mut group_by.expressions {
592 replace_alias_refs_in_expression(expr, &alias_to_expression, true);
593 }
594 }
595 if let Some(having) = &mut select.having {
596 replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
597 }
598 if let Some(qualify) = &mut select.qualify {
599 replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
600 }
601 if let Some(order_by) = &mut select.order_by {
602 for ordered in &mut order_by.expressions {
603 replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
604 }
605 }
606
607 Ok(())
608}
609
610fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
617 let projections = select.expressions.clone();
618
619 if let Some(group_by) = &mut select.group_by {
620 for group_expr in &mut group_by.expressions {
621 if let Some(index) = positional_reference(group_expr) {
622 let replacement = select_expression_at_position(&projections, index)?;
623 *group_expr = replacement;
624 }
625 }
626 }
627 Ok(())
628}
629
630fn expand_stars(
640 select: &mut Select,
641 _scope: &Scope,
642 resolver: &mut Resolver,
643 column_tables: &HashMap<String, Vec<String>>,
644) -> QualifyColumnsResult<()> {
645 let mut new_selections: Vec<Expression> = Vec::new();
646 let mut has_star = false;
647 let mut coalesced_columns: HashSet<String> = HashSet::new();
648
649 let ordered_sources = get_ordered_source_names(select);
651
652 for expr in &select.expressions {
653 match expr {
654 Expression::Star(_) => {
655 has_star = true;
656 for source_name in &ordered_sources {
657 if let Ok(columns) = resolver.get_source_columns(source_name) {
658 if columns.contains(&"*".to_string()) || columns.is_empty() {
659 return Ok(());
660 }
661 for col_name in &columns {
662 if coalesced_columns.contains(col_name) {
663 continue;
665 }
666 if let Some(tables) = column_tables.get(col_name) {
667 if tables.contains(source_name) {
668 coalesced_columns.insert(col_name.clone());
670 let coalesce = make_coalesce(col_name, tables);
671 new_selections.push(Expression::Alias(Box::new(Alias {
672 this: coalesce,
673 alias: Identifier::new(col_name),
674 column_aliases: vec![],
675 pre_alias_comments: vec![],
676 trailing_comments: vec![],
677 inferred_type: None,
678 })));
679 continue;
680 }
681 }
682 new_selections
683 .push(create_qualified_column(col_name, Some(source_name)));
684 }
685 }
686 }
687 }
688 Expression::Column(col) if is_star_column(col) => {
689 has_star = true;
690 if let Some(table) = &col.table {
691 let table_name = &table.name;
692 if !ordered_sources.contains(table_name) {
693 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
694 }
695 if let Ok(columns) = resolver.get_source_columns(table_name) {
696 if columns.contains(&"*".to_string()) || columns.is_empty() {
697 return Ok(());
698 }
699 for col_name in &columns {
700 if coalesced_columns.contains(col_name) {
701 continue;
702 }
703 if let Some(tables) = column_tables.get(col_name) {
704 if tables.contains(table_name) {
705 coalesced_columns.insert(col_name.clone());
706 let coalesce = make_coalesce(col_name, tables);
707 new_selections.push(Expression::Alias(Box::new(Alias {
708 this: coalesce,
709 alias: Identifier::new(col_name),
710 column_aliases: vec![],
711 pre_alias_comments: vec![],
712 trailing_comments: vec![],
713 inferred_type: None,
714 })));
715 continue;
716 }
717 }
718 new_selections
719 .push(create_qualified_column(col_name, Some(table_name)));
720 }
721 }
722 }
723 }
724 _ => new_selections.push(expr.clone()),
725 }
726 }
727
728 if has_star {
729 select.expressions = new_selections;
730 }
731
732 Ok(())
733}
734
735pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
742 if let Expression::Select(mut select) = scope.expression.clone() {
743 qualify_outputs_select(&mut select)?;
744 }
745 Ok(())
746}
747
748fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
749 let mut new_selections: Vec<Expression> = Vec::new();
750
751 for (i, expr) in select.expressions.iter().enumerate() {
752 match expr {
753 Expression::Alias(_) => new_selections.push(expr.clone()),
754 Expression::Column(col) => {
755 new_selections.push(create_alias(expr.clone(), &col.name.name));
756 }
757 Expression::Star(_) => new_selections.push(expr.clone()),
758 _ => {
759 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
760 new_selections.push(create_alias(expr.clone(), &alias_name));
761 }
762 }
763 }
764
765 select.expressions = new_selections;
766 Ok(())
767}
768
769fn qualify_columns_in_expression(
770 expr: &mut Expression,
771 scope: &Scope,
772 resolver: &mut Resolver,
773 allow_partial: bool,
774) -> QualifyColumnsResult<()> {
775 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
776 let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
777
778 let transformed = transform_recursive(expr.clone(), &|node| {
779 if first_error.borrow().is_some() {
780 return Ok(node);
781 }
782
783 match node {
784 Expression::Column(mut col) => {
785 if let Err(err) = qualify_single_column(
786 &mut col,
787 scope,
788 &mut resolver_cell.borrow_mut(),
789 allow_partial,
790 ) {
791 *first_error.borrow_mut() = Some(err);
792 }
793 Ok(Expression::Column(col))
794 }
795 _ => Ok(node),
796 }
797 })
798 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
799
800 if let Some(err) = first_error.into_inner() {
801 return Err(err);
802 }
803
804 *expr = transformed;
805 Ok(())
806}
807
808fn qualify_single_column(
809 col: &mut Column,
810 scope: &Scope,
811 resolver: &mut Resolver,
812 allow_partial: bool,
813) -> QualifyColumnsResult<()> {
814 if is_star_column(col) {
815 return Ok(());
816 }
817
818 if let Some(table) = &col.table {
819 let table_name = &table.name;
820 if !scope.sources.contains_key(table_name) {
821 if resolver.table_exists_in_schema(table_name) {
825 return Ok(());
826 }
827 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
828 }
829
830 if let Ok(source_columns) = resolver.get_source_columns(table_name) {
831 let normalized_column_name = normalize_column_name(&col.name.name, resolver.dialect);
832 if !allow_partial
833 && !source_columns.is_empty()
834 && !source_columns.iter().any(|column| {
835 normalize_column_name(column, resolver.dialect) == normalized_column_name
836 })
837 && !source_columns.contains(&"*".to_string())
838 {
839 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
840 }
841 }
842 return Ok(());
843 }
844
845 if let Some(table_name) = resolver.get_table(&col.name.name) {
846 col.table = Some(Identifier::new(table_name));
847 return Ok(());
848 }
849
850 if !allow_partial {
851 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
852 }
853
854 Ok(())
855}
856
857fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
858 normalize_name(name, dialect, false, true)
859}
860
861fn replace_alias_refs_in_expression(
862 expr: &mut Expression,
863 alias_to_expression: &HashMap<String, (Expression, usize)>,
864 literal_index: bool,
865) {
866 let transformed = transform_recursive(expr.clone(), &|node| match node {
867 Expression::Column(col) if col.table.is_none() => {
868 if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
869 if literal_index && matches!(alias_expr, Expression::Literal(_)) {
870 return Ok(Expression::number(*index as i64));
871 }
872 return Ok(Expression::Paren(Box::new(Paren {
873 this: alias_expr.clone(),
874 trailing_comments: vec![],
875 })));
876 }
877 Ok(Expression::Column(col))
878 }
879 other => Ok(other),
880 });
881
882 if let Ok(next) = transformed {
883 *expr = next;
884 }
885}
886
887fn positional_reference(expr: &Expression) -> Option<usize> {
888 match expr {
889 Expression::Literal(Literal::Number(value)) => value.parse::<usize>().ok(),
890 _ => None,
891 }
892}
893
894fn select_expression_at_position(
895 projections: &[Expression],
896 index: usize,
897) -> QualifyColumnsResult<Expression> {
898 if index == 0 || index > projections.len() {
899 return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
900 }
901
902 let projection = projections[index - 1].clone();
903 Ok(match projection {
904 Expression::Alias(alias) => alias.this.clone(),
905 other => other,
906 })
907}
908
909fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
912 let mut words: HashSet<&'static str> = [
914 "ADD",
916 "ALL",
917 "ALTER",
918 "AND",
919 "ANY",
920 "AS",
921 "ASC",
922 "BETWEEN",
923 "BY",
924 "CASE",
925 "CAST",
926 "CHECK",
927 "COLUMN",
928 "CONSTRAINT",
929 "CREATE",
930 "CROSS",
931 "CURRENT",
932 "CURRENT_DATE",
933 "CURRENT_TIME",
934 "CURRENT_TIMESTAMP",
935 "CURRENT_USER",
936 "DATABASE",
937 "DEFAULT",
938 "DELETE",
939 "DESC",
940 "DISTINCT",
941 "DROP",
942 "ELSE",
943 "END",
944 "ESCAPE",
945 "EXCEPT",
946 "EXISTS",
947 "FALSE",
948 "FETCH",
949 "FOR",
950 "FOREIGN",
951 "FROM",
952 "FULL",
953 "GRANT",
954 "GROUP",
955 "HAVING",
956 "IF",
957 "IN",
958 "INDEX",
959 "INNER",
960 "INSERT",
961 "INTERSECT",
962 "INTO",
963 "IS",
964 "JOIN",
965 "KEY",
966 "LEFT",
967 "LIKE",
968 "LIMIT",
969 "NATURAL",
970 "NOT",
971 "NULL",
972 "OFFSET",
973 "ON",
974 "OR",
975 "ORDER",
976 "OUTER",
977 "PRIMARY",
978 "REFERENCES",
979 "REPLACE",
980 "RETURNING",
981 "RIGHT",
982 "ROLLBACK",
983 "ROW",
984 "ROWS",
985 "SELECT",
986 "SESSION_USER",
987 "SET",
988 "SOME",
989 "TABLE",
990 "THEN",
991 "TO",
992 "TRUE",
993 "TRUNCATE",
994 "UNION",
995 "UNIQUE",
996 "UPDATE",
997 "USING",
998 "VALUES",
999 "VIEW",
1000 "WHEN",
1001 "WHERE",
1002 "WINDOW",
1003 "WITH",
1004 ]
1005 .iter()
1006 .copied()
1007 .collect();
1008
1009 match dialect {
1011 Some(DialectType::MySQL) => {
1012 words.extend(
1013 [
1014 "ANALYZE",
1015 "BOTH",
1016 "CHANGE",
1017 "CONDITION",
1018 "DATABASES",
1019 "DAY_HOUR",
1020 "DAY_MICROSECOND",
1021 "DAY_MINUTE",
1022 "DAY_SECOND",
1023 "DELAYED",
1024 "DETERMINISTIC",
1025 "DIV",
1026 "DUAL",
1027 "EACH",
1028 "ELSEIF",
1029 "ENCLOSED",
1030 "EXPLAIN",
1031 "FLOAT4",
1032 "FLOAT8",
1033 "FORCE",
1034 "HOUR_MICROSECOND",
1035 "HOUR_MINUTE",
1036 "HOUR_SECOND",
1037 "IGNORE",
1038 "INFILE",
1039 "INT1",
1040 "INT2",
1041 "INT3",
1042 "INT4",
1043 "INT8",
1044 "ITERATE",
1045 "KEYS",
1046 "KILL",
1047 "LEADING",
1048 "LEAVE",
1049 "LINES",
1050 "LOAD",
1051 "LOCK",
1052 "LONG",
1053 "LONGBLOB",
1054 "LONGTEXT",
1055 "LOOP",
1056 "LOW_PRIORITY",
1057 "MATCH",
1058 "MEDIUMBLOB",
1059 "MEDIUMINT",
1060 "MEDIUMTEXT",
1061 "MINUTE_MICROSECOND",
1062 "MINUTE_SECOND",
1063 "MOD",
1064 "MODIFIES",
1065 "NO_WRITE_TO_BINLOG",
1066 "OPTIMIZE",
1067 "OPTIONALLY",
1068 "OUT",
1069 "OUTFILE",
1070 "PURGE",
1071 "READS",
1072 "REGEXP",
1073 "RELEASE",
1074 "RENAME",
1075 "REPEAT",
1076 "REQUIRE",
1077 "RESIGNAL",
1078 "RETURN",
1079 "REVOKE",
1080 "RLIKE",
1081 "SCHEMA",
1082 "SCHEMAS",
1083 "SECOND_MICROSECOND",
1084 "SENSITIVE",
1085 "SEPARATOR",
1086 "SHOW",
1087 "SIGNAL",
1088 "SPATIAL",
1089 "SQL",
1090 "SQLEXCEPTION",
1091 "SQLSTATE",
1092 "SQLWARNING",
1093 "SQL_BIG_RESULT",
1094 "SQL_CALC_FOUND_ROWS",
1095 "SQL_SMALL_RESULT",
1096 "SSL",
1097 "STARTING",
1098 "STRAIGHT_JOIN",
1099 "TERMINATED",
1100 "TINYBLOB",
1101 "TINYINT",
1102 "TINYTEXT",
1103 "TRAILING",
1104 "TRIGGER",
1105 "UNDO",
1106 "UNLOCK",
1107 "UNSIGNED",
1108 "USAGE",
1109 "UTC_DATE",
1110 "UTC_TIME",
1111 "UTC_TIMESTAMP",
1112 "VARBINARY",
1113 "VARCHARACTER",
1114 "WHILE",
1115 "WRITE",
1116 "XOR",
1117 "YEAR_MONTH",
1118 "ZEROFILL",
1119 ]
1120 .iter()
1121 .copied(),
1122 );
1123 }
1124 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
1125 words.extend(
1126 [
1127 "ANALYSE",
1128 "ANALYZE",
1129 "ARRAY",
1130 "AUTHORIZATION",
1131 "BINARY",
1132 "BOTH",
1133 "COLLATE",
1134 "CONCURRENTLY",
1135 "DO",
1136 "FREEZE",
1137 "ILIKE",
1138 "INITIALLY",
1139 "ISNULL",
1140 "LATERAL",
1141 "LEADING",
1142 "LOCALTIME",
1143 "LOCALTIMESTAMP",
1144 "NOTNULL",
1145 "ONLY",
1146 "OVERLAPS",
1147 "PLACING",
1148 "SIMILAR",
1149 "SYMMETRIC",
1150 "TABLESAMPLE",
1151 "TRAILING",
1152 "VARIADIC",
1153 "VERBOSE",
1154 ]
1155 .iter()
1156 .copied(),
1157 );
1158 }
1159 Some(DialectType::BigQuery) => {
1160 words.extend(
1161 [
1162 "ASSERT_ROWS_MODIFIED",
1163 "COLLATE",
1164 "CONTAINS",
1165 "CUBE",
1166 "DEFINE",
1167 "ENUM",
1168 "EXTRACT",
1169 "FOLLOWING",
1170 "GROUPING",
1171 "GROUPS",
1172 "HASH",
1173 "IGNORE",
1174 "LATERAL",
1175 "LOOKUP",
1176 "MERGE",
1177 "NEW",
1178 "NO",
1179 "NULLS",
1180 "OF",
1181 "OVER",
1182 "PARTITION",
1183 "PRECEDING",
1184 "PROTO",
1185 "RANGE",
1186 "RECURSIVE",
1187 "RESPECT",
1188 "ROLLUP",
1189 "STRUCT",
1190 "TABLESAMPLE",
1191 "TREAT",
1192 "UNBOUNDED",
1193 "WITHIN",
1194 ]
1195 .iter()
1196 .copied(),
1197 );
1198 }
1199 Some(DialectType::Snowflake) => {
1200 words.extend(
1201 [
1202 "ACCOUNT",
1203 "BOTH",
1204 "CONNECT",
1205 "FOLLOWING",
1206 "ILIKE",
1207 "INCREMENT",
1208 "ISSUE",
1209 "LATERAL",
1210 "LEADING",
1211 "LOCALTIME",
1212 "LOCALTIMESTAMP",
1213 "MINUS",
1214 "QUALIFY",
1215 "REGEXP",
1216 "RLIKE",
1217 "SOME",
1218 "START",
1219 "TABLESAMPLE",
1220 "TOP",
1221 "TRAILING",
1222 "TRY_CAST",
1223 ]
1224 .iter()
1225 .copied(),
1226 );
1227 }
1228 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
1229 words.extend(
1230 [
1231 "BACKUP",
1232 "BREAK",
1233 "BROWSE",
1234 "BULK",
1235 "CASCADE",
1236 "CHECKPOINT",
1237 "CLOSE",
1238 "CLUSTERED",
1239 "COALESCE",
1240 "COMPUTE",
1241 "CONTAINS",
1242 "CONTAINSTABLE",
1243 "CONTINUE",
1244 "CONVERT",
1245 "DBCC",
1246 "DEALLOCATE",
1247 "DENY",
1248 "DISK",
1249 "DISTRIBUTED",
1250 "DUMP",
1251 "ERRLVL",
1252 "EXEC",
1253 "EXECUTE",
1254 "EXIT",
1255 "EXTERNAL",
1256 "FILE",
1257 "FILLFACTOR",
1258 "FREETEXT",
1259 "FREETEXTTABLE",
1260 "FUNCTION",
1261 "GOTO",
1262 "HOLDLOCK",
1263 "IDENTITY",
1264 "IDENTITYCOL",
1265 "IDENTITY_INSERT",
1266 "KILL",
1267 "LINENO",
1268 "MERGE",
1269 "NONCLUSTERED",
1270 "NULLIF",
1271 "OF",
1272 "OFF",
1273 "OFFSETS",
1274 "OPEN",
1275 "OPENDATASOURCE",
1276 "OPENQUERY",
1277 "OPENROWSET",
1278 "OPENXML",
1279 "OVER",
1280 "PERCENT",
1281 "PIVOT",
1282 "PLAN",
1283 "PRINT",
1284 "PROC",
1285 "PROCEDURE",
1286 "PUBLIC",
1287 "RAISERROR",
1288 "READ",
1289 "READTEXT",
1290 "RECONFIGURE",
1291 "REPLICATION",
1292 "RESTORE",
1293 "RESTRICT",
1294 "REVERT",
1295 "ROWCOUNT",
1296 "ROWGUIDCOL",
1297 "RULE",
1298 "SAVE",
1299 "SECURITYAUDIT",
1300 "SEMANTICKEYPHRASETABLE",
1301 "SEMANTICSIMILARITYDETAILSTABLE",
1302 "SEMANTICSIMILARITYTABLE",
1303 "SETUSER",
1304 "SHUTDOWN",
1305 "STATISTICS",
1306 "SYSTEM_USER",
1307 "TEXTSIZE",
1308 "TOP",
1309 "TRAN",
1310 "TRANSACTION",
1311 "TRIGGER",
1312 "TSEQUAL",
1313 "UNPIVOT",
1314 "UPDATETEXT",
1315 "WAITFOR",
1316 "WRITETEXT",
1317 ]
1318 .iter()
1319 .copied(),
1320 );
1321 }
1322 Some(DialectType::ClickHouse) => {
1323 words.extend(
1324 [
1325 "ANTI",
1326 "ARRAY",
1327 "ASOF",
1328 "FINAL",
1329 "FORMAT",
1330 "GLOBAL",
1331 "INF",
1332 "KILL",
1333 "MATERIALIZED",
1334 "NAN",
1335 "PREWHERE",
1336 "SAMPLE",
1337 "SEMI",
1338 "SETTINGS",
1339 "TOP",
1340 ]
1341 .iter()
1342 .copied(),
1343 );
1344 }
1345 Some(DialectType::DuckDB) => {
1346 words.extend(
1347 [
1348 "ANALYSE",
1349 "ANALYZE",
1350 "ARRAY",
1351 "BOTH",
1352 "LATERAL",
1353 "LEADING",
1354 "LOCALTIME",
1355 "LOCALTIMESTAMP",
1356 "PLACING",
1357 "QUALIFY",
1358 "SIMILAR",
1359 "TABLESAMPLE",
1360 "TRAILING",
1361 ]
1362 .iter()
1363 .copied(),
1364 );
1365 }
1366 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
1367 words.extend(
1368 [
1369 "BOTH",
1370 "CLUSTER",
1371 "DISTRIBUTE",
1372 "EXCHANGE",
1373 "EXTENDED",
1374 "FUNCTION",
1375 "LATERAL",
1376 "LEADING",
1377 "MACRO",
1378 "OVER",
1379 "PARTITION",
1380 "PERCENT",
1381 "RANGE",
1382 "READS",
1383 "REDUCE",
1384 "REGEXP",
1385 "REVOKE",
1386 "RLIKE",
1387 "ROLLUP",
1388 "SEMI",
1389 "SORT",
1390 "TABLESAMPLE",
1391 "TRAILING",
1392 "TRANSFORM",
1393 "UNBOUNDED",
1394 "UNIQUEJOIN",
1395 ]
1396 .iter()
1397 .copied(),
1398 );
1399 }
1400 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
1401 words.extend(
1402 [
1403 "CUBE",
1404 "DEALLOCATE",
1405 "DESCRIBE",
1406 "EXECUTE",
1407 "EXTRACT",
1408 "GROUPING",
1409 "LATERAL",
1410 "LOCALTIME",
1411 "LOCALTIMESTAMP",
1412 "NORMALIZE",
1413 "PREPARE",
1414 "ROLLUP",
1415 "SOME",
1416 "TABLESAMPLE",
1417 "UESCAPE",
1418 "UNNEST",
1419 ]
1420 .iter()
1421 .copied(),
1422 );
1423 }
1424 Some(DialectType::Oracle) => {
1425 words.extend(
1426 [
1427 "ACCESS",
1428 "AUDIT",
1429 "CLUSTER",
1430 "COMMENT",
1431 "COMPRESS",
1432 "CONNECT",
1433 "EXCLUSIVE",
1434 "FILE",
1435 "IDENTIFIED",
1436 "IMMEDIATE",
1437 "INCREMENT",
1438 "INITIAL",
1439 "LEVEL",
1440 "LOCK",
1441 "LONG",
1442 "MAXEXTENTS",
1443 "MINUS",
1444 "MODE",
1445 "NOAUDIT",
1446 "NOCOMPRESS",
1447 "NOWAIT",
1448 "NUMBER",
1449 "OF",
1450 "OFFLINE",
1451 "ONLINE",
1452 "PCTFREE",
1453 "PRIOR",
1454 "RAW",
1455 "RENAME",
1456 "RESOURCE",
1457 "REVOKE",
1458 "SHARE",
1459 "SIZE",
1460 "START",
1461 "SUCCESSFUL",
1462 "SYNONYM",
1463 "SYSDATE",
1464 "TRIGGER",
1465 "UID",
1466 "VALIDATE",
1467 "VARCHAR2",
1468 "WHENEVER",
1469 ]
1470 .iter()
1471 .copied(),
1472 );
1473 }
1474 Some(DialectType::Redshift) => {
1475 words.extend(
1476 [
1477 "AZ64",
1478 "BZIP2",
1479 "DELTA",
1480 "DELTA32K",
1481 "DISTSTYLE",
1482 "ENCODE",
1483 "GZIP",
1484 "ILIKE",
1485 "LIMIT",
1486 "LUNS",
1487 "LZO",
1488 "LZOP",
1489 "MOSTLY13",
1490 "MOSTLY32",
1491 "MOSTLY8",
1492 "RAW",
1493 "SIMILAR",
1494 "SNAPSHOT",
1495 "SORTKEY",
1496 "SYSDATE",
1497 "TOP",
1498 "ZSTD",
1499 ]
1500 .iter()
1501 .copied(),
1502 );
1503 }
1504 _ => {
1505 words.extend(
1507 [
1508 "ANALYZE",
1509 "ARRAY",
1510 "BOTH",
1511 "CUBE",
1512 "GROUPING",
1513 "LATERAL",
1514 "LEADING",
1515 "LOCALTIME",
1516 "LOCALTIMESTAMP",
1517 "OVER",
1518 "PARTITION",
1519 "QUALIFY",
1520 "RANGE",
1521 "ROLLUP",
1522 "SIMILAR",
1523 "SOME",
1524 "TABLESAMPLE",
1525 "TRAILING",
1526 ]
1527 .iter()
1528 .copied(),
1529 );
1530 }
1531 }
1532
1533 words
1534}
1535
1536fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1544 if name.is_empty() {
1545 return false;
1546 }
1547
1548 if name.as_bytes()[0].is_ascii_digit() {
1550 return true;
1551 }
1552
1553 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1555 return true;
1556 }
1557
1558 let upper = name.to_uppercase();
1560 reserved_words.contains(upper.as_str())
1561}
1562
1563fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1565 if id.quoted || id.name.is_empty() || id.name == "*" {
1568 return;
1569 }
1570 if needs_quoting(&id.name, reserved_words) {
1571 id.quoted = true;
1572 }
1573}
1574
1575fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1577 match expr {
1578 Expression::Identifier(id) => {
1580 maybe_quote(id, reserved_words);
1581 }
1582
1583 Expression::Column(col) => {
1584 maybe_quote(&mut col.name, reserved_words);
1585 if let Some(ref mut table) = col.table {
1586 maybe_quote(table, reserved_words);
1587 }
1588 }
1589
1590 Expression::Table(table_ref) => {
1591 maybe_quote(&mut table_ref.name, reserved_words);
1592 if let Some(ref mut schema) = table_ref.schema {
1593 maybe_quote(schema, reserved_words);
1594 }
1595 if let Some(ref mut catalog) = table_ref.catalog {
1596 maybe_quote(catalog, reserved_words);
1597 }
1598 if let Some(ref mut alias) = table_ref.alias {
1599 maybe_quote(alias, reserved_words);
1600 }
1601 for ca in &mut table_ref.column_aliases {
1602 maybe_quote(ca, reserved_words);
1603 }
1604 for p in &mut table_ref.partitions {
1605 maybe_quote(p, reserved_words);
1606 }
1607 for h in &mut table_ref.hints {
1609 quote_identifiers_recursive(h, reserved_words);
1610 }
1611 if let Some(ref mut ver) = table_ref.version {
1612 quote_identifiers_recursive(&mut ver.this, reserved_words);
1613 if let Some(ref mut e) = ver.expression {
1614 quote_identifiers_recursive(e, reserved_words);
1615 }
1616 }
1617 }
1618
1619 Expression::Star(star) => {
1620 if let Some(ref mut table) = star.table {
1621 maybe_quote(table, reserved_words);
1622 }
1623 if let Some(ref mut except_ids) = star.except {
1624 for id in except_ids {
1625 maybe_quote(id, reserved_words);
1626 }
1627 }
1628 if let Some(ref mut replace_aliases) = star.replace {
1629 for alias in replace_aliases {
1630 maybe_quote(&mut alias.alias, reserved_words);
1631 quote_identifiers_recursive(&mut alias.this, reserved_words);
1632 }
1633 }
1634 if let Some(ref mut rename_pairs) = star.rename {
1635 for (from, to) in rename_pairs {
1636 maybe_quote(from, reserved_words);
1637 maybe_quote(to, reserved_words);
1638 }
1639 }
1640 }
1641
1642 Expression::Alias(alias) => {
1644 maybe_quote(&mut alias.alias, reserved_words);
1645 for ca in &mut alias.column_aliases {
1646 maybe_quote(ca, reserved_words);
1647 }
1648 quote_identifiers_recursive(&mut alias.this, reserved_words);
1649 }
1650
1651 Expression::Select(select) => {
1653 for e in &mut select.expressions {
1654 quote_identifiers_recursive(e, reserved_words);
1655 }
1656 if let Some(ref mut from) = select.from {
1657 for e in &mut from.expressions {
1658 quote_identifiers_recursive(e, reserved_words);
1659 }
1660 }
1661 for join in &mut select.joins {
1662 quote_join(join, reserved_words);
1663 }
1664 for lv in &mut select.lateral_views {
1665 quote_lateral_view(lv, reserved_words);
1666 }
1667 if let Some(ref mut prewhere) = select.prewhere {
1668 quote_identifiers_recursive(prewhere, reserved_words);
1669 }
1670 if let Some(ref mut wh) = select.where_clause {
1671 quote_identifiers_recursive(&mut wh.this, reserved_words);
1672 }
1673 if let Some(ref mut gb) = select.group_by {
1674 for e in &mut gb.expressions {
1675 quote_identifiers_recursive(e, reserved_words);
1676 }
1677 }
1678 if let Some(ref mut hv) = select.having {
1679 quote_identifiers_recursive(&mut hv.this, reserved_words);
1680 }
1681 if let Some(ref mut q) = select.qualify {
1682 quote_identifiers_recursive(&mut q.this, reserved_words);
1683 }
1684 if let Some(ref mut ob) = select.order_by {
1685 for o in &mut ob.expressions {
1686 quote_identifiers_recursive(&mut o.this, reserved_words);
1687 }
1688 }
1689 if let Some(ref mut lim) = select.limit {
1690 quote_identifiers_recursive(&mut lim.this, reserved_words);
1691 }
1692 if let Some(ref mut off) = select.offset {
1693 quote_identifiers_recursive(&mut off.this, reserved_words);
1694 }
1695 if let Some(ref mut with) = select.with {
1696 quote_with(with, reserved_words);
1697 }
1698 if let Some(ref mut windows) = select.windows {
1699 for nw in windows {
1700 maybe_quote(&mut nw.name, reserved_words);
1701 quote_over(&mut nw.spec, reserved_words);
1702 }
1703 }
1704 if let Some(ref mut distinct_on) = select.distinct_on {
1705 for e in distinct_on {
1706 quote_identifiers_recursive(e, reserved_words);
1707 }
1708 }
1709 if let Some(ref mut limit_by) = select.limit_by {
1710 for e in limit_by {
1711 quote_identifiers_recursive(e, reserved_words);
1712 }
1713 }
1714 if let Some(ref mut settings) = select.settings {
1715 for e in settings {
1716 quote_identifiers_recursive(e, reserved_words);
1717 }
1718 }
1719 if let Some(ref mut format) = select.format {
1720 quote_identifiers_recursive(format, reserved_words);
1721 }
1722 }
1723
1724 Expression::Union(u) => {
1726 quote_identifiers_recursive(&mut u.left, reserved_words);
1727 quote_identifiers_recursive(&mut u.right, reserved_words);
1728 if let Some(ref mut with) = u.with {
1729 quote_with(with, reserved_words);
1730 }
1731 if let Some(ref mut ob) = u.order_by {
1732 for o in &mut ob.expressions {
1733 quote_identifiers_recursive(&mut o.this, reserved_words);
1734 }
1735 }
1736 if let Some(ref mut lim) = u.limit {
1737 quote_identifiers_recursive(lim, reserved_words);
1738 }
1739 if let Some(ref mut off) = u.offset {
1740 quote_identifiers_recursive(off, reserved_words);
1741 }
1742 }
1743 Expression::Intersect(i) => {
1744 quote_identifiers_recursive(&mut i.left, reserved_words);
1745 quote_identifiers_recursive(&mut i.right, reserved_words);
1746 if let Some(ref mut with) = i.with {
1747 quote_with(with, reserved_words);
1748 }
1749 if let Some(ref mut ob) = i.order_by {
1750 for o in &mut ob.expressions {
1751 quote_identifiers_recursive(&mut o.this, reserved_words);
1752 }
1753 }
1754 }
1755 Expression::Except(e) => {
1756 quote_identifiers_recursive(&mut e.left, reserved_words);
1757 quote_identifiers_recursive(&mut e.right, reserved_words);
1758 if let Some(ref mut with) = e.with {
1759 quote_with(with, reserved_words);
1760 }
1761 if let Some(ref mut ob) = e.order_by {
1762 for o in &mut ob.expressions {
1763 quote_identifiers_recursive(&mut o.this, reserved_words);
1764 }
1765 }
1766 }
1767
1768 Expression::Subquery(sq) => {
1770 quote_identifiers_recursive(&mut sq.this, reserved_words);
1771 if let Some(ref mut alias) = sq.alias {
1772 maybe_quote(alias, reserved_words);
1773 }
1774 for ca in &mut sq.column_aliases {
1775 maybe_quote(ca, reserved_words);
1776 }
1777 if let Some(ref mut ob) = sq.order_by {
1778 for o in &mut ob.expressions {
1779 quote_identifiers_recursive(&mut o.this, reserved_words);
1780 }
1781 }
1782 }
1783
1784 Expression::Insert(ins) => {
1786 quote_table_ref(&mut ins.table, reserved_words);
1787 for c in &mut ins.columns {
1788 maybe_quote(c, reserved_words);
1789 }
1790 for row in &mut ins.values {
1791 for e in row {
1792 quote_identifiers_recursive(e, reserved_words);
1793 }
1794 }
1795 if let Some(ref mut q) = ins.query {
1796 quote_identifiers_recursive(q, reserved_words);
1797 }
1798 for (id, val) in &mut ins.partition {
1799 maybe_quote(id, reserved_words);
1800 if let Some(ref mut v) = val {
1801 quote_identifiers_recursive(v, reserved_words);
1802 }
1803 }
1804 for e in &mut ins.returning {
1805 quote_identifiers_recursive(e, reserved_words);
1806 }
1807 if let Some(ref mut on_conflict) = ins.on_conflict {
1808 quote_identifiers_recursive(on_conflict, reserved_words);
1809 }
1810 if let Some(ref mut with) = ins.with {
1811 quote_with(with, reserved_words);
1812 }
1813 if let Some(ref mut alias) = ins.alias {
1814 maybe_quote(alias, reserved_words);
1815 }
1816 if let Some(ref mut src_alias) = ins.source_alias {
1817 maybe_quote(src_alias, reserved_words);
1818 }
1819 }
1820
1821 Expression::Update(upd) => {
1822 quote_table_ref(&mut upd.table, reserved_words);
1823 for tr in &mut upd.extra_tables {
1824 quote_table_ref(tr, reserved_words);
1825 }
1826 for join in &mut upd.table_joins {
1827 quote_join(join, reserved_words);
1828 }
1829 for (id, val) in &mut upd.set {
1830 maybe_quote(id, reserved_words);
1831 quote_identifiers_recursive(val, reserved_words);
1832 }
1833 if let Some(ref mut from) = upd.from_clause {
1834 for e in &mut from.expressions {
1835 quote_identifiers_recursive(e, reserved_words);
1836 }
1837 }
1838 for join in &mut upd.from_joins {
1839 quote_join(join, reserved_words);
1840 }
1841 if let Some(ref mut wh) = upd.where_clause {
1842 quote_identifiers_recursive(&mut wh.this, reserved_words);
1843 }
1844 for e in &mut upd.returning {
1845 quote_identifiers_recursive(e, reserved_words);
1846 }
1847 if let Some(ref mut with) = upd.with {
1848 quote_with(with, reserved_words);
1849 }
1850 }
1851
1852 Expression::Delete(del) => {
1853 quote_table_ref(&mut del.table, reserved_words);
1854 if let Some(ref mut alias) = del.alias {
1855 maybe_quote(alias, reserved_words);
1856 }
1857 for tr in &mut del.using {
1858 quote_table_ref(tr, reserved_words);
1859 }
1860 if let Some(ref mut wh) = del.where_clause {
1861 quote_identifiers_recursive(&mut wh.this, reserved_words);
1862 }
1863 if let Some(ref mut with) = del.with {
1864 quote_with(with, reserved_words);
1865 }
1866 }
1867
1868 Expression::And(bin)
1870 | Expression::Or(bin)
1871 | Expression::Eq(bin)
1872 | Expression::Neq(bin)
1873 | Expression::Lt(bin)
1874 | Expression::Lte(bin)
1875 | Expression::Gt(bin)
1876 | Expression::Gte(bin)
1877 | Expression::Add(bin)
1878 | Expression::Sub(bin)
1879 | Expression::Mul(bin)
1880 | Expression::Div(bin)
1881 | Expression::Mod(bin)
1882 | Expression::BitwiseAnd(bin)
1883 | Expression::BitwiseOr(bin)
1884 | Expression::BitwiseXor(bin)
1885 | Expression::Concat(bin)
1886 | Expression::Adjacent(bin)
1887 | Expression::TsMatch(bin)
1888 | Expression::PropertyEQ(bin)
1889 | Expression::ArrayContainsAll(bin)
1890 | Expression::ArrayContainedBy(bin)
1891 | Expression::ArrayOverlaps(bin)
1892 | Expression::JSONBContainsAllTopKeys(bin)
1893 | Expression::JSONBContainsAnyTopKeys(bin)
1894 | Expression::JSONBDeleteAtPath(bin)
1895 | Expression::ExtendsLeft(bin)
1896 | Expression::ExtendsRight(bin)
1897 | Expression::Is(bin)
1898 | Expression::NullSafeEq(bin)
1899 | Expression::NullSafeNeq(bin)
1900 | Expression::Glob(bin)
1901 | Expression::Match(bin)
1902 | Expression::MemberOf(bin)
1903 | Expression::BitwiseLeftShift(bin)
1904 | Expression::BitwiseRightShift(bin) => {
1905 quote_identifiers_recursive(&mut bin.left, reserved_words);
1906 quote_identifiers_recursive(&mut bin.right, reserved_words);
1907 }
1908
1909 Expression::Like(like) | Expression::ILike(like) => {
1911 quote_identifiers_recursive(&mut like.left, reserved_words);
1912 quote_identifiers_recursive(&mut like.right, reserved_words);
1913 if let Some(ref mut esc) = like.escape {
1914 quote_identifiers_recursive(esc, reserved_words);
1915 }
1916 }
1917
1918 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1920 quote_identifiers_recursive(&mut un.this, reserved_words);
1921 }
1922
1923 Expression::In(in_expr) => {
1925 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1926 for e in &mut in_expr.expressions {
1927 quote_identifiers_recursive(e, reserved_words);
1928 }
1929 if let Some(ref mut q) = in_expr.query {
1930 quote_identifiers_recursive(q, reserved_words);
1931 }
1932 if let Some(ref mut un) = in_expr.unnest {
1933 quote_identifiers_recursive(un, reserved_words);
1934 }
1935 }
1936
1937 Expression::Between(bw) => {
1938 quote_identifiers_recursive(&mut bw.this, reserved_words);
1939 quote_identifiers_recursive(&mut bw.low, reserved_words);
1940 quote_identifiers_recursive(&mut bw.high, reserved_words);
1941 }
1942
1943 Expression::IsNull(is_null) => {
1944 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1945 }
1946
1947 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
1948 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
1949 }
1950
1951 Expression::Exists(ex) => {
1952 quote_identifiers_recursive(&mut ex.this, reserved_words);
1953 }
1954
1955 Expression::Function(func) => {
1957 for arg in &mut func.args {
1958 quote_identifiers_recursive(arg, reserved_words);
1959 }
1960 }
1961
1962 Expression::AggregateFunction(agg) => {
1963 for arg in &mut agg.args {
1964 quote_identifiers_recursive(arg, reserved_words);
1965 }
1966 if let Some(ref mut filter) = agg.filter {
1967 quote_identifiers_recursive(filter, reserved_words);
1968 }
1969 for o in &mut agg.order_by {
1970 quote_identifiers_recursive(&mut o.this, reserved_words);
1971 }
1972 }
1973
1974 Expression::WindowFunction(wf) => {
1975 quote_identifiers_recursive(&mut wf.this, reserved_words);
1976 quote_over(&mut wf.over, reserved_words);
1977 }
1978
1979 Expression::Case(case) => {
1981 if let Some(ref mut operand) = case.operand {
1982 quote_identifiers_recursive(operand, reserved_words);
1983 }
1984 for (when, then) in &mut case.whens {
1985 quote_identifiers_recursive(when, reserved_words);
1986 quote_identifiers_recursive(then, reserved_words);
1987 }
1988 if let Some(ref mut else_) = case.else_ {
1989 quote_identifiers_recursive(else_, reserved_words);
1990 }
1991 }
1992
1993 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1995 quote_identifiers_recursive(&mut cast.this, reserved_words);
1996 if let Some(ref mut fmt) = cast.format {
1997 quote_identifiers_recursive(fmt, reserved_words);
1998 }
1999 }
2000
2001 Expression::Paren(paren) => {
2003 quote_identifiers_recursive(&mut paren.this, reserved_words);
2004 }
2005
2006 Expression::Annotated(ann) => {
2007 quote_identifiers_recursive(&mut ann.this, reserved_words);
2008 }
2009
2010 Expression::With(with) => {
2012 quote_with(with, reserved_words);
2013 }
2014
2015 Expression::Cte(cte) => {
2016 maybe_quote(&mut cte.alias, reserved_words);
2017 for c in &mut cte.columns {
2018 maybe_quote(c, reserved_words);
2019 }
2020 quote_identifiers_recursive(&mut cte.this, reserved_words);
2021 }
2022
2023 Expression::From(from) => {
2025 for e in &mut from.expressions {
2026 quote_identifiers_recursive(e, reserved_words);
2027 }
2028 }
2029
2030 Expression::Join(join) => {
2031 quote_join(join, reserved_words);
2032 }
2033
2034 Expression::JoinedTable(jt) => {
2035 quote_identifiers_recursive(&mut jt.left, reserved_words);
2036 for join in &mut jt.joins {
2037 quote_join(join, reserved_words);
2038 }
2039 if let Some(ref mut alias) = jt.alias {
2040 maybe_quote(alias, reserved_words);
2041 }
2042 }
2043
2044 Expression::Where(wh) => {
2045 quote_identifiers_recursive(&mut wh.this, reserved_words);
2046 }
2047
2048 Expression::GroupBy(gb) => {
2049 for e in &mut gb.expressions {
2050 quote_identifiers_recursive(e, reserved_words);
2051 }
2052 }
2053
2054 Expression::Having(hv) => {
2055 quote_identifiers_recursive(&mut hv.this, reserved_words);
2056 }
2057
2058 Expression::OrderBy(ob) => {
2059 for o in &mut ob.expressions {
2060 quote_identifiers_recursive(&mut o.this, reserved_words);
2061 }
2062 }
2063
2064 Expression::Ordered(ord) => {
2065 quote_identifiers_recursive(&mut ord.this, reserved_words);
2066 }
2067
2068 Expression::Limit(lim) => {
2069 quote_identifiers_recursive(&mut lim.this, reserved_words);
2070 }
2071
2072 Expression::Offset(off) => {
2073 quote_identifiers_recursive(&mut off.this, reserved_words);
2074 }
2075
2076 Expression::Qualify(q) => {
2077 quote_identifiers_recursive(&mut q.this, reserved_words);
2078 }
2079
2080 Expression::Window(ws) => {
2081 for e in &mut ws.partition_by {
2082 quote_identifiers_recursive(e, reserved_words);
2083 }
2084 for o in &mut ws.order_by {
2085 quote_identifiers_recursive(&mut o.this, reserved_words);
2086 }
2087 }
2088
2089 Expression::Over(over) => {
2090 quote_over(over, reserved_words);
2091 }
2092
2093 Expression::WithinGroup(wg) => {
2094 quote_identifiers_recursive(&mut wg.this, reserved_words);
2095 for o in &mut wg.order_by {
2096 quote_identifiers_recursive(&mut o.this, reserved_words);
2097 }
2098 }
2099
2100 Expression::Pivot(piv) => {
2102 quote_identifiers_recursive(&mut piv.this, reserved_words);
2103 for e in &mut piv.expressions {
2104 quote_identifiers_recursive(e, reserved_words);
2105 }
2106 for f in &mut piv.fields {
2107 quote_identifiers_recursive(f, reserved_words);
2108 }
2109 if let Some(ref mut alias) = piv.alias {
2110 maybe_quote(alias, reserved_words);
2111 }
2112 }
2113
2114 Expression::Unpivot(unpiv) => {
2115 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
2116 maybe_quote(&mut unpiv.value_column, reserved_words);
2117 maybe_quote(&mut unpiv.name_column, reserved_words);
2118 for e in &mut unpiv.columns {
2119 quote_identifiers_recursive(e, reserved_words);
2120 }
2121 if let Some(ref mut alias) = unpiv.alias {
2122 maybe_quote(alias, reserved_words);
2123 }
2124 }
2125
2126 Expression::Values(vals) => {
2128 for tuple in &mut vals.expressions {
2129 for e in &mut tuple.expressions {
2130 quote_identifiers_recursive(e, reserved_words);
2131 }
2132 }
2133 if let Some(ref mut alias) = vals.alias {
2134 maybe_quote(alias, reserved_words);
2135 }
2136 for ca in &mut vals.column_aliases {
2137 maybe_quote(ca, reserved_words);
2138 }
2139 }
2140
2141 Expression::Array(arr) => {
2143 for e in &mut arr.expressions {
2144 quote_identifiers_recursive(e, reserved_words);
2145 }
2146 }
2147
2148 Expression::Struct(st) => {
2149 for (_name, e) in &mut st.fields {
2150 quote_identifiers_recursive(e, reserved_words);
2151 }
2152 }
2153
2154 Expression::Tuple(tup) => {
2155 for e in &mut tup.expressions {
2156 quote_identifiers_recursive(e, reserved_words);
2157 }
2158 }
2159
2160 Expression::Subscript(sub) => {
2162 quote_identifiers_recursive(&mut sub.this, reserved_words);
2163 quote_identifiers_recursive(&mut sub.index, reserved_words);
2164 }
2165
2166 Expression::Dot(dot) => {
2167 quote_identifiers_recursive(&mut dot.this, reserved_words);
2168 maybe_quote(&mut dot.field, reserved_words);
2169 }
2170
2171 Expression::ScopeResolution(sr) => {
2172 if let Some(ref mut this) = sr.this {
2173 quote_identifiers_recursive(this, reserved_words);
2174 }
2175 quote_identifiers_recursive(&mut sr.expression, reserved_words);
2176 }
2177
2178 Expression::Lateral(lat) => {
2180 quote_identifiers_recursive(&mut lat.this, reserved_words);
2181 }
2183
2184 Expression::DPipe(dpipe) => {
2186 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
2187 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
2188 }
2189
2190 Expression::Merge(merge) => {
2192 quote_identifiers_recursive(&mut merge.this, reserved_words);
2193 quote_identifiers_recursive(&mut merge.using, reserved_words);
2194 if let Some(ref mut on) = merge.on {
2195 quote_identifiers_recursive(on, reserved_words);
2196 }
2197 if let Some(ref mut whens) = merge.whens {
2198 quote_identifiers_recursive(whens, reserved_words);
2199 }
2200 if let Some(ref mut with) = merge.with_ {
2201 quote_identifiers_recursive(with, reserved_words);
2202 }
2203 if let Some(ref mut ret) = merge.returning {
2204 quote_identifiers_recursive(ret, reserved_words);
2205 }
2206 }
2207
2208 Expression::LateralView(lv) => {
2210 quote_lateral_view(lv, reserved_words);
2211 }
2212
2213 Expression::Anonymous(anon) => {
2215 quote_identifiers_recursive(&mut anon.this, reserved_words);
2216 for e in &mut anon.expressions {
2217 quote_identifiers_recursive(e, reserved_words);
2218 }
2219 }
2220
2221 Expression::Filter(filter) => {
2223 quote_identifiers_recursive(&mut filter.this, reserved_words);
2224 quote_identifiers_recursive(&mut filter.expression, reserved_words);
2225 }
2226
2227 Expression::Returning(ret) => {
2229 for e in &mut ret.expressions {
2230 quote_identifiers_recursive(e, reserved_words);
2231 }
2232 }
2233
2234 Expression::BracedWildcard(inner) => {
2236 quote_identifiers_recursive(inner, reserved_words);
2237 }
2238
2239 Expression::ReturnStmt(inner) => {
2241 quote_identifiers_recursive(inner, reserved_words);
2242 }
2243
2244 Expression::Literal(_)
2246 | Expression::Boolean(_)
2247 | Expression::Null(_)
2248 | Expression::DataType(_)
2249 | Expression::Raw(_)
2250 | Expression::Placeholder(_)
2251 | Expression::CurrentDate(_)
2252 | Expression::CurrentTime(_)
2253 | Expression::CurrentTimestamp(_)
2254 | Expression::CurrentTimestampLTZ(_)
2255 | Expression::SessionUser(_)
2256 | Expression::RowNumber(_)
2257 | Expression::Rank(_)
2258 | Expression::DenseRank(_)
2259 | Expression::PercentRank(_)
2260 | Expression::CumeDist(_)
2261 | Expression::Random(_)
2262 | Expression::Pi(_)
2263 | Expression::JSONPathRoot(_) => {
2264 }
2266
2267 _ => {}
2271 }
2272}
2273
2274fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
2276 quote_identifiers_recursive(&mut join.this, reserved_words);
2277 if let Some(ref mut on) = join.on {
2278 quote_identifiers_recursive(on, reserved_words);
2279 }
2280 for id in &mut join.using {
2281 maybe_quote(id, reserved_words);
2282 }
2283 if let Some(ref mut mc) = join.match_condition {
2284 quote_identifiers_recursive(mc, reserved_words);
2285 }
2286 for piv in &mut join.pivots {
2287 quote_identifiers_recursive(piv, reserved_words);
2288 }
2289}
2290
2291fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
2293 for cte in &mut with.ctes {
2294 maybe_quote(&mut cte.alias, reserved_words);
2295 for c in &mut cte.columns {
2296 maybe_quote(c, reserved_words);
2297 }
2298 for k in &mut cte.key_expressions {
2299 maybe_quote(k, reserved_words);
2300 }
2301 quote_identifiers_recursive(&mut cte.this, reserved_words);
2302 }
2303}
2304
2305fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
2307 if let Some(ref mut wn) = over.window_name {
2308 maybe_quote(wn, reserved_words);
2309 }
2310 for e in &mut over.partition_by {
2311 quote_identifiers_recursive(e, reserved_words);
2312 }
2313 for o in &mut over.order_by {
2314 quote_identifiers_recursive(&mut o.this, reserved_words);
2315 }
2316 if let Some(ref mut alias) = over.alias {
2317 maybe_quote(alias, reserved_words);
2318 }
2319}
2320
2321fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
2323 maybe_quote(&mut table_ref.name, reserved_words);
2324 if let Some(ref mut schema) = table_ref.schema {
2325 maybe_quote(schema, reserved_words);
2326 }
2327 if let Some(ref mut catalog) = table_ref.catalog {
2328 maybe_quote(catalog, reserved_words);
2329 }
2330 if let Some(ref mut alias) = table_ref.alias {
2331 maybe_quote(alias, reserved_words);
2332 }
2333 for ca in &mut table_ref.column_aliases {
2334 maybe_quote(ca, reserved_words);
2335 }
2336 for p in &mut table_ref.partitions {
2337 maybe_quote(p, reserved_words);
2338 }
2339 for h in &mut table_ref.hints {
2340 quote_identifiers_recursive(h, reserved_words);
2341 }
2342}
2343
2344fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
2346 quote_identifiers_recursive(&mut lv.this, reserved_words);
2347 if let Some(ref mut ta) = lv.table_alias {
2348 maybe_quote(ta, reserved_words);
2349 }
2350 for ca in &mut lv.column_aliases {
2351 maybe_quote(ca, reserved_words);
2352 }
2353}
2354
2355pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
2366 let reserved_words = get_reserved_words(dialect);
2367 let mut result = expression;
2368 quote_identifiers_recursive(&mut result, &reserved_words);
2369 result
2370}
2371
2372pub fn pushdown_cte_alias_columns(_scope: &Scope) {
2377 }
2380
2381fn pushdown_cte_alias_columns_with(with: &mut With) {
2382 for cte in &mut with.ctes {
2383 if cte.columns.is_empty() {
2384 continue;
2385 }
2386
2387 if let Expression::Select(select) = &mut cte.this {
2388 let mut next_expressions = Vec::with_capacity(select.expressions.len());
2389
2390 for (i, projection) in select.expressions.iter().enumerate() {
2391 let Some(alias_name) = cte.columns.get(i) else {
2392 next_expressions.push(projection.clone());
2393 continue;
2394 };
2395
2396 match projection {
2397 Expression::Alias(existing) => {
2398 let mut aliased = existing.clone();
2399 aliased.alias = alias_name.clone();
2400 next_expressions.push(Expression::Alias(aliased));
2401 }
2402 _ => {
2403 next_expressions.push(create_alias(projection.clone(), &alias_name.name));
2404 }
2405 }
2406 }
2407
2408 select.expressions = next_expressions;
2409 }
2410 }
2411}
2412
2413fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
2419 let mut columns = Vec::new();
2420 collect_columns(&scope.expression, &mut columns);
2421 columns
2422}
2423
2424#[derive(Debug, Clone)]
2426struct ColumnRef {
2427 table: Option<String>,
2428 name: String,
2429}
2430
2431fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
2433 match expr {
2434 Expression::Column(col) => {
2435 columns.push(ColumnRef {
2436 table: col.table.as_ref().map(|t| t.name.clone()),
2437 name: col.name.name.clone(),
2438 });
2439 }
2440 Expression::Select(select) => {
2441 for e in &select.expressions {
2442 collect_columns(e, columns);
2443 }
2444 if let Some(from) = &select.from {
2445 for e in &from.expressions {
2446 collect_columns(e, columns);
2447 }
2448 }
2449 if let Some(where_clause) = &select.where_clause {
2450 collect_columns(&where_clause.this, columns);
2451 }
2452 if let Some(group_by) = &select.group_by {
2453 for e in &group_by.expressions {
2454 collect_columns(e, columns);
2455 }
2456 }
2457 if let Some(having) = &select.having {
2458 collect_columns(&having.this, columns);
2459 }
2460 if let Some(order_by) = &select.order_by {
2461 for o in &order_by.expressions {
2462 collect_columns(&o.this, columns);
2463 }
2464 }
2465 for join in &select.joins {
2466 collect_columns(&join.this, columns);
2467 if let Some(on) = &join.on {
2468 collect_columns(on, columns);
2469 }
2470 }
2471 }
2472 Expression::Alias(alias) => {
2473 collect_columns(&alias.this, columns);
2474 }
2475 Expression::Function(func) => {
2476 for arg in &func.args {
2477 collect_columns(arg, columns);
2478 }
2479 }
2480 Expression::AggregateFunction(agg) => {
2481 for arg in &agg.args {
2482 collect_columns(arg, columns);
2483 }
2484 }
2485 Expression::And(bin)
2486 | Expression::Or(bin)
2487 | Expression::Eq(bin)
2488 | Expression::Neq(bin)
2489 | Expression::Lt(bin)
2490 | Expression::Lte(bin)
2491 | Expression::Gt(bin)
2492 | Expression::Gte(bin)
2493 | Expression::Add(bin)
2494 | Expression::Sub(bin)
2495 | Expression::Mul(bin)
2496 | Expression::Div(bin) => {
2497 collect_columns(&bin.left, columns);
2498 collect_columns(&bin.right, columns);
2499 }
2500 Expression::Not(unary) | Expression::Neg(unary) => {
2501 collect_columns(&unary.this, columns);
2502 }
2503 Expression::Paren(paren) => {
2504 collect_columns(&paren.this, columns);
2505 }
2506 Expression::Case(case) => {
2507 if let Some(operand) = &case.operand {
2508 collect_columns(operand, columns);
2509 }
2510 for (when, then) in &case.whens {
2511 collect_columns(when, columns);
2512 collect_columns(then, columns);
2513 }
2514 if let Some(else_) = &case.else_ {
2515 collect_columns(else_, columns);
2516 }
2517 }
2518 Expression::Cast(cast) => {
2519 collect_columns(&cast.this, columns);
2520 }
2521 Expression::In(in_expr) => {
2522 collect_columns(&in_expr.this, columns);
2523 for e in &in_expr.expressions {
2524 collect_columns(e, columns);
2525 }
2526 if let Some(query) = &in_expr.query {
2527 collect_columns(query, columns);
2528 }
2529 }
2530 Expression::Between(between) => {
2531 collect_columns(&between.this, columns);
2532 collect_columns(&between.low, columns);
2533 collect_columns(&between.high, columns);
2534 }
2535 Expression::Subquery(subquery) => {
2536 collect_columns(&subquery.this, columns);
2537 }
2538 _ => {}
2539 }
2540}
2541
2542fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
2544 get_scope_columns(scope)
2545 .into_iter()
2546 .filter(|c| c.table.is_none())
2547 .collect()
2548}
2549
2550fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
2552 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
2553
2554 get_scope_columns(scope)
2555 .into_iter()
2556 .filter(|c| {
2557 if let Some(table) = &c.table {
2558 !source_names.contains(table)
2559 } else {
2560 false
2561 }
2562 })
2563 .collect()
2564}
2565
2566fn is_correlated_subquery(scope: &Scope) -> bool {
2568 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2569}
2570
2571fn is_star_column(col: &Column) -> bool {
2573 col.name.name == "*"
2574}
2575
2576fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2578 Expression::Column(Column {
2579 name: Identifier::new(name),
2580 table: table.map(Identifier::new),
2581 join_mark: false,
2582 trailing_comments: vec![],
2583 span: None,
2584 inferred_type: None,
2585 })
2586}
2587
2588fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2590 Expression::Alias(Box::new(Alias {
2591 this: expr,
2592 alias: Identifier::new(alias_name),
2593 column_aliases: vec![],
2594 pre_alias_comments: vec![],
2595 trailing_comments: vec![],
2596 inferred_type: None,
2597 }))
2598}
2599
2600fn get_output_name(expr: &Expression) -> Option<String> {
2602 match expr {
2603 Expression::Column(col) => Some(col.name.name.clone()),
2604 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2605 Expression::Identifier(id) => Some(id.name.clone()),
2606 _ => None,
2607 }
2608}
2609
2610#[cfg(test)]
2611mod tests {
2612 use super::*;
2613 use crate::expressions::DataType;
2614 use crate::generator::Generator;
2615 use crate::parser::Parser;
2616 use crate::scope::build_scope;
2617 use crate::{MappingSchema, Schema};
2618
2619 fn gen(expr: &Expression) -> String {
2620 Generator::new().generate(expr).unwrap()
2621 }
2622
2623 fn parse(sql: &str) -> Expression {
2624 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2625 }
2626
2627 #[test]
2628 fn test_qualify_columns_options() {
2629 let options = QualifyColumnsOptions::new()
2630 .with_expand_alias_refs(true)
2631 .with_expand_stars(false)
2632 .with_dialect(DialectType::PostgreSQL)
2633 .with_allow_partial(true);
2634
2635 assert!(options.expand_alias_refs);
2636 assert!(!options.expand_stars);
2637 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2638 assert!(options.allow_partial_qualification);
2639 }
2640
2641 #[test]
2642 fn test_get_scope_columns() {
2643 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2644 let scope = build_scope(&expr);
2645 let columns = get_scope_columns(&scope);
2646
2647 assert!(columns.iter().any(|c| c.name == "a"));
2648 assert!(columns.iter().any(|c| c.name == "b"));
2649 assert!(columns.iter().any(|c| c.name == "c"));
2650 }
2651
2652 #[test]
2653 fn test_get_unqualified_columns() {
2654 let expr = parse("SELECT t.a, b FROM t");
2655 let scope = build_scope(&expr);
2656 let unqualified = get_unqualified_columns(&scope);
2657
2658 assert!(unqualified.iter().any(|c| c.name == "b"));
2660 assert!(!unqualified.iter().any(|c| c.name == "a"));
2661 }
2662
2663 #[test]
2664 fn test_is_star_column() {
2665 let col = Column {
2666 name: Identifier::new("*"),
2667 table: Some(Identifier::new("t")),
2668 join_mark: false,
2669 trailing_comments: vec![],
2670 span: None,
2671 inferred_type: None,
2672 };
2673 assert!(is_star_column(&col));
2674
2675 let col2 = Column {
2676 name: Identifier::new("id"),
2677 table: None,
2678 join_mark: false,
2679 trailing_comments: vec![],
2680 span: None,
2681 inferred_type: None,
2682 };
2683 assert!(!is_star_column(&col2));
2684 }
2685
2686 #[test]
2687 fn test_create_qualified_column() {
2688 let expr = create_qualified_column("id", Some("users"));
2689 let sql = gen(&expr);
2690 assert!(sql.contains("users"));
2691 assert!(sql.contains("id"));
2692 }
2693
2694 #[test]
2695 fn test_create_alias() {
2696 let col = Expression::Column(Column {
2697 name: Identifier::new("value"),
2698 table: None,
2699 join_mark: false,
2700 trailing_comments: vec![],
2701 span: None,
2702 inferred_type: None,
2703 });
2704 let aliased = create_alias(col, "total");
2705 let sql = gen(&aliased);
2706 assert!(sql.contains("AS") || sql.contains("total"));
2707 }
2708
2709 #[test]
2710 fn test_validate_qualify_columns_success() {
2711 let expr = parse("SELECT t.a, t.b FROM t");
2713 let result = validate_qualify_columns(&expr);
2714 let _ = result;
2717 }
2718
2719 #[test]
2720 fn test_collect_columns_nested() {
2721 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2722 let mut columns = Vec::new();
2723 collect_columns(&expr, &mut columns);
2724
2725 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2726 assert!(names.contains(&"a"));
2727 assert!(names.contains(&"b"));
2728 assert!(names.contains(&"c"));
2729 assert!(names.contains(&"d"));
2730 assert!(names.contains(&"e"));
2731 assert!(names.contains(&"f"));
2732 }
2733
2734 #[test]
2735 fn test_collect_columns_in_case() {
2736 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2737 let mut columns = Vec::new();
2738 collect_columns(&expr, &mut columns);
2739
2740 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2741 assert!(names.contains(&"a"));
2742 assert!(names.contains(&"b"));
2743 assert!(names.contains(&"c"));
2744 }
2745
2746 #[test]
2747 fn test_collect_columns_in_subquery() {
2748 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2749 let mut columns = Vec::new();
2750 collect_columns(&expr, &mut columns);
2751
2752 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2753 assert!(names.contains(&"a"));
2754 assert!(names.contains(&"b"));
2755 assert!(names.contains(&"c"));
2756 }
2757
2758 #[test]
2759 fn test_qualify_outputs_basic() {
2760 let expr = parse("SELECT a, b + c FROM t");
2761 let scope = build_scope(&expr);
2762 let result = qualify_outputs(&scope);
2763 assert!(result.is_ok());
2764 }
2765
2766 #[test]
2767 fn test_qualify_columns_expands_star_with_schema() {
2768 let expr = parse("SELECT * FROM users");
2769
2770 let mut schema = MappingSchema::new();
2771 schema
2772 .add_table(
2773 "users",
2774 &[
2775 (
2776 "id".to_string(),
2777 DataType::Int {
2778 length: None,
2779 integer_spelling: false,
2780 },
2781 ),
2782 ("name".to_string(), DataType::Text),
2783 ("email".to_string(), DataType::Text),
2784 ],
2785 None,
2786 )
2787 .expect("schema setup");
2788
2789 let result =
2790 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2791 let sql = gen(&result);
2792
2793 assert!(!sql.contains("SELECT *"));
2794 assert!(sql.contains("users.id"));
2795 assert!(sql.contains("users.name"));
2796 assert!(sql.contains("users.email"));
2797 }
2798
2799 #[test]
2800 fn test_qualify_columns_expands_group_by_positions() {
2801 let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
2802
2803 let mut schema = MappingSchema::new();
2804 schema
2805 .add_table(
2806 "t",
2807 &[
2808 (
2809 "a".to_string(),
2810 DataType::Int {
2811 length: None,
2812 integer_spelling: false,
2813 },
2814 ),
2815 (
2816 "b".to_string(),
2817 DataType::Int {
2818 length: None,
2819 integer_spelling: false,
2820 },
2821 ),
2822 ],
2823 None,
2824 )
2825 .expect("schema setup");
2826
2827 let result =
2828 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2829 let sql = gen(&result);
2830
2831 assert!(!sql.contains("GROUP BY 1"));
2832 assert!(!sql.contains("GROUP BY 2"));
2833 assert!(sql.contains("GROUP BY"));
2834 assert!(sql.contains("t.a"));
2835 assert!(sql.contains("t.b"));
2836 }
2837
2838 #[test]
2843 fn test_expand_using_simple() {
2844 let expr = parse("SELECT x.b FROM x JOIN y USING (b)");
2846
2847 let mut schema = MappingSchema::new();
2848 schema
2849 .add_table(
2850 "x",
2851 &[
2852 ("a".to_string(), DataType::BigInt { length: None }),
2853 ("b".to_string(), DataType::BigInt { length: None }),
2854 ],
2855 None,
2856 )
2857 .expect("schema setup");
2858 schema
2859 .add_table(
2860 "y",
2861 &[
2862 ("b".to_string(), DataType::BigInt { length: None }),
2863 ("c".to_string(), DataType::BigInt { length: None }),
2864 ],
2865 None,
2866 )
2867 .expect("schema setup");
2868
2869 let result =
2870 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2871 let sql = gen(&result);
2872
2873 assert!(
2875 !sql.contains("USING"),
2876 "USING should be replaced with ON: {sql}"
2877 );
2878 assert!(
2879 sql.contains("ON x.b = y.b"),
2880 "ON condition should be x.b = y.b: {sql}"
2881 );
2882 assert!(sql.contains("SELECT x.b"), "SELECT should keep x.b: {sql}");
2884 }
2885
2886 #[test]
2887 fn test_expand_using_unqualified_coalesce() {
2888 let expr = parse("SELECT b FROM x JOIN y USING(b)");
2890
2891 let mut schema = MappingSchema::new();
2892 schema
2893 .add_table(
2894 "x",
2895 &[
2896 ("a".to_string(), DataType::BigInt { length: None }),
2897 ("b".to_string(), DataType::BigInt { length: None }),
2898 ],
2899 None,
2900 )
2901 .expect("schema setup");
2902 schema
2903 .add_table(
2904 "y",
2905 &[
2906 ("b".to_string(), DataType::BigInt { length: None }),
2907 ("c".to_string(), DataType::BigInt { length: None }),
2908 ],
2909 None,
2910 )
2911 .expect("schema setup");
2912
2913 let result =
2914 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2915 let sql = gen(&result);
2916
2917 assert!(
2918 sql.contains("COALESCE(x.b, y.b)"),
2919 "Unqualified USING column should become COALESCE: {sql}"
2920 );
2921 assert!(
2922 sql.contains("AS b"),
2923 "COALESCE should be aliased as 'b': {sql}"
2924 );
2925 assert!(
2926 sql.contains("ON x.b = y.b"),
2927 "ON condition should be generated: {sql}"
2928 );
2929 }
2930
2931 #[test]
2932 fn test_expand_using_with_where() {
2933 let expr = parse("SELECT b FROM x JOIN y USING(b) WHERE b = 1");
2935
2936 let mut schema = MappingSchema::new();
2937 schema
2938 .add_table(
2939 "x",
2940 &[("b".to_string(), DataType::BigInt { length: None })],
2941 None,
2942 )
2943 .expect("schema setup");
2944 schema
2945 .add_table(
2946 "y",
2947 &[("b".to_string(), DataType::BigInt { length: None })],
2948 None,
2949 )
2950 .expect("schema setup");
2951
2952 let result =
2953 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2954 let sql = gen(&result);
2955
2956 assert!(
2957 sql.contains("WHERE COALESCE(x.b, y.b)"),
2958 "WHERE should use COALESCE for USING column: {sql}"
2959 );
2960 }
2961
2962 #[test]
2963 fn test_expand_using_multi_join() {
2964 let expr = parse("SELECT b FROM x JOIN y USING(b) JOIN z USING(b)");
2966
2967 let mut schema = MappingSchema::new();
2968 for table in &["x", "y", "z"] {
2969 schema
2970 .add_table(
2971 table,
2972 &[("b".to_string(), DataType::BigInt { length: None })],
2973 None,
2974 )
2975 .expect("schema setup");
2976 }
2977
2978 let result =
2979 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2980 let sql = gen(&result);
2981
2982 assert!(
2984 sql.contains("COALESCE(x.b, y.b, z.b)"),
2985 "Should have 3-table COALESCE: {sql}"
2986 );
2987 assert!(
2989 sql.contains("ON x.b = y.b"),
2990 "First join ON condition: {sql}"
2991 );
2992 }
2993
2994 #[test]
2995 fn test_expand_using_multi_column() {
2996 let expr = parse("SELECT b, c FROM y JOIN z USING(b, c)");
2998
2999 let mut schema = MappingSchema::new();
3000 schema
3001 .add_table(
3002 "y",
3003 &[
3004 ("b".to_string(), DataType::BigInt { length: None }),
3005 ("c".to_string(), DataType::BigInt { length: None }),
3006 ],
3007 None,
3008 )
3009 .expect("schema setup");
3010 schema
3011 .add_table(
3012 "z",
3013 &[
3014 ("b".to_string(), DataType::BigInt { length: None }),
3015 ("c".to_string(), DataType::BigInt { length: None }),
3016 ],
3017 None,
3018 )
3019 .expect("schema setup");
3020
3021 let result =
3022 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3023 let sql = gen(&result);
3024
3025 assert!(
3026 sql.contains("COALESCE(y.b, z.b)"),
3027 "column 'b' should get COALESCE: {sql}"
3028 );
3029 assert!(
3030 sql.contains("COALESCE(y.c, z.c)"),
3031 "column 'c' should get COALESCE: {sql}"
3032 );
3033 assert!(
3035 sql.contains("y.b = z.b") && sql.contains("y.c = z.c"),
3036 "ON should have both equality conditions: {sql}"
3037 );
3038 }
3039
3040 #[test]
3041 fn test_expand_using_star() {
3042 let expr = parse("SELECT * FROM x JOIN y USING(b)");
3044
3045 let mut schema = MappingSchema::new();
3046 schema
3047 .add_table(
3048 "x",
3049 &[
3050 ("a".to_string(), DataType::BigInt { length: None }),
3051 ("b".to_string(), DataType::BigInt { length: None }),
3052 ],
3053 None,
3054 )
3055 .expect("schema setup");
3056 schema
3057 .add_table(
3058 "y",
3059 &[
3060 ("b".to_string(), DataType::BigInt { length: None }),
3061 ("c".to_string(), DataType::BigInt { length: None }),
3062 ],
3063 None,
3064 )
3065 .expect("schema setup");
3066
3067 let result =
3068 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3069 let sql = gen(&result);
3070
3071 assert!(
3073 sql.contains("COALESCE(x.b, y.b) AS b"),
3074 "USING column should be COALESCE in star expansion: {sql}"
3075 );
3076 assert!(sql.contains("x.a"), "non-USING column a from x: {sql}");
3078 assert!(sql.contains("y.c"), "non-USING column c from y: {sql}");
3079 let coalesce_count = sql.matches("COALESCE").count();
3081 assert_eq!(
3082 coalesce_count, 1,
3083 "b should appear only once as COALESCE: {sql}"
3084 );
3085 }
3086
3087 #[test]
3088 fn test_expand_using_table_star() {
3089 let expr = parse("SELECT x.* FROM x JOIN y USING(b)");
3091
3092 let mut schema = MappingSchema::new();
3093 schema
3094 .add_table(
3095 "x",
3096 &[
3097 ("a".to_string(), DataType::BigInt { length: None }),
3098 ("b".to_string(), DataType::BigInt { length: None }),
3099 ],
3100 None,
3101 )
3102 .expect("schema setup");
3103 schema
3104 .add_table(
3105 "y",
3106 &[
3107 ("b".to_string(), DataType::BigInt { length: None }),
3108 ("c".to_string(), DataType::BigInt { length: None }),
3109 ],
3110 None,
3111 )
3112 .expect("schema setup");
3113
3114 let result =
3115 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3116 let sql = gen(&result);
3117
3118 assert!(
3120 sql.contains("COALESCE(x.b, y.b)"),
3121 "USING column from x.* should become COALESCE: {sql}"
3122 );
3123 assert!(sql.contains("x.a"), "non-USING column a: {sql}");
3124 }
3125
3126 #[test]
3127 fn test_qualify_columns_qualified_table_name() {
3128 let expr = parse("SELECT a FROM raw.t1");
3129
3130 let mut schema = MappingSchema::new();
3131 schema
3132 .add_table(
3133 "raw.t1",
3134 &[("a".to_string(), DataType::BigInt { length: None })],
3135 None,
3136 )
3137 .expect("schema setup");
3138
3139 let result =
3140 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3141 let sql = gen(&result);
3142
3143 assert!(
3144 sql.contains("t1.a"),
3145 "column should be qualified with table name: {sql}"
3146 );
3147 }
3148
3149 #[test]
3150 fn test_qualify_columns_correlated_scalar_subquery() {
3151 let expr =
3152 parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1");
3153
3154 let mut schema = MappingSchema::new();
3155 schema
3156 .add_table(
3157 "t1",
3158 &[("id".to_string(), DataType::BigInt { length: None })],
3159 None,
3160 )
3161 .expect("schema setup");
3162 schema
3163 .add_table(
3164 "t2",
3165 &[
3166 ("id".to_string(), DataType::BigInt { length: None }),
3167 ("val".to_string(), DataType::BigInt { length: None }),
3168 ],
3169 None,
3170 )
3171 .expect("schema setup");
3172
3173 let result =
3174 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3175 let sql = gen(&result);
3176
3177 assert!(
3178 sql.contains("t1.id"),
3179 "outer column should be qualified: {sql}"
3180 );
3181 assert!(
3182 sql.contains("t2.id"),
3183 "inner column should be qualified: {sql}"
3184 );
3185 }
3186
3187 #[test]
3188 fn test_qualify_columns_rejects_unknown_table() {
3189 let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1");
3190
3191 let mut schema = MappingSchema::new();
3192 schema
3193 .add_table(
3194 "t1",
3195 &[("id".to_string(), DataType::BigInt { length: None })],
3196 None,
3197 )
3198 .expect("schema setup");
3199
3200 let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new());
3201 assert!(
3202 result.is_err(),
3203 "should reject reference to table not in scope or schema"
3204 );
3205 }
3206
3207 #[test]
3212 fn test_needs_quoting_reserved_word() {
3213 let reserved = get_reserved_words(None);
3214 assert!(needs_quoting("select", &reserved));
3215 assert!(needs_quoting("SELECT", &reserved));
3216 assert!(needs_quoting("from", &reserved));
3217 assert!(needs_quoting("WHERE", &reserved));
3218 assert!(needs_quoting("join", &reserved));
3219 assert!(needs_quoting("table", &reserved));
3220 }
3221
3222 #[test]
3223 fn test_needs_quoting_normal_identifiers() {
3224 let reserved = get_reserved_words(None);
3225 assert!(!needs_quoting("foo", &reserved));
3226 assert!(!needs_quoting("my_column", &reserved));
3227 assert!(!needs_quoting("col1", &reserved));
3228 assert!(!needs_quoting("A", &reserved));
3229 assert!(!needs_quoting("_hidden", &reserved));
3230 }
3231
3232 #[test]
3233 fn test_needs_quoting_special_characters() {
3234 let reserved = get_reserved_words(None);
3235 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
3241
3242 #[test]
3243 fn test_needs_quoting_starts_with_digit() {
3244 let reserved = get_reserved_words(None);
3245 assert!(needs_quoting("1col", &reserved));
3246 assert!(needs_quoting("123", &reserved));
3247 assert!(needs_quoting("0_start", &reserved));
3248 }
3249
3250 #[test]
3251 fn test_needs_quoting_empty() {
3252 let reserved = get_reserved_words(None);
3253 assert!(!needs_quoting("", &reserved));
3254 }
3255
3256 #[test]
3257 fn test_maybe_quote_sets_quoted_flag() {
3258 let reserved = get_reserved_words(None);
3259 let mut id = Identifier::new("select");
3260 assert!(!id.quoted);
3261 maybe_quote(&mut id, &reserved);
3262 assert!(id.quoted);
3263 }
3264
3265 #[test]
3266 fn test_maybe_quote_skips_already_quoted() {
3267 let reserved = get_reserved_words(None);
3268 let mut id = Identifier::quoted("myname");
3269 assert!(id.quoted);
3270 maybe_quote(&mut id, &reserved);
3271 assert!(id.quoted); assert_eq!(id.name, "myname"); }
3274
3275 #[test]
3276 fn test_maybe_quote_skips_star() {
3277 let reserved = get_reserved_words(None);
3278 let mut id = Identifier::new("*");
3279 maybe_quote(&mut id, &reserved);
3280 assert!(!id.quoted); }
3282
3283 #[test]
3284 fn test_maybe_quote_skips_normal() {
3285 let reserved = get_reserved_words(None);
3286 let mut id = Identifier::new("normal_col");
3287 maybe_quote(&mut id, &reserved);
3288 assert!(!id.quoted);
3289 }
3290
3291 #[test]
3292 fn test_quote_identifiers_column_with_reserved_name() {
3293 let expr = Expression::Column(Column {
3295 name: Identifier::new("select"),
3296 table: None,
3297 join_mark: false,
3298 trailing_comments: vec![],
3299 span: None,
3300 inferred_type: None,
3301 });
3302 let result = quote_identifiers(expr, None);
3303 if let Expression::Column(col) = &result {
3304 assert!(col.name.quoted, "Column named 'select' should be quoted");
3305 } else {
3306 panic!("Expected Column expression");
3307 }
3308 }
3309
3310 #[test]
3311 fn test_quote_identifiers_column_with_special_chars() {
3312 let expr = Expression::Column(Column {
3313 name: Identifier::new("my column"),
3314 table: None,
3315 join_mark: false,
3316 trailing_comments: vec![],
3317 span: None,
3318 inferred_type: None,
3319 });
3320 let result = quote_identifiers(expr, None);
3321 if let Expression::Column(col) = &result {
3322 assert!(col.name.quoted, "Column with space should be quoted");
3323 } else {
3324 panic!("Expected Column expression");
3325 }
3326 }
3327
3328 #[test]
3329 fn test_quote_identifiers_preserves_normal_column() {
3330 let expr = Expression::Column(Column {
3331 name: Identifier::new("normal_col"),
3332 table: Some(Identifier::new("my_table")),
3333 join_mark: false,
3334 trailing_comments: vec![],
3335 span: None,
3336 inferred_type: None,
3337 });
3338 let result = quote_identifiers(expr, None);
3339 if let Expression::Column(col) = &result {
3340 assert!(!col.name.quoted, "Normal column should not be quoted");
3341 assert!(
3342 !col.table.as_ref().unwrap().quoted,
3343 "Normal table should not be quoted"
3344 );
3345 } else {
3346 panic!("Expected Column expression");
3347 }
3348 }
3349
3350 #[test]
3351 fn test_quote_identifiers_table_ref_reserved() {
3352 let expr = Expression::Table(TableRef::new("select"));
3353 let result = quote_identifiers(expr, None);
3354 if let Expression::Table(tr) = &result {
3355 assert!(tr.name.quoted, "Table named 'select' should be quoted");
3356 } else {
3357 panic!("Expected Table expression");
3358 }
3359 }
3360
3361 #[test]
3362 fn test_quote_identifiers_table_ref_schema_and_alias() {
3363 let mut tr = TableRef::new("my_table");
3364 tr.schema = Some(Identifier::new("from"));
3365 tr.alias = Some(Identifier::new("t"));
3366 let expr = Expression::Table(tr);
3367 let result = quote_identifiers(expr, None);
3368 if let Expression::Table(tr) = &result {
3369 assert!(!tr.name.quoted, "Normal table name should not be quoted");
3370 assert!(
3371 tr.schema.as_ref().unwrap().quoted,
3372 "Schema named 'from' should be quoted"
3373 );
3374 assert!(
3375 !tr.alias.as_ref().unwrap().quoted,
3376 "Normal alias should not be quoted"
3377 );
3378 } else {
3379 panic!("Expected Table expression");
3380 }
3381 }
3382
3383 #[test]
3384 fn test_quote_identifiers_identifier_node() {
3385 let expr = Expression::Identifier(Identifier::new("order"));
3386 let result = quote_identifiers(expr, None);
3387 if let Expression::Identifier(id) = &result {
3388 assert!(id.quoted, "Identifier named 'order' should be quoted");
3389 } else {
3390 panic!("Expected Identifier expression");
3391 }
3392 }
3393
3394 #[test]
3395 fn test_quote_identifiers_alias() {
3396 let inner = Expression::Column(Column {
3397 name: Identifier::new("val"),
3398 table: None,
3399 join_mark: false,
3400 trailing_comments: vec![],
3401 span: None,
3402 inferred_type: None,
3403 });
3404 let expr = Expression::Alias(Box::new(Alias {
3405 this: inner,
3406 alias: Identifier::new("select"),
3407 column_aliases: vec![Identifier::new("from")],
3408 pre_alias_comments: vec![],
3409 trailing_comments: vec![],
3410 inferred_type: None,
3411 }));
3412 let result = quote_identifiers(expr, None);
3413 if let Expression::Alias(alias) = &result {
3414 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
3415 assert!(
3416 alias.column_aliases[0].quoted,
3417 "Column alias named 'from' should be quoted"
3418 );
3419 if let Expression::Column(col) = &alias.this {
3421 assert!(!col.name.quoted);
3422 }
3423 } else {
3424 panic!("Expected Alias expression");
3425 }
3426 }
3427
3428 #[test]
3429 fn test_quote_identifiers_select_recursive() {
3430 let expr = parse("SELECT a, b FROM t WHERE c = 1");
3432 let result = quote_identifiers(expr, None);
3433 let sql = gen(&result);
3435 assert!(sql.contains("a"));
3437 assert!(sql.contains("b"));
3438 assert!(sql.contains("t"));
3439 }
3440
3441 #[test]
3442 fn test_quote_identifiers_digit_start() {
3443 let expr = Expression::Column(Column {
3444 name: Identifier::new("1col"),
3445 table: None,
3446 join_mark: false,
3447 trailing_comments: vec![],
3448 span: None,
3449 inferred_type: None,
3450 });
3451 let result = quote_identifiers(expr, None);
3452 if let Expression::Column(col) = &result {
3453 assert!(
3454 col.name.quoted,
3455 "Column starting with digit should be quoted"
3456 );
3457 } else {
3458 panic!("Expected Column expression");
3459 }
3460 }
3461
3462 #[test]
3463 fn test_quote_identifiers_with_mysql_dialect() {
3464 let reserved = get_reserved_words(Some(DialectType::MySQL));
3465 assert!(needs_quoting("KILL", &reserved));
3467 assert!(needs_quoting("FORCE", &reserved));
3469 }
3470
3471 #[test]
3472 fn test_quote_identifiers_with_postgresql_dialect() {
3473 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
3474 assert!(needs_quoting("ILIKE", &reserved));
3476 assert!(needs_quoting("VERBOSE", &reserved));
3478 }
3479
3480 #[test]
3481 fn test_quote_identifiers_with_bigquery_dialect() {
3482 let reserved = get_reserved_words(Some(DialectType::BigQuery));
3483 assert!(needs_quoting("STRUCT", &reserved));
3485 assert!(needs_quoting("PROTO", &reserved));
3487 }
3488
3489 #[test]
3490 fn test_quote_identifiers_case_insensitive_reserved() {
3491 let reserved = get_reserved_words(None);
3492 assert!(needs_quoting("Select", &reserved));
3493 assert!(needs_quoting("sElEcT", &reserved));
3494 assert!(needs_quoting("FROM", &reserved));
3495 assert!(needs_quoting("from", &reserved));
3496 }
3497
3498 #[test]
3499 fn test_quote_identifiers_join_using() {
3500 let mut join = crate::expressions::Join {
3502 this: Expression::Table(TableRef::new("other")),
3503 on: None,
3504 using: vec![Identifier::new("key"), Identifier::new("value")],
3505 kind: crate::expressions::JoinKind::Inner,
3506 use_inner_keyword: false,
3507 use_outer_keyword: false,
3508 deferred_condition: false,
3509 join_hint: None,
3510 match_condition: None,
3511 pivots: vec![],
3512 comments: vec![],
3513 nesting_group: 0,
3514 directed: false,
3515 };
3516 let reserved = get_reserved_words(None);
3517 quote_join(&mut join, &reserved);
3518 assert!(
3520 join.using[0].quoted,
3521 "USING identifier 'key' should be quoted"
3522 );
3523 assert!(
3524 !join.using[1].quoted,
3525 "USING identifier 'value' should not be quoted"
3526 );
3527 }
3528
3529 #[test]
3530 fn test_quote_identifiers_cte() {
3531 let mut cte = crate::expressions::Cte {
3533 alias: Identifier::new("select"),
3534 this: Expression::Column(Column {
3535 name: Identifier::new("x"),
3536 table: None,
3537 join_mark: false,
3538 trailing_comments: vec![],
3539 span: None,
3540 inferred_type: None,
3541 }),
3542 columns: vec![Identifier::new("from"), Identifier::new("normal")],
3543 materialized: None,
3544 key_expressions: vec![],
3545 alias_first: false,
3546 comments: Vec::new(),
3547 };
3548 let reserved = get_reserved_words(None);
3549 maybe_quote(&mut cte.alias, &reserved);
3550 for c in &mut cte.columns {
3551 maybe_quote(c, &reserved);
3552 }
3553 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
3554 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
3555 assert!(
3556 !cte.columns[1].quoted,
3557 "CTE column 'normal' should not be quoted"
3558 );
3559 }
3560
3561 #[test]
3562 fn test_quote_identifiers_binary_ops_recurse() {
3563 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
3566 Expression::Column(Column {
3567 name: Identifier::new("select"),
3568 table: None,
3569 join_mark: false,
3570 trailing_comments: vec![],
3571 span: None,
3572 inferred_type: None,
3573 }),
3574 Expression::Column(Column {
3575 name: Identifier::new("normal"),
3576 table: None,
3577 join_mark: false,
3578 trailing_comments: vec![],
3579 span: None,
3580 inferred_type: None,
3581 }),
3582 )));
3583 let result = quote_identifiers(expr, None);
3584 if let Expression::Add(bin) = &result {
3585 if let Expression::Column(left) = &bin.left {
3586 assert!(
3587 left.name.quoted,
3588 "'select' column should be quoted in binary op"
3589 );
3590 }
3591 if let Expression::Column(right) = &bin.right {
3592 assert!(!right.name.quoted, "'normal' column should not be quoted");
3593 }
3594 } else {
3595 panic!("Expected Add expression");
3596 }
3597 }
3598
3599 #[test]
3600 fn test_quote_identifiers_already_quoted_preserved() {
3601 let expr = Expression::Column(Column {
3603 name: Identifier::quoted("normal_name"),
3604 table: None,
3605 join_mark: false,
3606 trailing_comments: vec![],
3607 span: None,
3608 inferred_type: None,
3609 });
3610 let result = quote_identifiers(expr, None);
3611 if let Expression::Column(col) = &result {
3612 assert!(
3613 col.name.quoted,
3614 "Already-quoted identifier should remain quoted"
3615 );
3616 } else {
3617 panic!("Expected Column expression");
3618 }
3619 }
3620
3621 #[test]
3622 fn test_quote_identifiers_full_parsed_query() {
3623 let mut select = crate::expressions::Select::new();
3626 select.expressions.push(Expression::Column(Column {
3627 name: Identifier::new("order"),
3628 table: Some(Identifier::new("t")),
3629 join_mark: false,
3630 trailing_comments: vec![],
3631 span: None,
3632 inferred_type: None,
3633 }));
3634 select.from = Some(crate::expressions::From {
3635 expressions: vec![Expression::Table(TableRef::new("t"))],
3636 });
3637 let expr = Expression::Select(Box::new(select));
3638
3639 let result = quote_identifiers(expr, None);
3640 if let Expression::Select(sel) = &result {
3641 if let Expression::Column(col) = &sel.expressions[0] {
3642 assert!(col.name.quoted, "Column named 'order' should be quoted");
3643 assert!(
3644 !col.table.as_ref().unwrap().quoted,
3645 "Table 't' should not be quoted"
3646 );
3647 } else {
3648 panic!("Expected Column in SELECT list");
3649 }
3650 } else {
3651 panic!("Expected Select expression");
3652 }
3653 }
3654
3655 #[test]
3656 fn test_get_reserved_words_all_dialects() {
3657 let dialects = [
3659 None,
3660 Some(DialectType::Generic),
3661 Some(DialectType::MySQL),
3662 Some(DialectType::PostgreSQL),
3663 Some(DialectType::BigQuery),
3664 Some(DialectType::Snowflake),
3665 Some(DialectType::TSQL),
3666 Some(DialectType::ClickHouse),
3667 Some(DialectType::DuckDB),
3668 Some(DialectType::Hive),
3669 Some(DialectType::Spark),
3670 Some(DialectType::Trino),
3671 Some(DialectType::Oracle),
3672 Some(DialectType::Redshift),
3673 ];
3674 for dialect in &dialects {
3675 let words = get_reserved_words(*dialect);
3676 assert!(
3678 words.contains("SELECT"),
3679 "All dialects should have SELECT as reserved"
3680 );
3681 assert!(
3682 words.contains("FROM"),
3683 "All dialects should have FROM as reserved"
3684 );
3685 }
3686 }
3687}