1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 inferred_type: None,
108 }));
109 add_text_to_concat(result)
110 }
111
112 Expression::And(bin) => {
114 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116 Expression::And(Box::new(crate::expressions::BinaryOp {
117 left,
118 right,
119 left_comments: bin.left_comments,
120 operator_comments: bin.operator_comments,
121 trailing_comments: bin.trailing_comments,
122 inferred_type: None,
123 }))
124 }
125 Expression::Or(bin) => {
126 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128 Expression::Or(Box::new(crate::expressions::BinaryOp {
129 left,
130 right,
131 left_comments: bin.left_comments,
132 operator_comments: bin.operator_comments,
133 trailing_comments: bin.trailing_comments,
134 inferred_type: None,
135 }))
136 }
137
138 Expression::Not(un) => {
139 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140 Expression::Not(Box::new(crate::expressions::UnaryOp {
141 this: inner,
142 inferred_type: None,
143 }))
144 }
145
146 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158 Expression::Cast(cast) => {
160 let inner = canonicalize_recursive(cast.this, dialect);
161 let result = Expression::Cast(Box::new(crate::expressions::Cast {
162 this: inner,
163 to: cast.to,
164 trailing_comments: cast.trailing_comments,
165 double_colon_syntax: cast.double_colon_syntax,
166 format: cast.format,
167 default: cast.default,
168 inferred_type: None,
169 }));
170 remove_redundant_casts(result)
171 }
172
173 Expression::Function(func) => {
175 let args = func
176 .args
177 .into_iter()
178 .map(|e| canonicalize_recursive(e, dialect))
179 .collect();
180 Expression::Function(Box::new(crate::expressions::Function {
181 name: func.name,
182 args,
183 distinct: func.distinct,
184 trailing_comments: func.trailing_comments,
185 use_bracket_syntax: func.use_bracket_syntax,
186 no_parens: func.no_parens,
187 quoted: func.quoted,
188 span: None,
189 inferred_type: None,
190 }))
191 }
192
193 Expression::AggregateFunction(agg) => {
194 let args = agg
195 .args
196 .into_iter()
197 .map(|e| canonicalize_recursive(e, dialect))
198 .collect();
199 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200 name: agg.name,
201 args,
202 distinct: agg.distinct,
203 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204 order_by: agg.order_by,
205 limit: agg.limit,
206 ignore_nulls: agg.ignore_nulls,
207 inferred_type: None,
208 }))
209 }
210
211 Expression::Alias(alias) => {
213 let inner = canonicalize_recursive(alias.this, dialect);
214 Expression::Alias(Box::new(crate::expressions::Alias {
215 this: inner,
216 alias: alias.alias,
217 column_aliases: alias.column_aliases,
218 pre_alias_comments: alias.pre_alias_comments,
219 trailing_comments: alias.trailing_comments,
220 inferred_type: None,
221 }))
222 }
223
224 Expression::Paren(paren) => {
226 let inner = canonicalize_recursive(paren.this, dialect);
227 Expression::Paren(Box::new(crate::expressions::Paren {
228 this: inner,
229 trailing_comments: paren.trailing_comments,
230 }))
231 }
232
233 Expression::Case(case) => {
235 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236 let whens = case
237 .whens
238 .into_iter()
239 .map(|(w, t)| {
240 (
241 canonicalize_recursive(w, dialect),
242 canonicalize_recursive(t, dialect),
243 )
244 })
245 .collect();
246 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247 Expression::Case(Box::new(crate::expressions::Case {
248 operand,
249 whens,
250 else_,
251 comments: Vec::new(),
252 inferred_type: None,
253 }))
254 }
255
256 Expression::Between(between) => {
258 let this = canonicalize_recursive(between.this, dialect);
259 let low = canonicalize_recursive(between.low, dialect);
260 let high = canonicalize_recursive(between.high, dialect);
261 Expression::Between(Box::new(crate::expressions::Between {
262 this,
263 low,
264 high,
265 not: between.not,
266 symmetric: between.symmetric,
267 }))
268 }
269
270 Expression::In(in_expr) => {
272 let this = canonicalize_recursive(in_expr.this, dialect);
273 let expressions = in_expr
274 .expressions
275 .into_iter()
276 .map(|e| canonicalize_recursive(e, dialect))
277 .collect();
278 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279 Expression::In(Box::new(crate::expressions::In {
280 this,
281 expressions,
282 query,
283 not: in_expr.not,
284 global: in_expr.global,
285 unnest: in_expr.unnest,
286 is_field: in_expr.is_field,
287 }))
288 }
289
290 Expression::Subquery(subquery) => {
292 let this = canonicalize_recursive(subquery.this, dialect);
293 Expression::Subquery(Box::new(crate::expressions::Subquery {
294 this,
295 alias: subquery.alias,
296 column_aliases: subquery.column_aliases,
297 order_by: subquery.order_by,
298 limit: subquery.limit,
299 offset: subquery.offset,
300 distribute_by: subquery.distribute_by,
301 sort_by: subquery.sort_by,
302 cluster_by: subquery.cluster_by,
303 lateral: subquery.lateral,
304 modifiers_inside: subquery.modifiers_inside,
305 trailing_comments: subquery.trailing_comments,
306 inferred_type: None,
307 }))
308 }
309
310 Expression::Union(union) => {
312 let left = canonicalize_recursive(union.left, dialect);
313 let right = canonicalize_recursive(union.right, dialect);
314 Expression::Union(Box::new(crate::expressions::Union {
315 left,
316 right,
317 all: union.all,
318 distinct: union.distinct,
319 with: union.with,
320 order_by: union.order_by,
321 limit: union.limit,
322 offset: union.offset,
323 distribute_by: union.distribute_by,
324 sort_by: union.sort_by,
325 cluster_by: union.cluster_by,
326 by_name: union.by_name,
327 side: union.side,
328 kind: union.kind,
329 corresponding: union.corresponding,
330 strict: union.strict,
331 on_columns: union.on_columns,
332 }))
333 }
334 Expression::Intersect(intersect) => {
335 let left = canonicalize_recursive(intersect.left, dialect);
336 let right = canonicalize_recursive(intersect.right, dialect);
337 Expression::Intersect(Box::new(crate::expressions::Intersect {
338 left,
339 right,
340 all: intersect.all,
341 distinct: intersect.distinct,
342 with: intersect.with,
343 order_by: intersect.order_by,
344 limit: intersect.limit,
345 offset: intersect.offset,
346 distribute_by: intersect.distribute_by,
347 sort_by: intersect.sort_by,
348 cluster_by: intersect.cluster_by,
349 by_name: intersect.by_name,
350 side: intersect.side,
351 kind: intersect.kind,
352 corresponding: intersect.corresponding,
353 strict: intersect.strict,
354 on_columns: intersect.on_columns,
355 }))
356 }
357 Expression::Except(except) => {
358 let left = canonicalize_recursive(except.left, dialect);
359 let right = canonicalize_recursive(except.right, dialect);
360 Expression::Except(Box::new(crate::expressions::Except {
361 left,
362 right,
363 all: except.all,
364 distinct: except.distinct,
365 with: except.with,
366 order_by: except.order_by,
367 limit: except.limit,
368 offset: except.offset,
369 distribute_by: except.distribute_by,
370 sort_by: except.sort_by,
371 cluster_by: except.cluster_by,
372 by_name: except.by_name,
373 side: except.side,
374 kind: except.kind,
375 corresponding: except.corresponding,
376 strict: except.strict,
377 on_columns: except.on_columns,
378 }))
379 }
380
381 other => other,
383 };
384
385 expr
386}
387
388fn add_text_to_concat(expression: Expression) -> Expression {
393 expression
396}
397
398fn remove_redundant_casts(expression: Expression) -> Expression {
402 if let Expression::Cast(cast) = &expression {
403 if let Expression::Literal(lit) = &cast.this {
409 if let Literal::String(_) = lit.as_ref() {
410 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
411 return cast.this.clone();
412 }
413 }
414 }
415 if let Expression::Literal(lit) = &cast.this {
416 if let Literal::Number(_) = lit.as_ref() {
417 if matches!(
418 &cast.to,
419 DataType::Int { .. }
420 | DataType::BigInt { .. }
421 | DataType::Decimal { .. }
422 | DataType::Float { .. }
423 ) {
424 }
426 }
427 }
428 }
429 expression
430}
431
432fn ensure_bools(expression: Expression) -> Expression {
437 expression
441}
442
443fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
447 if !ordered.desc && ordered.explicit_asc {
450 ordered.explicit_asc = false;
451 }
452 ordered
453}
454
455fn canonicalize_comparison<F>(
457 constructor: F,
458 bin: crate::expressions::BinaryOp,
459 dialect: Option<DialectType>,
460) -> Expression
461where
462 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
463{
464 let left = canonicalize_recursive(bin.left, dialect);
465 let right = canonicalize_recursive(bin.right, dialect);
466
467 let (left, right) = coerce_date_operands(left, right);
469
470 constructor(Box::new(crate::expressions::BinaryOp {
471 left,
472 right,
473 left_comments: bin.left_comments,
474 operator_comments: bin.operator_comments,
475 trailing_comments: bin.trailing_comments,
476 inferred_type: None,
477 }))
478}
479
480fn canonicalize_binary<F>(
482 constructor: F,
483 bin: crate::expressions::BinaryOp,
484 dialect: Option<DialectType>,
485) -> Expression
486where
487 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
488{
489 let left = canonicalize_recursive(bin.left, dialect);
490 let right = canonicalize_recursive(bin.right, dialect);
491
492 constructor(Box::new(crate::expressions::BinaryOp {
493 left,
494 right,
495 left_comments: bin.left_comments,
496 operator_comments: bin.operator_comments,
497 trailing_comments: bin.trailing_comments,
498 inferred_type: None,
499 }))
500}
501
502fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
507 let left = coerce_date_string(left, &right);
509 let right = coerce_date_string(right, &left);
510 (left, right)
511}
512
513fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
515 if let Expression::Literal(ref lit) = expr {
516 if let Literal::String(ref s) = lit.as_ref() {
517 if is_iso_date(s) {
519 } else if is_iso_datetime(s) {
522 }
525 }
526 }
527 expr
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::generator::Generator;
534 use crate::parser::Parser;
535
536 fn gen(expr: &Expression) -> String {
537 Generator::new().generate(expr).unwrap()
538 }
539
540 fn parse(sql: &str) -> Expression {
541 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
542 }
543
544 #[test]
545 fn test_canonicalize_simple() {
546 let expr = parse("SELECT a FROM t");
547 let result = canonicalize(expr, None);
548 let sql = gen(&result);
549 assert!(sql.contains("SELECT"));
550 }
551
552 #[test]
553 fn test_canonicalize_preserves_structure() {
554 let expr = parse("SELECT a, b FROM t WHERE c = 1");
555 let result = canonicalize(expr, None);
556 let sql = gen(&result);
557 assert!(sql.contains("WHERE"));
558 }
559
560 #[test]
561 fn test_canonicalize_and_or() {
562 let expr = parse("SELECT 1 WHERE a AND b OR c");
563 let result = canonicalize(expr, None);
564 let sql = gen(&result);
565 assert!(sql.contains("AND") || sql.contains("OR"));
566 }
567
568 #[test]
569 fn test_canonicalize_comparison() {
570 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
571 let result = canonicalize(expr, None);
572 let sql = gen(&result);
573 assert!(sql.contains("=") && sql.contains(">"));
574 }
575
576 #[test]
577 fn test_canonicalize_case() {
578 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
579 let result = canonicalize(expr, None);
580 let sql = gen(&result);
581 assert!(sql.contains("CASE") && sql.contains("WHEN"));
582 }
583
584 #[test]
585 fn test_canonicalize_subquery() {
586 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
587 let result = canonicalize(expr, None);
588 let sql = gen(&result);
589 assert!(sql.contains("SELECT") && sql.contains("sub"));
590 }
591
592 #[test]
593 fn test_canonicalize_order_by() {
594 let expr = parse("SELECT a FROM t ORDER BY a");
595 let result = canonicalize(expr, None);
596 let sql = gen(&result);
597 assert!(sql.contains("ORDER BY"));
598 }
599
600 #[test]
601 fn test_canonicalize_union() {
602 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
603 let result = canonicalize(expr, None);
604 let sql = gen(&result);
605 assert!(sql.contains("UNION"));
606 }
607
608 #[test]
609 fn test_add_text_to_concat_passthrough() {
610 let expr = parse("SELECT 1 + 2");
612 let result = canonicalize(expr, None);
613 let sql = gen(&result);
614 assert!(sql.contains("+"));
615 }
616
617 #[test]
618 fn test_canonicalize_function() {
619 let expr = parse("SELECT MAX(a) FROM t");
620 let result = canonicalize(expr, None);
621 let sql = gen(&result);
622 assert!(sql.contains("MAX"));
623 }
624
625 #[test]
626 fn test_canonicalize_between() {
627 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
628 let result = canonicalize(expr, None);
629 let sql = gen(&result);
630 assert!(sql.contains("BETWEEN"));
631 }
632}