1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal, Null};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 inferred_type: None,
108 }));
109 add_text_to_concat(result)
110 }
111
112 Expression::And(bin) => {
114 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116 Expression::And(Box::new(crate::expressions::BinaryOp {
117 left,
118 right,
119 left_comments: bin.left_comments,
120 operator_comments: bin.operator_comments,
121 trailing_comments: bin.trailing_comments,
122 inferred_type: None,
123 }))
124 }
125 Expression::Or(bin) => {
126 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128 Expression::Or(Box::new(crate::expressions::BinaryOp {
129 left,
130 right,
131 left_comments: bin.left_comments,
132 operator_comments: bin.operator_comments,
133 trailing_comments: bin.trailing_comments,
134 inferred_type: None,
135 }))
136 }
137
138 Expression::Not(un) => {
139 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140 Expression::Not(Box::new(crate::expressions::UnaryOp {
141 this: inner,
142 inferred_type: None,
143 }))
144 }
145
146 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158 Expression::Cast(cast) => {
160 let inner = canonicalize_recursive(cast.this, dialect);
161 let result = Expression::Cast(Box::new(crate::expressions::Cast {
162 this: inner,
163 to: cast.to,
164 trailing_comments: cast.trailing_comments,
165 double_colon_syntax: cast.double_colon_syntax,
166 format: cast.format,
167 default: cast.default,
168 inferred_type: None,
169 }));
170 remove_redundant_casts(result)
171 }
172
173 Expression::Function(func) => {
175 let args = func
176 .args
177 .into_iter()
178 .map(|e| canonicalize_recursive(e, dialect))
179 .collect();
180 Expression::Function(Box::new(crate::expressions::Function {
181 name: func.name,
182 args,
183 distinct: func.distinct,
184 trailing_comments: func.trailing_comments,
185 use_bracket_syntax: func.use_bracket_syntax,
186 no_parens: func.no_parens,
187 quoted: func.quoted,
188 span: None,
189 inferred_type: None,
190 }))
191 }
192
193 Expression::AggregateFunction(agg) => {
194 let args = agg
195 .args
196 .into_iter()
197 .map(|e| canonicalize_recursive(e, dialect))
198 .collect();
199 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200 name: agg.name,
201 args,
202 distinct: agg.distinct,
203 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204 order_by: agg.order_by,
205 limit: agg.limit,
206 ignore_nulls: agg.ignore_nulls,
207 inferred_type: None,
208 }))
209 }
210
211 Expression::Alias(alias) => {
213 let inner = canonicalize_recursive(alias.this, dialect);
214 Expression::Alias(Box::new(crate::expressions::Alias {
215 this: inner,
216 alias: alias.alias,
217 column_aliases: alias.column_aliases,
218 alias_explicit_as: false,
219 alias_keyword: None,
220 pre_alias_comments: alias.pre_alias_comments,
221 trailing_comments: alias.trailing_comments,
222 inferred_type: None,
223 }))
224 }
225
226 Expression::Paren(paren) => {
228 let inner = canonicalize_recursive(paren.this, dialect);
229 Expression::Paren(Box::new(crate::expressions::Paren {
230 this: inner,
231 trailing_comments: paren.trailing_comments,
232 }))
233 }
234
235 Expression::Case(case) => {
237 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
238 let whens = case
239 .whens
240 .into_iter()
241 .map(|(w, t)| {
242 (
243 canonicalize_recursive(w, dialect),
244 canonicalize_recursive(t, dialect),
245 )
246 })
247 .collect();
248 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
249 Expression::Case(Box::new(crate::expressions::Case {
250 operand,
251 whens,
252 else_,
253 comments: Vec::new(),
254 inferred_type: None,
255 }))
256 }
257
258 Expression::Between(between) => {
260 let this = canonicalize_recursive(between.this, dialect);
261 let low = canonicalize_recursive(between.low, dialect);
262 let high = canonicalize_recursive(between.high, dialect);
263 Expression::Between(Box::new(crate::expressions::Between {
264 this,
265 low,
266 high,
267 not: between.not,
268 symmetric: between.symmetric,
269 }))
270 }
271
272 Expression::In(in_expr) => {
274 let this = canonicalize_recursive(in_expr.this, dialect);
275 let expressions = in_expr
276 .expressions
277 .into_iter()
278 .map(|e| canonicalize_recursive(e, dialect))
279 .collect();
280 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
281 Expression::In(Box::new(crate::expressions::In {
282 this,
283 expressions,
284 query,
285 not: in_expr.not,
286 global: in_expr.global,
287 unnest: in_expr.unnest,
288 is_field: in_expr.is_field,
289 }))
290 }
291
292 Expression::Subquery(subquery) => {
294 let this = canonicalize_recursive(subquery.this, dialect);
295 Expression::Subquery(Box::new(crate::expressions::Subquery {
296 this,
297 alias: subquery.alias,
298 column_aliases: subquery.column_aliases,
299 alias_explicit_as: subquery.alias_explicit_as,
300 alias_keyword: subquery.alias_keyword,
301 order_by: subquery.order_by,
302 limit: subquery.limit,
303 offset: subquery.offset,
304 distribute_by: subquery.distribute_by,
305 sort_by: subquery.sort_by,
306 cluster_by: subquery.cluster_by,
307 lateral: subquery.lateral,
308 modifiers_inside: subquery.modifiers_inside,
309 trailing_comments: subquery.trailing_comments,
310 inferred_type: None,
311 }))
312 }
313
314 Expression::Union(union) => {
316 let mut u = *union;
317 let left = std::mem::replace(&mut u.left, Expression::Null(Null));
318 u.left = canonicalize_recursive(left, dialect);
319 let right = std::mem::replace(&mut u.right, Expression::Null(Null));
320 u.right = canonicalize_recursive(right, dialect);
321 Expression::Union(Box::new(u))
322 }
323 Expression::Intersect(intersect) => {
324 let mut i = *intersect;
325 let left = std::mem::replace(&mut i.left, Expression::Null(Null));
326 i.left = canonicalize_recursive(left, dialect);
327 let right = std::mem::replace(&mut i.right, Expression::Null(Null));
328 i.right = canonicalize_recursive(right, dialect);
329 Expression::Intersect(Box::new(i))
330 }
331 Expression::Except(except) => {
332 let mut e = *except;
333 let left = std::mem::replace(&mut e.left, Expression::Null(Null));
334 e.left = canonicalize_recursive(left, dialect);
335 let right = std::mem::replace(&mut e.right, Expression::Null(Null));
336 e.right = canonicalize_recursive(right, dialect);
337 Expression::Except(Box::new(e))
338 }
339
340 other => other,
342 };
343
344 expr
345}
346
347fn add_text_to_concat(expression: Expression) -> Expression {
352 expression
355}
356
357fn remove_redundant_casts(expression: Expression) -> Expression {
361 if let Expression::Cast(cast) = &expression {
362 if let Expression::Literal(lit) = &cast.this {
368 if let Literal::String(_) = lit.as_ref() {
369 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
370 return cast.this.clone();
371 }
372 }
373 }
374 if let Expression::Literal(lit) = &cast.this {
375 if let Literal::Number(_) = lit.as_ref() {
376 if matches!(
377 &cast.to,
378 DataType::Int { .. }
379 | DataType::BigInt { .. }
380 | DataType::Decimal { .. }
381 | DataType::Float { .. }
382 ) {
383 }
385 }
386 }
387 }
388 expression
389}
390
391fn ensure_bools(expression: Expression) -> Expression {
396 expression
400}
401
402fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
406 if !ordered.desc && ordered.explicit_asc {
409 ordered.explicit_asc = false;
410 }
411 ordered
412}
413
414fn canonicalize_comparison<F>(
416 constructor: F,
417 bin: crate::expressions::BinaryOp,
418 dialect: Option<DialectType>,
419) -> Expression
420where
421 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
422{
423 let left = canonicalize_recursive(bin.left, dialect);
424 let right = canonicalize_recursive(bin.right, dialect);
425
426 let (left, right) = coerce_date_operands(left, right);
428
429 constructor(Box::new(crate::expressions::BinaryOp {
430 left,
431 right,
432 left_comments: bin.left_comments,
433 operator_comments: bin.operator_comments,
434 trailing_comments: bin.trailing_comments,
435 inferred_type: None,
436 }))
437}
438
439fn canonicalize_binary<F>(
441 constructor: F,
442 bin: crate::expressions::BinaryOp,
443 dialect: Option<DialectType>,
444) -> Expression
445where
446 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
447{
448 let left = canonicalize_recursive(bin.left, dialect);
449 let right = canonicalize_recursive(bin.right, dialect);
450
451 constructor(Box::new(crate::expressions::BinaryOp {
452 left,
453 right,
454 left_comments: bin.left_comments,
455 operator_comments: bin.operator_comments,
456 trailing_comments: bin.trailing_comments,
457 inferred_type: None,
458 }))
459}
460
461fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
466 let left = coerce_date_string(left, &right);
468 let right = coerce_date_string(right, &left);
469 (left, right)
470}
471
472fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
474 if let Expression::Literal(ref lit) = expr {
475 if let Literal::String(ref s) = lit.as_ref() {
476 if is_iso_date(s) {
478 } else if is_iso_datetime(s) {
481 }
484 }
485 }
486 expr
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::generator::Generator;
493 use crate::parser::Parser;
494
495 fn gen(expr: &Expression) -> String {
496 Generator::new().generate(expr).unwrap()
497 }
498
499 fn parse(sql: &str) -> Expression {
500 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
501 }
502
503 #[test]
504 fn test_canonicalize_simple() {
505 let expr = parse("SELECT a FROM t");
506 let result = canonicalize(expr, None);
507 let sql = gen(&result);
508 assert!(sql.contains("SELECT"));
509 }
510
511 #[test]
512 fn test_canonicalize_preserves_structure() {
513 let expr = parse("SELECT a, b FROM t WHERE c = 1");
514 let result = canonicalize(expr, None);
515 let sql = gen(&result);
516 assert!(sql.contains("WHERE"));
517 }
518
519 #[test]
520 fn test_canonicalize_and_or() {
521 let expr = parse("SELECT 1 WHERE a AND b OR c");
522 let result = canonicalize(expr, None);
523 let sql = gen(&result);
524 assert!(sql.contains("AND") || sql.contains("OR"));
525 }
526
527 #[test]
528 fn test_canonicalize_comparison() {
529 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
530 let result = canonicalize(expr, None);
531 let sql = gen(&result);
532 assert!(sql.contains("=") && sql.contains(">"));
533 }
534
535 #[test]
536 fn test_canonicalize_case() {
537 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
538 let result = canonicalize(expr, None);
539 let sql = gen(&result);
540 assert!(sql.contains("CASE") && sql.contains("WHEN"));
541 }
542
543 #[test]
544 fn test_canonicalize_subquery() {
545 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
546 let result = canonicalize(expr, None);
547 let sql = gen(&result);
548 assert!(sql.contains("SELECT") && sql.contains("sub"));
549 }
550
551 #[test]
552 fn test_canonicalize_order_by() {
553 let expr = parse("SELECT a FROM t ORDER BY a");
554 let result = canonicalize(expr, None);
555 let sql = gen(&result);
556 assert!(sql.contains("ORDER BY"));
557 }
558
559 #[test]
560 fn test_canonicalize_union() {
561 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
562 let result = canonicalize(expr, None);
563 let sql = gen(&result);
564 assert!(sql.contains("UNION"));
565 }
566
567 #[test]
568 fn test_add_text_to_concat_passthrough() {
569 let expr = parse("SELECT 1 + 2");
571 let result = canonicalize(expr, None);
572 let sql = gen(&result);
573 assert!(sql.contains("+"));
574 }
575
576 #[test]
577 fn test_canonicalize_function() {
578 let expr = parse("SELECT MAX(a) FROM t");
579 let result = canonicalize(expr, None);
580 let sql = gen(&result);
581 assert!(sql.contains("MAX"));
582 }
583
584 #[test]
585 fn test_canonicalize_between() {
586 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
587 let result = canonicalize(expr, None);
588 let sql = gen(&result);
589 assert!(sql.contains("BETWEEN"));
590 }
591}