Skip to main content

polyglot_sql/optimizer/
normalize_identifiers.rs

1//! Identifier Normalization Module
2//!
3//! This module provides functionality for normalizing identifiers in SQL queries
4//! based on dialect-specific rules for case sensitivity and quoting.
5//!
6//! Ported from sqlglot's optimizer/normalize_identifiers.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11/// Strategy for normalizing identifiers based on dialect rules.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14    /// Unquoted identifiers are lowercased (e.g., PostgreSQL)
15    Lowercase,
16    /// Unquoted identifiers are uppercased (e.g., Oracle, Snowflake)
17    Uppercase,
18    /// Always case-sensitive, regardless of quotes (e.g., MySQL on Linux)
19    CaseSensitive,
20    /// Always case-insensitive (lowercase), regardless of quotes (e.g., Spark, BigQuery)
21    CaseInsensitive,
22    /// Always case-insensitive (uppercase), regardless of quotes
23    CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27    fn default() -> Self {
28        Self::Lowercase
29    }
30}
31
32/// Get the normalization strategy for a dialect
33pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34    match dialect {
35        // Uppercase dialects
36        Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37            NormalizationStrategy::Uppercase
38        }
39        // Case-sensitive dialects
40        Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41            NormalizationStrategy::CaseSensitive
42        }
43        // Case-insensitive dialects (lowercase)
44        Some(DialectType::DuckDB)
45        | Some(DialectType::SQLite)
46        | Some(DialectType::BigQuery)
47        | Some(DialectType::Presto)
48        | Some(DialectType::Trino)
49        | Some(DialectType::Hive)
50        | Some(DialectType::Spark)
51        | Some(DialectType::Databricks)
52        | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53        // Default: lowercase (PostgreSQL-like behavior)
54        _ => NormalizationStrategy::Lowercase,
55    }
56}
57
58/// Normalize identifiers in an expression based on dialect rules.
59///
60/// This transformation reflects how identifiers would be resolved by the engine
61/// corresponding to each SQL dialect. For example:
62/// - `FoO` → `foo` in PostgreSQL (lowercases unquoted)
63/// - `FoO` → `FOO` in Snowflake (uppercases unquoted)
64/// - `"FoO"` → `FoO` preserved when quoted (case-sensitive)
65///
66/// # Arguments
67/// * `expression` - The expression to normalize
68/// * `dialect` - The dialect to use for normalization rules
69///
70/// # Returns
71/// The expression with normalized identifiers
72pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73    let strategy = get_normalization_strategy(dialect);
74    normalize_expression(expression, strategy)
75}
76
77/// Normalize a single identifier based on the strategy.
78///
79/// # Arguments
80/// * `identifier` - The identifier to normalize
81/// * `strategy` - The normalization strategy to use
82///
83/// # Returns
84/// The normalized identifier
85pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86    // Case-sensitive strategy: never normalize
87    if strategy == NormalizationStrategy::CaseSensitive {
88        return identifier;
89    }
90
91    // If quoted and not case-insensitive, don't normalize
92    if identifier.quoted
93        && strategy != NormalizationStrategy::CaseInsensitive
94        && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95    {
96        return identifier;
97    }
98
99    // Normalize the identifier name
100    let normalized_name = match strategy {
101        NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102            identifier.name.to_uppercase()
103        }
104        NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105            identifier.name.to_lowercase()
106        }
107        NormalizationStrategy::CaseSensitive => identifier.name, // Should not reach here
108    };
109
110    Identifier {
111        name: normalized_name,
112        quoted: identifier.quoted,
113        trailing_comments: identifier.trailing_comments,
114    }
115}
116
117/// Recursively normalize all identifiers in an expression.
118fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
119    match expression {
120        Expression::Identifier(id) => {
121            Expression::Identifier(normalize_identifier(id, strategy))
122        }
123        Expression::Column(col) => {
124            Expression::Column(Column {
125                name: normalize_identifier(col.name, strategy),
126                table: col.table.map(|t| normalize_identifier(t, strategy)),
127                join_mark: col.join_mark,
128                trailing_comments: col.trailing_comments,
129            })
130        }
131        Expression::Table(mut table) => {
132            table.name = normalize_identifier(table.name, strategy);
133            if let Some(schema) = table.schema {
134                table.schema = Some(normalize_identifier(schema, strategy));
135            }
136            if let Some(catalog) = table.catalog {
137                table.catalog = Some(normalize_identifier(catalog, strategy));
138            }
139            if let Some(alias) = table.alias {
140                table.alias = Some(normalize_identifier(alias, strategy));
141            }
142            table.column_aliases = table
143                .column_aliases
144                .into_iter()
145                .map(|a| normalize_identifier(a, strategy))
146                .collect();
147            Expression::Table(table)
148        }
149        Expression::Select(select) => {
150            let mut select = *select;
151            // Normalize SELECT expressions
152            select.expressions = select
153                .expressions
154                .into_iter()
155                .map(|e| normalize_expression(e, strategy))
156                .collect();
157            // Normalize FROM
158            if let Some(mut from) = select.from {
159                from.expressions = from
160                    .expressions
161                    .into_iter()
162                    .map(|e| normalize_expression(e, strategy))
163                    .collect();
164                select.from = Some(from);
165            }
166            // Normalize JOINs
167            select.joins = select
168                .joins
169                .into_iter()
170                .map(|mut j| {
171                    j.this = normalize_expression(j.this, strategy);
172                    if let Some(on) = j.on {
173                        j.on = Some(normalize_expression(on, strategy));
174                    }
175                    j
176                })
177                .collect();
178            // Normalize WHERE
179            if let Some(mut where_clause) = select.where_clause {
180                where_clause.this = normalize_expression(where_clause.this, strategy);
181                select.where_clause = Some(where_clause);
182            }
183            // Normalize GROUP BY
184            if let Some(mut group_by) = select.group_by {
185                group_by.expressions = group_by
186                    .expressions
187                    .into_iter()
188                    .map(|e| normalize_expression(e, strategy))
189                    .collect();
190                select.group_by = Some(group_by);
191            }
192            // Normalize HAVING
193            if let Some(mut having) = select.having {
194                having.this = normalize_expression(having.this, strategy);
195                select.having = Some(having);
196            }
197            // Normalize ORDER BY
198            if let Some(mut order_by) = select.order_by {
199                order_by.expressions = order_by
200                    .expressions
201                    .into_iter()
202                    .map(|mut o| {
203                        o.this = normalize_expression(o.this, strategy);
204                        o
205                    })
206                    .collect();
207                select.order_by = Some(order_by);
208            }
209            Expression::Select(Box::new(select))
210        }
211        Expression::Alias(alias) => {
212            let mut alias = *alias;
213            alias.this = normalize_expression(alias.this, strategy);
214            alias.alias = normalize_identifier(alias.alias, strategy);
215            Expression::Alias(Box::new(alias))
216        }
217        // Binary operations
218        Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
219        Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
220        Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
221        Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
222        Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
223        Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
224        Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
225        Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
226        Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
227        Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
228        Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
229        Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
230        Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
231        Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
232        // Unary operations
233        Expression::Not(un) => {
234            let mut un = *un;
235            un.this = normalize_expression(un.this, strategy);
236            Expression::Not(Box::new(un))
237        }
238        Expression::Neg(un) => {
239            let mut un = *un;
240            un.this = normalize_expression(un.this, strategy);
241            Expression::Neg(Box::new(un))
242        }
243        // Functions
244        Expression::Function(func) => {
245            let mut func = *func;
246            func.args = func
247                .args
248                .into_iter()
249                .map(|e| normalize_expression(e, strategy))
250                .collect();
251            Expression::Function(Box::new(func))
252        }
253        Expression::AggregateFunction(agg) => {
254            let mut agg = *agg;
255            agg.args = agg
256                .args
257                .into_iter()
258                .map(|e| normalize_expression(e, strategy))
259                .collect();
260            Expression::AggregateFunction(Box::new(agg))
261        }
262        // Other expressions with children
263        Expression::Paren(paren) => {
264            let mut paren = *paren;
265            paren.this = normalize_expression(paren.this, strategy);
266            Expression::Paren(Box::new(paren))
267        }
268        Expression::Case(case) => {
269            let mut case = *case;
270            case.operand = case.operand.map(|e| normalize_expression(e, strategy));
271            case.whens = case
272                .whens
273                .into_iter()
274                .map(|(w, t)| {
275                    (
276                        normalize_expression(w, strategy),
277                        normalize_expression(t, strategy),
278                    )
279                })
280                .collect();
281            case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
282            Expression::Case(Box::new(case))
283        }
284        Expression::Cast(cast) => {
285            let mut cast = *cast;
286            cast.this = normalize_expression(cast.this, strategy);
287            Expression::Cast(Box::new(cast))
288        }
289        Expression::In(in_expr) => {
290            let mut in_expr = *in_expr;
291            in_expr.this = normalize_expression(in_expr.this, strategy);
292            in_expr.expressions = in_expr
293                .expressions
294                .into_iter()
295                .map(|e| normalize_expression(e, strategy))
296                .collect();
297            if let Some(q) = in_expr.query {
298                in_expr.query = Some(normalize_expression(q, strategy));
299            }
300            Expression::In(Box::new(in_expr))
301        }
302        Expression::Between(between) => {
303            let mut between = *between;
304            between.this = normalize_expression(between.this, strategy);
305            between.low = normalize_expression(between.low, strategy);
306            between.high = normalize_expression(between.high, strategy);
307            Expression::Between(Box::new(between))
308        }
309        Expression::Subquery(subquery) => {
310            let mut subquery = *subquery;
311            subquery.this = normalize_expression(subquery.this, strategy);
312            if let Some(alias) = subquery.alias {
313                subquery.alias = Some(normalize_identifier(alias, strategy));
314            }
315            Expression::Subquery(Box::new(subquery))
316        }
317        // Set operations
318        Expression::Union(union) => {
319            let mut union = *union;
320            union.left = normalize_expression(union.left, strategy);
321            union.right = normalize_expression(union.right, strategy);
322            Expression::Union(Box::new(union))
323        }
324        Expression::Intersect(intersect) => {
325            let mut intersect = *intersect;
326            intersect.left = normalize_expression(intersect.left, strategy);
327            intersect.right = normalize_expression(intersect.right, strategy);
328            Expression::Intersect(Box::new(intersect))
329        }
330        Expression::Except(except) => {
331            let mut except = *except;
332            except.left = normalize_expression(except.left, strategy);
333            except.right = normalize_expression(except.right, strategy);
334            Expression::Except(Box::new(except))
335        }
336        // Leaf nodes and others - return unchanged
337        _ => expression,
338    }
339}
340
341/// Helper to normalize binary operations
342fn normalize_binary<F>(
343    constructor: F,
344    mut bin: crate::expressions::BinaryOp,
345    strategy: NormalizationStrategy,
346) -> Expression
347where
348    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
349{
350    bin.left = normalize_expression(bin.left, strategy);
351    bin.right = normalize_expression(bin.right, strategy);
352    constructor(Box::new(bin))
353}
354
355/// Check if an identifier contains case-sensitive characters based on dialect rules.
356pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
357    match strategy {
358        NormalizationStrategy::CaseInsensitive | NormalizationStrategy::CaseInsensitiveUppercase => {
359            false
360        }
361        NormalizationStrategy::Uppercase => {
362            text.chars().any(|c| c.is_lowercase())
363        }
364        NormalizationStrategy::Lowercase => {
365            text.chars().any(|c| c.is_uppercase())
366        }
367        NormalizationStrategy::CaseSensitive => true,
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::generator::Generator;
375    use crate::parser::Parser;
376
377    fn gen(expr: &Expression) -> String {
378        Generator::new().generate(expr).unwrap()
379    }
380
381    fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
382        let ast = Parser::parse_sql(sql).expect("Failed to parse");
383        let normalized = normalize_identifiers(ast[0].clone(), dialect);
384        gen(&normalized)
385    }
386
387    #[test]
388    fn test_normalize_lowercase() {
389        // PostgreSQL-like: lowercase unquoted identifiers
390        let result = parse_and_normalize("SELECT FoO FROM Bar", None);
391        assert!(result.contains("foo") || result.contains("FOO")); // normalized
392    }
393
394    #[test]
395    fn test_normalize_uppercase() {
396        // Snowflake: uppercase unquoted identifiers
397        let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
398        // Should contain uppercase versions
399        assert!(result.to_uppercase().contains("FOO"));
400    }
401
402    #[test]
403    fn test_normalize_preserves_quoted() {
404        // Quoted identifiers should be preserved in non-case-insensitive dialects
405        let id = Identifier {
406            name: "FoO".to_string(),
407            quoted: true,
408            trailing_comments: vec![],
409        };
410        let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
411        assert_eq!(normalized.name, "FoO"); // Preserved
412    }
413
414    #[test]
415    fn test_case_insensitive_normalizes_quoted() {
416        // In case-insensitive dialects, even quoted identifiers are normalized
417        let id = Identifier {
418            name: "FoO".to_string(),
419            quoted: true,
420            trailing_comments: vec![],
421        };
422        let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
423        assert_eq!(normalized.name, "foo"); // Lowercased
424    }
425
426    #[test]
427    fn test_case_sensitive_no_normalization() {
428        // Case-sensitive dialects don't normalize at all
429        let id = Identifier {
430            name: "FoO".to_string(),
431            quoted: false,
432            trailing_comments: vec![],
433        };
434        let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
435        assert_eq!(normalized.name, "FoO"); // Unchanged
436    }
437
438    #[test]
439    fn test_normalize_column() {
440        let col = Expression::Column(Column {
441            name: Identifier::new("MyColumn"),
442            table: Some(Identifier::new("MyTable")),
443            join_mark: false,
444            trailing_comments: vec![],
445        });
446
447        let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
448        let sql = gen(&normalized);
449        assert!(sql.contains("mycolumn") || sql.contains("mytable"));
450    }
451
452    #[test]
453    fn test_get_normalization_strategy() {
454        assert_eq!(
455            get_normalization_strategy(Some(DialectType::Snowflake)),
456            NormalizationStrategy::Uppercase
457        );
458        assert_eq!(
459            get_normalization_strategy(Some(DialectType::PostgreSQL)),
460            NormalizationStrategy::Lowercase
461        );
462        assert_eq!(
463            get_normalization_strategy(Some(DialectType::MySQL)),
464            NormalizationStrategy::CaseSensitive
465        );
466        assert_eq!(
467            get_normalization_strategy(Some(DialectType::DuckDB)),
468            NormalizationStrategy::CaseInsensitive
469        );
470    }
471}