1use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12 get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19 pub db: Option<String>,
21 pub catalog: Option<String>,
23 pub dialect: Option<DialectType>,
25 pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_db(mut self, db: impl Into<String>) -> Self {
35 self.db = Some(db.into());
36 self
37 }
38
39 pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40 self.catalog = Some(catalog.into());
41 self
42 }
43
44 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45 self.dialect = Some(dialect);
46 self
47 }
48
49 pub fn with_canonical_aliases(mut self) -> Self {
50 self.canonicalize_table_aliases = true;
51 self
52 }
53}
54
55pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77 let strategy = get_normalization_strategy(options.dialect);
78 let mut next_alias = name_sequence("_");
79
80 match expression {
81 Expression::Select(select) => {
82 let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83 Expression::Select(Box::new(qualified))
84 }
85 Expression::Union(mut union) => {
86 union.left = qualify_tables(union.left, options);
87 union.right = qualify_tables(union.right, options);
88 Expression::Union(union)
89 }
90 Expression::Intersect(mut intersect) => {
91 intersect.left = qualify_tables(intersect.left, options);
92 intersect.right = qualify_tables(intersect.right, options);
93 Expression::Intersect(intersect)
94 }
95 Expression::Except(mut except) => {
96 except.left = qualify_tables(except.left, options);
97 except.right = qualify_tables(except.right, options);
98 Expression::Except(except)
99 }
100 _ => expression,
101 }
102}
103
104fn qualify_select(
106 mut select: Select,
107 options: &QualifyTablesOptions,
108 strategy: NormalizationStrategy,
109 next_alias: &mut impl FnMut() -> String,
110) -> Select {
111 let cte_names: HashSet<String> = select
113 .with
114 .as_ref()
115 .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116 .unwrap_or_default();
117
118 let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121 if let Some(ref mut with) = select.with {
123 for cte in &mut with.ctes {
124 cte.this = qualify_tables(cte.this.clone(), options);
125 }
126 }
127
128 if let Some(ref mut from) = select.from {
130 for expr in &mut from.expressions {
131 *expr = qualify_table_expression(
132 expr.clone(),
133 options,
134 strategy,
135 &cte_names,
136 &mut canonical_aliases,
137 next_alias,
138 );
139 }
140 }
141
142 for join in &mut select.joins {
144 join.this = qualify_table_expression(
145 join.this.clone(),
146 options,
147 strategy,
148 &cte_names,
149 &mut canonical_aliases,
150 next_alias,
151 );
152 }
153
154 if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156 select = update_column_references(select, &canonical_aliases);
157 }
158
159 select
160}
161
162fn qualify_table_expression(
164 expression: Expression,
165 options: &QualifyTablesOptions,
166 strategy: NormalizationStrategy,
167 cte_names: &HashSet<String>,
168 canonical_aliases: &mut HashMap<String, String>,
169 next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171 match expression {
172 Expression::Table(mut table) => {
173 let table_name = table.name.name.clone();
174
175 if cte_names.contains(&table_name) {
177 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179 return Expression::Table(table);
180 }
181
182 if let Some(ref db) = options.db {
184 if table.schema.is_none() {
185 table.schema =
186 Some(normalize_identifier(Identifier::new(db.clone()), strategy));
187 }
188 }
189
190 if let Some(ref catalog) = options.catalog {
192 if table.schema.is_some() && table.catalog.is_none() {
193 table.catalog = Some(normalize_identifier(
194 Identifier::new(catalog.clone()),
195 strategy,
196 ));
197 }
198 }
199
200 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
202
203 Expression::Table(table)
204 }
205 Expression::Subquery(mut subquery) => {
206 subquery.this = qualify_tables(subquery.this, options);
208
209 if subquery.alias.is_none() || options.canonicalize_table_aliases {
211 let alias_name = if options.canonicalize_table_aliases {
212 let new_name = next_alias();
213 if let Some(ref old_alias) = subquery.alias {
214 canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
215 }
216 new_name
217 } else {
218 subquery
219 .alias
220 .as_ref()
221 .map(|a| a.name.clone())
222 .unwrap_or_else(|| next_alias())
223 };
224
225 subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
226 }
227
228 Expression::Subquery(subquery)
229 }
230 Expression::Paren(mut paren) => {
231 paren.this = qualify_table_expression(
232 paren.this,
233 options,
234 strategy,
235 cte_names,
236 canonical_aliases,
237 next_alias,
238 );
239 Expression::Paren(paren)
240 }
241 _ => expression,
242 }
243}
244
245fn ensure_table_alias(
247 table: &mut TableRef,
248 strategy: NormalizationStrategy,
249 canonical_aliases: &mut HashMap<String, String>,
250 next_alias: &mut impl FnMut() -> String,
251 options: &QualifyTablesOptions,
252) {
253 let table_name = table.name.name.clone();
254
255 if options.canonicalize_table_aliases {
256 let new_alias = next_alias();
258 let old_alias = table
259 .alias
260 .as_ref()
261 .map(|a| a.name.clone())
262 .unwrap_or(table_name.clone());
263 canonical_aliases.insert(old_alias, new_alias.clone());
264 table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
265 } else if table.alias.is_none() {
266 table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
268 }
269}
270
271fn update_column_references(
273 mut select: Select,
274 canonical_aliases: &HashMap<String, String>,
275) -> Select {
276 select.expressions = select
278 .expressions
279 .into_iter()
280 .map(|e| update_column_in_expression(e, canonical_aliases))
281 .collect();
282
283 if let Some(mut where_clause) = select.where_clause {
285 where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
286 select.where_clause = Some(where_clause);
287 }
288
289 if let Some(mut group_by) = select.group_by {
291 group_by.expressions = group_by
292 .expressions
293 .into_iter()
294 .map(|e| update_column_in_expression(e, canonical_aliases))
295 .collect();
296 select.group_by = Some(group_by);
297 }
298
299 if let Some(mut having) = select.having {
301 having.this = update_column_in_expression(having.this, canonical_aliases);
302 select.having = Some(having);
303 }
304
305 if let Some(mut order_by) = select.order_by {
307 order_by.expressions = order_by
308 .expressions
309 .into_iter()
310 .map(|mut o| {
311 o.this = update_column_in_expression(o.this, canonical_aliases);
312 o
313 })
314 .collect();
315 select.order_by = Some(order_by);
316 }
317
318 for join in &mut select.joins {
320 if let Some(on) = &mut join.on {
321 *on = update_column_in_expression(on.clone(), canonical_aliases);
322 }
323 }
324
325 select
326}
327
328fn update_column_in_expression(
330 expression: Expression,
331 canonical_aliases: &HashMap<String, String>,
332) -> Expression {
333 match expression {
334 Expression::Column(mut col) => {
335 if let Some(ref table) = col.table {
336 if let Some(canonical) = canonical_aliases.get(&table.name) {
337 col.table = Some(Identifier {
338 name: canonical.clone(),
339 quoted: table.quoted,
340 trailing_comments: table.trailing_comments.clone(),
341 span: None,
342 });
343 }
344 }
345 Expression::Column(col)
346 }
347 Expression::And(mut bin) => {
348 bin.left = update_column_in_expression(bin.left, canonical_aliases);
349 bin.right = update_column_in_expression(bin.right, canonical_aliases);
350 Expression::And(bin)
351 }
352 Expression::Or(mut bin) => {
353 bin.left = update_column_in_expression(bin.left, canonical_aliases);
354 bin.right = update_column_in_expression(bin.right, canonical_aliases);
355 Expression::Or(bin)
356 }
357 Expression::Eq(mut bin) => {
358 bin.left = update_column_in_expression(bin.left, canonical_aliases);
359 bin.right = update_column_in_expression(bin.right, canonical_aliases);
360 Expression::Eq(bin)
361 }
362 Expression::Neq(mut bin) => {
363 bin.left = update_column_in_expression(bin.left, canonical_aliases);
364 bin.right = update_column_in_expression(bin.right, canonical_aliases);
365 Expression::Neq(bin)
366 }
367 Expression::Lt(mut bin) => {
368 bin.left = update_column_in_expression(bin.left, canonical_aliases);
369 bin.right = update_column_in_expression(bin.right, canonical_aliases);
370 Expression::Lt(bin)
371 }
372 Expression::Lte(mut bin) => {
373 bin.left = update_column_in_expression(bin.left, canonical_aliases);
374 bin.right = update_column_in_expression(bin.right, canonical_aliases);
375 Expression::Lte(bin)
376 }
377 Expression::Gt(mut bin) => {
378 bin.left = update_column_in_expression(bin.left, canonical_aliases);
379 bin.right = update_column_in_expression(bin.right, canonical_aliases);
380 Expression::Gt(bin)
381 }
382 Expression::Gte(mut bin) => {
383 bin.left = update_column_in_expression(bin.left, canonical_aliases);
384 bin.right = update_column_in_expression(bin.right, canonical_aliases);
385 Expression::Gte(bin)
386 }
387 Expression::Not(mut un) => {
388 un.this = update_column_in_expression(un.this, canonical_aliases);
389 Expression::Not(un)
390 }
391 Expression::Paren(mut paren) => {
392 paren.this = update_column_in_expression(paren.this, canonical_aliases);
393 Expression::Paren(paren)
394 }
395 Expression::Alias(mut alias) => {
396 alias.this = update_column_in_expression(alias.this, canonical_aliases);
397 Expression::Alias(alias)
398 }
399 Expression::Function(mut func) => {
400 func.args = func
401 .args
402 .into_iter()
403 .map(|a| update_column_in_expression(a, canonical_aliases))
404 .collect();
405 Expression::Function(func)
406 }
407 Expression::AggregateFunction(mut agg) => {
408 agg.args = agg
409 .args
410 .into_iter()
411 .map(|a| update_column_in_expression(a, canonical_aliases))
412 .collect();
413 Expression::AggregateFunction(agg)
414 }
415 Expression::Case(mut case) => {
416 case.operand = case
417 .operand
418 .map(|o| update_column_in_expression(o, canonical_aliases));
419 case.whens = case
420 .whens
421 .into_iter()
422 .map(|(w, t)| {
423 (
424 update_column_in_expression(w, canonical_aliases),
425 update_column_in_expression(t, canonical_aliases),
426 )
427 })
428 .collect();
429 case.else_ = case
430 .else_
431 .map(|e| update_column_in_expression(e, canonical_aliases));
432 Expression::Case(case)
433 }
434 _ => expression,
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::generator::Generator;
442 use crate::parser::Parser;
443
444 fn gen(expr: &Expression) -> String {
445 Generator::new().generate(expr).unwrap()
446 }
447
448 fn parse(sql: &str) -> Expression {
449 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
450 }
451
452 #[test]
453 fn test_qualify_with_db() {
454 let options = QualifyTablesOptions::new().with_db("mydb");
455 let expr = parse("SELECT * FROM users");
456 let qualified = qualify_tables(expr, &options);
457 let sql = gen(&qualified);
458 assert!(sql.contains("mydb") && sql.contains("users"));
460 }
461
462 #[test]
463 fn test_qualify_with_db_and_catalog() {
464 let options = QualifyTablesOptions::new()
465 .with_db("mydb")
466 .with_catalog("mycatalog");
467 let expr = parse("SELECT * FROM users");
468 let qualified = qualify_tables(expr, &options);
469 let sql = gen(&qualified);
470 assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
472 }
473
474 #[test]
475 fn test_preserve_existing_schema() {
476 let options = QualifyTablesOptions::new().with_db("default_db");
477 let expr = parse("SELECT * FROM other_db.users");
478 let qualified = qualify_tables(expr, &options);
479 let sql = gen(&qualified);
480 assert!(sql.contains("other_db"));
482 assert!(!sql.contains("default_db"));
483 }
484
485 #[test]
486 fn test_ensure_table_alias() {
487 let options = QualifyTablesOptions::new();
488 let expr = parse("SELECT * FROM users");
489 let qualified = qualify_tables(expr, &options);
490 let sql = gen(&qualified);
491 assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
493 }
494
495 #[test]
496 fn test_canonical_aliases() {
497 let options = QualifyTablesOptions::new().with_canonical_aliases();
498 let expr = parse("SELECT u.id FROM users u");
499 let qualified = qualify_tables(expr, &options);
500 let sql = gen(&qualified);
501 assert!(sql.contains("_0"));
503 }
504
505 #[test]
506 fn test_qualify_join() {
507 let options = QualifyTablesOptions::new().with_db("mydb");
508 let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
509 let qualified = qualify_tables(expr, &options);
510 let sql = gen(&qualified);
511 assert!(sql.contains("mydb"));
513 }
514
515 #[test]
516 fn test_dont_qualify_cte() {
517 let options = QualifyTablesOptions::new().with_db("mydb");
518 let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
519 let qualified = qualify_tables(expr, &options);
520 let sql = gen(&qualified);
521 assert!(sql.contains("cte"));
524 }
525
526 #[test]
527 fn test_qualify_subquery() {
528 let options = QualifyTablesOptions::new().with_db("mydb");
529 let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
530 let qualified = qualify_tables(expr, &options);
531 let sql = gen(&qualified);
532 assert!(sql.contains("mydb"));
534 }
535}