Skip to main content

polyglot_sql/optimizer/
qualify_tables.rs

1//! Table Qualification Module
2//!
3//! This module provides functionality for qualifying table references in SQL queries
4//! with their database and catalog names.
5//!
6//! Ported from sqlglot's optimizer/qualify_tables.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12    get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16/// Options for table qualification
17#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19    /// Default database name to add to unqualified tables
20    pub db: Option<String>,
21    /// Default catalog name to add to tables that have a db but no catalog
22    pub catalog: Option<String>,
23    /// The dialect to use for normalization
24    pub dialect: Option<DialectType>,
25    /// Whether to use canonical aliases (_0, _1, ...) instead of table names
26    pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn with_db(mut self, db: impl Into<String>) -> Self {
35        self.db = Some(db.into());
36        self
37    }
38
39    pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40        self.catalog = Some(catalog.into());
41        self
42    }
43
44    pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45        self.dialect = Some(dialect);
46        self
47    }
48
49    pub fn with_canonical_aliases(mut self) -> Self {
50        self.canonicalize_table_aliases = true;
51        self
52    }
53}
54
55/// Rewrite SQL AST to have fully qualified tables.
56///
57/// This function:
58/// - Adds database/catalog prefixes to table references
59/// - Ensures all tables have aliases
60/// - Optionally canonicalizes aliases to _0, _1, etc.
61///
62/// # Examples
63///
64/// ```ignore
65/// // SELECT 1 FROM tbl -> SELECT 1 FROM db.tbl AS tbl
66/// let options = QualifyTablesOptions::new().with_db("db");
67/// let qualified = qualify_tables(expression, &options);
68/// ```
69///
70/// # Arguments
71/// * `expression` - The expression to qualify
72/// * `options` - Qualification options
73///
74/// # Returns
75/// The qualified expression
76pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77    let strategy = get_normalization_strategy(options.dialect);
78    let mut next_alias = name_sequence("_");
79
80    match expression {
81        Expression::Select(select) => {
82            let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83            Expression::Select(Box::new(qualified))
84        }
85        Expression::Union(mut union) => {
86            union.left = qualify_tables(union.left, options);
87            union.right = qualify_tables(union.right, options);
88            Expression::Union(union)
89        }
90        Expression::Intersect(mut intersect) => {
91            intersect.left = qualify_tables(intersect.left, options);
92            intersect.right = qualify_tables(intersect.right, options);
93            Expression::Intersect(intersect)
94        }
95        Expression::Except(mut except) => {
96            except.left = qualify_tables(except.left, options);
97            except.right = qualify_tables(except.right, options);
98            Expression::Except(except)
99        }
100        _ => expression,
101    }
102}
103
104/// Qualify a SELECT expression
105fn qualify_select(
106    mut select: Select,
107    options: &QualifyTablesOptions,
108    strategy: NormalizationStrategy,
109    next_alias: &mut impl FnMut() -> String,
110) -> Select {
111    // Collect CTE names to avoid qualifying them
112    let cte_names: HashSet<String> = select
113        .with
114        .as_ref()
115        .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116        .unwrap_or_default();
117
118    // Track canonical aliases if needed
119    let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121    // Qualify CTEs first
122    if let Some(ref mut with) = select.with {
123        for cte in &mut with.ctes {
124            cte.this = qualify_tables(cte.this.clone(), options);
125        }
126    }
127
128    // Qualify tables in FROM clause
129    if let Some(ref mut from) = select.from {
130        for expr in &mut from.expressions {
131            *expr = qualify_table_expression(
132                expr.clone(),
133                options,
134                strategy,
135                &cte_names,
136                &mut canonical_aliases,
137                next_alias,
138            );
139        }
140    }
141
142    // Qualify tables in JOINs
143    for join in &mut select.joins {
144        join.this = qualify_table_expression(
145            join.this.clone(),
146            options,
147            strategy,
148            &cte_names,
149            &mut canonical_aliases,
150            next_alias,
151        );
152    }
153
154    // Update column references if using canonical aliases
155    if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156        select = update_column_references(select, &canonical_aliases);
157    }
158
159    select
160}
161
162/// Qualify a table expression (Table, Subquery, etc.)
163fn qualify_table_expression(
164    expression: Expression,
165    options: &QualifyTablesOptions,
166    strategy: NormalizationStrategy,
167    cte_names: &HashSet<String>,
168    canonical_aliases: &mut HashMap<String, String>,
169    next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171    match expression {
172        Expression::Table(mut table) => {
173            let table_name = table.name.name.clone();
174
175            // Don't qualify CTEs
176            if cte_names.contains(&table_name) {
177                // Still ensure it has an alias
178                ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179                return Expression::Table(table);
180            }
181
182            // Add db if specified and not already present
183            if let Some(ref db) = options.db {
184                if table.schema.is_none() {
185                    table.schema =
186                        Some(normalize_identifier(Identifier::new(db.clone()), strategy));
187                }
188            }
189
190            // Add catalog if specified, db is present, and catalog not already present
191            if let Some(ref catalog) = options.catalog {
192                if table.schema.is_some() && table.catalog.is_none() {
193                    table.catalog = Some(normalize_identifier(
194                        Identifier::new(catalog.clone()),
195                        strategy,
196                    ));
197                }
198            }
199
200            // Ensure the table has an alias
201            ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
202
203            Expression::Table(table)
204        }
205        Expression::Subquery(mut subquery) => {
206            // Qualify the inner query
207            subquery.this = qualify_tables(subquery.this, options);
208
209            // Ensure the subquery has an alias
210            if subquery.alias.is_none() || options.canonicalize_table_aliases {
211                let alias_name = if options.canonicalize_table_aliases {
212                    let new_name = next_alias();
213                    if let Some(ref old_alias) = subquery.alias {
214                        canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
215                    }
216                    new_name
217                } else {
218                    subquery
219                        .alias
220                        .as_ref()
221                        .map(|a| a.name.clone())
222                        .unwrap_or_else(|| next_alias())
223                };
224
225                subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
226            }
227
228            Expression::Subquery(subquery)
229        }
230        Expression::Paren(mut paren) => {
231            paren.this = qualify_table_expression(
232                paren.this,
233                options,
234                strategy,
235                cte_names,
236                canonical_aliases,
237                next_alias,
238            );
239            Expression::Paren(paren)
240        }
241        _ => expression,
242    }
243}
244
245/// Ensure a table has an alias
246fn ensure_table_alias(
247    table: &mut TableRef,
248    strategy: NormalizationStrategy,
249    canonical_aliases: &mut HashMap<String, String>,
250    next_alias: &mut impl FnMut() -> String,
251    options: &QualifyTablesOptions,
252) {
253    let table_name = table.name.name.clone();
254
255    if options.canonicalize_table_aliases {
256        // Use canonical alias (_0, _1, etc.)
257        let new_alias = next_alias();
258        let old_alias = table
259            .alias
260            .as_ref()
261            .map(|a| a.name.clone())
262            .unwrap_or(table_name.clone());
263        canonical_aliases.insert(old_alias, new_alias.clone());
264        table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
265    } else if table.alias.is_none() {
266        // Use table name as alias
267        table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
268    }
269}
270
271/// Update column references to use canonical aliases
272fn update_column_references(
273    mut select: Select,
274    canonical_aliases: &HashMap<String, String>,
275) -> Select {
276    // Update SELECT expressions
277    select.expressions = select
278        .expressions
279        .into_iter()
280        .map(|e| update_column_in_expression(e, canonical_aliases))
281        .collect();
282
283    // Update WHERE
284    if let Some(mut where_clause) = select.where_clause {
285        where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
286        select.where_clause = Some(where_clause);
287    }
288
289    // Update GROUP BY
290    if let Some(mut group_by) = select.group_by {
291        group_by.expressions = group_by
292            .expressions
293            .into_iter()
294            .map(|e| update_column_in_expression(e, canonical_aliases))
295            .collect();
296        select.group_by = Some(group_by);
297    }
298
299    // Update HAVING
300    if let Some(mut having) = select.having {
301        having.this = update_column_in_expression(having.this, canonical_aliases);
302        select.having = Some(having);
303    }
304
305    // Update ORDER BY
306    if let Some(mut order_by) = select.order_by {
307        order_by.expressions = order_by
308            .expressions
309            .into_iter()
310            .map(|mut o| {
311                o.this = update_column_in_expression(o.this, canonical_aliases);
312                o
313            })
314            .collect();
315        select.order_by = Some(order_by);
316    }
317
318    // Update JOIN ON conditions
319    for join in &mut select.joins {
320        if let Some(on) = &mut join.on {
321            *on = update_column_in_expression(on.clone(), canonical_aliases);
322        }
323    }
324
325    select
326}
327
328/// Update column references in an expression
329fn update_column_in_expression(
330    expression: Expression,
331    canonical_aliases: &HashMap<String, String>,
332) -> Expression {
333    match expression {
334        Expression::Column(mut col) => {
335            if let Some(ref table) = col.table {
336                if let Some(canonical) = canonical_aliases.get(&table.name) {
337                    col.table = Some(Identifier {
338                        name: canonical.clone(),
339                        quoted: table.quoted,
340                        trailing_comments: table.trailing_comments.clone(),
341                    });
342                }
343            }
344            Expression::Column(col)
345        }
346        Expression::And(mut bin) => {
347            bin.left = update_column_in_expression(bin.left, canonical_aliases);
348            bin.right = update_column_in_expression(bin.right, canonical_aliases);
349            Expression::And(bin)
350        }
351        Expression::Or(mut bin) => {
352            bin.left = update_column_in_expression(bin.left, canonical_aliases);
353            bin.right = update_column_in_expression(bin.right, canonical_aliases);
354            Expression::Or(bin)
355        }
356        Expression::Eq(mut bin) => {
357            bin.left = update_column_in_expression(bin.left, canonical_aliases);
358            bin.right = update_column_in_expression(bin.right, canonical_aliases);
359            Expression::Eq(bin)
360        }
361        Expression::Neq(mut bin) => {
362            bin.left = update_column_in_expression(bin.left, canonical_aliases);
363            bin.right = update_column_in_expression(bin.right, canonical_aliases);
364            Expression::Neq(bin)
365        }
366        Expression::Lt(mut bin) => {
367            bin.left = update_column_in_expression(bin.left, canonical_aliases);
368            bin.right = update_column_in_expression(bin.right, canonical_aliases);
369            Expression::Lt(bin)
370        }
371        Expression::Lte(mut bin) => {
372            bin.left = update_column_in_expression(bin.left, canonical_aliases);
373            bin.right = update_column_in_expression(bin.right, canonical_aliases);
374            Expression::Lte(bin)
375        }
376        Expression::Gt(mut bin) => {
377            bin.left = update_column_in_expression(bin.left, canonical_aliases);
378            bin.right = update_column_in_expression(bin.right, canonical_aliases);
379            Expression::Gt(bin)
380        }
381        Expression::Gte(mut bin) => {
382            bin.left = update_column_in_expression(bin.left, canonical_aliases);
383            bin.right = update_column_in_expression(bin.right, canonical_aliases);
384            Expression::Gte(bin)
385        }
386        Expression::Not(mut un) => {
387            un.this = update_column_in_expression(un.this, canonical_aliases);
388            Expression::Not(un)
389        }
390        Expression::Paren(mut paren) => {
391            paren.this = update_column_in_expression(paren.this, canonical_aliases);
392            Expression::Paren(paren)
393        }
394        Expression::Alias(mut alias) => {
395            alias.this = update_column_in_expression(alias.this, canonical_aliases);
396            Expression::Alias(alias)
397        }
398        Expression::Function(mut func) => {
399            func.args = func
400                .args
401                .into_iter()
402                .map(|a| update_column_in_expression(a, canonical_aliases))
403                .collect();
404            Expression::Function(func)
405        }
406        Expression::AggregateFunction(mut agg) => {
407            agg.args = agg
408                .args
409                .into_iter()
410                .map(|a| update_column_in_expression(a, canonical_aliases))
411                .collect();
412            Expression::AggregateFunction(agg)
413        }
414        Expression::Case(mut case) => {
415            case.operand = case
416                .operand
417                .map(|o| update_column_in_expression(o, canonical_aliases));
418            case.whens = case
419                .whens
420                .into_iter()
421                .map(|(w, t)| {
422                    (
423                        update_column_in_expression(w, canonical_aliases),
424                        update_column_in_expression(t, canonical_aliases),
425                    )
426                })
427                .collect();
428            case.else_ = case
429                .else_
430                .map(|e| update_column_in_expression(e, canonical_aliases));
431            Expression::Case(case)
432        }
433        _ => expression,
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::generator::Generator;
441    use crate::parser::Parser;
442
443    fn gen(expr: &Expression) -> String {
444        Generator::new().generate(expr).unwrap()
445    }
446
447    fn parse(sql: &str) -> Expression {
448        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
449    }
450
451    #[test]
452    fn test_qualify_with_db() {
453        let options = QualifyTablesOptions::new().with_db("mydb");
454        let expr = parse("SELECT * FROM users");
455        let qualified = qualify_tables(expr, &options);
456        let sql = gen(&qualified);
457        // Should contain mydb.users
458        assert!(sql.contains("mydb") && sql.contains("users"));
459    }
460
461    #[test]
462    fn test_qualify_with_db_and_catalog() {
463        let options = QualifyTablesOptions::new()
464            .with_db("mydb")
465            .with_catalog("mycatalog");
466        let expr = parse("SELECT * FROM users");
467        let qualified = qualify_tables(expr, &options);
468        let sql = gen(&qualified);
469        // Should contain mycatalog.mydb.users
470        assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
471    }
472
473    #[test]
474    fn test_preserve_existing_schema() {
475        let options = QualifyTablesOptions::new().with_db("default_db");
476        let expr = parse("SELECT * FROM other_db.users");
477        let qualified = qualify_tables(expr, &options);
478        let sql = gen(&qualified);
479        // Should preserve other_db, not add default_db
480        assert!(sql.contains("other_db"));
481        assert!(!sql.contains("default_db"));
482    }
483
484    #[test]
485    fn test_ensure_table_alias() {
486        let options = QualifyTablesOptions::new();
487        let expr = parse("SELECT * FROM users");
488        let qualified = qualify_tables(expr, &options);
489        let sql = gen(&qualified);
490        // Should have alias (AS users)
491        assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
492    }
493
494    #[test]
495    fn test_canonical_aliases() {
496        let options = QualifyTablesOptions::new().with_canonical_aliases();
497        let expr = parse("SELECT u.id FROM users u");
498        let qualified = qualify_tables(expr, &options);
499        let sql = gen(&qualified);
500        // Should use canonical alias like _0
501        assert!(sql.contains("_0"));
502    }
503
504    #[test]
505    fn test_qualify_join() {
506        let options = QualifyTablesOptions::new().with_db("mydb");
507        let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
508        let qualified = qualify_tables(expr, &options);
509        let sql = gen(&qualified);
510        // Both tables should be qualified
511        assert!(sql.contains("mydb"));
512    }
513
514    #[test]
515    fn test_dont_qualify_cte() {
516        let options = QualifyTablesOptions::new().with_db("mydb");
517        let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
518        let qualified = qualify_tables(expr, &options);
519        let sql = gen(&qualified);
520        // CTE reference should not be qualified with mydb
521        // The CTE definition might have mydb, but the SELECT FROM cte should not
522        assert!(sql.contains("cte"));
523    }
524
525    #[test]
526    fn test_qualify_subquery() {
527        let options = QualifyTablesOptions::new().with_db("mydb");
528        let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
529        let qualified = qualify_tables(expr, &options);
530        let sql = gen(&qualified);
531        // Inner table should be qualified
532        assert!(sql.contains("mydb"));
533    }
534}