1use crate::dialects::DialectType;
9use crate::expressions::{
10 Alias, Column, Expression, Identifier, Join, LateralView, Over, TableRef, With,
11};
12use crate::resolver::{Resolver, ResolverError};
13use crate::schema::Schema;
14use crate::scope::{traverse_scope, Scope};
15use std::collections::{HashMap, HashSet};
16use thiserror::Error;
17
18#[derive(Debug, Error, Clone)]
20pub enum QualifyColumnsError {
21 #[error("Unknown table: {0}")]
22 UnknownTable(String),
23
24 #[error("Unknown column: {0}")]
25 UnknownColumn(String),
26
27 #[error("Ambiguous column: {0}")]
28 AmbiguousColumn(String),
29
30 #[error("Cannot automatically join: {0}")]
31 CannotAutoJoin(String),
32
33 #[error("Unknown output column: {0}")]
34 UnknownOutputColumn(String),
35
36 #[error("Column could not be resolved: {column}{for_table}")]
37 ColumnNotResolved { column: String, for_table: String },
38
39 #[error("Resolver error: {0}")]
40 ResolverError(#[from] ResolverError),
41}
42
43pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
45
46#[derive(Debug, Clone, Default)]
48pub struct QualifyColumnsOptions {
49 pub expand_alias_refs: bool,
51 pub expand_stars: bool,
53 pub infer_schema: Option<bool>,
55 pub allow_partial_qualification: bool,
57 pub dialect: Option<DialectType>,
59}
60
61impl QualifyColumnsOptions {
62 pub fn new() -> Self {
64 Self {
65 expand_alias_refs: true,
66 expand_stars: true,
67 infer_schema: None,
68 allow_partial_qualification: false,
69 dialect: None,
70 }
71 }
72
73 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
75 self.expand_alias_refs = expand;
76 self
77 }
78
79 pub fn with_expand_stars(mut self, expand: bool) -> Self {
81 self.expand_stars = expand;
82 self
83 }
84
85 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
87 self.dialect = Some(dialect);
88 self
89 }
90
91 pub fn with_allow_partial(mut self, allow: bool) -> Self {
93 self.allow_partial_qualification = allow;
94 self
95 }
96}
97
98pub fn qualify_columns(
113 expression: Expression,
114 schema: &dyn Schema,
115 options: &QualifyColumnsOptions,
116) -> QualifyColumnsResult<Expression> {
117 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
118 let dialect = options.dialect.or_else(|| schema.dialect());
119
120 let result = expression.clone();
121
122 for scope in traverse_scope(&expression) {
124 let scope_expression = &scope.expression;
125 let is_select = matches!(scope_expression, Expression::Select(_));
126
127 let mut resolver = Resolver::new(&scope, schema, infer_schema);
129
130 qualify_columns_in_scope(&scope, &mut resolver, options.allow_partial_qualification)?;
132
133 if options.expand_alias_refs {
135 expand_alias_refs(&scope, &mut resolver, dialect)?;
136 }
137
138 if is_select {
139 if options.expand_stars {
141 expand_stars(&scope, &mut resolver)?;
142 }
143
144 qualify_outputs(&scope)?;
146 }
147
148 expand_group_by(&scope, dialect)?;
150 }
151
152 Ok(result)
153}
154
155pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
160 let mut all_unqualified = Vec::new();
161
162 for scope in traverse_scope(expression) {
163 if let Expression::Select(_) = &scope.expression {
164 let unqualified = get_unqualified_columns(&scope);
166
167 let external = get_external_columns(&scope);
169 if !external.is_empty() && !is_correlated_subquery(&scope) {
170 let first = &external[0];
171 let for_table = if first.table.is_some() {
172 format!(" for table: '{}'", first.table.as_ref().unwrap())
173 } else {
174 String::new()
175 };
176 return Err(QualifyColumnsError::ColumnNotResolved {
177 column: first.name.clone(),
178 for_table,
179 });
180 }
181
182 all_unqualified.extend(unqualified);
183 }
184 }
185
186 if !all_unqualified.is_empty() {
187 let first = &all_unqualified[0];
188 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
189 }
190
191 Ok(())
192}
193
194fn qualify_columns_in_scope(
196 scope: &Scope,
197 resolver: &mut Resolver,
198 allow_partial: bool,
199) -> QualifyColumnsResult<()> {
200 let columns = get_scope_columns(scope);
201
202 for column_ref in columns {
203 let column_table = &column_ref.table;
204 let column_name = &column_ref.name;
205
206 if let Some(table) = column_table {
207 if scope.sources.contains_key(table) {
209 if let Ok(source_columns) = resolver.get_source_columns(table) {
210 if !allow_partial
211 && !source_columns.is_empty()
212 && !source_columns.contains(column_name)
213 && !source_columns.contains(&"*".to_string())
214 {
215 return Err(QualifyColumnsError::UnknownColumn(column_name.clone()));
216 }
217 }
218 }
219 } else {
220 if let Some(table) = resolver.get_table(column_name) {
222 let _ = table;
225 }
226 }
227 }
228
229 Ok(())
230}
231
232fn expand_alias_refs(
239 scope: &Scope,
240 _resolver: &mut Resolver,
241 _dialect: Option<DialectType>,
242) -> QualifyColumnsResult<()> {
243 let expression = &scope.expression;
244
245 if !matches!(expression, Expression::Select(_)) {
246 return Ok(());
247 }
248
249 let mut _alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
251
252 if let Expression::Select(select) = expression {
253 for (i, expr) in select.expressions.iter().enumerate() {
254 if let Expression::Alias(alias) = expr {
255 _alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
256 }
257 }
258 }
259
260 Ok(())
264}
265
266fn expand_group_by(scope: &Scope, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
273 if let Expression::Select(select) = &scope.expression {
274 if let Some(_group_by) = &select.group_by {
275 }
278 }
279 Ok(())
280}
281
282fn expand_stars(scope: &Scope, resolver: &mut Resolver) -> QualifyColumnsResult<()> {
289 if let Expression::Select(select) = &scope.expression {
290 let mut _new_selections: Vec<Expression> = Vec::new();
291 let mut _has_star = false;
292
293 for expr in &select.expressions {
294 match expr {
295 Expression::Star(_) => {
296 _has_star = true;
297 for (source_name, _) in &scope.sources {
299 if let Ok(columns) = resolver.get_source_columns(source_name) {
300 if columns.contains(&"*".to_string()) || columns.is_empty() {
301 return Ok(());
303 }
304 for col_name in columns {
305 _new_selections
306 .push(create_qualified_column(&col_name, Some(source_name)));
307 }
308 }
309 }
310 }
311 Expression::Column(col) if is_star_column(col) => {
312 _has_star = true;
313 if let Some(table) = &col.table {
315 let table_name = &table.name;
316 if !scope.sources.contains_key(table_name) {
317 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
318 }
319 if let Ok(columns) = resolver.get_source_columns(table_name) {
320 if columns.contains(&"*".to_string()) || columns.is_empty() {
321 return Ok(());
322 }
323 for col_name in columns {
324 _new_selections
325 .push(create_qualified_column(&col_name, Some(table_name)));
326 }
327 }
328 }
329 }
330 _ => {
331 _new_selections.push(expr.clone());
332 }
333 }
334 }
335
336 }
339
340 Ok(())
341}
342
343pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
350 if let Expression::Select(select) = &scope.expression {
351 let mut new_selections: Vec<Expression> = Vec::new();
352
353 for (i, expr) in select.expressions.iter().enumerate() {
354 match expr {
355 Expression::Alias(_) => {
356 new_selections.push(expr.clone());
358 }
359 Expression::Column(col) => {
360 new_selections.push(create_alias(expr.clone(), &col.name.name));
362 }
363 Expression::Star(_) => {
364 new_selections.push(expr.clone());
366 }
367 _ => {
368 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
370 new_selections.push(create_alias(expr.clone(), &alias_name));
371 }
372 }
373 }
374
375 }
377
378 Ok(())
379}
380
381fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
384 let mut words: HashSet<&'static str> = [
386 "ADD",
388 "ALL",
389 "ALTER",
390 "AND",
391 "ANY",
392 "AS",
393 "ASC",
394 "BETWEEN",
395 "BY",
396 "CASE",
397 "CAST",
398 "CHECK",
399 "COLUMN",
400 "CONSTRAINT",
401 "CREATE",
402 "CROSS",
403 "CURRENT",
404 "CURRENT_DATE",
405 "CURRENT_TIME",
406 "CURRENT_TIMESTAMP",
407 "CURRENT_USER",
408 "DATABASE",
409 "DEFAULT",
410 "DELETE",
411 "DESC",
412 "DISTINCT",
413 "DROP",
414 "ELSE",
415 "END",
416 "ESCAPE",
417 "EXCEPT",
418 "EXISTS",
419 "FALSE",
420 "FETCH",
421 "FOR",
422 "FOREIGN",
423 "FROM",
424 "FULL",
425 "GRANT",
426 "GROUP",
427 "HAVING",
428 "IF",
429 "IN",
430 "INDEX",
431 "INNER",
432 "INSERT",
433 "INTERSECT",
434 "INTO",
435 "IS",
436 "JOIN",
437 "KEY",
438 "LEFT",
439 "LIKE",
440 "LIMIT",
441 "NATURAL",
442 "NOT",
443 "NULL",
444 "OFFSET",
445 "ON",
446 "OR",
447 "ORDER",
448 "OUTER",
449 "PRIMARY",
450 "REFERENCES",
451 "REPLACE",
452 "RETURNING",
453 "RIGHT",
454 "ROLLBACK",
455 "ROW",
456 "ROWS",
457 "SELECT",
458 "SESSION_USER",
459 "SET",
460 "SOME",
461 "TABLE",
462 "THEN",
463 "TO",
464 "TRUE",
465 "TRUNCATE",
466 "UNION",
467 "UNIQUE",
468 "UPDATE",
469 "USING",
470 "VALUES",
471 "VIEW",
472 "WHEN",
473 "WHERE",
474 "WINDOW",
475 "WITH",
476 ]
477 .iter()
478 .copied()
479 .collect();
480
481 match dialect {
483 Some(DialectType::MySQL) => {
484 words.extend(
485 [
486 "ANALYZE",
487 "BOTH",
488 "CHANGE",
489 "CONDITION",
490 "DATABASES",
491 "DAY_HOUR",
492 "DAY_MICROSECOND",
493 "DAY_MINUTE",
494 "DAY_SECOND",
495 "DELAYED",
496 "DETERMINISTIC",
497 "DIV",
498 "DUAL",
499 "EACH",
500 "ELSEIF",
501 "ENCLOSED",
502 "EXPLAIN",
503 "FLOAT4",
504 "FLOAT8",
505 "FORCE",
506 "HOUR_MICROSECOND",
507 "HOUR_MINUTE",
508 "HOUR_SECOND",
509 "IGNORE",
510 "INFILE",
511 "INT1",
512 "INT2",
513 "INT3",
514 "INT4",
515 "INT8",
516 "ITERATE",
517 "KEYS",
518 "KILL",
519 "LEADING",
520 "LEAVE",
521 "LINES",
522 "LOAD",
523 "LOCK",
524 "LONG",
525 "LONGBLOB",
526 "LONGTEXT",
527 "LOOP",
528 "LOW_PRIORITY",
529 "MATCH",
530 "MEDIUMBLOB",
531 "MEDIUMINT",
532 "MEDIUMTEXT",
533 "MINUTE_MICROSECOND",
534 "MINUTE_SECOND",
535 "MOD",
536 "MODIFIES",
537 "NO_WRITE_TO_BINLOG",
538 "OPTIMIZE",
539 "OPTIONALLY",
540 "OUT",
541 "OUTFILE",
542 "PURGE",
543 "READS",
544 "REGEXP",
545 "RELEASE",
546 "RENAME",
547 "REPEAT",
548 "REQUIRE",
549 "RESIGNAL",
550 "RETURN",
551 "REVOKE",
552 "RLIKE",
553 "SCHEMA",
554 "SCHEMAS",
555 "SECOND_MICROSECOND",
556 "SENSITIVE",
557 "SEPARATOR",
558 "SHOW",
559 "SIGNAL",
560 "SPATIAL",
561 "SQL",
562 "SQLEXCEPTION",
563 "SQLSTATE",
564 "SQLWARNING",
565 "SQL_BIG_RESULT",
566 "SQL_CALC_FOUND_ROWS",
567 "SQL_SMALL_RESULT",
568 "SSL",
569 "STARTING",
570 "STRAIGHT_JOIN",
571 "TERMINATED",
572 "TINYBLOB",
573 "TINYINT",
574 "TINYTEXT",
575 "TRAILING",
576 "TRIGGER",
577 "UNDO",
578 "UNLOCK",
579 "UNSIGNED",
580 "USAGE",
581 "UTC_DATE",
582 "UTC_TIME",
583 "UTC_TIMESTAMP",
584 "VARBINARY",
585 "VARCHARACTER",
586 "WHILE",
587 "WRITE",
588 "XOR",
589 "YEAR_MONTH",
590 "ZEROFILL",
591 ]
592 .iter()
593 .copied(),
594 );
595 }
596 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
597 words.extend(
598 [
599 "ANALYSE",
600 "ANALYZE",
601 "ARRAY",
602 "AUTHORIZATION",
603 "BINARY",
604 "BOTH",
605 "COLLATE",
606 "CONCURRENTLY",
607 "DO",
608 "FREEZE",
609 "ILIKE",
610 "INITIALLY",
611 "ISNULL",
612 "LATERAL",
613 "LEADING",
614 "LOCALTIME",
615 "LOCALTIMESTAMP",
616 "NOTNULL",
617 "ONLY",
618 "OVERLAPS",
619 "PLACING",
620 "SIMILAR",
621 "SYMMETRIC",
622 "TABLESAMPLE",
623 "TRAILING",
624 "VARIADIC",
625 "VERBOSE",
626 ]
627 .iter()
628 .copied(),
629 );
630 }
631 Some(DialectType::BigQuery) => {
632 words.extend(
633 [
634 "ASSERT_ROWS_MODIFIED",
635 "COLLATE",
636 "CONTAINS",
637 "CUBE",
638 "DEFINE",
639 "ENUM",
640 "EXTRACT",
641 "FOLLOWING",
642 "GROUPING",
643 "GROUPS",
644 "HASH",
645 "IGNORE",
646 "LATERAL",
647 "LOOKUP",
648 "MERGE",
649 "NEW",
650 "NO",
651 "NULLS",
652 "OF",
653 "OVER",
654 "PARTITION",
655 "PRECEDING",
656 "PROTO",
657 "RANGE",
658 "RECURSIVE",
659 "RESPECT",
660 "ROLLUP",
661 "STRUCT",
662 "TABLESAMPLE",
663 "TREAT",
664 "UNBOUNDED",
665 "WITHIN",
666 ]
667 .iter()
668 .copied(),
669 );
670 }
671 Some(DialectType::Snowflake) => {
672 words.extend(
673 [
674 "ACCOUNT",
675 "BOTH",
676 "CONNECT",
677 "FOLLOWING",
678 "ILIKE",
679 "INCREMENT",
680 "ISSUE",
681 "LATERAL",
682 "LEADING",
683 "LOCALTIME",
684 "LOCALTIMESTAMP",
685 "MINUS",
686 "QUALIFY",
687 "REGEXP",
688 "RLIKE",
689 "SOME",
690 "START",
691 "TABLESAMPLE",
692 "TOP",
693 "TRAILING",
694 "TRY_CAST",
695 ]
696 .iter()
697 .copied(),
698 );
699 }
700 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
701 words.extend(
702 [
703 "BACKUP",
704 "BREAK",
705 "BROWSE",
706 "BULK",
707 "CASCADE",
708 "CHECKPOINT",
709 "CLOSE",
710 "CLUSTERED",
711 "COALESCE",
712 "COMPUTE",
713 "CONTAINS",
714 "CONTAINSTABLE",
715 "CONTINUE",
716 "CONVERT",
717 "DBCC",
718 "DEALLOCATE",
719 "DENY",
720 "DISK",
721 "DISTRIBUTED",
722 "DUMP",
723 "ERRLVL",
724 "EXEC",
725 "EXECUTE",
726 "EXIT",
727 "EXTERNAL",
728 "FILE",
729 "FILLFACTOR",
730 "FREETEXT",
731 "FREETEXTTABLE",
732 "FUNCTION",
733 "GOTO",
734 "HOLDLOCK",
735 "IDENTITY",
736 "IDENTITYCOL",
737 "IDENTITY_INSERT",
738 "KILL",
739 "LINENO",
740 "MERGE",
741 "NONCLUSTERED",
742 "NULLIF",
743 "OF",
744 "OFF",
745 "OFFSETS",
746 "OPEN",
747 "OPENDATASOURCE",
748 "OPENQUERY",
749 "OPENROWSET",
750 "OPENXML",
751 "OVER",
752 "PERCENT",
753 "PIVOT",
754 "PLAN",
755 "PRINT",
756 "PROC",
757 "PROCEDURE",
758 "PUBLIC",
759 "RAISERROR",
760 "READ",
761 "READTEXT",
762 "RECONFIGURE",
763 "REPLICATION",
764 "RESTORE",
765 "RESTRICT",
766 "REVERT",
767 "ROWCOUNT",
768 "ROWGUIDCOL",
769 "RULE",
770 "SAVE",
771 "SECURITYAUDIT",
772 "SEMANTICKEYPHRASETABLE",
773 "SEMANTICSIMILARITYDETAILSTABLE",
774 "SEMANTICSIMILARITYTABLE",
775 "SETUSER",
776 "SHUTDOWN",
777 "STATISTICS",
778 "SYSTEM_USER",
779 "TEXTSIZE",
780 "TOP",
781 "TRAN",
782 "TRANSACTION",
783 "TRIGGER",
784 "TSEQUAL",
785 "UNPIVOT",
786 "UPDATETEXT",
787 "WAITFOR",
788 "WRITETEXT",
789 ]
790 .iter()
791 .copied(),
792 );
793 }
794 Some(DialectType::ClickHouse) => {
795 words.extend(
796 [
797 "ANTI",
798 "ARRAY",
799 "ASOF",
800 "FINAL",
801 "FORMAT",
802 "GLOBAL",
803 "INF",
804 "KILL",
805 "MATERIALIZED",
806 "NAN",
807 "PREWHERE",
808 "SAMPLE",
809 "SEMI",
810 "SETTINGS",
811 "TOP",
812 ]
813 .iter()
814 .copied(),
815 );
816 }
817 Some(DialectType::DuckDB) => {
818 words.extend(
819 [
820 "ANALYSE",
821 "ANALYZE",
822 "ARRAY",
823 "BOTH",
824 "LATERAL",
825 "LEADING",
826 "LOCALTIME",
827 "LOCALTIMESTAMP",
828 "PLACING",
829 "QUALIFY",
830 "SIMILAR",
831 "TABLESAMPLE",
832 "TRAILING",
833 ]
834 .iter()
835 .copied(),
836 );
837 }
838 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
839 words.extend(
840 [
841 "BOTH",
842 "CLUSTER",
843 "DISTRIBUTE",
844 "EXCHANGE",
845 "EXTENDED",
846 "FUNCTION",
847 "LATERAL",
848 "LEADING",
849 "MACRO",
850 "OVER",
851 "PARTITION",
852 "PERCENT",
853 "RANGE",
854 "READS",
855 "REDUCE",
856 "REGEXP",
857 "REVOKE",
858 "RLIKE",
859 "ROLLUP",
860 "SEMI",
861 "SORT",
862 "TABLESAMPLE",
863 "TRAILING",
864 "TRANSFORM",
865 "UNBOUNDED",
866 "UNIQUEJOIN",
867 ]
868 .iter()
869 .copied(),
870 );
871 }
872 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
873 words.extend(
874 [
875 "CUBE",
876 "DEALLOCATE",
877 "DESCRIBE",
878 "EXECUTE",
879 "EXTRACT",
880 "GROUPING",
881 "LATERAL",
882 "LOCALTIME",
883 "LOCALTIMESTAMP",
884 "NORMALIZE",
885 "PREPARE",
886 "ROLLUP",
887 "SOME",
888 "TABLESAMPLE",
889 "UESCAPE",
890 "UNNEST",
891 ]
892 .iter()
893 .copied(),
894 );
895 }
896 Some(DialectType::Oracle) => {
897 words.extend(
898 [
899 "ACCESS",
900 "AUDIT",
901 "CLUSTER",
902 "COMMENT",
903 "COMPRESS",
904 "CONNECT",
905 "EXCLUSIVE",
906 "FILE",
907 "IDENTIFIED",
908 "IMMEDIATE",
909 "INCREMENT",
910 "INITIAL",
911 "LEVEL",
912 "LOCK",
913 "LONG",
914 "MAXEXTENTS",
915 "MINUS",
916 "MODE",
917 "NOAUDIT",
918 "NOCOMPRESS",
919 "NOWAIT",
920 "NUMBER",
921 "OF",
922 "OFFLINE",
923 "ONLINE",
924 "PCTFREE",
925 "PRIOR",
926 "RAW",
927 "RENAME",
928 "RESOURCE",
929 "REVOKE",
930 "SHARE",
931 "SIZE",
932 "START",
933 "SUCCESSFUL",
934 "SYNONYM",
935 "SYSDATE",
936 "TRIGGER",
937 "UID",
938 "VALIDATE",
939 "VARCHAR2",
940 "WHENEVER",
941 ]
942 .iter()
943 .copied(),
944 );
945 }
946 Some(DialectType::Redshift) => {
947 words.extend(
948 [
949 "AZ64",
950 "BZIP2",
951 "DELTA",
952 "DELTA32K",
953 "DISTSTYLE",
954 "ENCODE",
955 "GZIP",
956 "ILIKE",
957 "LIMIT",
958 "LUNS",
959 "LZO",
960 "LZOP",
961 "MOSTLY13",
962 "MOSTLY32",
963 "MOSTLY8",
964 "RAW",
965 "SIMILAR",
966 "SNAPSHOT",
967 "SORTKEY",
968 "SYSDATE",
969 "TOP",
970 "ZSTD",
971 ]
972 .iter()
973 .copied(),
974 );
975 }
976 _ => {
977 words.extend(
979 [
980 "ANALYZE",
981 "ARRAY",
982 "BOTH",
983 "CUBE",
984 "GROUPING",
985 "LATERAL",
986 "LEADING",
987 "LOCALTIME",
988 "LOCALTIMESTAMP",
989 "OVER",
990 "PARTITION",
991 "QUALIFY",
992 "RANGE",
993 "ROLLUP",
994 "SIMILAR",
995 "SOME",
996 "TABLESAMPLE",
997 "TRAILING",
998 ]
999 .iter()
1000 .copied(),
1001 );
1002 }
1003 }
1004
1005 words
1006}
1007
1008fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
1016 if name.is_empty() {
1017 return false;
1018 }
1019
1020 if name.as_bytes()[0].is_ascii_digit() {
1022 return true;
1023 }
1024
1025 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
1027 return true;
1028 }
1029
1030 let upper = name.to_uppercase();
1032 reserved_words.contains(upper.as_str())
1033}
1034
1035fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
1037 if id.quoted || id.name.is_empty() || id.name == "*" {
1040 return;
1041 }
1042 if needs_quoting(&id.name, reserved_words) {
1043 id.quoted = true;
1044 }
1045}
1046
1047fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
1049 match expr {
1050 Expression::Identifier(id) => {
1052 maybe_quote(id, reserved_words);
1053 }
1054
1055 Expression::Column(col) => {
1056 maybe_quote(&mut col.name, reserved_words);
1057 if let Some(ref mut table) = col.table {
1058 maybe_quote(table, reserved_words);
1059 }
1060 }
1061
1062 Expression::Table(table_ref) => {
1063 maybe_quote(&mut table_ref.name, reserved_words);
1064 if let Some(ref mut schema) = table_ref.schema {
1065 maybe_quote(schema, reserved_words);
1066 }
1067 if let Some(ref mut catalog) = table_ref.catalog {
1068 maybe_quote(catalog, reserved_words);
1069 }
1070 if let Some(ref mut alias) = table_ref.alias {
1071 maybe_quote(alias, reserved_words);
1072 }
1073 for ca in &mut table_ref.column_aliases {
1074 maybe_quote(ca, reserved_words);
1075 }
1076 for p in &mut table_ref.partitions {
1077 maybe_quote(p, reserved_words);
1078 }
1079 for h in &mut table_ref.hints {
1081 quote_identifiers_recursive(h, reserved_words);
1082 }
1083 if let Some(ref mut ver) = table_ref.version {
1084 quote_identifiers_recursive(&mut ver.this, reserved_words);
1085 if let Some(ref mut e) = ver.expression {
1086 quote_identifiers_recursive(e, reserved_words);
1087 }
1088 }
1089 }
1090
1091 Expression::Star(star) => {
1092 if let Some(ref mut table) = star.table {
1093 maybe_quote(table, reserved_words);
1094 }
1095 if let Some(ref mut except_ids) = star.except {
1096 for id in except_ids {
1097 maybe_quote(id, reserved_words);
1098 }
1099 }
1100 if let Some(ref mut replace_aliases) = star.replace {
1101 for alias in replace_aliases {
1102 maybe_quote(&mut alias.alias, reserved_words);
1103 quote_identifiers_recursive(&mut alias.this, reserved_words);
1104 }
1105 }
1106 if let Some(ref mut rename_pairs) = star.rename {
1107 for (from, to) in rename_pairs {
1108 maybe_quote(from, reserved_words);
1109 maybe_quote(to, reserved_words);
1110 }
1111 }
1112 }
1113
1114 Expression::Alias(alias) => {
1116 maybe_quote(&mut alias.alias, reserved_words);
1117 for ca in &mut alias.column_aliases {
1118 maybe_quote(ca, reserved_words);
1119 }
1120 quote_identifiers_recursive(&mut alias.this, reserved_words);
1121 }
1122
1123 Expression::Select(select) => {
1125 for e in &mut select.expressions {
1126 quote_identifiers_recursive(e, reserved_words);
1127 }
1128 if let Some(ref mut from) = select.from {
1129 for e in &mut from.expressions {
1130 quote_identifiers_recursive(e, reserved_words);
1131 }
1132 }
1133 for join in &mut select.joins {
1134 quote_join(join, reserved_words);
1135 }
1136 for lv in &mut select.lateral_views {
1137 quote_lateral_view(lv, reserved_words);
1138 }
1139 if let Some(ref mut prewhere) = select.prewhere {
1140 quote_identifiers_recursive(prewhere, reserved_words);
1141 }
1142 if let Some(ref mut wh) = select.where_clause {
1143 quote_identifiers_recursive(&mut wh.this, reserved_words);
1144 }
1145 if let Some(ref mut gb) = select.group_by {
1146 for e in &mut gb.expressions {
1147 quote_identifiers_recursive(e, reserved_words);
1148 }
1149 }
1150 if let Some(ref mut hv) = select.having {
1151 quote_identifiers_recursive(&mut hv.this, reserved_words);
1152 }
1153 if let Some(ref mut q) = select.qualify {
1154 quote_identifiers_recursive(&mut q.this, reserved_words);
1155 }
1156 if let Some(ref mut ob) = select.order_by {
1157 for o in &mut ob.expressions {
1158 quote_identifiers_recursive(&mut o.this, reserved_words);
1159 }
1160 }
1161 if let Some(ref mut lim) = select.limit {
1162 quote_identifiers_recursive(&mut lim.this, reserved_words);
1163 }
1164 if let Some(ref mut off) = select.offset {
1165 quote_identifiers_recursive(&mut off.this, reserved_words);
1166 }
1167 if let Some(ref mut with) = select.with {
1168 quote_with(with, reserved_words);
1169 }
1170 if let Some(ref mut windows) = select.windows {
1171 for nw in windows {
1172 maybe_quote(&mut nw.name, reserved_words);
1173 quote_over(&mut nw.spec, reserved_words);
1174 }
1175 }
1176 if let Some(ref mut distinct_on) = select.distinct_on {
1177 for e in distinct_on {
1178 quote_identifiers_recursive(e, reserved_words);
1179 }
1180 }
1181 if let Some(ref mut limit_by) = select.limit_by {
1182 for e in limit_by {
1183 quote_identifiers_recursive(e, reserved_words);
1184 }
1185 }
1186 if let Some(ref mut settings) = select.settings {
1187 for e in settings {
1188 quote_identifiers_recursive(e, reserved_words);
1189 }
1190 }
1191 if let Some(ref mut format) = select.format {
1192 quote_identifiers_recursive(format, reserved_words);
1193 }
1194 }
1195
1196 Expression::Union(u) => {
1198 quote_identifiers_recursive(&mut u.left, reserved_words);
1199 quote_identifiers_recursive(&mut u.right, reserved_words);
1200 if let Some(ref mut with) = u.with {
1201 quote_with(with, reserved_words);
1202 }
1203 if let Some(ref mut ob) = u.order_by {
1204 for o in &mut ob.expressions {
1205 quote_identifiers_recursive(&mut o.this, reserved_words);
1206 }
1207 }
1208 if let Some(ref mut lim) = u.limit {
1209 quote_identifiers_recursive(lim, reserved_words);
1210 }
1211 if let Some(ref mut off) = u.offset {
1212 quote_identifiers_recursive(off, reserved_words);
1213 }
1214 }
1215 Expression::Intersect(i) => {
1216 quote_identifiers_recursive(&mut i.left, reserved_words);
1217 quote_identifiers_recursive(&mut i.right, reserved_words);
1218 if let Some(ref mut with) = i.with {
1219 quote_with(with, reserved_words);
1220 }
1221 if let Some(ref mut ob) = i.order_by {
1222 for o in &mut ob.expressions {
1223 quote_identifiers_recursive(&mut o.this, reserved_words);
1224 }
1225 }
1226 }
1227 Expression::Except(e) => {
1228 quote_identifiers_recursive(&mut e.left, reserved_words);
1229 quote_identifiers_recursive(&mut e.right, reserved_words);
1230 if let Some(ref mut with) = e.with {
1231 quote_with(with, reserved_words);
1232 }
1233 if let Some(ref mut ob) = e.order_by {
1234 for o in &mut ob.expressions {
1235 quote_identifiers_recursive(&mut o.this, reserved_words);
1236 }
1237 }
1238 }
1239
1240 Expression::Subquery(sq) => {
1242 quote_identifiers_recursive(&mut sq.this, reserved_words);
1243 if let Some(ref mut alias) = sq.alias {
1244 maybe_quote(alias, reserved_words);
1245 }
1246 for ca in &mut sq.column_aliases {
1247 maybe_quote(ca, reserved_words);
1248 }
1249 if let Some(ref mut ob) = sq.order_by {
1250 for o in &mut ob.expressions {
1251 quote_identifiers_recursive(&mut o.this, reserved_words);
1252 }
1253 }
1254 }
1255
1256 Expression::Insert(ins) => {
1258 quote_table_ref(&mut ins.table, reserved_words);
1259 for c in &mut ins.columns {
1260 maybe_quote(c, reserved_words);
1261 }
1262 for row in &mut ins.values {
1263 for e in row {
1264 quote_identifiers_recursive(e, reserved_words);
1265 }
1266 }
1267 if let Some(ref mut q) = ins.query {
1268 quote_identifiers_recursive(q, reserved_words);
1269 }
1270 for (id, val) in &mut ins.partition {
1271 maybe_quote(id, reserved_words);
1272 if let Some(ref mut v) = val {
1273 quote_identifiers_recursive(v, reserved_words);
1274 }
1275 }
1276 for e in &mut ins.returning {
1277 quote_identifiers_recursive(e, reserved_words);
1278 }
1279 if let Some(ref mut on_conflict) = ins.on_conflict {
1280 quote_identifiers_recursive(on_conflict, reserved_words);
1281 }
1282 if let Some(ref mut with) = ins.with {
1283 quote_with(with, reserved_words);
1284 }
1285 if let Some(ref mut alias) = ins.alias {
1286 maybe_quote(alias, reserved_words);
1287 }
1288 if let Some(ref mut src_alias) = ins.source_alias {
1289 maybe_quote(src_alias, reserved_words);
1290 }
1291 }
1292
1293 Expression::Update(upd) => {
1294 quote_table_ref(&mut upd.table, reserved_words);
1295 for tr in &mut upd.extra_tables {
1296 quote_table_ref(tr, reserved_words);
1297 }
1298 for join in &mut upd.table_joins {
1299 quote_join(join, reserved_words);
1300 }
1301 for (id, val) in &mut upd.set {
1302 maybe_quote(id, reserved_words);
1303 quote_identifiers_recursive(val, reserved_words);
1304 }
1305 if let Some(ref mut from) = upd.from_clause {
1306 for e in &mut from.expressions {
1307 quote_identifiers_recursive(e, reserved_words);
1308 }
1309 }
1310 for join in &mut upd.from_joins {
1311 quote_join(join, reserved_words);
1312 }
1313 if let Some(ref mut wh) = upd.where_clause {
1314 quote_identifiers_recursive(&mut wh.this, reserved_words);
1315 }
1316 for e in &mut upd.returning {
1317 quote_identifiers_recursive(e, reserved_words);
1318 }
1319 if let Some(ref mut with) = upd.with {
1320 quote_with(with, reserved_words);
1321 }
1322 }
1323
1324 Expression::Delete(del) => {
1325 quote_table_ref(&mut del.table, reserved_words);
1326 if let Some(ref mut alias) = del.alias {
1327 maybe_quote(alias, reserved_words);
1328 }
1329 for tr in &mut del.using {
1330 quote_table_ref(tr, reserved_words);
1331 }
1332 if let Some(ref mut wh) = del.where_clause {
1333 quote_identifiers_recursive(&mut wh.this, reserved_words);
1334 }
1335 if let Some(ref mut with) = del.with {
1336 quote_with(with, reserved_words);
1337 }
1338 }
1339
1340 Expression::And(bin)
1342 | Expression::Or(bin)
1343 | Expression::Eq(bin)
1344 | Expression::Neq(bin)
1345 | Expression::Lt(bin)
1346 | Expression::Lte(bin)
1347 | Expression::Gt(bin)
1348 | Expression::Gte(bin)
1349 | Expression::Add(bin)
1350 | Expression::Sub(bin)
1351 | Expression::Mul(bin)
1352 | Expression::Div(bin)
1353 | Expression::Mod(bin)
1354 | Expression::BitwiseAnd(bin)
1355 | Expression::BitwiseOr(bin)
1356 | Expression::BitwiseXor(bin)
1357 | Expression::Concat(bin)
1358 | Expression::Adjacent(bin)
1359 | Expression::TsMatch(bin)
1360 | Expression::PropertyEQ(bin)
1361 | Expression::ArrayContainsAll(bin)
1362 | Expression::ArrayContainedBy(bin)
1363 | Expression::ArrayOverlaps(bin)
1364 | Expression::JSONBContainsAllTopKeys(bin)
1365 | Expression::JSONBContainsAnyTopKeys(bin)
1366 | Expression::JSONBDeleteAtPath(bin)
1367 | Expression::ExtendsLeft(bin)
1368 | Expression::ExtendsRight(bin)
1369 | Expression::Is(bin)
1370 | Expression::NullSafeEq(bin)
1371 | Expression::NullSafeNeq(bin)
1372 | Expression::Glob(bin)
1373 | Expression::Match(bin)
1374 | Expression::MemberOf(bin)
1375 | Expression::BitwiseLeftShift(bin)
1376 | Expression::BitwiseRightShift(bin) => {
1377 quote_identifiers_recursive(&mut bin.left, reserved_words);
1378 quote_identifiers_recursive(&mut bin.right, reserved_words);
1379 }
1380
1381 Expression::Like(like) | Expression::ILike(like) => {
1383 quote_identifiers_recursive(&mut like.left, reserved_words);
1384 quote_identifiers_recursive(&mut like.right, reserved_words);
1385 if let Some(ref mut esc) = like.escape {
1386 quote_identifiers_recursive(esc, reserved_words);
1387 }
1388 }
1389
1390 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1392 quote_identifiers_recursive(&mut un.this, reserved_words);
1393 }
1394
1395 Expression::In(in_expr) => {
1397 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
1398 for e in &mut in_expr.expressions {
1399 quote_identifiers_recursive(e, reserved_words);
1400 }
1401 if let Some(ref mut q) = in_expr.query {
1402 quote_identifiers_recursive(q, reserved_words);
1403 }
1404 if let Some(ref mut un) = in_expr.unnest {
1405 quote_identifiers_recursive(un, reserved_words);
1406 }
1407 }
1408
1409 Expression::Between(bw) => {
1410 quote_identifiers_recursive(&mut bw.this, reserved_words);
1411 quote_identifiers_recursive(&mut bw.low, reserved_words);
1412 quote_identifiers_recursive(&mut bw.high, reserved_words);
1413 }
1414
1415 Expression::IsNull(is_null) => {
1416 quote_identifiers_recursive(&mut is_null.this, reserved_words);
1417 }
1418
1419 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
1420 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
1421 }
1422
1423 Expression::Exists(ex) => {
1424 quote_identifiers_recursive(&mut ex.this, reserved_words);
1425 }
1426
1427 Expression::Function(func) => {
1429 for arg in &mut func.args {
1430 quote_identifiers_recursive(arg, reserved_words);
1431 }
1432 }
1433
1434 Expression::AggregateFunction(agg) => {
1435 for arg in &mut agg.args {
1436 quote_identifiers_recursive(arg, reserved_words);
1437 }
1438 if let Some(ref mut filter) = agg.filter {
1439 quote_identifiers_recursive(filter, reserved_words);
1440 }
1441 for o in &mut agg.order_by {
1442 quote_identifiers_recursive(&mut o.this, reserved_words);
1443 }
1444 }
1445
1446 Expression::WindowFunction(wf) => {
1447 quote_identifiers_recursive(&mut wf.this, reserved_words);
1448 quote_over(&mut wf.over, reserved_words);
1449 }
1450
1451 Expression::Case(case) => {
1453 if let Some(ref mut operand) = case.operand {
1454 quote_identifiers_recursive(operand, reserved_words);
1455 }
1456 for (when, then) in &mut case.whens {
1457 quote_identifiers_recursive(when, reserved_words);
1458 quote_identifiers_recursive(then, reserved_words);
1459 }
1460 if let Some(ref mut else_) = case.else_ {
1461 quote_identifiers_recursive(else_, reserved_words);
1462 }
1463 }
1464
1465 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1467 quote_identifiers_recursive(&mut cast.this, reserved_words);
1468 if let Some(ref mut fmt) = cast.format {
1469 quote_identifiers_recursive(fmt, reserved_words);
1470 }
1471 }
1472
1473 Expression::Paren(paren) => {
1475 quote_identifiers_recursive(&mut paren.this, reserved_words);
1476 }
1477
1478 Expression::Annotated(ann) => {
1479 quote_identifiers_recursive(&mut ann.this, reserved_words);
1480 }
1481
1482 Expression::With(with) => {
1484 quote_with(with, reserved_words);
1485 }
1486
1487 Expression::Cte(cte) => {
1488 maybe_quote(&mut cte.alias, reserved_words);
1489 for c in &mut cte.columns {
1490 maybe_quote(c, reserved_words);
1491 }
1492 quote_identifiers_recursive(&mut cte.this, reserved_words);
1493 }
1494
1495 Expression::From(from) => {
1497 for e in &mut from.expressions {
1498 quote_identifiers_recursive(e, reserved_words);
1499 }
1500 }
1501
1502 Expression::Join(join) => {
1503 quote_join(join, reserved_words);
1504 }
1505
1506 Expression::JoinedTable(jt) => {
1507 quote_identifiers_recursive(&mut jt.left, reserved_words);
1508 for join in &mut jt.joins {
1509 quote_join(join, reserved_words);
1510 }
1511 if let Some(ref mut alias) = jt.alias {
1512 maybe_quote(alias, reserved_words);
1513 }
1514 }
1515
1516 Expression::Where(wh) => {
1517 quote_identifiers_recursive(&mut wh.this, reserved_words);
1518 }
1519
1520 Expression::GroupBy(gb) => {
1521 for e in &mut gb.expressions {
1522 quote_identifiers_recursive(e, reserved_words);
1523 }
1524 }
1525
1526 Expression::Having(hv) => {
1527 quote_identifiers_recursive(&mut hv.this, reserved_words);
1528 }
1529
1530 Expression::OrderBy(ob) => {
1531 for o in &mut ob.expressions {
1532 quote_identifiers_recursive(&mut o.this, reserved_words);
1533 }
1534 }
1535
1536 Expression::Ordered(ord) => {
1537 quote_identifiers_recursive(&mut ord.this, reserved_words);
1538 }
1539
1540 Expression::Limit(lim) => {
1541 quote_identifiers_recursive(&mut lim.this, reserved_words);
1542 }
1543
1544 Expression::Offset(off) => {
1545 quote_identifiers_recursive(&mut off.this, reserved_words);
1546 }
1547
1548 Expression::Qualify(q) => {
1549 quote_identifiers_recursive(&mut q.this, reserved_words);
1550 }
1551
1552 Expression::Window(ws) => {
1553 for e in &mut ws.partition_by {
1554 quote_identifiers_recursive(e, reserved_words);
1555 }
1556 for o in &mut ws.order_by {
1557 quote_identifiers_recursive(&mut o.this, reserved_words);
1558 }
1559 }
1560
1561 Expression::Over(over) => {
1562 quote_over(over, reserved_words);
1563 }
1564
1565 Expression::WithinGroup(wg) => {
1566 quote_identifiers_recursive(&mut wg.this, reserved_words);
1567 for o in &mut wg.order_by {
1568 quote_identifiers_recursive(&mut o.this, reserved_words);
1569 }
1570 }
1571
1572 Expression::Pivot(piv) => {
1574 quote_identifiers_recursive(&mut piv.this, reserved_words);
1575 for e in &mut piv.expressions {
1576 quote_identifiers_recursive(e, reserved_words);
1577 }
1578 for f in &mut piv.fields {
1579 quote_identifiers_recursive(f, reserved_words);
1580 }
1581 if let Some(ref mut alias) = piv.alias {
1582 maybe_quote(alias, reserved_words);
1583 }
1584 }
1585
1586 Expression::Unpivot(unpiv) => {
1587 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
1588 maybe_quote(&mut unpiv.value_column, reserved_words);
1589 maybe_quote(&mut unpiv.name_column, reserved_words);
1590 for e in &mut unpiv.columns {
1591 quote_identifiers_recursive(e, reserved_words);
1592 }
1593 if let Some(ref mut alias) = unpiv.alias {
1594 maybe_quote(alias, reserved_words);
1595 }
1596 }
1597
1598 Expression::Values(vals) => {
1600 for tuple in &mut vals.expressions {
1601 for e in &mut tuple.expressions {
1602 quote_identifiers_recursive(e, reserved_words);
1603 }
1604 }
1605 if let Some(ref mut alias) = vals.alias {
1606 maybe_quote(alias, reserved_words);
1607 }
1608 for ca in &mut vals.column_aliases {
1609 maybe_quote(ca, reserved_words);
1610 }
1611 }
1612
1613 Expression::Array(arr) => {
1615 for e in &mut arr.expressions {
1616 quote_identifiers_recursive(e, reserved_words);
1617 }
1618 }
1619
1620 Expression::Struct(st) => {
1621 for (_name, e) in &mut st.fields {
1622 quote_identifiers_recursive(e, reserved_words);
1623 }
1624 }
1625
1626 Expression::Tuple(tup) => {
1627 for e in &mut tup.expressions {
1628 quote_identifiers_recursive(e, reserved_words);
1629 }
1630 }
1631
1632 Expression::Subscript(sub) => {
1634 quote_identifiers_recursive(&mut sub.this, reserved_words);
1635 quote_identifiers_recursive(&mut sub.index, reserved_words);
1636 }
1637
1638 Expression::Dot(dot) => {
1639 quote_identifiers_recursive(&mut dot.this, reserved_words);
1640 maybe_quote(&mut dot.field, reserved_words);
1641 }
1642
1643 Expression::ScopeResolution(sr) => {
1644 if let Some(ref mut this) = sr.this {
1645 quote_identifiers_recursive(this, reserved_words);
1646 }
1647 quote_identifiers_recursive(&mut sr.expression, reserved_words);
1648 }
1649
1650 Expression::Lateral(lat) => {
1652 quote_identifiers_recursive(&mut lat.this, reserved_words);
1653 }
1655
1656 Expression::DPipe(dpipe) => {
1658 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
1659 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
1660 }
1661
1662 Expression::Merge(merge) => {
1664 quote_identifiers_recursive(&mut merge.this, reserved_words);
1665 quote_identifiers_recursive(&mut merge.using, reserved_words);
1666 if let Some(ref mut on) = merge.on {
1667 quote_identifiers_recursive(on, reserved_words);
1668 }
1669 if let Some(ref mut whens) = merge.whens {
1670 quote_identifiers_recursive(whens, reserved_words);
1671 }
1672 if let Some(ref mut with) = merge.with_ {
1673 quote_identifiers_recursive(with, reserved_words);
1674 }
1675 if let Some(ref mut ret) = merge.returning {
1676 quote_identifiers_recursive(ret, reserved_words);
1677 }
1678 }
1679
1680 Expression::LateralView(lv) => {
1682 quote_lateral_view(lv, reserved_words);
1683 }
1684
1685 Expression::Anonymous(anon) => {
1687 quote_identifiers_recursive(&mut anon.this, reserved_words);
1688 for e in &mut anon.expressions {
1689 quote_identifiers_recursive(e, reserved_words);
1690 }
1691 }
1692
1693 Expression::Filter(filter) => {
1695 quote_identifiers_recursive(&mut filter.this, reserved_words);
1696 quote_identifiers_recursive(&mut filter.expression, reserved_words);
1697 }
1698
1699 Expression::Returning(ret) => {
1701 for e in &mut ret.expressions {
1702 quote_identifiers_recursive(e, reserved_words);
1703 }
1704 }
1705
1706 Expression::BracedWildcard(inner) => {
1708 quote_identifiers_recursive(inner, reserved_words);
1709 }
1710
1711 Expression::ReturnStmt(inner) => {
1713 quote_identifiers_recursive(inner, reserved_words);
1714 }
1715
1716 Expression::Literal(_)
1718 | Expression::Boolean(_)
1719 | Expression::Null(_)
1720 | Expression::DataType(_)
1721 | Expression::Raw(_)
1722 | Expression::Placeholder(_)
1723 | Expression::CurrentDate(_)
1724 | Expression::CurrentTime(_)
1725 | Expression::CurrentTimestamp(_)
1726 | Expression::CurrentTimestampLTZ(_)
1727 | Expression::SessionUser(_)
1728 | Expression::RowNumber(_)
1729 | Expression::Rank(_)
1730 | Expression::DenseRank(_)
1731 | Expression::PercentRank(_)
1732 | Expression::CumeDist(_)
1733 | Expression::Random(_)
1734 | Expression::Pi(_)
1735 | Expression::JSONPathRoot(_) => {
1736 }
1738
1739 _ => {}
1743 }
1744}
1745
1746fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
1748 quote_identifiers_recursive(&mut join.this, reserved_words);
1749 if let Some(ref mut on) = join.on {
1750 quote_identifiers_recursive(on, reserved_words);
1751 }
1752 for id in &mut join.using {
1753 maybe_quote(id, reserved_words);
1754 }
1755 if let Some(ref mut mc) = join.match_condition {
1756 quote_identifiers_recursive(mc, reserved_words);
1757 }
1758 for piv in &mut join.pivots {
1759 quote_identifiers_recursive(piv, reserved_words);
1760 }
1761}
1762
1763fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
1765 for cte in &mut with.ctes {
1766 maybe_quote(&mut cte.alias, reserved_words);
1767 for c in &mut cte.columns {
1768 maybe_quote(c, reserved_words);
1769 }
1770 for k in &mut cte.key_expressions {
1771 maybe_quote(k, reserved_words);
1772 }
1773 quote_identifiers_recursive(&mut cte.this, reserved_words);
1774 }
1775}
1776
1777fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
1779 if let Some(ref mut wn) = over.window_name {
1780 maybe_quote(wn, reserved_words);
1781 }
1782 for e in &mut over.partition_by {
1783 quote_identifiers_recursive(e, reserved_words);
1784 }
1785 for o in &mut over.order_by {
1786 quote_identifiers_recursive(&mut o.this, reserved_words);
1787 }
1788 if let Some(ref mut alias) = over.alias {
1789 maybe_quote(alias, reserved_words);
1790 }
1791}
1792
1793fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
1795 maybe_quote(&mut table_ref.name, reserved_words);
1796 if let Some(ref mut schema) = table_ref.schema {
1797 maybe_quote(schema, reserved_words);
1798 }
1799 if let Some(ref mut catalog) = table_ref.catalog {
1800 maybe_quote(catalog, reserved_words);
1801 }
1802 if let Some(ref mut alias) = table_ref.alias {
1803 maybe_quote(alias, reserved_words);
1804 }
1805 for ca in &mut table_ref.column_aliases {
1806 maybe_quote(ca, reserved_words);
1807 }
1808 for p in &mut table_ref.partitions {
1809 maybe_quote(p, reserved_words);
1810 }
1811 for h in &mut table_ref.hints {
1812 quote_identifiers_recursive(h, reserved_words);
1813 }
1814}
1815
1816fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
1818 quote_identifiers_recursive(&mut lv.this, reserved_words);
1819 if let Some(ref mut ta) = lv.table_alias {
1820 maybe_quote(ta, reserved_words);
1821 }
1822 for ca in &mut lv.column_aliases {
1823 maybe_quote(ca, reserved_words);
1824 }
1825}
1826
1827pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
1838 let reserved_words = get_reserved_words(dialect);
1839 let mut result = expression;
1840 quote_identifiers_recursive(&mut result, &reserved_words);
1841 result
1842}
1843
1844pub fn pushdown_cte_alias_columns(_scope: &Scope) {
1849 }
1852
1853fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
1859 let mut columns = Vec::new();
1860 collect_columns(&scope.expression, &mut columns);
1861 columns
1862}
1863
1864#[derive(Debug, Clone)]
1866struct ColumnRef {
1867 table: Option<String>,
1868 name: String,
1869}
1870
1871fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
1873 match expr {
1874 Expression::Column(col) => {
1875 columns.push(ColumnRef {
1876 table: col.table.as_ref().map(|t| t.name.clone()),
1877 name: col.name.name.clone(),
1878 });
1879 }
1880 Expression::Select(select) => {
1881 for e in &select.expressions {
1882 collect_columns(e, columns);
1883 }
1884 if let Some(from) = &select.from {
1885 for e in &from.expressions {
1886 collect_columns(e, columns);
1887 }
1888 }
1889 if let Some(where_clause) = &select.where_clause {
1890 collect_columns(&where_clause.this, columns);
1891 }
1892 if let Some(group_by) = &select.group_by {
1893 for e in &group_by.expressions {
1894 collect_columns(e, columns);
1895 }
1896 }
1897 if let Some(having) = &select.having {
1898 collect_columns(&having.this, columns);
1899 }
1900 if let Some(order_by) = &select.order_by {
1901 for o in &order_by.expressions {
1902 collect_columns(&o.this, columns);
1903 }
1904 }
1905 for join in &select.joins {
1906 collect_columns(&join.this, columns);
1907 if let Some(on) = &join.on {
1908 collect_columns(on, columns);
1909 }
1910 }
1911 }
1912 Expression::Alias(alias) => {
1913 collect_columns(&alias.this, columns);
1914 }
1915 Expression::Function(func) => {
1916 for arg in &func.args {
1917 collect_columns(arg, columns);
1918 }
1919 }
1920 Expression::AggregateFunction(agg) => {
1921 for arg in &agg.args {
1922 collect_columns(arg, columns);
1923 }
1924 }
1925 Expression::And(bin)
1926 | Expression::Or(bin)
1927 | Expression::Eq(bin)
1928 | Expression::Neq(bin)
1929 | Expression::Lt(bin)
1930 | Expression::Lte(bin)
1931 | Expression::Gt(bin)
1932 | Expression::Gte(bin)
1933 | Expression::Add(bin)
1934 | Expression::Sub(bin)
1935 | Expression::Mul(bin)
1936 | Expression::Div(bin) => {
1937 collect_columns(&bin.left, columns);
1938 collect_columns(&bin.right, columns);
1939 }
1940 Expression::Not(unary) | Expression::Neg(unary) => {
1941 collect_columns(&unary.this, columns);
1942 }
1943 Expression::Paren(paren) => {
1944 collect_columns(&paren.this, columns);
1945 }
1946 Expression::Case(case) => {
1947 if let Some(operand) = &case.operand {
1948 collect_columns(operand, columns);
1949 }
1950 for (when, then) in &case.whens {
1951 collect_columns(when, columns);
1952 collect_columns(then, columns);
1953 }
1954 if let Some(else_) = &case.else_ {
1955 collect_columns(else_, columns);
1956 }
1957 }
1958 Expression::Cast(cast) => {
1959 collect_columns(&cast.this, columns);
1960 }
1961 Expression::In(in_expr) => {
1962 collect_columns(&in_expr.this, columns);
1963 for e in &in_expr.expressions {
1964 collect_columns(e, columns);
1965 }
1966 if let Some(query) = &in_expr.query {
1967 collect_columns(query, columns);
1968 }
1969 }
1970 Expression::Between(between) => {
1971 collect_columns(&between.this, columns);
1972 collect_columns(&between.low, columns);
1973 collect_columns(&between.high, columns);
1974 }
1975 Expression::Subquery(subquery) => {
1976 collect_columns(&subquery.this, columns);
1977 }
1978 _ => {}
1979 }
1980}
1981
1982fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
1984 get_scope_columns(scope)
1985 .into_iter()
1986 .filter(|c| c.table.is_none())
1987 .collect()
1988}
1989
1990fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
1992 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
1993
1994 get_scope_columns(scope)
1995 .into_iter()
1996 .filter(|c| {
1997 if let Some(table) = &c.table {
1998 !source_names.contains(table)
1999 } else {
2000 false
2001 }
2002 })
2003 .collect()
2004}
2005
2006fn is_correlated_subquery(scope: &Scope) -> bool {
2008 scope.can_be_correlated && !get_external_columns(scope).is_empty()
2009}
2010
2011fn is_star_column(col: &Column) -> bool {
2013 col.name.name == "*"
2014}
2015
2016fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
2018 Expression::Column(Column {
2019 name: Identifier::new(name),
2020 table: table.map(Identifier::new),
2021 join_mark: false,
2022 trailing_comments: vec![],
2023 })
2024}
2025
2026fn create_alias(expr: Expression, alias_name: &str) -> Expression {
2028 Expression::Alias(Box::new(Alias {
2029 this: expr,
2030 alias: Identifier::new(alias_name),
2031 column_aliases: vec![],
2032 pre_alias_comments: vec![],
2033 trailing_comments: vec![],
2034 }))
2035}
2036
2037fn get_output_name(expr: &Expression) -> Option<String> {
2039 match expr {
2040 Expression::Column(col) => Some(col.name.name.clone()),
2041 Expression::Alias(alias) => Some(alias.alias.name.clone()),
2042 Expression::Identifier(id) => Some(id.name.clone()),
2043 _ => None,
2044 }
2045}
2046
2047#[cfg(test)]
2048mod tests {
2049 use super::*;
2050 use crate::generator::Generator;
2051 use crate::parser::Parser;
2052 use crate::scope::build_scope;
2053
2054 fn gen(expr: &Expression) -> String {
2055 Generator::new().generate(expr).unwrap()
2056 }
2057
2058 fn parse(sql: &str) -> Expression {
2059 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
2060 }
2061
2062 #[test]
2063 fn test_qualify_columns_options() {
2064 let options = QualifyColumnsOptions::new()
2065 .with_expand_alias_refs(true)
2066 .with_expand_stars(false)
2067 .with_dialect(DialectType::PostgreSQL)
2068 .with_allow_partial(true);
2069
2070 assert!(options.expand_alias_refs);
2071 assert!(!options.expand_stars);
2072 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
2073 assert!(options.allow_partial_qualification);
2074 }
2075
2076 #[test]
2077 fn test_get_scope_columns() {
2078 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2079 let scope = build_scope(&expr);
2080 let columns = get_scope_columns(&scope);
2081
2082 assert!(columns.iter().any(|c| c.name == "a"));
2083 assert!(columns.iter().any(|c| c.name == "b"));
2084 assert!(columns.iter().any(|c| c.name == "c"));
2085 }
2086
2087 #[test]
2088 fn test_get_unqualified_columns() {
2089 let expr = parse("SELECT t.a, b FROM t");
2090 let scope = build_scope(&expr);
2091 let unqualified = get_unqualified_columns(&scope);
2092
2093 assert!(unqualified.iter().any(|c| c.name == "b"));
2095 assert!(!unqualified.iter().any(|c| c.name == "a"));
2096 }
2097
2098 #[test]
2099 fn test_is_star_column() {
2100 let col = Column {
2101 name: Identifier::new("*"),
2102 table: Some(Identifier::new("t")),
2103 join_mark: false,
2104 trailing_comments: vec![],
2105 };
2106 assert!(is_star_column(&col));
2107
2108 let col2 = Column {
2109 name: Identifier::new("id"),
2110 table: None,
2111 join_mark: false,
2112 trailing_comments: vec![],
2113 };
2114 assert!(!is_star_column(&col2));
2115 }
2116
2117 #[test]
2118 fn test_create_qualified_column() {
2119 let expr = create_qualified_column("id", Some("users"));
2120 let sql = gen(&expr);
2121 assert!(sql.contains("users"));
2122 assert!(sql.contains("id"));
2123 }
2124
2125 #[test]
2126 fn test_create_alias() {
2127 let col = Expression::Column(Column {
2128 name: Identifier::new("value"),
2129 table: None,
2130 join_mark: false,
2131 trailing_comments: vec![],
2132 });
2133 let aliased = create_alias(col, "total");
2134 let sql = gen(&aliased);
2135 assert!(sql.contains("AS") || sql.contains("total"));
2136 }
2137
2138 #[test]
2139 fn test_validate_qualify_columns_success() {
2140 let expr = parse("SELECT t.a, t.b FROM t");
2142 let result = validate_qualify_columns(&expr);
2143 let _ = result;
2146 }
2147
2148 #[test]
2149 fn test_collect_columns_nested() {
2150 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
2151 let mut columns = Vec::new();
2152 collect_columns(&expr, &mut columns);
2153
2154 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2155 assert!(names.contains(&"a"));
2156 assert!(names.contains(&"b"));
2157 assert!(names.contains(&"c"));
2158 assert!(names.contains(&"d"));
2159 assert!(names.contains(&"e"));
2160 assert!(names.contains(&"f"));
2161 }
2162
2163 #[test]
2164 fn test_collect_columns_in_case() {
2165 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
2166 let mut columns = Vec::new();
2167 collect_columns(&expr, &mut columns);
2168
2169 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2170 assert!(names.contains(&"a"));
2171 assert!(names.contains(&"b"));
2172 assert!(names.contains(&"c"));
2173 }
2174
2175 #[test]
2176 fn test_collect_columns_in_subquery() {
2177 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
2178 let mut columns = Vec::new();
2179 collect_columns(&expr, &mut columns);
2180
2181 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
2182 assert!(names.contains(&"a"));
2183 assert!(names.contains(&"b"));
2184 assert!(names.contains(&"c"));
2185 }
2186
2187 #[test]
2188 fn test_qualify_outputs_basic() {
2189 let expr = parse("SELECT a, b + c FROM t");
2190 let scope = build_scope(&expr);
2191 let result = qualify_outputs(&scope);
2192 assert!(result.is_ok());
2193 }
2194
2195 #[test]
2200 fn test_needs_quoting_reserved_word() {
2201 let reserved = get_reserved_words(None);
2202 assert!(needs_quoting("select", &reserved));
2203 assert!(needs_quoting("SELECT", &reserved));
2204 assert!(needs_quoting("from", &reserved));
2205 assert!(needs_quoting("WHERE", &reserved));
2206 assert!(needs_quoting("join", &reserved));
2207 assert!(needs_quoting("table", &reserved));
2208 }
2209
2210 #[test]
2211 fn test_needs_quoting_normal_identifiers() {
2212 let reserved = get_reserved_words(None);
2213 assert!(!needs_quoting("foo", &reserved));
2214 assert!(!needs_quoting("my_column", &reserved));
2215 assert!(!needs_quoting("col1", &reserved));
2216 assert!(!needs_quoting("A", &reserved));
2217 assert!(!needs_quoting("_hidden", &reserved));
2218 }
2219
2220 #[test]
2221 fn test_needs_quoting_special_characters() {
2222 let reserved = get_reserved_words(None);
2223 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
2229
2230 #[test]
2231 fn test_needs_quoting_starts_with_digit() {
2232 let reserved = get_reserved_words(None);
2233 assert!(needs_quoting("1col", &reserved));
2234 assert!(needs_quoting("123", &reserved));
2235 assert!(needs_quoting("0_start", &reserved));
2236 }
2237
2238 #[test]
2239 fn test_needs_quoting_empty() {
2240 let reserved = get_reserved_words(None);
2241 assert!(!needs_quoting("", &reserved));
2242 }
2243
2244 #[test]
2245 fn test_maybe_quote_sets_quoted_flag() {
2246 let reserved = get_reserved_words(None);
2247 let mut id = Identifier::new("select");
2248 assert!(!id.quoted);
2249 maybe_quote(&mut id, &reserved);
2250 assert!(id.quoted);
2251 }
2252
2253 #[test]
2254 fn test_maybe_quote_skips_already_quoted() {
2255 let reserved = get_reserved_words(None);
2256 let mut id = Identifier::quoted("myname");
2257 assert!(id.quoted);
2258 maybe_quote(&mut id, &reserved);
2259 assert!(id.quoted); assert_eq!(id.name, "myname"); }
2262
2263 #[test]
2264 fn test_maybe_quote_skips_star() {
2265 let reserved = get_reserved_words(None);
2266 let mut id = Identifier::new("*");
2267 maybe_quote(&mut id, &reserved);
2268 assert!(!id.quoted); }
2270
2271 #[test]
2272 fn test_maybe_quote_skips_normal() {
2273 let reserved = get_reserved_words(None);
2274 let mut id = Identifier::new("normal_col");
2275 maybe_quote(&mut id, &reserved);
2276 assert!(!id.quoted);
2277 }
2278
2279 #[test]
2280 fn test_quote_identifiers_column_with_reserved_name() {
2281 let expr = Expression::Column(Column {
2283 name: Identifier::new("select"),
2284 table: None,
2285 join_mark: false,
2286 trailing_comments: vec![],
2287 });
2288 let result = quote_identifiers(expr, None);
2289 if let Expression::Column(col) = &result {
2290 assert!(col.name.quoted, "Column named 'select' should be quoted");
2291 } else {
2292 panic!("Expected Column expression");
2293 }
2294 }
2295
2296 #[test]
2297 fn test_quote_identifiers_column_with_special_chars() {
2298 let expr = Expression::Column(Column {
2299 name: Identifier::new("my column"),
2300 table: None,
2301 join_mark: false,
2302 trailing_comments: vec![],
2303 });
2304 let result = quote_identifiers(expr, None);
2305 if let Expression::Column(col) = &result {
2306 assert!(col.name.quoted, "Column with space should be quoted");
2307 } else {
2308 panic!("Expected Column expression");
2309 }
2310 }
2311
2312 #[test]
2313 fn test_quote_identifiers_preserves_normal_column() {
2314 let expr = Expression::Column(Column {
2315 name: Identifier::new("normal_col"),
2316 table: Some(Identifier::new("my_table")),
2317 join_mark: false,
2318 trailing_comments: vec![],
2319 });
2320 let result = quote_identifiers(expr, None);
2321 if let Expression::Column(col) = &result {
2322 assert!(!col.name.quoted, "Normal column should not be quoted");
2323 assert!(
2324 !col.table.as_ref().unwrap().quoted,
2325 "Normal table should not be quoted"
2326 );
2327 } else {
2328 panic!("Expected Column expression");
2329 }
2330 }
2331
2332 #[test]
2333 fn test_quote_identifiers_table_ref_reserved() {
2334 let expr = Expression::Table(TableRef::new("select"));
2335 let result = quote_identifiers(expr, None);
2336 if let Expression::Table(tr) = &result {
2337 assert!(tr.name.quoted, "Table named 'select' should be quoted");
2338 } else {
2339 panic!("Expected Table expression");
2340 }
2341 }
2342
2343 #[test]
2344 fn test_quote_identifiers_table_ref_schema_and_alias() {
2345 let mut tr = TableRef::new("my_table");
2346 tr.schema = Some(Identifier::new("from"));
2347 tr.alias = Some(Identifier::new("t"));
2348 let expr = Expression::Table(tr);
2349 let result = quote_identifiers(expr, None);
2350 if let Expression::Table(tr) = &result {
2351 assert!(!tr.name.quoted, "Normal table name should not be quoted");
2352 assert!(
2353 tr.schema.as_ref().unwrap().quoted,
2354 "Schema named 'from' should be quoted"
2355 );
2356 assert!(
2357 !tr.alias.as_ref().unwrap().quoted,
2358 "Normal alias should not be quoted"
2359 );
2360 } else {
2361 panic!("Expected Table expression");
2362 }
2363 }
2364
2365 #[test]
2366 fn test_quote_identifiers_identifier_node() {
2367 let expr = Expression::Identifier(Identifier::new("order"));
2368 let result = quote_identifiers(expr, None);
2369 if let Expression::Identifier(id) = &result {
2370 assert!(id.quoted, "Identifier named 'order' should be quoted");
2371 } else {
2372 panic!("Expected Identifier expression");
2373 }
2374 }
2375
2376 #[test]
2377 fn test_quote_identifiers_alias() {
2378 let inner = Expression::Column(Column {
2379 name: Identifier::new("val"),
2380 table: None,
2381 join_mark: false,
2382 trailing_comments: vec![],
2383 });
2384 let expr = Expression::Alias(Box::new(Alias {
2385 this: inner,
2386 alias: Identifier::new("select"),
2387 column_aliases: vec![Identifier::new("from")],
2388 pre_alias_comments: vec![],
2389 trailing_comments: vec![],
2390 }));
2391 let result = quote_identifiers(expr, None);
2392 if let Expression::Alias(alias) = &result {
2393 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
2394 assert!(
2395 alias.column_aliases[0].quoted,
2396 "Column alias named 'from' should be quoted"
2397 );
2398 if let Expression::Column(col) = &alias.this {
2400 assert!(!col.name.quoted);
2401 }
2402 } else {
2403 panic!("Expected Alias expression");
2404 }
2405 }
2406
2407 #[test]
2408 fn test_quote_identifiers_select_recursive() {
2409 let expr = parse("SELECT a, b FROM t WHERE c = 1");
2411 let result = quote_identifiers(expr, None);
2412 let sql = gen(&result);
2414 assert!(sql.contains("a"));
2416 assert!(sql.contains("b"));
2417 assert!(sql.contains("t"));
2418 }
2419
2420 #[test]
2421 fn test_quote_identifiers_digit_start() {
2422 let expr = Expression::Column(Column {
2423 name: Identifier::new("1col"),
2424 table: None,
2425 join_mark: false,
2426 trailing_comments: vec![],
2427 });
2428 let result = quote_identifiers(expr, None);
2429 if let Expression::Column(col) = &result {
2430 assert!(
2431 col.name.quoted,
2432 "Column starting with digit should be quoted"
2433 );
2434 } else {
2435 panic!("Expected Column expression");
2436 }
2437 }
2438
2439 #[test]
2440 fn test_quote_identifiers_with_mysql_dialect() {
2441 let reserved = get_reserved_words(Some(DialectType::MySQL));
2442 assert!(needs_quoting("KILL", &reserved));
2444 assert!(needs_quoting("FORCE", &reserved));
2446 }
2447
2448 #[test]
2449 fn test_quote_identifiers_with_postgresql_dialect() {
2450 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
2451 assert!(needs_quoting("ILIKE", &reserved));
2453 assert!(needs_quoting("VERBOSE", &reserved));
2455 }
2456
2457 #[test]
2458 fn test_quote_identifiers_with_bigquery_dialect() {
2459 let reserved = get_reserved_words(Some(DialectType::BigQuery));
2460 assert!(needs_quoting("STRUCT", &reserved));
2462 assert!(needs_quoting("PROTO", &reserved));
2464 }
2465
2466 #[test]
2467 fn test_quote_identifiers_case_insensitive_reserved() {
2468 let reserved = get_reserved_words(None);
2469 assert!(needs_quoting("Select", &reserved));
2470 assert!(needs_quoting("sElEcT", &reserved));
2471 assert!(needs_quoting("FROM", &reserved));
2472 assert!(needs_quoting("from", &reserved));
2473 }
2474
2475 #[test]
2476 fn test_quote_identifiers_join_using() {
2477 let mut join = crate::expressions::Join {
2479 this: Expression::Table(TableRef::new("other")),
2480 on: None,
2481 using: vec![Identifier::new("key"), Identifier::new("value")],
2482 kind: crate::expressions::JoinKind::Inner,
2483 use_inner_keyword: false,
2484 use_outer_keyword: false,
2485 deferred_condition: false,
2486 join_hint: None,
2487 match_condition: None,
2488 pivots: vec![],
2489 comments: vec![],
2490 nesting_group: 0,
2491 directed: false,
2492 };
2493 let reserved = get_reserved_words(None);
2494 quote_join(&mut join, &reserved);
2495 assert!(
2497 join.using[0].quoted,
2498 "USING identifier 'key' should be quoted"
2499 );
2500 assert!(
2501 !join.using[1].quoted,
2502 "USING identifier 'value' should not be quoted"
2503 );
2504 }
2505
2506 #[test]
2507 fn test_quote_identifiers_cte() {
2508 let mut cte = crate::expressions::Cte {
2510 alias: Identifier::new("select"),
2511 this: Expression::Column(Column {
2512 name: Identifier::new("x"),
2513 table: None,
2514 join_mark: false,
2515 trailing_comments: vec![],
2516 }),
2517 columns: vec![Identifier::new("from"), Identifier::new("normal")],
2518 materialized: None,
2519 key_expressions: vec![],
2520 alias_first: false,
2521 comments: Vec::new(),
2522 };
2523 let reserved = get_reserved_words(None);
2524 maybe_quote(&mut cte.alias, &reserved);
2525 for c in &mut cte.columns {
2526 maybe_quote(c, &reserved);
2527 }
2528 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
2529 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
2530 assert!(
2531 !cte.columns[1].quoted,
2532 "CTE column 'normal' should not be quoted"
2533 );
2534 }
2535
2536 #[test]
2537 fn test_quote_identifiers_binary_ops_recurse() {
2538 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
2541 Expression::Column(Column {
2542 name: Identifier::new("select"),
2543 table: None,
2544 join_mark: false,
2545 trailing_comments: vec![],
2546 }),
2547 Expression::Column(Column {
2548 name: Identifier::new("normal"),
2549 table: None,
2550 join_mark: false,
2551 trailing_comments: vec![],
2552 }),
2553 )));
2554 let result = quote_identifiers(expr, None);
2555 if let Expression::Add(bin) = &result {
2556 if let Expression::Column(left) = &bin.left {
2557 assert!(
2558 left.name.quoted,
2559 "'select' column should be quoted in binary op"
2560 );
2561 }
2562 if let Expression::Column(right) = &bin.right {
2563 assert!(!right.name.quoted, "'normal' column should not be quoted");
2564 }
2565 } else {
2566 panic!("Expected Add expression");
2567 }
2568 }
2569
2570 #[test]
2571 fn test_quote_identifiers_already_quoted_preserved() {
2572 let expr = Expression::Column(Column {
2574 name: Identifier::quoted("normal_name"),
2575 table: None,
2576 join_mark: false,
2577 trailing_comments: vec![],
2578 });
2579 let result = quote_identifiers(expr, None);
2580 if let Expression::Column(col) = &result {
2581 assert!(
2582 col.name.quoted,
2583 "Already-quoted identifier should remain quoted"
2584 );
2585 } else {
2586 panic!("Expected Column expression");
2587 }
2588 }
2589
2590 #[test]
2591 fn test_quote_identifiers_full_parsed_query() {
2592 let mut select = crate::expressions::Select::new();
2595 select.expressions.push(Expression::Column(Column {
2596 name: Identifier::new("order"),
2597 table: Some(Identifier::new("t")),
2598 join_mark: false,
2599 trailing_comments: vec![],
2600 }));
2601 select.from = Some(crate::expressions::From {
2602 expressions: vec![Expression::Table(TableRef::new("t"))],
2603 });
2604 let expr = Expression::Select(Box::new(select));
2605
2606 let result = quote_identifiers(expr, None);
2607 if let Expression::Select(sel) = &result {
2608 if let Expression::Column(col) = &sel.expressions[0] {
2609 assert!(col.name.quoted, "Column named 'order' should be quoted");
2610 assert!(
2611 !col.table.as_ref().unwrap().quoted,
2612 "Table 't' should not be quoted"
2613 );
2614 } else {
2615 panic!("Expected Column in SELECT list");
2616 }
2617 } else {
2618 panic!("Expected Select expression");
2619 }
2620 }
2621
2622 #[test]
2623 fn test_get_reserved_words_all_dialects() {
2624 let dialects = [
2626 None,
2627 Some(DialectType::Generic),
2628 Some(DialectType::MySQL),
2629 Some(DialectType::PostgreSQL),
2630 Some(DialectType::BigQuery),
2631 Some(DialectType::Snowflake),
2632 Some(DialectType::TSQL),
2633 Some(DialectType::ClickHouse),
2634 Some(DialectType::DuckDB),
2635 Some(DialectType::Hive),
2636 Some(DialectType::Spark),
2637 Some(DialectType::Trino),
2638 Some(DialectType::Oracle),
2639 Some(DialectType::Redshift),
2640 ];
2641 for dialect in &dialects {
2642 let words = get_reserved_words(*dialect);
2643 assert!(
2645 words.contains("SELECT"),
2646 "All dialects should have SELECT as reserved"
2647 );
2648 assert!(
2649 words.contains("FROM"),
2650 "All dialects should have FROM as reserved"
2651 );
2652 }
2653 }
2654}