1use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12 get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19 pub db: Option<String>,
21 pub catalog: Option<String>,
23 pub dialect: Option<DialectType>,
25 pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_db(mut self, db: impl Into<String>) -> Self {
35 self.db = Some(db.into());
36 self
37 }
38
39 pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40 self.catalog = Some(catalog.into());
41 self
42 }
43
44 pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45 self.dialect = Some(dialect);
46 self
47 }
48
49 pub fn with_canonical_aliases(mut self) -> Self {
50 self.canonicalize_table_aliases = true;
51 self
52 }
53}
54
55pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77 let strategy = get_normalization_strategy(options.dialect);
78 let mut next_alias = name_sequence("_");
79
80 match expression {
81 Expression::Select(select) => {
82 let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83 Expression::Select(Box::new(qualified))
84 }
85 Expression::Union(mut union) => {
86 union.left = qualify_tables(union.left, options);
87 union.right = qualify_tables(union.right, options);
88 Expression::Union(union)
89 }
90 Expression::Intersect(mut intersect) => {
91 intersect.left = qualify_tables(intersect.left, options);
92 intersect.right = qualify_tables(intersect.right, options);
93 Expression::Intersect(intersect)
94 }
95 Expression::Except(mut except) => {
96 except.left = qualify_tables(except.left, options);
97 except.right = qualify_tables(except.right, options);
98 Expression::Except(except)
99 }
100 _ => expression,
101 }
102}
103
104fn qualify_select(
106 mut select: Select,
107 options: &QualifyTablesOptions,
108 strategy: NormalizationStrategy,
109 next_alias: &mut impl FnMut() -> String,
110) -> Select {
111 let cte_names: HashSet<String> = select
113 .with
114 .as_ref()
115 .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116 .unwrap_or_default();
117
118 let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121 if let Some(ref mut with) = select.with {
123 for cte in &mut with.ctes {
124 cte.this = qualify_tables(cte.this.clone(), options);
125 }
126 }
127
128 if let Some(ref mut from) = select.from {
130 for expr in &mut from.expressions {
131 *expr = qualify_table_expression(
132 expr.clone(),
133 options,
134 strategy,
135 &cte_names,
136 &mut canonical_aliases,
137 next_alias,
138 );
139 }
140 }
141
142 for join in &mut select.joins {
144 join.this = qualify_table_expression(
145 join.this.clone(),
146 options,
147 strategy,
148 &cte_names,
149 &mut canonical_aliases,
150 next_alias,
151 );
152 }
153
154 if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156 select = update_column_references(select, &canonical_aliases);
157 }
158
159 select
160}
161
162fn qualify_table_expression(
164 expression: Expression,
165 options: &QualifyTablesOptions,
166 strategy: NormalizationStrategy,
167 cte_names: &HashSet<String>,
168 canonical_aliases: &mut HashMap<String, String>,
169 next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171 match expression {
172 Expression::Table(mut table) => {
173 let table_name = table.name.name.clone();
174
175 if cte_names.contains(&table_name) {
177 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179 return Expression::Table(table);
180 }
181
182 if let Some(ref db) = options.db {
184 if table.schema.is_none() {
185 table.schema = Some(normalize_identifier(
186 Identifier::new(db.clone()),
187 strategy,
188 ));
189 }
190 }
191
192 if let Some(ref catalog) = options.catalog {
194 if table.schema.is_some() && table.catalog.is_none() {
195 table.catalog = Some(normalize_identifier(
196 Identifier::new(catalog.clone()),
197 strategy,
198 ));
199 }
200 }
201
202 ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
204
205 Expression::Table(table)
206 }
207 Expression::Subquery(mut subquery) => {
208 subquery.this = qualify_tables(subquery.this, options);
210
211 if subquery.alias.is_none() || options.canonicalize_table_aliases {
213 let alias_name = if options.canonicalize_table_aliases {
214 let new_name = next_alias();
215 if let Some(ref old_alias) = subquery.alias {
216 canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
217 }
218 new_name
219 } else {
220 subquery
221 .alias
222 .as_ref()
223 .map(|a| a.name.clone())
224 .unwrap_or_else(|| next_alias())
225 };
226
227 subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
228 }
229
230 Expression::Subquery(subquery)
231 }
232 Expression::Paren(mut paren) => {
233 paren.this = qualify_table_expression(
234 paren.this,
235 options,
236 strategy,
237 cte_names,
238 canonical_aliases,
239 next_alias,
240 );
241 Expression::Paren(paren)
242 }
243 _ => expression,
244 }
245}
246
247fn ensure_table_alias(
249 table: &mut TableRef,
250 strategy: NormalizationStrategy,
251 canonical_aliases: &mut HashMap<String, String>,
252 next_alias: &mut impl FnMut() -> String,
253 options: &QualifyTablesOptions,
254) {
255 let table_name = table.name.name.clone();
256
257 if options.canonicalize_table_aliases {
258 let new_alias = next_alias();
260 let old_alias = table.alias.as_ref().map(|a| a.name.clone()).unwrap_or(table_name.clone());
261 canonical_aliases.insert(old_alias, new_alias.clone());
262 table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
263 } else if table.alias.is_none() {
264 table.alias = Some(normalize_identifier(
266 Identifier::new(table_name),
267 strategy,
268 ));
269 }
270}
271
272fn update_column_references(mut select: Select, canonical_aliases: &HashMap<String, String>) -> Select {
274 select.expressions = select
276 .expressions
277 .into_iter()
278 .map(|e| update_column_in_expression(e, canonical_aliases))
279 .collect();
280
281 if let Some(mut where_clause) = select.where_clause {
283 where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
284 select.where_clause = Some(where_clause);
285 }
286
287 if let Some(mut group_by) = select.group_by {
289 group_by.expressions = group_by
290 .expressions
291 .into_iter()
292 .map(|e| update_column_in_expression(e, canonical_aliases))
293 .collect();
294 select.group_by = Some(group_by);
295 }
296
297 if let Some(mut having) = select.having {
299 having.this = update_column_in_expression(having.this, canonical_aliases);
300 select.having = Some(having);
301 }
302
303 if let Some(mut order_by) = select.order_by {
305 order_by.expressions = order_by
306 .expressions
307 .into_iter()
308 .map(|mut o| {
309 o.this = update_column_in_expression(o.this, canonical_aliases);
310 o
311 })
312 .collect();
313 select.order_by = Some(order_by);
314 }
315
316 for join in &mut select.joins {
318 if let Some(on) = &mut join.on {
319 *on = update_column_in_expression(on.clone(), canonical_aliases);
320 }
321 }
322
323 select
324}
325
326fn update_column_in_expression(
328 expression: Expression,
329 canonical_aliases: &HashMap<String, String>,
330) -> Expression {
331 match expression {
332 Expression::Column(mut col) => {
333 if let Some(ref table) = col.table {
334 if let Some(canonical) = canonical_aliases.get(&table.name) {
335 col.table = Some(Identifier {
336 name: canonical.clone(),
337 quoted: table.quoted,
338 trailing_comments: table.trailing_comments.clone(),
339 });
340 }
341 }
342 Expression::Column(col)
343 }
344 Expression::And(mut bin) => {
345 bin.left = update_column_in_expression(bin.left, canonical_aliases);
346 bin.right = update_column_in_expression(bin.right, canonical_aliases);
347 Expression::And(bin)
348 }
349 Expression::Or(mut bin) => {
350 bin.left = update_column_in_expression(bin.left, canonical_aliases);
351 bin.right = update_column_in_expression(bin.right, canonical_aliases);
352 Expression::Or(bin)
353 }
354 Expression::Eq(mut bin) => {
355 bin.left = update_column_in_expression(bin.left, canonical_aliases);
356 bin.right = update_column_in_expression(bin.right, canonical_aliases);
357 Expression::Eq(bin)
358 }
359 Expression::Neq(mut bin) => {
360 bin.left = update_column_in_expression(bin.left, canonical_aliases);
361 bin.right = update_column_in_expression(bin.right, canonical_aliases);
362 Expression::Neq(bin)
363 }
364 Expression::Lt(mut bin) => {
365 bin.left = update_column_in_expression(bin.left, canonical_aliases);
366 bin.right = update_column_in_expression(bin.right, canonical_aliases);
367 Expression::Lt(bin)
368 }
369 Expression::Lte(mut bin) => {
370 bin.left = update_column_in_expression(bin.left, canonical_aliases);
371 bin.right = update_column_in_expression(bin.right, canonical_aliases);
372 Expression::Lte(bin)
373 }
374 Expression::Gt(mut bin) => {
375 bin.left = update_column_in_expression(bin.left, canonical_aliases);
376 bin.right = update_column_in_expression(bin.right, canonical_aliases);
377 Expression::Gt(bin)
378 }
379 Expression::Gte(mut bin) => {
380 bin.left = update_column_in_expression(bin.left, canonical_aliases);
381 bin.right = update_column_in_expression(bin.right, canonical_aliases);
382 Expression::Gte(bin)
383 }
384 Expression::Not(mut un) => {
385 un.this = update_column_in_expression(un.this, canonical_aliases);
386 Expression::Not(un)
387 }
388 Expression::Paren(mut paren) => {
389 paren.this = update_column_in_expression(paren.this, canonical_aliases);
390 Expression::Paren(paren)
391 }
392 Expression::Alias(mut alias) => {
393 alias.this = update_column_in_expression(alias.this, canonical_aliases);
394 Expression::Alias(alias)
395 }
396 Expression::Function(mut func) => {
397 func.args = func
398 .args
399 .into_iter()
400 .map(|a| update_column_in_expression(a, canonical_aliases))
401 .collect();
402 Expression::Function(func)
403 }
404 Expression::AggregateFunction(mut agg) => {
405 agg.args = agg
406 .args
407 .into_iter()
408 .map(|a| update_column_in_expression(a, canonical_aliases))
409 .collect();
410 Expression::AggregateFunction(agg)
411 }
412 Expression::Case(mut case) => {
413 case.operand = case
414 .operand
415 .map(|o| update_column_in_expression(o, canonical_aliases));
416 case.whens = case
417 .whens
418 .into_iter()
419 .map(|(w, t)| {
420 (
421 update_column_in_expression(w, canonical_aliases),
422 update_column_in_expression(t, canonical_aliases),
423 )
424 })
425 .collect();
426 case.else_ = case
427 .else_
428 .map(|e| update_column_in_expression(e, canonical_aliases));
429 Expression::Case(case)
430 }
431 _ => expression,
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::generator::Generator;
439 use crate::parser::Parser;
440
441 fn gen(expr: &Expression) -> String {
442 Generator::new().generate(expr).unwrap()
443 }
444
445 fn parse(sql: &str) -> Expression {
446 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
447 }
448
449 #[test]
450 fn test_qualify_with_db() {
451 let options = QualifyTablesOptions::new().with_db("mydb");
452 let expr = parse("SELECT * FROM users");
453 let qualified = qualify_tables(expr, &options);
454 let sql = gen(&qualified);
455 assert!(sql.contains("mydb") && sql.contains("users"));
457 }
458
459 #[test]
460 fn test_qualify_with_db_and_catalog() {
461 let options = QualifyTablesOptions::new()
462 .with_db("mydb")
463 .with_catalog("mycatalog");
464 let expr = parse("SELECT * FROM users");
465 let qualified = qualify_tables(expr, &options);
466 let sql = gen(&qualified);
467 assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
469 }
470
471 #[test]
472 fn test_preserve_existing_schema() {
473 let options = QualifyTablesOptions::new().with_db("default_db");
474 let expr = parse("SELECT * FROM other_db.users");
475 let qualified = qualify_tables(expr, &options);
476 let sql = gen(&qualified);
477 assert!(sql.contains("other_db"));
479 assert!(!sql.contains("default_db"));
480 }
481
482 #[test]
483 fn test_ensure_table_alias() {
484 let options = QualifyTablesOptions::new();
485 let expr = parse("SELECT * FROM users");
486 let qualified = qualify_tables(expr, &options);
487 let sql = gen(&qualified);
488 assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
490 }
491
492 #[test]
493 fn test_canonical_aliases() {
494 let options = QualifyTablesOptions::new().with_canonical_aliases();
495 let expr = parse("SELECT u.id FROM users u");
496 let qualified = qualify_tables(expr, &options);
497 let sql = gen(&qualified);
498 assert!(sql.contains("_0"));
500 }
501
502 #[test]
503 fn test_qualify_join() {
504 let options = QualifyTablesOptions::new().with_db("mydb");
505 let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
506 let qualified = qualify_tables(expr, &options);
507 let sql = gen(&qualified);
508 assert!(sql.contains("mydb"));
510 }
511
512 #[test]
513 fn test_dont_qualify_cte() {
514 let options = QualifyTablesOptions::new().with_db("mydb");
515 let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
516 let qualified = qualify_tables(expr, &options);
517 let sql = gen(&qualified);
518 assert!(sql.contains("cte"));
521 }
522
523 #[test]
524 fn test_qualify_subquery() {
525 let options = QualifyTablesOptions::new().with_db("mydb");
526 let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
527 let qualified = qualify_tables(expr, &options);
528 let sql = gen(&qualified);
529 assert!(sql.contains("mydb"));
531 }
532}