1use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier, Null};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14 Lowercase,
16 Uppercase,
18 CaseSensitive,
20 CaseInsensitive,
22 CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27 fn default() -> Self {
28 Self::Lowercase
29 }
30}
31
32pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34 match dialect {
35 Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37 NormalizationStrategy::Uppercase
38 }
39 Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41 NormalizationStrategy::CaseSensitive
42 }
43 Some(DialectType::DuckDB)
45 | Some(DialectType::SQLite)
46 | Some(DialectType::BigQuery)
47 | Some(DialectType::Presto)
48 | Some(DialectType::Trino)
49 | Some(DialectType::Hive)
50 | Some(DialectType::Spark)
51 | Some(DialectType::Databricks)
52 | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53 _ => NormalizationStrategy::Lowercase,
55 }
56}
57
58pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73 let strategy = get_normalization_strategy(dialect);
74 normalize_expression(expression, strategy)
75}
76
77pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86 if strategy == NormalizationStrategy::CaseSensitive {
88 return identifier;
89 }
90
91 if identifier.quoted
93 && strategy != NormalizationStrategy::CaseInsensitive
94 && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95 {
96 return identifier;
97 }
98
99 let normalized_name = match strategy {
101 NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102 identifier.name.to_uppercase()
103 }
104 NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105 identifier.name.to_lowercase()
106 }
107 NormalizationStrategy::CaseSensitive => identifier.name, };
109
110 Identifier {
111 name: normalized_name,
112 quoted: identifier.quoted,
113 trailing_comments: identifier.trailing_comments,
114 span: None,
115 }
116}
117
118fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120 match expression {
121 Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122 Expression::Column(col) => Expression::boxed_column(Column {
123 name: normalize_identifier(col.name, strategy),
124 table: col.table.map(|t| normalize_identifier(t, strategy)),
125 join_mark: col.join_mark,
126 trailing_comments: col.trailing_comments,
127 span: None,
128 inferred_type: None,
129 }),
130 Expression::Table(mut table) => {
131 table.name = normalize_identifier(table.name, strategy);
132 if let Some(schema) = table.schema {
133 table.schema = Some(normalize_identifier(schema, strategy));
134 }
135 if let Some(catalog) = table.catalog {
136 table.catalog = Some(normalize_identifier(catalog, strategy));
137 }
138 if let Some(alias) = table.alias {
139 table.alias = Some(normalize_identifier(alias, strategy));
140 }
141 table.column_aliases = table
142 .column_aliases
143 .into_iter()
144 .map(|a| normalize_identifier(a, strategy))
145 .collect();
146 Expression::Table(table)
147 }
148 Expression::Select(select) => {
149 let mut select = *select;
150 select.expressions = select
152 .expressions
153 .into_iter()
154 .map(|e| normalize_expression(e, strategy))
155 .collect();
156 if let Some(mut from) = select.from {
158 from.expressions = from
159 .expressions
160 .into_iter()
161 .map(|e| normalize_expression(e, strategy))
162 .collect();
163 select.from = Some(from);
164 }
165 select.joins = select
167 .joins
168 .into_iter()
169 .map(|mut j| {
170 j.this = normalize_expression(j.this, strategy);
171 if let Some(on) = j.on {
172 j.on = Some(normalize_expression(on, strategy));
173 }
174 j
175 })
176 .collect();
177 if let Some(mut where_clause) = select.where_clause {
179 where_clause.this = normalize_expression(where_clause.this, strategy);
180 select.where_clause = Some(where_clause);
181 }
182 if let Some(mut group_by) = select.group_by {
184 group_by.expressions = group_by
185 .expressions
186 .into_iter()
187 .map(|e| normalize_expression(e, strategy))
188 .collect();
189 select.group_by = Some(group_by);
190 }
191 if let Some(mut having) = select.having {
193 having.this = normalize_expression(having.this, strategy);
194 select.having = Some(having);
195 }
196 if let Some(mut order_by) = select.order_by {
198 order_by.expressions = order_by
199 .expressions
200 .into_iter()
201 .map(|mut o| {
202 o.this = normalize_expression(o.this, strategy);
203 o
204 })
205 .collect();
206 select.order_by = Some(order_by);
207 }
208 Expression::Select(Box::new(select))
209 }
210 Expression::Alias(alias) => {
211 let mut alias = *alias;
212 alias.this = normalize_expression(alias.this, strategy);
213 alias.alias = normalize_identifier(alias.alias, strategy);
214 Expression::Alias(Box::new(alias))
215 }
216 Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
218 Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
219 Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
220 Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
221 Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
222 Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
223 Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
224 Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
225 Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
226 Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
227 Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
228 Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
229 Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
230 Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
231 Expression::Not(un) => {
233 let mut un = *un;
234 un.this = normalize_expression(un.this, strategy);
235 Expression::Not(Box::new(un))
236 }
237 Expression::Neg(un) => {
238 let mut un = *un;
239 un.this = normalize_expression(un.this, strategy);
240 Expression::Neg(Box::new(un))
241 }
242 Expression::Function(func) => {
244 let mut func = *func;
245 func.args = func
246 .args
247 .into_iter()
248 .map(|e| normalize_expression(e, strategy))
249 .collect();
250 Expression::Function(Box::new(func))
251 }
252 Expression::AggregateFunction(agg) => {
253 let mut agg = *agg;
254 agg.args = agg
255 .args
256 .into_iter()
257 .map(|e| normalize_expression(e, strategy))
258 .collect();
259 Expression::AggregateFunction(Box::new(agg))
260 }
261 Expression::Paren(paren) => {
263 let mut paren = *paren;
264 paren.this = normalize_expression(paren.this, strategy);
265 Expression::Paren(Box::new(paren))
266 }
267 Expression::Case(case) => {
268 let mut case = *case;
269 case.operand = case.operand.map(|e| normalize_expression(e, strategy));
270 case.whens = case
271 .whens
272 .into_iter()
273 .map(|(w, t)| {
274 (
275 normalize_expression(w, strategy),
276 normalize_expression(t, strategy),
277 )
278 })
279 .collect();
280 case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
281 Expression::Case(Box::new(case))
282 }
283 Expression::Cast(cast) => {
284 let mut cast = *cast;
285 cast.this = normalize_expression(cast.this, strategy);
286 Expression::Cast(Box::new(cast))
287 }
288 Expression::In(in_expr) => {
289 let mut in_expr = *in_expr;
290 in_expr.this = normalize_expression(in_expr.this, strategy);
291 in_expr.expressions = in_expr
292 .expressions
293 .into_iter()
294 .map(|e| normalize_expression(e, strategy))
295 .collect();
296 if let Some(q) = in_expr.query {
297 in_expr.query = Some(normalize_expression(q, strategy));
298 }
299 Expression::In(Box::new(in_expr))
300 }
301 Expression::Between(between) => {
302 let mut between = *between;
303 between.this = normalize_expression(between.this, strategy);
304 between.low = normalize_expression(between.low, strategy);
305 between.high = normalize_expression(between.high, strategy);
306 Expression::Between(Box::new(between))
307 }
308 Expression::Subquery(subquery) => {
309 let mut subquery = *subquery;
310 subquery.this = normalize_expression(subquery.this, strategy);
311 if let Some(alias) = subquery.alias {
312 subquery.alias = Some(normalize_identifier(alias, strategy));
313 }
314 Expression::Subquery(Box::new(subquery))
315 }
316 Expression::Union(mut union) => {
318 let left = std::mem::replace(&mut union.left, Expression::Null(Null));
319 union.left = normalize_expression(left, strategy);
320 let right = std::mem::replace(&mut union.right, Expression::Null(Null));
321 union.right = normalize_expression(right, strategy);
322 Expression::Union(union)
323 }
324 Expression::Intersect(mut intersect) => {
325 let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
326 intersect.left = normalize_expression(left, strategy);
327 let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
328 intersect.right = normalize_expression(right, strategy);
329 Expression::Intersect(intersect)
330 }
331 Expression::Except(mut except) => {
332 let left = std::mem::replace(&mut except.left, Expression::Null(Null));
333 except.left = normalize_expression(left, strategy);
334 let right = std::mem::replace(&mut except.right, Expression::Null(Null));
335 except.right = normalize_expression(right, strategy);
336 Expression::Except(except)
337 }
338 _ => expression,
340 }
341}
342
343fn normalize_binary<F>(
345 constructor: F,
346 mut bin: crate::expressions::BinaryOp,
347 strategy: NormalizationStrategy,
348) -> Expression
349where
350 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
351{
352 bin.left = normalize_expression(bin.left, strategy);
353 bin.right = normalize_expression(bin.right, strategy);
354 constructor(Box::new(bin))
355}
356
357pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
359 match strategy {
360 NormalizationStrategy::CaseInsensitive
361 | NormalizationStrategy::CaseInsensitiveUppercase => false,
362 NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
363 NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
364 NormalizationStrategy::CaseSensitive => true,
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::generator::Generator;
372 use crate::parser::Parser;
373
374 fn gen(expr: &Expression) -> String {
375 Generator::new().generate(expr).unwrap()
376 }
377
378 fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
379 let ast = Parser::parse_sql(sql).expect("Failed to parse");
380 let normalized = normalize_identifiers(ast[0].clone(), dialect);
381 gen(&normalized)
382 }
383
384 #[test]
385 fn test_normalize_lowercase() {
386 let result = parse_and_normalize("SELECT FoO FROM Bar", None);
388 assert!(result.contains("foo") || result.contains("FOO")); }
390
391 #[test]
392 fn test_normalize_uppercase() {
393 let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
395 assert!(result.to_uppercase().contains("FOO"));
397 }
398
399 #[test]
400 fn test_normalize_preserves_quoted() {
401 let id = Identifier {
403 name: "FoO".to_string(),
404 quoted: true,
405 trailing_comments: vec![],
406 span: None,
407 };
408 let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
409 assert_eq!(normalized.name, "FoO"); }
411
412 #[test]
413 fn test_case_insensitive_normalizes_quoted() {
414 let id = Identifier {
416 name: "FoO".to_string(),
417 quoted: true,
418 trailing_comments: vec![],
419 span: None,
420 };
421 let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
422 assert_eq!(normalized.name, "foo"); }
424
425 #[test]
426 fn test_case_sensitive_no_normalization() {
427 let id = Identifier {
429 name: "FoO".to_string(),
430 quoted: false,
431 trailing_comments: vec![],
432 span: None,
433 };
434 let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
435 assert_eq!(normalized.name, "FoO"); }
437
438 #[test]
439 fn test_normalize_column() {
440 let col = Expression::boxed_column(Column {
441 name: Identifier::new("MyColumn"),
442 table: Some(Identifier::new("MyTable")),
443 join_mark: false,
444 trailing_comments: vec![],
445 span: None,
446 inferred_type: None,
447 });
448
449 let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
450 let sql = gen(&normalized);
451 assert!(sql.contains("mycolumn") || sql.contains("mytable"));
452 }
453
454 #[test]
455 fn test_get_normalization_strategy() {
456 assert_eq!(
457 get_normalization_strategy(Some(DialectType::Snowflake)),
458 NormalizationStrategy::Uppercase
459 );
460 assert_eq!(
461 get_normalization_strategy(Some(DialectType::PostgreSQL)),
462 NormalizationStrategy::Lowercase
463 );
464 assert_eq!(
465 get_normalization_strategy(Some(DialectType::MySQL)),
466 NormalizationStrategy::CaseSensitive
467 );
468 assert_eq!(
469 get_normalization_strategy(Some(DialectType::DuckDB)),
470 NormalizationStrategy::CaseInsensitive
471 );
472 }
473}