1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 inferred_type: None,
108 }));
109 add_text_to_concat(result)
110 }
111
112 Expression::And(bin) => {
114 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116 Expression::And(Box::new(crate::expressions::BinaryOp {
117 left,
118 right,
119 left_comments: bin.left_comments,
120 operator_comments: bin.operator_comments,
121 trailing_comments: bin.trailing_comments,
122 inferred_type: None,
123 }))
124 }
125 Expression::Or(bin) => {
126 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128 Expression::Or(Box::new(crate::expressions::BinaryOp {
129 left,
130 right,
131 left_comments: bin.left_comments,
132 operator_comments: bin.operator_comments,
133 trailing_comments: bin.trailing_comments,
134 inferred_type: None,
135 }))
136 }
137
138 Expression::Not(un) => {
139 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140 Expression::Not(Box::new(crate::expressions::UnaryOp {
141 this: inner,
142 inferred_type: None,
143 }))
144 }
145
146 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158 Expression::Cast(cast) => {
160 let inner = canonicalize_recursive(cast.this, dialect);
161 let result = Expression::Cast(Box::new(crate::expressions::Cast {
162 this: inner,
163 to: cast.to,
164 trailing_comments: cast.trailing_comments,
165 double_colon_syntax: cast.double_colon_syntax,
166 format: cast.format,
167 default: cast.default,
168 inferred_type: None,
169 }));
170 remove_redundant_casts(result)
171 }
172
173 Expression::Function(func) => {
175 let args = func
176 .args
177 .into_iter()
178 .map(|e| canonicalize_recursive(e, dialect))
179 .collect();
180 Expression::Function(Box::new(crate::expressions::Function {
181 name: func.name,
182 args,
183 distinct: func.distinct,
184 trailing_comments: func.trailing_comments,
185 use_bracket_syntax: func.use_bracket_syntax,
186 no_parens: func.no_parens,
187 quoted: func.quoted,
188 span: None,
189 inferred_type: None,
190 }))
191 }
192
193 Expression::AggregateFunction(agg) => {
194 let args = agg
195 .args
196 .into_iter()
197 .map(|e| canonicalize_recursive(e, dialect))
198 .collect();
199 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200 name: agg.name,
201 args,
202 distinct: agg.distinct,
203 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204 order_by: agg.order_by,
205 limit: agg.limit,
206 ignore_nulls: agg.ignore_nulls,
207 inferred_type: None,
208 }))
209 }
210
211 Expression::Alias(alias) => {
213 let inner = canonicalize_recursive(alias.this, dialect);
214 Expression::Alias(Box::new(crate::expressions::Alias {
215 this: inner,
216 alias: alias.alias,
217 column_aliases: alias.column_aliases,
218 pre_alias_comments: alias.pre_alias_comments,
219 trailing_comments: alias.trailing_comments,
220 inferred_type: None,
221 }))
222 }
223
224 Expression::Paren(paren) => {
226 let inner = canonicalize_recursive(paren.this, dialect);
227 Expression::Paren(Box::new(crate::expressions::Paren {
228 this: inner,
229 trailing_comments: paren.trailing_comments,
230 }))
231 }
232
233 Expression::Case(case) => {
235 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236 let whens = case
237 .whens
238 .into_iter()
239 .map(|(w, t)| {
240 (
241 canonicalize_recursive(w, dialect),
242 canonicalize_recursive(t, dialect),
243 )
244 })
245 .collect();
246 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247 Expression::Case(Box::new(crate::expressions::Case {
248 operand,
249 whens,
250 else_,
251 comments: Vec::new(),
252 inferred_type: None,
253 }))
254 }
255
256 Expression::Between(between) => {
258 let this = canonicalize_recursive(between.this, dialect);
259 let low = canonicalize_recursive(between.low, dialect);
260 let high = canonicalize_recursive(between.high, dialect);
261 Expression::Between(Box::new(crate::expressions::Between {
262 this,
263 low,
264 high,
265 not: between.not,
266 symmetric: between.symmetric,
267 }))
268 }
269
270 Expression::In(in_expr) => {
272 let this = canonicalize_recursive(in_expr.this, dialect);
273 let expressions = in_expr
274 .expressions
275 .into_iter()
276 .map(|e| canonicalize_recursive(e, dialect))
277 .collect();
278 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279 Expression::In(Box::new(crate::expressions::In {
280 this,
281 expressions,
282 query,
283 not: in_expr.not,
284 global: in_expr.global,
285 unnest: in_expr.unnest,
286 is_field: in_expr.is_field,
287 }))
288 }
289
290 Expression::Subquery(subquery) => {
292 let this = canonicalize_recursive(subquery.this, dialect);
293 Expression::Subquery(Box::new(crate::expressions::Subquery {
294 this,
295 alias: subquery.alias,
296 column_aliases: subquery.column_aliases,
297 order_by: subquery.order_by,
298 limit: subquery.limit,
299 offset: subquery.offset,
300 distribute_by: subquery.distribute_by,
301 sort_by: subquery.sort_by,
302 cluster_by: subquery.cluster_by,
303 lateral: subquery.lateral,
304 modifiers_inside: subquery.modifiers_inside,
305 trailing_comments: subquery.trailing_comments,
306 inferred_type: None,
307 }))
308 }
309
310 Expression::Union(union) => {
312 let left = canonicalize_recursive(union.left, dialect);
313 let right = canonicalize_recursive(union.right, dialect);
314 Expression::Union(Box::new(crate::expressions::Union {
315 left,
316 right,
317 all: union.all,
318 distinct: union.distinct,
319 with: union.with,
320 order_by: union.order_by,
321 limit: union.limit,
322 offset: union.offset,
323 distribute_by: union.distribute_by,
324 sort_by: union.sort_by,
325 cluster_by: union.cluster_by,
326 by_name: union.by_name,
327 side: union.side,
328 kind: union.kind,
329 corresponding: union.corresponding,
330 strict: union.strict,
331 on_columns: union.on_columns,
332 }))
333 }
334 Expression::Intersect(intersect) => {
335 let left = canonicalize_recursive(intersect.left, dialect);
336 let right = canonicalize_recursive(intersect.right, dialect);
337 Expression::Intersect(Box::new(crate::expressions::Intersect {
338 left,
339 right,
340 all: intersect.all,
341 distinct: intersect.distinct,
342 with: intersect.with,
343 order_by: intersect.order_by,
344 limit: intersect.limit,
345 offset: intersect.offset,
346 distribute_by: intersect.distribute_by,
347 sort_by: intersect.sort_by,
348 cluster_by: intersect.cluster_by,
349 by_name: intersect.by_name,
350 side: intersect.side,
351 kind: intersect.kind,
352 corresponding: intersect.corresponding,
353 strict: intersect.strict,
354 on_columns: intersect.on_columns,
355 }))
356 }
357 Expression::Except(except) => {
358 let left = canonicalize_recursive(except.left, dialect);
359 let right = canonicalize_recursive(except.right, dialect);
360 Expression::Except(Box::new(crate::expressions::Except {
361 left,
362 right,
363 all: except.all,
364 distinct: except.distinct,
365 with: except.with,
366 order_by: except.order_by,
367 limit: except.limit,
368 offset: except.offset,
369 distribute_by: except.distribute_by,
370 sort_by: except.sort_by,
371 cluster_by: except.cluster_by,
372 by_name: except.by_name,
373 side: except.side,
374 kind: except.kind,
375 corresponding: except.corresponding,
376 strict: except.strict,
377 on_columns: except.on_columns,
378 }))
379 }
380
381 other => other,
383 };
384
385 expr
386}
387
388fn add_text_to_concat(expression: Expression) -> Expression {
393 expression
396}
397
398fn remove_redundant_casts(expression: Expression) -> Expression {
402 if let Expression::Cast(cast) = &expression {
403 if let Expression::Literal(Literal::String(_)) = &cast.this {
409 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
410 return cast.this.clone();
411 }
412 }
413 if let Expression::Literal(Literal::Number(_)) = &cast.this {
414 if matches!(
415 &cast.to,
416 DataType::Int { .. }
417 | DataType::BigInt { .. }
418 | DataType::Decimal { .. }
419 | DataType::Float { .. }
420 ) {
421 }
423 }
424 }
425 expression
426}
427
428fn ensure_bools(expression: Expression) -> Expression {
433 expression
437}
438
439fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
443 if !ordered.desc && ordered.explicit_asc {
446 ordered.explicit_asc = false;
447 }
448 ordered
449}
450
451fn canonicalize_comparison<F>(
453 constructor: F,
454 bin: crate::expressions::BinaryOp,
455 dialect: Option<DialectType>,
456) -> Expression
457where
458 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
459{
460 let left = canonicalize_recursive(bin.left, dialect);
461 let right = canonicalize_recursive(bin.right, dialect);
462
463 let (left, right) = coerce_date_operands(left, right);
465
466 constructor(Box::new(crate::expressions::BinaryOp {
467 left,
468 right,
469 left_comments: bin.left_comments,
470 operator_comments: bin.operator_comments,
471 trailing_comments: bin.trailing_comments,
472 inferred_type: None,
473 }))
474}
475
476fn canonicalize_binary<F>(
478 constructor: F,
479 bin: crate::expressions::BinaryOp,
480 dialect: Option<DialectType>,
481) -> Expression
482where
483 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
484{
485 let left = canonicalize_recursive(bin.left, dialect);
486 let right = canonicalize_recursive(bin.right, dialect);
487
488 constructor(Box::new(crate::expressions::BinaryOp {
489 left,
490 right,
491 left_comments: bin.left_comments,
492 operator_comments: bin.operator_comments,
493 trailing_comments: bin.trailing_comments,
494 inferred_type: None,
495 }))
496}
497
498fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
503 let left = coerce_date_string(left, &right);
505 let right = coerce_date_string(right, &left);
506 (left, right)
507}
508
509fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
511 if let Expression::Literal(Literal::String(ref s)) = expr {
512 if is_iso_date(s) {
514 } else if is_iso_datetime(s) {
517 }
520 }
521 expr
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::generator::Generator;
528 use crate::parser::Parser;
529
530 fn gen(expr: &Expression) -> String {
531 Generator::new().generate(expr).unwrap()
532 }
533
534 fn parse(sql: &str) -> Expression {
535 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
536 }
537
538 #[test]
539 fn test_canonicalize_simple() {
540 let expr = parse("SELECT a FROM t");
541 let result = canonicalize(expr, None);
542 let sql = gen(&result);
543 assert!(sql.contains("SELECT"));
544 }
545
546 #[test]
547 fn test_canonicalize_preserves_structure() {
548 let expr = parse("SELECT a, b FROM t WHERE c = 1");
549 let result = canonicalize(expr, None);
550 let sql = gen(&result);
551 assert!(sql.contains("WHERE"));
552 }
553
554 #[test]
555 fn test_canonicalize_and_or() {
556 let expr = parse("SELECT 1 WHERE a AND b OR c");
557 let result = canonicalize(expr, None);
558 let sql = gen(&result);
559 assert!(sql.contains("AND") || sql.contains("OR"));
560 }
561
562 #[test]
563 fn test_canonicalize_comparison() {
564 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
565 let result = canonicalize(expr, None);
566 let sql = gen(&result);
567 assert!(sql.contains("=") && sql.contains(">"));
568 }
569
570 #[test]
571 fn test_canonicalize_case() {
572 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
573 let result = canonicalize(expr, None);
574 let sql = gen(&result);
575 assert!(sql.contains("CASE") && sql.contains("WHEN"));
576 }
577
578 #[test]
579 fn test_canonicalize_subquery() {
580 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
581 let result = canonicalize(expr, None);
582 let sql = gen(&result);
583 assert!(sql.contains("SELECT") && sql.contains("sub"));
584 }
585
586 #[test]
587 fn test_canonicalize_order_by() {
588 let expr = parse("SELECT a FROM t ORDER BY a");
589 let result = canonicalize(expr, None);
590 let sql = gen(&result);
591 assert!(sql.contains("ORDER BY"));
592 }
593
594 #[test]
595 fn test_canonicalize_union() {
596 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
597 let result = canonicalize(expr, None);
598 let sql = gen(&result);
599 assert!(sql.contains("UNION"));
600 }
601
602 #[test]
603 fn test_add_text_to_concat_passthrough() {
604 let expr = parse("SELECT 1 + 2");
606 let result = canonicalize(expr, None);
607 let sql = gen(&result);
608 assert!(sql.contains("+"));
609 }
610
611 #[test]
612 fn test_canonicalize_function() {
613 let expr = parse("SELECT MAX(a) FROM t");
614 let result = canonicalize(expr, None);
615 let sql = gen(&result);
616 assert!(sql.contains("MAX"));
617 }
618
619 #[test]
620 fn test_canonicalize_between() {
621 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
622 let result = canonicalize(expr, None);
623 let sql = gen(&result);
624 assert!(sql.contains("BETWEEN"));
625 }
626}