Skip to main content

polyglot_sql/optimizer/
normalize_identifiers.rs

1//! Identifier Normalization Module
2//!
3//! This module provides functionality for normalizing identifiers in SQL queries
4//! based on dialect-specific rules for case sensitivity and quoting.
5//!
6//! Ported from sqlglot's optimizer/normalize_identifiers.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier, Null};
10
11/// Strategy for normalizing identifiers based on dialect rules.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14    /// Unquoted identifiers are lowercased (e.g., PostgreSQL)
15    Lowercase,
16    /// Unquoted identifiers are uppercased (e.g., Oracle, Snowflake)
17    Uppercase,
18    /// Always case-sensitive, regardless of quotes (e.g., MySQL on Linux)
19    CaseSensitive,
20    /// Always case-insensitive (lowercase), regardless of quotes (e.g., Spark, BigQuery)
21    CaseInsensitive,
22    /// Always case-insensitive (uppercase), regardless of quotes
23    CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27    fn default() -> Self {
28        Self::Lowercase
29    }
30}
31
32/// Get the normalization strategy for a dialect
33pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34    match dialect {
35        // Uppercase dialects
36        Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37            NormalizationStrategy::Uppercase
38        }
39        // Case-sensitive dialects
40        Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41            NormalizationStrategy::CaseSensitive
42        }
43        // Case-insensitive dialects (lowercase)
44        Some(DialectType::DuckDB)
45        | Some(DialectType::SQLite)
46        | Some(DialectType::BigQuery)
47        | Some(DialectType::Presto)
48        | Some(DialectType::Trino)
49        | Some(DialectType::Hive)
50        | Some(DialectType::Spark)
51        | Some(DialectType::Databricks)
52        | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53        // Default: lowercase (PostgreSQL-like behavior)
54        _ => NormalizationStrategy::Lowercase,
55    }
56}
57
58/// Normalize identifiers in an expression based on dialect rules.
59///
60/// This transformation reflects how identifiers would be resolved by the engine
61/// corresponding to each SQL dialect. For example:
62/// - `FoO` → `foo` in PostgreSQL (lowercases unquoted)
63/// - `FoO` → `FOO` in Snowflake (uppercases unquoted)
64/// - `"FoO"` → `FoO` preserved when quoted (case-sensitive)
65///
66/// # Arguments
67/// * `expression` - The expression to normalize
68/// * `dialect` - The dialect to use for normalization rules
69///
70/// # Returns
71/// The expression with normalized identifiers
72pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73    let strategy = get_normalization_strategy(dialect);
74    normalize_expression(expression, strategy)
75}
76
77/// Normalize a single identifier based on the strategy.
78///
79/// # Arguments
80/// * `identifier` - The identifier to normalize
81/// * `strategy` - The normalization strategy to use
82///
83/// # Returns
84/// The normalized identifier
85pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86    // Case-sensitive strategy: never normalize
87    if strategy == NormalizationStrategy::CaseSensitive {
88        return identifier;
89    }
90
91    // If quoted and not case-insensitive, don't normalize
92    if identifier.quoted
93        && strategy != NormalizationStrategy::CaseInsensitive
94        && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95    {
96        return identifier;
97    }
98
99    // Normalize the identifier name
100    let normalized_name = match strategy {
101        NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102            identifier.name.to_uppercase()
103        }
104        NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105            identifier.name.to_lowercase()
106        }
107        NormalizationStrategy::CaseSensitive => identifier.name, // Should not reach here
108    };
109
110    Identifier {
111        name: normalized_name,
112        quoted: identifier.quoted,
113        trailing_comments: identifier.trailing_comments,
114        span: None,
115    }
116}
117
118/// Recursively normalize all identifiers in an expression.
119fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120    match expression {
121        Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122        Expression::Column(col) => Expression::boxed_column(Column {
123            name: normalize_identifier(col.name, strategy),
124            table: col.table.map(|t| normalize_identifier(t, strategy)),
125            join_mark: col.join_mark,
126            trailing_comments: col.trailing_comments,
127            span: None,
128            inferred_type: None,
129        }),
130        Expression::Table(mut table) => {
131            table.name = normalize_identifier(table.name, strategy);
132            if let Some(schema) = table.schema {
133                table.schema = Some(normalize_identifier(schema, strategy));
134            }
135            if let Some(catalog) = table.catalog {
136                table.catalog = Some(normalize_identifier(catalog, strategy));
137            }
138            if let Some(alias) = table.alias {
139                table.alias = Some(normalize_identifier(alias, strategy));
140            }
141            table.column_aliases = table
142                .column_aliases
143                .into_iter()
144                .map(|a| normalize_identifier(a, strategy))
145                .collect();
146            Expression::Table(table)
147        }
148        Expression::Select(select) => {
149            let mut select = *select;
150            // Normalize SELECT expressions
151            select.expressions = select
152                .expressions
153                .into_iter()
154                .map(|e| normalize_expression(e, strategy))
155                .collect();
156            // Normalize FROM
157            if let Some(mut from) = select.from {
158                from.expressions = from
159                    .expressions
160                    .into_iter()
161                    .map(|e| normalize_expression(e, strategy))
162                    .collect();
163                select.from = Some(from);
164            }
165            // Normalize JOINs
166            select.joins = select
167                .joins
168                .into_iter()
169                .map(|mut j| {
170                    j.this = normalize_expression(j.this, strategy);
171                    if let Some(on) = j.on {
172                        j.on = Some(normalize_expression(on, strategy));
173                    }
174                    j
175                })
176                .collect();
177            // Normalize WHERE
178            if let Some(mut where_clause) = select.where_clause {
179                where_clause.this = normalize_expression(where_clause.this, strategy);
180                select.where_clause = Some(where_clause);
181            }
182            // Normalize GROUP BY
183            if let Some(mut group_by) = select.group_by {
184                group_by.expressions = group_by
185                    .expressions
186                    .into_iter()
187                    .map(|e| normalize_expression(e, strategy))
188                    .collect();
189                select.group_by = Some(group_by);
190            }
191            // Normalize HAVING
192            if let Some(mut having) = select.having {
193                having.this = normalize_expression(having.this, strategy);
194                select.having = Some(having);
195            }
196            // Normalize ORDER BY
197            if let Some(mut order_by) = select.order_by {
198                order_by.expressions = order_by
199                    .expressions
200                    .into_iter()
201                    .map(|mut o| {
202                        o.this = normalize_expression(o.this, strategy);
203                        o
204                    })
205                    .collect();
206                select.order_by = Some(order_by);
207            }
208            Expression::Select(Box::new(select))
209        }
210        Expression::Alias(alias) => {
211            let mut alias = *alias;
212            alias.this = normalize_expression(alias.this, strategy);
213            alias.alias = normalize_identifier(alias.alias, strategy);
214            Expression::Alias(Box::new(alias))
215        }
216        // Binary operations
217        Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
218        Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
219        Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
220        Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
221        Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
222        Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
223        Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
224        Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
225        Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
226        Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
227        Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
228        Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
229        Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
230        Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
231        // Unary operations
232        Expression::Not(un) => {
233            let mut un = *un;
234            un.this = normalize_expression(un.this, strategy);
235            Expression::Not(Box::new(un))
236        }
237        Expression::Neg(un) => {
238            let mut un = *un;
239            un.this = normalize_expression(un.this, strategy);
240            Expression::Neg(Box::new(un))
241        }
242        // Functions
243        Expression::Function(func) => {
244            let mut func = *func;
245            func.args = func
246                .args
247                .into_iter()
248                .map(|e| normalize_expression(e, strategy))
249                .collect();
250            Expression::Function(Box::new(func))
251        }
252        Expression::AggregateFunction(agg) => {
253            let mut agg = *agg;
254            agg.args = agg
255                .args
256                .into_iter()
257                .map(|e| normalize_expression(e, strategy))
258                .collect();
259            Expression::AggregateFunction(Box::new(agg))
260        }
261        // Other expressions with children
262        Expression::Paren(paren) => {
263            let mut paren = *paren;
264            paren.this = normalize_expression(paren.this, strategy);
265            Expression::Paren(Box::new(paren))
266        }
267        Expression::Case(case) => {
268            let mut case = *case;
269            case.operand = case.operand.map(|e| normalize_expression(e, strategy));
270            case.whens = case
271                .whens
272                .into_iter()
273                .map(|(w, t)| {
274                    (
275                        normalize_expression(w, strategy),
276                        normalize_expression(t, strategy),
277                    )
278                })
279                .collect();
280            case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
281            Expression::Case(Box::new(case))
282        }
283        Expression::Cast(cast) => {
284            let mut cast = *cast;
285            cast.this = normalize_expression(cast.this, strategy);
286            Expression::Cast(Box::new(cast))
287        }
288        Expression::In(in_expr) => {
289            let mut in_expr = *in_expr;
290            in_expr.this = normalize_expression(in_expr.this, strategy);
291            in_expr.expressions = in_expr
292                .expressions
293                .into_iter()
294                .map(|e| normalize_expression(e, strategy))
295                .collect();
296            if let Some(q) = in_expr.query {
297                in_expr.query = Some(normalize_expression(q, strategy));
298            }
299            Expression::In(Box::new(in_expr))
300        }
301        Expression::Between(between) => {
302            let mut between = *between;
303            between.this = normalize_expression(between.this, strategy);
304            between.low = normalize_expression(between.low, strategy);
305            between.high = normalize_expression(between.high, strategy);
306            Expression::Between(Box::new(between))
307        }
308        Expression::Subquery(subquery) => {
309            let mut subquery = *subquery;
310            subquery.this = normalize_expression(subquery.this, strategy);
311            if let Some(alias) = subquery.alias {
312                subquery.alias = Some(normalize_identifier(alias, strategy));
313            }
314            Expression::Subquery(Box::new(subquery))
315        }
316        // Set operations
317        Expression::Union(mut union) => {
318            let left = std::mem::replace(&mut union.left, Expression::Null(Null));
319            union.left = normalize_expression(left, strategy);
320            let right = std::mem::replace(&mut union.right, Expression::Null(Null));
321            union.right = normalize_expression(right, strategy);
322            Expression::Union(union)
323        }
324        Expression::Intersect(mut intersect) => {
325            let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
326            intersect.left = normalize_expression(left, strategy);
327            let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
328            intersect.right = normalize_expression(right, strategy);
329            Expression::Intersect(intersect)
330        }
331        Expression::Except(mut except) => {
332            let left = std::mem::replace(&mut except.left, Expression::Null(Null));
333            except.left = normalize_expression(left, strategy);
334            let right = std::mem::replace(&mut except.right, Expression::Null(Null));
335            except.right = normalize_expression(right, strategy);
336            Expression::Except(except)
337        }
338        // Leaf nodes and others - return unchanged
339        _ => expression,
340    }
341}
342
343/// Helper to normalize binary operations
344fn normalize_binary<F>(
345    constructor: F,
346    mut bin: crate::expressions::BinaryOp,
347    strategy: NormalizationStrategy,
348) -> Expression
349where
350    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
351{
352    bin.left = normalize_expression(bin.left, strategy);
353    bin.right = normalize_expression(bin.right, strategy);
354    constructor(Box::new(bin))
355}
356
357/// Check if an identifier contains case-sensitive characters based on dialect rules.
358pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
359    match strategy {
360        NormalizationStrategy::CaseInsensitive
361        | NormalizationStrategy::CaseInsensitiveUppercase => false,
362        NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
363        NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
364        NormalizationStrategy::CaseSensitive => true,
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::generator::Generator;
372    use crate::parser::Parser;
373
374    fn gen(expr: &Expression) -> String {
375        Generator::new().generate(expr).unwrap()
376    }
377
378    fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
379        let ast = Parser::parse_sql(sql).expect("Failed to parse");
380        let normalized = normalize_identifiers(ast[0].clone(), dialect);
381        gen(&normalized)
382    }
383
384    #[test]
385    fn test_normalize_lowercase() {
386        // PostgreSQL-like: lowercase unquoted identifiers
387        let result = parse_and_normalize("SELECT FoO FROM Bar", None);
388        assert!(result.contains("foo") || result.contains("FOO")); // normalized
389    }
390
391    #[test]
392    fn test_normalize_uppercase() {
393        // Snowflake: uppercase unquoted identifiers
394        let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
395        // Should contain uppercase versions
396        assert!(result.to_uppercase().contains("FOO"));
397    }
398
399    #[test]
400    fn test_normalize_preserves_quoted() {
401        // Quoted identifiers should be preserved in non-case-insensitive dialects
402        let id = Identifier {
403            name: "FoO".to_string(),
404            quoted: true,
405            trailing_comments: vec![],
406            span: None,
407        };
408        let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
409        assert_eq!(normalized.name, "FoO"); // Preserved
410    }
411
412    #[test]
413    fn test_case_insensitive_normalizes_quoted() {
414        // In case-insensitive dialects, even quoted identifiers are normalized
415        let id = Identifier {
416            name: "FoO".to_string(),
417            quoted: true,
418            trailing_comments: vec![],
419            span: None,
420        };
421        let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
422        assert_eq!(normalized.name, "foo"); // Lowercased
423    }
424
425    #[test]
426    fn test_case_sensitive_no_normalization() {
427        // Case-sensitive dialects don't normalize at all
428        let id = Identifier {
429            name: "FoO".to_string(),
430            quoted: false,
431            trailing_comments: vec![],
432            span: None,
433        };
434        let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
435        assert_eq!(normalized.name, "FoO"); // Unchanged
436    }
437
438    #[test]
439    fn test_normalize_column() {
440        let col = Expression::boxed_column(Column {
441            name: Identifier::new("MyColumn"),
442            table: Some(Identifier::new("MyTable")),
443            join_mark: false,
444            trailing_comments: vec![],
445            span: None,
446            inferred_type: None,
447        });
448
449        let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
450        let sql = gen(&normalized);
451        assert!(sql.contains("mycolumn") || sql.contains("mytable"));
452    }
453
454    #[test]
455    fn test_get_normalization_strategy() {
456        assert_eq!(
457            get_normalization_strategy(Some(DialectType::Snowflake)),
458            NormalizationStrategy::Uppercase
459        );
460        assert_eq!(
461            get_normalization_strategy(Some(DialectType::PostgreSQL)),
462            NormalizationStrategy::Lowercase
463        );
464        assert_eq!(
465            get_normalization_strategy(Some(DialectType::MySQL)),
466            NormalizationStrategy::CaseSensitive
467        );
468        assert_eq!(
469            get_normalization_strategy(Some(DialectType::DuckDB)),
470            NormalizationStrategy::CaseInsensitive
471        );
472    }
473}