1use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14 Lowercase,
16 Uppercase,
18 CaseSensitive,
20 CaseInsensitive,
22 CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27 fn default() -> Self {
28 Self::Lowercase
29 }
30}
31
32pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34 match dialect {
35 Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37 NormalizationStrategy::Uppercase
38 }
39 Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41 NormalizationStrategy::CaseSensitive
42 }
43 Some(DialectType::DuckDB)
45 | Some(DialectType::SQLite)
46 | Some(DialectType::BigQuery)
47 | Some(DialectType::Presto)
48 | Some(DialectType::Trino)
49 | Some(DialectType::Hive)
50 | Some(DialectType::Spark)
51 | Some(DialectType::Databricks)
52 | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53 _ => NormalizationStrategy::Lowercase,
55 }
56}
57
58pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73 let strategy = get_normalization_strategy(dialect);
74 normalize_expression(expression, strategy)
75}
76
77pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86 if strategy == NormalizationStrategy::CaseSensitive {
88 return identifier;
89 }
90
91 if identifier.quoted
93 && strategy != NormalizationStrategy::CaseInsensitive
94 && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95 {
96 return identifier;
97 }
98
99 let normalized_name = match strategy {
101 NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102 identifier.name.to_uppercase()
103 }
104 NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105 identifier.name.to_lowercase()
106 }
107 NormalizationStrategy::CaseSensitive => identifier.name, };
109
110 Identifier {
111 name: normalized_name,
112 quoted: identifier.quoted,
113 trailing_comments: identifier.trailing_comments,
114 }
115}
116
117fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
119 match expression {
120 Expression::Identifier(id) => {
121 Expression::Identifier(normalize_identifier(id, strategy))
122 }
123 Expression::Column(col) => {
124 Expression::Column(Column {
125 name: normalize_identifier(col.name, strategy),
126 table: col.table.map(|t| normalize_identifier(t, strategy)),
127 join_mark: col.join_mark,
128 trailing_comments: col.trailing_comments,
129 })
130 }
131 Expression::Table(mut table) => {
132 table.name = normalize_identifier(table.name, strategy);
133 if let Some(schema) = table.schema {
134 table.schema = Some(normalize_identifier(schema, strategy));
135 }
136 if let Some(catalog) = table.catalog {
137 table.catalog = Some(normalize_identifier(catalog, strategy));
138 }
139 if let Some(alias) = table.alias {
140 table.alias = Some(normalize_identifier(alias, strategy));
141 }
142 table.column_aliases = table
143 .column_aliases
144 .into_iter()
145 .map(|a| normalize_identifier(a, strategy))
146 .collect();
147 Expression::Table(table)
148 }
149 Expression::Select(select) => {
150 let mut select = *select;
151 select.expressions = select
153 .expressions
154 .into_iter()
155 .map(|e| normalize_expression(e, strategy))
156 .collect();
157 if let Some(mut from) = select.from {
159 from.expressions = from
160 .expressions
161 .into_iter()
162 .map(|e| normalize_expression(e, strategy))
163 .collect();
164 select.from = Some(from);
165 }
166 select.joins = select
168 .joins
169 .into_iter()
170 .map(|mut j| {
171 j.this = normalize_expression(j.this, strategy);
172 if let Some(on) = j.on {
173 j.on = Some(normalize_expression(on, strategy));
174 }
175 j
176 })
177 .collect();
178 if let Some(mut where_clause) = select.where_clause {
180 where_clause.this = normalize_expression(where_clause.this, strategy);
181 select.where_clause = Some(where_clause);
182 }
183 if let Some(mut group_by) = select.group_by {
185 group_by.expressions = group_by
186 .expressions
187 .into_iter()
188 .map(|e| normalize_expression(e, strategy))
189 .collect();
190 select.group_by = Some(group_by);
191 }
192 if let Some(mut having) = select.having {
194 having.this = normalize_expression(having.this, strategy);
195 select.having = Some(having);
196 }
197 if let Some(mut order_by) = select.order_by {
199 order_by.expressions = order_by
200 .expressions
201 .into_iter()
202 .map(|mut o| {
203 o.this = normalize_expression(o.this, strategy);
204 o
205 })
206 .collect();
207 select.order_by = Some(order_by);
208 }
209 Expression::Select(Box::new(select))
210 }
211 Expression::Alias(alias) => {
212 let mut alias = *alias;
213 alias.this = normalize_expression(alias.this, strategy);
214 alias.alias = normalize_identifier(alias.alias, strategy);
215 Expression::Alias(Box::new(alias))
216 }
217 Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
219 Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
220 Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
221 Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
222 Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
223 Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
224 Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
225 Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
226 Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
227 Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
228 Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
229 Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
230 Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
231 Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
232 Expression::Not(un) => {
234 let mut un = *un;
235 un.this = normalize_expression(un.this, strategy);
236 Expression::Not(Box::new(un))
237 }
238 Expression::Neg(un) => {
239 let mut un = *un;
240 un.this = normalize_expression(un.this, strategy);
241 Expression::Neg(Box::new(un))
242 }
243 Expression::Function(func) => {
245 let mut func = *func;
246 func.args = func
247 .args
248 .into_iter()
249 .map(|e| normalize_expression(e, strategy))
250 .collect();
251 Expression::Function(Box::new(func))
252 }
253 Expression::AggregateFunction(agg) => {
254 let mut agg = *agg;
255 agg.args = agg
256 .args
257 .into_iter()
258 .map(|e| normalize_expression(e, strategy))
259 .collect();
260 Expression::AggregateFunction(Box::new(agg))
261 }
262 Expression::Paren(paren) => {
264 let mut paren = *paren;
265 paren.this = normalize_expression(paren.this, strategy);
266 Expression::Paren(Box::new(paren))
267 }
268 Expression::Case(case) => {
269 let mut case = *case;
270 case.operand = case.operand.map(|e| normalize_expression(e, strategy));
271 case.whens = case
272 .whens
273 .into_iter()
274 .map(|(w, t)| {
275 (
276 normalize_expression(w, strategy),
277 normalize_expression(t, strategy),
278 )
279 })
280 .collect();
281 case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
282 Expression::Case(Box::new(case))
283 }
284 Expression::Cast(cast) => {
285 let mut cast = *cast;
286 cast.this = normalize_expression(cast.this, strategy);
287 Expression::Cast(Box::new(cast))
288 }
289 Expression::In(in_expr) => {
290 let mut in_expr = *in_expr;
291 in_expr.this = normalize_expression(in_expr.this, strategy);
292 in_expr.expressions = in_expr
293 .expressions
294 .into_iter()
295 .map(|e| normalize_expression(e, strategy))
296 .collect();
297 if let Some(q) = in_expr.query {
298 in_expr.query = Some(normalize_expression(q, strategy));
299 }
300 Expression::In(Box::new(in_expr))
301 }
302 Expression::Between(between) => {
303 let mut between = *between;
304 between.this = normalize_expression(between.this, strategy);
305 between.low = normalize_expression(between.low, strategy);
306 between.high = normalize_expression(between.high, strategy);
307 Expression::Between(Box::new(between))
308 }
309 Expression::Subquery(subquery) => {
310 let mut subquery = *subquery;
311 subquery.this = normalize_expression(subquery.this, strategy);
312 if let Some(alias) = subquery.alias {
313 subquery.alias = Some(normalize_identifier(alias, strategy));
314 }
315 Expression::Subquery(Box::new(subquery))
316 }
317 Expression::Union(union) => {
319 let mut union = *union;
320 union.left = normalize_expression(union.left, strategy);
321 union.right = normalize_expression(union.right, strategy);
322 Expression::Union(Box::new(union))
323 }
324 Expression::Intersect(intersect) => {
325 let mut intersect = *intersect;
326 intersect.left = normalize_expression(intersect.left, strategy);
327 intersect.right = normalize_expression(intersect.right, strategy);
328 Expression::Intersect(Box::new(intersect))
329 }
330 Expression::Except(except) => {
331 let mut except = *except;
332 except.left = normalize_expression(except.left, strategy);
333 except.right = normalize_expression(except.right, strategy);
334 Expression::Except(Box::new(except))
335 }
336 _ => expression,
338 }
339}
340
341fn normalize_binary<F>(
343 constructor: F,
344 mut bin: crate::expressions::BinaryOp,
345 strategy: NormalizationStrategy,
346) -> Expression
347where
348 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
349{
350 bin.left = normalize_expression(bin.left, strategy);
351 bin.right = normalize_expression(bin.right, strategy);
352 constructor(Box::new(bin))
353}
354
355pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
357 match strategy {
358 NormalizationStrategy::CaseInsensitive | NormalizationStrategy::CaseInsensitiveUppercase => {
359 false
360 }
361 NormalizationStrategy::Uppercase => {
362 text.chars().any(|c| c.is_lowercase())
363 }
364 NormalizationStrategy::Lowercase => {
365 text.chars().any(|c| c.is_uppercase())
366 }
367 NormalizationStrategy::CaseSensitive => true,
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::generator::Generator;
375 use crate::parser::Parser;
376
377 fn gen(expr: &Expression) -> String {
378 Generator::new().generate(expr).unwrap()
379 }
380
381 fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
382 let ast = Parser::parse_sql(sql).expect("Failed to parse");
383 let normalized = normalize_identifiers(ast[0].clone(), dialect);
384 gen(&normalized)
385 }
386
387 #[test]
388 fn test_normalize_lowercase() {
389 let result = parse_and_normalize("SELECT FoO FROM Bar", None);
391 assert!(result.contains("foo") || result.contains("FOO")); }
393
394 #[test]
395 fn test_normalize_uppercase() {
396 let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
398 assert!(result.to_uppercase().contains("FOO"));
400 }
401
402 #[test]
403 fn test_normalize_preserves_quoted() {
404 let id = Identifier {
406 name: "FoO".to_string(),
407 quoted: true,
408 trailing_comments: vec![],
409 };
410 let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
411 assert_eq!(normalized.name, "FoO"); }
413
414 #[test]
415 fn test_case_insensitive_normalizes_quoted() {
416 let id = Identifier {
418 name: "FoO".to_string(),
419 quoted: true,
420 trailing_comments: vec![],
421 };
422 let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
423 assert_eq!(normalized.name, "foo"); }
425
426 #[test]
427 fn test_case_sensitive_no_normalization() {
428 let id = Identifier {
430 name: "FoO".to_string(),
431 quoted: false,
432 trailing_comments: vec![],
433 };
434 let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
435 assert_eq!(normalized.name, "FoO"); }
437
438 #[test]
439 fn test_normalize_column() {
440 let col = Expression::Column(Column {
441 name: Identifier::new("MyColumn"),
442 table: Some(Identifier::new("MyTable")),
443 join_mark: false,
444 trailing_comments: vec![],
445 });
446
447 let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
448 let sql = gen(&normalized);
449 assert!(sql.contains("mycolumn") || sql.contains("mytable"));
450 }
451
452 #[test]
453 fn test_get_normalization_strategy() {
454 assert_eq!(
455 get_normalization_strategy(Some(DialectType::Snowflake)),
456 NormalizationStrategy::Uppercase
457 );
458 assert_eq!(
459 get_normalization_strategy(Some(DialectType::PostgreSQL)),
460 NormalizationStrategy::Lowercase
461 );
462 assert_eq!(
463 get_normalization_strategy(Some(DialectType::MySQL)),
464 NormalizationStrategy::CaseSensitive
465 );
466 assert_eq!(
467 get_normalization_strategy(Some(DialectType::DuckDB)),
468 NormalizationStrategy::CaseInsensitive
469 );
470 }
471}