1use crate::dialects::transform_recursive;
9use crate::dialects::DialectType;
10use crate::expressions::{
11 Alias, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren, Select,
12 TableRef, With,
13};
14use crate::resolver::{Resolver, ResolverError};
15use crate::schema::Schema;
16use crate::scope::{build_scope, traverse_scope, Scope};
17use std::cell::RefCell;
18use std::collections::{HashMap, HashSet};
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone)]
23pub enum QualifyColumnsError {
24 #[error("Unknown table: {0}")]
25 UnknownTable(String),
26
27 #[error("Unknown column: {0}")]
28 UnknownColumn(String),
29
30 #[error("Ambiguous column: {0}")]
31 AmbiguousColumn(String),
32
33 #[error("Cannot automatically join: {0}")]
34 CannotAutoJoin(String),
35
36 #[error("Unknown output column: {0}")]
37 UnknownOutputColumn(String),
38
39 #[error("Column could not be resolved: {column}{for_table}")]
40 ColumnNotResolved { column: String, for_table: String },
41
42 #[error("Resolver error: {0}")]
43 ResolverError(#[from] ResolverError),
44}
45
46pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
48
49#[derive(Debug, Clone, Default)]
51pub struct QualifyColumnsOptions {
52 pub expand_alias_refs: bool,
54 pub expand_stars: bool,
56 pub infer_schema: Option<bool>,
58 pub allow_partial_qualification: bool,
60 pub dialect: Option<DialectType>,
62}
63
64impl QualifyColumnsOptions {
65 pub fn new() -> Self {
67 Self {
68 expand_alias_refs: true,
69 expand_stars: true,
70 infer_schema: None,
71 allow_partial_qualification: false,
72 dialect: None,
73 }
74 }
75
76 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
78 self.expand_alias_refs = expand;
79 self
80 }
81
82 pub fn with_expand_stars(mut self, expand: bool) -> Self {
84 self.expand_stars = expand;
85 self
86 }
87
88 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
90 self.dialect = Some(dialect);
91 self
92 }
93
94 pub fn with_allow_partial(mut self, allow: bool) -> Self {
96 self.allow_partial_qualification = allow;
97 self
98 }
99}
100
101pub fn qualify_columns(
116 expression: Expression,
117 schema: &dyn Schema,
118 options: &QualifyColumnsOptions,
119) -> QualifyColumnsResult<Expression> {
120 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
121 let dialect = options.dialect.or_else(|| schema.dialect());
122 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
123
124 let transformed = transform_recursive(expression, &|node| {
125 if first_error.borrow().is_some() {
126 return Ok(node);
127 }
128
129 match node {
130 Expression::Select(mut select) => {
131 if let Some(with) = &mut select.with {
132 pushdown_cte_alias_columns_with(with);
133 }
134
135 let scope_expr = Expression::Select(select.clone());
136 let scope = build_scope(&scope_expr);
137 let mut resolver = Resolver::new(&scope, schema, infer_schema);
138
139 if let Err(err) = qualify_columns_in_scope(
140 &mut select,
141 &scope,
142 &mut resolver,
143 options.allow_partial_qualification,
144 ) {
145 *first_error.borrow_mut() = Some(err);
146 }
147
148 if first_error.borrow().is_none() && options.expand_alias_refs {
149 if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
150 *first_error.borrow_mut() = Some(err);
151 }
152 }
153
154 if first_error.borrow().is_none() && options.expand_stars {
155 if let Err(err) = expand_stars(&mut select, &scope, &mut resolver) {
156 *first_error.borrow_mut() = Some(err);
157 }
158 }
159
160 if first_error.borrow().is_none() {
161 if let Err(err) = qualify_outputs_select(&mut select) {
162 *first_error.borrow_mut() = Some(err);
163 }
164 }
165
166 if first_error.borrow().is_none() {
167 if let Err(err) = expand_group_by(&mut select, dialect) {
168 *first_error.borrow_mut() = Some(err);
169 }
170 }
171
172 Ok(Expression::Select(select))
173 }
174 _ => Ok(node),
175 }
176 })
177 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
178
179 if let Some(err) = first_error.into_inner() {
180 return Err(err);
181 }
182
183 Ok(transformed)
184}
185
186pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
191 let mut all_unqualified = Vec::new();
192
193 for scope in traverse_scope(expression) {
194 if let Expression::Select(_) = &scope.expression {
195 let unqualified = get_unqualified_columns(&scope);
197
198 let external = get_external_columns(&scope);
200 if !external.is_empty() && !is_correlated_subquery(&scope) {
201 let first = &external[0];
202 let for_table = if first.table.is_some() {
203 format!(" for table: '{}'", first.table.as_ref().unwrap())
204 } else {
205 String::new()
206 };
207 return Err(QualifyColumnsError::ColumnNotResolved {
208 column: first.name.clone(),
209 for_table,
210 });
211 }
212
213 all_unqualified.extend(unqualified);
214 }
215 }
216
217 if !all_unqualified.is_empty() {
218 let first = &all_unqualified[0];
219 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
220 }
221
222 Ok(())
223}
224
225fn qualify_columns_in_scope(
227 select: &mut Select,
228 scope: &Scope,
229 resolver: &mut Resolver,
230 allow_partial: bool,
231) -> QualifyColumnsResult<()> {
232 for expr in &mut select.expressions {
233 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
234 }
235 if let Some(where_clause) = &mut select.where_clause {
236 qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
237 }
238 if let Some(group_by) = &mut select.group_by {
239 for expr in &mut group_by.expressions {
240 qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
241 }
242 }
243 if let Some(having) = &mut select.having {
244 qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
245 }
246 if let Some(qualify) = &mut select.qualify {
247 qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
248 }
249 if let Some(order_by) = &mut select.order_by {
250 for ordered in &mut order_by.expressions {
251 qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
252 }
253 }
254 for join in &mut select.joins {
255 qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
256 if let Some(on) = &mut join.on {
257 qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
258 }
259 }
260 Ok(())
261}
262
263fn expand_alias_refs(
270 select: &mut Select,
271 _resolver: &mut Resolver,
272 _dialect: Option<DialectType>,
273) -> QualifyColumnsResult<()> {
274 let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
275
276 for (i, expr) in select.expressions.iter_mut().enumerate() {
277 replace_alias_refs_in_expression(expr, &alias_to_expression, false);
278 if let Expression::Alias(alias) = expr {
279 alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
280 }
281 }
282
283 if let Some(where_clause) = &mut select.where_clause {
284 replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
285 }
286 if let Some(group_by) = &mut select.group_by {
287 for expr in &mut group_by.expressions {
288 replace_alias_refs_in_expression(expr, &alias_to_expression, true);
289 }
290 }
291 if let Some(having) = &mut select.having {
292 replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
293 }
294 if let Some(qualify) = &mut select.qualify {
295 replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
296 }
297 if let Some(order_by) = &mut select.order_by {
298 for ordered in &mut order_by.expressions {
299 replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
300 }
301 }
302
303 Ok(())
304}
305
306fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
313 let projections = select.expressions.clone();
314
315 if let Some(group_by) = &mut select.group_by {
316 for group_expr in &mut group_by.expressions {
317 if let Some(index) = positional_reference(group_expr) {
318 let replacement = select_expression_at_position(&projections, index)?;
319 *group_expr = replacement;
320 }
321 }
322 }
323 Ok(())
324}
325
326fn expand_stars(
333 select: &mut Select,
334 scope: &Scope,
335 resolver: &mut Resolver,
336) -> QualifyColumnsResult<()> {
337 let mut new_selections: Vec<Expression> = Vec::new();
338 let mut has_star = false;
339
340 for expr in &select.expressions {
341 match expr {
342 Expression::Star(_) => {
343 has_star = true;
344 for source_name in scope.sources.keys() {
345 if let Ok(columns) = resolver.get_source_columns(source_name) {
346 if columns.contains(&"*".to_string()) || columns.is_empty() {
347 return Ok(());
348 }
349 for col_name in columns {
350 new_selections
351 .push(create_qualified_column(&col_name, Some(source_name)));
352 }
353 }
354 }
355 }
356 Expression::Column(col) if is_star_column(col) => {
357 has_star = true;
358 if let Some(table) = &col.table {
359 let table_name = &table.name;
360 if !scope.sources.contains_key(table_name) {
361 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
362 }
363 if let Ok(columns) = resolver.get_source_columns(table_name) {
364 if columns.contains(&"*".to_string()) || columns.is_empty() {
365 return Ok(());
366 }
367 for col_name in columns {
368 new_selections
369 .push(create_qualified_column(&col_name, Some(table_name)));
370 }
371 }
372 }
373 }
374 _ => new_selections.push(expr.clone()),
375 }
376 }
377
378 if has_star {
379 select.expressions = new_selections;
380 }
381
382 Ok(())
383}
384
385pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
392 if let Expression::Select(mut select) = scope.expression.clone() {
393 qualify_outputs_select(&mut select)?;
394 }
395 Ok(())
396}
397
398fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
399 let mut new_selections: Vec<Expression> = Vec::new();
400
401 for (i, expr) in select.expressions.iter().enumerate() {
402 match expr {
403 Expression::Alias(_) => new_selections.push(expr.clone()),
404 Expression::Column(col) => {
405 new_selections.push(create_alias(expr.clone(), &col.name.name));
406 }
407 Expression::Star(_) => new_selections.push(expr.clone()),
408 _ => {
409 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
410 new_selections.push(create_alias(expr.clone(), &alias_name));
411 }
412 }
413 }
414
415 select.expressions = new_selections;
416 Ok(())
417}
418
419fn qualify_columns_in_expression(
420 expr: &mut Expression,
421 scope: &Scope,
422 resolver: &mut Resolver,
423 allow_partial: bool,
424) -> QualifyColumnsResult<()> {
425 let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
426 let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
427
428 let transformed = transform_recursive(expr.clone(), &|node| {
429 if first_error.borrow().is_some() {
430 return Ok(node);
431 }
432
433 match node {
434 Expression::Column(mut col) => {
435 if let Err(err) = qualify_single_column(
436 &mut col,
437 scope,
438 &mut resolver_cell.borrow_mut(),
439 allow_partial,
440 ) {
441 *first_error.borrow_mut() = Some(err);
442 }
443 Ok(Expression::Column(col))
444 }
445 _ => Ok(node),
446 }
447 })
448 .map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
449
450 if let Some(err) = first_error.into_inner() {
451 return Err(err);
452 }
453
454 *expr = transformed;
455 Ok(())
456}
457
458fn qualify_single_column(
459 col: &mut Column,
460 scope: &Scope,
461 resolver: &mut Resolver,
462 allow_partial: bool,
463) -> QualifyColumnsResult<()> {
464 if is_star_column(col) {
465 return Ok(());
466 }
467
468 if let Some(table) = &col.table {
469 let table_name = &table.name;
470 if !scope.sources.contains_key(table_name) {
471 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
472 }
473
474 if let Ok(source_columns) = resolver.get_source_columns(table_name) {
475 if !allow_partial
476 && !source_columns.is_empty()
477 && !source_columns.contains(&col.name.name)
478 && !source_columns.contains(&"*".to_string())
479 {
480 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
481 }
482 }
483 return Ok(());
484 }
485
486 if let Some(table_name) = resolver.get_table(&col.name.name) {
487 col.table = Some(Identifier::new(table_name));
488 return Ok(());
489 }
490
491 if !allow_partial {
492 return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
493 }
494
495 Ok(())
496}
497
498fn replace_alias_refs_in_expression(
499 expr: &mut Expression,
500 alias_to_expression: &HashMap<String, (Expression, usize)>,
501 literal_index: bool,
502) {
503 let transformed = transform_recursive(expr.clone(), &|node| match node {
504 Expression::Column(col) if col.table.is_none() => {
505 if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
506 if literal_index && matches!(alias_expr, Expression::Literal(_)) {
507 return Ok(Expression::number(*index as i64));
508 }
509 return Ok(Expression::Paren(Box::new(Paren {
510 this: alias_expr.clone(),
511 trailing_comments: vec![],
512 })));
513 }
514 Ok(Expression::Column(col))
515 }
516 other => Ok(other),
517 });
518
519 if let Ok(next) = transformed {
520 *expr = next;
521 }
522}
523
524fn positional_reference(expr: &Expression) -> Option<usize> {
525 match expr {
526 Expression::Literal(Literal::Number(value)) => value.parse::<usize>().ok(),
527 _ => None,
528 }
529}
530
531fn select_expression_at_position(
532 projections: &[Expression],
533 index: usize,
534) -> QualifyColumnsResult<Expression> {
535 if index == 0 || index > projections.len() {
536 return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
537 }
538
539 let projection = projections[index - 1].clone();
540 Ok(match projection {
541 Expression::Alias(alias) => alias.this.clone(),
542 other => other,
543 })
544}
545
546fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
549 let mut words: HashSet<&'static str> = [
551 "ADD",
553 "ALL",
554 "ALTER",
555 "AND",
556 "ANY",
557 "AS",
558 "ASC",
559 "BETWEEN",
560 "BY",
561 "CASE",
562 "CAST",
563 "CHECK",
564 "COLUMN",
565 "CONSTRAINT",
566 "CREATE",
567 "CROSS",
568 "CURRENT",
569 "CURRENT_DATE",
570 "CURRENT_TIME",
571 "CURRENT_TIMESTAMP",
572 "CURRENT_USER",
573 "DATABASE",
574 "DEFAULT",
575 "DELETE",
576 "DESC",
577 "DISTINCT",
578 "DROP",
579 "ELSE",
580 "END",
581 "ESCAPE",
582 "EXCEPT",
583 "EXISTS",
584 "FALSE",
585 "FETCH",
586 "FOR",
587 "FOREIGN",
588 "FROM",
589 "FULL",
590 "GRANT",
591 "GROUP",
592 "HAVING",
593 "IF",
594 "IN",
595 "INDEX",
596 "INNER",
597 "INSERT",
598 "INTERSECT",
599 "INTO",
600 "IS",
601 "JOIN",
602 "KEY",
603 "LEFT",
604 "LIKE",
605 "LIMIT",
606 "NATURAL",
607 "NOT",
608 "NULL",
609 "OFFSET",
610 "ON",
611 "OR",
612 "ORDER",
613 "OUTER",
614 "PRIMARY",
615 "REFERENCES",
616 "REPLACE",
617 "RETURNING",
618 "RIGHT",
619 "ROLLBACK",
620 "ROW",
621 "ROWS",
622 "SELECT",
623 "SESSION_USER",
624 "SET",
625 "SOME",
626 "TABLE",
627 "THEN",
628 "TO",
629 "TRUE",
630 "TRUNCATE",
631 "UNION",
632 "UNIQUE",
633 "UPDATE",
634 "USING",
635 "VALUES",
636 "VIEW",
637 "WHEN",
638 "WHERE",
639 "WINDOW",
640 "WITH",
641 ]
642 .iter()
643 .copied()
644 .collect();
645
646 match dialect {
648 Some(DialectType::MySQL) => {
649 words.extend(
650 [
651 "ANALYZE",
652 "BOTH",
653 "CHANGE",
654 "CONDITION",
655 "DATABASES",
656 "DAY_HOUR",
657 "DAY_MICROSECOND",
658 "DAY_MINUTE",
659 "DAY_SECOND",
660 "DELAYED",
661 "DETERMINISTIC",
662 "DIV",
663 "DUAL",
664 "EACH",
665 "ELSEIF",
666 "ENCLOSED",
667 "EXPLAIN",
668 "FLOAT4",
669 "FLOAT8",
670 "FORCE",
671 "HOUR_MICROSECOND",
672 "HOUR_MINUTE",
673 "HOUR_SECOND",
674 "IGNORE",
675 "INFILE",
676 "INT1",
677 "INT2",
678 "INT3",
679 "INT4",
680 "INT8",
681 "ITERATE",
682 "KEYS",
683 "KILL",
684 "LEADING",
685 "LEAVE",
686 "LINES",
687 "LOAD",
688 "LOCK",
689 "LONG",
690 "LONGBLOB",
691 "LONGTEXT",
692 "LOOP",
693 "LOW_PRIORITY",
694 "MATCH",
695 "MEDIUMBLOB",
696 "MEDIUMINT",
697 "MEDIUMTEXT",
698 "MINUTE_MICROSECOND",
699 "MINUTE_SECOND",
700 "MOD",
701 "MODIFIES",
702 "NO_WRITE_TO_BINLOG",
703 "OPTIMIZE",
704 "OPTIONALLY",
705 "OUT",
706 "OUTFILE",
707 "PURGE",
708 "READS",
709 "REGEXP",
710 "RELEASE",
711 "RENAME",
712 "REPEAT",
713 "REQUIRE",
714 "RESIGNAL",
715 "RETURN",
716 "REVOKE",
717 "RLIKE",
718 "SCHEMA",
719 "SCHEMAS",
720 "SECOND_MICROSECOND",
721 "SENSITIVE",
722 "SEPARATOR",
723 "SHOW",
724 "SIGNAL",
725 "SPATIAL",
726 "SQL",
727 "SQLEXCEPTION",
728 "SQLSTATE",
729 "SQLWARNING",
730 "SQL_BIG_RESULT",
731 "SQL_CALC_FOUND_ROWS",
732 "SQL_SMALL_RESULT",
733 "SSL",
734 "STARTING",
735 "STRAIGHT_JOIN",
736 "TERMINATED",
737 "TINYBLOB",
738 "TINYINT",
739 "TINYTEXT",
740 "TRAILING",
741 "TRIGGER",
742 "UNDO",
743 "UNLOCK",
744 "UNSIGNED",
745 "USAGE",
746 "UTC_DATE",
747 "UTC_TIME",
748 "UTC_TIMESTAMP",
749 "VARBINARY",
750 "VARCHARACTER",
751 "WHILE",
752 "WRITE",
753 "XOR",
754 "YEAR_MONTH",
755 "ZEROFILL",
756 ]
757 .iter()
758 .copied(),
759 );
760 }
761 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
762 words.extend(
763 [
764 "ANALYSE",
765 "ANALYZE",
766 "ARRAY",
767 "AUTHORIZATION",
768 "BINARY",
769 "BOTH",
770 "COLLATE",
771 "CONCURRENTLY",
772 "DO",
773 "FREEZE",
774 "ILIKE",
775 "INITIALLY",
776 "ISNULL",
777 "LATERAL",
778 "LEADING",
779 "LOCALTIME",
780 "LOCALTIMESTAMP",
781 "NOTNULL",
782 "ONLY",
783 "OVERLAPS",
784 "PLACING",
785 "SIMILAR",
786 "SYMMETRIC",
787 "TABLESAMPLE",
788 "TRAILING",
789 "VARIADIC",
790 "VERBOSE",
791 ]
792 .iter()
793 .copied(),
794 );
795 }
796 Some(DialectType::BigQuery) => {
797 words.extend(
798 [
799 "ASSERT_ROWS_MODIFIED",
800 "COLLATE",
801 "CONTAINS",
802 "CUBE",
803 "DEFINE",
804 "ENUM",
805 "EXTRACT",
806 "FOLLOWING",
807 "GROUPING",
808 "GROUPS",
809 "HASH",
810 "IGNORE",
811 "LATERAL",
812 "LOOKUP",
813 "MERGE",
814 "NEW",
815 "NO",
816 "NULLS",
817 "OF",
818 "OVER",
819 "PARTITION",
820 "PRECEDING",
821 "PROTO",
822 "RANGE",
823 "RECURSIVE",
824 "RESPECT",
825 "ROLLUP",
826 "STRUCT",
827 "TABLESAMPLE",
828 "TREAT",
829 "UNBOUNDED",
830 "WITHIN",
831 ]
832 .iter()
833 .copied(),
834 );
835 }
836 Some(DialectType::Snowflake) => {
837 words.extend(
838 [
839 "ACCOUNT",
840 "BOTH",
841 "CONNECT",
842 "FOLLOWING",
843 "ILIKE",
844 "INCREMENT",
845 "ISSUE",
846 "LATERAL",
847 "LEADING",
848 "LOCALTIME",
849 "LOCALTIMESTAMP",
850 "MINUS",
851 "QUALIFY",
852 "REGEXP",
853 "RLIKE",
854 "SOME",
855 "START",
856 "TABLESAMPLE",
857 "TOP",
858 "TRAILING",
859 "TRY_CAST",
860 ]
861 .iter()
862 .copied(),
863 );
864 }
865 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
866 words.extend(
867 [
868 "BACKUP",
869 "BREAK",
870 "BROWSE",
871 "BULK",
872 "CASCADE",
873 "CHECKPOINT",
874 "CLOSE",
875 "CLUSTERED",
876 "COALESCE",
877 "COMPUTE",
878 "CONTAINS",
879 "CONTAINSTABLE",
880 "CONTINUE",
881 "CONVERT",
882 "DBCC",
883 "DEALLOCATE",
884 "DENY",
885 "DISK",
886 "DISTRIBUTED",
887 "DUMP",
888 "ERRLVL",
889 "EXEC",
890 "EXECUTE",
891 "EXIT",
892 "EXTERNAL",
893 "FILE",
894 "FILLFACTOR",
895 "FREETEXT",
896 "FREETEXTTABLE",
897 "FUNCTION",
898 "GOTO",
899 "HOLDLOCK",
900 "IDENTITY",
901 "IDENTITYCOL",
902 "IDENTITY_INSERT",
903 "KILL",
904 "LINENO",
905 "MERGE",
906 "NONCLUSTERED",
907 "NULLIF",
908 "OF",
909 "OFF",
910 "OFFSETS",
911 "OPEN",
912 "OPENDATASOURCE",
913 "OPENQUERY",
914 "OPENROWSET",
915 "OPENXML",
916 "OVER",
917 "PERCENT",
918 "PIVOT",
919 "PLAN",
920 "PRINT",
921 "PROC",
922 "PROCEDURE",
923 "PUBLIC",
924 "RAISERROR",
925 "READ",
926 "READTEXT",
927 "RECONFIGURE",
928 "REPLICATION",
929 "RESTORE",
930 "RESTRICT",
931 "REVERT",
932 "ROWCOUNT",
933 "ROWGUIDCOL",
934 "RULE",
935 "SAVE",
936 "SECURITYAUDIT",
937 "SEMANTICKEYPHRASETABLE",
938 "SEMANTICSIMILARITYDETAILSTABLE",
939 "SEMANTICSIMILARITYTABLE",
940 "SETUSER",
941 "SHUTDOWN",
942 "STATISTICS",
943 "SYSTEM_USER",
944 "TEXTSIZE",
945 "TOP",
946 "TRAN",
947 "TRANSACTION",
948 "TRIGGER",
949 "TSEQUAL",
950 "UNPIVOT",
951 "UPDATETEXT",
952 "WAITFOR",
953 "WRITETEXT",
954 ]
955 .iter()
956 .copied(),
957 );
958 }
959 Some(DialectType::ClickHouse) => {
960 words.extend(
961 [
962 "ANTI",
963 "ARRAY",
964 "ASOF",
965 "FINAL",
966 "FORMAT",
967 "GLOBAL",
968 "INF",
969 "KILL",
970 "MATERIALIZED",
971 "NAN",
972 "PREWHERE",
973 "SAMPLE",
974 "SEMI",
975 "SETTINGS",
976 "TOP",
977 ]
978 .iter()
979 .copied(),
980 );
981 }
982 Some(DialectType::DuckDB) => {
983 words.extend(
984 [
985 "ANALYSE",
986 "ANALYZE",
987 "ARRAY",
988 "BOTH",
989 "LATERAL",
990 "LEADING",
991 "LOCALTIME",
992 "LOCALTIMESTAMP",
993 "PLACING",
994 "QUALIFY",
995 "SIMILAR",
996 "TABLESAMPLE",
997 "TRAILING",
998 ]
999 .iter()
1000 .copied(),
1001 );
1002 }
1003 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
1004 words.extend(
1005 [
1006 "BOTH",
1007 "CLUSTER",
1008 "DISTRIBUTE",
1009 "EXCHANGE",
1010 "EXTENDED",
1011 "FUNCTION",
1012 "LATERAL",
1013 "LEADING",
1014 "MACRO",
1015 "OVER",
1016 "PARTITION",
1017 "PERCENT",
1018 "RANGE",
1019 "READS",
1020 "REDUCE",
1021 "REGEXP",
1022 "REVOKE",
1023 "RLIKE",
1024 "ROLLUP",
1025 "SEMI",
1026 "SORT",
1027 "TABLESAMPLE",
1028 "TRAILING",
1029 "TRANSFORM",
1030 "UNBOUNDED",
1031 "UNIQUEJOIN",
1032 ]
1033 .iter()
1034 .copied(),
1035 );
1036 }
1037 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
1038 words.extend(
1039 [
1040 "CUBE",
1041 "DEALLOCATE",
1042 "DESCRIBE",
1043 "EXECUTE",
1044 "EXTRACT",
1045 "GROUPING",
1046 "LATERAL",
1047 "LOCALTIME",
1048 "LOCALTIMESTAMP",
1049 "NORMALIZE",
1050 "PREPARE",
1051 "ROLLUP",
1052 "SOME",
1053 "TABLESAMPLE",
1054 "UESCAPE",
1055 "UNNEST",
1056 ]
1057 .iter()
1058 .copied(),
1059 );
1060 }
1061 Some(DialectType::Oracle) => {
1062 words.extend(
1063 [
1064 "ACCESS",
1065 "AUDIT",
1066 "CLUSTER",
1067 "COMMENT",
1068 "COMPRESS",
1069 "CONNECT",
1070 "EXCLUSIVE",
1071 "FILE",
1072 "IDENTIFIED",
1073 "IMMEDIATE",
1074 "INCREMENT",
1075 "INITIAL",
1076 "LEVEL",
1077 "LOCK",
1078 "LONG",
1079 "MAXEXTENTS",
1080 "MINUS",
1081 "MODE",
1082 "NOAUDIT",
1083 "NOCOMPRESS",
1084 "NOWAIT",
1085 "NUMBER",
1086 "OF",
1087 "OFFLINE",
1088 "ONLINE",
1089 "PCTFREE",
1090 "PRIOR",
1091 "RAW",
1092 "RENAME",
1093 "RESOURCE",
1094 "REVOKE",
1095 "SHARE",
1096 "SIZE",
1097 "START",
1098 "SUCCESSFUL",
1099 "SYNONYM",
1100 "SYSDATE",
1101 "TRIGGER",
1102 "UID",
1103 "VALIDATE",
1104 "VARCHAR2",
1105 "WHENEVER",
1106 ]
1107 .iter()
1108 .copied(),
1109 );
1110 }
1111 Some(DialectType::Redshift) => {
1112 words.extend(
1113 [
1114 "AZ64",
1115 "BZIP2",
1116 "DELTA",
1117 "DELTA32K",
1118 "DISTSTYLE",
1119 "ENCODE",
1120 "GZIP",
1121 "ILIKE",
1122 "LIMIT",
1123 "LUNS",
1124 "LZO",
1125 "LZOP",
1126 "MOSTLY13",
1127 "MOSTLY32",
1128 "MOSTLY8",
1129 "RAW",
1130 "SIMILAR",
1131 "SNAPSHOT",
1132 "SORTKEY",
1133 "SYSDATE",
1134 "TOP",
1135 "ZSTD",
1136 ]
1137 .iter()
1138 .copied(),
1139 );
1140 }
1141 _ => {
1142 words.extend(
1144 [
1145 "ANALYZE",
1146 "ARRAY",
1147 "BOTH",
1148 "CUBE",
1149 "GROUPING",
1150 "LATERAL",
1151 "LEADING",
1152 "LOCALTIME",
1153 "LOCALTIMESTAMP",
1154 "OVER",
1155 "PARTITION",
1156 "QUALIFY",
1157 "RANGE",
1158 "ROLLUP",
1159 "SIMILAR",
1160 "SOME",
1161 "TABLESAMPLE",
1162 "TRAILING",
1163 ]
1164 .iter()
1165 .copied(),
1166 );
1167 }
1168 }
1169
1170 words
1171}
1172
1173fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1181 if name.is_empty() {
1182 return false;
1183 }
1184
1185 if name.as_bytes()[0].is_ascii_digit() {
1187 return true;
1188 }
1189
1190 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1192 return true;
1193 }
1194
1195 let upper = name.to_uppercase();
1197 reserved_words.contains(upper.as_str())
1198}
1199
1200fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1202 if id.quoted || id.name.is_empty() || id.name == "*" {
1205 return;
1206 }
1207 if needs_quoting(&id.name, reserved_words) {
1208 id.quoted = true;
1209 }
1210}
1211
1212fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1214 match expr {
1215 Expression::Identifier(id) => {
1217 maybe_quote(id, reserved_words);
1218 }
1219
1220 Expression::Column(col) => {
1221 maybe_quote(&mut col.name, reserved_words);
1222 if let Some(ref mut table) = col.table {
1223 maybe_quote(table, reserved_words);
1224 }
1225 }
1226
1227 Expression::Table(table_ref) => {
1228 maybe_quote(&mut table_ref.name, reserved_words);
1229 if let Some(ref mut schema) = table_ref.schema {
1230 maybe_quote(schema, reserved_words);
1231 }
1232 if let Some(ref mut catalog) = table_ref.catalog {
1233 maybe_quote(catalog, reserved_words);
1234 }
1235 if let Some(ref mut alias) = table_ref.alias {
1236 maybe_quote(alias, reserved_words);
1237 }
1238 for ca in &mut table_ref.column_aliases {
1239 maybe_quote(ca, reserved_words);
1240 }
1241 for p in &mut table_ref.partitions {
1242 maybe_quote(p, reserved_words);
1243 }
1244 for h in &mut table_ref.hints {
1246 quote_identifiers_recursive(h, reserved_words);
1247 }
1248 if let Some(ref mut ver) = table_ref.version {
1249 quote_identifiers_recursive(&mut ver.this, reserved_words);
1250 if let Some(ref mut e) = ver.expression {
1251 quote_identifiers_recursive(e, reserved_words);
1252 }
1253 }
1254 }
1255
1256 Expression::Star(star) => {
1257 if let Some(ref mut table) = star.table {
1258 maybe_quote(table, reserved_words);
1259 }
1260 if let Some(ref mut except_ids) = star.except {
1261 for id in except_ids {
1262 maybe_quote(id, reserved_words);
1263 }
1264 }
1265 if let Some(ref mut replace_aliases) = star.replace {
1266 for alias in replace_aliases {
1267 maybe_quote(&mut alias.alias, reserved_words);
1268 quote_identifiers_recursive(&mut alias.this, reserved_words);
1269 }
1270 }
1271 if let Some(ref mut rename_pairs) = star.rename {
1272 for (from, to) in rename_pairs {
1273 maybe_quote(from, reserved_words);
1274 maybe_quote(to, reserved_words);
1275 }
1276 }
1277 }
1278
1279 Expression::Alias(alias) => {
1281 maybe_quote(&mut alias.alias, reserved_words);
1282 for ca in &mut alias.column_aliases {
1283 maybe_quote(ca, reserved_words);
1284 }
1285 quote_identifiers_recursive(&mut alias.this, reserved_words);
1286 }
1287
1288 Expression::Select(select) => {
1290 for e in &mut select.expressions {
1291 quote_identifiers_recursive(e, reserved_words);
1292 }
1293 if let Some(ref mut from) = select.from {
1294 for e in &mut from.expressions {
1295 quote_identifiers_recursive(e, reserved_words);
1296 }
1297 }
1298 for join in &mut select.joins {
1299 quote_join(join, reserved_words);
1300 }
1301 for lv in &mut select.lateral_views {
1302 quote_lateral_view(lv, reserved_words);
1303 }
1304 if let Some(ref mut prewhere) = select.prewhere {
1305 quote_identifiers_recursive(prewhere, reserved_words);
1306 }
1307 if let Some(ref mut wh) = select.where_clause {
1308 quote_identifiers_recursive(&mut wh.this, reserved_words);
1309 }
1310 if let Some(ref mut gb) = select.group_by {
1311 for e in &mut gb.expressions {
1312 quote_identifiers_recursive(e, reserved_words);
1313 }
1314 }
1315 if let Some(ref mut hv) = select.having {
1316 quote_identifiers_recursive(&mut hv.this, reserved_words);
1317 }
1318 if let Some(ref mut q) = select.qualify {
1319 quote_identifiers_recursive(&mut q.this, reserved_words);
1320 }
1321 if let Some(ref mut ob) = select.order_by {
1322 for o in &mut ob.expressions {
1323 quote_identifiers_recursive(&mut o.this, reserved_words);
1324 }
1325 }
1326 if let Some(ref mut lim) = select.limit {
1327 quote_identifiers_recursive(&mut lim.this, reserved_words);
1328 }
1329 if let Some(ref mut off) = select.offset {
1330 quote_identifiers_recursive(&mut off.this, reserved_words);
1331 }
1332 if let Some(ref mut with) = select.with {
1333 quote_with(with, reserved_words);
1334 }
1335 if let Some(ref mut windows) = select.windows {
1336 for nw in windows {
1337 maybe_quote(&mut nw.name, reserved_words);
1338 quote_over(&mut nw.spec, reserved_words);
1339 }
1340 }
1341 if let Some(ref mut distinct_on) = select.distinct_on {
1342 for e in distinct_on {
1343 quote_identifiers_recursive(e, reserved_words);
1344 }
1345 }
1346 if let Some(ref mut limit_by) = select.limit_by {
1347 for e in limit_by {
1348 quote_identifiers_recursive(e, reserved_words);
1349 }
1350 }
1351 if let Some(ref mut settings) = select.settings {
1352 for e in settings {
1353 quote_identifiers_recursive(e, reserved_words);
1354 }
1355 }
1356 if let Some(ref mut format) = select.format {
1357 quote_identifiers_recursive(format, reserved_words);
1358 }
1359 }
1360
1361 Expression::Union(u) => {
1363 quote_identifiers_recursive(&mut u.left, reserved_words);
1364 quote_identifiers_recursive(&mut u.right, reserved_words);
1365 if let Some(ref mut with) = u.with {
1366 quote_with(with, reserved_words);
1367 }
1368 if let Some(ref mut ob) = u.order_by {
1369 for o in &mut ob.expressions {
1370 quote_identifiers_recursive(&mut o.this, reserved_words);
1371 }
1372 }
1373 if let Some(ref mut lim) = u.limit {
1374 quote_identifiers_recursive(lim, reserved_words);
1375 }
1376 if let Some(ref mut off) = u.offset {
1377 quote_identifiers_recursive(off, reserved_words);
1378 }
1379 }
1380 Expression::Intersect(i) => {
1381 quote_identifiers_recursive(&mut i.left, reserved_words);
1382 quote_identifiers_recursive(&mut i.right, reserved_words);
1383 if let Some(ref mut with) = i.with {
1384 quote_with(with, reserved_words);
1385 }
1386 if let Some(ref mut ob) = i.order_by {
1387 for o in &mut ob.expressions {
1388 quote_identifiers_recursive(&mut o.this, reserved_words);
1389 }
1390 }
1391 }
1392 Expression::Except(e) => {
1393 quote_identifiers_recursive(&mut e.left, reserved_words);
1394 quote_identifiers_recursive(&mut e.right, reserved_words);
1395 if let Some(ref mut with) = e.with {
1396 quote_with(with, reserved_words);
1397 }
1398 if let Some(ref mut ob) = e.order_by {
1399 for o in &mut ob.expressions {
1400 quote_identifiers_recursive(&mut o.this, reserved_words);
1401 }
1402 }
1403 }
1404
1405 Expression::Subquery(sq) => {
1407 quote_identifiers_recursive(&mut sq.this, reserved_words);
1408 if let Some(ref mut alias) = sq.alias {
1409 maybe_quote(alias, reserved_words);
1410 }
1411 for ca in &mut sq.column_aliases {
1412 maybe_quote(ca, reserved_words);
1413 }
1414 if let Some(ref mut ob) = sq.order_by {
1415 for o in &mut ob.expressions {
1416 quote_identifiers_recursive(&mut o.this, reserved_words);
1417 }
1418 }
1419 }
1420
1421 Expression::Insert(ins) => {
1423 quote_table_ref(&mut ins.table, reserved_words);
1424 for c in &mut ins.columns {
1425 maybe_quote(c, reserved_words);
1426 }
1427 for row in &mut ins.values {
1428 for e in row {
1429 quote_identifiers_recursive(e, reserved_words);
1430 }
1431 }
1432 if let Some(ref mut q) = ins.query {
1433 quote_identifiers_recursive(q, reserved_words);
1434 }
1435 for (id, val) in &mut ins.partition {
1436 maybe_quote(id, reserved_words);
1437 if let Some(ref mut v) = val {
1438 quote_identifiers_recursive(v, reserved_words);
1439 }
1440 }
1441 for e in &mut ins.returning {
1442 quote_identifiers_recursive(e, reserved_words);
1443 }
1444 if let Some(ref mut on_conflict) = ins.on_conflict {
1445 quote_identifiers_recursive(on_conflict, reserved_words);
1446 }
1447 if let Some(ref mut with) = ins.with {
1448 quote_with(with, reserved_words);
1449 }
1450 if let Some(ref mut alias) = ins.alias {
1451 maybe_quote(alias, reserved_words);
1452 }
1453 if let Some(ref mut src_alias) = ins.source_alias {
1454 maybe_quote(src_alias, reserved_words);
1455 }
1456 }
1457
1458 Expression::Update(upd) => {
1459 quote_table_ref(&mut upd.table, reserved_words);
1460 for tr in &mut upd.extra_tables {
1461 quote_table_ref(tr, reserved_words);
1462 }
1463 for join in &mut upd.table_joins {
1464 quote_join(join, reserved_words);
1465 }
1466 for (id, val) in &mut upd.set {
1467 maybe_quote(id, reserved_words);
1468 quote_identifiers_recursive(val, reserved_words);
1469 }
1470 if let Some(ref mut from) = upd.from_clause {
1471 for e in &mut from.expressions {
1472 quote_identifiers_recursive(e, reserved_words);
1473 }
1474 }
1475 for join in &mut upd.from_joins {
1476 quote_join(join, reserved_words);
1477 }
1478 if let Some(ref mut wh) = upd.where_clause {
1479 quote_identifiers_recursive(&mut wh.this, reserved_words);
1480 }
1481 for e in &mut upd.returning {
1482 quote_identifiers_recursive(e, reserved_words);
1483 }
1484 if let Some(ref mut with) = upd.with {
1485 quote_with(with, reserved_words);
1486 }
1487 }
1488
1489 Expression::Delete(del) => {
1490 quote_table_ref(&mut del.table, reserved_words);
1491 if let Some(ref mut alias) = del.alias {
1492 maybe_quote(alias, reserved_words);
1493 }
1494 for tr in &mut del.using {
1495 quote_table_ref(tr, reserved_words);
1496 }
1497 if let Some(ref mut wh) = del.where_clause {
1498 quote_identifiers_recursive(&mut wh.this, reserved_words);
1499 }
1500 if let Some(ref mut with) = del.with {
1501 quote_with(with, reserved_words);
1502 }
1503 }
1504
1505 Expression::And(bin)
1507 | Expression::Or(bin)
1508 | Expression::Eq(bin)
1509 | Expression::Neq(bin)
1510 | Expression::Lt(bin)
1511 | Expression::Lte(bin)
1512 | Expression::Gt(bin)
1513 | Expression::Gte(bin)
1514 | Expression::Add(bin)
1515 | Expression::Sub(bin)
1516 | Expression::Mul(bin)
1517 | Expression::Div(bin)
1518 | Expression::Mod(bin)
1519 | Expression::BitwiseAnd(bin)
1520 | Expression::BitwiseOr(bin)
1521 | Expression::BitwiseXor(bin)
1522 | Expression::Concat(bin)
1523 | Expression::Adjacent(bin)
1524 | Expression::TsMatch(bin)
1525 | Expression::PropertyEQ(bin)
1526 | Expression::ArrayContainsAll(bin)
1527 | Expression::ArrayContainedBy(bin)
1528 | Expression::ArrayOverlaps(bin)
1529 | Expression::JSONBContainsAllTopKeys(bin)
1530 | Expression::JSONBContainsAnyTopKeys(bin)
1531 | Expression::JSONBDeleteAtPath(bin)
1532 | Expression::ExtendsLeft(bin)
1533 | Expression::ExtendsRight(bin)
1534 | Expression::Is(bin)
1535 | Expression::NullSafeEq(bin)
1536 | Expression::NullSafeNeq(bin)
1537 | Expression::Glob(bin)
1538 | Expression::Match(bin)
1539 | Expression::MemberOf(bin)
1540 | Expression::BitwiseLeftShift(bin)
1541 | Expression::BitwiseRightShift(bin) => {
1542 quote_identifiers_recursive(&mut bin.left, reserved_words);
1543 quote_identifiers_recursive(&mut bin.right, reserved_words);
1544 }
1545
1546 Expression::Like(like) | Expression::ILike(like) => {
1548 quote_identifiers_recursive(&mut like.left, reserved_words);
1549 quote_identifiers_recursive(&mut like.right, reserved_words);
1550 if let Some(ref mut esc) = like.escape {
1551 quote_identifiers_recursive(esc, reserved_words);
1552 }
1553 }
1554
1555 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1557 quote_identifiers_recursive(&mut un.this, reserved_words);
1558 }
1559
1560 Expression::In(in_expr) => {
1562 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1563 for e in &mut in_expr.expressions {
1564 quote_identifiers_recursive(e, reserved_words);
1565 }
1566 if let Some(ref mut q) = in_expr.query {
1567 quote_identifiers_recursive(q, reserved_words);
1568 }
1569 if let Some(ref mut un) = in_expr.unnest {
1570 quote_identifiers_recursive(un, reserved_words);
1571 }
1572 }
1573
1574 Expression::Between(bw) => {
1575 quote_identifiers_recursive(&mut bw.this, reserved_words);
1576 quote_identifiers_recursive(&mut bw.low, reserved_words);
1577 quote_identifiers_recursive(&mut bw.high, reserved_words);
1578 }
1579
1580 Expression::IsNull(is_null) => {
1581 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1582 }
1583
1584 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
1585 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
1586 }
1587
1588 Expression::Exists(ex) => {
1589 quote_identifiers_recursive(&mut ex.this, reserved_words);
1590 }
1591
1592 Expression::Function(func) => {
1594 for arg in &mut func.args {
1595 quote_identifiers_recursive(arg, reserved_words);
1596 }
1597 }
1598
1599 Expression::AggregateFunction(agg) => {
1600 for arg in &mut agg.args {
1601 quote_identifiers_recursive(arg, reserved_words);
1602 }
1603 if let Some(ref mut filter) = agg.filter {
1604 quote_identifiers_recursive(filter, reserved_words);
1605 }
1606 for o in &mut agg.order_by {
1607 quote_identifiers_recursive(&mut o.this, reserved_words);
1608 }
1609 }
1610
1611 Expression::WindowFunction(wf) => {
1612 quote_identifiers_recursive(&mut wf.this, reserved_words);
1613 quote_over(&mut wf.over, reserved_words);
1614 }
1615
1616 Expression::Case(case) => {
1618 if let Some(ref mut operand) = case.operand {
1619 quote_identifiers_recursive(operand, reserved_words);
1620 }
1621 for (when, then) in &mut case.whens {
1622 quote_identifiers_recursive(when, reserved_words);
1623 quote_identifiers_recursive(then, reserved_words);
1624 }
1625 if let Some(ref mut else_) = case.else_ {
1626 quote_identifiers_recursive(else_, reserved_words);
1627 }
1628 }
1629
1630 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1632 quote_identifiers_recursive(&mut cast.this, reserved_words);
1633 if let Some(ref mut fmt) = cast.format {
1634 quote_identifiers_recursive(fmt, reserved_words);
1635 }
1636 }
1637
1638 Expression::Paren(paren) => {
1640 quote_identifiers_recursive(&mut paren.this, reserved_words);
1641 }
1642
1643 Expression::Annotated(ann) => {
1644 quote_identifiers_recursive(&mut ann.this, reserved_words);
1645 }
1646
1647 Expression::With(with) => {
1649 quote_with(with, reserved_words);
1650 }
1651
1652 Expression::Cte(cte) => {
1653 maybe_quote(&mut cte.alias, reserved_words);
1654 for c in &mut cte.columns {
1655 maybe_quote(c, reserved_words);
1656 }
1657 quote_identifiers_recursive(&mut cte.this, reserved_words);
1658 }
1659
1660 Expression::From(from) => {
1662 for e in &mut from.expressions {
1663 quote_identifiers_recursive(e, reserved_words);
1664 }
1665 }
1666
1667 Expression::Join(join) => {
1668 quote_join(join, reserved_words);
1669 }
1670
1671 Expression::JoinedTable(jt) => {
1672 quote_identifiers_recursive(&mut jt.left, reserved_words);
1673 for join in &mut jt.joins {
1674 quote_join(join, reserved_words);
1675 }
1676 if let Some(ref mut alias) = jt.alias {
1677 maybe_quote(alias, reserved_words);
1678 }
1679 }
1680
1681 Expression::Where(wh) => {
1682 quote_identifiers_recursive(&mut wh.this, reserved_words);
1683 }
1684
1685 Expression::GroupBy(gb) => {
1686 for e in &mut gb.expressions {
1687 quote_identifiers_recursive(e, reserved_words);
1688 }
1689 }
1690
1691 Expression::Having(hv) => {
1692 quote_identifiers_recursive(&mut hv.this, reserved_words);
1693 }
1694
1695 Expression::OrderBy(ob) => {
1696 for o in &mut ob.expressions {
1697 quote_identifiers_recursive(&mut o.this, reserved_words);
1698 }
1699 }
1700
1701 Expression::Ordered(ord) => {
1702 quote_identifiers_recursive(&mut ord.this, reserved_words);
1703 }
1704
1705 Expression::Limit(lim) => {
1706 quote_identifiers_recursive(&mut lim.this, reserved_words);
1707 }
1708
1709 Expression::Offset(off) => {
1710 quote_identifiers_recursive(&mut off.this, reserved_words);
1711 }
1712
1713 Expression::Qualify(q) => {
1714 quote_identifiers_recursive(&mut q.this, reserved_words);
1715 }
1716
1717 Expression::Window(ws) => {
1718 for e in &mut ws.partition_by {
1719 quote_identifiers_recursive(e, reserved_words);
1720 }
1721 for o in &mut ws.order_by {
1722 quote_identifiers_recursive(&mut o.this, reserved_words);
1723 }
1724 }
1725
1726 Expression::Over(over) => {
1727 quote_over(over, reserved_words);
1728 }
1729
1730 Expression::WithinGroup(wg) => {
1731 quote_identifiers_recursive(&mut wg.this, reserved_words);
1732 for o in &mut wg.order_by {
1733 quote_identifiers_recursive(&mut o.this, reserved_words);
1734 }
1735 }
1736
1737 Expression::Pivot(piv) => {
1739 quote_identifiers_recursive(&mut piv.this, reserved_words);
1740 for e in &mut piv.expressions {
1741 quote_identifiers_recursive(e, reserved_words);
1742 }
1743 for f in &mut piv.fields {
1744 quote_identifiers_recursive(f, reserved_words);
1745 }
1746 if let Some(ref mut alias) = piv.alias {
1747 maybe_quote(alias, reserved_words);
1748 }
1749 }
1750
1751 Expression::Unpivot(unpiv) => {
1752 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
1753 maybe_quote(&mut unpiv.value_column, reserved_words);
1754 maybe_quote(&mut unpiv.name_column, reserved_words);
1755 for e in &mut unpiv.columns {
1756 quote_identifiers_recursive(e, reserved_words);
1757 }
1758 if let Some(ref mut alias) = unpiv.alias {
1759 maybe_quote(alias, reserved_words);
1760 }
1761 }
1762
1763 Expression::Values(vals) => {
1765 for tuple in &mut vals.expressions {
1766 for e in &mut tuple.expressions {
1767 quote_identifiers_recursive(e, reserved_words);
1768 }
1769 }
1770 if let Some(ref mut alias) = vals.alias {
1771 maybe_quote(alias, reserved_words);
1772 }
1773 for ca in &mut vals.column_aliases {
1774 maybe_quote(ca, reserved_words);
1775 }
1776 }
1777
1778 Expression::Array(arr) => {
1780 for e in &mut arr.expressions {
1781 quote_identifiers_recursive(e, reserved_words);
1782 }
1783 }
1784
1785 Expression::Struct(st) => {
1786 for (_name, e) in &mut st.fields {
1787 quote_identifiers_recursive(e, reserved_words);
1788 }
1789 }
1790
1791 Expression::Tuple(tup) => {
1792 for e in &mut tup.expressions {
1793 quote_identifiers_recursive(e, reserved_words);
1794 }
1795 }
1796
1797 Expression::Subscript(sub) => {
1799 quote_identifiers_recursive(&mut sub.this, reserved_words);
1800 quote_identifiers_recursive(&mut sub.index, reserved_words);
1801 }
1802
1803 Expression::Dot(dot) => {
1804 quote_identifiers_recursive(&mut dot.this, reserved_words);
1805 maybe_quote(&mut dot.field, reserved_words);
1806 }
1807
1808 Expression::ScopeResolution(sr) => {
1809 if let Some(ref mut this) = sr.this {
1810 quote_identifiers_recursive(this, reserved_words);
1811 }
1812 quote_identifiers_recursive(&mut sr.expression, reserved_words);
1813 }
1814
1815 Expression::Lateral(lat) => {
1817 quote_identifiers_recursive(&mut lat.this, reserved_words);
1818 }
1820
1821 Expression::DPipe(dpipe) => {
1823 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
1824 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
1825 }
1826
1827 Expression::Merge(merge) => {
1829 quote_identifiers_recursive(&mut merge.this, reserved_words);
1830 quote_identifiers_recursive(&mut merge.using, reserved_words);
1831 if let Some(ref mut on) = merge.on {
1832 quote_identifiers_recursive(on, reserved_words);
1833 }
1834 if let Some(ref mut whens) = merge.whens {
1835 quote_identifiers_recursive(whens, reserved_words);
1836 }
1837 if let Some(ref mut with) = merge.with_ {
1838 quote_identifiers_recursive(with, reserved_words);
1839 }
1840 if let Some(ref mut ret) = merge.returning {
1841 quote_identifiers_recursive(ret, reserved_words);
1842 }
1843 }
1844
1845 Expression::LateralView(lv) => {
1847 quote_lateral_view(lv, reserved_words);
1848 }
1849
1850 Expression::Anonymous(anon) => {
1852 quote_identifiers_recursive(&mut anon.this, reserved_words);
1853 for e in &mut anon.expressions {
1854 quote_identifiers_recursive(e, reserved_words);
1855 }
1856 }
1857
1858 Expression::Filter(filter) => {
1860 quote_identifiers_recursive(&mut filter.this, reserved_words);
1861 quote_identifiers_recursive(&mut filter.expression, reserved_words);
1862 }
1863
1864 Expression::Returning(ret) => {
1866 for e in &mut ret.expressions {
1867 quote_identifiers_recursive(e, reserved_words);
1868 }
1869 }
1870
1871 Expression::BracedWildcard(inner) => {
1873 quote_identifiers_recursive(inner, reserved_words);
1874 }
1875
1876 Expression::ReturnStmt(inner) => {
1878 quote_identifiers_recursive(inner, reserved_words);
1879 }
1880
1881 Expression::Literal(_)
1883 | Expression::Boolean(_)
1884 | Expression::Null(_)
1885 | Expression::DataType(_)
1886 | Expression::Raw(_)
1887 | Expression::Placeholder(_)
1888 | Expression::CurrentDate(_)
1889 | Expression::CurrentTime(_)
1890 | Expression::CurrentTimestamp(_)
1891 | Expression::CurrentTimestampLTZ(_)
1892 | Expression::SessionUser(_)
1893 | Expression::RowNumber(_)
1894 | Expression::Rank(_)
1895 | Expression::DenseRank(_)
1896 | Expression::PercentRank(_)
1897 | Expression::CumeDist(_)
1898 | Expression::Random(_)
1899 | Expression::Pi(_)
1900 | Expression::JSONPathRoot(_) => {
1901 }
1903
1904 _ => {}
1908 }
1909}
1910
1911fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
1913 quote_identifiers_recursive(&mut join.this, reserved_words);
1914 if let Some(ref mut on) = join.on {
1915 quote_identifiers_recursive(on, reserved_words);
1916 }
1917 for id in &mut join.using {
1918 maybe_quote(id, reserved_words);
1919 }
1920 if let Some(ref mut mc) = join.match_condition {
1921 quote_identifiers_recursive(mc, reserved_words);
1922 }
1923 for piv in &mut join.pivots {
1924 quote_identifiers_recursive(piv, reserved_words);
1925 }
1926}
1927
1928fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
1930 for cte in &mut with.ctes {
1931 maybe_quote(&mut cte.alias, reserved_words);
1932 for c in &mut cte.columns {
1933 maybe_quote(c, reserved_words);
1934 }
1935 for k in &mut cte.key_expressions {
1936 maybe_quote(k, reserved_words);
1937 }
1938 quote_identifiers_recursive(&mut cte.this, reserved_words);
1939 }
1940}
1941
1942fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
1944 if let Some(ref mut wn) = over.window_name {
1945 maybe_quote(wn, reserved_words);
1946 }
1947 for e in &mut over.partition_by {
1948 quote_identifiers_recursive(e, reserved_words);
1949 }
1950 for o in &mut over.order_by {
1951 quote_identifiers_recursive(&mut o.this, reserved_words);
1952 }
1953 if let Some(ref mut alias) = over.alias {
1954 maybe_quote(alias, reserved_words);
1955 }
1956}
1957
1958fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
1960 maybe_quote(&mut table_ref.name, reserved_words);
1961 if let Some(ref mut schema) = table_ref.schema {
1962 maybe_quote(schema, reserved_words);
1963 }
1964 if let Some(ref mut catalog) = table_ref.catalog {
1965 maybe_quote(catalog, reserved_words);
1966 }
1967 if let Some(ref mut alias) = table_ref.alias {
1968 maybe_quote(alias, reserved_words);
1969 }
1970 for ca in &mut table_ref.column_aliases {
1971 maybe_quote(ca, reserved_words);
1972 }
1973 for p in &mut table_ref.partitions {
1974 maybe_quote(p, reserved_words);
1975 }
1976 for h in &mut table_ref.hints {
1977 quote_identifiers_recursive(h, reserved_words);
1978 }
1979}
1980
1981fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
1983 quote_identifiers_recursive(&mut lv.this, reserved_words);
1984 if let Some(ref mut ta) = lv.table_alias {
1985 maybe_quote(ta, reserved_words);
1986 }
1987 for ca in &mut lv.column_aliases {
1988 maybe_quote(ca, reserved_words);
1989 }
1990}
1991
1992pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
2003 let reserved_words = get_reserved_words(dialect);
2004 let mut result = expression;
2005 quote_identifiers_recursive(&mut result, &reserved_words);
2006 result
2007}
2008
2009pub fn pushdown_cte_alias_columns(_scope: &Scope) {
2014 }
2017
2018fn pushdown_cte_alias_columns_with(with: &mut With) {
2019 for cte in &mut with.ctes {
2020 if cte.columns.is_empty() {
2021 continue;
2022 }
2023
2024 if let Expression::Select(select) = &mut cte.this {
2025 let mut next_expressions = Vec::with_capacity(select.expressions.len());
2026
2027 for (i, projection) in select.expressions.iter().enumerate() {
2028 let Some(alias_name) = cte.columns.get(i) else {
2029 next_expressions.push(projection.clone());
2030 continue;
2031 };
2032
2033 match projection {
2034 Expression::Alias(existing) => {
2035 let mut aliased = existing.clone();
2036 aliased.alias = alias_name.clone();
2037 next_expressions.push(Expression::Alias(aliased));
2038 }
2039 _ => {
2040 next_expressions.push(create_alias(projection.clone(), &alias_name.name));
2041 }
2042 }
2043 }
2044
2045 select.expressions = next_expressions;
2046 }
2047 }
2048}
2049
2050fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
2056 let mut columns = Vec::new();
2057 collect_columns(&scope.expression, &mut columns);
2058 columns
2059}
2060
2061#[derive(Debug, Clone)]
2063struct ColumnRef {
2064 table: Option<String>,
2065 name: String,
2066}
2067
2068fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
2070 match expr {
2071 Expression::Column(col) => {
2072 columns.push(ColumnRef {
2073 table: col.table.as_ref().map(|t| t.name.clone()),
2074 name: col.name.name.clone(),
2075 });
2076 }
2077 Expression::Select(select) => {
2078 for e in &select.expressions {
2079 collect_columns(e, columns);
2080 }
2081 if let Some(from) = &select.from {
2082 for e in &from.expressions {
2083 collect_columns(e, columns);
2084 }
2085 }
2086 if let Some(where_clause) = &select.where_clause {
2087 collect_columns(&where_clause.this, columns);
2088 }
2089 if let Some(group_by) = &select.group_by {
2090 for e in &group_by.expressions {
2091 collect_columns(e, columns);
2092 }
2093 }
2094 if let Some(having) = &select.having {
2095 collect_columns(&having.this, columns);
2096 }
2097 if let Some(order_by) = &select.order_by {
2098 for o in &order_by.expressions {
2099 collect_columns(&o.this, columns);
2100 }
2101 }
2102 for join in &select.joins {
2103 collect_columns(&join.this, columns);
2104 if let Some(on) = &join.on {
2105 collect_columns(on, columns);
2106 }
2107 }
2108 }
2109 Expression::Alias(alias) => {
2110 collect_columns(&alias.this, columns);
2111 }
2112 Expression::Function(func) => {
2113 for arg in &func.args {
2114 collect_columns(arg, columns);
2115 }
2116 }
2117 Expression::AggregateFunction(agg) => {
2118 for arg in &agg.args {
2119 collect_columns(arg, columns);
2120 }
2121 }
2122 Expression::And(bin)
2123 | Expression::Or(bin)
2124 | Expression::Eq(bin)
2125 | Expression::Neq(bin)
2126 | Expression::Lt(bin)
2127 | Expression::Lte(bin)
2128 | Expression::Gt(bin)
2129 | Expression::Gte(bin)
2130 | Expression::Add(bin)
2131 | Expression::Sub(bin)
2132 | Expression::Mul(bin)
2133 | Expression::Div(bin) => {
2134 collect_columns(&bin.left, columns);
2135 collect_columns(&bin.right, columns);
2136 }
2137 Expression::Not(unary) | Expression::Neg(unary) => {
2138 collect_columns(&unary.this, columns);
2139 }
2140 Expression::Paren(paren) => {
2141 collect_columns(&paren.this, columns);
2142 }
2143 Expression::Case(case) => {
2144 if let Some(operand) = &case.operand {
2145 collect_columns(operand, columns);
2146 }
2147 for (when, then) in &case.whens {
2148 collect_columns(when, columns);
2149 collect_columns(then, columns);
2150 }
2151 if let Some(else_) = &case.else_ {
2152 collect_columns(else_, columns);
2153 }
2154 }
2155 Expression::Cast(cast) => {
2156 collect_columns(&cast.this, columns);
2157 }
2158 Expression::In(in_expr) => {
2159 collect_columns(&in_expr.this, columns);
2160 for e in &in_expr.expressions {
2161 collect_columns(e, columns);
2162 }
2163 if let Some(query) = &in_expr.query {
2164 collect_columns(query, columns);
2165 }
2166 }
2167 Expression::Between(between) => {
2168 collect_columns(&between.this, columns);
2169 collect_columns(&between.low, columns);
2170 collect_columns(&between.high, columns);
2171 }
2172 Expression::Subquery(subquery) => {
2173 collect_columns(&subquery.this, columns);
2174 }
2175 _ => {}
2176 }
2177}
2178
2179fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
2181 get_scope_columns(scope)
2182 .into_iter()
2183 .filter(|c| c.table.is_none())
2184 .collect()
2185}
2186
2187fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
2189 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
2190
2191 get_scope_columns(scope)
2192 .into_iter()
2193 .filter(|c| {
2194 if let Some(table) = &c.table {
2195 !source_names.contains(table)
2196 } else {
2197 false
2198 }
2199 })
2200 .collect()
2201}
2202
2203fn is_correlated_subquery(scope: &Scope) -> bool {
2205 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2206}
2207
2208fn is_star_column(col: &Column) -> bool {
2210 col.name.name == "*"
2211}
2212
2213fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2215 Expression::Column(Column {
2216 name: Identifier::new(name),
2217 table: table.map(Identifier::new),
2218 join_mark: false,
2219 trailing_comments: vec![],
2220 span: None,
2221 })
2222}
2223
2224fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2226 Expression::Alias(Box::new(Alias {
2227 this: expr,
2228 alias: Identifier::new(alias_name),
2229 column_aliases: vec![],
2230 pre_alias_comments: vec![],
2231 trailing_comments: vec![],
2232 }))
2233}
2234
2235fn get_output_name(expr: &Expression) -> Option<String> {
2237 match expr {
2238 Expression::Column(col) => Some(col.name.name.clone()),
2239 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2240 Expression::Identifier(id) => Some(id.name.clone()),
2241 _ => None,
2242 }
2243}
2244
2245#[cfg(test)]
2246mod tests {
2247 use super::*;
2248 use crate::expressions::DataType;
2249 use crate::generator::Generator;
2250 use crate::parser::Parser;
2251 use crate::scope::build_scope;
2252 use crate::{MappingSchema, Schema};
2253
2254 fn gen(expr: &Expression) -> String {
2255 Generator::new().generate(expr).unwrap()
2256 }
2257
2258 fn parse(sql: &str) -> Expression {
2259 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2260 }
2261
2262 #[test]
2263 fn test_qualify_columns_options() {
2264 let options = QualifyColumnsOptions::new()
2265 .with_expand_alias_refs(true)
2266 .with_expand_stars(false)
2267 .with_dialect(DialectType::PostgreSQL)
2268 .with_allow_partial(true);
2269
2270 assert!(options.expand_alias_refs);
2271 assert!(!options.expand_stars);
2272 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2273 assert!(options.allow_partial_qualification);
2274 }
2275
2276 #[test]
2277 fn test_get_scope_columns() {
2278 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2279 let scope = build_scope(&expr);
2280 let columns = get_scope_columns(&scope);
2281
2282 assert!(columns.iter().any(|c| c.name == "a"));
2283 assert!(columns.iter().any(|c| c.name == "b"));
2284 assert!(columns.iter().any(|c| c.name == "c"));
2285 }
2286
2287 #[test]
2288 fn test_get_unqualified_columns() {
2289 let expr = parse("SELECT t.a, b FROM t");
2290 let scope = build_scope(&expr);
2291 let unqualified = get_unqualified_columns(&scope);
2292
2293 assert!(unqualified.iter().any(|c| c.name == "b"));
2295 assert!(!unqualified.iter().any(|c| c.name == "a"));
2296 }
2297
2298 #[test]
2299 fn test_is_star_column() {
2300 let col = Column {
2301 name: Identifier::new("*"),
2302 table: Some(Identifier::new("t")),
2303 join_mark: false,
2304 trailing_comments: vec![],
2305 span: None,
2306 };
2307 assert!(is_star_column(&col));
2308
2309 let col2 = Column {
2310 name: Identifier::new("id"),
2311 table: None,
2312 join_mark: false,
2313 trailing_comments: vec![],
2314 span: None,
2315 };
2316 assert!(!is_star_column(&col2));
2317 }
2318
2319 #[test]
2320 fn test_create_qualified_column() {
2321 let expr = create_qualified_column("id", Some("users"));
2322 let sql = gen(&expr);
2323 assert!(sql.contains("users"));
2324 assert!(sql.contains("id"));
2325 }
2326
2327 #[test]
2328 fn test_create_alias() {
2329 let col = Expression::Column(Column {
2330 name: Identifier::new("value"),
2331 table: None,
2332 join_mark: false,
2333 trailing_comments: vec![],
2334 span: None,
2335 });
2336 let aliased = create_alias(col, "total");
2337 let sql = gen(&aliased);
2338 assert!(sql.contains("AS") || sql.contains("total"));
2339 }
2340
2341 #[test]
2342 fn test_validate_qualify_columns_success() {
2343 let expr = parse("SELECT t.a, t.b FROM t");
2345 let result = validate_qualify_columns(&expr);
2346 let _ = result;
2349 }
2350
2351 #[test]
2352 fn test_collect_columns_nested() {
2353 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2354 let mut columns = Vec::new();
2355 collect_columns(&expr, &mut columns);
2356
2357 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2358 assert!(names.contains(&"a"));
2359 assert!(names.contains(&"b"));
2360 assert!(names.contains(&"c"));
2361 assert!(names.contains(&"d"));
2362 assert!(names.contains(&"e"));
2363 assert!(names.contains(&"f"));
2364 }
2365
2366 #[test]
2367 fn test_collect_columns_in_case() {
2368 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2369 let mut columns = Vec::new();
2370 collect_columns(&expr, &mut columns);
2371
2372 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2373 assert!(names.contains(&"a"));
2374 assert!(names.contains(&"b"));
2375 assert!(names.contains(&"c"));
2376 }
2377
2378 #[test]
2379 fn test_collect_columns_in_subquery() {
2380 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2381 let mut columns = Vec::new();
2382 collect_columns(&expr, &mut columns);
2383
2384 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2385 assert!(names.contains(&"a"));
2386 assert!(names.contains(&"b"));
2387 assert!(names.contains(&"c"));
2388 }
2389
2390 #[test]
2391 fn test_qualify_outputs_basic() {
2392 let expr = parse("SELECT a, b + c FROM t");
2393 let scope = build_scope(&expr);
2394 let result = qualify_outputs(&scope);
2395 assert!(result.is_ok());
2396 }
2397
2398 #[test]
2399 fn test_qualify_columns_expands_star_with_schema() {
2400 let expr = parse("SELECT * FROM users");
2401
2402 let mut schema = MappingSchema::new();
2403 schema
2404 .add_table(
2405 "users",
2406 &[
2407 (
2408 "id".to_string(),
2409 DataType::Int {
2410 length: None,
2411 integer_spelling: false,
2412 },
2413 ),
2414 ("name".to_string(), DataType::Text),
2415 ("email".to_string(), DataType::Text),
2416 ],
2417 None,
2418 )
2419 .expect("schema setup");
2420
2421 let result =
2422 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2423 let sql = gen(&result);
2424
2425 assert!(!sql.contains("SELECT *"));
2426 assert!(sql.contains("users.id"));
2427 assert!(sql.contains("users.name"));
2428 assert!(sql.contains("users.email"));
2429 }
2430
2431 #[test]
2432 fn test_qualify_columns_expands_group_by_positions() {
2433 let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
2434
2435 let mut schema = MappingSchema::new();
2436 schema
2437 .add_table(
2438 "t",
2439 &[
2440 (
2441 "a".to_string(),
2442 DataType::Int {
2443 length: None,
2444 integer_spelling: false,
2445 },
2446 ),
2447 (
2448 "b".to_string(),
2449 DataType::Int {
2450 length: None,
2451 integer_spelling: false,
2452 },
2453 ),
2454 ],
2455 None,
2456 )
2457 .expect("schema setup");
2458
2459 let result =
2460 qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
2461 let sql = gen(&result);
2462
2463 assert!(!sql.contains("GROUP BY 1"));
2464 assert!(!sql.contains("GROUP BY 2"));
2465 assert!(sql.contains("GROUP BY"));
2466 assert!(sql.contains("t.a"));
2467 assert!(sql.contains("t.b"));
2468 }
2469
2470 #[test]
2475 fn test_needs_quoting_reserved_word() {
2476 let reserved = get_reserved_words(None);
2477 assert!(needs_quoting("select", &reserved));
2478 assert!(needs_quoting("SELECT", &reserved));
2479 assert!(needs_quoting("from", &reserved));
2480 assert!(needs_quoting("WHERE", &reserved));
2481 assert!(needs_quoting("join", &reserved));
2482 assert!(needs_quoting("table", &reserved));
2483 }
2484
2485 #[test]
2486 fn test_needs_quoting_normal_identifiers() {
2487 let reserved = get_reserved_words(None);
2488 assert!(!needs_quoting("foo", &reserved));
2489 assert!(!needs_quoting("my_column", &reserved));
2490 assert!(!needs_quoting("col1", &reserved));
2491 assert!(!needs_quoting("A", &reserved));
2492 assert!(!needs_quoting("_hidden", &reserved));
2493 }
2494
2495 #[test]
2496 fn test_needs_quoting_special_characters() {
2497 let reserved = get_reserved_words(None);
2498 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
2504
2505 #[test]
2506 fn test_needs_quoting_starts_with_digit() {
2507 let reserved = get_reserved_words(None);
2508 assert!(needs_quoting("1col", &reserved));
2509 assert!(needs_quoting("123", &reserved));
2510 assert!(needs_quoting("0_start", &reserved));
2511 }
2512
2513 #[test]
2514 fn test_needs_quoting_empty() {
2515 let reserved = get_reserved_words(None);
2516 assert!(!needs_quoting("", &reserved));
2517 }
2518
2519 #[test]
2520 fn test_maybe_quote_sets_quoted_flag() {
2521 let reserved = get_reserved_words(None);
2522 let mut id = Identifier::new("select");
2523 assert!(!id.quoted);
2524 maybe_quote(&mut id, &reserved);
2525 assert!(id.quoted);
2526 }
2527
2528 #[test]
2529 fn test_maybe_quote_skips_already_quoted() {
2530 let reserved = get_reserved_words(None);
2531 let mut id = Identifier::quoted("myname");
2532 assert!(id.quoted);
2533 maybe_quote(&mut id, &reserved);
2534 assert!(id.quoted); assert_eq!(id.name, "myname"); }
2537
2538 #[test]
2539 fn test_maybe_quote_skips_star() {
2540 let reserved = get_reserved_words(None);
2541 let mut id = Identifier::new("*");
2542 maybe_quote(&mut id, &reserved);
2543 assert!(!id.quoted); }
2545
2546 #[test]
2547 fn test_maybe_quote_skips_normal() {
2548 let reserved = get_reserved_words(None);
2549 let mut id = Identifier::new("normal_col");
2550 maybe_quote(&mut id, &reserved);
2551 assert!(!id.quoted);
2552 }
2553
2554 #[test]
2555 fn test_quote_identifiers_column_with_reserved_name() {
2556 let expr = Expression::Column(Column {
2558 name: Identifier::new("select"),
2559 table: None,
2560 join_mark: false,
2561 trailing_comments: vec![],
2562 span: None,
2563 });
2564 let result = quote_identifiers(expr, None);
2565 if let Expression::Column(col) = &result {
2566 assert!(col.name.quoted, "Column named 'select' should be quoted");
2567 } else {
2568 panic!("Expected Column expression");
2569 }
2570 }
2571
2572 #[test]
2573 fn test_quote_identifiers_column_with_special_chars() {
2574 let expr = Expression::Column(Column {
2575 name: Identifier::new("my column"),
2576 table: None,
2577 join_mark: false,
2578 trailing_comments: vec![],
2579 span: None,
2580 });
2581 let result = quote_identifiers(expr, None);
2582 if let Expression::Column(col) = &result {
2583 assert!(col.name.quoted, "Column with space should be quoted");
2584 } else {
2585 panic!("Expected Column expression");
2586 }
2587 }
2588
2589 #[test]
2590 fn test_quote_identifiers_preserves_normal_column() {
2591 let expr = Expression::Column(Column {
2592 name: Identifier::new("normal_col"),
2593 table: Some(Identifier::new("my_table")),
2594 join_mark: false,
2595 trailing_comments: vec![],
2596 span: None,
2597 });
2598 let result = quote_identifiers(expr, None);
2599 if let Expression::Column(col) = &result {
2600 assert!(!col.name.quoted, "Normal column should not be quoted");
2601 assert!(
2602 !col.table.as_ref().unwrap().quoted,
2603 "Normal table should not be quoted"
2604 );
2605 } else {
2606 panic!("Expected Column expression");
2607 }
2608 }
2609
2610 #[test]
2611 fn test_quote_identifiers_table_ref_reserved() {
2612 let expr = Expression::Table(TableRef::new("select"));
2613 let result = quote_identifiers(expr, None);
2614 if let Expression::Table(tr) = &result {
2615 assert!(tr.name.quoted, "Table named 'select' should be quoted");
2616 } else {
2617 panic!("Expected Table expression");
2618 }
2619 }
2620
2621 #[test]
2622 fn test_quote_identifiers_table_ref_schema_and_alias() {
2623 let mut tr = TableRef::new("my_table");
2624 tr.schema = Some(Identifier::new("from"));
2625 tr.alias = Some(Identifier::new("t"));
2626 let expr = Expression::Table(tr);
2627 let result = quote_identifiers(expr, None);
2628 if let Expression::Table(tr) = &result {
2629 assert!(!tr.name.quoted, "Normal table name should not be quoted");
2630 assert!(
2631 tr.schema.as_ref().unwrap().quoted,
2632 "Schema named 'from' should be quoted"
2633 );
2634 assert!(
2635 !tr.alias.as_ref().unwrap().quoted,
2636 "Normal alias should not be quoted"
2637 );
2638 } else {
2639 panic!("Expected Table expression");
2640 }
2641 }
2642
2643 #[test]
2644 fn test_quote_identifiers_identifier_node() {
2645 let expr = Expression::Identifier(Identifier::new("order"));
2646 let result = quote_identifiers(expr, None);
2647 if let Expression::Identifier(id) = &result {
2648 assert!(id.quoted, "Identifier named 'order' should be quoted");
2649 } else {
2650 panic!("Expected Identifier expression");
2651 }
2652 }
2653
2654 #[test]
2655 fn test_quote_identifiers_alias() {
2656 let inner = Expression::Column(Column {
2657 name: Identifier::new("val"),
2658 table: None,
2659 join_mark: false,
2660 trailing_comments: vec![],
2661 span: None,
2662 });
2663 let expr = Expression::Alias(Box::new(Alias {
2664 this: inner,
2665 alias: Identifier::new("select"),
2666 column_aliases: vec![Identifier::new("from")],
2667 pre_alias_comments: vec![],
2668 trailing_comments: vec![],
2669 }));
2670 let result = quote_identifiers(expr, None);
2671 if let Expression::Alias(alias) = &result {
2672 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
2673 assert!(
2674 alias.column_aliases[0].quoted,
2675 "Column alias named 'from' should be quoted"
2676 );
2677 if let Expression::Column(col) = &alias.this {
2679 assert!(!col.name.quoted);
2680 }
2681 } else {
2682 panic!("Expected Alias expression");
2683 }
2684 }
2685
2686 #[test]
2687 fn test_quote_identifiers_select_recursive() {
2688 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2690 let result = quote_identifiers(expr, None);
2691 let sql = gen(&result);
2693 assert!(sql.contains("a"));
2695 assert!(sql.contains("b"));
2696 assert!(sql.contains("t"));
2697 }
2698
2699 #[test]
2700 fn test_quote_identifiers_digit_start() {
2701 let expr = Expression::Column(Column {
2702 name: Identifier::new("1col"),
2703 table: None,
2704 join_mark: false,
2705 trailing_comments: vec![],
2706 span: None,
2707 });
2708 let result = quote_identifiers(expr, None);
2709 if let Expression::Column(col) = &result {
2710 assert!(
2711 col.name.quoted,
2712 "Column starting with digit should be quoted"
2713 );
2714 } else {
2715 panic!("Expected Column expression");
2716 }
2717 }
2718
2719 #[test]
2720 fn test_quote_identifiers_with_mysql_dialect() {
2721 let reserved = get_reserved_words(Some(DialectType::MySQL));
2722 assert!(needs_quoting("KILL", &reserved));
2724 assert!(needs_quoting("FORCE", &reserved));
2726 }
2727
2728 #[test]
2729 fn test_quote_identifiers_with_postgresql_dialect() {
2730 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
2731 assert!(needs_quoting("ILIKE", &reserved));
2733 assert!(needs_quoting("VERBOSE", &reserved));
2735 }
2736
2737 #[test]
2738 fn test_quote_identifiers_with_bigquery_dialect() {
2739 let reserved = get_reserved_words(Some(DialectType::BigQuery));
2740 assert!(needs_quoting("STRUCT", &reserved));
2742 assert!(needs_quoting("PROTO", &reserved));
2744 }
2745
2746 #[test]
2747 fn test_quote_identifiers_case_insensitive_reserved() {
2748 let reserved = get_reserved_words(None);
2749 assert!(needs_quoting("Select", &reserved));
2750 assert!(needs_quoting("sElEcT", &reserved));
2751 assert!(needs_quoting("FROM", &reserved));
2752 assert!(needs_quoting("from", &reserved));
2753 }
2754
2755 #[test]
2756 fn test_quote_identifiers_join_using() {
2757 let mut join = crate::expressions::Join {
2759 this: Expression::Table(TableRef::new("other")),
2760 on: None,
2761 using: vec![Identifier::new("key"), Identifier::new("value")],
2762 kind: crate::expressions::JoinKind::Inner,
2763 use_inner_keyword: false,
2764 use_outer_keyword: false,
2765 deferred_condition: false,
2766 join_hint: None,
2767 match_condition: None,
2768 pivots: vec![],
2769 comments: vec![],
2770 nesting_group: 0,
2771 directed: false,
2772 };
2773 let reserved = get_reserved_words(None);
2774 quote_join(&mut join, &reserved);
2775 assert!(
2777 join.using[0].quoted,
2778 "USING identifier 'key' should be quoted"
2779 );
2780 assert!(
2781 !join.using[1].quoted,
2782 "USING identifier 'value' should not be quoted"
2783 );
2784 }
2785
2786 #[test]
2787 fn test_quote_identifiers_cte() {
2788 let mut cte = crate::expressions::Cte {
2790 alias: Identifier::new("select"),
2791 this: Expression::Column(Column {
2792 name: Identifier::new("x"),
2793 table: None,
2794 join_mark: false,
2795 trailing_comments: vec![],
2796 span: None,
2797 }),
2798 columns: vec![Identifier::new("from"), Identifier::new("normal")],
2799 materialized: None,
2800 key_expressions: vec![],
2801 alias_first: false,
2802 comments: Vec::new(),
2803 };
2804 let reserved = get_reserved_words(None);
2805 maybe_quote(&mut cte.alias, &reserved);
2806 for c in &mut cte.columns {
2807 maybe_quote(c, &reserved);
2808 }
2809 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
2810 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
2811 assert!(
2812 !cte.columns[1].quoted,
2813 "CTE column 'normal' should not be quoted"
2814 );
2815 }
2816
2817 #[test]
2818 fn test_quote_identifiers_binary_ops_recurse() {
2819 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
2822 Expression::Column(Column {
2823 name: Identifier::new("select"),
2824 table: None,
2825 join_mark: false,
2826 trailing_comments: vec![],
2827 span: None,
2828 }),
2829 Expression::Column(Column {
2830 name: Identifier::new("normal"),
2831 table: None,
2832 join_mark: false,
2833 trailing_comments: vec![],
2834 span: None,
2835 }),
2836 )));
2837 let result = quote_identifiers(expr, None);
2838 if let Expression::Add(bin) = &result {
2839 if let Expression::Column(left) = &bin.left {
2840 assert!(
2841 left.name.quoted,
2842 "'select' column should be quoted in binary op"
2843 );
2844 }
2845 if let Expression::Column(right) = &bin.right {
2846 assert!(!right.name.quoted, "'normal' column should not be quoted");
2847 }
2848 } else {
2849 panic!("Expected Add expression");
2850 }
2851 }
2852
2853 #[test]
2854 fn test_quote_identifiers_already_quoted_preserved() {
2855 let expr = Expression::Column(Column {
2857 name: Identifier::quoted("normal_name"),
2858 table: None,
2859 join_mark: false,
2860 trailing_comments: vec![],
2861 span: None,
2862 });
2863 let result = quote_identifiers(expr, None);
2864 if let Expression::Column(col) = &result {
2865 assert!(
2866 col.name.quoted,
2867 "Already-quoted identifier should remain quoted"
2868 );
2869 } else {
2870 panic!("Expected Column expression");
2871 }
2872 }
2873
2874 #[test]
2875 fn test_quote_identifiers_full_parsed_query() {
2876 let mut select = crate::expressions::Select::new();
2879 select.expressions.push(Expression::Column(Column {
2880 name: Identifier::new("order"),
2881 table: Some(Identifier::new("t")),
2882 join_mark: false,
2883 trailing_comments: vec![],
2884 span: None,
2885 }));
2886 select.from = Some(crate::expressions::From {
2887 expressions: vec![Expression::Table(TableRef::new("t"))],
2888 });
2889 let expr = Expression::Select(Box::new(select));
2890
2891 let result = quote_identifiers(expr, None);
2892 if let Expression::Select(sel) = &result {
2893 if let Expression::Column(col) = &sel.expressions[0] {
2894 assert!(col.name.quoted, "Column named 'order' should be quoted");
2895 assert!(
2896 !col.table.as_ref().unwrap().quoted,
2897 "Table 't' should not be quoted"
2898 );
2899 } else {
2900 panic!("Expected Column in SELECT list");
2901 }
2902 } else {
2903 panic!("Expected Select expression");
2904 }
2905 }
2906
2907 #[test]
2908 fn test_get_reserved_words_all_dialects() {
2909 let dialects = [
2911 None,
2912 Some(DialectType::Generic),
2913 Some(DialectType::MySQL),
2914 Some(DialectType::PostgreSQL),
2915 Some(DialectType::BigQuery),
2916 Some(DialectType::Snowflake),
2917 Some(DialectType::TSQL),
2918 Some(DialectType::ClickHouse),
2919 Some(DialectType::DuckDB),
2920 Some(DialectType::Hive),
2921 Some(DialectType::Spark),
2922 Some(DialectType::Trino),
2923 Some(DialectType::Oracle),
2924 Some(DialectType::Redshift),
2925 ];
2926 for dialect in &dialects {
2927 let words = get_reserved_words(*dialect);
2928 assert!(
2930 words.contains("SELECT"),
2931 "All dialects should have SELECT as reserved"
2932 );
2933 assert!(
2934 words.contains("FROM"),
2935 "All dialects should have FROM as reserved"
2936 );
2937 }
2938 }
2939}