1use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12 get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19 pub db: Option<String>,
21 pub catalog: Option<String>,
23 pub dialect: Option<DialectType>,
25 pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_db(mut self, db: impl Into<String>) -> Self {
35 self.db = Some(db.into());
36 self
37 }
38
39 pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40 self.catalog = Some(catalog.into());
41 self
42 }
43
44 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45 self.dialect = Some(dialect);
46 self
47 }
48
49 pub fn with_canonical_aliases(mut self) -> Self {
50 self.canonicalize_table_aliases = true;
51 self
52 }
53}
54
55pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77 let strategy = get_normalization_strategy(options.dialect);
78 let mut next_alias = name_sequence("_");
79
80 match expression {
81 Expression::Select(select) => {
82 let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83 Expression::Select(Box::new(qualified))
84 }
85 Expression::Union(mut union) => {
86 union.left = qualify_tables(union.left, options);
87 union.right = qualify_tables(union.right, options);
88 Expression::Union(union)
89 }
90 Expression::Intersect(mut intersect) => {
91 intersect.left = qualify_tables(intersect.left, options);
92 intersect.right = qualify_tables(intersect.right, options);
93 Expression::Intersect(intersect)
94 }
95 Expression::Except(mut except) => {
96 except.left = qualify_tables(except.left, options);
97 except.right = qualify_tables(except.right, options);
98 Expression::Except(except)
99 }
100 _ => expression,
101 }
102}
103
104fn qualify_select(
106 mut select: Select,
107 options: &QualifyTablesOptions,
108 strategy: NormalizationStrategy,
109 next_alias: &mut impl FnMut() -> String,
110) -> Select {
111 let cte_names: HashSet<String> = select
113 .with
114 .as_ref()
115 .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116 .unwrap_or_default();
117
118 let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121 if let Some(ref mut with) = select.with {
123 for cte in &mut with.ctes {
124 cte.this = qualify_tables(cte.this.clone(), options);
125 }
126 }
127
128 if let Some(ref mut from) = select.from {
130 for expr in &mut from.expressions {
131 *expr = qualify_table_expression(
132 expr.clone(),
133 options,
134 strategy,
135 &cte_names,
136 &mut canonical_aliases,
137 next_alias,
138 );
139 }
140 }
141
142 for join in &mut select.joins {
144 join.this = qualify_table_expression(
145 join.this.clone(),
146 options,
147 strategy,
148 &cte_names,
149 &mut canonical_aliases,
150 next_alias,
151 );
152 }
153
154 if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156 select = update_column_references(select, &canonical_aliases);
157 }
158
159 select
160}
161
162fn qualify_table_expression(
164 expression: Expression,
165 options: &QualifyTablesOptions,
166 strategy: NormalizationStrategy,
167 cte_names: &HashSet<String>,
168 canonical_aliases: &mut HashMap<String, String>,
169 next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171 match expression {
172 Expression::Table(mut table) => {
173 let table_name = table.name.name.clone();
174
175 if cte_names.contains(&table_name) {
177 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179 return Expression::Table(table);
180 }
181
182 if let Some(ref db) = options.db {
184 if table.schema.is_none() {
185 table.schema =
186 Some(normalize_identifier(Identifier::new(db.clone()), strategy));
187 }
188 }
189
190 if let Some(ref catalog) = options.catalog {
192 if table.schema.is_some() && table.catalog.is_none() {
193 table.catalog = Some(normalize_identifier(
194 Identifier::new(catalog.clone()),
195 strategy,
196 ));
197 }
198 }
199
200 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
202
203 Expression::Table(table)
204 }
205 Expression::Subquery(mut subquery) => {
206 subquery.this = qualify_tables(subquery.this, options);
208
209 if subquery.alias.is_none() || options.canonicalize_table_aliases {
211 let alias_name = if options.canonicalize_table_aliases {
212 let new_name = next_alias();
213 if let Some(ref old_alias) = subquery.alias {
214 canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
215 }
216 new_name
217 } else {
218 subquery
219 .alias
220 .as_ref()
221 .map(|a| a.name.clone())
222 .unwrap_or_else(|| next_alias())
223 };
224
225 subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
226 }
227
228 Expression::Subquery(subquery)
229 }
230 Expression::Paren(mut paren) => {
231 paren.this = qualify_table_expression(
232 paren.this,
233 options,
234 strategy,
235 cte_names,
236 canonical_aliases,
237 next_alias,
238 );
239 Expression::Paren(paren)
240 }
241 _ => expression,
242 }
243}
244
245fn ensure_table_alias(
247 table: &mut TableRef,
248 strategy: NormalizationStrategy,
249 canonical_aliases: &mut HashMap<String, String>,
250 next_alias: &mut impl FnMut() -> String,
251 options: &QualifyTablesOptions,
252) {
253 let table_name = table.name.name.clone();
254
255 if options.canonicalize_table_aliases {
256 let new_alias = next_alias();
258 let old_alias = table
259 .alias
260 .as_ref()
261 .map(|a| a.name.clone())
262 .unwrap_or(table_name.clone());
263 canonical_aliases.insert(old_alias, new_alias.clone());
264 table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
265 } else if table.alias.is_none() {
266 table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
268 }
269}
270
271fn update_column_references(
273 mut select: Select,
274 canonical_aliases: &HashMap<String, String>,
275) -> Select {
276 select.expressions = select
278 .expressions
279 .into_iter()
280 .map(|e| update_column_in_expression(e, canonical_aliases))
281 .collect();
282
283 if let Some(mut where_clause) = select.where_clause {
285 where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
286 select.where_clause = Some(where_clause);
287 }
288
289 if let Some(mut group_by) = select.group_by {
291 group_by.expressions = group_by
292 .expressions
293 .into_iter()
294 .map(|e| update_column_in_expression(e, canonical_aliases))
295 .collect();
296 select.group_by = Some(group_by);
297 }
298
299 if let Some(mut having) = select.having {
301 having.this = update_column_in_expression(having.this, canonical_aliases);
302 select.having = Some(having);
303 }
304
305 if let Some(mut order_by) = select.order_by {
307 order_by.expressions = order_by
308 .expressions
309 .into_iter()
310 .map(|mut o| {
311 o.this = update_column_in_expression(o.this, canonical_aliases);
312 o
313 })
314 .collect();
315 select.order_by = Some(order_by);
316 }
317
318 for join in &mut select.joins {
320 if let Some(on) = &mut join.on {
321 *on = update_column_in_expression(on.clone(), canonical_aliases);
322 }
323 }
324
325 select
326}
327
328fn update_column_in_expression(
330 expression: Expression,
331 canonical_aliases: &HashMap<String, String>,
332) -> Expression {
333 match expression {
334 Expression::Column(mut col) => {
335 if let Some(ref table) = col.table {
336 if let Some(canonical) = canonical_aliases.get(&table.name) {
337 col.table = Some(Identifier {
338 name: canonical.clone(),
339 quoted: table.quoted,
340 trailing_comments: table.trailing_comments.clone(),
341 });
342 }
343 }
344 Expression::Column(col)
345 }
346 Expression::And(mut bin) => {
347 bin.left = update_column_in_expression(bin.left, canonical_aliases);
348 bin.right = update_column_in_expression(bin.right, canonical_aliases);
349 Expression::And(bin)
350 }
351 Expression::Or(mut bin) => {
352 bin.left = update_column_in_expression(bin.left, canonical_aliases);
353 bin.right = update_column_in_expression(bin.right, canonical_aliases);
354 Expression::Or(bin)
355 }
356 Expression::Eq(mut bin) => {
357 bin.left = update_column_in_expression(bin.left, canonical_aliases);
358 bin.right = update_column_in_expression(bin.right, canonical_aliases);
359 Expression::Eq(bin)
360 }
361 Expression::Neq(mut bin) => {
362 bin.left = update_column_in_expression(bin.left, canonical_aliases);
363 bin.right = update_column_in_expression(bin.right, canonical_aliases);
364 Expression::Neq(bin)
365 }
366 Expression::Lt(mut bin) => {
367 bin.left = update_column_in_expression(bin.left, canonical_aliases);
368 bin.right = update_column_in_expression(bin.right, canonical_aliases);
369 Expression::Lt(bin)
370 }
371 Expression::Lte(mut bin) => {
372 bin.left = update_column_in_expression(bin.left, canonical_aliases);
373 bin.right = update_column_in_expression(bin.right, canonical_aliases);
374 Expression::Lte(bin)
375 }
376 Expression::Gt(mut bin) => {
377 bin.left = update_column_in_expression(bin.left, canonical_aliases);
378 bin.right = update_column_in_expression(bin.right, canonical_aliases);
379 Expression::Gt(bin)
380 }
381 Expression::Gte(mut bin) => {
382 bin.left = update_column_in_expression(bin.left, canonical_aliases);
383 bin.right = update_column_in_expression(bin.right, canonical_aliases);
384 Expression::Gte(bin)
385 }
386 Expression::Not(mut un) => {
387 un.this = update_column_in_expression(un.this, canonical_aliases);
388 Expression::Not(un)
389 }
390 Expression::Paren(mut paren) => {
391 paren.this = update_column_in_expression(paren.this, canonical_aliases);
392 Expression::Paren(paren)
393 }
394 Expression::Alias(mut alias) => {
395 alias.this = update_column_in_expression(alias.this, canonical_aliases);
396 Expression::Alias(alias)
397 }
398 Expression::Function(mut func) => {
399 func.args = func
400 .args
401 .into_iter()
402 .map(|a| update_column_in_expression(a, canonical_aliases))
403 .collect();
404 Expression::Function(func)
405 }
406 Expression::AggregateFunction(mut agg) => {
407 agg.args = agg
408 .args
409 .into_iter()
410 .map(|a| update_column_in_expression(a, canonical_aliases))
411 .collect();
412 Expression::AggregateFunction(agg)
413 }
414 Expression::Case(mut case) => {
415 case.operand = case
416 .operand
417 .map(|o| update_column_in_expression(o, canonical_aliases));
418 case.whens = case
419 .whens
420 .into_iter()
421 .map(|(w, t)| {
422 (
423 update_column_in_expression(w, canonical_aliases),
424 update_column_in_expression(t, canonical_aliases),
425 )
426 })
427 .collect();
428 case.else_ = case
429 .else_
430 .map(|e| update_column_in_expression(e, canonical_aliases));
431 Expression::Case(case)
432 }
433 _ => expression,
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::generator::Generator;
441 use crate::parser::Parser;
442
443 fn gen(expr: &Expression) -> String {
444 Generator::new().generate(expr).unwrap()
445 }
446
447 fn parse(sql: &str) -> Expression {
448 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
449 }
450
451 #[test]
452 fn test_qualify_with_db() {
453 let options = QualifyTablesOptions::new().with_db("mydb");
454 let expr = parse("SELECT * FROM users");
455 let qualified = qualify_tables(expr, &options);
456 let sql = gen(&qualified);
457 assert!(sql.contains("mydb") && sql.contains("users"));
459 }
460
461 #[test]
462 fn test_qualify_with_db_and_catalog() {
463 let options = QualifyTablesOptions::new()
464 .with_db("mydb")
465 .with_catalog("mycatalog");
466 let expr = parse("SELECT * FROM users");
467 let qualified = qualify_tables(expr, &options);
468 let sql = gen(&qualified);
469 assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
471 }
472
473 #[test]
474 fn test_preserve_existing_schema() {
475 let options = QualifyTablesOptions::new().with_db("default_db");
476 let expr = parse("SELECT * FROM other_db.users");
477 let qualified = qualify_tables(expr, &options);
478 let sql = gen(&qualified);
479 assert!(sql.contains("other_db"));
481 assert!(!sql.contains("default_db"));
482 }
483
484 #[test]
485 fn test_ensure_table_alias() {
486 let options = QualifyTablesOptions::new();
487 let expr = parse("SELECT * FROM users");
488 let qualified = qualify_tables(expr, &options);
489 let sql = gen(&qualified);
490 assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
492 }
493
494 #[test]
495 fn test_canonical_aliases() {
496 let options = QualifyTablesOptions::new().with_canonical_aliases();
497 let expr = parse("SELECT u.id FROM users u");
498 let qualified = qualify_tables(expr, &options);
499 let sql = gen(&qualified);
500 assert!(sql.contains("_0"));
502 }
503
504 #[test]
505 fn test_qualify_join() {
506 let options = QualifyTablesOptions::new().with_db("mydb");
507 let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
508 let qualified = qualify_tables(expr, &options);
509 let sql = gen(&qualified);
510 assert!(sql.contains("mydb"));
512 }
513
514 #[test]
515 fn test_dont_qualify_cte() {
516 let options = QualifyTablesOptions::new().with_db("mydb");
517 let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
518 let qualified = qualify_tables(expr, &options);
519 let sql = gen(&qualified);
520 assert!(sql.contains("cte"));
523 }
524
525 #[test]
526 fn test_qualify_subquery() {
527 let options = QualifyTablesOptions::new().with_db("mydb");
528 let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
529 let qualified = qualify_tables(expr, &options);
530 let sql = gen(&qualified);
531 assert!(sql.contains("mydb"));
533 }
534}