Skip to main content

polyglot_sql/optimizer/
normalize_identifiers.rs

1//! Identifier Normalization Module
2//!
3//! This module provides functionality for normalizing identifiers in SQL queries
4//! based on dialect-specific rules for case sensitivity and quoting.
5//!
6//! Ported from sqlglot's optimizer/normalize_identifiers.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11/// Strategy for normalizing identifiers based on dialect rules.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14    /// Unquoted identifiers are lowercased (e.g., PostgreSQL)
15    Lowercase,
16    /// Unquoted identifiers are uppercased (e.g., Oracle, Snowflake)
17    Uppercase,
18    /// Always case-sensitive, regardless of quotes (e.g., MySQL on Linux)
19    CaseSensitive,
20    /// Always case-insensitive (lowercase), regardless of quotes (e.g., Spark, BigQuery)
21    CaseInsensitive,
22    /// Always case-insensitive (uppercase), regardless of quotes
23    CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27    fn default() -> Self {
28        Self::Lowercase
29    }
30}
31
32/// Get the normalization strategy for a dialect
33pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34    match dialect {
35        // Uppercase dialects
36        Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37            NormalizationStrategy::Uppercase
38        }
39        // Case-sensitive dialects
40        Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41            NormalizationStrategy::CaseSensitive
42        }
43        // Case-insensitive dialects (lowercase)
44        Some(DialectType::DuckDB)
45        | Some(DialectType::SQLite)
46        | Some(DialectType::BigQuery)
47        | Some(DialectType::Presto)
48        | Some(DialectType::Trino)
49        | Some(DialectType::Hive)
50        | Some(DialectType::Spark)
51        | Some(DialectType::Databricks)
52        | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53        // Default: lowercase (PostgreSQL-like behavior)
54        _ => NormalizationStrategy::Lowercase,
55    }
56}
57
58/// Normalize identifiers in an expression based on dialect rules.
59///
60/// This transformation reflects how identifiers would be resolved by the engine
61/// corresponding to each SQL dialect. For example:
62/// - `FoO` → `foo` in PostgreSQL (lowercases unquoted)
63/// - `FoO` → `FOO` in Snowflake (uppercases unquoted)
64/// - `"FoO"` → `FoO` preserved when quoted (case-sensitive)
65///
66/// # Arguments
67/// * `expression` - The expression to normalize
68/// * `dialect` - The dialect to use for normalization rules
69///
70/// # Returns
71/// The expression with normalized identifiers
72pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73    let strategy = get_normalization_strategy(dialect);
74    normalize_expression(expression, strategy)
75}
76
77/// Normalize a single identifier based on the strategy.
78///
79/// # Arguments
80/// * `identifier` - The identifier to normalize
81/// * `strategy` - The normalization strategy to use
82///
83/// # Returns
84/// The normalized identifier
85pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86    // Case-sensitive strategy: never normalize
87    if strategy == NormalizationStrategy::CaseSensitive {
88        return identifier;
89    }
90
91    // If quoted and not case-insensitive, don't normalize
92    if identifier.quoted
93        && strategy != NormalizationStrategy::CaseInsensitive
94        && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95    {
96        return identifier;
97    }
98
99    // Normalize the identifier name
100    let normalized_name = match strategy {
101        NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102            identifier.name.to_uppercase()
103        }
104        NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105            identifier.name.to_lowercase()
106        }
107        NormalizationStrategy::CaseSensitive => identifier.name, // Should not reach here
108    };
109
110    Identifier {
111        name: normalized_name,
112        quoted: identifier.quoted,
113        trailing_comments: identifier.trailing_comments,
114        span: None,
115    }
116}
117
118/// Recursively normalize all identifiers in an expression.
119fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120    match expression {
121        Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122        Expression::Column(col) => Expression::Column(Column {
123            name: normalize_identifier(col.name, strategy),
124            table: col.table.map(|t| normalize_identifier(t, strategy)),
125            join_mark: col.join_mark,
126            trailing_comments: col.trailing_comments,
127            span: None,
128        }),
129        Expression::Table(mut table) => {
130            table.name = normalize_identifier(table.name, strategy);
131            if let Some(schema) = table.schema {
132                table.schema = Some(normalize_identifier(schema, strategy));
133            }
134            if let Some(catalog) = table.catalog {
135                table.catalog = Some(normalize_identifier(catalog, strategy));
136            }
137            if let Some(alias) = table.alias {
138                table.alias = Some(normalize_identifier(alias, strategy));
139            }
140            table.column_aliases = table
141                .column_aliases
142                .into_iter()
143                .map(|a| normalize_identifier(a, strategy))
144                .collect();
145            Expression::Table(table)
146        }
147        Expression::Select(select) => {
148            let mut select = *select;
149            // Normalize SELECT expressions
150            select.expressions = select
151                .expressions
152                .into_iter()
153                .map(|e| normalize_expression(e, strategy))
154                .collect();
155            // Normalize FROM
156            if let Some(mut from) = select.from {
157                from.expressions = from
158                    .expressions
159                    .into_iter()
160                    .map(|e| normalize_expression(e, strategy))
161                    .collect();
162                select.from = Some(from);
163            }
164            // Normalize JOINs
165            select.joins = select
166                .joins
167                .into_iter()
168                .map(|mut j| {
169                    j.this = normalize_expression(j.this, strategy);
170                    if let Some(on) = j.on {
171                        j.on = Some(normalize_expression(on, strategy));
172                    }
173                    j
174                })
175                .collect();
176            // Normalize WHERE
177            if let Some(mut where_clause) = select.where_clause {
178                where_clause.this = normalize_expression(where_clause.this, strategy);
179                select.where_clause = Some(where_clause);
180            }
181            // Normalize GROUP BY
182            if let Some(mut group_by) = select.group_by {
183                group_by.expressions = group_by
184                    .expressions
185                    .into_iter()
186                    .map(|e| normalize_expression(e, strategy))
187                    .collect();
188                select.group_by = Some(group_by);
189            }
190            // Normalize HAVING
191            if let Some(mut having) = select.having {
192                having.this = normalize_expression(having.this, strategy);
193                select.having = Some(having);
194            }
195            // Normalize ORDER BY
196            if let Some(mut order_by) = select.order_by {
197                order_by.expressions = order_by
198                    .expressions
199                    .into_iter()
200                    .map(|mut o| {
201                        o.this = normalize_expression(o.this, strategy);
202                        o
203                    })
204                    .collect();
205                select.order_by = Some(order_by);
206            }
207            Expression::Select(Box::new(select))
208        }
209        Expression::Alias(alias) => {
210            let mut alias = *alias;
211            alias.this = normalize_expression(alias.this, strategy);
212            alias.alias = normalize_identifier(alias.alias, strategy);
213            Expression::Alias(Box::new(alias))
214        }
215        // Binary operations
216        Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
217        Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
218        Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
219        Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
220        Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
221        Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
222        Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
223        Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
224        Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
225        Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
226        Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
227        Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
228        Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
229        Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
230        // Unary operations
231        Expression::Not(un) => {
232            let mut un = *un;
233            un.this = normalize_expression(un.this, strategy);
234            Expression::Not(Box::new(un))
235        }
236        Expression::Neg(un) => {
237            let mut un = *un;
238            un.this = normalize_expression(un.this, strategy);
239            Expression::Neg(Box::new(un))
240        }
241        // Functions
242        Expression::Function(func) => {
243            let mut func = *func;
244            func.args = func
245                .args
246                .into_iter()
247                .map(|e| normalize_expression(e, strategy))
248                .collect();
249            Expression::Function(Box::new(func))
250        }
251        Expression::AggregateFunction(agg) => {
252            let mut agg = *agg;
253            agg.args = agg
254                .args
255                .into_iter()
256                .map(|e| normalize_expression(e, strategy))
257                .collect();
258            Expression::AggregateFunction(Box::new(agg))
259        }
260        // Other expressions with children
261        Expression::Paren(paren) => {
262            let mut paren = *paren;
263            paren.this = normalize_expression(paren.this, strategy);
264            Expression::Paren(Box::new(paren))
265        }
266        Expression::Case(case) => {
267            let mut case = *case;
268            case.operand = case.operand.map(|e| normalize_expression(e, strategy));
269            case.whens = case
270                .whens
271                .into_iter()
272                .map(|(w, t)| {
273                    (
274                        normalize_expression(w, strategy),
275                        normalize_expression(t, strategy),
276                    )
277                })
278                .collect();
279            case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
280            Expression::Case(Box::new(case))
281        }
282        Expression::Cast(cast) => {
283            let mut cast = *cast;
284            cast.this = normalize_expression(cast.this, strategy);
285            Expression::Cast(Box::new(cast))
286        }
287        Expression::In(in_expr) => {
288            let mut in_expr = *in_expr;
289            in_expr.this = normalize_expression(in_expr.this, strategy);
290            in_expr.expressions = in_expr
291                .expressions
292                .into_iter()
293                .map(|e| normalize_expression(e, strategy))
294                .collect();
295            if let Some(q) = in_expr.query {
296                in_expr.query = Some(normalize_expression(q, strategy));
297            }
298            Expression::In(Box::new(in_expr))
299        }
300        Expression::Between(between) => {
301            let mut between = *between;
302            between.this = normalize_expression(between.this, strategy);
303            between.low = normalize_expression(between.low, strategy);
304            between.high = normalize_expression(between.high, strategy);
305            Expression::Between(Box::new(between))
306        }
307        Expression::Subquery(subquery) => {
308            let mut subquery = *subquery;
309            subquery.this = normalize_expression(subquery.this, strategy);
310            if let Some(alias) = subquery.alias {
311                subquery.alias = Some(normalize_identifier(alias, strategy));
312            }
313            Expression::Subquery(Box::new(subquery))
314        }
315        // Set operations
316        Expression::Union(union) => {
317            let mut union = *union;
318            union.left = normalize_expression(union.left, strategy);
319            union.right = normalize_expression(union.right, strategy);
320            Expression::Union(Box::new(union))
321        }
322        Expression::Intersect(intersect) => {
323            let mut intersect = *intersect;
324            intersect.left = normalize_expression(intersect.left, strategy);
325            intersect.right = normalize_expression(intersect.right, strategy);
326            Expression::Intersect(Box::new(intersect))
327        }
328        Expression::Except(except) => {
329            let mut except = *except;
330            except.left = normalize_expression(except.left, strategy);
331            except.right = normalize_expression(except.right, strategy);
332            Expression::Except(Box::new(except))
333        }
334        // Leaf nodes and others - return unchanged
335        _ => expression,
336    }
337}
338
339/// Helper to normalize binary operations
340fn normalize_binary<F>(
341    constructor: F,
342    mut bin: crate::expressions::BinaryOp,
343    strategy: NormalizationStrategy,
344) -> Expression
345where
346    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
347{
348    bin.left = normalize_expression(bin.left, strategy);
349    bin.right = normalize_expression(bin.right, strategy);
350    constructor(Box::new(bin))
351}
352
353/// Check if an identifier contains case-sensitive characters based on dialect rules.
354pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
355    match strategy {
356        NormalizationStrategy::CaseInsensitive
357        | NormalizationStrategy::CaseInsensitiveUppercase => false,
358        NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
359        NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
360        NormalizationStrategy::CaseSensitive => true,
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::generator::Generator;
368    use crate::parser::Parser;
369
370    fn gen(expr: &Expression) -> String {
371        Generator::new().generate(expr).unwrap()
372    }
373
374    fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
375        let ast = Parser::parse_sql(sql).expect("Failed to parse");
376        let normalized = normalize_identifiers(ast[0].clone(), dialect);
377        gen(&normalized)
378    }
379
380    #[test]
381    fn test_normalize_lowercase() {
382        // PostgreSQL-like: lowercase unquoted identifiers
383        let result = parse_and_normalize("SELECT FoO FROM Bar", None);
384        assert!(result.contains("foo") || result.contains("FOO")); // normalized
385    }
386
387    #[test]
388    fn test_normalize_uppercase() {
389        // Snowflake: uppercase unquoted identifiers
390        let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
391        // Should contain uppercase versions
392        assert!(result.to_uppercase().contains("FOO"));
393    }
394
395    #[test]
396    fn test_normalize_preserves_quoted() {
397        // Quoted identifiers should be preserved in non-case-insensitive dialects
398        let id = Identifier {
399            name: "FoO".to_string(),
400            quoted: true,
401            trailing_comments: vec![],
402            span: None,
403        };
404        let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
405        assert_eq!(normalized.name, "FoO"); // Preserved
406    }
407
408    #[test]
409    fn test_case_insensitive_normalizes_quoted() {
410        // In case-insensitive dialects, even quoted identifiers are normalized
411        let id = Identifier {
412            name: "FoO".to_string(),
413            quoted: true,
414            trailing_comments: vec![],
415            span: None,
416        };
417        let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
418        assert_eq!(normalized.name, "foo"); // Lowercased
419    }
420
421    #[test]
422    fn test_case_sensitive_no_normalization() {
423        // Case-sensitive dialects don't normalize at all
424        let id = Identifier {
425            name: "FoO".to_string(),
426            quoted: false,
427            trailing_comments: vec![],
428            span: None,
429        };
430        let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
431        assert_eq!(normalized.name, "FoO"); // Unchanged
432    }
433
434    #[test]
435    fn test_normalize_column() {
436        let col = Expression::Column(Column {
437            name: Identifier::new("MyColumn"),
438            table: Some(Identifier::new("MyTable")),
439            join_mark: false,
440            trailing_comments: vec![],
441            span: None,
442        });
443
444        let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
445        let sql = gen(&normalized);
446        assert!(sql.contains("mycolumn") || sql.contains("mytable"));
447    }
448
449    #[test]
450    fn test_get_normalization_strategy() {
451        assert_eq!(
452            get_normalization_strategy(Some(DialectType::Snowflake)),
453            NormalizationStrategy::Uppercase
454        );
455        assert_eq!(
456            get_normalization_strategy(Some(DialectType::PostgreSQL)),
457            NormalizationStrategy::Lowercase
458        );
459        assert_eq!(
460            get_normalization_strategy(Some(DialectType::MySQL)),
461            NormalizationStrategy::CaseSensitive
462        );
463        assert_eq!(
464            get_normalization_strategy(Some(DialectType::DuckDB)),
465            NormalizationStrategy::CaseInsensitive
466        );
467    }
468}