Skip to main content

polyglot_sql/optimizer/
qualify_tables.rs

1//! Table Qualification Module
2//!
3//! This module provides functionality for qualifying table references in SQL queries
4//! with their database and catalog names.
5//!
6//! Ported from sqlglot's optimizer/qualify_tables.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12    get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16/// Options for table qualification
17#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19    /// Default database name to add to unqualified tables
20    pub db: Option<String>,
21    /// Default catalog name to add to tables that have a db but no catalog
22    pub catalog: Option<String>,
23    /// The dialect to use for normalization
24    pub dialect: Option<DialectType>,
25    /// Whether to use canonical aliases (_0, _1, ...) instead of table names
26    pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn with_db(mut self, db: impl Into<String>) -> Self {
35        self.db = Some(db.into());
36        self
37    }
38
39    pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40        self.catalog = Some(catalog.into());
41        self
42    }
43
44    pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45        self.dialect = Some(dialect);
46        self
47    }
48
49    pub fn with_canonical_aliases(mut self) -> Self {
50        self.canonicalize_table_aliases = true;
51        self
52    }
53}
54
55/// Rewrite SQL AST to have fully qualified tables.
56///
57/// This function:
58/// - Adds database/catalog prefixes to table references
59/// - Ensures all tables have aliases
60/// - Optionally canonicalizes aliases to _0, _1, etc.
61///
62/// # Examples
63///
64/// ```ignore
65/// // SELECT 1 FROM tbl -> SELECT 1 FROM db.tbl AS tbl
66/// let options = QualifyTablesOptions::new().with_db("db");
67/// let qualified = qualify_tables(expression, &options);
68/// ```
69///
70/// # Arguments
71/// * `expression` - The expression to qualify
72/// * `options` - Qualification options
73///
74/// # Returns
75/// The qualified expression
76pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77    let strategy = get_normalization_strategy(options.dialect);
78    let mut next_alias = name_sequence("_");
79
80    match expression {
81        Expression::Select(select) => {
82            let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83            Expression::Select(Box::new(qualified))
84        }
85        Expression::Union(mut union) => {
86            union.left = qualify_tables(union.left, options);
87            union.right = qualify_tables(union.right, options);
88            Expression::Union(union)
89        }
90        Expression::Intersect(mut intersect) => {
91            intersect.left = qualify_tables(intersect.left, options);
92            intersect.right = qualify_tables(intersect.right, options);
93            Expression::Intersect(intersect)
94        }
95        Expression::Except(mut except) => {
96            except.left = qualify_tables(except.left, options);
97            except.right = qualify_tables(except.right, options);
98            Expression::Except(except)
99        }
100        _ => expression,
101    }
102}
103
104/// Qualify a SELECT expression
105fn qualify_select(
106    mut select: Select,
107    options: &QualifyTablesOptions,
108    strategy: NormalizationStrategy,
109    next_alias: &mut impl FnMut() -> String,
110) -> Select {
111    // Collect CTE names to avoid qualifying them
112    let cte_names: HashSet<String> = select
113        .with
114        .as_ref()
115        .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116        .unwrap_or_default();
117
118    // Track canonical aliases if needed
119    let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121    // Qualify CTEs first
122    if let Some(ref mut with) = select.with {
123        for cte in &mut with.ctes {
124            cte.this = qualify_tables(cte.this.clone(), options);
125        }
126    }
127
128    // Qualify tables in FROM clause
129    if let Some(ref mut from) = select.from {
130        for expr in &mut from.expressions {
131            *expr = qualify_table_expression(
132                expr.clone(),
133                options,
134                strategy,
135                &cte_names,
136                &mut canonical_aliases,
137                next_alias,
138            );
139        }
140    }
141
142    // Qualify tables in JOINs
143    for join in &mut select.joins {
144        join.this = qualify_table_expression(
145            join.this.clone(),
146            options,
147            strategy,
148            &cte_names,
149            &mut canonical_aliases,
150            next_alias,
151        );
152    }
153
154    // Update column references if using canonical aliases
155    if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156        select = update_column_references(select, &canonical_aliases);
157    }
158
159    select
160}
161
162/// Qualify a table expression (Table, Subquery, etc.)
163fn qualify_table_expression(
164    expression: Expression,
165    options: &QualifyTablesOptions,
166    strategy: NormalizationStrategy,
167    cte_names: &HashSet<String>,
168    canonical_aliases: &mut HashMap<String, String>,
169    next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171    match expression {
172        Expression::Table(mut table) => {
173            let table_name = table.name.name.clone();
174
175            // Don't qualify CTEs
176            if cte_names.contains(&table_name) {
177                // Still ensure it has an alias
178                ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179                return Expression::Table(table);
180            }
181
182            // Add db if specified and not already present
183            if let Some(ref db) = options.db {
184                if table.schema.is_none() {
185                    table.schema = Some(normalize_identifier(
186                        Identifier::new(db.clone()),
187                        strategy,
188                    ));
189                }
190            }
191
192            // Add catalog if specified, db is present, and catalog not already present
193            if let Some(ref catalog) = options.catalog {
194                if table.schema.is_some() && table.catalog.is_none() {
195                    table.catalog = Some(normalize_identifier(
196                        Identifier::new(catalog.clone()),
197                        strategy,
198                    ));
199                }
200            }
201
202            // Ensure the table has an alias
203            ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
204
205            Expression::Table(table)
206        }
207        Expression::Subquery(mut subquery) => {
208            // Qualify the inner query
209            subquery.this = qualify_tables(subquery.this, options);
210
211            // Ensure the subquery has an alias
212            if subquery.alias.is_none() || options.canonicalize_table_aliases {
213                let alias_name = if options.canonicalize_table_aliases {
214                    let new_name = next_alias();
215                    if let Some(ref old_alias) = subquery.alias {
216                        canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
217                    }
218                    new_name
219                } else {
220                    subquery
221                        .alias
222                        .as_ref()
223                        .map(|a| a.name.clone())
224                        .unwrap_or_else(|| next_alias())
225                };
226
227                subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
228            }
229
230            Expression::Subquery(subquery)
231        }
232        Expression::Paren(mut paren) => {
233            paren.this = qualify_table_expression(
234                paren.this,
235                options,
236                strategy,
237                cte_names,
238                canonical_aliases,
239                next_alias,
240            );
241            Expression::Paren(paren)
242        }
243        _ => expression,
244    }
245}
246
247/// Ensure a table has an alias
248fn ensure_table_alias(
249    table: &mut TableRef,
250    strategy: NormalizationStrategy,
251    canonical_aliases: &mut HashMap<String, String>,
252    next_alias: &mut impl FnMut() -> String,
253    options: &QualifyTablesOptions,
254) {
255    let table_name = table.name.name.clone();
256
257    if options.canonicalize_table_aliases {
258        // Use canonical alias (_0, _1, etc.)
259        let new_alias = next_alias();
260        let old_alias = table.alias.as_ref().map(|a| a.name.clone()).unwrap_or(table_name.clone());
261        canonical_aliases.insert(old_alias, new_alias.clone());
262        table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
263    } else if table.alias.is_none() {
264        // Use table name as alias
265        table.alias = Some(normalize_identifier(
266            Identifier::new(table_name),
267            strategy,
268        ));
269    }
270}
271
272/// Update column references to use canonical aliases
273fn update_column_references(mut select: Select, canonical_aliases: &HashMap<String, String>) -> Select {
274    // Update SELECT expressions
275    select.expressions = select
276        .expressions
277        .into_iter()
278        .map(|e| update_column_in_expression(e, canonical_aliases))
279        .collect();
280
281    // Update WHERE
282    if let Some(mut where_clause) = select.where_clause {
283        where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
284        select.where_clause = Some(where_clause);
285    }
286
287    // Update GROUP BY
288    if let Some(mut group_by) = select.group_by {
289        group_by.expressions = group_by
290            .expressions
291            .into_iter()
292            .map(|e| update_column_in_expression(e, canonical_aliases))
293            .collect();
294        select.group_by = Some(group_by);
295    }
296
297    // Update HAVING
298    if let Some(mut having) = select.having {
299        having.this = update_column_in_expression(having.this, canonical_aliases);
300        select.having = Some(having);
301    }
302
303    // Update ORDER BY
304    if let Some(mut order_by) = select.order_by {
305        order_by.expressions = order_by
306            .expressions
307            .into_iter()
308            .map(|mut o| {
309                o.this = update_column_in_expression(o.this, canonical_aliases);
310                o
311            })
312            .collect();
313        select.order_by = Some(order_by);
314    }
315
316    // Update JOIN ON conditions
317    for join in &mut select.joins {
318        if let Some(on) = &mut join.on {
319            *on = update_column_in_expression(on.clone(), canonical_aliases);
320        }
321    }
322
323    select
324}
325
326/// Update column references in an expression
327fn update_column_in_expression(
328    expression: Expression,
329    canonical_aliases: &HashMap<String, String>,
330) -> Expression {
331    match expression {
332        Expression::Column(mut col) => {
333            if let Some(ref table) = col.table {
334                if let Some(canonical) = canonical_aliases.get(&table.name) {
335                    col.table = Some(Identifier {
336                        name: canonical.clone(),
337                        quoted: table.quoted,
338                        trailing_comments: table.trailing_comments.clone(),
339                    });
340                }
341            }
342            Expression::Column(col)
343        }
344        Expression::And(mut bin) => {
345            bin.left = update_column_in_expression(bin.left, canonical_aliases);
346            bin.right = update_column_in_expression(bin.right, canonical_aliases);
347            Expression::And(bin)
348        }
349        Expression::Or(mut bin) => {
350            bin.left = update_column_in_expression(bin.left, canonical_aliases);
351            bin.right = update_column_in_expression(bin.right, canonical_aliases);
352            Expression::Or(bin)
353        }
354        Expression::Eq(mut bin) => {
355            bin.left = update_column_in_expression(bin.left, canonical_aliases);
356            bin.right = update_column_in_expression(bin.right, canonical_aliases);
357            Expression::Eq(bin)
358        }
359        Expression::Neq(mut bin) => {
360            bin.left = update_column_in_expression(bin.left, canonical_aliases);
361            bin.right = update_column_in_expression(bin.right, canonical_aliases);
362            Expression::Neq(bin)
363        }
364        Expression::Lt(mut bin) => {
365            bin.left = update_column_in_expression(bin.left, canonical_aliases);
366            bin.right = update_column_in_expression(bin.right, canonical_aliases);
367            Expression::Lt(bin)
368        }
369        Expression::Lte(mut bin) => {
370            bin.left = update_column_in_expression(bin.left, canonical_aliases);
371            bin.right = update_column_in_expression(bin.right, canonical_aliases);
372            Expression::Lte(bin)
373        }
374        Expression::Gt(mut bin) => {
375            bin.left = update_column_in_expression(bin.left, canonical_aliases);
376            bin.right = update_column_in_expression(bin.right, canonical_aliases);
377            Expression::Gt(bin)
378        }
379        Expression::Gte(mut bin) => {
380            bin.left = update_column_in_expression(bin.left, canonical_aliases);
381            bin.right = update_column_in_expression(bin.right, canonical_aliases);
382            Expression::Gte(bin)
383        }
384        Expression::Not(mut un) => {
385            un.this = update_column_in_expression(un.this, canonical_aliases);
386            Expression::Not(un)
387        }
388        Expression::Paren(mut paren) => {
389            paren.this = update_column_in_expression(paren.this, canonical_aliases);
390            Expression::Paren(paren)
391        }
392        Expression::Alias(mut alias) => {
393            alias.this = update_column_in_expression(alias.this, canonical_aliases);
394            Expression::Alias(alias)
395        }
396        Expression::Function(mut func) => {
397            func.args = func
398                .args
399                .into_iter()
400                .map(|a| update_column_in_expression(a, canonical_aliases))
401                .collect();
402            Expression::Function(func)
403        }
404        Expression::AggregateFunction(mut agg) => {
405            agg.args = agg
406                .args
407                .into_iter()
408                .map(|a| update_column_in_expression(a, canonical_aliases))
409                .collect();
410            Expression::AggregateFunction(agg)
411        }
412        Expression::Case(mut case) => {
413            case.operand = case
414                .operand
415                .map(|o| update_column_in_expression(o, canonical_aliases));
416            case.whens = case
417                .whens
418                .into_iter()
419                .map(|(w, t)| {
420                    (
421                        update_column_in_expression(w, canonical_aliases),
422                        update_column_in_expression(t, canonical_aliases),
423                    )
424                })
425                .collect();
426            case.else_ = case
427                .else_
428                .map(|e| update_column_in_expression(e, canonical_aliases));
429            Expression::Case(case)
430        }
431        _ => expression,
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::generator::Generator;
439    use crate::parser::Parser;
440
441    fn gen(expr: &Expression) -> String {
442        Generator::new().generate(expr).unwrap()
443    }
444
445    fn parse(sql: &str) -> Expression {
446        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
447    }
448
449    #[test]
450    fn test_qualify_with_db() {
451        let options = QualifyTablesOptions::new().with_db("mydb");
452        let expr = parse("SELECT * FROM users");
453        let qualified = qualify_tables(expr, &options);
454        let sql = gen(&qualified);
455        // Should contain mydb.users
456        assert!(sql.contains("mydb") && sql.contains("users"));
457    }
458
459    #[test]
460    fn test_qualify_with_db_and_catalog() {
461        let options = QualifyTablesOptions::new()
462            .with_db("mydb")
463            .with_catalog("mycatalog");
464        let expr = parse("SELECT * FROM users");
465        let qualified = qualify_tables(expr, &options);
466        let sql = gen(&qualified);
467        // Should contain mycatalog.mydb.users
468        assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
469    }
470
471    #[test]
472    fn test_preserve_existing_schema() {
473        let options = QualifyTablesOptions::new().with_db("default_db");
474        let expr = parse("SELECT * FROM other_db.users");
475        let qualified = qualify_tables(expr, &options);
476        let sql = gen(&qualified);
477        // Should preserve other_db, not add default_db
478        assert!(sql.contains("other_db"));
479        assert!(!sql.contains("default_db"));
480    }
481
482    #[test]
483    fn test_ensure_table_alias() {
484        let options = QualifyTablesOptions::new();
485        let expr = parse("SELECT * FROM users");
486        let qualified = qualify_tables(expr, &options);
487        let sql = gen(&qualified);
488        // Should have alias (AS users)
489        assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
490    }
491
492    #[test]
493    fn test_canonical_aliases() {
494        let options = QualifyTablesOptions::new().with_canonical_aliases();
495        let expr = parse("SELECT u.id FROM users u");
496        let qualified = qualify_tables(expr, &options);
497        let sql = gen(&qualified);
498        // Should use canonical alias like _0
499        assert!(sql.contains("_0"));
500    }
501
502    #[test]
503    fn test_qualify_join() {
504        let options = QualifyTablesOptions::new().with_db("mydb");
505        let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
506        let qualified = qualify_tables(expr, &options);
507        let sql = gen(&qualified);
508        // Both tables should be qualified
509        assert!(sql.contains("mydb"));
510    }
511
512    #[test]
513    fn test_dont_qualify_cte() {
514        let options = QualifyTablesOptions::new().with_db("mydb");
515        let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
516        let qualified = qualify_tables(expr, &options);
517        let sql = gen(&qualified);
518        // CTE reference should not be qualified with mydb
519        // The CTE definition might have mydb, but the SELECT FROM cte should not
520        assert!(sql.contains("cte"));
521    }
522
523    #[test]
524    fn test_qualify_subquery() {
525        let options = QualifyTablesOptions::new().with_db("mydb");
526        let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
527        let qualified = qualify_tables(expr, &options);
528        let sql = gen(&qualified);
529        // Inner table should be qualified
530        assert!(sql.contains("mydb"));
531    }
532}