Skip to main content

polyglot_sql/optimizer/
normalize_identifiers.rs

1//! Identifier Normalization Module
2//!
3//! This module provides functionality for normalizing identifiers in SQL queries
4//! based on dialect-specific rules for case sensitivity and quoting.
5//!
6//! Ported from sqlglot's optimizer/normalize_identifiers.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11/// Strategy for normalizing identifiers based on dialect rules.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14    /// Unquoted identifiers are lowercased (e.g., PostgreSQL)
15    Lowercase,
16    /// Unquoted identifiers are uppercased (e.g., Oracle, Snowflake)
17    Uppercase,
18    /// Always case-sensitive, regardless of quotes (e.g., MySQL on Linux)
19    CaseSensitive,
20    /// Always case-insensitive (lowercase), regardless of quotes (e.g., Spark, BigQuery)
21    CaseInsensitive,
22    /// Always case-insensitive (uppercase), regardless of quotes
23    CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27    fn default() -> Self {
28        Self::Lowercase
29    }
30}
31
32/// Get the normalization strategy for a dialect
33pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34    match dialect {
35        // Uppercase dialects
36        Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37            NormalizationStrategy::Uppercase
38        }
39        // Case-sensitive dialects
40        Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41            NormalizationStrategy::CaseSensitive
42        }
43        // Case-insensitive dialects (lowercase)
44        Some(DialectType::DuckDB)
45        | Some(DialectType::SQLite)
46        | Some(DialectType::BigQuery)
47        | Some(DialectType::Presto)
48        | Some(DialectType::Trino)
49        | Some(DialectType::Hive)
50        | Some(DialectType::Spark)
51        | Some(DialectType::Databricks)
52        | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53        // Default: lowercase (PostgreSQL-like behavior)
54        _ => NormalizationStrategy::Lowercase,
55    }
56}
57
58/// Normalize identifiers in an expression based on dialect rules.
59///
60/// This transformation reflects how identifiers would be resolved by the engine
61/// corresponding to each SQL dialect. For example:
62/// - `FoO` → `foo` in PostgreSQL (lowercases unquoted)
63/// - `FoO` → `FOO` in Snowflake (uppercases unquoted)
64/// - `"FoO"` → `FoO` preserved when quoted (case-sensitive)
65///
66/// # Arguments
67/// * `expression` - The expression to normalize
68/// * `dialect` - The dialect to use for normalization rules
69///
70/// # Returns
71/// The expression with normalized identifiers
72pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73    let strategy = get_normalization_strategy(dialect);
74    normalize_expression(expression, strategy)
75}
76
77/// Normalize a single identifier based on the strategy.
78///
79/// # Arguments
80/// * `identifier` - The identifier to normalize
81/// * `strategy` - The normalization strategy to use
82///
83/// # Returns
84/// The normalized identifier
85pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86    // Case-sensitive strategy: never normalize
87    if strategy == NormalizationStrategy::CaseSensitive {
88        return identifier;
89    }
90
91    // If quoted and not case-insensitive, don't normalize
92    if identifier.quoted
93        && strategy != NormalizationStrategy::CaseInsensitive
94        && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95    {
96        return identifier;
97    }
98
99    // Normalize the identifier name
100    let normalized_name = match strategy {
101        NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102            identifier.name.to_uppercase()
103        }
104        NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105            identifier.name.to_lowercase()
106        }
107        NormalizationStrategy::CaseSensitive => identifier.name, // Should not reach here
108    };
109
110    Identifier {
111        name: normalized_name,
112        quoted: identifier.quoted,
113        trailing_comments: identifier.trailing_comments,
114        span: None,
115    }
116}
117
118/// Recursively normalize all identifiers in an expression.
119fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120    match expression {
121        Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122        Expression::Column(col) => Expression::Column(Column {
123            name: normalize_identifier(col.name, strategy),
124            table: col.table.map(|t| normalize_identifier(t, strategy)),
125            join_mark: col.join_mark,
126            trailing_comments: col.trailing_comments,
127            span: None,
128            inferred_type: None,
129        }),
130        Expression::Table(mut table) => {
131            table.name = normalize_identifier(table.name, strategy);
132            if let Some(schema) = table.schema {
133                table.schema = Some(normalize_identifier(schema, strategy));
134            }
135            if let Some(catalog) = table.catalog {
136                table.catalog = Some(normalize_identifier(catalog, strategy));
137            }
138            if let Some(alias) = table.alias {
139                table.alias = Some(normalize_identifier(alias, strategy));
140            }
141            table.column_aliases = table
142                .column_aliases
143                .into_iter()
144                .map(|a| normalize_identifier(a, strategy))
145                .collect();
146            Expression::Table(table)
147        }
148        Expression::Select(select) => {
149            let mut select = *select;
150            // Normalize SELECT expressions
151            select.expressions = select
152                .expressions
153                .into_iter()
154                .map(|e| normalize_expression(e, strategy))
155                .collect();
156            // Normalize FROM
157            if let Some(mut from) = select.from {
158                from.expressions = from
159                    .expressions
160                    .into_iter()
161                    .map(|e| normalize_expression(e, strategy))
162                    .collect();
163                select.from = Some(from);
164            }
165            // Normalize JOINs
166            select.joins = select
167                .joins
168                .into_iter()
169                .map(|mut j| {
170                    j.this = normalize_expression(j.this, strategy);
171                    if let Some(on) = j.on {
172                        j.on = Some(normalize_expression(on, strategy));
173                    }
174                    j
175                })
176                .collect();
177            // Normalize WHERE
178            if let Some(mut where_clause) = select.where_clause {
179                where_clause.this = normalize_expression(where_clause.this, strategy);
180                select.where_clause = Some(where_clause);
181            }
182            // Normalize GROUP BY
183            if let Some(mut group_by) = select.group_by {
184                group_by.expressions = group_by
185                    .expressions
186                    .into_iter()
187                    .map(|e| normalize_expression(e, strategy))
188                    .collect();
189                select.group_by = Some(group_by);
190            }
191            // Normalize HAVING
192            if let Some(mut having) = select.having {
193                having.this = normalize_expression(having.this, strategy);
194                select.having = Some(having);
195            }
196            // Normalize ORDER BY
197            if let Some(mut order_by) = select.order_by {
198                order_by.expressions = order_by
199                    .expressions
200                    .into_iter()
201                    .map(|mut o| {
202                        o.this = normalize_expression(o.this, strategy);
203                        o
204                    })
205                    .collect();
206                select.order_by = Some(order_by);
207            }
208            Expression::Select(Box::new(select))
209        }
210        Expression::Alias(alias) => {
211            let mut alias = *alias;
212            alias.this = normalize_expression(alias.this, strategy);
213            alias.alias = normalize_identifier(alias.alias, strategy);
214            Expression::Alias(Box::new(alias))
215        }
216        // Binary operations
217        Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
218        Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
219        Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
220        Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
221        Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
222        Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
223        Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
224        Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
225        Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
226        Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
227        Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
228        Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
229        Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
230        Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
231        // Unary operations
232        Expression::Not(un) => {
233            let mut un = *un;
234            un.this = normalize_expression(un.this, strategy);
235            Expression::Not(Box::new(un))
236        }
237        Expression::Neg(un) => {
238            let mut un = *un;
239            un.this = normalize_expression(un.this, strategy);
240            Expression::Neg(Box::new(un))
241        }
242        // Functions
243        Expression::Function(func) => {
244            let mut func = *func;
245            func.args = func
246                .args
247                .into_iter()
248                .map(|e| normalize_expression(e, strategy))
249                .collect();
250            Expression::Function(Box::new(func))
251        }
252        Expression::AggregateFunction(agg) => {
253            let mut agg = *agg;
254            agg.args = agg
255                .args
256                .into_iter()
257                .map(|e| normalize_expression(e, strategy))
258                .collect();
259            Expression::AggregateFunction(Box::new(agg))
260        }
261        // Other expressions with children
262        Expression::Paren(paren) => {
263            let mut paren = *paren;
264            paren.this = normalize_expression(paren.this, strategy);
265            Expression::Paren(Box::new(paren))
266        }
267        Expression::Case(case) => {
268            let mut case = *case;
269            case.operand = case.operand.map(|e| normalize_expression(e, strategy));
270            case.whens = case
271                .whens
272                .into_iter()
273                .map(|(w, t)| {
274                    (
275                        normalize_expression(w, strategy),
276                        normalize_expression(t, strategy),
277                    )
278                })
279                .collect();
280            case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
281            Expression::Case(Box::new(case))
282        }
283        Expression::Cast(cast) => {
284            let mut cast = *cast;
285            cast.this = normalize_expression(cast.this, strategy);
286            Expression::Cast(Box::new(cast))
287        }
288        Expression::In(in_expr) => {
289            let mut in_expr = *in_expr;
290            in_expr.this = normalize_expression(in_expr.this, strategy);
291            in_expr.expressions = in_expr
292                .expressions
293                .into_iter()
294                .map(|e| normalize_expression(e, strategy))
295                .collect();
296            if let Some(q) = in_expr.query {
297                in_expr.query = Some(normalize_expression(q, strategy));
298            }
299            Expression::In(Box::new(in_expr))
300        }
301        Expression::Between(between) => {
302            let mut between = *between;
303            between.this = normalize_expression(between.this, strategy);
304            between.low = normalize_expression(between.low, strategy);
305            between.high = normalize_expression(between.high, strategy);
306            Expression::Between(Box::new(between))
307        }
308        Expression::Subquery(subquery) => {
309            let mut subquery = *subquery;
310            subquery.this = normalize_expression(subquery.this, strategy);
311            if let Some(alias) = subquery.alias {
312                subquery.alias = Some(normalize_identifier(alias, strategy));
313            }
314            Expression::Subquery(Box::new(subquery))
315        }
316        // Set operations
317        Expression::Union(union) => {
318            let mut union = *union;
319            union.left = normalize_expression(union.left, strategy);
320            union.right = normalize_expression(union.right, strategy);
321            Expression::Union(Box::new(union))
322        }
323        Expression::Intersect(intersect) => {
324            let mut intersect = *intersect;
325            intersect.left = normalize_expression(intersect.left, strategy);
326            intersect.right = normalize_expression(intersect.right, strategy);
327            Expression::Intersect(Box::new(intersect))
328        }
329        Expression::Except(except) => {
330            let mut except = *except;
331            except.left = normalize_expression(except.left, strategy);
332            except.right = normalize_expression(except.right, strategy);
333            Expression::Except(Box::new(except))
334        }
335        // Leaf nodes and others - return unchanged
336        _ => expression,
337    }
338}
339
340/// Helper to normalize binary operations
341fn normalize_binary<F>(
342    constructor: F,
343    mut bin: crate::expressions::BinaryOp,
344    strategy: NormalizationStrategy,
345) -> Expression
346where
347    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
348{
349    bin.left = normalize_expression(bin.left, strategy);
350    bin.right = normalize_expression(bin.right, strategy);
351    constructor(Box::new(bin))
352}
353
354/// Check if an identifier contains case-sensitive characters based on dialect rules.
355pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
356    match strategy {
357        NormalizationStrategy::CaseInsensitive
358        | NormalizationStrategy::CaseInsensitiveUppercase => false,
359        NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
360        NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
361        NormalizationStrategy::CaseSensitive => true,
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::generator::Generator;
369    use crate::parser::Parser;
370
371    fn gen(expr: &Expression) -> String {
372        Generator::new().generate(expr).unwrap()
373    }
374
375    fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
376        let ast = Parser::parse_sql(sql).expect("Failed to parse");
377        let normalized = normalize_identifiers(ast[0].clone(), dialect);
378        gen(&normalized)
379    }
380
381    #[test]
382    fn test_normalize_lowercase() {
383        // PostgreSQL-like: lowercase unquoted identifiers
384        let result = parse_and_normalize("SELECT FoO FROM Bar", None);
385        assert!(result.contains("foo") || result.contains("FOO")); // normalized
386    }
387
388    #[test]
389    fn test_normalize_uppercase() {
390        // Snowflake: uppercase unquoted identifiers
391        let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
392        // Should contain uppercase versions
393        assert!(result.to_uppercase().contains("FOO"));
394    }
395
396    #[test]
397    fn test_normalize_preserves_quoted() {
398        // Quoted identifiers should be preserved in non-case-insensitive dialects
399        let id = Identifier {
400            name: "FoO".to_string(),
401            quoted: true,
402            trailing_comments: vec![],
403            span: None,
404        };
405        let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
406        assert_eq!(normalized.name, "FoO"); // Preserved
407    }
408
409    #[test]
410    fn test_case_insensitive_normalizes_quoted() {
411        // In case-insensitive dialects, even quoted identifiers are normalized
412        let id = Identifier {
413            name: "FoO".to_string(),
414            quoted: true,
415            trailing_comments: vec![],
416            span: None,
417        };
418        let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
419        assert_eq!(normalized.name, "foo"); // Lowercased
420    }
421
422    #[test]
423    fn test_case_sensitive_no_normalization() {
424        // Case-sensitive dialects don't normalize at all
425        let id = Identifier {
426            name: "FoO".to_string(),
427            quoted: false,
428            trailing_comments: vec![],
429            span: None,
430        };
431        let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
432        assert_eq!(normalized.name, "FoO"); // Unchanged
433    }
434
435    #[test]
436    fn test_normalize_column() {
437        let col = Expression::Column(Column {
438            name: Identifier::new("MyColumn"),
439            table: Some(Identifier::new("MyTable")),
440            join_mark: false,
441            trailing_comments: vec![],
442            span: None,
443            inferred_type: None,
444        });
445
446        let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
447        let sql = gen(&normalized);
448        assert!(sql.contains("mycolumn") || sql.contains("mytable"));
449    }
450
451    #[test]
452    fn test_get_normalization_strategy() {
453        assert_eq!(
454            get_normalization_strategy(Some(DialectType::Snowflake)),
455            NormalizationStrategy::Uppercase
456        );
457        assert_eq!(
458            get_normalization_strategy(Some(DialectType::PostgreSQL)),
459            NormalizationStrategy::Lowercase
460        );
461        assert_eq!(
462            get_normalization_strategy(Some(DialectType::MySQL)),
463            NormalizationStrategy::CaseSensitive
464        );
465        assert_eq!(
466            get_normalization_strategy(Some(DialectType::DuckDB)),
467            NormalizationStrategy::CaseInsensitive
468        );
469    }
470}