1use crate::dialects::transform_recursive;
9use crate::dialects::DialectType;
10use crate::expressions::{
11 Alias, BinaryOp, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren,
12 Select, TableRef, VarArgFunc, With,
13};
14use crate::resolver::{Resolver, ResolverError};
15use crate::schema::{normalize_name, Schema};
16use crate::scope::{build_scope, traverse_scope, Scope};
17use std::cell::RefCell;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum QualifyColumnsError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Unknown column: {0}")]
28 UnknownColumn(String),
29
30 #[error("Ambiguous column: {0}")]
31 AmbiguousColumn(String),
32
33 #[error("Cannot automatically join: {0}")]
34 CannotAutoJoin(String),
35
36 #[error("Unknown output column: {0}")]
37 UnknownOutputColumn(String),
38
39 #[error("Column could not be resolved: {column}{for_table}")]
40 ColumnNotResolved { column: String, for_table: String },
41
42 #[error("Resolver error: {0}")]
43 ResolverError(#[from] ResolverError),
44}
45
46pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
48
49#[derive(Debug, Clone, Default)]
51pub struct QualifyColumnsOptions {
52 pub expand_alias_refs: bool,
54 pub expand_stars: bool,
56 pub infer_schema: Option<bool>,
58 pub allow_partial_qualification: bool,
60 pub dialect: Option<DialectType>,
62}
63
64impl QualifyColumnsOptions {
65 pub fn new() -> Self {
67 Self {
68 expand_alias_refs: true,
69 expand_stars: true,
70 infer_schema: None,
71 allow_partial_qualification: false,
72 dialect: None,
73 }
74 }
75
76 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
78 self.expand_alias_refs = expand;
79 self
80 }
81
82 pub fn with_expand_stars(mut self, expand: bool) -> Self {
84 self.expand_stars = expand;
85 self
86 }
87
88 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
90 self.dialect = Some(dialect);
91 self
92 }
93
94 pub fn with_allow_partial(mut self, allow: bool) -> Self {
96 self.allow_partial_qualification = allow;
97 self
98 }
99}
100
101pub fn qualify_columns(
116 expression: Expression,
117 schema: &dyn Schema,
118 options: &QualifyColumnsOptions,
119) -> QualifyColumnsResult<Expression> {
120 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
121 let dialect = options.dialect.or_else(|| schema.dialect());
122 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
123
124 let transformed = transform_recursive(expression, &|node| {
125 if first_error.borrow().is_some() {
126 return Ok(node);
127 }
128
129 match node {
130 Expression::Select(mut select) => {
131 if let Some(with) = &mut select.with {
132 pushdown_cte_alias_columns_with(with);
133 }
134
135 let scope_expr = Expression::Select(select.clone());
136 let scope = build_scope(&scope_expr);
137 let mut resolver = Resolver::new(&scope, schema, infer_schema);
138
139 let column_tables = if first_error.borrow().is_none() {
141 match expand_using(&mut select, &scope, &mut resolver) {
142 Ok(ct) => ct,
143 Err(err) => {
144 *first_error.borrow_mut() = Some(err);
145 HashMap::new()
146 }
147 }
148 } else {
149 HashMap::new()
150 };
151
152 if first_error.borrow().is_none() {
154 if let Err(err) = qualify_columns_in_scope(
155 &mut select,
156 &scope,
157 &mut resolver,
158 options.allow_partial_qualification,
159 ) {
160 *first_error.borrow_mut() = Some(err);
161 }
162 }
163
164 if first_error.borrow().is_none() && options.expand_alias_refs {
166 if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
167 *first_error.borrow_mut() = Some(err);
168 }
169 }
170
171 if first_error.borrow().is_none() && options.expand_stars {
173 if let Err(err) =
174 expand_stars(&mut select, &scope, &mut resolver, &column_tables)
175 {
176 *first_error.borrow_mut() = Some(err);
177 }
178 }
179
180 if first_error.borrow().is_none() {
182 if let Err(err) = qualify_outputs_select(&mut select) {
183 *first_error.borrow_mut() = Some(err);
184 }
185 }
186
187 if first_error.borrow().is_none() {
189 if let Err(err) = expand_group_by(&mut select, dialect) {
190 *first_error.borrow_mut() = Some(err);
191 }
192 }
193
194 Ok(Expression::Select(select))
195 }
196 _ => Ok(node),
197 }
198 })
199 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
200
201 if let Some(err) = first_error.into_inner() {
202 return Err(err);
203 }
204
205 Ok(transformed)
206}
207
208pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
213 let mut all_unqualified = Vec::new();
214
215 for scope in traverse_scope(expression) {
216 if let Expression::Select(_) = &scope.expression {
217 let unqualified = get_unqualified_columns(&scope);
219
220 let external = get_external_columns(&scope);
222 if !external.is_empty() && !is_correlated_subquery(&scope) {
223 let first = &external[0];
224 let for_table = if first.table.is_some() {
225 format!(" for table: '{}'", first.table.as_ref().unwrap())
226 } else {
227 String::new()
228 };
229 return Err(QualifyColumnsError::ColumnNotResolved {
230 column: first.name.clone(),
231 for_table,
232 });
233 }
234
235 all_unqualified.extend(unqualified);
236 }
237 }
238
239 if !all_unqualified.is_empty() {
240 let first = &all_unqualified[0];
241 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
242 }
243
244 Ok(())
245}
246
247fn get_source_name(expr: &Expression) -> Option<String> {
249 match expr {
250 Expression::Table(t) => Some(
251 t.alias
252 .as_ref()
253 .map(|a| a.name.clone())
254 .unwrap_or_else(|| t.name.name.clone()),
255 ),
256 Expression::Subquery(sq) => sq.alias.as_ref().map(|a| a.name.clone()),
257 _ => None,
258 }
259}
260
261fn get_ordered_source_names(select: &Select) -> Vec<String> {
264 let mut ordered = Vec::new();
265 if let Some(from) = &select.from {
266 for expr in &from.expressions {
267 if let Some(name) = get_source_name(expr) {
268 ordered.push(name);
269 }
270 }
271 }
272 for join in &select.joins {
273 if let Some(name) = get_source_name(&join.this) {
274 ordered.push(name);
275 }
276 }
277 ordered
278}
279
280fn make_coalesce(column_name: &str, tables: &[String]) -> Expression {
282 let args: Vec<Expression> = tables
283 .iter()
284 .map(|t| Expression::qualified_column(t.as_str(), column_name))
285 .collect();
286 Expression::Coalesce(Box::new(VarArgFunc {
287 expressions: args,
288 original_name: None,
289 inferred_type: None,
290 }))
291}
292
293fn expand_using(
299 select: &mut Select,
300 _scope: &Scope,
301 resolver: &mut Resolver,
302) -> QualifyColumnsResult<HashMap<String, Vec<String>>> {
303 let mut columns: HashMap<String, String> = HashMap::new();
305
306 let mut column_tables: HashMap<String, Vec<String>> = HashMap::new();
308
309 let join_names: HashSet<String> = select
311 .joins
312 .iter()
313 .filter_map(|j| get_source_name(&j.this))
314 .collect();
315
316 let all_ordered = get_ordered_source_names(select);
317 let mut ordered: Vec<String> = all_ordered
318 .iter()
319 .filter(|name| !join_names.contains(name.as_str()))
320 .cloned()
321 .collect();
322
323 if join_names.is_empty() {
324 return Ok(column_tables);
325 }
326
327 fn update_source_columns(
329 source_name: &str,
330 columns: &mut HashMap<String, String>,
331 resolver: &mut Resolver,
332 ) {
333 if let Ok(source_cols) = resolver.get_source_columns(source_name) {
334 for col_name in source_cols {
335 columns
336 .entry(col_name)
337 .or_insert_with(|| source_name.to_string());
338 }
339 }
340 }
341
342 for source_name in &ordered {
344 update_source_columns(source_name, &mut columns, resolver);
345 }
346
347 for i in 0..select.joins.len() {
348 let source_table = ordered.last().cloned().unwrap_or_default();
350 if !source_table.is_empty() {
351 update_source_columns(&source_table, &mut columns, resolver);
352 }
353
354 let join_table = get_source_name(&select.joins[i].this).unwrap_or_default();
356 ordered.push(join_table.clone());
357
358 if select.joins[i].using.is_empty() {
360 continue;
361 }
362
363 let _join_columns: Vec<String> =
364 resolver.get_source_columns(&join_table).unwrap_or_default();
365
366 let using_identifiers: Vec<String> = select.joins[i]
367 .using
368 .iter()
369 .map(|id| id.name.clone())
370 .collect();
371
372 let using_count = using_identifiers.len();
373 let is_semi_or_anti = matches!(
374 select.joins[i].kind,
375 crate::expressions::JoinKind::Semi
376 | crate::expressions::JoinKind::Anti
377 | crate::expressions::JoinKind::LeftSemi
378 | crate::expressions::JoinKind::LeftAnti
379 | crate::expressions::JoinKind::RightSemi
380 | crate::expressions::JoinKind::RightAnti
381 );
382
383 let mut conditions: Vec<Expression> = Vec::new();
384
385 for identifier in &using_identifiers {
386 let table = columns
387 .get(identifier)
388 .cloned()
389 .unwrap_or_else(|| source_table.clone());
390
391 let lhs = if i == 0 || using_count == 1 {
393 Expression::qualified_column(table.as_str(), identifier.as_str())
395 } else {
396 let coalesce_cols: Vec<String> = ordered[..ordered.len() - 1]
399 .iter()
400 .filter(|t| {
401 resolver
402 .get_source_columns(t)
403 .unwrap_or_default()
404 .contains(identifier)
405 })
406 .cloned()
407 .collect();
408
409 if coalesce_cols.len() > 1 {
410 make_coalesce(identifier, &coalesce_cols)
411 } else {
412 Expression::qualified_column(table.as_str(), identifier.as_str())
413 }
414 };
415
416 let rhs = Expression::qualified_column(join_table.as_str(), identifier.as_str());
418
419 conditions.push(Expression::Eq(Box::new(BinaryOp::new(lhs, rhs))));
420
421 if !is_semi_or_anti {
423 let tables = column_tables
424 .entry(identifier.clone())
425 .or_insert_with(Vec::new);
426 if !tables.contains(&table) {
427 tables.push(table.clone());
428 }
429 if !tables.contains(&join_table) {
430 tables.push(join_table.clone());
431 }
432 }
433 }
434
435 let on_condition = conditions
437 .into_iter()
438 .reduce(|acc, cond| Expression::And(Box::new(BinaryOp::new(acc, cond))))
439 .expect("at least one USING column");
440
441 select.joins[i].on = Some(on_condition);
443 select.joins[i].using = vec![];
444 }
445
446 if !column_tables.is_empty() {
448 let mut new_expressions = Vec::with_capacity(select.expressions.len());
450 for expr in &select.expressions {
451 match expr {
452 Expression::Column(col)
453 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
454 {
455 let tables = &column_tables[&col.name.name];
456 let coalesce = make_coalesce(&col.name.name, tables);
457 new_expressions.push(Expression::Alias(Box::new(Alias {
459 this: coalesce,
460 alias: Identifier::new(&col.name.name),
461 column_aliases: vec![],
462 pre_alias_comments: vec![],
463 trailing_comments: vec![],
464 inferred_type: None,
465 })));
466 }
467 _ => {
468 let mut rewritten = expr.clone();
469 rewrite_using_columns_in_expression(&mut rewritten, &column_tables);
470 new_expressions.push(rewritten);
471 }
472 }
473 }
474 select.expressions = new_expressions;
475
476 if let Some(where_clause) = &mut select.where_clause {
478 rewrite_using_columns_in_expression(&mut where_clause.this, &column_tables);
479 }
480
481 if let Some(group_by) = &mut select.group_by {
483 for expr in &mut group_by.expressions {
484 rewrite_using_columns_in_expression(expr, &column_tables);
485 }
486 }
487
488 if let Some(having) = &mut select.having {
490 rewrite_using_columns_in_expression(&mut having.this, &column_tables);
491 }
492
493 if let Some(qualify) = &mut select.qualify {
495 rewrite_using_columns_in_expression(&mut qualify.this, &column_tables);
496 }
497
498 if let Some(order_by) = &mut select.order_by {
500 for ordered in &mut order_by.expressions {
501 rewrite_using_columns_in_expression(&mut ordered.this, &column_tables);
502 }
503 }
504 }
505
506 Ok(column_tables)
507}
508
509fn rewrite_using_columns_in_expression(
511 expr: &mut Expression,
512 column_tables: &HashMap<String, Vec<String>>,
513) {
514 let transformed = transform_recursive(expr.clone(), &|node| match node {
515 Expression::Column(col)
516 if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
517 {
518 let tables = &column_tables[&col.name.name];
519 Ok(make_coalesce(&col.name.name, tables))
520 }
521 other => Ok(other),
522 });
523
524 if let Ok(next) = transformed {
525 *expr = next;
526 }
527}
528
529fn qualify_columns_in_scope(
531 select: &mut Select,
532 scope: &Scope,
533 resolver: &mut Resolver,
534 allow_partial: bool,
535) -> QualifyColumnsResult<()> {
536 for expr in &mut select.expressions {
537 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
538 }
539 if let Some(where_clause) = &mut select.where_clause {
540 qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
541 }
542 if let Some(group_by) = &mut select.group_by {
543 for expr in &mut group_by.expressions {
544 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
545 }
546 }
547 if let Some(having) = &mut select.having {
548 qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
549 }
550 if let Some(qualify) = &mut select.qualify {
551 qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
552 }
553 if let Some(order_by) = &mut select.order_by {
554 for ordered in &mut order_by.expressions {
555 qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
556 }
557 }
558 for join in &mut select.joins {
559 qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
560 if let Some(on) = &mut join.on {
561 qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
562 }
563 }
564 Ok(())
565}
566
567fn expand_alias_refs(
574 select: &mut Select,
575 _resolver: &mut Resolver,
576 _dialect: Option<DialectType>,
577) -> QualifyColumnsResult<()> {
578 let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
579
580 for (i, expr) in select.expressions.iter_mut().enumerate() {
581 replace_alias_refs_in_expression(expr, &alias_to_expression, false);
582 if let Expression::Alias(alias) = expr {
583 alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
584 }
585 }
586
587 if let Some(where_clause) = &mut select.where_clause {
588 replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
589 }
590 if let Some(group_by) = &mut select.group_by {
591 for expr in &mut group_by.expressions {
592 replace_alias_refs_in_expression(expr, &alias_to_expression, true);
593 }
594 }
595 if let Some(having) = &mut select.having {
596 replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
597 }
598 if let Some(qualify) = &mut select.qualify {
599 replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
600 }
601 if let Some(order_by) = &mut select.order_by {
602 for ordered in &mut order_by.expressions {
603 replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
604 }
605 }
606
607 Ok(())
608}
609
610fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
617 let projections = select.expressions.clone();
618
619 if let Some(group_by) = &mut select.group_by {
620 for group_expr in &mut group_by.expressions {
621 if let Some(index) = positional_reference(group_expr) {
622 let replacement = select_expression_at_position(&projections, index)?;
623 *group_expr = replacement;
624 }
625 }
626 }
627 Ok(())
628}
629
630fn expand_stars(
640 select: &mut Select,
641 _scope: &Scope,
642 resolver: &mut Resolver,
643 column_tables: &HashMap<String, Vec<String>>,
644) -> QualifyColumnsResult<()> {
645 let mut new_selections: Vec<Expression> = Vec::new();
646 let mut has_star = false;
647 let mut coalesced_columns: HashSet<String> = HashSet::new();
648
649 let ordered_sources = get_ordered_source_names(select);
651
652 for expr in &select.expressions {
653 match expr {
654 Expression::Star(_) => {
655 has_star = true;
656 for source_name in &ordered_sources {
657 if let Ok(columns) = resolver.get_source_columns(source_name) {
658 if columns.contains(&"*".to_string()) || columns.is_empty() {
659 return Ok(());
660 }
661 for col_name in &columns {
662 if coalesced_columns.contains(col_name) {
663 continue;
665 }
666 if let Some(tables) = column_tables.get(col_name) {
667 if tables.contains(source_name) {
668 coalesced_columns.insert(col_name.clone());
670 let coalesce = make_coalesce(col_name, tables);
671 new_selections.push(Expression::Alias(Box::new(Alias {
672 this: coalesce,
673 alias: Identifier::new(col_name),
674 column_aliases: vec![],
675 pre_alias_comments: vec![],
676 trailing_comments: vec![],
677 inferred_type: None,
678 })));
679 continue;
680 }
681 }
682 new_selections
683 .push(create_qualified_column(col_name, Some(source_name)));
684 }
685 }
686 }
687 }
688 Expression::Column(col) if is_star_column(col) => {
689 has_star = true;
690 if let Some(table) = &col.table {
691 let table_name = &table.name;
692 if !ordered_sources.contains(table_name) {
693 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
694 }
695 if let Ok(columns) = resolver.get_source_columns(table_name) {
696 if columns.contains(&"*".to_string()) || columns.is_empty() {
697 return Ok(());
698 }
699 for col_name in &columns {
700 if coalesced_columns.contains(col_name) {
701 continue;
702 }
703 if let Some(tables) = column_tables.get(col_name) {
704 if tables.contains(table_name) {
705 coalesced_columns.insert(col_name.clone());
706 let coalesce = make_coalesce(col_name, tables);
707 new_selections.push(Expression::Alias(Box::new(Alias {
708 this: coalesce,
709 alias: Identifier::new(col_name),
710 column_aliases: vec![],
711 pre_alias_comments: vec![],
712 trailing_comments: vec![],
713 inferred_type: None,
714 })));
715 continue;
716 }
717 }
718 new_selections
719 .push(create_qualified_column(col_name, Some(table_name)));
720 }
721 }
722 }
723 }
724 _ => new_selections.push(expr.clone()),
725 }
726 }
727
728 if has_star {
729 select.expressions = new_selections;
730 }
731
732 Ok(())
733}
734
735pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
742 if let Expression::Select(mut select) = scope.expression.clone() {
743 qualify_outputs_select(&mut select)?;
744 }
745 Ok(())
746}
747
748fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
749 let mut new_selections: Vec<Expression> = Vec::new();
750
751 for (i, expr) in select.expressions.iter().enumerate() {
752 match expr {
753 Expression::Alias(_) => new_selections.push(expr.clone()),
754 Expression::Column(col) => {
755 new_selections.push(create_alias(expr.clone(), &col.name.name));
756 }
757 Expression::Star(_) => new_selections.push(expr.clone()),
758 _ => {
759 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
760 new_selections.push(create_alias(expr.clone(), &alias_name));
761 }
762 }
763 }
764
765 select.expressions = new_selections;
766 Ok(())
767}
768
769fn qualify_columns_in_expression(
770 expr: &mut Expression,
771 scope: &Scope,
772 resolver: &mut Resolver,
773 allow_partial: bool,
774) -> QualifyColumnsResult<()> {
775 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
776 let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
777
778 let transformed = transform_recursive(expr.clone(), &|node| {
779 if first_error.borrow().is_some() {
780 return Ok(node);
781 }
782
783 match node {
784 Expression::Column(mut col) => {
785 if let Err(err) = qualify_single_column(
786 &mut col,
787 scope,
788 &mut resolver_cell.borrow_mut(),
789 allow_partial,
790 ) {
791 *first_error.borrow_mut() = Some(err);
792 }
793 Ok(Expression::Column(col))
794 }
795 _ => Ok(node),
796 }
797 })
798 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
799
800 if let Some(err) = first_error.into_inner() {
801 return Err(err);
802 }
803
804 *expr = transformed;
805 Ok(())
806}
807
808fn qualify_single_column(
809 col: &mut Column,
810 scope: &Scope,
811 resolver: &mut Resolver,
812 allow_partial: bool,
813) -> QualifyColumnsResult<()> {
814 if is_star_column(col) {
815 return Ok(());
816 }
817
818 if let Some(table) = &col.table {
819 let table_name = &table.name;
820 if !scope.sources.contains_key(table_name) {
821 if resolver.table_exists_in_schema(table_name) {
825 return Ok(());
826 }
827 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
828 }
829
830 if let Ok(source_columns) = resolver.get_source_columns(table_name) {
831 let normalized_column_name = normalize_column_name(&col.name.name, resolver.dialect);
832 if !allow_partial
833 && !source_columns.is_empty()
834 && !source_columns.iter().any(|column| {
835 normalize_column_name(column, resolver.dialect) == normalized_column_name
836 })
837 && !source_columns.contains(&"*".to_string())
838 {
839 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
840 }
841 }
842 return Ok(());
843 }
844
845 if let Some(table_name) = resolver.get_table(&col.name.name) {
846 col.table = Some(Identifier::new(table_name));
847 return Ok(());
848 }
849
850 if let Some(outer_table) = resolver.find_column_in_outer_schema_tables(&col.name.name) {
853 col.table = Some(Identifier::new(outer_table));
854 return Ok(());
855 }
856
857 if !allow_partial {
858 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
859 }
860
861 Ok(())
862}
863
864fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
865 normalize_name(name, dialect, false, true)
866}
867
868fn replace_alias_refs_in_expression(
869 expr: &mut Expression,
870 alias_to_expression: &HashMap<String, (Expression, usize)>,
871 literal_index: bool,
872) {
873 let transformed = transform_recursive(expr.clone(), &|node| match node {
874 Expression::Column(col) if col.table.is_none() => {
875 if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
876 if literal_index && matches!(alias_expr, Expression::Literal(_)) {
877 return Ok(Expression::number(*index as i64));
878 }
879 return Ok(Expression::Paren(Box::new(Paren {
880 this: alias_expr.clone(),
881 trailing_comments: vec![],
882 })));
883 }
884 Ok(Expression::Column(col))
885 }
886 other => Ok(other),
887 });
888
889 if let Ok(next) = transformed {
890 *expr = next;
891 }
892}
893
894fn positional_reference(expr: &Expression) -> Option<usize> {
895 match expr {
896 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
897 let Literal::Number(value) = lit.as_ref() else {
898 unreachable!()
899 };
900 value.parse::<usize>().ok()
901 }
902 _ => None,
903 }
904}
905
906fn select_expression_at_position(
907 projections: &[Expression],
908 index: usize,
909) -> QualifyColumnsResult<Expression> {
910 if index == 0 || index > projections.len() {
911 return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
912 }
913
914 let projection = projections[index - 1].clone();
915 Ok(match projection {
916 Expression::Alias(alias) => alias.this.clone(),
917 other => other,
918 })
919}
920
921fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
924 let mut words: HashSet<&'static str> = [
926 "ADD",
928 "ALL",
929 "ALTER",
930 "AND",
931 "ANY",
932 "AS",
933 "ASC",
934 "BETWEEN",
935 "BY",
936 "CASE",
937 "CAST",
938 "CHECK",
939 "COLUMN",
940 "CONSTRAINT",
941 "CREATE",
942 "CROSS",
943 "CURRENT",
944 "CURRENT_DATE",
945 "CURRENT_TIME",
946 "CURRENT_TIMESTAMP",
947 "CURRENT_USER",
948 "DATABASE",
949 "DEFAULT",
950 "DELETE",
951 "DESC",
952 "DISTINCT",
953 "DROP",
954 "ELSE",
955 "END",
956 "ESCAPE",
957 "EXCEPT",
958 "EXISTS",
959 "FALSE",
960 "FETCH",
961 "FOR",
962 "FOREIGN",
963 "FROM",
964 "FULL",
965 "GRANT",
966 "GROUP",
967 "HAVING",
968 "IF",
969 "IN",
970 "INDEX",
971 "INNER",
972 "INSERT",
973 "INTERSECT",
974 "INTO",
975 "IS",
976 "JOIN",
977 "KEY",
978 "LEFT",
979 "LIKE",
980 "LIMIT",
981 "NATURAL",
982 "NOT",
983 "NULL",
984 "OFFSET",
985 "ON",
986 "OR",
987 "ORDER",
988 "OUTER",
989 "PRIMARY",
990 "REFERENCES",
991 "REPLACE",
992 "RETURNING",
993 "RIGHT",
994 "ROLLBACK",
995 "ROW",
996 "ROWS",
997 "SELECT",
998 "SESSION_USER",
999 "SET",
1000 "SOME",
1001 "TABLE",
1002 "THEN",
1003 "TO",
1004 "TRUE",
1005 "TRUNCATE",
1006 "UNION",
1007 "UNIQUE",
1008 "UPDATE",
1009 "USING",
1010 "VALUES",
1011 "VIEW",
1012 "WHEN",
1013 "WHERE",
1014 "WINDOW",
1015 "WITH",
1016 ]
1017 .iter()
1018 .copied()
1019 .collect();
1020
1021 match dialect {
1023 Some(DialectType::MySQL) => {
1024 words.extend(
1025 [
1026 "ANALYZE",
1027 "BOTH",
1028 "CHANGE",
1029 "CONDITION",
1030 "DATABASES",
1031 "DAY_HOUR",
1032 "DAY_MICROSECOND",
1033 "DAY_MINUTE",
1034 "DAY_SECOND",
1035 "DELAYED",
1036 "DETERMINISTIC",
1037 "DIV",
1038 "DUAL",
1039 "EACH",
1040 "ELSEIF",
1041 "ENCLOSED",
1042 "EXPLAIN",
1043 "FLOAT4",
1044 "FLOAT8",
1045 "FORCE",
1046 "HOUR_MICROSECOND",
1047 "HOUR_MINUTE",
1048 "HOUR_SECOND",
1049 "IGNORE",
1050 "INFILE",
1051 "INT1",
1052 "INT2",
1053 "INT3",
1054 "INT4",
1055 "INT8",
1056 "ITERATE",
1057 "KEYS",
1058 "KILL",
1059 "LEADING",
1060 "LEAVE",
1061 "LINES",
1062 "LOAD",
1063 "LOCK",
1064 "LONG",
1065 "LONGBLOB",
1066 "LONGTEXT",
1067 "LOOP",
1068 "LOW_PRIORITY",
1069 "MATCH",
1070 "MEDIUMBLOB",
1071 "MEDIUMINT",
1072 "MEDIUMTEXT",
1073 "MINUTE_MICROSECOND",
1074 "MINUTE_SECOND",
1075 "MOD",
1076 "MODIFIES",
1077 "NO_WRITE_TO_BINLOG",
1078 "OPTIMIZE",
1079 "OPTIONALLY",
1080 "OUT",
1081 "OUTFILE",
1082 "PURGE",
1083 "READS",
1084 "REGEXP",
1085 "RELEASE",
1086 "RENAME",
1087 "REPEAT",
1088 "REQUIRE",
1089 "RESIGNAL",
1090 "RETURN",
1091 "REVOKE",
1092 "RLIKE",
1093 "SCHEMA",
1094 "SCHEMAS",
1095 "SECOND_MICROSECOND",
1096 "SENSITIVE",
1097 "SEPARATOR",
1098 "SHOW",
1099 "SIGNAL",
1100 "SPATIAL",
1101 "SQL",
1102 "SQLEXCEPTION",
1103 "SQLSTATE",
1104 "SQLWARNING",
1105 "SQL_BIG_RESULT",
1106 "SQL_CALC_FOUND_ROWS",
1107 "SQL_SMALL_RESULT",
1108 "SSL",
1109 "STARTING",
1110 "STRAIGHT_JOIN",
1111 "TERMINATED",
1112 "TINYBLOB",
1113 "TINYINT",
1114 "TINYTEXT",
1115 "TRAILING",
1116 "TRIGGER",
1117 "UNDO",
1118 "UNLOCK",
1119 "UNSIGNED",
1120 "USAGE",
1121 "UTC_DATE",
1122 "UTC_TIME",
1123 "UTC_TIMESTAMP",
1124 "VARBINARY",
1125 "VARCHARACTER",
1126 "WHILE",
1127 "WRITE",
1128 "XOR",
1129 "YEAR_MONTH",
1130 "ZEROFILL",
1131 ]
1132 .iter()
1133 .copied(),
1134 );
1135 }
1136 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
1137 words.extend(
1138 [
1139 "ANALYSE",
1140 "ANALYZE",
1141 "ARRAY",
1142 "AUTHORIZATION",
1143 "BINARY",
1144 "BOTH",
1145 "COLLATE",
1146 "CONCURRENTLY",
1147 "DO",
1148 "FREEZE",
1149 "ILIKE",
1150 "INITIALLY",
1151 "ISNULL",
1152 "LATERAL",
1153 "LEADING",
1154 "LOCALTIME",
1155 "LOCALTIMESTAMP",
1156 "NOTNULL",
1157 "ONLY",
1158 "OVERLAPS",
1159 "PLACING",
1160 "SIMILAR",
1161 "SYMMETRIC",
1162 "TABLESAMPLE",
1163 "TRAILING",
1164 "VARIADIC",
1165 "VERBOSE",
1166 ]
1167 .iter()
1168 .copied(),
1169 );
1170 }
1171 Some(DialectType::BigQuery) => {
1172 words.extend(
1173 [
1174 "ASSERT_ROWS_MODIFIED",
1175 "COLLATE",
1176 "CONTAINS",
1177 "CUBE",
1178 "DEFINE",
1179 "ENUM",
1180 "EXTRACT",
1181 "FOLLOWING",
1182 "GROUPING",
1183 "GROUPS",
1184 "HASH",
1185 "IGNORE",
1186 "LATERAL",
1187 "LOOKUP",
1188 "MERGE",
1189 "NEW",
1190 "NO",
1191 "NULLS",
1192 "OF",
1193 "OVER",
1194 "PARTITION",
1195 "PRECEDING",
1196 "PROTO",
1197 "RANGE",
1198 "RECURSIVE",
1199 "RESPECT",
1200 "ROLLUP",
1201 "STRUCT",
1202 "TABLESAMPLE",
1203 "TREAT",
1204 "UNBOUNDED",
1205 "WITHIN",
1206 ]
1207 .iter()
1208 .copied(),
1209 );
1210 }
1211 Some(DialectType::Snowflake) => {
1212 words.extend(
1213 [
1214 "ACCOUNT",
1215 "BOTH",
1216 "CONNECT",
1217 "FOLLOWING",
1218 "ILIKE",
1219 "INCREMENT",
1220 "ISSUE",
1221 "LATERAL",
1222 "LEADING",
1223 "LOCALTIME",
1224 "LOCALTIMESTAMP",
1225 "MINUS",
1226 "QUALIFY",
1227 "REGEXP",
1228 "RLIKE",
1229 "SOME",
1230 "START",
1231 "TABLESAMPLE",
1232 "TOP",
1233 "TRAILING",
1234 "TRY_CAST",
1235 ]
1236 .iter()
1237 .copied(),
1238 );
1239 }
1240 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
1241 words.extend(
1242 [
1243 "BACKUP",
1244 "BREAK",
1245 "BROWSE",
1246 "BULK",
1247 "CASCADE",
1248 "CHECKPOINT",
1249 "CLOSE",
1250 "CLUSTERED",
1251 "COALESCE",
1252 "COMPUTE",
1253 "CONTAINS",
1254 "CONTAINSTABLE",
1255 "CONTINUE",
1256 "CONVERT",
1257 "DBCC",
1258 "DEALLOCATE",
1259 "DENY",
1260 "DISK",
1261 "DISTRIBUTED",
1262 "DUMP",
1263 "ERRLVL",
1264 "EXEC",
1265 "EXECUTE",
1266 "EXIT",
1267 "EXTERNAL",
1268 "FILE",
1269 "FILLFACTOR",
1270 "FREETEXT",
1271 "FREETEXTTABLE",
1272 "FUNCTION",
1273 "GOTO",
1274 "HOLDLOCK",
1275 "IDENTITY",
1276 "IDENTITYCOL",
1277 "IDENTITY_INSERT",
1278 "KILL",
1279 "LINENO",
1280 "MERGE",
1281 "NONCLUSTERED",
1282 "NULLIF",
1283 "OF",
1284 "OFF",
1285 "OFFSETS",
1286 "OPEN",
1287 "OPENDATASOURCE",
1288 "OPENQUERY",
1289 "OPENROWSET",
1290 "OPENXML",
1291 "OVER",
1292 "PERCENT",
1293 "PIVOT",
1294 "PLAN",
1295 "PRINT",
1296 "PROC",
1297 "PROCEDURE",
1298 "PUBLIC",
1299 "RAISERROR",
1300 "READ",
1301 "READTEXT",
1302 "RECONFIGURE",
1303 "REPLICATION",
1304 "RESTORE",
1305 "RESTRICT",
1306 "REVERT",
1307 "ROWCOUNT",
1308 "ROWGUIDCOL",
1309 "RULE",
1310 "SAVE",
1311 "SECURITYAUDIT",
1312 "SEMANTICKEYPHRASETABLE",
1313 "SEMANTICSIMILARITYDETAILSTABLE",
1314 "SEMANTICSIMILARITYTABLE",
1315 "SETUSER",
1316 "SHUTDOWN",
1317 "STATISTICS",
1318 "SYSTEM_USER",
1319 "TEXTSIZE",
1320 "TOP",
1321 "TRAN",
1322 "TRANSACTION",
1323 "TRIGGER",
1324 "TSEQUAL",
1325 "UNPIVOT",
1326 "UPDATETEXT",
1327 "WAITFOR",
1328 "WRITETEXT",
1329 ]
1330 .iter()
1331 .copied(),
1332 );
1333 }
1334 Some(DialectType::ClickHouse) => {
1335 words.extend(
1336 [
1337 "ANTI",
1338 "ARRAY",
1339 "ASOF",
1340 "FINAL",
1341 "FORMAT",
1342 "GLOBAL",
1343 "INF",
1344 "KILL",
1345 "MATERIALIZED",
1346 "NAN",
1347 "PREWHERE",
1348 "SAMPLE",
1349 "SEMI",
1350 "SETTINGS",
1351 "TOP",
1352 ]
1353 .iter()
1354 .copied(),
1355 );
1356 }
1357 Some(DialectType::DuckDB) => {
1358 words.extend(
1359 [
1360 "ANALYSE",
1361 "ANALYZE",
1362 "ARRAY",
1363 "BOTH",
1364 "LATERAL",
1365 "LEADING",
1366 "LOCALTIME",
1367 "LOCALTIMESTAMP",
1368 "PLACING",
1369 "QUALIFY",
1370 "SIMILAR",
1371 "TABLESAMPLE",
1372 "TRAILING",
1373 ]
1374 .iter()
1375 .copied(),
1376 );
1377 }
1378 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
1379 words.extend(
1380 [
1381 "BOTH",
1382 "CLUSTER",
1383 "DISTRIBUTE",
1384 "EXCHANGE",
1385 "EXTENDED",
1386 "FUNCTION",
1387 "LATERAL",
1388 "LEADING",
1389 "MACRO",
1390 "OVER",
1391 "PARTITION",
1392 "PERCENT",
1393 "RANGE",
1394 "READS",
1395 "REDUCE",
1396 "REGEXP",
1397 "REVOKE",
1398 "RLIKE",
1399 "ROLLUP",
1400 "SEMI",
1401 "SORT",
1402 "TABLESAMPLE",
1403 "TRAILING",
1404 "TRANSFORM",
1405 "UNBOUNDED",
1406 "UNIQUEJOIN",
1407 ]
1408 .iter()
1409 .copied(),
1410 );
1411 }
1412 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
1413 words.extend(
1414 [
1415 "CUBE",
1416 "DEALLOCATE",
1417 "DESCRIBE",
1418 "EXECUTE",
1419 "EXTRACT",
1420 "GROUPING",
1421 "LATERAL",
1422 "LOCALTIME",
1423 "LOCALTIMESTAMP",
1424 "NORMALIZE",
1425 "PREPARE",
1426 "ROLLUP",
1427 "SOME",
1428 "TABLESAMPLE",
1429 "UESCAPE",
1430 "UNNEST",
1431 ]
1432 .iter()
1433 .copied(),
1434 );
1435 }
1436 Some(DialectType::Oracle) => {
1437 words.extend(
1438 [
1439 "ACCESS",
1440 "AUDIT",
1441 "CLUSTER",
1442 "COMMENT",
1443 "COMPRESS",
1444 "CONNECT",
1445 "EXCLUSIVE",
1446 "FILE",
1447 "IDENTIFIED",
1448 "IMMEDIATE",
1449 "INCREMENT",
1450 "INITIAL",
1451 "LEVEL",
1452 "LOCK",
1453 "LONG",
1454 "MAXEXTENTS",
1455 "MINUS",
1456 "MODE",
1457 "NOAUDIT",
1458 "NOCOMPRESS",
1459 "NOWAIT",
1460 "NUMBER",
1461 "OF",
1462 "OFFLINE",
1463 "ONLINE",
1464 "PCTFREE",
1465 "PRIOR",
1466 "RAW",
1467 "RENAME",
1468 "RESOURCE",
1469 "REVOKE",
1470 "SHARE",
1471 "SIZE",
1472 "START",
1473 "SUCCESSFUL",
1474 "SYNONYM",
1475 "SYSDATE",
1476 "TRIGGER",
1477 "UID",
1478 "VALIDATE",
1479 "VARCHAR2",
1480 "WHENEVER",
1481 ]
1482 .iter()
1483 .copied(),
1484 );
1485 }
1486 Some(DialectType::Redshift) => {
1487 words.extend(
1488 [
1489 "AZ64",
1490 "BZIP2",
1491 "DELTA",
1492 "DELTA32K",
1493 "DISTSTYLE",
1494 "ENCODE",
1495 "GZIP",
1496 "ILIKE",
1497 "LIMIT",
1498 "LUNS",
1499 "LZO",
1500 "LZOP",
1501 "MOSTLY13",
1502 "MOSTLY32",
1503 "MOSTLY8",
1504 "RAW",
1505 "SIMILAR",
1506 "SNAPSHOT",
1507 "SORTKEY",
1508 "SYSDATE",
1509 "TOP",
1510 "ZSTD",
1511 ]
1512 .iter()
1513 .copied(),
1514 );
1515 }
1516 _ => {
1517 words.extend(
1519 [
1520 "ANALYZE",
1521 "ARRAY",
1522 "BOTH",
1523 "CUBE",
1524 "GROUPING",
1525 "LATERAL",
1526 "LEADING",
1527 "LOCALTIME",
1528 "LOCALTIMESTAMP",
1529 "OVER",
1530 "PARTITION",
1531 "QUALIFY",
1532 "RANGE",
1533 "ROLLUP",
1534 "SIMILAR",
1535 "SOME",
1536 "TABLESAMPLE",
1537 "TRAILING",
1538 ]
1539 .iter()
1540 .copied(),
1541 );
1542 }
1543 }
1544
1545 words
1546}
1547
1548fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1556 if name.is_empty() {
1557 return false;
1558 }
1559
1560 if name.as_bytes()[0].is_ascii_digit() {
1562 return true;
1563 }
1564
1565 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1567 return true;
1568 }
1569
1570 let upper = name.to_uppercase();
1572 reserved_words.contains(upper.as_str())
1573}
1574
1575fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1577 if id.quoted || id.name.is_empty() || id.name == "*" {
1580 return;
1581 }
1582 if needs_quoting(&id.name, reserved_words) {
1583 id.quoted = true;
1584 }
1585}
1586
1587fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1589 match expr {
1590 Expression::Identifier(id) => {
1592 maybe_quote(id, reserved_words);
1593 }
1594
1595 Expression::Column(col) => {
1596 maybe_quote(&mut col.name, reserved_words);
1597 if let Some(ref mut table) = col.table {
1598 maybe_quote(table, reserved_words);
1599 }
1600 }
1601
1602 Expression::Table(table_ref) => {
1603 maybe_quote(&mut table_ref.name, reserved_words);
1604 if let Some(ref mut schema) = table_ref.schema {
1605 maybe_quote(schema, reserved_words);
1606 }
1607 if let Some(ref mut catalog) = table_ref.catalog {
1608 maybe_quote(catalog, reserved_words);
1609 }
1610 if let Some(ref mut alias) = table_ref.alias {
1611 maybe_quote(alias, reserved_words);
1612 }
1613 for ca in &mut table_ref.column_aliases {
1614 maybe_quote(ca, reserved_words);
1615 }
1616 for p in &mut table_ref.partitions {
1617 maybe_quote(p, reserved_words);
1618 }
1619 for h in &mut table_ref.hints {
1621 quote_identifiers_recursive(h, reserved_words);
1622 }
1623 if let Some(ref mut ver) = table_ref.version {
1624 quote_identifiers_recursive(&mut ver.this, reserved_words);
1625 if let Some(ref mut e) = ver.expression {
1626 quote_identifiers_recursive(e, reserved_words);
1627 }
1628 }
1629 }
1630
1631 Expression::Star(star) => {
1632 if let Some(ref mut table) = star.table {
1633 maybe_quote(table, reserved_words);
1634 }
1635 if let Some(ref mut except_ids) = star.except {
1636 for id in except_ids {
1637 maybe_quote(id, reserved_words);
1638 }
1639 }
1640 if let Some(ref mut replace_aliases) = star.replace {
1641 for alias in replace_aliases {
1642 maybe_quote(&mut alias.alias, reserved_words);
1643 quote_identifiers_recursive(&mut alias.this, reserved_words);
1644 }
1645 }
1646 if let Some(ref mut rename_pairs) = star.rename {
1647 for (from, to) in rename_pairs {
1648 maybe_quote(from, reserved_words);
1649 maybe_quote(to, reserved_words);
1650 }
1651 }
1652 }
1653
1654 Expression::Alias(alias) => {
1656 maybe_quote(&mut alias.alias, reserved_words);
1657 for ca in &mut alias.column_aliases {
1658 maybe_quote(ca, reserved_words);
1659 }
1660 quote_identifiers_recursive(&mut alias.this, reserved_words);
1661 }
1662
1663 Expression::Select(select) => {
1665 for e in &mut select.expressions {
1666 quote_identifiers_recursive(e, reserved_words);
1667 }
1668 if let Some(ref mut from) = select.from {
1669 for e in &mut from.expressions {
1670 quote_identifiers_recursive(e, reserved_words);
1671 }
1672 }
1673 for join in &mut select.joins {
1674 quote_join(join, reserved_words);
1675 }
1676 for lv in &mut select.lateral_views {
1677 quote_lateral_view(lv, reserved_words);
1678 }
1679 if let Some(ref mut prewhere) = select.prewhere {
1680 quote_identifiers_recursive(prewhere, reserved_words);
1681 }
1682 if let Some(ref mut wh) = select.where_clause {
1683 quote_identifiers_recursive(&mut wh.this, reserved_words);
1684 }
1685 if let Some(ref mut gb) = select.group_by {
1686 for e in &mut gb.expressions {
1687 quote_identifiers_recursive(e, reserved_words);
1688 }
1689 }
1690 if let Some(ref mut hv) = select.having {
1691 quote_identifiers_recursive(&mut hv.this, reserved_words);
1692 }
1693 if let Some(ref mut q) = select.qualify {
1694 quote_identifiers_recursive(&mut q.this, reserved_words);
1695 }
1696 if let Some(ref mut ob) = select.order_by {
1697 for o in &mut ob.expressions {
1698 quote_identifiers_recursive(&mut o.this, reserved_words);
1699 }
1700 }
1701 if let Some(ref mut lim) = select.limit {
1702 quote_identifiers_recursive(&mut lim.this, reserved_words);
1703 }
1704 if let Some(ref mut off) = select.offset {
1705 quote_identifiers_recursive(&mut off.this, reserved_words);
1706 }
1707 if let Some(ref mut with) = select.with {
1708 quote_with(with, reserved_words);
1709 }
1710 if let Some(ref mut windows) = select.windows {
1711 for nw in windows {
1712 maybe_quote(&mut nw.name, reserved_words);
1713 quote_over(&mut nw.spec, reserved_words);
1714 }
1715 }
1716 if let Some(ref mut distinct_on) = select.distinct_on {
1717 for e in distinct_on {
1718 quote_identifiers_recursive(e, reserved_words);
1719 }
1720 }
1721 if let Some(ref mut limit_by) = select.limit_by {
1722 for e in limit_by {
1723 quote_identifiers_recursive(e, reserved_words);
1724 }
1725 }
1726 if let Some(ref mut settings) = select.settings {
1727 for e in settings {
1728 quote_identifiers_recursive(e, reserved_words);
1729 }
1730 }
1731 if let Some(ref mut format) = select.format {
1732 quote_identifiers_recursive(format, reserved_words);
1733 }
1734 }
1735
1736 Expression::Union(u) => {
1738 quote_identifiers_recursive(&mut u.left, reserved_words);
1739 quote_identifiers_recursive(&mut u.right, reserved_words);
1740 if let Some(ref mut with) = u.with {
1741 quote_with(with, reserved_words);
1742 }
1743 if let Some(ref mut ob) = u.order_by {
1744 for o in &mut ob.expressions {
1745 quote_identifiers_recursive(&mut o.this, reserved_words);
1746 }
1747 }
1748 if let Some(ref mut lim) = u.limit {
1749 quote_identifiers_recursive(lim, reserved_words);
1750 }
1751 if let Some(ref mut off) = u.offset {
1752 quote_identifiers_recursive(off, reserved_words);
1753 }
1754 }
1755 Expression::Intersect(i) => {
1756 quote_identifiers_recursive(&mut i.left, reserved_words);
1757 quote_identifiers_recursive(&mut i.right, reserved_words);
1758 if let Some(ref mut with) = i.with {
1759 quote_with(with, reserved_words);
1760 }
1761 if let Some(ref mut ob) = i.order_by {
1762 for o in &mut ob.expressions {
1763 quote_identifiers_recursive(&mut o.this, reserved_words);
1764 }
1765 }
1766 }
1767 Expression::Except(e) => {
1768 quote_identifiers_recursive(&mut e.left, reserved_words);
1769 quote_identifiers_recursive(&mut e.right, reserved_words);
1770 if let Some(ref mut with) = e.with {
1771 quote_with(with, reserved_words);
1772 }
1773 if let Some(ref mut ob) = e.order_by {
1774 for o in &mut ob.expressions {
1775 quote_identifiers_recursive(&mut o.this, reserved_words);
1776 }
1777 }
1778 }
1779
1780 Expression::Subquery(sq) => {
1782 quote_identifiers_recursive(&mut sq.this, reserved_words);
1783 if let Some(ref mut alias) = sq.alias {
1784 maybe_quote(alias, reserved_words);
1785 }
1786 for ca in &mut sq.column_aliases {
1787 maybe_quote(ca, reserved_words);
1788 }
1789 if let Some(ref mut ob) = sq.order_by {
1790 for o in &mut ob.expressions {
1791 quote_identifiers_recursive(&mut o.this, reserved_words);
1792 }
1793 }
1794 }
1795
1796 Expression::Insert(ins) => {
1798 quote_table_ref(&mut ins.table, reserved_words);
1799 for c in &mut ins.columns {
1800 maybe_quote(c, reserved_words);
1801 }
1802 for row in &mut ins.values {
1803 for e in row {
1804 quote_identifiers_recursive(e, reserved_words);
1805 }
1806 }
1807 if let Some(ref mut q) = ins.query {
1808 quote_identifiers_recursive(q, reserved_words);
1809 }
1810 for (id, val) in &mut ins.partition {
1811 maybe_quote(id, reserved_words);
1812 if let Some(ref mut v) = val {
1813 quote_identifiers_recursive(v, reserved_words);
1814 }
1815 }
1816 for e in &mut ins.returning {
1817 quote_identifiers_recursive(e, reserved_words);
1818 }
1819 if let Some(ref mut on_conflict) = ins.on_conflict {
1820 quote_identifiers_recursive(on_conflict, reserved_words);
1821 }
1822 if let Some(ref mut with) = ins.with {
1823 quote_with(with, reserved_words);
1824 }
1825 if let Some(ref mut alias) = ins.alias {
1826 maybe_quote(alias, reserved_words);
1827 }
1828 if let Some(ref mut src_alias) = ins.source_alias {
1829 maybe_quote(src_alias, reserved_words);
1830 }
1831 }
1832
1833 Expression::Update(upd) => {
1834 quote_table_ref(&mut upd.table, reserved_words);
1835 for tr in &mut upd.extra_tables {
1836 quote_table_ref(tr, reserved_words);
1837 }
1838 for join in &mut upd.table_joins {
1839 quote_join(join, reserved_words);
1840 }
1841 for (id, val) in &mut upd.set {
1842 maybe_quote(id, reserved_words);
1843 quote_identifiers_recursive(val, reserved_words);
1844 }
1845 if let Some(ref mut from) = upd.from_clause {
1846 for e in &mut from.expressions {
1847 quote_identifiers_recursive(e, reserved_words);
1848 }
1849 }
1850 for join in &mut upd.from_joins {
1851 quote_join(join, reserved_words);
1852 }
1853 if let Some(ref mut wh) = upd.where_clause {
1854 quote_identifiers_recursive(&mut wh.this, reserved_words);
1855 }
1856 for e in &mut upd.returning {
1857 quote_identifiers_recursive(e, reserved_words);
1858 }
1859 if let Some(ref mut with) = upd.with {
1860 quote_with(with, reserved_words);
1861 }
1862 }
1863
1864 Expression::Delete(del) => {
1865 quote_table_ref(&mut del.table, reserved_words);
1866 if let Some(ref mut alias) = del.alias {
1867 maybe_quote(alias, reserved_words);
1868 }
1869 for tr in &mut del.using {
1870 quote_table_ref(tr, reserved_words);
1871 }
1872 if let Some(ref mut wh) = del.where_clause {
1873 quote_identifiers_recursive(&mut wh.this, reserved_words);
1874 }
1875 if let Some(ref mut with) = del.with {
1876 quote_with(with, reserved_words);
1877 }
1878 }
1879
1880 Expression::And(bin)
1882 | Expression::Or(bin)
1883 | Expression::Eq(bin)
1884 | Expression::Neq(bin)
1885 | Expression::Lt(bin)
1886 | Expression::Lte(bin)
1887 | Expression::Gt(bin)
1888 | Expression::Gte(bin)
1889 | Expression::Add(bin)
1890 | Expression::Sub(bin)
1891 | Expression::Mul(bin)
1892 | Expression::Div(bin)
1893 | Expression::Mod(bin)
1894 | Expression::BitwiseAnd(bin)
1895 | Expression::BitwiseOr(bin)
1896 | Expression::BitwiseXor(bin)
1897 | Expression::Concat(bin)
1898 | Expression::Adjacent(bin)
1899 | Expression::TsMatch(bin)
1900 | Expression::PropertyEQ(bin)
1901 | Expression::ArrayContainsAll(bin)
1902 | Expression::ArrayContainedBy(bin)
1903 | Expression::ArrayOverlaps(bin)
1904 | Expression::JSONBContainsAllTopKeys(bin)
1905 | Expression::JSONBContainsAnyTopKeys(bin)
1906 | Expression::JSONBDeleteAtPath(bin)
1907 | Expression::ExtendsLeft(bin)
1908 | Expression::ExtendsRight(bin)
1909 | Expression::Is(bin)
1910 | Expression::NullSafeEq(bin)
1911 | Expression::NullSafeNeq(bin)
1912 | Expression::Glob(bin)
1913 | Expression::Match(bin)
1914 | Expression::MemberOf(bin)
1915 | Expression::BitwiseLeftShift(bin)
1916 | Expression::BitwiseRightShift(bin) => {
1917 quote_identifiers_recursive(&mut bin.left, reserved_words);
1918 quote_identifiers_recursive(&mut bin.right, reserved_words);
1919 }
1920
1921 Expression::Like(like) | Expression::ILike(like) => {
1923 quote_identifiers_recursive(&mut like.left, reserved_words);
1924 quote_identifiers_recursive(&mut like.right, reserved_words);
1925 if let Some(ref mut esc) = like.escape {
1926 quote_identifiers_recursive(esc, reserved_words);
1927 }
1928 }
1929
1930 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1932 quote_identifiers_recursive(&mut un.this, reserved_words);
1933 }
1934
1935 Expression::In(in_expr) => {
1937 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1938 for e in &mut in_expr.expressions {
1939 quote_identifiers_recursive(e, reserved_words);
1940 }
1941 if let Some(ref mut q) = in_expr.query {
1942 quote_identifiers_recursive(q, reserved_words);
1943 }
1944 if let Some(ref mut un) = in_expr.unnest {
1945 quote_identifiers_recursive(un, reserved_words);
1946 }
1947 }
1948
1949 Expression::Between(bw) => {
1950 quote_identifiers_recursive(&mut bw.this, reserved_words);
1951 quote_identifiers_recursive(&mut bw.low, reserved_words);
1952 quote_identifiers_recursive(&mut bw.high, reserved_words);
1953 }
1954
1955 Expression::IsNull(is_null) => {
1956 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1957 }
1958
1959 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
1960 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
1961 }
1962
1963 Expression::Exists(ex) => {
1964 quote_identifiers_recursive(&mut ex.this, reserved_words);
1965 }
1966
1967 Expression::Function(func) => {
1969 for arg in &mut func.args {
1970 quote_identifiers_recursive(arg, reserved_words);
1971 }
1972 }
1973
1974 Expression::AggregateFunction(agg) => {
1975 for arg in &mut agg.args {
1976 quote_identifiers_recursive(arg, reserved_words);
1977 }
1978 if let Some(ref mut filter) = agg.filter {
1979 quote_identifiers_recursive(filter, reserved_words);
1980 }
1981 for o in &mut agg.order_by {
1982 quote_identifiers_recursive(&mut o.this, reserved_words);
1983 }
1984 }
1985
1986 Expression::WindowFunction(wf) => {
1987 quote_identifiers_recursive(&mut wf.this, reserved_words);
1988 quote_over(&mut wf.over, reserved_words);
1989 }
1990
1991 Expression::Case(case) => {
1993 if let Some(ref mut operand) = case.operand {
1994 quote_identifiers_recursive(operand, reserved_words);
1995 }
1996 for (when, then) in &mut case.whens {
1997 quote_identifiers_recursive(when, reserved_words);
1998 quote_identifiers_recursive(then, reserved_words);
1999 }
2000 if let Some(ref mut else_) = case.else_ {
2001 quote_identifiers_recursive(else_, reserved_words);
2002 }
2003 }
2004
2005 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
2007 quote_identifiers_recursive(&mut cast.this, reserved_words);
2008 if let Some(ref mut fmt) = cast.format {
2009 quote_identifiers_recursive(fmt, reserved_words);
2010 }
2011 }
2012
2013 Expression::Paren(paren) => {
2015 quote_identifiers_recursive(&mut paren.this, reserved_words);
2016 }
2017
2018 Expression::Annotated(ann) => {
2019 quote_identifiers_recursive(&mut ann.this, reserved_words);
2020 }
2021
2022 Expression::With(with) => {
2024 quote_with(with, reserved_words);
2025 }
2026
2027 Expression::Cte(cte) => {
2028 maybe_quote(&mut cte.alias, reserved_words);
2029 for c in &mut cte.columns {
2030 maybe_quote(c, reserved_words);
2031 }
2032 quote_identifiers_recursive(&mut cte.this, reserved_words);
2033 }
2034
2035 Expression::From(from) => {
2037 for e in &mut from.expressions {
2038 quote_identifiers_recursive(e, reserved_words);
2039 }
2040 }
2041
2042 Expression::Join(join) => {
2043 quote_join(join, reserved_words);
2044 }
2045
2046 Expression::JoinedTable(jt) => {
2047 quote_identifiers_recursive(&mut jt.left, reserved_words);
2048 for join in &mut jt.joins {
2049 quote_join(join, reserved_words);
2050 }
2051 if let Some(ref mut alias) = jt.alias {
2052 maybe_quote(alias, reserved_words);
2053 }
2054 }
2055
2056 Expression::Where(wh) => {
2057 quote_identifiers_recursive(&mut wh.this, reserved_words);
2058 }
2059
2060 Expression::GroupBy(gb) => {
2061 for e in &mut gb.expressions {
2062 quote_identifiers_recursive(e, reserved_words);
2063 }
2064 }
2065
2066 Expression::Having(hv) => {
2067 quote_identifiers_recursive(&mut hv.this, reserved_words);
2068 }
2069
2070 Expression::OrderBy(ob) => {
2071 for o in &mut ob.expressions {
2072 quote_identifiers_recursive(&mut o.this, reserved_words);
2073 }
2074 }
2075
2076 Expression::Ordered(ord) => {
2077 quote_identifiers_recursive(&mut ord.this, reserved_words);
2078 }
2079
2080 Expression::Limit(lim) => {
2081 quote_identifiers_recursive(&mut lim.this, reserved_words);
2082 }
2083
2084 Expression::Offset(off) => {
2085 quote_identifiers_recursive(&mut off.this, reserved_words);
2086 }
2087
2088 Expression::Qualify(q) => {
2089 quote_identifiers_recursive(&mut q.this, reserved_words);
2090 }
2091
2092 Expression::Window(ws) => {
2093 for e in &mut ws.partition_by {
2094 quote_identifiers_recursive(e, reserved_words);
2095 }
2096 for o in &mut ws.order_by {
2097 quote_identifiers_recursive(&mut o.this, reserved_words);
2098 }
2099 }
2100
2101 Expression::Over(over) => {
2102 quote_over(over, reserved_words);
2103 }
2104
2105 Expression::WithinGroup(wg) => {
2106 quote_identifiers_recursive(&mut wg.this, reserved_words);
2107 for o in &mut wg.order_by {
2108 quote_identifiers_recursive(&mut o.this, reserved_words);
2109 }
2110 }
2111
2112 Expression::Pivot(piv) => {
2114 quote_identifiers_recursive(&mut piv.this, reserved_words);
2115 for e in &mut piv.expressions {
2116 quote_identifiers_recursive(e, reserved_words);
2117 }
2118 for f in &mut piv.fields {
2119 quote_identifiers_recursive(f, reserved_words);
2120 }
2121 if let Some(ref mut alias) = piv.alias {
2122 maybe_quote(alias, reserved_words);
2123 }
2124 }
2125
2126 Expression::Unpivot(unpiv) => {
2127 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
2128 maybe_quote(&mut unpiv.value_column, reserved_words);
2129 maybe_quote(&mut unpiv.name_column, reserved_words);
2130 for e in &mut unpiv.columns {
2131 quote_identifiers_recursive(e, reserved_words);
2132 }
2133 if let Some(ref mut alias) = unpiv.alias {
2134 maybe_quote(alias, reserved_words);
2135 }
2136 }
2137
2138 Expression::Values(vals) => {
2140 for tuple in &mut vals.expressions {
2141 for e in &mut tuple.expressions {
2142 quote_identifiers_recursive(e, reserved_words);
2143 }
2144 }
2145 if let Some(ref mut alias) = vals.alias {
2146 maybe_quote(alias, reserved_words);
2147 }
2148 for ca in &mut vals.column_aliases {
2149 maybe_quote(ca, reserved_words);
2150 }
2151 }
2152
2153 Expression::Array(arr) => {
2155 for e in &mut arr.expressions {
2156 quote_identifiers_recursive(e, reserved_words);
2157 }
2158 }
2159
2160 Expression::Struct(st) => {
2161 for (_name, e) in &mut st.fields {
2162 quote_identifiers_recursive(e, reserved_words);
2163 }
2164 }
2165
2166 Expression::Tuple(tup) => {
2167 for e in &mut tup.expressions {
2168 quote_identifiers_recursive(e, reserved_words);
2169 }
2170 }
2171
2172 Expression::Subscript(sub) => {
2174 quote_identifiers_recursive(&mut sub.this, reserved_words);
2175 quote_identifiers_recursive(&mut sub.index, reserved_words);
2176 }
2177
2178 Expression::Dot(dot) => {
2179 quote_identifiers_recursive(&mut dot.this, reserved_words);
2180 maybe_quote(&mut dot.field, reserved_words);
2181 }
2182
2183 Expression::ScopeResolution(sr) => {
2184 if let Some(ref mut this) = sr.this {
2185 quote_identifiers_recursive(this, reserved_words);
2186 }
2187 quote_identifiers_recursive(&mut sr.expression, reserved_words);
2188 }
2189
2190 Expression::Lateral(lat) => {
2192 quote_identifiers_recursive(&mut lat.this, reserved_words);
2193 }
2195
2196 Expression::DPipe(dpipe) => {
2198 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
2199 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
2200 }
2201
2202 Expression::Merge(merge) => {
2204 quote_identifiers_recursive(&mut merge.this, reserved_words);
2205 quote_identifiers_recursive(&mut merge.using, reserved_words);
2206 if let Some(ref mut on) = merge.on {
2207 quote_identifiers_recursive(on, reserved_words);
2208 }
2209 if let Some(ref mut whens) = merge.whens {
2210 quote_identifiers_recursive(whens, reserved_words);
2211 }
2212 if let Some(ref mut with) = merge.with_ {
2213 quote_identifiers_recursive(with, reserved_words);
2214 }
2215 if let Some(ref mut ret) = merge.returning {
2216 quote_identifiers_recursive(ret, reserved_words);
2217 }
2218 }
2219
2220 Expression::LateralView(lv) => {
2222 quote_lateral_view(lv, reserved_words);
2223 }
2224
2225 Expression::Anonymous(anon) => {
2227 quote_identifiers_recursive(&mut anon.this, reserved_words);
2228 for e in &mut anon.expressions {
2229 quote_identifiers_recursive(e, reserved_words);
2230 }
2231 }
2232
2233 Expression::Filter(filter) => {
2235 quote_identifiers_recursive(&mut filter.this, reserved_words);
2236 quote_identifiers_recursive(&mut filter.expression, reserved_words);
2237 }
2238
2239 Expression::Returning(ret) => {
2241 for e in &mut ret.expressions {
2242 quote_identifiers_recursive(e, reserved_words);
2243 }
2244 }
2245
2246 Expression::BracedWildcard(inner) => {
2248 quote_identifiers_recursive(inner, reserved_words);
2249 }
2250
2251 Expression::ReturnStmt(inner) => {
2253 quote_identifiers_recursive(inner, reserved_words);
2254 }
2255
2256 Expression::Literal(_)
2258 | Expression::Boolean(_)
2259 | Expression::Null(_)
2260 | Expression::DataType(_)
2261 | Expression::Raw(_)
2262 | Expression::Placeholder(_)
2263 | Expression::CurrentDate(_)
2264 | Expression::CurrentTime(_)
2265 | Expression::CurrentTimestamp(_)
2266 | Expression::CurrentTimestampLTZ(_)
2267 | Expression::SessionUser(_)
2268 | Expression::RowNumber(_)
2269 | Expression::Rank(_)
2270 | Expression::DenseRank(_)
2271 | Expression::PercentRank(_)
2272 | Expression::CumeDist(_)
2273 | Expression::Random(_)
2274 | Expression::Pi(_)
2275 | Expression::JSONPathRoot(_) => {
2276 }
2278
2279 _ => {}
2283 }
2284}
2285
2286fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
2288 quote_identifiers_recursive(&mut join.this, reserved_words);
2289 if let Some(ref mut on) = join.on {
2290 quote_identifiers_recursive(on, reserved_words);
2291 }
2292 for id in &mut join.using {
2293 maybe_quote(id, reserved_words);
2294 }
2295 if let Some(ref mut mc) = join.match_condition {
2296 quote_identifiers_recursive(mc, reserved_words);
2297 }
2298 for piv in &mut join.pivots {
2299 quote_identifiers_recursive(piv, reserved_words);
2300 }
2301}
2302
2303fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
2305 for cte in &mut with.ctes {
2306 maybe_quote(&mut cte.alias, reserved_words);
2307 for c in &mut cte.columns {
2308 maybe_quote(c, reserved_words);
2309 }
2310 for k in &mut cte.key_expressions {
2311 maybe_quote(k, reserved_words);
2312 }
2313 quote_identifiers_recursive(&mut cte.this, reserved_words);
2314 }
2315}
2316
2317fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
2319 if let Some(ref mut wn) = over.window_name {
2320 maybe_quote(wn, reserved_words);
2321 }
2322 for e in &mut over.partition_by {
2323 quote_identifiers_recursive(e, reserved_words);
2324 }
2325 for o in &mut over.order_by {
2326 quote_identifiers_recursive(&mut o.this, reserved_words);
2327 }
2328 if let Some(ref mut alias) = over.alias {
2329 maybe_quote(alias, reserved_words);
2330 }
2331}
2332
2333fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
2335 maybe_quote(&mut table_ref.name, reserved_words);
2336 if let Some(ref mut schema) = table_ref.schema {
2337 maybe_quote(schema, reserved_words);
2338 }
2339 if let Some(ref mut catalog) = table_ref.catalog {
2340 maybe_quote(catalog, reserved_words);
2341 }
2342 if let Some(ref mut alias) = table_ref.alias {
2343 maybe_quote(alias, reserved_words);
2344 }
2345 for ca in &mut table_ref.column_aliases {
2346 maybe_quote(ca, reserved_words);
2347 }
2348 for p in &mut table_ref.partitions {
2349 maybe_quote(p, reserved_words);
2350 }
2351 for h in &mut table_ref.hints {
2352 quote_identifiers_recursive(h, reserved_words);
2353 }
2354}
2355
2356fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
2358 quote_identifiers_recursive(&mut lv.this, reserved_words);
2359 if let Some(ref mut ta) = lv.table_alias {
2360 maybe_quote(ta, reserved_words);
2361 }
2362 for ca in &mut lv.column_aliases {
2363 maybe_quote(ca, reserved_words);
2364 }
2365}
2366
2367pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
2378 let reserved_words = get_reserved_words(dialect);
2379 let mut result = expression;
2380 quote_identifiers_recursive(&mut result, &reserved_words);
2381 result
2382}
2383
2384pub fn pushdown_cte_alias_columns(_scope: &Scope) {
2389 }
2392
2393fn pushdown_cte_alias_columns_with(with: &mut With) {
2394 for cte in &mut with.ctes {
2395 if cte.columns.is_empty() {
2396 continue;
2397 }
2398
2399 if let Expression::Select(select) = &mut cte.this {
2400 let mut next_expressions = Vec::with_capacity(select.expressions.len());
2401
2402 for (i, projection) in select.expressions.iter().enumerate() {
2403 let Some(alias_name) = cte.columns.get(i) else {
2404 next_expressions.push(projection.clone());
2405 continue;
2406 };
2407
2408 match projection {
2409 Expression::Alias(existing) => {
2410 let mut aliased = existing.clone();
2411 aliased.alias = alias_name.clone();
2412 next_expressions.push(Expression::Alias(aliased));
2413 }
2414 _ => {
2415 next_expressions.push(create_alias(projection.clone(), &alias_name.name));
2416 }
2417 }
2418 }
2419
2420 select.expressions = next_expressions;
2421 }
2422 }
2423}
2424
2425fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
2431 let mut columns = Vec::new();
2432 collect_columns(&scope.expression, &mut columns);
2433 columns
2434}
2435
2436#[derive(Debug, Clone)]
2438struct ColumnRef {
2439 table: Option<String>,
2440 name: String,
2441}
2442
2443fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
2445 match expr {
2446 Expression::Column(col) => {
2447 columns.push(ColumnRef {
2448 table: col.table.as_ref().map(|t| t.name.clone()),
2449 name: col.name.name.clone(),
2450 });
2451 }
2452 Expression::Select(select) => {
2453 for e in &select.expressions {
2454 collect_columns(e, columns);
2455 }
2456 if let Some(from) = &select.from {
2457 for e in &from.expressions {
2458 collect_columns(e, columns);
2459 }
2460 }
2461 if let Some(where_clause) = &select.where_clause {
2462 collect_columns(&where_clause.this, columns);
2463 }
2464 if let Some(group_by) = &select.group_by {
2465 for e in &group_by.expressions {
2466 collect_columns(e, columns);
2467 }
2468 }
2469 if let Some(having) = &select.having {
2470 collect_columns(&having.this, columns);
2471 }
2472 if let Some(order_by) = &select.order_by {
2473 for o in &order_by.expressions {
2474 collect_columns(&o.this, columns);
2475 }
2476 }
2477 for join in &select.joins {
2478 collect_columns(&join.this, columns);
2479 if let Some(on) = &join.on {
2480 collect_columns(on, columns);
2481 }
2482 }
2483 }
2484 Expression::Alias(alias) => {
2485 collect_columns(&alias.this, columns);
2486 }
2487 Expression::Function(func) => {
2488 for arg in &func.args {
2489 collect_columns(arg, columns);
2490 }
2491 }
2492 Expression::AggregateFunction(agg) => {
2493 for arg in &agg.args {
2494 collect_columns(arg, columns);
2495 }
2496 }
2497 Expression::And(bin)
2498 | Expression::Or(bin)
2499 | Expression::Eq(bin)
2500 | Expression::Neq(bin)
2501 | Expression::Lt(bin)
2502 | Expression::Lte(bin)
2503 | Expression::Gt(bin)
2504 | Expression::Gte(bin)
2505 | Expression::Add(bin)
2506 | Expression::Sub(bin)
2507 | Expression::Mul(bin)
2508 | Expression::Div(bin) => {
2509 collect_columns(&bin.left, columns);
2510 collect_columns(&bin.right, columns);
2511 }
2512 Expression::Not(unary) | Expression::Neg(unary) => {
2513 collect_columns(&unary.this, columns);
2514 }
2515 Expression::Paren(paren) => {
2516 collect_columns(&paren.this, columns);
2517 }
2518 Expression::Case(case) => {
2519 if let Some(operand) = &case.operand {
2520 collect_columns(operand, columns);
2521 }
2522 for (when, then) in &case.whens {
2523 collect_columns(when, columns);
2524 collect_columns(then, columns);
2525 }
2526 if let Some(else_) = &case.else_ {
2527 collect_columns(else_, columns);
2528 }
2529 }
2530 Expression::Cast(cast) => {
2531 collect_columns(&cast.this, columns);
2532 }
2533 Expression::In(in_expr) => {
2534 collect_columns(&in_expr.this, columns);
2535 for e in &in_expr.expressions {
2536 collect_columns(e, columns);
2537 }
2538 if let Some(query) = &in_expr.query {
2539 collect_columns(query, columns);
2540 }
2541 }
2542 Expression::Between(between) => {
2543 collect_columns(&between.this, columns);
2544 collect_columns(&between.low, columns);
2545 collect_columns(&between.high, columns);
2546 }
2547 Expression::Subquery(subquery) => {
2548 collect_columns(&subquery.this, columns);
2549 }
2550 _ => {}
2551 }
2552}
2553
2554fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
2556 get_scope_columns(scope)
2557 .into_iter()
2558 .filter(|c| c.table.is_none())
2559 .collect()
2560}
2561
2562fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
2564 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
2565
2566 get_scope_columns(scope)
2567 .into_iter()
2568 .filter(|c| {
2569 if let Some(table) = &c.table {
2570 !source_names.contains(table)
2571 } else {
2572 false
2573 }
2574 })
2575 .collect()
2576}
2577
2578fn is_correlated_subquery(scope: &Scope) -> bool {
2580 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2581}
2582
2583fn is_star_column(col: &Column) -> bool {
2585 col.name.name == "*"
2586}
2587
2588fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2590 Expression::boxed_column(Column {
2591 name: Identifier::new(name),
2592 table: table.map(Identifier::new),
2593 join_mark: false,
2594 trailing_comments: vec![],
2595 span: None,
2596 inferred_type: None,
2597 })
2598}
2599
2600fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2602 Expression::Alias(Box::new(Alias {
2603 this: expr,
2604 alias: Identifier::new(alias_name),
2605 column_aliases: vec![],
2606 pre_alias_comments: vec![],
2607 trailing_comments: vec![],
2608 inferred_type: None,
2609 }))
2610}
2611
2612fn get_output_name(expr: &Expression) -> Option<String> {
2614 match expr {
2615 Expression::Column(col) => Some(col.name.name.clone()),
2616 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2617 Expression::Identifier(id) => Some(id.name.clone()),
2618 _ => None,
2619 }
2620}
2621
2622#[cfg(test)]
2623mod tests {
2624 use super::*;
2625 use crate::expressions::DataType;
2626 use crate::generator::Generator;
2627 use crate::parser::Parser;
2628 use crate::scope::build_scope;
2629 use crate::{MappingSchema, Schema};
2630
2631 fn gen(expr: &Expression) -> String {
2632 Generator::new().generate(expr).unwrap()
2633 }
2634
2635 fn parse(sql: &str) -> Expression {
2636 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2637 }
2638
2639 #[test]
2640 fn test_qualify_columns_options() {
2641 let options = QualifyColumnsOptions::new()
2642 .with_expand_alias_refs(true)
2643 .with_expand_stars(false)
2644 .with_dialect(DialectType::PostgreSQL)
2645 .with_allow_partial(true);
2646
2647 assert!(options.expand_alias_refs);
2648 assert!(!options.expand_stars);
2649 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2650 assert!(options.allow_partial_qualification);
2651 }
2652
2653 #[test]
2654 fn test_get_scope_columns() {
2655 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2656 let scope = build_scope(&expr);
2657 let columns = get_scope_columns(&scope);
2658
2659 assert!(columns.iter().any(|c| c.name == "a"));
2660 assert!(columns.iter().any(|c| c.name == "b"));
2661 assert!(columns.iter().any(|c| c.name == "c"));
2662 }
2663
2664 #[test]
2665 fn test_get_unqualified_columns() {
2666 let expr = parse("SELECT t.a, b FROM t");
2667 let scope = build_scope(&expr);
2668 let unqualified = get_unqualified_columns(&scope);
2669
2670 assert!(unqualified.iter().any(|c| c.name == "b"));
2672 assert!(!unqualified.iter().any(|c| c.name == "a"));
2673 }
2674
2675 #[test]
2676 fn test_is_star_column() {
2677 let col = Column {
2678 name: Identifier::new("*"),
2679 table: Some(Identifier::new("t")),
2680 join_mark: false,
2681 trailing_comments: vec![],
2682 span: None,
2683 inferred_type: None,
2684 };
2685 assert!(is_star_column(&col));
2686
2687 let col2 = Column {
2688 name: Identifier::new("id"),
2689 table: None,
2690 join_mark: false,
2691 trailing_comments: vec![],
2692 span: None,
2693 inferred_type: None,
2694 };
2695 assert!(!is_star_column(&col2));
2696 }
2697
2698 #[test]
2699 fn test_create_qualified_column() {
2700 let expr = create_qualified_column("id", Some("users"));
2701 let sql = gen(&expr);
2702 assert!(sql.contains("users"));
2703 assert!(sql.contains("id"));
2704 }
2705
2706 #[test]
2707 fn test_create_alias() {
2708 let col = Expression::boxed_column(Column {
2709 name: Identifier::new("value"),
2710 table: None,
2711 join_mark: false,
2712 trailing_comments: vec![],
2713 span: None,
2714 inferred_type: None,
2715 });
2716 let aliased = create_alias(col, "total");
2717 let sql = gen(&aliased);
2718 assert!(sql.contains("AS") || sql.contains("total"));
2719 }
2720
2721 #[test]
2722 fn test_validate_qualify_columns_success() {
2723 let expr = parse("SELECT t.a, t.b FROM t");
2725 let result = validate_qualify_columns(&expr);
2726 let _ = result;
2729 }
2730
2731 #[test]
2732 fn test_collect_columns_nested() {
2733 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2734 let mut columns = Vec::new();
2735 collect_columns(&expr, &mut columns);
2736
2737 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2738 assert!(names.contains(&"a"));
2739 assert!(names.contains(&"b"));
2740 assert!(names.contains(&"c"));
2741 assert!(names.contains(&"d"));
2742 assert!(names.contains(&"e"));
2743 assert!(names.contains(&"f"));
2744 }
2745
2746 #[test]
2747 fn test_collect_columns_in_case() {
2748 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2749 let mut columns = Vec::new();
2750 collect_columns(&expr, &mut columns);
2751
2752 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2753 assert!(names.contains(&"a"));
2754 assert!(names.contains(&"b"));
2755 assert!(names.contains(&"c"));
2756 }
2757
2758 #[test]
2759 fn test_collect_columns_in_subquery() {
2760 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2761 let mut columns = Vec::new();
2762 collect_columns(&expr, &mut columns);
2763
2764 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2765 assert!(names.contains(&"a"));
2766 assert!(names.contains(&"b"));
2767 assert!(names.contains(&"c"));
2768 }
2769
2770 #[test]
2771 fn test_qualify_outputs_basic() {
2772 let expr = parse("SELECT a, b + c FROM t");
2773 let scope = build_scope(&expr);
2774 let result = qualify_outputs(&scope);
2775 assert!(result.is_ok());
2776 }
2777
2778 #[test]
2779 fn test_qualify_columns_expands_star_with_schema() {
2780 let expr = parse("SELECT * FROM users");
2781
2782 let mut schema = MappingSchema::new();
2783 schema
2784 .add_table(
2785 "users",
2786 &[
2787 (
2788 "id".to_string(),
2789 DataType::Int {
2790 length: None,
2791 integer_spelling: false,
2792 },
2793 ),
2794 ("name".to_string(), DataType::Text),
2795 ("email".to_string(), DataType::Text),
2796 ],
2797 None,
2798 )
2799 .expect("schema setup");
2800
2801 let result =
2802 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2803 let sql = gen(&result);
2804
2805 assert!(!sql.contains("SELECT *"));
2806 assert!(sql.contains("users.id"));
2807 assert!(sql.contains("users.name"));
2808 assert!(sql.contains("users.email"));
2809 }
2810
2811 #[test]
2812 fn test_qualify_columns_expands_group_by_positions() {
2813 let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
2814
2815 let mut schema = MappingSchema::new();
2816 schema
2817 .add_table(
2818 "t",
2819 &[
2820 (
2821 "a".to_string(),
2822 DataType::Int {
2823 length: None,
2824 integer_spelling: false,
2825 },
2826 ),
2827 (
2828 "b".to_string(),
2829 DataType::Int {
2830 length: None,
2831 integer_spelling: false,
2832 },
2833 ),
2834 ],
2835 None,
2836 )
2837 .expect("schema setup");
2838
2839 let result =
2840 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2841 let sql = gen(&result);
2842
2843 assert!(!sql.contains("GROUP BY 1"));
2844 assert!(!sql.contains("GROUP BY 2"));
2845 assert!(sql.contains("GROUP BY"));
2846 assert!(sql.contains("t.a"));
2847 assert!(sql.contains("t.b"));
2848 }
2849
2850 #[test]
2855 fn test_expand_using_simple() {
2856 let expr = parse("SELECT x.b FROM x JOIN y USING (b)");
2858
2859 let mut schema = MappingSchema::new();
2860 schema
2861 .add_table(
2862 "x",
2863 &[
2864 ("a".to_string(), DataType::BigInt { length: None }),
2865 ("b".to_string(), DataType::BigInt { length: None }),
2866 ],
2867 None,
2868 )
2869 .expect("schema setup");
2870 schema
2871 .add_table(
2872 "y",
2873 &[
2874 ("b".to_string(), DataType::BigInt { length: None }),
2875 ("c".to_string(), DataType::BigInt { length: None }),
2876 ],
2877 None,
2878 )
2879 .expect("schema setup");
2880
2881 let result =
2882 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2883 let sql = gen(&result);
2884
2885 assert!(
2887 !sql.contains("USING"),
2888 "USING should be replaced with ON: {sql}"
2889 );
2890 assert!(
2891 sql.contains("ON x.b = y.b"),
2892 "ON condition should be x.b = y.b: {sql}"
2893 );
2894 assert!(sql.contains("SELECT x.b"), "SELECT should keep x.b: {sql}");
2896 }
2897
2898 #[test]
2899 fn test_expand_using_unqualified_coalesce() {
2900 let expr = parse("SELECT b FROM x JOIN y USING(b)");
2902
2903 let mut schema = MappingSchema::new();
2904 schema
2905 .add_table(
2906 "x",
2907 &[
2908 ("a".to_string(), DataType::BigInt { length: None }),
2909 ("b".to_string(), DataType::BigInt { length: None }),
2910 ],
2911 None,
2912 )
2913 .expect("schema setup");
2914 schema
2915 .add_table(
2916 "y",
2917 &[
2918 ("b".to_string(), DataType::BigInt { length: None }),
2919 ("c".to_string(), DataType::BigInt { length: None }),
2920 ],
2921 None,
2922 )
2923 .expect("schema setup");
2924
2925 let result =
2926 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2927 let sql = gen(&result);
2928
2929 assert!(
2930 sql.contains("COALESCE(x.b, y.b)"),
2931 "Unqualified USING column should become COALESCE: {sql}"
2932 );
2933 assert!(
2934 sql.contains("AS b"),
2935 "COALESCE should be aliased as 'b': {sql}"
2936 );
2937 assert!(
2938 sql.contains("ON x.b = y.b"),
2939 "ON condition should be generated: {sql}"
2940 );
2941 }
2942
2943 #[test]
2944 fn test_expand_using_with_where() {
2945 let expr = parse("SELECT b FROM x JOIN y USING(b) WHERE b = 1");
2947
2948 let mut schema = MappingSchema::new();
2949 schema
2950 .add_table(
2951 "x",
2952 &[("b".to_string(), DataType::BigInt { length: None })],
2953 None,
2954 )
2955 .expect("schema setup");
2956 schema
2957 .add_table(
2958 "y",
2959 &[("b".to_string(), DataType::BigInt { length: None })],
2960 None,
2961 )
2962 .expect("schema setup");
2963
2964 let result =
2965 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2966 let sql = gen(&result);
2967
2968 assert!(
2969 sql.contains("WHERE COALESCE(x.b, y.b)"),
2970 "WHERE should use COALESCE for USING column: {sql}"
2971 );
2972 }
2973
2974 #[test]
2975 fn test_expand_using_multi_join() {
2976 let expr = parse("SELECT b FROM x JOIN y USING(b) JOIN z USING(b)");
2978
2979 let mut schema = MappingSchema::new();
2980 for table in &["x", "y", "z"] {
2981 schema
2982 .add_table(
2983 table,
2984 &[("b".to_string(), DataType::BigInt { length: None })],
2985 None,
2986 )
2987 .expect("schema setup");
2988 }
2989
2990 let result =
2991 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2992 let sql = gen(&result);
2993
2994 assert!(
2996 sql.contains("COALESCE(x.b, y.b, z.b)"),
2997 "Should have 3-table COALESCE: {sql}"
2998 );
2999 assert!(
3001 sql.contains("ON x.b = y.b"),
3002 "First join ON condition: {sql}"
3003 );
3004 }
3005
3006 #[test]
3007 fn test_expand_using_multi_column() {
3008 let expr = parse("SELECT b, c FROM y JOIN z USING(b, c)");
3010
3011 let mut schema = MappingSchema::new();
3012 schema
3013 .add_table(
3014 "y",
3015 &[
3016 ("b".to_string(), DataType::BigInt { length: None }),
3017 ("c".to_string(), DataType::BigInt { length: None }),
3018 ],
3019 None,
3020 )
3021 .expect("schema setup");
3022 schema
3023 .add_table(
3024 "z",
3025 &[
3026 ("b".to_string(), DataType::BigInt { length: None }),
3027 ("c".to_string(), DataType::BigInt { length: None }),
3028 ],
3029 None,
3030 )
3031 .expect("schema setup");
3032
3033 let result =
3034 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3035 let sql = gen(&result);
3036
3037 assert!(
3038 sql.contains("COALESCE(y.b, z.b)"),
3039 "column 'b' should get COALESCE: {sql}"
3040 );
3041 assert!(
3042 sql.contains("COALESCE(y.c, z.c)"),
3043 "column 'c' should get COALESCE: {sql}"
3044 );
3045 assert!(
3047 sql.contains("y.b = z.b") && sql.contains("y.c = z.c"),
3048 "ON should have both equality conditions: {sql}"
3049 );
3050 }
3051
3052 #[test]
3053 fn test_expand_using_star() {
3054 let expr = parse("SELECT * FROM x JOIN y USING(b)");
3056
3057 let mut schema = MappingSchema::new();
3058 schema
3059 .add_table(
3060 "x",
3061 &[
3062 ("a".to_string(), DataType::BigInt { length: None }),
3063 ("b".to_string(), DataType::BigInt { length: None }),
3064 ],
3065 None,
3066 )
3067 .expect("schema setup");
3068 schema
3069 .add_table(
3070 "y",
3071 &[
3072 ("b".to_string(), DataType::BigInt { length: None }),
3073 ("c".to_string(), DataType::BigInt { length: None }),
3074 ],
3075 None,
3076 )
3077 .expect("schema setup");
3078
3079 let result =
3080 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3081 let sql = gen(&result);
3082
3083 assert!(
3085 sql.contains("COALESCE(x.b, y.b) AS b"),
3086 "USING column should be COALESCE in star expansion: {sql}"
3087 );
3088 assert!(sql.contains("x.a"), "non-USING column a from x: {sql}");
3090 assert!(sql.contains("y.c"), "non-USING column c from y: {sql}");
3091 let coalesce_count = sql.matches("COALESCE").count();
3093 assert_eq!(
3094 coalesce_count, 1,
3095 "b should appear only once as COALESCE: {sql}"
3096 );
3097 }
3098
3099 #[test]
3100 fn test_expand_using_table_star() {
3101 let expr = parse("SELECT x.* FROM x JOIN y USING(b)");
3103
3104 let mut schema = MappingSchema::new();
3105 schema
3106 .add_table(
3107 "x",
3108 &[
3109 ("a".to_string(), DataType::BigInt { length: None }),
3110 ("b".to_string(), DataType::BigInt { length: None }),
3111 ],
3112 None,
3113 )
3114 .expect("schema setup");
3115 schema
3116 .add_table(
3117 "y",
3118 &[
3119 ("b".to_string(), DataType::BigInt { length: None }),
3120 ("c".to_string(), DataType::BigInt { length: None }),
3121 ],
3122 None,
3123 )
3124 .expect("schema setup");
3125
3126 let result =
3127 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3128 let sql = gen(&result);
3129
3130 assert!(
3132 sql.contains("COALESCE(x.b, y.b)"),
3133 "USING column from x.* should become COALESCE: {sql}"
3134 );
3135 assert!(sql.contains("x.a"), "non-USING column a: {sql}");
3136 }
3137
3138 #[test]
3139 fn test_qualify_columns_qualified_table_name() {
3140 let expr = parse("SELECT a FROM raw.t1");
3141
3142 let mut schema = MappingSchema::new();
3143 schema
3144 .add_table(
3145 "raw.t1",
3146 &[("a".to_string(), DataType::BigInt { length: None })],
3147 None,
3148 )
3149 .expect("schema setup");
3150
3151 let result =
3152 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3153 let sql = gen(&result);
3154
3155 assert!(
3156 sql.contains("t1.a"),
3157 "column should be qualified with table name: {sql}"
3158 );
3159 }
3160
3161 #[test]
3162 fn test_qualify_columns_correlated_scalar_subquery() {
3163 let expr =
3164 parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1");
3165
3166 let mut schema = MappingSchema::new();
3167 schema
3168 .add_table(
3169 "t1",
3170 &[("id".to_string(), DataType::BigInt { length: None })],
3171 None,
3172 )
3173 .expect("schema setup");
3174 schema
3175 .add_table(
3176 "t2",
3177 &[
3178 ("id".to_string(), DataType::BigInt { length: None }),
3179 ("val".to_string(), DataType::BigInt { length: None }),
3180 ],
3181 None,
3182 )
3183 .expect("schema setup");
3184
3185 let result =
3186 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3187 let sql = gen(&result);
3188
3189 assert!(
3190 sql.contains("t1.id"),
3191 "outer column should be qualified: {sql}"
3192 );
3193 assert!(
3194 sql.contains("t2.id"),
3195 "inner column should be qualified: {sql}"
3196 );
3197 }
3198
3199 #[test]
3200 fn test_qualify_columns_correlated_scalar_subquery_unqualified() {
3201 let expr =
3202 parse("SELECT t1_id, (SELECT AVG(val) FROM t2 WHERE t2_id = t1_id) AS avg_val FROM t1");
3203
3204 let mut schema = MappingSchema::new();
3205 schema
3206 .add_table(
3207 "t1",
3208 &[("t1_id".to_string(), DataType::BigInt { length: None })],
3209 None,
3210 )
3211 .expect("schema setup");
3212 schema
3213 .add_table(
3214 "t2",
3215 &[
3216 ("t2_id".to_string(), DataType::BigInt { length: None }),
3217 ("val".to_string(), DataType::BigInt { length: None }),
3218 ],
3219 None,
3220 )
3221 .expect("schema setup");
3222
3223 let result =
3224 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3225 let sql = gen(&result);
3226
3227 assert!(
3228 sql.contains("t1.t1_id"),
3229 "outer column should be qualified: {sql}"
3230 );
3231 assert!(
3232 sql.contains("t2.t2_id"),
3233 "inner column should be qualified: {sql}"
3234 );
3235 assert!(
3237 sql.contains("= t1.t1_id"),
3238 "correlated column should be qualified: {sql}"
3239 );
3240 }
3241
3242 #[test]
3243 fn test_qualify_columns_correlated_exists_subquery() {
3244 let expr = parse(
3245 "SELECT o_orderpriority FROM orders \
3246 WHERE EXISTS (SELECT * FROM lineitem WHERE l_orderkey = o_orderkey)",
3247 );
3248
3249 let mut schema = MappingSchema::new();
3250 schema
3251 .add_table(
3252 "orders",
3253 &[
3254 ("o_orderpriority".to_string(), DataType::Text),
3255 ("o_orderkey".to_string(), DataType::BigInt { length: None }),
3256 ],
3257 None,
3258 )
3259 .expect("schema setup");
3260 schema
3261 .add_table(
3262 "lineitem",
3263 &[("l_orderkey".to_string(), DataType::BigInt { length: None })],
3264 None,
3265 )
3266 .expect("schema setup");
3267
3268 let result =
3269 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
3270 let sql = gen(&result);
3271
3272 assert!(
3273 sql.contains("orders.o_orderpriority"),
3274 "outer column should be qualified: {sql}"
3275 );
3276 assert!(
3277 sql.contains("lineitem.l_orderkey"),
3278 "inner column should be qualified: {sql}"
3279 );
3280 assert!(
3281 sql.contains("orders.o_orderkey"),
3282 "correlated outer column should be qualified: {sql}"
3283 );
3284 }
3285
3286 #[test]
3287 fn test_qualify_columns_rejects_unknown_table() {
3288 let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1");
3289
3290 let mut schema = MappingSchema::new();
3291 schema
3292 .add_table(
3293 "t1",
3294 &[("id".to_string(), DataType::BigInt { length: None })],
3295 None,
3296 )
3297 .expect("schema setup");
3298
3299 let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new());
3300 assert!(
3301 result.is_err(),
3302 "should reject reference to table not in scope or schema"
3303 );
3304 }
3305
3306 #[test]
3311 fn test_needs_quoting_reserved_word() {
3312 let reserved = get_reserved_words(None);
3313 assert!(needs_quoting("select", &reserved));
3314 assert!(needs_quoting("SELECT", &reserved));
3315 assert!(needs_quoting("from", &reserved));
3316 assert!(needs_quoting("WHERE", &reserved));
3317 assert!(needs_quoting("join", &reserved));
3318 assert!(needs_quoting("table", &reserved));
3319 }
3320
3321 #[test]
3322 fn test_needs_quoting_normal_identifiers() {
3323 let reserved = get_reserved_words(None);
3324 assert!(!needs_quoting("foo", &reserved));
3325 assert!(!needs_quoting("my_column", &reserved));
3326 assert!(!needs_quoting("col1", &reserved));
3327 assert!(!needs_quoting("A", &reserved));
3328 assert!(!needs_quoting("_hidden", &reserved));
3329 }
3330
3331 #[test]
3332 fn test_needs_quoting_special_characters() {
3333 let reserved = get_reserved_words(None);
3334 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
3340
3341 #[test]
3342 fn test_needs_quoting_starts_with_digit() {
3343 let reserved = get_reserved_words(None);
3344 assert!(needs_quoting("1col", &reserved));
3345 assert!(needs_quoting("123", &reserved));
3346 assert!(needs_quoting("0_start", &reserved));
3347 }
3348
3349 #[test]
3350 fn test_needs_quoting_empty() {
3351 let reserved = get_reserved_words(None);
3352 assert!(!needs_quoting("", &reserved));
3353 }
3354
3355 #[test]
3356 fn test_maybe_quote_sets_quoted_flag() {
3357 let reserved = get_reserved_words(None);
3358 let mut id = Identifier::new("select");
3359 assert!(!id.quoted);
3360 maybe_quote(&mut id, &reserved);
3361 assert!(id.quoted);
3362 }
3363
3364 #[test]
3365 fn test_maybe_quote_skips_already_quoted() {
3366 let reserved = get_reserved_words(None);
3367 let mut id = Identifier::quoted("myname");
3368 assert!(id.quoted);
3369 maybe_quote(&mut id, &reserved);
3370 assert!(id.quoted); assert_eq!(id.name, "myname"); }
3373
3374 #[test]
3375 fn test_maybe_quote_skips_star() {
3376 let reserved = get_reserved_words(None);
3377 let mut id = Identifier::new("*");
3378 maybe_quote(&mut id, &reserved);
3379 assert!(!id.quoted); }
3381
3382 #[test]
3383 fn test_maybe_quote_skips_normal() {
3384 let reserved = get_reserved_words(None);
3385 let mut id = Identifier::new("normal_col");
3386 maybe_quote(&mut id, &reserved);
3387 assert!(!id.quoted);
3388 }
3389
3390 #[test]
3391 fn test_quote_identifiers_column_with_reserved_name() {
3392 let expr = Expression::boxed_column(Column {
3394 name: Identifier::new("select"),
3395 table: None,
3396 join_mark: false,
3397 trailing_comments: vec![],
3398 span: None,
3399 inferred_type: None,
3400 });
3401 let result = quote_identifiers(expr, None);
3402 if let Expression::Column(col) = &result {
3403 assert!(col.name.quoted, "Column named 'select' should be quoted");
3404 } else {
3405 panic!("Expected Column expression");
3406 }
3407 }
3408
3409 #[test]
3410 fn test_quote_identifiers_column_with_special_chars() {
3411 let expr = Expression::boxed_column(Column {
3412 name: Identifier::new("my column"),
3413 table: None,
3414 join_mark: false,
3415 trailing_comments: vec![],
3416 span: None,
3417 inferred_type: None,
3418 });
3419 let result = quote_identifiers(expr, None);
3420 if let Expression::Column(col) = &result {
3421 assert!(col.name.quoted, "Column with space should be quoted");
3422 } else {
3423 panic!("Expected Column expression");
3424 }
3425 }
3426
3427 #[test]
3428 fn test_quote_identifiers_preserves_normal_column() {
3429 let expr = Expression::boxed_column(Column {
3430 name: Identifier::new("normal_col"),
3431 table: Some(Identifier::new("my_table")),
3432 join_mark: false,
3433 trailing_comments: vec![],
3434 span: None,
3435 inferred_type: None,
3436 });
3437 let result = quote_identifiers(expr, None);
3438 if let Expression::Column(col) = &result {
3439 assert!(!col.name.quoted, "Normal column should not be quoted");
3440 assert!(
3441 !col.table.as_ref().unwrap().quoted,
3442 "Normal table should not be quoted"
3443 );
3444 } else {
3445 panic!("Expected Column expression");
3446 }
3447 }
3448
3449 #[test]
3450 fn test_quote_identifiers_table_ref_reserved() {
3451 let expr = Expression::Table(Box::new(TableRef::new("select")));
3452 let result = quote_identifiers(expr, None);
3453 if let Expression::Table(tr) = &result {
3454 assert!(tr.name.quoted, "Table named 'select' should be quoted");
3455 } else {
3456 panic!("Expected Table expression");
3457 }
3458 }
3459
3460 #[test]
3461 fn test_quote_identifiers_table_ref_schema_and_alias() {
3462 let mut tr = TableRef::new("my_table");
3463 tr.schema = Some(Identifier::new("from"));
3464 tr.alias = Some(Identifier::new("t"));
3465 let expr = Expression::Table(Box::new(tr));
3466 let result = quote_identifiers(expr, None);
3467 if let Expression::Table(tr) = &result {
3468 assert!(!tr.name.quoted, "Normal table name should not be quoted");
3469 assert!(
3470 tr.schema.as_ref().unwrap().quoted,
3471 "Schema named 'from' should be quoted"
3472 );
3473 assert!(
3474 !tr.alias.as_ref().unwrap().quoted,
3475 "Normal alias should not be quoted"
3476 );
3477 } else {
3478 panic!("Expected Table expression");
3479 }
3480 }
3481
3482 #[test]
3483 fn test_quote_identifiers_identifier_node() {
3484 let expr = Expression::Identifier(Identifier::new("order"));
3485 let result = quote_identifiers(expr, None);
3486 if let Expression::Identifier(id) = &result {
3487 assert!(id.quoted, "Identifier named 'order' should be quoted");
3488 } else {
3489 panic!("Expected Identifier expression");
3490 }
3491 }
3492
3493 #[test]
3494 fn test_quote_identifiers_alias() {
3495 let inner = Expression::boxed_column(Column {
3496 name: Identifier::new("val"),
3497 table: None,
3498 join_mark: false,
3499 trailing_comments: vec![],
3500 span: None,
3501 inferred_type: None,
3502 });
3503 let expr = Expression::Alias(Box::new(Alias {
3504 this: inner,
3505 alias: Identifier::new("select"),
3506 column_aliases: vec![Identifier::new("from")],
3507 pre_alias_comments: vec![],
3508 trailing_comments: vec![],
3509 inferred_type: None,
3510 }));
3511 let result = quote_identifiers(expr, None);
3512 if let Expression::Alias(alias) = &result {
3513 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
3514 assert!(
3515 alias.column_aliases[0].quoted,
3516 "Column alias named 'from' should be quoted"
3517 );
3518 if let Expression::Column(col) = &alias.this {
3520 assert!(!col.name.quoted);
3521 }
3522 } else {
3523 panic!("Expected Alias expression");
3524 }
3525 }
3526
3527 #[test]
3528 fn test_quote_identifiers_select_recursive() {
3529 let expr = parse("SELECT a, b FROM t WHERE c = 1");
3531 let result = quote_identifiers(expr, None);
3532 let sql = gen(&result);
3534 assert!(sql.contains("a"));
3536 assert!(sql.contains("b"));
3537 assert!(sql.contains("t"));
3538 }
3539
3540 #[test]
3541 fn test_quote_identifiers_digit_start() {
3542 let expr = Expression::boxed_column(Column {
3543 name: Identifier::new("1col"),
3544 table: None,
3545 join_mark: false,
3546 trailing_comments: vec![],
3547 span: None,
3548 inferred_type: None,
3549 });
3550 let result = quote_identifiers(expr, None);
3551 if let Expression::Column(col) = &result {
3552 assert!(
3553 col.name.quoted,
3554 "Column starting with digit should be quoted"
3555 );
3556 } else {
3557 panic!("Expected Column expression");
3558 }
3559 }
3560
3561 #[test]
3562 fn test_quote_identifiers_with_mysql_dialect() {
3563 let reserved = get_reserved_words(Some(DialectType::MySQL));
3564 assert!(needs_quoting("KILL", &reserved));
3566 assert!(needs_quoting("FORCE", &reserved));
3568 }
3569
3570 #[test]
3571 fn test_quote_identifiers_with_postgresql_dialect() {
3572 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
3573 assert!(needs_quoting("ILIKE", &reserved));
3575 assert!(needs_quoting("VERBOSE", &reserved));
3577 }
3578
3579 #[test]
3580 fn test_quote_identifiers_with_bigquery_dialect() {
3581 let reserved = get_reserved_words(Some(DialectType::BigQuery));
3582 assert!(needs_quoting("STRUCT", &reserved));
3584 assert!(needs_quoting("PROTO", &reserved));
3586 }
3587
3588 #[test]
3589 fn test_quote_identifiers_case_insensitive_reserved() {
3590 let reserved = get_reserved_words(None);
3591 assert!(needs_quoting("Select", &reserved));
3592 assert!(needs_quoting("sElEcT", &reserved));
3593 assert!(needs_quoting("FROM", &reserved));
3594 assert!(needs_quoting("from", &reserved));
3595 }
3596
3597 #[test]
3598 fn test_quote_identifiers_join_using() {
3599 let mut join = crate::expressions::Join {
3601 this: Expression::Table(Box::new(TableRef::new("other"))),
3602 on: None,
3603 using: vec![Identifier::new("key"), Identifier::new("value")],
3604 kind: crate::expressions::JoinKind::Inner,
3605 use_inner_keyword: false,
3606 use_outer_keyword: false,
3607 deferred_condition: false,
3608 join_hint: None,
3609 match_condition: None,
3610 pivots: vec![],
3611 comments: vec![],
3612 nesting_group: 0,
3613 directed: false,
3614 };
3615 let reserved = get_reserved_words(None);
3616 quote_join(&mut join, &reserved);
3617 assert!(
3619 join.using[0].quoted,
3620 "USING identifier 'key' should be quoted"
3621 );
3622 assert!(
3623 !join.using[1].quoted,
3624 "USING identifier 'value' should not be quoted"
3625 );
3626 }
3627
3628 #[test]
3629 fn test_quote_identifiers_cte() {
3630 let mut cte = crate::expressions::Cte {
3632 alias: Identifier::new("select"),
3633 this: Expression::boxed_column(Column {
3634 name: Identifier::new("x"),
3635 table: None,
3636 join_mark: false,
3637 trailing_comments: vec![],
3638 span: None,
3639 inferred_type: None,
3640 }),
3641 columns: vec![Identifier::new("from"), Identifier::new("normal")],
3642 materialized: None,
3643 key_expressions: vec![],
3644 alias_first: false,
3645 comments: Vec::new(),
3646 };
3647 let reserved = get_reserved_words(None);
3648 maybe_quote(&mut cte.alias, &reserved);
3649 for c in &mut cte.columns {
3650 maybe_quote(c, &reserved);
3651 }
3652 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
3653 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
3654 assert!(
3655 !cte.columns[1].quoted,
3656 "CTE column 'normal' should not be quoted"
3657 );
3658 }
3659
3660 #[test]
3661 fn test_quote_identifiers_binary_ops_recurse() {
3662 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
3665 Expression::boxed_column(Column {
3666 name: Identifier::new("select"),
3667 table: None,
3668 join_mark: false,
3669 trailing_comments: vec![],
3670 span: None,
3671 inferred_type: None,
3672 }),
3673 Expression::boxed_column(Column {
3674 name: Identifier::new("normal"),
3675 table: None,
3676 join_mark: false,
3677 trailing_comments: vec![],
3678 span: None,
3679 inferred_type: None,
3680 }),
3681 )));
3682 let result = quote_identifiers(expr, None);
3683 if let Expression::Add(bin) = &result {
3684 if let Expression::Column(left) = &bin.left {
3685 assert!(
3686 left.name.quoted,
3687 "'select' column should be quoted in binary op"
3688 );
3689 }
3690 if let Expression::Column(right) = &bin.right {
3691 assert!(!right.name.quoted, "'normal' column should not be quoted");
3692 }
3693 } else {
3694 panic!("Expected Add expression");
3695 }
3696 }
3697
3698 #[test]
3699 fn test_quote_identifiers_already_quoted_preserved() {
3700 let expr = Expression::boxed_column(Column {
3702 name: Identifier::quoted("normal_name"),
3703 table: None,
3704 join_mark: false,
3705 trailing_comments: vec![],
3706 span: None,
3707 inferred_type: None,
3708 });
3709 let result = quote_identifiers(expr, None);
3710 if let Expression::Column(col) = &result {
3711 assert!(
3712 col.name.quoted,
3713 "Already-quoted identifier should remain quoted"
3714 );
3715 } else {
3716 panic!("Expected Column expression");
3717 }
3718 }
3719
3720 #[test]
3721 fn test_quote_identifiers_full_parsed_query() {
3722 let mut select = crate::expressions::Select::new();
3725 select.expressions.push(Expression::boxed_column(Column {
3726 name: Identifier::new("order"),
3727 table: Some(Identifier::new("t")),
3728 join_mark: false,
3729 trailing_comments: vec![],
3730 span: None,
3731 inferred_type: None,
3732 }));
3733 select.from = Some(crate::expressions::From {
3734 expressions: vec![Expression::Table(Box::new(TableRef::new("t")))],
3735 });
3736 let expr = Expression::Select(Box::new(select));
3737
3738 let result = quote_identifiers(expr, None);
3739 if let Expression::Select(sel) = &result {
3740 if let Expression::Column(col) = &sel.expressions[0] {
3741 assert!(col.name.quoted, "Column named 'order' should be quoted");
3742 assert!(
3743 !col.table.as_ref().unwrap().quoted,
3744 "Table 't' should not be quoted"
3745 );
3746 } else {
3747 panic!("Expected Column in SELECT list");
3748 }
3749 } else {
3750 panic!("Expected Select expression");
3751 }
3752 }
3753
3754 #[test]
3755 fn test_get_reserved_words_all_dialects() {
3756 let dialects = [
3758 None,
3759 Some(DialectType::Generic),
3760 Some(DialectType::MySQL),
3761 Some(DialectType::PostgreSQL),
3762 Some(DialectType::BigQuery),
3763 Some(DialectType::Snowflake),
3764 Some(DialectType::TSQL),
3765 Some(DialectType::ClickHouse),
3766 Some(DialectType::DuckDB),
3767 Some(DialectType::Hive),
3768 Some(DialectType::Spark),
3769 Some(DialectType::Trino),
3770 Some(DialectType::Oracle),
3771 Some(DialectType::Redshift),
3772 ];
3773 for dialect in &dialects {
3774 let words = get_reserved_words(*dialect);
3775 assert!(
3777 words.contains("SELECT"),
3778 "All dialects should have SELECT as reserved"
3779 );
3780 assert!(
3781 words.contains("FROM"),
3782 "All dialects should have FROM as reserved"
3783 );
3784 }
3785 }
3786}