Skip to main content

polyglot_sql/optimizer/
normalize_identifiers.rs

1//! Identifier Normalization Module
2//!
3//! This module provides functionality for normalizing identifiers in SQL queries
4//! based on dialect-specific rules for case sensitivity and quoting.
5//!
6//! Ported from sqlglot's optimizer/normalize_identifiers.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11/// Strategy for normalizing identifiers based on dialect rules.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14    /// Unquoted identifiers are lowercased (e.g., PostgreSQL)
15    Lowercase,
16    /// Unquoted identifiers are uppercased (e.g., Oracle, Snowflake)
17    Uppercase,
18    /// Always case-sensitive, regardless of quotes (e.g., MySQL on Linux)
19    CaseSensitive,
20    /// Always case-insensitive (lowercase), regardless of quotes (e.g., Spark, BigQuery)
21    CaseInsensitive,
22    /// Always case-insensitive (uppercase), regardless of quotes
23    CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27    fn default() -> Self {
28        Self::Lowercase
29    }
30}
31
32/// Get the normalization strategy for a dialect
33pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34    match dialect {
35        // Uppercase dialects
36        Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37            NormalizationStrategy::Uppercase
38        }
39        // Case-sensitive dialects
40        Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41            NormalizationStrategy::CaseSensitive
42        }
43        // Case-insensitive dialects (lowercase)
44        Some(DialectType::DuckDB)
45        | Some(DialectType::SQLite)
46        | Some(DialectType::BigQuery)
47        | Some(DialectType::Presto)
48        | Some(DialectType::Trino)
49        | Some(DialectType::Hive)
50        | Some(DialectType::Spark)
51        | Some(DialectType::Databricks)
52        | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53        // Default: lowercase (PostgreSQL-like behavior)
54        _ => NormalizationStrategy::Lowercase,
55    }
56}
57
58/// Normalize identifiers in an expression based on dialect rules.
59///
60/// This transformation reflects how identifiers would be resolved by the engine
61/// corresponding to each SQL dialect. For example:
62/// - `FoO` → `foo` in PostgreSQL (lowercases unquoted)
63/// - `FoO` → `FOO` in Snowflake (uppercases unquoted)
64/// - `"FoO"` → `FoO` preserved when quoted (case-sensitive)
65///
66/// # Arguments
67/// * `expression` - The expression to normalize
68/// * `dialect` - The dialect to use for normalization rules
69///
70/// # Returns
71/// The expression with normalized identifiers
72pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73    let strategy = get_normalization_strategy(dialect);
74    normalize_expression(expression, strategy)
75}
76
77/// Normalize a single identifier based on the strategy.
78///
79/// # Arguments
80/// * `identifier` - The identifier to normalize
81/// * `strategy` - The normalization strategy to use
82///
83/// # Returns
84/// The normalized identifier
85pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86    // Case-sensitive strategy: never normalize
87    if strategy == NormalizationStrategy::CaseSensitive {
88        return identifier;
89    }
90
91    // If quoted and not case-insensitive, don't normalize
92    if identifier.quoted
93        && strategy != NormalizationStrategy::CaseInsensitive
94        && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95    {
96        return identifier;
97    }
98
99    // Normalize the identifier name
100    let normalized_name = match strategy {
101        NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102            identifier.name.to_uppercase()
103        }
104        NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105            identifier.name.to_lowercase()
106        }
107        NormalizationStrategy::CaseSensitive => identifier.name, // Should not reach here
108    };
109
110    Identifier {
111        name: normalized_name,
112        quoted: identifier.quoted,
113        trailing_comments: identifier.trailing_comments,
114    }
115}
116
117/// Recursively normalize all identifiers in an expression.
118fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
119    match expression {
120        Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
121        Expression::Column(col) => Expression::Column(Column {
122            name: normalize_identifier(col.name, strategy),
123            table: col.table.map(|t| normalize_identifier(t, strategy)),
124            join_mark: col.join_mark,
125            trailing_comments: col.trailing_comments,
126        }),
127        Expression::Table(mut table) => {
128            table.name = normalize_identifier(table.name, strategy);
129            if let Some(schema) = table.schema {
130                table.schema = Some(normalize_identifier(schema, strategy));
131            }
132            if let Some(catalog) = table.catalog {
133                table.catalog = Some(normalize_identifier(catalog, strategy));
134            }
135            if let Some(alias) = table.alias {
136                table.alias = Some(normalize_identifier(alias, strategy));
137            }
138            table.column_aliases = table
139                .column_aliases
140                .into_iter()
141                .map(|a| normalize_identifier(a, strategy))
142                .collect();
143            Expression::Table(table)
144        }
145        Expression::Select(select) => {
146            let mut select = *select;
147            // Normalize SELECT expressions
148            select.expressions = select
149                .expressions
150                .into_iter()
151                .map(|e| normalize_expression(e, strategy))
152                .collect();
153            // Normalize FROM
154            if let Some(mut from) = select.from {
155                from.expressions = from
156                    .expressions
157                    .into_iter()
158                    .map(|e| normalize_expression(e, strategy))
159                    .collect();
160                select.from = Some(from);
161            }
162            // Normalize JOINs
163            select.joins = select
164                .joins
165                .into_iter()
166                .map(|mut j| {
167                    j.this = normalize_expression(j.this, strategy);
168                    if let Some(on) = j.on {
169                        j.on = Some(normalize_expression(on, strategy));
170                    }
171                    j
172                })
173                .collect();
174            // Normalize WHERE
175            if let Some(mut where_clause) = select.where_clause {
176                where_clause.this = normalize_expression(where_clause.this, strategy);
177                select.where_clause = Some(where_clause);
178            }
179            // Normalize GROUP BY
180            if let Some(mut group_by) = select.group_by {
181                group_by.expressions = group_by
182                    .expressions
183                    .into_iter()
184                    .map(|e| normalize_expression(e, strategy))
185                    .collect();
186                select.group_by = Some(group_by);
187            }
188            // Normalize HAVING
189            if let Some(mut having) = select.having {
190                having.this = normalize_expression(having.this, strategy);
191                select.having = Some(having);
192            }
193            // Normalize ORDER BY
194            if let Some(mut order_by) = select.order_by {
195                order_by.expressions = order_by
196                    .expressions
197                    .into_iter()
198                    .map(|mut o| {
199                        o.this = normalize_expression(o.this, strategy);
200                        o
201                    })
202                    .collect();
203                select.order_by = Some(order_by);
204            }
205            Expression::Select(Box::new(select))
206        }
207        Expression::Alias(alias) => {
208            let mut alias = *alias;
209            alias.this = normalize_expression(alias.this, strategy);
210            alias.alias = normalize_identifier(alias.alias, strategy);
211            Expression::Alias(Box::new(alias))
212        }
213        // Binary operations
214        Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
215        Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
216        Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
217        Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
218        Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
219        Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
220        Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
221        Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
222        Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
223        Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
224        Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
225        Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
226        Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
227        Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
228        // Unary operations
229        Expression::Not(un) => {
230            let mut un = *un;
231            un.this = normalize_expression(un.this, strategy);
232            Expression::Not(Box::new(un))
233        }
234        Expression::Neg(un) => {
235            let mut un = *un;
236            un.this = normalize_expression(un.this, strategy);
237            Expression::Neg(Box::new(un))
238        }
239        // Functions
240        Expression::Function(func) => {
241            let mut func = *func;
242            func.args = func
243                .args
244                .into_iter()
245                .map(|e| normalize_expression(e, strategy))
246                .collect();
247            Expression::Function(Box::new(func))
248        }
249        Expression::AggregateFunction(agg) => {
250            let mut agg = *agg;
251            agg.args = agg
252                .args
253                .into_iter()
254                .map(|e| normalize_expression(e, strategy))
255                .collect();
256            Expression::AggregateFunction(Box::new(agg))
257        }
258        // Other expressions with children
259        Expression::Paren(paren) => {
260            let mut paren = *paren;
261            paren.this = normalize_expression(paren.this, strategy);
262            Expression::Paren(Box::new(paren))
263        }
264        Expression::Case(case) => {
265            let mut case = *case;
266            case.operand = case.operand.map(|e| normalize_expression(e, strategy));
267            case.whens = case
268                .whens
269                .into_iter()
270                .map(|(w, t)| {
271                    (
272                        normalize_expression(w, strategy),
273                        normalize_expression(t, strategy),
274                    )
275                })
276                .collect();
277            case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
278            Expression::Case(Box::new(case))
279        }
280        Expression::Cast(cast) => {
281            let mut cast = *cast;
282            cast.this = normalize_expression(cast.this, strategy);
283            Expression::Cast(Box::new(cast))
284        }
285        Expression::In(in_expr) => {
286            let mut in_expr = *in_expr;
287            in_expr.this = normalize_expression(in_expr.this, strategy);
288            in_expr.expressions = in_expr
289                .expressions
290                .into_iter()
291                .map(|e| normalize_expression(e, strategy))
292                .collect();
293            if let Some(q) = in_expr.query {
294                in_expr.query = Some(normalize_expression(q, strategy));
295            }
296            Expression::In(Box::new(in_expr))
297        }
298        Expression::Between(between) => {
299            let mut between = *between;
300            between.this = normalize_expression(between.this, strategy);
301            between.low = normalize_expression(between.low, strategy);
302            between.high = normalize_expression(between.high, strategy);
303            Expression::Between(Box::new(between))
304        }
305        Expression::Subquery(subquery) => {
306            let mut subquery = *subquery;
307            subquery.this = normalize_expression(subquery.this, strategy);
308            if let Some(alias) = subquery.alias {
309                subquery.alias = Some(normalize_identifier(alias, strategy));
310            }
311            Expression::Subquery(Box::new(subquery))
312        }
313        // Set operations
314        Expression::Union(union) => {
315            let mut union = *union;
316            union.left = normalize_expression(union.left, strategy);
317            union.right = normalize_expression(union.right, strategy);
318            Expression::Union(Box::new(union))
319        }
320        Expression::Intersect(intersect) => {
321            let mut intersect = *intersect;
322            intersect.left = normalize_expression(intersect.left, strategy);
323            intersect.right = normalize_expression(intersect.right, strategy);
324            Expression::Intersect(Box::new(intersect))
325        }
326        Expression::Except(except) => {
327            let mut except = *except;
328            except.left = normalize_expression(except.left, strategy);
329            except.right = normalize_expression(except.right, strategy);
330            Expression::Except(Box::new(except))
331        }
332        // Leaf nodes and others - return unchanged
333        _ => expression,
334    }
335}
336
337/// Helper to normalize binary operations
338fn normalize_binary<F>(
339    constructor: F,
340    mut bin: crate::expressions::BinaryOp,
341    strategy: NormalizationStrategy,
342) -> Expression
343where
344    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
345{
346    bin.left = normalize_expression(bin.left, strategy);
347    bin.right = normalize_expression(bin.right, strategy);
348    constructor(Box::new(bin))
349}
350
351/// Check if an identifier contains case-sensitive characters based on dialect rules.
352pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
353    match strategy {
354        NormalizationStrategy::CaseInsensitive
355        | NormalizationStrategy::CaseInsensitiveUppercase => false,
356        NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
357        NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
358        NormalizationStrategy::CaseSensitive => true,
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::generator::Generator;
366    use crate::parser::Parser;
367
368    fn gen(expr: &Expression) -> String {
369        Generator::new().generate(expr).unwrap()
370    }
371
372    fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
373        let ast = Parser::parse_sql(sql).expect("Failed to parse");
374        let normalized = normalize_identifiers(ast[0].clone(), dialect);
375        gen(&normalized)
376    }
377
378    #[test]
379    fn test_normalize_lowercase() {
380        // PostgreSQL-like: lowercase unquoted identifiers
381        let result = parse_and_normalize("SELECT FoO FROM Bar", None);
382        assert!(result.contains("foo") || result.contains("FOO")); // normalized
383    }
384
385    #[test]
386    fn test_normalize_uppercase() {
387        // Snowflake: uppercase unquoted identifiers
388        let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
389        // Should contain uppercase versions
390        assert!(result.to_uppercase().contains("FOO"));
391    }
392
393    #[test]
394    fn test_normalize_preserves_quoted() {
395        // Quoted identifiers should be preserved in non-case-insensitive dialects
396        let id = Identifier {
397            name: "FoO".to_string(),
398            quoted: true,
399            trailing_comments: vec![],
400        };
401        let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
402        assert_eq!(normalized.name, "FoO"); // Preserved
403    }
404
405    #[test]
406    fn test_case_insensitive_normalizes_quoted() {
407        // In case-insensitive dialects, even quoted identifiers are normalized
408        let id = Identifier {
409            name: "FoO".to_string(),
410            quoted: true,
411            trailing_comments: vec![],
412        };
413        let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
414        assert_eq!(normalized.name, "foo"); // Lowercased
415    }
416
417    #[test]
418    fn test_case_sensitive_no_normalization() {
419        // Case-sensitive dialects don't normalize at all
420        let id = Identifier {
421            name: "FoO".to_string(),
422            quoted: false,
423            trailing_comments: vec![],
424        };
425        let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
426        assert_eq!(normalized.name, "FoO"); // Unchanged
427    }
428
429    #[test]
430    fn test_normalize_column() {
431        let col = Expression::Column(Column {
432            name: Identifier::new("MyColumn"),
433            table: Some(Identifier::new("MyTable")),
434            join_mark: false,
435            trailing_comments: vec![],
436        });
437
438        let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
439        let sql = gen(&normalized);
440        assert!(sql.contains("mycolumn") || sql.contains("mytable"));
441    }
442
443    #[test]
444    fn test_get_normalization_strategy() {
445        assert_eq!(
446            get_normalization_strategy(Some(DialectType::Snowflake)),
447            NormalizationStrategy::Uppercase
448        );
449        assert_eq!(
450            get_normalization_strategy(Some(DialectType::PostgreSQL)),
451            NormalizationStrategy::Lowercase
452        );
453        assert_eq!(
454            get_normalization_strategy(Some(DialectType::MySQL)),
455            NormalizationStrategy::CaseSensitive
456        );
457        assert_eq!(
458            get_normalization_strategy(Some(DialectType::DuckDB)),
459            NormalizationStrategy::CaseInsensitive
460        );
461    }
462}