1use crate::dialects::DialectType;
9use crate::expressions::{Column, Expression, Identifier};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NormalizationStrategy {
14 Lowercase,
16 Uppercase,
18 CaseSensitive,
20 CaseInsensitive,
22 CaseInsensitiveUppercase,
24}
25
26impl Default for NormalizationStrategy {
27 fn default() -> Self {
28 Self::Lowercase
29 }
30}
31
32pub fn get_normalization_strategy(dialect: Option<DialectType>) -> NormalizationStrategy {
34 match dialect {
35 Some(DialectType::Oracle) | Some(DialectType::Snowflake) | Some(DialectType::Exasol) => {
37 NormalizationStrategy::Uppercase
38 }
39 Some(DialectType::MySQL) | Some(DialectType::ClickHouse) => {
41 NormalizationStrategy::CaseSensitive
42 }
43 Some(DialectType::DuckDB)
45 | Some(DialectType::SQLite)
46 | Some(DialectType::BigQuery)
47 | Some(DialectType::Presto)
48 | Some(DialectType::Trino)
49 | Some(DialectType::Hive)
50 | Some(DialectType::Spark)
51 | Some(DialectType::Databricks)
52 | Some(DialectType::Redshift) => NormalizationStrategy::CaseInsensitive,
53 _ => NormalizationStrategy::Lowercase,
55 }
56}
57
58pub fn normalize_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
73 let strategy = get_normalization_strategy(dialect);
74 normalize_expression(expression, strategy)
75}
76
77pub fn normalize_identifier(identifier: Identifier, strategy: NormalizationStrategy) -> Identifier {
86 if strategy == NormalizationStrategy::CaseSensitive {
88 return identifier;
89 }
90
91 if identifier.quoted
93 && strategy != NormalizationStrategy::CaseInsensitive
94 && strategy != NormalizationStrategy::CaseInsensitiveUppercase
95 {
96 return identifier;
97 }
98
99 let normalized_name = match strategy {
101 NormalizationStrategy::Uppercase | NormalizationStrategy::CaseInsensitiveUppercase => {
102 identifier.name.to_uppercase()
103 }
104 NormalizationStrategy::Lowercase | NormalizationStrategy::CaseInsensitive => {
105 identifier.name.to_lowercase()
106 }
107 NormalizationStrategy::CaseSensitive => identifier.name, };
109
110 Identifier {
111 name: normalized_name,
112 quoted: identifier.quoted,
113 trailing_comments: identifier.trailing_comments,
114 span: None,
115 }
116}
117
118fn normalize_expression(expression: Expression, strategy: NormalizationStrategy) -> Expression {
120 match expression {
121 Expression::Identifier(id) => Expression::Identifier(normalize_identifier(id, strategy)),
122 Expression::Column(col) => Expression::Column(Column {
123 name: normalize_identifier(col.name, strategy),
124 table: col.table.map(|t| normalize_identifier(t, strategy)),
125 join_mark: col.join_mark,
126 trailing_comments: col.trailing_comments,
127 span: None,
128 }),
129 Expression::Table(mut table) => {
130 table.name = normalize_identifier(table.name, strategy);
131 if let Some(schema) = table.schema {
132 table.schema = Some(normalize_identifier(schema, strategy));
133 }
134 if let Some(catalog) = table.catalog {
135 table.catalog = Some(normalize_identifier(catalog, strategy));
136 }
137 if let Some(alias) = table.alias {
138 table.alias = Some(normalize_identifier(alias, strategy));
139 }
140 table.column_aliases = table
141 .column_aliases
142 .into_iter()
143 .map(|a| normalize_identifier(a, strategy))
144 .collect();
145 Expression::Table(table)
146 }
147 Expression::Select(select) => {
148 let mut select = *select;
149 select.expressions = select
151 .expressions
152 .into_iter()
153 .map(|e| normalize_expression(e, strategy))
154 .collect();
155 if let Some(mut from) = select.from {
157 from.expressions = from
158 .expressions
159 .into_iter()
160 .map(|e| normalize_expression(e, strategy))
161 .collect();
162 select.from = Some(from);
163 }
164 select.joins = select
166 .joins
167 .into_iter()
168 .map(|mut j| {
169 j.this = normalize_expression(j.this, strategy);
170 if let Some(on) = j.on {
171 j.on = Some(normalize_expression(on, strategy));
172 }
173 j
174 })
175 .collect();
176 if let Some(mut where_clause) = select.where_clause {
178 where_clause.this = normalize_expression(where_clause.this, strategy);
179 select.where_clause = Some(where_clause);
180 }
181 if let Some(mut group_by) = select.group_by {
183 group_by.expressions = group_by
184 .expressions
185 .into_iter()
186 .map(|e| normalize_expression(e, strategy))
187 .collect();
188 select.group_by = Some(group_by);
189 }
190 if let Some(mut having) = select.having {
192 having.this = normalize_expression(having.this, strategy);
193 select.having = Some(having);
194 }
195 if let Some(mut order_by) = select.order_by {
197 order_by.expressions = order_by
198 .expressions
199 .into_iter()
200 .map(|mut o| {
201 o.this = normalize_expression(o.this, strategy);
202 o
203 })
204 .collect();
205 select.order_by = Some(order_by);
206 }
207 Expression::Select(Box::new(select))
208 }
209 Expression::Alias(alias) => {
210 let mut alias = *alias;
211 alias.this = normalize_expression(alias.this, strategy);
212 alias.alias = normalize_identifier(alias.alias, strategy);
213 Expression::Alias(Box::new(alias))
214 }
215 Expression::And(bin) => normalize_binary(Expression::And, *bin, strategy),
217 Expression::Or(bin) => normalize_binary(Expression::Or, *bin, strategy),
218 Expression::Add(bin) => normalize_binary(Expression::Add, *bin, strategy),
219 Expression::Sub(bin) => normalize_binary(Expression::Sub, *bin, strategy),
220 Expression::Mul(bin) => normalize_binary(Expression::Mul, *bin, strategy),
221 Expression::Div(bin) => normalize_binary(Expression::Div, *bin, strategy),
222 Expression::Mod(bin) => normalize_binary(Expression::Mod, *bin, strategy),
223 Expression::Eq(bin) => normalize_binary(Expression::Eq, *bin, strategy),
224 Expression::Neq(bin) => normalize_binary(Expression::Neq, *bin, strategy),
225 Expression::Lt(bin) => normalize_binary(Expression::Lt, *bin, strategy),
226 Expression::Lte(bin) => normalize_binary(Expression::Lte, *bin, strategy),
227 Expression::Gt(bin) => normalize_binary(Expression::Gt, *bin, strategy),
228 Expression::Gte(bin) => normalize_binary(Expression::Gte, *bin, strategy),
229 Expression::Concat(bin) => normalize_binary(Expression::Concat, *bin, strategy),
230 Expression::Not(un) => {
232 let mut un = *un;
233 un.this = normalize_expression(un.this, strategy);
234 Expression::Not(Box::new(un))
235 }
236 Expression::Neg(un) => {
237 let mut un = *un;
238 un.this = normalize_expression(un.this, strategy);
239 Expression::Neg(Box::new(un))
240 }
241 Expression::Function(func) => {
243 let mut func = *func;
244 func.args = func
245 .args
246 .into_iter()
247 .map(|e| normalize_expression(e, strategy))
248 .collect();
249 Expression::Function(Box::new(func))
250 }
251 Expression::AggregateFunction(agg) => {
252 let mut agg = *agg;
253 agg.args = agg
254 .args
255 .into_iter()
256 .map(|e| normalize_expression(e, strategy))
257 .collect();
258 Expression::AggregateFunction(Box::new(agg))
259 }
260 Expression::Paren(paren) => {
262 let mut paren = *paren;
263 paren.this = normalize_expression(paren.this, strategy);
264 Expression::Paren(Box::new(paren))
265 }
266 Expression::Case(case) => {
267 let mut case = *case;
268 case.operand = case.operand.map(|e| normalize_expression(e, strategy));
269 case.whens = case
270 .whens
271 .into_iter()
272 .map(|(w, t)| {
273 (
274 normalize_expression(w, strategy),
275 normalize_expression(t, strategy),
276 )
277 })
278 .collect();
279 case.else_ = case.else_.map(|e| normalize_expression(e, strategy));
280 Expression::Case(Box::new(case))
281 }
282 Expression::Cast(cast) => {
283 let mut cast = *cast;
284 cast.this = normalize_expression(cast.this, strategy);
285 Expression::Cast(Box::new(cast))
286 }
287 Expression::In(in_expr) => {
288 let mut in_expr = *in_expr;
289 in_expr.this = normalize_expression(in_expr.this, strategy);
290 in_expr.expressions = in_expr
291 .expressions
292 .into_iter()
293 .map(|e| normalize_expression(e, strategy))
294 .collect();
295 if let Some(q) = in_expr.query {
296 in_expr.query = Some(normalize_expression(q, strategy));
297 }
298 Expression::In(Box::new(in_expr))
299 }
300 Expression::Between(between) => {
301 let mut between = *between;
302 between.this = normalize_expression(between.this, strategy);
303 between.low = normalize_expression(between.low, strategy);
304 between.high = normalize_expression(between.high, strategy);
305 Expression::Between(Box::new(between))
306 }
307 Expression::Subquery(subquery) => {
308 let mut subquery = *subquery;
309 subquery.this = normalize_expression(subquery.this, strategy);
310 if let Some(alias) = subquery.alias {
311 subquery.alias = Some(normalize_identifier(alias, strategy));
312 }
313 Expression::Subquery(Box::new(subquery))
314 }
315 Expression::Union(union) => {
317 let mut union = *union;
318 union.left = normalize_expression(union.left, strategy);
319 union.right = normalize_expression(union.right, strategy);
320 Expression::Union(Box::new(union))
321 }
322 Expression::Intersect(intersect) => {
323 let mut intersect = *intersect;
324 intersect.left = normalize_expression(intersect.left, strategy);
325 intersect.right = normalize_expression(intersect.right, strategy);
326 Expression::Intersect(Box::new(intersect))
327 }
328 Expression::Except(except) => {
329 let mut except = *except;
330 except.left = normalize_expression(except.left, strategy);
331 except.right = normalize_expression(except.right, strategy);
332 Expression::Except(Box::new(except))
333 }
334 _ => expression,
336 }
337}
338
339fn normalize_binary<F>(
341 constructor: F,
342 mut bin: crate::expressions::BinaryOp,
343 strategy: NormalizationStrategy,
344) -> Expression
345where
346 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
347{
348 bin.left = normalize_expression(bin.left, strategy);
349 bin.right = normalize_expression(bin.right, strategy);
350 constructor(Box::new(bin))
351}
352
353pub fn is_case_sensitive(text: &str, strategy: NormalizationStrategy) -> bool {
355 match strategy {
356 NormalizationStrategy::CaseInsensitive
357 | NormalizationStrategy::CaseInsensitiveUppercase => false,
358 NormalizationStrategy::Uppercase => text.chars().any(|c| c.is_lowercase()),
359 NormalizationStrategy::Lowercase => text.chars().any(|c| c.is_uppercase()),
360 NormalizationStrategy::CaseSensitive => true,
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::generator::Generator;
368 use crate::parser::Parser;
369
370 fn gen(expr: &Expression) -> String {
371 Generator::new().generate(expr).unwrap()
372 }
373
374 fn parse_and_normalize(sql: &str, dialect: Option<DialectType>) -> String {
375 let ast = Parser::parse_sql(sql).expect("Failed to parse");
376 let normalized = normalize_identifiers(ast[0].clone(), dialect);
377 gen(&normalized)
378 }
379
380 #[test]
381 fn test_normalize_lowercase() {
382 let result = parse_and_normalize("SELECT FoO FROM Bar", None);
384 assert!(result.contains("foo") || result.contains("FOO")); }
386
387 #[test]
388 fn test_normalize_uppercase() {
389 let result = parse_and_normalize("SELECT foo FROM bar", Some(DialectType::Snowflake));
391 assert!(result.to_uppercase().contains("FOO"));
393 }
394
395 #[test]
396 fn test_normalize_preserves_quoted() {
397 let id = Identifier {
399 name: "FoO".to_string(),
400 quoted: true,
401 trailing_comments: vec![],
402 span: None,
403 };
404 let normalized = normalize_identifier(id, NormalizationStrategy::Lowercase);
405 assert_eq!(normalized.name, "FoO"); }
407
408 #[test]
409 fn test_case_insensitive_normalizes_quoted() {
410 let id = Identifier {
412 name: "FoO".to_string(),
413 quoted: true,
414 trailing_comments: vec![],
415 span: None,
416 };
417 let normalized = normalize_identifier(id, NormalizationStrategy::CaseInsensitive);
418 assert_eq!(normalized.name, "foo"); }
420
421 #[test]
422 fn test_case_sensitive_no_normalization() {
423 let id = Identifier {
425 name: "FoO".to_string(),
426 quoted: false,
427 trailing_comments: vec![],
428 span: None,
429 };
430 let normalized = normalize_identifier(id, NormalizationStrategy::CaseSensitive);
431 assert_eq!(normalized.name, "FoO"); }
433
434 #[test]
435 fn test_normalize_column() {
436 let col = Expression::Column(Column {
437 name: Identifier::new("MyColumn"),
438 table: Some(Identifier::new("MyTable")),
439 join_mark: false,
440 trailing_comments: vec![],
441 span: None,
442 });
443
444 let normalized = normalize_expression(col, NormalizationStrategy::Lowercase);
445 let sql = gen(&normalized);
446 assert!(sql.contains("mycolumn") || sql.contains("mytable"));
447 }
448
449 #[test]
450 fn test_get_normalization_strategy() {
451 assert_eq!(
452 get_normalization_strategy(Some(DialectType::Snowflake)),
453 NormalizationStrategy::Uppercase
454 );
455 assert_eq!(
456 get_normalization_strategy(Some(DialectType::PostgreSQL)),
457 NormalizationStrategy::Lowercase
458 );
459 assert_eq!(
460 get_normalization_strategy(Some(DialectType::MySQL)),
461 NormalizationStrategy::CaseSensitive
462 );
463 assert_eq!(
464 get_normalization_strategy(Some(DialectType::DuckDB)),
465 NormalizationStrategy::CaseInsensitive
466 );
467 }
468}