1use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Null, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12 get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19 pub db: Option<String>,
21 pub catalog: Option<String>,
23 pub dialect: Option<DialectType>,
25 pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_db(mut self, db: impl Into<String>) -> Self {
35 self.db = Some(db.into());
36 self
37 }
38
39 pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40 self.catalog = Some(catalog.into());
41 self
42 }
43
44 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45 self.dialect = Some(dialect);
46 self
47 }
48
49 pub fn with_canonical_aliases(mut self) -> Self {
50 self.canonicalize_table_aliases = true;
51 self
52 }
53}
54
55pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77 let strategy = get_normalization_strategy(options.dialect);
78 let mut next_alias = name_sequence("_");
79
80 match expression {
81 Expression::Select(select) => {
82 let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83 Expression::Select(Box::new(qualified))
84 }
85 Expression::Union(mut union) => {
86 let left = std::mem::replace(&mut union.left, Expression::Null(Null));
87 union.left = qualify_tables(left, options);
88 let right = std::mem::replace(&mut union.right, Expression::Null(Null));
89 union.right = qualify_tables(right, options);
90 Expression::Union(union)
91 }
92 Expression::Intersect(mut intersect) => {
93 let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
94 intersect.left = qualify_tables(left, options);
95 let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
96 intersect.right = qualify_tables(right, options);
97 Expression::Intersect(intersect)
98 }
99 Expression::Except(mut except) => {
100 let left = std::mem::replace(&mut except.left, Expression::Null(Null));
101 except.left = qualify_tables(left, options);
102 let right = std::mem::replace(&mut except.right, Expression::Null(Null));
103 except.right = qualify_tables(right, options);
104 Expression::Except(except)
105 }
106 _ => expression,
107 }
108}
109
110fn qualify_select(
112 mut select: Select,
113 options: &QualifyTablesOptions,
114 strategy: NormalizationStrategy,
115 next_alias: &mut impl FnMut() -> String,
116) -> Select {
117 let cte_names: HashSet<String> = select
119 .with
120 .as_ref()
121 .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
122 .unwrap_or_default();
123
124 let mut canonical_aliases: HashMap<String, String> = HashMap::new();
126
127 if let Some(ref mut with) = select.with {
129 for cte in &mut with.ctes {
130 cte.this = qualify_tables(cte.this.clone(), options);
131 }
132 }
133
134 if let Some(ref mut from) = select.from {
136 for expr in &mut from.expressions {
137 *expr = qualify_table_expression(
138 expr.clone(),
139 options,
140 strategy,
141 &cte_names,
142 &mut canonical_aliases,
143 next_alias,
144 );
145 }
146 }
147
148 for join in &mut select.joins {
150 join.this = qualify_table_expression(
151 join.this.clone(),
152 options,
153 strategy,
154 &cte_names,
155 &mut canonical_aliases,
156 next_alias,
157 );
158 }
159
160 if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
162 select = update_column_references(select, &canonical_aliases);
163 }
164
165 select
166}
167
168fn qualify_table_expression(
170 expression: Expression,
171 options: &QualifyTablesOptions,
172 strategy: NormalizationStrategy,
173 cte_names: &HashSet<String>,
174 canonical_aliases: &mut HashMap<String, String>,
175 next_alias: &mut impl FnMut() -> String,
176) -> Expression {
177 match expression {
178 Expression::Table(mut table) => {
179 let table_name = table.name.name.clone();
180
181 if cte_names.contains(&table_name) {
183 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
185 return Expression::Table(table);
186 }
187
188 if let Some(ref db) = options.db {
190 if table.schema.is_none() {
191 table.schema =
192 Some(normalize_identifier(Identifier::new(db.clone()), strategy));
193 }
194 }
195
196 if let Some(ref catalog) = options.catalog {
198 if table.schema.is_some() && table.catalog.is_none() {
199 table.catalog = Some(normalize_identifier(
200 Identifier::new(catalog.clone()),
201 strategy,
202 ));
203 }
204 }
205
206 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
208
209 Expression::Table(table)
210 }
211 Expression::Subquery(mut subquery) => {
212 subquery.this = qualify_tables(subquery.this, options);
214
215 if subquery.alias.is_none() || options.canonicalize_table_aliases {
217 let alias_name = if options.canonicalize_table_aliases {
218 let new_name = next_alias();
219 if let Some(ref old_alias) = subquery.alias {
220 canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
221 }
222 new_name
223 } else {
224 subquery
225 .alias
226 .as_ref()
227 .map(|a| a.name.clone())
228 .unwrap_or_else(|| next_alias())
229 };
230
231 subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
232 }
233
234 Expression::Subquery(subquery)
235 }
236 Expression::Paren(mut paren) => {
237 paren.this = qualify_table_expression(
238 paren.this,
239 options,
240 strategy,
241 cte_names,
242 canonical_aliases,
243 next_alias,
244 );
245 Expression::Paren(paren)
246 }
247 _ => expression,
248 }
249}
250
251fn ensure_table_alias(
253 table: &mut TableRef,
254 strategy: NormalizationStrategy,
255 canonical_aliases: &mut HashMap<String, String>,
256 next_alias: &mut impl FnMut() -> String,
257 options: &QualifyTablesOptions,
258) {
259 let table_name = table.name.name.clone();
260
261 if options.canonicalize_table_aliases {
262 let new_alias = next_alias();
264 let old_alias = table
265 .alias
266 .as_ref()
267 .map(|a| a.name.clone())
268 .unwrap_or(table_name.clone());
269 canonical_aliases.insert(old_alias, new_alias.clone());
270 table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
271 } else if table.alias.is_none() {
272 table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
274 }
275}
276
277fn update_column_references(
279 mut select: Select,
280 canonical_aliases: &HashMap<String, String>,
281) -> Select {
282 select.expressions = select
284 .expressions
285 .into_iter()
286 .map(|e| update_column_in_expression(e, canonical_aliases))
287 .collect();
288
289 if let Some(mut where_clause) = select.where_clause {
291 where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
292 select.where_clause = Some(where_clause);
293 }
294
295 if let Some(mut group_by) = select.group_by {
297 group_by.expressions = group_by
298 .expressions
299 .into_iter()
300 .map(|e| update_column_in_expression(e, canonical_aliases))
301 .collect();
302 select.group_by = Some(group_by);
303 }
304
305 if let Some(mut having) = select.having {
307 having.this = update_column_in_expression(having.this, canonical_aliases);
308 select.having = Some(having);
309 }
310
311 if let Some(mut order_by) = select.order_by {
313 order_by.expressions = order_by
314 .expressions
315 .into_iter()
316 .map(|mut o| {
317 o.this = update_column_in_expression(o.this, canonical_aliases);
318 o
319 })
320 .collect();
321 select.order_by = Some(order_by);
322 }
323
324 for join in &mut select.joins {
326 if let Some(on) = &mut join.on {
327 *on = update_column_in_expression(on.clone(), canonical_aliases);
328 }
329 }
330
331 select
332}
333
334fn update_column_in_expression(
336 expression: Expression,
337 canonical_aliases: &HashMap<String, String>,
338) -> Expression {
339 match expression {
340 Expression::Column(mut col) => {
341 if let Some(ref table) = col.table {
342 if let Some(canonical) = canonical_aliases.get(&table.name) {
343 col.table = Some(Identifier {
344 name: canonical.clone(),
345 quoted: table.quoted,
346 trailing_comments: table.trailing_comments.clone(),
347 span: None,
348 });
349 }
350 }
351 Expression::Column(col)
352 }
353 Expression::And(mut bin) => {
354 bin.left = update_column_in_expression(bin.left, canonical_aliases);
355 bin.right = update_column_in_expression(bin.right, canonical_aliases);
356 Expression::And(bin)
357 }
358 Expression::Or(mut bin) => {
359 bin.left = update_column_in_expression(bin.left, canonical_aliases);
360 bin.right = update_column_in_expression(bin.right, canonical_aliases);
361 Expression::Or(bin)
362 }
363 Expression::Eq(mut bin) => {
364 bin.left = update_column_in_expression(bin.left, canonical_aliases);
365 bin.right = update_column_in_expression(bin.right, canonical_aliases);
366 Expression::Eq(bin)
367 }
368 Expression::Neq(mut bin) => {
369 bin.left = update_column_in_expression(bin.left, canonical_aliases);
370 bin.right = update_column_in_expression(bin.right, canonical_aliases);
371 Expression::Neq(bin)
372 }
373 Expression::Lt(mut bin) => {
374 bin.left = update_column_in_expression(bin.left, canonical_aliases);
375 bin.right = update_column_in_expression(bin.right, canonical_aliases);
376 Expression::Lt(bin)
377 }
378 Expression::Lte(mut bin) => {
379 bin.left = update_column_in_expression(bin.left, canonical_aliases);
380 bin.right = update_column_in_expression(bin.right, canonical_aliases);
381 Expression::Lte(bin)
382 }
383 Expression::Gt(mut bin) => {
384 bin.left = update_column_in_expression(bin.left, canonical_aliases);
385 bin.right = update_column_in_expression(bin.right, canonical_aliases);
386 Expression::Gt(bin)
387 }
388 Expression::Gte(mut bin) => {
389 bin.left = update_column_in_expression(bin.left, canonical_aliases);
390 bin.right = update_column_in_expression(bin.right, canonical_aliases);
391 Expression::Gte(bin)
392 }
393 Expression::Not(mut un) => {
394 un.this = update_column_in_expression(un.this, canonical_aliases);
395 Expression::Not(un)
396 }
397 Expression::Paren(mut paren) => {
398 paren.this = update_column_in_expression(paren.this, canonical_aliases);
399 Expression::Paren(paren)
400 }
401 Expression::Alias(mut alias) => {
402 alias.this = update_column_in_expression(alias.this, canonical_aliases);
403 Expression::Alias(alias)
404 }
405 Expression::Function(mut func) => {
406 func.args = func
407 .args
408 .into_iter()
409 .map(|a| update_column_in_expression(a, canonical_aliases))
410 .collect();
411 Expression::Function(func)
412 }
413 Expression::AggregateFunction(mut agg) => {
414 agg.args = agg
415 .args
416 .into_iter()
417 .map(|a| update_column_in_expression(a, canonical_aliases))
418 .collect();
419 Expression::AggregateFunction(agg)
420 }
421 Expression::Case(mut case) => {
422 case.operand = case
423 .operand
424 .map(|o| update_column_in_expression(o, canonical_aliases));
425 case.whens = case
426 .whens
427 .into_iter()
428 .map(|(w, t)| {
429 (
430 update_column_in_expression(w, canonical_aliases),
431 update_column_in_expression(t, canonical_aliases),
432 )
433 })
434 .collect();
435 case.else_ = case
436 .else_
437 .map(|e| update_column_in_expression(e, canonical_aliases));
438 Expression::Case(case)
439 }
440 _ => expression,
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::generator::Generator;
448 use crate::parser::Parser;
449
450 fn gen(expr: &Expression) -> String {
451 Generator::new().generate(expr).unwrap()
452 }
453
454 fn parse(sql: &str) -> Expression {
455 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
456 }
457
458 #[test]
459 fn test_qualify_with_db() {
460 let options = QualifyTablesOptions::new().with_db("mydb");
461 let expr = parse("SELECT * FROM users");
462 let qualified = qualify_tables(expr, &options);
463 let sql = gen(&qualified);
464 assert!(sql.contains("mydb") && sql.contains("users"));
466 }
467
468 #[test]
469 fn test_qualify_with_db_and_catalog() {
470 let options = QualifyTablesOptions::new()
471 .with_db("mydb")
472 .with_catalog("mycatalog");
473 let expr = parse("SELECT * FROM users");
474 let qualified = qualify_tables(expr, &options);
475 let sql = gen(&qualified);
476 assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
478 }
479
480 #[test]
481 fn test_preserve_existing_schema() {
482 let options = QualifyTablesOptions::new().with_db("default_db");
483 let expr = parse("SELECT * FROM other_db.users");
484 let qualified = qualify_tables(expr, &options);
485 let sql = gen(&qualified);
486 assert!(sql.contains("other_db"));
488 assert!(!sql.contains("default_db"));
489 }
490
491 #[test]
492 fn test_ensure_table_alias() {
493 let options = QualifyTablesOptions::new();
494 let expr = parse("SELECT * FROM users");
495 let qualified = qualify_tables(expr, &options);
496 let sql = gen(&qualified);
497 assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
499 }
500
501 #[test]
502 fn test_canonical_aliases() {
503 let options = QualifyTablesOptions::new().with_canonical_aliases();
504 let expr = parse("SELECT u.id FROM users u");
505 let qualified = qualify_tables(expr, &options);
506 let sql = gen(&qualified);
507 assert!(sql.contains("_0"));
509 }
510
511 #[test]
512 fn test_qualify_join() {
513 let options = QualifyTablesOptions::new().with_db("mydb");
514 let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
515 let qualified = qualify_tables(expr, &options);
516 let sql = gen(&qualified);
517 assert!(sql.contains("mydb"));
519 }
520
521 #[test]
522 fn test_dont_qualify_cte() {
523 let options = QualifyTablesOptions::new().with_db("mydb");
524 let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
525 let qualified = qualify_tables(expr, &options);
526 let sql = gen(&qualified);
527 assert!(sql.contains("cte"));
530 }
531
532 #[test]
533 fn test_qualify_subquery() {
534 let options = QualifyTablesOptions::new().with_db("mydb");
535 let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
536 let qualified = qualify_tables(expr, &options);
537 let sql = gen(&qualified);
538 assert!(sql.contains("mydb"));
540 }
541}