1use crate::dialects::DialectType;
9use crate::expressions::{Alias, Column, Expression, Identifier, Join, LateralView, Over, TableRef, With};
10use crate::resolver::{Resolver, ResolverError};
11use crate::schema::Schema;
12use crate::scope::{traverse_scope, Scope};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16#[derive(Debug, Error, Clone)]
18pub enum QualifyColumnsError {
19 #[error("Unknown table: {0}")]
20 UnknownTable(String),
21
22 #[error("Unknown column: {0}")]
23 UnknownColumn(String),
24
25 #[error("Ambiguous column: {0}")]
26 AmbiguousColumn(String),
27
28 #[error("Cannot automatically join: {0}")]
29 CannotAutoJoin(String),
30
31 #[error("Unknown output column: {0}")]
32 UnknownOutputColumn(String),
33
34 #[error("Column could not be resolved: {column}{for_table}")]
35 ColumnNotResolved { column: String, for_table: String },
36
37 #[error("Resolver error: {0}")]
38 ResolverError(#[from] ResolverError),
39}
40
41pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
43
44#[derive(Debug, Clone, Default)]
46pub struct QualifyColumnsOptions {
47 pub expand_alias_refs: bool,
49 pub expand_stars: bool,
51 pub infer_schema: Option<bool>,
53 pub allow_partial_qualification: bool,
55 pub dialect: Option<DialectType>,
57}
58
59impl QualifyColumnsOptions {
60 pub fn new() -> Self {
62 Self {
63 expand_alias_refs: true,
64 expand_stars: true,
65 infer_schema: None,
66 allow_partial_qualification: false,
67 dialect: None,
68 }
69 }
70
71 pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
73 self.expand_alias_refs = expand;
74 self
75 }
76
77 pub fn with_expand_stars(mut self, expand: bool) -> Self {
79 self.expand_stars = expand;
80 self
81 }
82
83 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
85 self.dialect = Some(dialect);
86 self
87 }
88
89 pub fn with_allow_partial(mut self, allow: bool) -> Self {
91 self.allow_partial_qualification = allow;
92 self
93 }
94}
95
96pub fn qualify_columns(
111 expression: Expression,
112 schema: &dyn Schema,
113 options: &QualifyColumnsOptions,
114) -> QualifyColumnsResult<Expression> {
115 let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
116 let dialect = options.dialect.or_else(|| schema.dialect());
117
118 let result = expression.clone();
119
120 for scope in traverse_scope(&expression) {
122 let scope_expression = &scope.expression;
123 let is_select = matches!(scope_expression, Expression::Select(_));
124
125 let mut resolver = Resolver::new(&scope, schema, infer_schema);
127
128 qualify_columns_in_scope(&scope, &mut resolver, options.allow_partial_qualification)?;
130
131 if options.expand_alias_refs {
133 expand_alias_refs(&scope, &mut resolver, dialect)?;
134 }
135
136 if is_select {
137 if options.expand_stars {
139 expand_stars(&scope, &mut resolver)?;
140 }
141
142 qualify_outputs(&scope)?;
144 }
145
146 expand_group_by(&scope, dialect)?;
148 }
149
150 Ok(result)
151}
152
153pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
158 let mut all_unqualified = Vec::new();
159
160 for scope in traverse_scope(expression) {
161 if let Expression::Select(_) = &scope.expression {
162 let unqualified = get_unqualified_columns(&scope);
164
165 let external = get_external_columns(&scope);
167 if !external.is_empty() && !is_correlated_subquery(&scope) {
168 let first = &external[0];
169 let for_table = if first.table.is_some() {
170 format!(" for table: '{}'", first.table.as_ref().unwrap())
171 } else {
172 String::new()
173 };
174 return Err(QualifyColumnsError::ColumnNotResolved {
175 column: first.name.clone(),
176 for_table,
177 });
178 }
179
180 all_unqualified.extend(unqualified);
181 }
182 }
183
184 if !all_unqualified.is_empty() {
185 let first = &all_unqualified[0];
186 return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
187 }
188
189 Ok(())
190}
191
192fn qualify_columns_in_scope(
194 scope: &Scope,
195 resolver: &mut Resolver,
196 allow_partial: bool,
197) -> QualifyColumnsResult<()> {
198 let columns = get_scope_columns(scope);
199
200 for column_ref in columns {
201 let column_table = &column_ref.table;
202 let column_name = &column_ref.name;
203
204 if let Some(table) = column_table {
205 if scope.sources.contains_key(table) {
207 if let Ok(source_columns) = resolver.get_source_columns(table) {
208 if !allow_partial
209 && !source_columns.is_empty()
210 && !source_columns.contains(column_name)
211 && !source_columns.contains(&"*".to_string())
212 {
213 return Err(QualifyColumnsError::UnknownColumn(column_name.clone()));
214 }
215 }
216 }
217 } else {
218 if let Some(table) = resolver.get_table(column_name) {
220 let _ = table;
223 }
224 }
225 }
226
227 Ok(())
228}
229
230fn expand_alias_refs(
237 scope: &Scope,
238 _resolver: &mut Resolver,
239 _dialect: Option<DialectType>,
240) -> QualifyColumnsResult<()> {
241 let expression = &scope.expression;
242
243 if !matches!(expression, Expression::Select(_)) {
244 return Ok(());
245 }
246
247 let mut _alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
249
250 if let Expression::Select(select) = expression {
251 for (i, expr) in select.expressions.iter().enumerate() {
252 if let Expression::Alias(alias) = expr {
253 _alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
254 }
255 }
256 }
257
258 Ok(())
262}
263
264fn expand_group_by(scope: &Scope, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
271 if let Expression::Select(select) = &scope.expression {
272 if let Some(_group_by) = &select.group_by {
273 }
276 }
277 Ok(())
278}
279
280fn expand_stars(scope: &Scope, resolver: &mut Resolver) -> QualifyColumnsResult<()> {
287 if let Expression::Select(select) = &scope.expression {
288 let mut _new_selections: Vec<Expression> = Vec::new();
289 let mut _has_star = false;
290
291 for expr in &select.expressions {
292 match expr {
293 Expression::Star(_) => {
294 _has_star = true;
295 for (source_name, _) in &scope.sources {
297 if let Ok(columns) = resolver.get_source_columns(source_name) {
298 if columns.contains(&"*".to_string()) || columns.is_empty() {
299 return Ok(());
301 }
302 for col_name in columns {
303 _new_selections.push(create_qualified_column(
304 &col_name,
305 Some(source_name),
306 ));
307 }
308 }
309 }
310 }
311 Expression::Column(col) if is_star_column(col) => {
312 _has_star = true;
313 if let Some(table) = &col.table {
315 let table_name = &table.name;
316 if !scope.sources.contains_key(table_name) {
317 return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
318 }
319 if let Ok(columns) = resolver.get_source_columns(table_name) {
320 if columns.contains(&"*".to_string()) || columns.is_empty() {
321 return Ok(());
322 }
323 for col_name in columns {
324 _new_selections.push(create_qualified_column(
325 &col_name,
326 Some(table_name),
327 ));
328 }
329 }
330 }
331 }
332 _ => {
333 _new_selections.push(expr.clone());
334 }
335 }
336 }
337
338 }
341
342 Ok(())
343}
344
345pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
352 if let Expression::Select(select) = &scope.expression {
353 let mut new_selections: Vec<Expression> = Vec::new();
354
355 for (i, expr) in select.expressions.iter().enumerate() {
356 match expr {
357 Expression::Alias(_) => {
358 new_selections.push(expr.clone());
360 }
361 Expression::Column(col) => {
362 new_selections.push(create_alias(expr.clone(), &col.name.name));
364 }
365 Expression::Star(_) => {
366 new_selections.push(expr.clone());
368 }
369 _ => {
370 let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
372 new_selections.push(create_alias(expr.clone(), &alias_name));
373 }
374 }
375 }
376
377 }
379
380 Ok(())
381}
382
383fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
386 let mut words: HashSet<&'static str> = [
388 "ADD", "ALL", "ALTER", "AND", "ANY", "AS", "ASC", "BETWEEN", "BY",
390 "CASE", "CAST", "CHECK", "COLUMN", "CONSTRAINT", "CREATE", "CROSS",
391 "CURRENT", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP",
392 "CURRENT_USER", "DATABASE", "DEFAULT", "DELETE", "DESC", "DISTINCT",
393 "DROP", "ELSE", "END", "ESCAPE", "EXCEPT", "EXISTS", "FALSE",
394 "FETCH", "FOR", "FOREIGN", "FROM", "FULL", "GRANT", "GROUP",
395 "HAVING", "IF", "IN", "INDEX", "INNER", "INSERT", "INTERSECT",
396 "INTO", "IS", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "NATURAL",
397 "NOT", "NULL", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PRIMARY",
398 "REFERENCES", "REPLACE", "RETURNING", "RIGHT", "ROLLBACK", "ROW",
399 "ROWS", "SELECT", "SESSION_USER", "SET", "SOME", "TABLE", "THEN",
400 "TO", "TRUE", "TRUNCATE", "UNION", "UNIQUE", "UPDATE", "USING",
401 "VALUES", "VIEW", "WHEN", "WHERE", "WINDOW", "WITH",
402 ].iter().copied().collect();
403
404 match dialect {
406 Some(DialectType::MySQL) => {
407 words.extend([
408 "ANALYZE", "BOTH", "CHANGE", "CONDITION", "DATABASES",
409 "DAY_HOUR", "DAY_MICROSECOND", "DAY_MINUTE", "DAY_SECOND",
410 "DELAYED", "DETERMINISTIC", "DIV", "DUAL", "EACH",
411 "ELSEIF", "ENCLOSED", "EXPLAIN", "FLOAT4", "FLOAT8",
412 "FORCE", "HOUR_MICROSECOND", "HOUR_MINUTE", "HOUR_SECOND",
413 "IGNORE", "INFILE", "INT1", "INT2", "INT3", "INT4", "INT8",
414 "ITERATE", "KEYS", "KILL", "LEADING", "LEAVE", "LINES",
415 "LOAD", "LOCK", "LONG", "LONGBLOB", "LONGTEXT", "LOOP",
416 "LOW_PRIORITY", "MATCH", "MEDIUMBLOB", "MEDIUMINT",
417 "MEDIUMTEXT", "MINUTE_MICROSECOND", "MINUTE_SECOND", "MOD",
418 "MODIFIES", "NO_WRITE_TO_BINLOG", "OPTIMIZE", "OPTIONALLY",
419 "OUT", "OUTFILE", "PURGE", "READS", "REGEXP", "RELEASE",
420 "RENAME", "REPEAT", "REQUIRE", "RESIGNAL", "RETURN",
421 "REVOKE", "RLIKE", "SCHEMA", "SCHEMAS", "SECOND_MICROSECOND",
422 "SENSITIVE", "SEPARATOR", "SHOW", "SIGNAL", "SPATIAL",
423 "SQL", "SQLEXCEPTION", "SQLSTATE", "SQLWARNING",
424 "SQL_BIG_RESULT", "SQL_CALC_FOUND_ROWS", "SQL_SMALL_RESULT",
425 "SSL", "STARTING", "STRAIGHT_JOIN", "TERMINATED",
426 "TINYBLOB", "TINYINT", "TINYTEXT", "TRAILING", "TRIGGER",
427 "UNDO", "UNLOCK", "UNSIGNED", "USAGE", "UTC_DATE",
428 "UTC_TIME", "UTC_TIMESTAMP", "VARBINARY", "VARCHARACTER",
429 "WHILE", "WRITE", "XOR", "YEAR_MONTH", "ZEROFILL",
430 ].iter().copied());
431 }
432 Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
433 words.extend([
434 "ANALYSE", "ANALYZE", "ARRAY", "AUTHORIZATION", "BINARY",
435 "BOTH", "COLLATE", "CONCURRENTLY", "DO", "FREEZE",
436 "ILIKE", "INITIALLY", "ISNULL", "LATERAL", "LEADING",
437 "LOCALTIME", "LOCALTIMESTAMP", "NOTNULL", "ONLY", "OVERLAPS",
438 "PLACING", "SIMILAR", "SYMMETRIC", "TABLESAMPLE",
439 "TRAILING", "VARIADIC", "VERBOSE",
440 ].iter().copied());
441 }
442 Some(DialectType::BigQuery) => {
443 words.extend([
444 "ASSERT_ROWS_MODIFIED", "COLLATE", "CONTAINS", "CUBE",
445 "DEFINE", "ENUM", "EXTRACT", "FOLLOWING", "GROUPING",
446 "GROUPS", "HASH", "IGNORE", "LATERAL", "LOOKUP",
447 "MERGE", "NEW", "NO", "NULLS", "OF", "OVER", "PARTITION",
448 "PRECEDING", "PROTO", "RANGE", "RECURSIVE", "RESPECT",
449 "ROLLUP", "STRUCT", "TABLESAMPLE", "TREAT", "UNBOUNDED",
450 "WITHIN",
451 ].iter().copied());
452 }
453 Some(DialectType::Snowflake) => {
454 words.extend([
455 "ACCOUNT", "BOTH", "CONNECT", "FOLLOWING", "ILIKE",
456 "INCREMENT", "ISSUE", "LATERAL", "LEADING", "LOCALTIME",
457 "LOCALTIMESTAMP", "MINUS", "QUALIFY", "REGEXP", "RLIKE",
458 "SOME", "START", "TABLESAMPLE", "TOP", "TRAILING",
459 "TRY_CAST",
460 ].iter().copied());
461 }
462 Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
463 words.extend([
464 "BACKUP", "BREAK", "BROWSE", "BULK", "CASCADE", "CHECKPOINT",
465 "CLOSE", "CLUSTERED", "COALESCE", "COMPUTE", "CONTAINS",
466 "CONTAINSTABLE", "CONTINUE", "CONVERT", "DBCC",
467 "DEALLOCATE", "DENY", "DISK", "DISTRIBUTED", "DUMP",
468 "ERRLVL", "EXEC", "EXECUTE", "EXIT", "EXTERNAL", "FILE",
469 "FILLFACTOR", "FREETEXT", "FREETEXTTABLE", "FUNCTION",
470 "GOTO", "HOLDLOCK", "IDENTITY", "IDENTITYCOL",
471 "IDENTITY_INSERT", "KILL", "LINENO", "MERGE",
472 "NONCLUSTERED", "NULLIF", "OF", "OFF", "OFFSETS",
473 "OPEN", "OPENDATASOURCE", "OPENQUERY", "OPENROWSET",
474 "OPENXML", "OVER", "PERCENT", "PIVOT", "PLAN", "PRINT",
475 "PROC", "PROCEDURE", "PUBLIC", "RAISERROR", "READ",
476 "READTEXT", "RECONFIGURE", "REPLICATION", "RESTORE",
477 "RESTRICT", "REVERT", "ROWCOUNT", "ROWGUIDCOL", "RULE",
478 "SAVE", "SECURITYAUDIT", "SEMANTICKEYPHRASETABLE",
479 "SEMANTICSIMILARITYDETAILSTABLE",
480 "SEMANTICSIMILARITYTABLE", "SETUSER", "SHUTDOWN",
481 "STATISTICS", "SYSTEM_USER", "TEXTSIZE", "TOP", "TRAN",
482 "TRANSACTION", "TRIGGER", "TSEQUAL", "UNPIVOT",
483 "UPDATETEXT", "WAITFOR", "WRITETEXT",
484 ].iter().copied());
485 }
486 Some(DialectType::ClickHouse) => {
487 words.extend([
488 "ANTI", "ARRAY", "ASOF", "FINAL", "FORMAT", "GLOBAL",
489 "INF", "KILL", "MATERIALIZED", "NAN", "PREWHERE",
490 "SAMPLE", "SEMI", "SETTINGS", "TOP",
491 ].iter().copied());
492 }
493 Some(DialectType::DuckDB) => {
494 words.extend([
495 "ANALYSE", "ANALYZE", "ARRAY", "BOTH", "LATERAL",
496 "LEADING", "LOCALTIME", "LOCALTIMESTAMP", "PLACING",
497 "QUALIFY", "SIMILAR", "TABLESAMPLE", "TRAILING",
498 ].iter().copied());
499 }
500 Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
501 words.extend([
502 "BOTH", "CLUSTER", "DISTRIBUTE", "EXCHANGE", "EXTENDED",
503 "FUNCTION", "LATERAL", "LEADING", "MACRO", "OVER",
504 "PARTITION", "PERCENT", "RANGE", "READS", "REDUCE",
505 "REGEXP", "REVOKE", "RLIKE", "ROLLUP", "SEMI", "SORT",
506 "TABLESAMPLE", "TRAILING", "TRANSFORM", "UNBOUNDED",
507 "UNIQUEJOIN",
508 ].iter().copied());
509 }
510 Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
511 words.extend([
512 "CUBE", "DEALLOCATE", "DESCRIBE", "EXECUTE", "EXTRACT",
513 "GROUPING", "LATERAL", "LOCALTIME", "LOCALTIMESTAMP",
514 "NORMALIZE", "PREPARE", "ROLLUP", "SOME",
515 "TABLESAMPLE", "UESCAPE", "UNNEST",
516 ].iter().copied());
517 }
518 Some(DialectType::Oracle) => {
519 words.extend([
520 "ACCESS", "AUDIT", "CLUSTER", "COMMENT", "COMPRESS",
521 "CONNECT", "EXCLUSIVE", "FILE", "IDENTIFIED", "IMMEDIATE",
522 "INCREMENT", "INITIAL", "LEVEL", "LOCK", "LONG",
523 "MAXEXTENTS", "MINUS", "MODE", "NOAUDIT", "NOCOMPRESS",
524 "NOWAIT", "NUMBER", "OF", "OFFLINE", "ONLINE", "PCTFREE",
525 "PRIOR", "RAW", "RENAME", "RESOURCE", "REVOKE",
526 "SHARE", "SIZE", "START", "SUCCESSFUL", "SYNONYM",
527 "SYSDATE", "TRIGGER", "UID", "VALIDATE", "VARCHAR2",
528 "WHENEVER",
529 ].iter().copied());
530 }
531 Some(DialectType::Redshift) => {
532 words.extend([
533 "AZ64", "BZIP2", "DELTA", "DELTA32K", "DISTSTYLE",
534 "ENCODE", "GZIP", "ILIKE", "LIMIT", "LUNS", "LZO",
535 "LZOP", "MOSTLY13", "MOSTLY32", "MOSTLY8", "RAW",
536 "SIMILAR", "SNAPSHOT", "SORTKEY", "SYSDATE", "TOP",
537 "ZSTD",
538 ].iter().copied());
539 }
540 _ => {
541 words.extend([
543 "ANALYZE", "ARRAY", "BOTH", "CUBE", "GROUPING", "LATERAL",
544 "LEADING", "LOCALTIME", "LOCALTIMESTAMP", "OVER", "PARTITION",
545 "QUALIFY", "RANGE", "ROLLUP", "SIMILAR", "SOME",
546 "TABLESAMPLE", "TRAILING",
547 ].iter().copied());
548 }
549 }
550
551 words
552}
553
554fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
562 if name.is_empty() {
563 return false;
564 }
565
566 if name.as_bytes()[0].is_ascii_digit() {
568 return true;
569 }
570
571 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
573 return true;
574 }
575
576 let upper = name.to_uppercase();
578 reserved_words.contains(upper.as_str())
579}
580
581fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
583 if id.quoted || id.name.is_empty() || id.name == "*" {
586 return;
587 }
588 if needs_quoting(&id.name, reserved_words) {
589 id.quoted = true;
590 }
591}
592
593fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
595 match expr {
596 Expression::Identifier(id) => {
598 maybe_quote(id, reserved_words);
599 }
600
601 Expression::Column(col) => {
602 maybe_quote(&mut col.name, reserved_words);
603 if let Some(ref mut table) = col.table {
604 maybe_quote(table, reserved_words);
605 }
606 }
607
608 Expression::Table(table_ref) => {
609 maybe_quote(&mut table_ref.name, reserved_words);
610 if let Some(ref mut schema) = table_ref.schema {
611 maybe_quote(schema, reserved_words);
612 }
613 if let Some(ref mut catalog) = table_ref.catalog {
614 maybe_quote(catalog, reserved_words);
615 }
616 if let Some(ref mut alias) = table_ref.alias {
617 maybe_quote(alias, reserved_words);
618 }
619 for ca in &mut table_ref.column_aliases {
620 maybe_quote(ca, reserved_words);
621 }
622 for p in &mut table_ref.partitions {
623 maybe_quote(p, reserved_words);
624 }
625 for h in &mut table_ref.hints {
627 quote_identifiers_recursive(h, reserved_words);
628 }
629 if let Some(ref mut ver) = table_ref.version {
630 quote_identifiers_recursive(&mut ver.this, reserved_words);
631 if let Some(ref mut e) = ver.expression {
632 quote_identifiers_recursive(e, reserved_words);
633 }
634 }
635 }
636
637 Expression::Star(star) => {
638 if let Some(ref mut table) = star.table {
639 maybe_quote(table, reserved_words);
640 }
641 if let Some(ref mut except_ids) = star.except {
642 for id in except_ids {
643 maybe_quote(id, reserved_words);
644 }
645 }
646 if let Some(ref mut replace_aliases) = star.replace {
647 for alias in replace_aliases {
648 maybe_quote(&mut alias.alias, reserved_words);
649 quote_identifiers_recursive(&mut alias.this, reserved_words);
650 }
651 }
652 if let Some(ref mut rename_pairs) = star.rename {
653 for (from, to) in rename_pairs {
654 maybe_quote(from, reserved_words);
655 maybe_quote(to, reserved_words);
656 }
657 }
658 }
659
660 Expression::Alias(alias) => {
662 maybe_quote(&mut alias.alias, reserved_words);
663 for ca in &mut alias.column_aliases {
664 maybe_quote(ca, reserved_words);
665 }
666 quote_identifiers_recursive(&mut alias.this, reserved_words);
667 }
668
669 Expression::Select(select) => {
671 for e in &mut select.expressions {
672 quote_identifiers_recursive(e, reserved_words);
673 }
674 if let Some(ref mut from) = select.from {
675 for e in &mut from.expressions {
676 quote_identifiers_recursive(e, reserved_words);
677 }
678 }
679 for join in &mut select.joins {
680 quote_join(join, reserved_words);
681 }
682 for lv in &mut select.lateral_views {
683 quote_lateral_view(lv, reserved_words);
684 }
685 if let Some(ref mut prewhere) = select.prewhere {
686 quote_identifiers_recursive(prewhere, reserved_words);
687 }
688 if let Some(ref mut wh) = select.where_clause {
689 quote_identifiers_recursive(&mut wh.this, reserved_words);
690 }
691 if let Some(ref mut gb) = select.group_by {
692 for e in &mut gb.expressions {
693 quote_identifiers_recursive(e, reserved_words);
694 }
695 }
696 if let Some(ref mut hv) = select.having {
697 quote_identifiers_recursive(&mut hv.this, reserved_words);
698 }
699 if let Some(ref mut q) = select.qualify {
700 quote_identifiers_recursive(&mut q.this, reserved_words);
701 }
702 if let Some(ref mut ob) = select.order_by {
703 for o in &mut ob.expressions {
704 quote_identifiers_recursive(&mut o.this, reserved_words);
705 }
706 }
707 if let Some(ref mut lim) = select.limit {
708 quote_identifiers_recursive(&mut lim.this, reserved_words);
709 }
710 if let Some(ref mut off) = select.offset {
711 quote_identifiers_recursive(&mut off.this, reserved_words);
712 }
713 if let Some(ref mut with) = select.with {
714 quote_with(with, reserved_words);
715 }
716 if let Some(ref mut windows) = select.windows {
717 for nw in windows {
718 maybe_quote(&mut nw.name, reserved_words);
719 quote_over(&mut nw.spec, reserved_words);
720 }
721 }
722 if let Some(ref mut distinct_on) = select.distinct_on {
723 for e in distinct_on {
724 quote_identifiers_recursive(e, reserved_words);
725 }
726 }
727 if let Some(ref mut limit_by) = select.limit_by {
728 for e in limit_by {
729 quote_identifiers_recursive(e, reserved_words);
730 }
731 }
732 if let Some(ref mut settings) = select.settings {
733 for e in settings {
734 quote_identifiers_recursive(e, reserved_words);
735 }
736 }
737 if let Some(ref mut format) = select.format {
738 quote_identifiers_recursive(format, reserved_words);
739 }
740 }
741
742 Expression::Union(u) => {
744 quote_identifiers_recursive(&mut u.left, reserved_words);
745 quote_identifiers_recursive(&mut u.right, reserved_words);
746 if let Some(ref mut with) = u.with {
747 quote_with(with, reserved_words);
748 }
749 if let Some(ref mut ob) = u.order_by {
750 for o in &mut ob.expressions {
751 quote_identifiers_recursive(&mut o.this, reserved_words);
752 }
753 }
754 if let Some(ref mut lim) = u.limit {
755 quote_identifiers_recursive(lim, reserved_words);
756 }
757 if let Some(ref mut off) = u.offset {
758 quote_identifiers_recursive(off, reserved_words);
759 }
760 }
761 Expression::Intersect(i) => {
762 quote_identifiers_recursive(&mut i.left, reserved_words);
763 quote_identifiers_recursive(&mut i.right, reserved_words);
764 if let Some(ref mut with) = i.with {
765 quote_with(with, reserved_words);
766 }
767 if let Some(ref mut ob) = i.order_by {
768 for o in &mut ob.expressions {
769 quote_identifiers_recursive(&mut o.this, reserved_words);
770 }
771 }
772 }
773 Expression::Except(e) => {
774 quote_identifiers_recursive(&mut e.left, reserved_words);
775 quote_identifiers_recursive(&mut e.right, reserved_words);
776 if let Some(ref mut with) = e.with {
777 quote_with(with, reserved_words);
778 }
779 if let Some(ref mut ob) = e.order_by {
780 for o in &mut ob.expressions {
781 quote_identifiers_recursive(&mut o.this, reserved_words);
782 }
783 }
784 }
785
786 Expression::Subquery(sq) => {
788 quote_identifiers_recursive(&mut sq.this, reserved_words);
789 if let Some(ref mut alias) = sq.alias {
790 maybe_quote(alias, reserved_words);
791 }
792 for ca in &mut sq.column_aliases {
793 maybe_quote(ca, reserved_words);
794 }
795 if let Some(ref mut ob) = sq.order_by {
796 for o in &mut ob.expressions {
797 quote_identifiers_recursive(&mut o.this, reserved_words);
798 }
799 }
800 }
801
802 Expression::Insert(ins) => {
804 quote_table_ref(&mut ins.table, reserved_words);
805 for c in &mut ins.columns {
806 maybe_quote(c, reserved_words);
807 }
808 for row in &mut ins.values {
809 for e in row {
810 quote_identifiers_recursive(e, reserved_words);
811 }
812 }
813 if let Some(ref mut q) = ins.query {
814 quote_identifiers_recursive(q, reserved_words);
815 }
816 for (id, val) in &mut ins.partition {
817 maybe_quote(id, reserved_words);
818 if let Some(ref mut v) = val {
819 quote_identifiers_recursive(v, reserved_words);
820 }
821 }
822 for e in &mut ins.returning {
823 quote_identifiers_recursive(e, reserved_words);
824 }
825 if let Some(ref mut on_conflict) = ins.on_conflict {
826 quote_identifiers_recursive(on_conflict, reserved_words);
827 }
828 if let Some(ref mut with) = ins.with {
829 quote_with(with, reserved_words);
830 }
831 if let Some(ref mut alias) = ins.alias {
832 maybe_quote(alias, reserved_words);
833 }
834 if let Some(ref mut src_alias) = ins.source_alias {
835 maybe_quote(src_alias, reserved_words);
836 }
837 }
838
839 Expression::Update(upd) => {
840 quote_table_ref(&mut upd.table, reserved_words);
841 for tr in &mut upd.extra_tables {
842 quote_table_ref(tr, reserved_words);
843 }
844 for join in &mut upd.table_joins {
845 quote_join(join, reserved_words);
846 }
847 for (id, val) in &mut upd.set {
848 maybe_quote(id, reserved_words);
849 quote_identifiers_recursive(val, reserved_words);
850 }
851 if let Some(ref mut from) = upd.from_clause {
852 for e in &mut from.expressions {
853 quote_identifiers_recursive(e, reserved_words);
854 }
855 }
856 for join in &mut upd.from_joins {
857 quote_join(join, reserved_words);
858 }
859 if let Some(ref mut wh) = upd.where_clause {
860 quote_identifiers_recursive(&mut wh.this, reserved_words);
861 }
862 for e in &mut upd.returning {
863 quote_identifiers_recursive(e, reserved_words);
864 }
865 if let Some(ref mut with) = upd.with {
866 quote_with(with, reserved_words);
867 }
868 }
869
870 Expression::Delete(del) => {
871 quote_table_ref(&mut del.table, reserved_words);
872 if let Some(ref mut alias) = del.alias {
873 maybe_quote(alias, reserved_words);
874 }
875 for tr in &mut del.using {
876 quote_table_ref(tr, reserved_words);
877 }
878 if let Some(ref mut wh) = del.where_clause {
879 quote_identifiers_recursive(&mut wh.this, reserved_words);
880 }
881 if let Some(ref mut with) = del.with {
882 quote_with(with, reserved_words);
883 }
884 }
885
886 Expression::And(bin) | Expression::Or(bin) | Expression::Eq(bin) |
888 Expression::Neq(bin) | Expression::Lt(bin) | Expression::Lte(bin) |
889 Expression::Gt(bin) | Expression::Gte(bin) | Expression::Add(bin) |
890 Expression::Sub(bin) | Expression::Mul(bin) | Expression::Div(bin) |
891 Expression::Mod(bin) | Expression::BitwiseAnd(bin) |
892 Expression::BitwiseOr(bin) | Expression::BitwiseXor(bin) |
893 Expression::Concat(bin) | Expression::Adjacent(bin) |
894 Expression::TsMatch(bin) | Expression::PropertyEQ(bin) |
895 Expression::ArrayContainsAll(bin) | Expression::ArrayContainedBy(bin) |
896 Expression::ArrayOverlaps(bin) | Expression::JSONBContainsAllTopKeys(bin) |
897 Expression::JSONBContainsAnyTopKeys(bin) | Expression::JSONBDeleteAtPath(bin) |
898 Expression::ExtendsLeft(bin) | Expression::ExtendsRight(bin) |
899 Expression::Is(bin) | Expression::NullSafeEq(bin) |
900 Expression::NullSafeNeq(bin) | Expression::Glob(bin) |
901 Expression::Match(bin) | Expression::MemberOf(bin) |
902 Expression::BitwiseLeftShift(bin) | Expression::BitwiseRightShift(bin) => {
903 quote_identifiers_recursive(&mut bin.left, reserved_words);
904 quote_identifiers_recursive(&mut bin.right, reserved_words);
905 }
906
907 Expression::Like(like) | Expression::ILike(like) => {
909 quote_identifiers_recursive(&mut like.left, reserved_words);
910 quote_identifiers_recursive(&mut like.right, reserved_words);
911 if let Some(ref mut esc) = like.escape {
912 quote_identifiers_recursive(esc, reserved_words);
913 }
914 }
915
916 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
918 quote_identifiers_recursive(&mut un.this, reserved_words);
919 }
920
921 Expression::In(in_expr) => {
923 quote_identifiers_recursive(&mut in_expr.this, reserved_words);
924 for e in &mut in_expr.expressions {
925 quote_identifiers_recursive(e, reserved_words);
926 }
927 if let Some(ref mut q) = in_expr.query {
928 quote_identifiers_recursive(q, reserved_words);
929 }
930 if let Some(ref mut un) = in_expr.unnest {
931 quote_identifiers_recursive(un, reserved_words);
932 }
933 }
934
935 Expression::Between(bw) => {
936 quote_identifiers_recursive(&mut bw.this, reserved_words);
937 quote_identifiers_recursive(&mut bw.low, reserved_words);
938 quote_identifiers_recursive(&mut bw.high, reserved_words);
939 }
940
941 Expression::IsNull(is_null) => {
942 quote_identifiers_recursive(&mut is_null.this, reserved_words);
943 }
944
945 Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
946 quote_identifiers_recursive(&mut is_tf.this, reserved_words);
947 }
948
949 Expression::Exists(ex) => {
950 quote_identifiers_recursive(&mut ex.this, reserved_words);
951 }
952
953 Expression::Function(func) => {
955 for arg in &mut func.args {
956 quote_identifiers_recursive(arg, reserved_words);
957 }
958 }
959
960 Expression::AggregateFunction(agg) => {
961 for arg in &mut agg.args {
962 quote_identifiers_recursive(arg, reserved_words);
963 }
964 if let Some(ref mut filter) = agg.filter {
965 quote_identifiers_recursive(filter, reserved_words);
966 }
967 for o in &mut agg.order_by {
968 quote_identifiers_recursive(&mut o.this, reserved_words);
969 }
970 }
971
972 Expression::WindowFunction(wf) => {
973 quote_identifiers_recursive(&mut wf.this, reserved_words);
974 quote_over(&mut wf.over, reserved_words);
975 }
976
977 Expression::Case(case) => {
979 if let Some(ref mut operand) = case.operand {
980 quote_identifiers_recursive(operand, reserved_words);
981 }
982 for (when, then) in &mut case.whens {
983 quote_identifiers_recursive(when, reserved_words);
984 quote_identifiers_recursive(then, reserved_words);
985 }
986 if let Some(ref mut else_) = case.else_ {
987 quote_identifiers_recursive(else_, reserved_words);
988 }
989 }
990
991 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
993 quote_identifiers_recursive(&mut cast.this, reserved_words);
994 if let Some(ref mut fmt) = cast.format {
995 quote_identifiers_recursive(fmt, reserved_words);
996 }
997 }
998
999 Expression::Paren(paren) => {
1001 quote_identifiers_recursive(&mut paren.this, reserved_words);
1002 }
1003
1004 Expression::Annotated(ann) => {
1005 quote_identifiers_recursive(&mut ann.this, reserved_words);
1006 }
1007
1008 Expression::With(with) => {
1010 quote_with(with, reserved_words);
1011 }
1012
1013 Expression::Cte(cte) => {
1014 maybe_quote(&mut cte.alias, reserved_words);
1015 for c in &mut cte.columns {
1016 maybe_quote(c, reserved_words);
1017 }
1018 quote_identifiers_recursive(&mut cte.this, reserved_words);
1019 }
1020
1021 Expression::From(from) => {
1023 for e in &mut from.expressions {
1024 quote_identifiers_recursive(e, reserved_words);
1025 }
1026 }
1027
1028 Expression::Join(join) => {
1029 quote_join(join, reserved_words);
1030 }
1031
1032 Expression::JoinedTable(jt) => {
1033 quote_identifiers_recursive(&mut jt.left, reserved_words);
1034 for join in &mut jt.joins {
1035 quote_join(join, reserved_words);
1036 }
1037 if let Some(ref mut alias) = jt.alias {
1038 maybe_quote(alias, reserved_words);
1039 }
1040 }
1041
1042 Expression::Where(wh) => {
1043 quote_identifiers_recursive(&mut wh.this, reserved_words);
1044 }
1045
1046 Expression::GroupBy(gb) => {
1047 for e in &mut gb.expressions {
1048 quote_identifiers_recursive(e, reserved_words);
1049 }
1050 }
1051
1052 Expression::Having(hv) => {
1053 quote_identifiers_recursive(&mut hv.this, reserved_words);
1054 }
1055
1056 Expression::OrderBy(ob) => {
1057 for o in &mut ob.expressions {
1058 quote_identifiers_recursive(&mut o.this, reserved_words);
1059 }
1060 }
1061
1062 Expression::Ordered(ord) => {
1063 quote_identifiers_recursive(&mut ord.this, reserved_words);
1064 }
1065
1066 Expression::Limit(lim) => {
1067 quote_identifiers_recursive(&mut lim.this, reserved_words);
1068 }
1069
1070 Expression::Offset(off) => {
1071 quote_identifiers_recursive(&mut off.this, reserved_words);
1072 }
1073
1074 Expression::Qualify(q) => {
1075 quote_identifiers_recursive(&mut q.this, reserved_words);
1076 }
1077
1078 Expression::Window(ws) => {
1079 for e in &mut ws.partition_by {
1080 quote_identifiers_recursive(e, reserved_words);
1081 }
1082 for o in &mut ws.order_by {
1083 quote_identifiers_recursive(&mut o.this, reserved_words);
1084 }
1085 }
1086
1087 Expression::Over(over) => {
1088 quote_over(over, reserved_words);
1089 }
1090
1091 Expression::WithinGroup(wg) => {
1092 quote_identifiers_recursive(&mut wg.this, reserved_words);
1093 for o in &mut wg.order_by {
1094 quote_identifiers_recursive(&mut o.this, reserved_words);
1095 }
1096 }
1097
1098 Expression::Pivot(piv) => {
1100 quote_identifiers_recursive(&mut piv.this, reserved_words);
1101 for e in &mut piv.expressions {
1102 quote_identifiers_recursive(e, reserved_words);
1103 }
1104 for f in &mut piv.fields {
1105 quote_identifiers_recursive(f, reserved_words);
1106 }
1107 if let Some(ref mut alias) = piv.alias {
1108 maybe_quote(alias, reserved_words);
1109 }
1110 }
1111
1112 Expression::Unpivot(unpiv) => {
1113 quote_identifiers_recursive(&mut unpiv.this, reserved_words);
1114 maybe_quote(&mut unpiv.value_column, reserved_words);
1115 maybe_quote(&mut unpiv.name_column, reserved_words);
1116 for e in &mut unpiv.columns {
1117 quote_identifiers_recursive(e, reserved_words);
1118 }
1119 if let Some(ref mut alias) = unpiv.alias {
1120 maybe_quote(alias, reserved_words);
1121 }
1122 }
1123
1124 Expression::Values(vals) => {
1126 for tuple in &mut vals.expressions {
1127 for e in &mut tuple.expressions {
1128 quote_identifiers_recursive(e, reserved_words);
1129 }
1130 }
1131 if let Some(ref mut alias) = vals.alias {
1132 maybe_quote(alias, reserved_words);
1133 }
1134 for ca in &mut vals.column_aliases {
1135 maybe_quote(ca, reserved_words);
1136 }
1137 }
1138
1139 Expression::Array(arr) => {
1141 for e in &mut arr.expressions {
1142 quote_identifiers_recursive(e, reserved_words);
1143 }
1144 }
1145
1146 Expression::Struct(st) => {
1147 for (_name, e) in &mut st.fields {
1148 quote_identifiers_recursive(e, reserved_words);
1149 }
1150 }
1151
1152 Expression::Tuple(tup) => {
1153 for e in &mut tup.expressions {
1154 quote_identifiers_recursive(e, reserved_words);
1155 }
1156 }
1157
1158 Expression::Subscript(sub) => {
1160 quote_identifiers_recursive(&mut sub.this, reserved_words);
1161 quote_identifiers_recursive(&mut sub.index, reserved_words);
1162 }
1163
1164 Expression::Dot(dot) => {
1165 quote_identifiers_recursive(&mut dot.this, reserved_words);
1166 maybe_quote(&mut dot.field, reserved_words);
1167 }
1168
1169 Expression::ScopeResolution(sr) => {
1170 if let Some(ref mut this) = sr.this {
1171 quote_identifiers_recursive(this, reserved_words);
1172 }
1173 quote_identifiers_recursive(&mut sr.expression, reserved_words);
1174 }
1175
1176 Expression::Lateral(lat) => {
1178 quote_identifiers_recursive(&mut lat.this, reserved_words);
1179 }
1181
1182 Expression::DPipe(dpipe) => {
1184 quote_identifiers_recursive(&mut dpipe.this, reserved_words);
1185 quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
1186 }
1187
1188 Expression::Merge(merge) => {
1190 quote_identifiers_recursive(&mut merge.this, reserved_words);
1191 quote_identifiers_recursive(&mut merge.using, reserved_words);
1192 if let Some(ref mut on) = merge.on {
1193 quote_identifiers_recursive(on, reserved_words);
1194 }
1195 if let Some(ref mut whens) = merge.whens {
1196 quote_identifiers_recursive(whens, reserved_words);
1197 }
1198 if let Some(ref mut with) = merge.with_ {
1199 quote_identifiers_recursive(with, reserved_words);
1200 }
1201 if let Some(ref mut ret) = merge.returning {
1202 quote_identifiers_recursive(ret, reserved_words);
1203 }
1204 }
1205
1206 Expression::LateralView(lv) => {
1208 quote_lateral_view(lv, reserved_words);
1209 }
1210
1211 Expression::Anonymous(anon) => {
1213 quote_identifiers_recursive(&mut anon.this, reserved_words);
1214 for e in &mut anon.expressions {
1215 quote_identifiers_recursive(e, reserved_words);
1216 }
1217 }
1218
1219 Expression::Filter(filter) => {
1221 quote_identifiers_recursive(&mut filter.this, reserved_words);
1222 quote_identifiers_recursive(&mut filter.expression, reserved_words);
1223 }
1224
1225 Expression::Returning(ret) => {
1227 for e in &mut ret.expressions {
1228 quote_identifiers_recursive(e, reserved_words);
1229 }
1230 }
1231
1232 Expression::BracedWildcard(inner) => {
1234 quote_identifiers_recursive(inner, reserved_words);
1235 }
1236
1237 Expression::ReturnStmt(inner) => {
1239 quote_identifiers_recursive(inner, reserved_words);
1240 }
1241
1242 Expression::Literal(_)
1244 | Expression::Boolean(_)
1245 | Expression::Null(_)
1246 | Expression::DataType(_)
1247 | Expression::Raw(_)
1248 | Expression::Placeholder(_)
1249 | Expression::CurrentDate(_)
1250 | Expression::CurrentTime(_)
1251 | Expression::CurrentTimestamp(_)
1252 | Expression::CurrentTimestampLTZ(_)
1253 | Expression::SessionUser(_)
1254 | Expression::RowNumber(_)
1255 | Expression::Rank(_)
1256 | Expression::DenseRank(_)
1257 | Expression::PercentRank(_)
1258 | Expression::CumeDist(_)
1259 | Expression::Random(_)
1260 | Expression::Pi(_)
1261 | Expression::JSONPathRoot(_) => {
1262 }
1264
1265 _ => {}
1269 }
1270}
1271
1272fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
1274 quote_identifiers_recursive(&mut join.this, reserved_words);
1275 if let Some(ref mut on) = join.on {
1276 quote_identifiers_recursive(on, reserved_words);
1277 }
1278 for id in &mut join.using {
1279 maybe_quote(id, reserved_words);
1280 }
1281 if let Some(ref mut mc) = join.match_condition {
1282 quote_identifiers_recursive(mc, reserved_words);
1283 }
1284 for piv in &mut join.pivots {
1285 quote_identifiers_recursive(piv, reserved_words);
1286 }
1287}
1288
1289fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
1291 for cte in &mut with.ctes {
1292 maybe_quote(&mut cte.alias, reserved_words);
1293 for c in &mut cte.columns {
1294 maybe_quote(c, reserved_words);
1295 }
1296 for k in &mut cte.key_expressions {
1297 maybe_quote(k, reserved_words);
1298 }
1299 quote_identifiers_recursive(&mut cte.this, reserved_words);
1300 }
1301}
1302
1303fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
1305 if let Some(ref mut wn) = over.window_name {
1306 maybe_quote(wn, reserved_words);
1307 }
1308 for e in &mut over.partition_by {
1309 quote_identifiers_recursive(e, reserved_words);
1310 }
1311 for o in &mut over.order_by {
1312 quote_identifiers_recursive(&mut o.this, reserved_words);
1313 }
1314 if let Some(ref mut alias) = over.alias {
1315 maybe_quote(alias, reserved_words);
1316 }
1317}
1318
1319fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
1321 maybe_quote(&mut table_ref.name, reserved_words);
1322 if let Some(ref mut schema) = table_ref.schema {
1323 maybe_quote(schema, reserved_words);
1324 }
1325 if let Some(ref mut catalog) = table_ref.catalog {
1326 maybe_quote(catalog, reserved_words);
1327 }
1328 if let Some(ref mut alias) = table_ref.alias {
1329 maybe_quote(alias, reserved_words);
1330 }
1331 for ca in &mut table_ref.column_aliases {
1332 maybe_quote(ca, reserved_words);
1333 }
1334 for p in &mut table_ref.partitions {
1335 maybe_quote(p, reserved_words);
1336 }
1337 for h in &mut table_ref.hints {
1338 quote_identifiers_recursive(h, reserved_words);
1339 }
1340}
1341
1342fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
1344 quote_identifiers_recursive(&mut lv.this, reserved_words);
1345 if let Some(ref mut ta) = lv.table_alias {
1346 maybe_quote(ta, reserved_words);
1347 }
1348 for ca in &mut lv.column_aliases {
1349 maybe_quote(ca, reserved_words);
1350 }
1351}
1352
1353pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
1364 let reserved_words = get_reserved_words(dialect);
1365 let mut result = expression;
1366 quote_identifiers_recursive(&mut result, &reserved_words);
1367 result
1368}
1369
1370pub fn pushdown_cte_alias_columns(_scope: &Scope) {
1375 }
1378
1379fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
1385 let mut columns = Vec::new();
1386 collect_columns(&scope.expression, &mut columns);
1387 columns
1388}
1389
1390#[derive(Debug, Clone)]
1392struct ColumnRef {
1393 table: Option<String>,
1394 name: String,
1395}
1396
1397fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
1399 match expr {
1400 Expression::Column(col) => {
1401 columns.push(ColumnRef {
1402 table: col.table.as_ref().map(|t| t.name.clone()),
1403 name: col.name.name.clone(),
1404 });
1405 }
1406 Expression::Select(select) => {
1407 for e in &select.expressions {
1408 collect_columns(e, columns);
1409 }
1410 if let Some(from) = &select.from {
1411 for e in &from.expressions {
1412 collect_columns(e, columns);
1413 }
1414 }
1415 if let Some(where_clause) = &select.where_clause {
1416 collect_columns(&where_clause.this, columns);
1417 }
1418 if let Some(group_by) = &select.group_by {
1419 for e in &group_by.expressions {
1420 collect_columns(e, columns);
1421 }
1422 }
1423 if let Some(having) = &select.having {
1424 collect_columns(&having.this, columns);
1425 }
1426 if let Some(order_by) = &select.order_by {
1427 for o in &order_by.expressions {
1428 collect_columns(&o.this, columns);
1429 }
1430 }
1431 for join in &select.joins {
1432 collect_columns(&join.this, columns);
1433 if let Some(on) = &join.on {
1434 collect_columns(on, columns);
1435 }
1436 }
1437 }
1438 Expression::Alias(alias) => {
1439 collect_columns(&alias.this, columns);
1440 }
1441 Expression::Function(func) => {
1442 for arg in &func.args {
1443 collect_columns(arg, columns);
1444 }
1445 }
1446 Expression::AggregateFunction(agg) => {
1447 for arg in &agg.args {
1448 collect_columns(arg, columns);
1449 }
1450 }
1451 Expression::And(bin)
1452 | Expression::Or(bin)
1453 | Expression::Eq(bin)
1454 | Expression::Neq(bin)
1455 | Expression::Lt(bin)
1456 | Expression::Lte(bin)
1457 | Expression::Gt(bin)
1458 | Expression::Gte(bin)
1459 | Expression::Add(bin)
1460 | Expression::Sub(bin)
1461 | Expression::Mul(bin)
1462 | Expression::Div(bin) => {
1463 collect_columns(&bin.left, columns);
1464 collect_columns(&bin.right, columns);
1465 }
1466 Expression::Not(unary) | Expression::Neg(unary) => {
1467 collect_columns(&unary.this, columns);
1468 }
1469 Expression::Paren(paren) => {
1470 collect_columns(&paren.this, columns);
1471 }
1472 Expression::Case(case) => {
1473 if let Some(operand) = &case.operand {
1474 collect_columns(operand, columns);
1475 }
1476 for (when, then) in &case.whens {
1477 collect_columns(when, columns);
1478 collect_columns(then, columns);
1479 }
1480 if let Some(else_) = &case.else_ {
1481 collect_columns(else_, columns);
1482 }
1483 }
1484 Expression::Cast(cast) => {
1485 collect_columns(&cast.this, columns);
1486 }
1487 Expression::In(in_expr) => {
1488 collect_columns(&in_expr.this, columns);
1489 for e in &in_expr.expressions {
1490 collect_columns(e, columns);
1491 }
1492 if let Some(query) = &in_expr.query {
1493 collect_columns(query, columns);
1494 }
1495 }
1496 Expression::Between(between) => {
1497 collect_columns(&between.this, columns);
1498 collect_columns(&between.low, columns);
1499 collect_columns(&between.high, columns);
1500 }
1501 Expression::Subquery(subquery) => {
1502 collect_columns(&subquery.this, columns);
1503 }
1504 _ => {}
1505 }
1506}
1507
1508fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
1510 get_scope_columns(scope)
1511 .into_iter()
1512 .filter(|c| c.table.is_none())
1513 .collect()
1514}
1515
1516fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
1518 let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
1519
1520 get_scope_columns(scope)
1521 .into_iter()
1522 .filter(|c| {
1523 if let Some(table) = &c.table {
1524 !source_names.contains(table)
1525 } else {
1526 false
1527 }
1528 })
1529 .collect()
1530}
1531
1532fn is_correlated_subquery(scope: &Scope) -> bool {
1534 scope.can_be_correlated && !get_external_columns(scope).is_empty()
1535}
1536
1537fn is_star_column(col: &Column) -> bool {
1539 col.name.name == "*"
1540}
1541
1542fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
1544 Expression::Column(Column {
1545 name: Identifier::new(name),
1546 table: table.map(Identifier::new),
1547 join_mark: false,
1548 trailing_comments: vec![],
1549 })
1550}
1551
1552fn create_alias(expr: Expression, alias_name: &str) -> Expression {
1554 Expression::Alias(Box::new(Alias {
1555 this: expr,
1556 alias: Identifier::new(alias_name),
1557 column_aliases: vec![],
1558 pre_alias_comments: vec![],
1559 trailing_comments: vec![],
1560 }))
1561}
1562
1563fn get_output_name(expr: &Expression) -> Option<String> {
1565 match expr {
1566 Expression::Column(col) => Some(col.name.name.clone()),
1567 Expression::Alias(alias) => Some(alias.alias.name.clone()),
1568 Expression::Identifier(id) => Some(id.name.clone()),
1569 _ => None,
1570 }
1571}
1572
1573#[cfg(test)]
1574mod tests {
1575 use super::*;
1576 use crate::generator::Generator;
1577 use crate::parser::Parser;
1578 use crate::scope::build_scope;
1579
1580 fn gen(expr: &Expression) -> String {
1581 Generator::new().generate(expr).unwrap()
1582 }
1583
1584 fn parse(sql: &str) -> Expression {
1585 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1586 }
1587
1588 #[test]
1589 fn test_qualify_columns_options() {
1590 let options = QualifyColumnsOptions::new()
1591 .with_expand_alias_refs(true)
1592 .with_expand_stars(false)
1593 .with_dialect(DialectType::PostgreSQL)
1594 .with_allow_partial(true);
1595
1596 assert!(options.expand_alias_refs);
1597 assert!(!options.expand_stars);
1598 assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
1599 assert!(options.allow_partial_qualification);
1600 }
1601
1602 #[test]
1603 fn test_get_scope_columns() {
1604 let expr = parse("SELECT a, b FROM t WHERE c = 1");
1605 let scope = build_scope(&expr);
1606 let columns = get_scope_columns(&scope);
1607
1608 assert!(columns.iter().any(|c| c.name == "a"));
1609 assert!(columns.iter().any(|c| c.name == "b"));
1610 assert!(columns.iter().any(|c| c.name == "c"));
1611 }
1612
1613 #[test]
1614 fn test_get_unqualified_columns() {
1615 let expr = parse("SELECT t.a, b FROM t");
1616 let scope = build_scope(&expr);
1617 let unqualified = get_unqualified_columns(&scope);
1618
1619 assert!(unqualified.iter().any(|c| c.name == "b"));
1621 assert!(!unqualified.iter().any(|c| c.name == "a"));
1622 }
1623
1624 #[test]
1625 fn test_is_star_column() {
1626 let col = Column {
1627 name: Identifier::new("*"),
1628 table: Some(Identifier::new("t")),
1629 join_mark: false,
1630 trailing_comments: vec![],
1631 };
1632 assert!(is_star_column(&col));
1633
1634 let col2 = Column {
1635 name: Identifier::new("id"),
1636 table: None,
1637 join_mark: false,
1638 trailing_comments: vec![],
1639 };
1640 assert!(!is_star_column(&col2));
1641 }
1642
1643 #[test]
1644 fn test_create_qualified_column() {
1645 let expr = create_qualified_column("id", Some("users"));
1646 let sql = gen(&expr);
1647 assert!(sql.contains("users"));
1648 assert!(sql.contains("id"));
1649 }
1650
1651 #[test]
1652 fn test_create_alias() {
1653 let col = Expression::Column(Column {
1654 name: Identifier::new("value"),
1655 table: None,
1656 join_mark: false,
1657 trailing_comments: vec![],
1658 });
1659 let aliased = create_alias(col, "total");
1660 let sql = gen(&aliased);
1661 assert!(sql.contains("AS") || sql.contains("total"));
1662 }
1663
1664 #[test]
1665 fn test_validate_qualify_columns_success() {
1666 let expr = parse("SELECT t.a, t.b FROM t");
1668 let result = validate_qualify_columns(&expr);
1669 let _ = result;
1672 }
1673
1674 #[test]
1675 fn test_collect_columns_nested() {
1676 let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
1677 let mut columns = Vec::new();
1678 collect_columns(&expr, &mut columns);
1679
1680 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
1681 assert!(names.contains(&"a"));
1682 assert!(names.contains(&"b"));
1683 assert!(names.contains(&"c"));
1684 assert!(names.contains(&"d"));
1685 assert!(names.contains(&"e"));
1686 assert!(names.contains(&"f"));
1687 }
1688
1689 #[test]
1690 fn test_collect_columns_in_case() {
1691 let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
1692 let mut columns = Vec::new();
1693 collect_columns(&expr, &mut columns);
1694
1695 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
1696 assert!(names.contains(&"a"));
1697 assert!(names.contains(&"b"));
1698 assert!(names.contains(&"c"));
1699 }
1700
1701 #[test]
1702 fn test_collect_columns_in_subquery() {
1703 let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
1704 let mut columns = Vec::new();
1705 collect_columns(&expr, &mut columns);
1706
1707 let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
1708 assert!(names.contains(&"a"));
1709 assert!(names.contains(&"b"));
1710 assert!(names.contains(&"c"));
1711 }
1712
1713 #[test]
1714 fn test_qualify_outputs_basic() {
1715 let expr = parse("SELECT a, b + c FROM t");
1716 let scope = build_scope(&expr);
1717 let result = qualify_outputs(&scope);
1718 assert!(result.is_ok());
1719 }
1720
1721 #[test]
1726 fn test_needs_quoting_reserved_word() {
1727 let reserved = get_reserved_words(None);
1728 assert!(needs_quoting("select", &reserved));
1729 assert!(needs_quoting("SELECT", &reserved));
1730 assert!(needs_quoting("from", &reserved));
1731 assert!(needs_quoting("WHERE", &reserved));
1732 assert!(needs_quoting("join", &reserved));
1733 assert!(needs_quoting("table", &reserved));
1734 }
1735
1736 #[test]
1737 fn test_needs_quoting_normal_identifiers() {
1738 let reserved = get_reserved_words(None);
1739 assert!(!needs_quoting("foo", &reserved));
1740 assert!(!needs_quoting("my_column", &reserved));
1741 assert!(!needs_quoting("col1", &reserved));
1742 assert!(!needs_quoting("A", &reserved));
1743 assert!(!needs_quoting("_hidden", &reserved));
1744 }
1745
1746 #[test]
1747 fn test_needs_quoting_special_characters() {
1748 let reserved = get_reserved_words(None);
1749 assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
1755
1756 #[test]
1757 fn test_needs_quoting_starts_with_digit() {
1758 let reserved = get_reserved_words(None);
1759 assert!(needs_quoting("1col", &reserved));
1760 assert!(needs_quoting("123", &reserved));
1761 assert!(needs_quoting("0_start", &reserved));
1762 }
1763
1764 #[test]
1765 fn test_needs_quoting_empty() {
1766 let reserved = get_reserved_words(None);
1767 assert!(!needs_quoting("", &reserved));
1768 }
1769
1770 #[test]
1771 fn test_maybe_quote_sets_quoted_flag() {
1772 let reserved = get_reserved_words(None);
1773 let mut id = Identifier::new("select");
1774 assert!(!id.quoted);
1775 maybe_quote(&mut id, &reserved);
1776 assert!(id.quoted);
1777 }
1778
1779 #[test]
1780 fn test_maybe_quote_skips_already_quoted() {
1781 let reserved = get_reserved_words(None);
1782 let mut id = Identifier::quoted("myname");
1783 assert!(id.quoted);
1784 maybe_quote(&mut id, &reserved);
1785 assert!(id.quoted); assert_eq!(id.name, "myname"); }
1788
1789 #[test]
1790 fn test_maybe_quote_skips_star() {
1791 let reserved = get_reserved_words(None);
1792 let mut id = Identifier::new("*");
1793 maybe_quote(&mut id, &reserved);
1794 assert!(!id.quoted); }
1796
1797 #[test]
1798 fn test_maybe_quote_skips_normal() {
1799 let reserved = get_reserved_words(None);
1800 let mut id = Identifier::new("normal_col");
1801 maybe_quote(&mut id, &reserved);
1802 assert!(!id.quoted);
1803 }
1804
1805 #[test]
1806 fn test_quote_identifiers_column_with_reserved_name() {
1807 let expr = Expression::Column(Column {
1809 name: Identifier::new("select"),
1810 table: None,
1811 join_mark: false,
1812 trailing_comments: vec![],
1813 });
1814 let result = quote_identifiers(expr, None);
1815 if let Expression::Column(col) = &result {
1816 assert!(col.name.quoted, "Column named 'select' should be quoted");
1817 } else {
1818 panic!("Expected Column expression");
1819 }
1820 }
1821
1822 #[test]
1823 fn test_quote_identifiers_column_with_special_chars() {
1824 let expr = Expression::Column(Column {
1825 name: Identifier::new("my column"),
1826 table: None,
1827 join_mark: false,
1828 trailing_comments: vec![],
1829 });
1830 let result = quote_identifiers(expr, None);
1831 if let Expression::Column(col) = &result {
1832 assert!(col.name.quoted, "Column with space should be quoted");
1833 } else {
1834 panic!("Expected Column expression");
1835 }
1836 }
1837
1838 #[test]
1839 fn test_quote_identifiers_preserves_normal_column() {
1840 let expr = Expression::Column(Column {
1841 name: Identifier::new("normal_col"),
1842 table: Some(Identifier::new("my_table")),
1843 join_mark: false,
1844 trailing_comments: vec![],
1845 });
1846 let result = quote_identifiers(expr, None);
1847 if let Expression::Column(col) = &result {
1848 assert!(!col.name.quoted, "Normal column should not be quoted");
1849 assert!(!col.table.as_ref().unwrap().quoted, "Normal table should not be quoted");
1850 } else {
1851 panic!("Expected Column expression");
1852 }
1853 }
1854
1855 #[test]
1856 fn test_quote_identifiers_table_ref_reserved() {
1857 let expr = Expression::Table(TableRef::new("select"));
1858 let result = quote_identifiers(expr, None);
1859 if let Expression::Table(tr) = &result {
1860 assert!(tr.name.quoted, "Table named 'select' should be quoted");
1861 } else {
1862 panic!("Expected Table expression");
1863 }
1864 }
1865
1866 #[test]
1867 fn test_quote_identifiers_table_ref_schema_and_alias() {
1868 let mut tr = TableRef::new("my_table");
1869 tr.schema = Some(Identifier::new("from"));
1870 tr.alias = Some(Identifier::new("t"));
1871 let expr = Expression::Table(tr);
1872 let result = quote_identifiers(expr, None);
1873 if let Expression::Table(tr) = &result {
1874 assert!(!tr.name.quoted, "Normal table name should not be quoted");
1875 assert!(tr.schema.as_ref().unwrap().quoted, "Schema named 'from' should be quoted");
1876 assert!(!tr.alias.as_ref().unwrap().quoted, "Normal alias should not be quoted");
1877 } else {
1878 panic!("Expected Table expression");
1879 }
1880 }
1881
1882 #[test]
1883 fn test_quote_identifiers_identifier_node() {
1884 let expr = Expression::Identifier(Identifier::new("order"));
1885 let result = quote_identifiers(expr, None);
1886 if let Expression::Identifier(id) = &result {
1887 assert!(id.quoted, "Identifier named 'order' should be quoted");
1888 } else {
1889 panic!("Expected Identifier expression");
1890 }
1891 }
1892
1893 #[test]
1894 fn test_quote_identifiers_alias() {
1895 let inner = Expression::Column(Column {
1896 name: Identifier::new("val"),
1897 table: None,
1898 join_mark: false,
1899 trailing_comments: vec![],
1900 });
1901 let expr = Expression::Alias(Box::new(Alias {
1902 this: inner,
1903 alias: Identifier::new("select"),
1904 column_aliases: vec![Identifier::new("from")],
1905 pre_alias_comments: vec![],
1906 trailing_comments: vec![],
1907 }));
1908 let result = quote_identifiers(expr, None);
1909 if let Expression::Alias(alias) = &result {
1910 assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
1911 assert!(alias.column_aliases[0].quoted, "Column alias named 'from' should be quoted");
1912 if let Expression::Column(col) = &alias.this {
1914 assert!(!col.name.quoted);
1915 }
1916 } else {
1917 panic!("Expected Alias expression");
1918 }
1919 }
1920
1921 #[test]
1922 fn test_quote_identifiers_select_recursive() {
1923 let expr = parse("SELECT a, b FROM t WHERE c = 1");
1925 let result = quote_identifiers(expr, None);
1926 let sql = gen(&result);
1928 assert!(sql.contains("a"));
1930 assert!(sql.contains("b"));
1931 assert!(sql.contains("t"));
1932 }
1933
1934 #[test]
1935 fn test_quote_identifiers_digit_start() {
1936 let expr = Expression::Column(Column {
1937 name: Identifier::new("1col"),
1938 table: None,
1939 join_mark: false,
1940 trailing_comments: vec![],
1941 });
1942 let result = quote_identifiers(expr, None);
1943 if let Expression::Column(col) = &result {
1944 assert!(col.name.quoted, "Column starting with digit should be quoted");
1945 } else {
1946 panic!("Expected Column expression");
1947 }
1948 }
1949
1950 #[test]
1951 fn test_quote_identifiers_with_mysql_dialect() {
1952 let reserved = get_reserved_words(Some(DialectType::MySQL));
1953 assert!(needs_quoting("KILL", &reserved));
1955 assert!(needs_quoting("FORCE", &reserved));
1957 }
1958
1959 #[test]
1960 fn test_quote_identifiers_with_postgresql_dialect() {
1961 let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
1962 assert!(needs_quoting("ILIKE", &reserved));
1964 assert!(needs_quoting("VERBOSE", &reserved));
1966 }
1967
1968 #[test]
1969 fn test_quote_identifiers_with_bigquery_dialect() {
1970 let reserved = get_reserved_words(Some(DialectType::BigQuery));
1971 assert!(needs_quoting("STRUCT", &reserved));
1973 assert!(needs_quoting("PROTO", &reserved));
1975 }
1976
1977 #[test]
1978 fn test_quote_identifiers_case_insensitive_reserved() {
1979 let reserved = get_reserved_words(None);
1980 assert!(needs_quoting("Select", &reserved));
1981 assert!(needs_quoting("sElEcT", &reserved));
1982 assert!(needs_quoting("FROM", &reserved));
1983 assert!(needs_quoting("from", &reserved));
1984 }
1985
1986 #[test]
1987 fn test_quote_identifiers_join_using() {
1988 let mut join = crate::expressions::Join {
1990 this: Expression::Table(TableRef::new("other")),
1991 on: None,
1992 using: vec![Identifier::new("key"), Identifier::new("value")],
1993 kind: crate::expressions::JoinKind::Inner,
1994 use_inner_keyword: false,
1995 use_outer_keyword: false,
1996 deferred_condition: false,
1997 join_hint: None,
1998 match_condition: None,
1999 pivots: vec![],
2000 };
2001 let reserved = get_reserved_words(None);
2002 quote_join(&mut join, &reserved);
2003 assert!(join.using[0].quoted, "USING identifier 'key' should be quoted");
2005 assert!(!join.using[1].quoted, "USING identifier 'value' should not be quoted");
2006 }
2007
2008 #[test]
2009 fn test_quote_identifiers_cte() {
2010 let mut cte = crate::expressions::Cte {
2012 alias: Identifier::new("select"),
2013 this: Expression::Column(Column {
2014 name: Identifier::new("x"),
2015 table: None,
2016 join_mark: false,
2017 trailing_comments: vec![],
2018 }),
2019 columns: vec![Identifier::new("from"), Identifier::new("normal")],
2020 materialized: None,
2021 key_expressions: vec![],
2022 alias_first: false,
2023 };
2024 let reserved = get_reserved_words(None);
2025 maybe_quote(&mut cte.alias, &reserved);
2026 for c in &mut cte.columns {
2027 maybe_quote(c, &reserved);
2028 }
2029 assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
2030 assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
2031 assert!(!cte.columns[1].quoted, "CTE column 'normal' should not be quoted");
2032 }
2033
2034 #[test]
2035 fn test_quote_identifiers_binary_ops_recurse() {
2036 let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
2039 Expression::Column(Column {
2040 name: Identifier::new("select"),
2041 table: None,
2042 join_mark: false,
2043 trailing_comments: vec![],
2044 }),
2045 Expression::Column(Column {
2046 name: Identifier::new("normal"),
2047 table: None,
2048 join_mark: false,
2049 trailing_comments: vec![],
2050 }),
2051 )));
2052 let result = quote_identifiers(expr, None);
2053 if let Expression::Add(bin) = &result {
2054 if let Expression::Column(left) = &bin.left {
2055 assert!(left.name.quoted, "'select' column should be quoted in binary op");
2056 }
2057 if let Expression::Column(right) = &bin.right {
2058 assert!(!right.name.quoted, "'normal' column should not be quoted");
2059 }
2060 } else {
2061 panic!("Expected Add expression");
2062 }
2063 }
2064
2065 #[test]
2066 fn test_quote_identifiers_already_quoted_preserved() {
2067 let expr = Expression::Column(Column {
2069 name: Identifier::quoted("normal_name"),
2070 table: None,
2071 join_mark: false,
2072 trailing_comments: vec![],
2073 });
2074 let result = quote_identifiers(expr, None);
2075 if let Expression::Column(col) = &result {
2076 assert!(col.name.quoted, "Already-quoted identifier should remain quoted");
2077 } else {
2078 panic!("Expected Column expression");
2079 }
2080 }
2081
2082 #[test]
2083 fn test_quote_identifiers_full_parsed_query() {
2084 let mut select = crate::expressions::Select::new();
2087 select.expressions.push(Expression::Column(Column {
2088 name: Identifier::new("order"),
2089 table: Some(Identifier::new("t")),
2090 join_mark: false,
2091 trailing_comments: vec![],
2092 }));
2093 select.from = Some(crate::expressions::From {
2094 expressions: vec![Expression::Table(TableRef::new("t"))],
2095 });
2096 let expr = Expression::Select(Box::new(select));
2097
2098 let result = quote_identifiers(expr, None);
2099 if let Expression::Select(sel) = &result {
2100 if let Expression::Column(col) = &sel.expressions[0] {
2101 assert!(col.name.quoted, "Column named 'order' should be quoted");
2102 assert!(!col.table.as_ref().unwrap().quoted, "Table 't' should not be quoted");
2103 } else {
2104 panic!("Expected Column in SELECT list");
2105 }
2106 } else {
2107 panic!("Expected Select expression");
2108 }
2109 }
2110
2111 #[test]
2112 fn test_get_reserved_words_all_dialects() {
2113 let dialects = [
2115 None,
2116 Some(DialectType::Generic),
2117 Some(DialectType::MySQL),
2118 Some(DialectType::PostgreSQL),
2119 Some(DialectType::BigQuery),
2120 Some(DialectType::Snowflake),
2121 Some(DialectType::TSQL),
2122 Some(DialectType::ClickHouse),
2123 Some(DialectType::DuckDB),
2124 Some(DialectType::Hive),
2125 Some(DialectType::Spark),
2126 Some(DialectType::Trino),
2127 Some(DialectType::Oracle),
2128 Some(DialectType::Redshift),
2129 ];
2130 for dialect in &dialects {
2131 let words = get_reserved_words(*dialect);
2132 assert!(words.contains("SELECT"), "All dialects should have SELECT as reserved");
2134 assert!(words.contains("FROM"), "All dialects should have FROM as reserved");
2135 }
2136 }
2137}