1use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14 Lowercase,
16 Uppercase,
18 CaseSensitive,
20 CaseInsensitive,
22 CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27 fn default() -> Self {
28 Self::Lowercase
29 }
30}
31
32pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34 match dialect {
35 Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37 NormalizationStrategy::Uppercase
38 }
39 Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41 NormalizationStrategy::CaseSensitive
42 }
43 Some(DialectType::DuckDB)
45 | Some(DialectType::SQLite)
46 | Some(DialectType::BigQuery)
47 | Some(DialectType::Presto)
48 | Some(DialectType::Trino)
49 | Some(DialectType::Hive)
50 | Some(DialectType::Spark)
51 | Some(DialectType::Databricks)
52 | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53 _ => NormalizationStrategy::Lowercase,
55 }
56}
57
58pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73 let strategy = get_normalization_strategy(dialect);
74 normalize_expression(expression, strategy)
75}
76
77pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86 if strategy == NormalizationStrategy::CaseSensitive {
88 return identifier;
89 }
90
91 if identifier.quoted
93 && strategy != NormalizationStrategy::CaseInsensitive
94 && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95 {
96 return identifier;
97 }
98
99 let normalized_name = match strategy {
101 NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102 identifier.name.to_uppercase()
103 }
104 NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105 identifier.name.to_lowercase()
106 }
107 NormalizationStrategy::CaseSensitive => identifier.name, };
109
110 Identifier {
111 name: normalized_name,
112 quoted: identifier.quoted,
113 trailing_comments: identifier.trailing_comments,
114 span: None,
115 }
116}
117
118fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120 match expression {
121 Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122 Expression::Column(col) => Expression::Column(Column {
123 name: normalize_identifier(col.name, strategy),
124 table: col.table.map(|t| normalize_identifier(t, strategy)),
125 join_mark: col.join_mark,
126 trailing_comments: col.trailing_comments,
127 span: None,
128 inferred_type: None,
129 }),
130 Expression::Table(mut table) => {
131 table.name = normalize_identifier(table.name, strategy);
132 if let Some(schema) = table.schema {
133 table.schema = Some(normalize_identifier(schema, strategy));
134 }
135 if let Some(catalog) = table.catalog {
136 table.catalog = Some(normalize_identifier(catalog, strategy));
137 }
138 if let Some(alias) = table.alias {
139 table.alias = Some(normalize_identifier(alias, strategy));
140 }
141 table.column_aliases = table
142 .column_aliases
143 .into_iter()
144 .map(|a| normalize_identifier(a, strategy))
145 .collect();
146 Expression::Table(table)
147 }
148 Expression::Select(select) => {
149 let mut select = *select;
150 select.expressions = select
152 .expressions
153 .into_iter()
154 .map(|e| normalize_expression(e, strategy))
155 .collect();
156 if let Some(mut from) = select.from {
158 from.expressions = from
159 .expressions
160 .into_iter()
161 .map(|e| normalize_expression(e, strategy))
162 .collect();
163 select.from = Some(from);
164 }
165 select.joins = select
167 .joins
168 .into_iter()
169 .map(|mut j| {
170 j.this = normalize_expression(j.this, strategy);
171 if let Some(on) = j.on {
172 j.on = Some(normalize_expression(on, strategy));
173 }
174 j
175 })
176 .collect();
177 if let Some(mut where_clause) = select.where_clause {
179 where_clause.this = normalize_expression(where_clause.this, strategy);
180 select.where_clause = Some(where_clause);
181 }
182 if let Some(mut group_by) = select.group_by {
184 group_by.expressions = group_by
185 .expressions
186 .into_iter()
187 .map(|e| normalize_expression(e, strategy))
188 .collect();
189 select.group_by = Some(group_by);
190 }
191 if let Some(mut having) = select.having {
193 having.this = normalize_expression(having.this, strategy);
194 select.having = Some(having);
195 }
196 if let Some(mut order_by) = select.order_by {
198 order_by.expressions = order_by
199 .expressions
200 .into_iter()
201 .map(|mut o| {
202 o.this = normalize_expression(o.this, strategy);
203 o
204 })
205 .collect();
206 select.order_by = Some(order_by);
207 }
208 Expression::Select(Box::new(select))
209 }
210 Expression::Alias(alias) => {
211 let mut alias = *alias;
212 alias.this = normalize_expression(alias.this, strategy);
213 alias.alias = normalize_identifier(alias.alias, strategy);
214 Expression::Alias(Box::new(alias))
215 }
216 Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
218 Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
219 Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
220 Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
221 Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
222 Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
223 Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
224 Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
225 Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
226 Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
227 Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
228 Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
229 Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
230 Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
231 Expression::Not(un) => {
233 let mut un = *un;
234 un.this = normalize_expression(un.this, strategy);
235 Expression::Not(Box::new(un))
236 }
237 Expression::Neg(un) => {
238 let mut un = *un;
239 un.this = normalize_expression(un.this, strategy);
240 Expression::Neg(Box::new(un))
241 }
242 Expression::Function(func) => {
244 let mut func = *func;
245 func.args = func
246 .args
247 .into_iter()
248 .map(|e| normalize_expression(e, strategy))
249 .collect();
250 Expression::Function(Box::new(func))
251 }
252 Expression::AggregateFunction(agg) => {
253 let mut agg = *agg;
254 agg.args = agg
255 .args
256 .into_iter()
257 .map(|e| normalize_expression(e, strategy))
258 .collect();
259 Expression::AggregateFunction(Box::new(agg))
260 }
261 Expression::Paren(paren) => {
263 let mut paren = *paren;
264 paren.this = normalize_expression(paren.this, strategy);
265 Expression::Paren(Box::new(paren))
266 }
267 Expression::Case(case) => {
268 let mut case = *case;
269 case.operand = case.operand.map(|e| normalize_expression(e, strategy));
270 case.whens = case
271 .whens
272 .into_iter()
273 .map(|(w, t)| {
274 (
275 normalize_expression(w, strategy),
276 normalize_expression(t, strategy),
277 )
278 })
279 .collect();
280 case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
281 Expression::Case(Box::new(case))
282 }
283 Expression::Cast(cast) => {
284 let mut cast = *cast;
285 cast.this = normalize_expression(cast.this, strategy);
286 Expression::Cast(Box::new(cast))
287 }
288 Expression::In(in_expr) => {
289 let mut in_expr = *in_expr;
290 in_expr.this = normalize_expression(in_expr.this, strategy);
291 in_expr.expressions = in_expr
292 .expressions
293 .into_iter()
294 .map(|e| normalize_expression(e, strategy))
295 .collect();
296 if let Some(q) = in_expr.query {
297 in_expr.query = Some(normalize_expression(q, strategy));
298 }
299 Expression::In(Box::new(in_expr))
300 }
301 Expression::Between(between) => {
302 let mut between = *between;
303 between.this = normalize_expression(between.this, strategy);
304 between.low = normalize_expression(between.low, strategy);
305 between.high = normalize_expression(between.high, strategy);
306 Expression::Between(Box::new(between))
307 }
308 Expression::Subquery(subquery) => {
309 let mut subquery = *subquery;
310 subquery.this = normalize_expression(subquery.this, strategy);
311 if let Some(alias) = subquery.alias {
312 subquery.alias = Some(normalize_identifier(alias, strategy));
313 }
314 Expression::Subquery(Box::new(subquery))
315 }
316 Expression::Union(union) => {
318 let mut union = *union;
319 union.left = normalize_expression(union.left, strategy);
320 union.right = normalize_expression(union.right, strategy);
321 Expression::Union(Box::new(union))
322 }
323 Expression::Intersect(intersect) => {
324 let mut intersect = *intersect;
325 intersect.left = normalize_expression(intersect.left, strategy);
326 intersect.right = normalize_expression(intersect.right, strategy);
327 Expression::Intersect(Box::new(intersect))
328 }
329 Expression::Except(except) => {
330 let mut except = *except;
331 except.left = normalize_expression(except.left, strategy);
332 except.right = normalize_expression(except.right, strategy);
333 Expression::Except(Box::new(except))
334 }
335 _ => expression,
337 }
338}
339
340fn normalize_binary<F>(
342 constructor: F,
343 mut bin: crate::expressions::BinaryOp,
344 strategy: NormalizationStrategy,
345) -> Expression
346where
347 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
348{
349 bin.left = normalize_expression(bin.left, strategy);
350 bin.right = normalize_expression(bin.right, strategy);
351 constructor(Box::new(bin))
352}
353
354pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
356 match strategy {
357 NormalizationStrategy::CaseInsensitive
358 | NormalizationStrategy::CaseInsensitiveUppercase => false,
359 NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
360 NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
361 NormalizationStrategy::CaseSensitive => true,
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::generator::Generator;
369 use crate::parser::Parser;
370
371 fn gen(expr: &Expression) -> String {
372 Generator::new().generate(expr).unwrap()
373 }
374
375 fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
376 let ast = Parser::parse_sql(sql).expect("Failed to parse");
377 let normalized = normalize_identifiers(ast[0].clone(), dialect);
378 gen(&normalized)
379 }
380
381 #[test]
382 fn test_normalize_lowercase() {
383 let result = parse_and_normalize("SELECT FoO FROM Bar", None);
385 assert!(result.contains("foo") || result.contains("FOO")); }
387
388 #[test]
389 fn test_normalize_uppercase() {
390 let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
392 assert!(result.to_uppercase().contains("FOO"));
394 }
395
396 #[test]
397 fn test_normalize_preserves_quoted() {
398 let id = Identifier {
400 name: "FoO".to_string(),
401 quoted: true,
402 trailing_comments: vec![],
403 span: None,
404 };
405 let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
406 assert_eq!(normalized.name, "FoO"); }
408
409 #[test]
410 fn test_case_insensitive_normalizes_quoted() {
411 let id = Identifier {
413 name: "FoO".to_string(),
414 quoted: true,
415 trailing_comments: vec![],
416 span: None,
417 };
418 let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
419 assert_eq!(normalized.name, "foo"); }
421
422 #[test]
423 fn test_case_sensitive_no_normalization() {
424 let id = Identifier {
426 name: "FoO".to_string(),
427 quoted: false,
428 trailing_comments: vec![],
429 span: None,
430 };
431 let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
432 assert_eq!(normalized.name, "FoO"); }
434
435 #[test]
436 fn test_normalize_column() {
437 let col = Expression::Column(Column {
438 name: Identifier::new("MyColumn"),
439 table: Some(Identifier::new("MyTable")),
440 join_mark: false,
441 trailing_comments: vec![],
442 span: None,
443 inferred_type: None,
444 });
445
446 let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
447 let sql = gen(&normalized);
448 assert!(sql.contains("mycolumn") || sql.contains("mytable"));
449 }
450
451 #[test]
452 fn test_get_normalization_strategy() {
453 assert_eq!(
454 get_normalization_strategy(Some(DialectType::Snowflake)),
455 NormalizationStrategy::Uppercase
456 );
457 assert_eq!(
458 get_normalization_strategy(Some(DialectType::PostgreSQL)),
459 NormalizationStrategy::Lowercase
460 );
461 assert_eq!(
462 get_normalization_strategy(Some(DialectType::MySQL)),
463 NormalizationStrategy::CaseSensitive
464 );
465 assert_eq!(
466 get_normalization_strategy(Some(DialectType::DuckDB)),
467 NormalizationStrategy::CaseInsensitive
468 );
469 }
470}