Skip to main content

polyglot_sql/optimizer/
qualify_tables.rs

1//! Table Qualification Module
2//!
3//! This module provides functionality for qualifying table references in SQL queries
4//! with their database and catalog names.
5//!
6//! Ported from sqlglot's optimizer/qualify_tables.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Null, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12    get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16/// Options for table qualification
17#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19    /// Default database name to add to unqualified tables
20    pub db: Option<String>,
21    /// Default catalog name to add to tables that have a db but no catalog
22    pub catalog: Option<String>,
23    /// The dialect to use for normalization
24    pub dialect: Option<DialectType>,
25    /// Whether to use canonical aliases (_0, _1, ...) instead of table names
26    pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn with_db(mut self, db: impl Into<String>) -> Self {
35        self.db = Some(db.into());
36        self
37    }
38
39    pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40        self.catalog = Some(catalog.into());
41        self
42    }
43
44    pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45        self.dialect = Some(dialect);
46        self
47    }
48
49    pub fn with_canonical_aliases(mut self) -> Self {
50        self.canonicalize_table_aliases = true;
51        self
52    }
53}
54
55/// Rewrite SQL AST to have fully qualified tables.
56///
57/// This function:
58/// - Adds database/catalog prefixes to table references
59/// - Ensures all tables have aliases
60/// - Optionally canonicalizes aliases to _0, _1, etc.
61///
62/// # Examples
63///
64/// ```ignore
65/// // SELECT 1 FROM tbl -> SELECT 1 FROM db.tbl AS tbl
66/// let options = QualifyTablesOptions::new().with_db("db");
67/// let qualified = qualify_tables(expression, &options);
68/// ```
69///
70/// # Arguments
71/// * `expression` - The expression to qualify
72/// * `options` - Qualification options
73///
74/// # Returns
75/// The qualified expression
76pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77    let strategy = get_normalization_strategy(options.dialect);
78    let mut next_alias = name_sequence("_");
79
80    match expression {
81        Expression::Select(select) => {
82            let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83            Expression::Select(Box::new(qualified))
84        }
85        Expression::Union(mut union) => {
86            let left = std::mem::replace(&mut union.left, Expression::Null(Null));
87            union.left = qualify_tables(left, options);
88            let right = std::mem::replace(&mut union.right, Expression::Null(Null));
89            union.right = qualify_tables(right, options);
90            Expression::Union(union)
91        }
92        Expression::Intersect(mut intersect) => {
93            let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
94            intersect.left = qualify_tables(left, options);
95            let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
96            intersect.right = qualify_tables(right, options);
97            Expression::Intersect(intersect)
98        }
99        Expression::Except(mut except) => {
100            let left = std::mem::replace(&mut except.left, Expression::Null(Null));
101            except.left = qualify_tables(left, options);
102            let right = std::mem::replace(&mut except.right, Expression::Null(Null));
103            except.right = qualify_tables(right, options);
104            Expression::Except(except)
105        }
106        _ => expression,
107    }
108}
109
110/// Qualify a SELECT expression
111fn qualify_select(
112    mut select: Select,
113    options: &QualifyTablesOptions,
114    strategy: NormalizationStrategy,
115    next_alias: &mut impl FnMut() -> String,
116) -> Select {
117    // Collect CTE names to avoid qualifying them
118    let cte_names: HashSet<String> = select
119        .with
120        .as_ref()
121        .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
122        .unwrap_or_default();
123
124    // Track canonical aliases if needed
125    let mut canonical_aliases: HashMap<String, String> = HashMap::new();
126
127    // Qualify CTEs first
128    if let Some(ref mut with) = select.with {
129        for cte in &mut with.ctes {
130            cte.this = qualify_tables(cte.this.clone(), options);
131        }
132    }
133
134    // Qualify tables in FROM clause
135    if let Some(ref mut from) = select.from {
136        for expr in &mut from.expressions {
137            *expr = qualify_table_expression(
138                expr.clone(),
139                options,
140                strategy,
141                &cte_names,
142                &mut canonical_aliases,
143                next_alias,
144            );
145        }
146    }
147
148    // Qualify tables in JOINs
149    for join in &mut select.joins {
150        join.this = qualify_table_expression(
151            join.this.clone(),
152            options,
153            strategy,
154            &cte_names,
155            &mut canonical_aliases,
156            next_alias,
157        );
158    }
159
160    // Update column references if using canonical aliases
161    if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
162        select = update_column_references(select, &canonical_aliases);
163    }
164
165    select
166}
167
168/// Qualify a table expression (Table, Subquery, etc.)
169fn qualify_table_expression(
170    expression: Expression,
171    options: &QualifyTablesOptions,
172    strategy: NormalizationStrategy,
173    cte_names: &HashSet<String>,
174    canonical_aliases: &mut HashMap<String, String>,
175    next_alias: &mut impl FnMut() -> String,
176) -> Expression {
177    match expression {
178        Expression::Table(mut table) => {
179            let table_name = table.name.name.clone();
180
181            // Don't qualify CTEs
182            if cte_names.contains(&table_name) {
183                // Still ensure it has an alias
184                ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
185                return Expression::Table(table);
186            }
187
188            // Add db if specified and not already present
189            if let Some(ref db) = options.db {
190                if table.schema.is_none() {
191                    table.schema =
192                        Some(normalize_identifier(Identifier::new(db.clone()), strategy));
193                }
194            }
195
196            // Add catalog if specified, db is present, and catalog not already present
197            if let Some(ref catalog) = options.catalog {
198                if table.schema.is_some() && table.catalog.is_none() {
199                    table.catalog = Some(normalize_identifier(
200                        Identifier::new(catalog.clone()),
201                        strategy,
202                    ));
203                }
204            }
205
206            // Ensure the table has an alias
207            ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
208
209            Expression::Table(table)
210        }
211        Expression::Subquery(mut subquery) => {
212            // Qualify the inner query
213            subquery.this = qualify_tables(subquery.this, options);
214
215            // Ensure the subquery has an alias
216            if subquery.alias.is_none() || options.canonicalize_table_aliases {
217                let alias_name = if options.canonicalize_table_aliases {
218                    let new_name = next_alias();
219                    if let Some(ref old_alias) = subquery.alias {
220                        canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
221                    }
222                    new_name
223                } else {
224                    subquery
225                        .alias
226                        .as_ref()
227                        .map(|a| a.name.clone())
228                        .unwrap_or_else(|| next_alias())
229                };
230
231                subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
232            }
233
234            Expression::Subquery(subquery)
235        }
236        Expression::Paren(mut paren) => {
237            paren.this = qualify_table_expression(
238                paren.this,
239                options,
240                strategy,
241                cte_names,
242                canonical_aliases,
243                next_alias,
244            );
245            Expression::Paren(paren)
246        }
247        _ => expression,
248    }
249}
250
251/// Ensure a table has an alias
252fn ensure_table_alias(
253    table: &mut TableRef,
254    strategy: NormalizationStrategy,
255    canonical_aliases: &mut HashMap<String, String>,
256    next_alias: &mut impl FnMut() -> String,
257    options: &QualifyTablesOptions,
258) {
259    let table_name = table.name.name.clone();
260
261    if options.canonicalize_table_aliases {
262        // Use canonical alias (_0, _1, etc.)
263        let new_alias = next_alias();
264        let old_alias = table
265            .alias
266            .as_ref()
267            .map(|a| a.name.clone())
268            .unwrap_or(table_name.clone());
269        canonical_aliases.insert(old_alias, new_alias.clone());
270        table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
271    } else if table.alias.is_none() {
272        // Use table name as alias
273        table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
274    }
275}
276
277/// Update column references to use canonical aliases
278fn update_column_references(
279    mut select: Select,
280    canonical_aliases: &HashMap<String, String>,
281) -> Select {
282    // Update SELECT expressions
283    select.expressions = select
284        .expressions
285        .into_iter()
286        .map(|e| update_column_in_expression(e, canonical_aliases))
287        .collect();
288
289    // Update WHERE
290    if let Some(mut where_clause) = select.where_clause {
291        where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
292        select.where_clause = Some(where_clause);
293    }
294
295    // Update GROUP BY
296    if let Some(mut group_by) = select.group_by {
297        group_by.expressions = group_by
298            .expressions
299            .into_iter()
300            .map(|e| update_column_in_expression(e, canonical_aliases))
301            .collect();
302        select.group_by = Some(group_by);
303    }
304
305    // Update HAVING
306    if let Some(mut having) = select.having {
307        having.this = update_column_in_expression(having.this, canonical_aliases);
308        select.having = Some(having);
309    }
310
311    // Update ORDER BY
312    if let Some(mut order_by) = select.order_by {
313        order_by.expressions = order_by
314            .expressions
315            .into_iter()
316            .map(|mut o| {
317                o.this = update_column_in_expression(o.this, canonical_aliases);
318                o
319            })
320            .collect();
321        select.order_by = Some(order_by);
322    }
323
324    // Update JOIN ON conditions
325    for join in &mut select.joins {
326        if let Some(on) = &mut join.on {
327            *on = update_column_in_expression(on.clone(), canonical_aliases);
328        }
329    }
330
331    select
332}
333
334/// Update column references in an expression
335fn update_column_in_expression(
336    expression: Expression,
337    canonical_aliases: &HashMap<String, String>,
338) -> Expression {
339    match expression {
340        Expression::Column(mut col) => {
341            if let Some(ref table) = col.table {
342                if let Some(canonical) = canonical_aliases.get(&table.name) {
343                    col.table = Some(Identifier {
344                        name: canonical.clone(),
345                        quoted: table.quoted,
346                        trailing_comments: table.trailing_comments.clone(),
347                        span: None,
348                    });
349                }
350            }
351            Expression::Column(col)
352        }
353        Expression::And(mut bin) => {
354            bin.left = update_column_in_expression(bin.left, canonical_aliases);
355            bin.right = update_column_in_expression(bin.right, canonical_aliases);
356            Expression::And(bin)
357        }
358        Expression::Or(mut bin) => {
359            bin.left = update_column_in_expression(bin.left, canonical_aliases);
360            bin.right = update_column_in_expression(bin.right, canonical_aliases);
361            Expression::Or(bin)
362        }
363        Expression::Eq(mut bin) => {
364            bin.left = update_column_in_expression(bin.left, canonical_aliases);
365            bin.right = update_column_in_expression(bin.right, canonical_aliases);
366            Expression::Eq(bin)
367        }
368        Expression::Neq(mut bin) => {
369            bin.left = update_column_in_expression(bin.left, canonical_aliases);
370            bin.right = update_column_in_expression(bin.right, canonical_aliases);
371            Expression::Neq(bin)
372        }
373        Expression::Lt(mut bin) => {
374            bin.left = update_column_in_expression(bin.left, canonical_aliases);
375            bin.right = update_column_in_expression(bin.right, canonical_aliases);
376            Expression::Lt(bin)
377        }
378        Expression::Lte(mut bin) => {
379            bin.left = update_column_in_expression(bin.left, canonical_aliases);
380            bin.right = update_column_in_expression(bin.right, canonical_aliases);
381            Expression::Lte(bin)
382        }
383        Expression::Gt(mut bin) => {
384            bin.left = update_column_in_expression(bin.left, canonical_aliases);
385            bin.right = update_column_in_expression(bin.right, canonical_aliases);
386            Expression::Gt(bin)
387        }
388        Expression::Gte(mut bin) => {
389            bin.left = update_column_in_expression(bin.left, canonical_aliases);
390            bin.right = update_column_in_expression(bin.right, canonical_aliases);
391            Expression::Gte(bin)
392        }
393        Expression::Not(mut un) => {
394            un.this = update_column_in_expression(un.this, canonical_aliases);
395            Expression::Not(un)
396        }
397        Expression::Paren(mut paren) => {
398            paren.this = update_column_in_expression(paren.this, canonical_aliases);
399            Expression::Paren(paren)
400        }
401        Expression::Alias(mut alias) => {
402            alias.this = update_column_in_expression(alias.this, canonical_aliases);
403            Expression::Alias(alias)
404        }
405        Expression::Function(mut func) => {
406            func.args = func
407                .args
408                .into_iter()
409                .map(|a| update_column_in_expression(a, canonical_aliases))
410                .collect();
411            Expression::Function(func)
412        }
413        Expression::AggregateFunction(mut agg) => {
414            agg.args = agg
415                .args
416                .into_iter()
417                .map(|a| update_column_in_expression(a, canonical_aliases))
418                .collect();
419            Expression::AggregateFunction(agg)
420        }
421        Expression::Case(mut case) => {
422            case.operand = case
423                .operand
424                .map(|o| update_column_in_expression(o, canonical_aliases));
425            case.whens = case
426                .whens
427                .into_iter()
428                .map(|(w, t)| {
429                    (
430                        update_column_in_expression(w, canonical_aliases),
431                        update_column_in_expression(t, canonical_aliases),
432                    )
433                })
434                .collect();
435            case.else_ = case
436                .else_
437                .map(|e| update_column_in_expression(e, canonical_aliases));
438            Expression::Case(case)
439        }
440        _ => expression,
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::generator::Generator;
448    use crate::parser::Parser;
449
450    fn gen(expr: &Expression) -> String {
451        Generator::new().generate(expr).unwrap()
452    }
453
454    fn parse(sql: &str) -> Expression {
455        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
456    }
457
458    #[test]
459    fn test_qualify_with_db() {
460        let options = QualifyTablesOptions::new().with_db("mydb");
461        let expr = parse("SELECT * FROM users");
462        let qualified = qualify_tables(expr, &options);
463        let sql = gen(&qualified);
464        // Should contain mydb.users
465        assert!(sql.contains("mydb") && sql.contains("users"));
466    }
467
468    #[test]
469    fn test_qualify_with_db_and_catalog() {
470        let options = QualifyTablesOptions::new()
471            .with_db("mydb")
472            .with_catalog("mycatalog");
473        let expr = parse("SELECT * FROM users");
474        let qualified = qualify_tables(expr, &options);
475        let sql = gen(&qualified);
476        // Should contain mycatalog.mydb.users
477        assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
478    }
479
480    #[test]
481    fn test_preserve_existing_schema() {
482        let options = QualifyTablesOptions::new().with_db("default_db");
483        let expr = parse("SELECT * FROM other_db.users");
484        let qualified = qualify_tables(expr, &options);
485        let sql = gen(&qualified);
486        // Should preserve other_db, not add default_db
487        assert!(sql.contains("other_db"));
488        assert!(!sql.contains("default_db"));
489    }
490
491    #[test]
492    fn test_ensure_table_alias() {
493        let options = QualifyTablesOptions::new();
494        let expr = parse("SELECT * FROM users");
495        let qualified = qualify_tables(expr, &options);
496        let sql = gen(&qualified);
497        // Should have alias (AS users)
498        assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
499    }
500
501    #[test]
502    fn test_canonical_aliases() {
503        let options = QualifyTablesOptions::new().with_canonical_aliases();
504        let expr = parse("SELECT u.id FROM users u");
505        let qualified = qualify_tables(expr, &options);
506        let sql = gen(&qualified);
507        // Should use canonical alias like _0
508        assert!(sql.contains("_0"));
509    }
510
511    #[test]
512    fn test_qualify_join() {
513        let options = QualifyTablesOptions::new().with_db("mydb");
514        let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
515        let qualified = qualify_tables(expr, &options);
516        let sql = gen(&qualified);
517        // Both tables should be qualified
518        assert!(sql.contains("mydb"));
519    }
520
521    #[test]
522    fn test_dont_qualify_cte() {
523        let options = QualifyTablesOptions::new().with_db("mydb");
524        let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
525        let qualified = qualify_tables(expr, &options);
526        let sql = gen(&qualified);
527        // CTE reference should not be qualified with mydb
528        // The CTE definition might have mydb, but the SELECT FROM cte should not
529        assert!(sql.contains("cte"));
530    }
531
532    #[test]
533    fn test_qualify_subquery() {
534        let options = QualifyTablesOptions::new().with_db("mydb");
535        let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
536        let qualified = qualify_tables(expr, &options);
537        let sql = gen(&qualified);
538        // Inner table should be qualified
539        assert!(sql.contains("mydb"));
540    }
541}