1use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14 Lowercase,
16 Uppercase,
18 CaseSensitive,
20 CaseInsensitive,
22 CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27 fn default() -> Self {
28 Self::Lowercase
29 }
30}
31
32pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34 match dialect {
35 Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37 NormalizationStrategy::Uppercase
38 }
39 Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41 NormalizationStrategy::CaseSensitive
42 }
43 Some(DialectType::DuckDB)
45 | Some(DialectType::SQLite)
46 | Some(DialectType::BigQuery)
47 | Some(DialectType::Presto)
48 | Some(DialectType::Trino)
49 | Some(DialectType::Hive)
50 | Some(DialectType::Spark)
51 | Some(DialectType::Databricks)
52 | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53 _ => NormalizationStrategy::Lowercase,
55 }
56}
57
58pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73 let strategy = get_normalization_strategy(dialect);
74 normalize_expression(expression, strategy)
75}
76
77pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86 if strategy == NormalizationStrategy::CaseSensitive {
88 return identifier;
89 }
90
91 if identifier.quoted
93 && strategy != NormalizationStrategy::CaseInsensitive
94 && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95 {
96 return identifier;
97 }
98
99 let normalized_name = match strategy {
101 NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102 identifier.name.to_uppercase()
103 }
104 NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105 identifier.name.to_lowercase()
106 }
107 NormalizationStrategy::CaseSensitive => identifier.name, };
109
110 Identifier {
111 name: normalized_name,
112 quoted: identifier.quoted,
113 trailing_comments: identifier.trailing_comments,
114 }
115}
116
117fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
119 match expression {
120 Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
121 Expression::Column(col) => Expression::Column(Column {
122 name: normalize_identifier(col.name, strategy),
123 table: col.table.map(|t| normalize_identifier(t, strategy)),
124 join_mark: col.join_mark,
125 trailing_comments: col.trailing_comments,
126 }),
127 Expression::Table(mut table) => {
128 table.name = normalize_identifier(table.name, strategy);
129 if let Some(schema) = table.schema {
130 table.schema = Some(normalize_identifier(schema, strategy));
131 }
132 if let Some(catalog) = table.catalog {
133 table.catalog = Some(normalize_identifier(catalog, strategy));
134 }
135 if let Some(alias) = table.alias {
136 table.alias = Some(normalize_identifier(alias, strategy));
137 }
138 table.column_aliases = table
139 .column_aliases
140 .into_iter()
141 .map(|a| normalize_identifier(a, strategy))
142 .collect();
143 Expression::Table(table)
144 }
145 Expression::Select(select) => {
146 let mut select = *select;
147 select.expressions = select
149 .expressions
150 .into_iter()
151 .map(|e| normalize_expression(e, strategy))
152 .collect();
153 if let Some(mut from) = select.from {
155 from.expressions = from
156 .expressions
157 .into_iter()
158 .map(|e| normalize_expression(e, strategy))
159 .collect();
160 select.from = Some(from);
161 }
162 select.joins = select
164 .joins
165 .into_iter()
166 .map(|mut j| {
167 j.this = normalize_expression(j.this, strategy);
168 if let Some(on) = j.on {
169 j.on = Some(normalize_expression(on, strategy));
170 }
171 j
172 })
173 .collect();
174 if let Some(mut where_clause) = select.where_clause {
176 where_clause.this = normalize_expression(where_clause.this, strategy);
177 select.where_clause = Some(where_clause);
178 }
179 if let Some(mut group_by) = select.group_by {
181 group_by.expressions = group_by
182 .expressions
183 .into_iter()
184 .map(|e| normalize_expression(e, strategy))
185 .collect();
186 select.group_by = Some(group_by);
187 }
188 if let Some(mut having) = select.having {
190 having.this = normalize_expression(having.this, strategy);
191 select.having = Some(having);
192 }
193 if let Some(mut order_by) = select.order_by {
195 order_by.expressions = order_by
196 .expressions
197 .into_iter()
198 .map(|mut o| {
199 o.this = normalize_expression(o.this, strategy);
200 o
201 })
202 .collect();
203 select.order_by = Some(order_by);
204 }
205 Expression::Select(Box::new(select))
206 }
207 Expression::Alias(alias) => {
208 let mut alias = *alias;
209 alias.this = normalize_expression(alias.this, strategy);
210 alias.alias = normalize_identifier(alias.alias, strategy);
211 Expression::Alias(Box::new(alias))
212 }
213 Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
215 Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
216 Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
217 Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
218 Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
219 Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
220 Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
221 Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
222 Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
223 Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
224 Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
225 Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
226 Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
227 Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
228 Expression::Not(un) => {
230 let mut un = *un;
231 un.this = normalize_expression(un.this, strategy);
232 Expression::Not(Box::new(un))
233 }
234 Expression::Neg(un) => {
235 let mut un = *un;
236 un.this = normalize_expression(un.this, strategy);
237 Expression::Neg(Box::new(un))
238 }
239 Expression::Function(func) => {
241 let mut func = *func;
242 func.args = func
243 .args
244 .into_iter()
245 .map(|e| normalize_expression(e, strategy))
246 .collect();
247 Expression::Function(Box::new(func))
248 }
249 Expression::AggregateFunction(agg) => {
250 let mut agg = *agg;
251 agg.args = agg
252 .args
253 .into_iter()
254 .map(|e| normalize_expression(e, strategy))
255 .collect();
256 Expression::AggregateFunction(Box::new(agg))
257 }
258 Expression::Paren(paren) => {
260 let mut paren = *paren;
261 paren.this = normalize_expression(paren.this, strategy);
262 Expression::Paren(Box::new(paren))
263 }
264 Expression::Case(case) => {
265 let mut case = *case;
266 case.operand = case.operand.map(|e| normalize_expression(e, strategy));
267 case.whens = case
268 .whens
269 .into_iter()
270 .map(|(w, t)| {
271 (
272 normalize_expression(w, strategy),
273 normalize_expression(t, strategy),
274 )
275 })
276 .collect();
277 case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
278 Expression::Case(Box::new(case))
279 }
280 Expression::Cast(cast) => {
281 let mut cast = *cast;
282 cast.this = normalize_expression(cast.this, strategy);
283 Expression::Cast(Box::new(cast))
284 }
285 Expression::In(in_expr) => {
286 let mut in_expr = *in_expr;
287 in_expr.this = normalize_expression(in_expr.this, strategy);
288 in_expr.expressions = in_expr
289 .expressions
290 .into_iter()
291 .map(|e| normalize_expression(e, strategy))
292 .collect();
293 if let Some(q) = in_expr.query {
294 in_expr.query = Some(normalize_expression(q, strategy));
295 }
296 Expression::In(Box::new(in_expr))
297 }
298 Expression::Between(between) => {
299 let mut between = *between;
300 between.this = normalize_expression(between.this, strategy);
301 between.low = normalize_expression(between.low, strategy);
302 between.high = normalize_expression(between.high, strategy);
303 Expression::Between(Box::new(between))
304 }
305 Expression::Subquery(subquery) => {
306 let mut subquery = *subquery;
307 subquery.this = normalize_expression(subquery.this, strategy);
308 if let Some(alias) = subquery.alias {
309 subquery.alias = Some(normalize_identifier(alias, strategy));
310 }
311 Expression::Subquery(Box::new(subquery))
312 }
313 Expression::Union(union) => {
315 let mut union = *union;
316 union.left = normalize_expression(union.left, strategy);
317 union.right = normalize_expression(union.right, strategy);
318 Expression::Union(Box::new(union))
319 }
320 Expression::Intersect(intersect) => {
321 let mut intersect = *intersect;
322 intersect.left = normalize_expression(intersect.left, strategy);
323 intersect.right = normalize_expression(intersect.right, strategy);
324 Expression::Intersect(Box::new(intersect))
325 }
326 Expression::Except(except) => {
327 let mut except = *except;
328 except.left = normalize_expression(except.left, strategy);
329 except.right = normalize_expression(except.right, strategy);
330 Expression::Except(Box::new(except))
331 }
332 _ => expression,
334 }
335}
336
337fn normalize_binary<F>(
339 constructor: F,
340 mut bin: crate::expressions::BinaryOp,
341 strategy: NormalizationStrategy,
342) -> Expression
343where
344 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
345{
346 bin.left = normalize_expression(bin.left, strategy);
347 bin.right = normalize_expression(bin.right, strategy);
348 constructor(Box::new(bin))
349}
350
351pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
353 match strategy {
354 NormalizationStrategy::CaseInsensitive
355 | NormalizationStrategy::CaseInsensitiveUppercase => false,
356 NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
357 NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
358 NormalizationStrategy::CaseSensitive => true,
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::generator::Generator;
366 use crate::parser::Parser;
367
368 fn gen(expr: &Expression) -> String {
369 Generator::new().generate(expr).unwrap()
370 }
371
372 fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
373 let ast = Parser::parse_sql(sql).expect("Failed to parse");
374 let normalized = normalize_identifiers(ast[0].clone(), dialect);
375 gen(&normalized)
376 }
377
378 #[test]
379 fn test_normalize_lowercase() {
380 let result = parse_and_normalize("SELECT FoO FROM Bar", None);
382 assert!(result.contains("foo") || result.contains("FOO")); }
384
385 #[test]
386 fn test_normalize_uppercase() {
387 let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
389 assert!(result.to_uppercase().contains("FOO"));
391 }
392
393 #[test]
394 fn test_normalize_preserves_quoted() {
395 let id = Identifier {
397 name: "FoO".to_string(),
398 quoted: true,
399 trailing_comments: vec![],
400 };
401 let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
402 assert_eq!(normalized.name, "FoO"); }
404
405 #[test]
406 fn test_case_insensitive_normalizes_quoted() {
407 let id = Identifier {
409 name: "FoO".to_string(),
410 quoted: true,
411 trailing_comments: vec![],
412 };
413 let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
414 assert_eq!(normalized.name, "foo"); }
416
417 #[test]
418 fn test_case_sensitive_no_normalization() {
419 let id = Identifier {
421 name: "FoO".to_string(),
422 quoted: false,
423 trailing_comments: vec![],
424 };
425 let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
426 assert_eq!(normalized.name, "FoO"); }
428
429 #[test]
430 fn test_normalize_column() {
431 let col = Expression::Column(Column {
432 name: Identifier::new("MyColumn"),
433 table: Some(Identifier::new("MyTable")),
434 join_mark: false,
435 trailing_comments: vec![],
436 });
437
438 let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
439 let sql = gen(&normalized);
440 assert!(sql.contains("mycolumn") || sql.contains("mytable"));
441 }
442
443 #[test]
444 fn test_get_normalization_strategy() {
445 assert_eq!(
446 get_normalization_strategy(Some(DialectType::Snowflake)),
447 NormalizationStrategy::Uppercase
448 );
449 assert_eq!(
450 get_normalization_strategy(Some(DialectType::PostgreSQL)),
451 NormalizationStrategy::Lowercase
452 );
453 assert_eq!(
454 get_normalization_strategy(Some(DialectType::MySQL)),
455 NormalizationStrategy::CaseSensitive
456 );
457 assert_eq!(
458 get_normalization_strategy(Some(DialectType::DuckDB)),
459 NormalizationStrategy::CaseInsensitive
460 );
461 }
462}