1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 }));
108 add_text_to_concat(result)
109 }
110
111 Expression::And(bin) => {
113 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115 Expression::And(Box::new(crate::expressions::BinaryOp {
116 left,
117 right,
118 left_comments: bin.left_comments,
119 operator_comments: bin.operator_comments,
120 trailing_comments: bin.trailing_comments,
121 }))
122 }
123 Expression::Or(bin) => {
124 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126 Expression::Or(Box::new(crate::expressions::BinaryOp {
127 left,
128 right,
129 left_comments: bin.left_comments,
130 operator_comments: bin.operator_comments,
131 trailing_comments: bin.trailing_comments,
132 }))
133 }
134
135 Expression::Not(un) => {
136 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137 Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138 }
139
140 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152 Expression::Cast(cast) => {
154 let inner = canonicalize_recursive(cast.this, dialect);
155 let result = Expression::Cast(Box::new(crate::expressions::Cast {
156 this: inner,
157 to: cast.to,
158 trailing_comments: cast.trailing_comments,
159 double_colon_syntax: cast.double_colon_syntax,
160 format: cast.format,
161 default: cast.default,
162 }));
163 remove_redundant_casts(result)
164 }
165
166 Expression::Function(func) => {
168 let args = func
169 .args
170 .into_iter()
171 .map(|e| canonicalize_recursive(e, dialect))
172 .collect();
173 Expression::Function(Box::new(crate::expressions::Function {
174 name: func.name,
175 args,
176 distinct: func.distinct,
177 trailing_comments: func.trailing_comments,
178 use_bracket_syntax: func.use_bracket_syntax,
179 no_parens: func.no_parens,
180 quoted: func.quoted,
181 span: None,
182 }))
183 }
184
185 Expression::AggregateFunction(agg) => {
186 let args = agg
187 .args
188 .into_iter()
189 .map(|e| canonicalize_recursive(e, dialect))
190 .collect();
191 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
192 name: agg.name,
193 args,
194 distinct: agg.distinct,
195 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
196 order_by: agg.order_by,
197 limit: agg.limit,
198 ignore_nulls: agg.ignore_nulls,
199 }))
200 }
201
202 Expression::Alias(alias) => {
204 let inner = canonicalize_recursive(alias.this, dialect);
205 Expression::Alias(Box::new(crate::expressions::Alias {
206 this: inner,
207 alias: alias.alias,
208 column_aliases: alias.column_aliases,
209 pre_alias_comments: alias.pre_alias_comments,
210 trailing_comments: alias.trailing_comments,
211 }))
212 }
213
214 Expression::Paren(paren) => {
216 let inner = canonicalize_recursive(paren.this, dialect);
217 Expression::Paren(Box::new(crate::expressions::Paren {
218 this: inner,
219 trailing_comments: paren.trailing_comments,
220 }))
221 }
222
223 Expression::Case(case) => {
225 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
226 let whens = case
227 .whens
228 .into_iter()
229 .map(|(w, t)| {
230 (
231 canonicalize_recursive(w, dialect),
232 canonicalize_recursive(t, dialect),
233 )
234 })
235 .collect();
236 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
237 Expression::Case(Box::new(crate::expressions::Case {
238 operand,
239 whens,
240 else_,
241 comments: Vec::new(),
242 }))
243 }
244
245 Expression::Between(between) => {
247 let this = canonicalize_recursive(between.this, dialect);
248 let low = canonicalize_recursive(between.low, dialect);
249 let high = canonicalize_recursive(between.high, dialect);
250 Expression::Between(Box::new(crate::expressions::Between {
251 this,
252 low,
253 high,
254 not: between.not,
255 symmetric: between.symmetric,
256 }))
257 }
258
259 Expression::In(in_expr) => {
261 let this = canonicalize_recursive(in_expr.this, dialect);
262 let expressions = in_expr
263 .expressions
264 .into_iter()
265 .map(|e| canonicalize_recursive(e, dialect))
266 .collect();
267 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
268 Expression::In(Box::new(crate::expressions::In {
269 this,
270 expressions,
271 query,
272 not: in_expr.not,
273 global: in_expr.global,
274 unnest: in_expr.unnest,
275 is_field: in_expr.is_field,
276 }))
277 }
278
279 Expression::Subquery(subquery) => {
281 let this = canonicalize_recursive(subquery.this, dialect);
282 Expression::Subquery(Box::new(crate::expressions::Subquery {
283 this,
284 alias: subquery.alias,
285 column_aliases: subquery.column_aliases,
286 order_by: subquery.order_by,
287 limit: subquery.limit,
288 offset: subquery.offset,
289 distribute_by: subquery.distribute_by,
290 sort_by: subquery.sort_by,
291 cluster_by: subquery.cluster_by,
292 lateral: subquery.lateral,
293 modifiers_inside: subquery.modifiers_inside,
294 trailing_comments: subquery.trailing_comments,
295 }))
296 }
297
298 Expression::Union(union) => {
300 let left = canonicalize_recursive(union.left, dialect);
301 let right = canonicalize_recursive(union.right, dialect);
302 Expression::Union(Box::new(crate::expressions::Union {
303 left,
304 right,
305 all: union.all,
306 distinct: union.distinct,
307 with: union.with,
308 order_by: union.order_by,
309 limit: union.limit,
310 offset: union.offset,
311 distribute_by: union.distribute_by,
312 sort_by: union.sort_by,
313 cluster_by: union.cluster_by,
314 by_name: union.by_name,
315 side: union.side,
316 kind: union.kind,
317 corresponding: union.corresponding,
318 strict: union.strict,
319 on_columns: union.on_columns,
320 }))
321 }
322 Expression::Intersect(intersect) => {
323 let left = canonicalize_recursive(intersect.left, dialect);
324 let right = canonicalize_recursive(intersect.right, dialect);
325 Expression::Intersect(Box::new(crate::expressions::Intersect {
326 left,
327 right,
328 all: intersect.all,
329 distinct: intersect.distinct,
330 with: intersect.with,
331 order_by: intersect.order_by,
332 limit: intersect.limit,
333 offset: intersect.offset,
334 distribute_by: intersect.distribute_by,
335 sort_by: intersect.sort_by,
336 cluster_by: intersect.cluster_by,
337 by_name: intersect.by_name,
338 side: intersect.side,
339 kind: intersect.kind,
340 corresponding: intersect.corresponding,
341 strict: intersect.strict,
342 on_columns: intersect.on_columns,
343 }))
344 }
345 Expression::Except(except) => {
346 let left = canonicalize_recursive(except.left, dialect);
347 let right = canonicalize_recursive(except.right, dialect);
348 Expression::Except(Box::new(crate::expressions::Except {
349 left,
350 right,
351 all: except.all,
352 distinct: except.distinct,
353 with: except.with,
354 order_by: except.order_by,
355 limit: except.limit,
356 offset: except.offset,
357 distribute_by: except.distribute_by,
358 sort_by: except.sort_by,
359 cluster_by: except.cluster_by,
360 by_name: except.by_name,
361 side: except.side,
362 kind: except.kind,
363 corresponding: except.corresponding,
364 strict: except.strict,
365 on_columns: except.on_columns,
366 }))
367 }
368
369 other => other,
371 };
372
373 expr
374}
375
376fn add_text_to_concat(expression: Expression) -> Expression {
381 expression
384}
385
386fn remove_redundant_casts(expression: Expression) -> Expression {
390 if let Expression::Cast(cast) = &expression {
391 if let Expression::Literal(Literal::String(_)) = &cast.this {
397 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
398 return cast.this.clone();
399 }
400 }
401 if let Expression::Literal(Literal::Number(_)) = &cast.this {
402 if matches!(
403 &cast.to,
404 DataType::Int { .. }
405 | DataType::BigInt { .. }
406 | DataType::Decimal { .. }
407 | DataType::Float { .. }
408 ) {
409 }
411 }
412 }
413 expression
414}
415
416fn ensure_bools(expression: Expression) -> Expression {
421 expression
425}
426
427fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
431 if !ordered.desc && ordered.explicit_asc {
434 ordered.explicit_asc = false;
435 }
436 ordered
437}
438
439fn canonicalize_comparison<F>(
441 constructor: F,
442 bin: crate::expressions::BinaryOp,
443 dialect: Option<DialectType>,
444) -> Expression
445where
446 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
447{
448 let left = canonicalize_recursive(bin.left, dialect);
449 let right = canonicalize_recursive(bin.right, dialect);
450
451 let (left, right) = coerce_date_operands(left, right);
453
454 constructor(Box::new(crate::expressions::BinaryOp {
455 left,
456 right,
457 left_comments: bin.left_comments,
458 operator_comments: bin.operator_comments,
459 trailing_comments: bin.trailing_comments,
460 }))
461}
462
463fn canonicalize_binary<F>(
465 constructor: F,
466 bin: crate::expressions::BinaryOp,
467 dialect: Option<DialectType>,
468) -> Expression
469where
470 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
471{
472 let left = canonicalize_recursive(bin.left, dialect);
473 let right = canonicalize_recursive(bin.right, dialect);
474
475 constructor(Box::new(crate::expressions::BinaryOp {
476 left,
477 right,
478 left_comments: bin.left_comments,
479 operator_comments: bin.operator_comments,
480 trailing_comments: bin.trailing_comments,
481 }))
482}
483
484fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
489 let left = coerce_date_string(left, &right);
491 let right = coerce_date_string(right, &left);
492 (left, right)
493}
494
495fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
497 if let Expression::Literal(Literal::String(ref s)) = expr {
498 if is_iso_date(s) {
500 } else if is_iso_datetime(s) {
503 }
506 }
507 expr
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use crate::generator::Generator;
514 use crate::parser::Parser;
515
516 fn gen(expr: &Expression) -> String {
517 Generator::new().generate(expr).unwrap()
518 }
519
520 fn parse(sql: &str) -> Expression {
521 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
522 }
523
524 #[test]
525 fn test_canonicalize_simple() {
526 let expr = parse("SELECT a FROM t");
527 let result = canonicalize(expr, None);
528 let sql = gen(&result);
529 assert!(sql.contains("SELECT"));
530 }
531
532 #[test]
533 fn test_canonicalize_preserves_structure() {
534 let expr = parse("SELECT a, b FROM t WHERE c = 1");
535 let result = canonicalize(expr, None);
536 let sql = gen(&result);
537 assert!(sql.contains("WHERE"));
538 }
539
540 #[test]
541 fn test_canonicalize_and_or() {
542 let expr = parse("SELECT 1 WHERE a AND b OR c");
543 let result = canonicalize(expr, None);
544 let sql = gen(&result);
545 assert!(sql.contains("AND") || sql.contains("OR"));
546 }
547
548 #[test]
549 fn test_canonicalize_comparison() {
550 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
551 let result = canonicalize(expr, None);
552 let sql = gen(&result);
553 assert!(sql.contains("=") && sql.contains(">"));
554 }
555
556 #[test]
557 fn test_canonicalize_case() {
558 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
559 let result = canonicalize(expr, None);
560 let sql = gen(&result);
561 assert!(sql.contains("CASE") && sql.contains("WHEN"));
562 }
563
564 #[test]
565 fn test_canonicalize_subquery() {
566 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
567 let result = canonicalize(expr, None);
568 let sql = gen(&result);
569 assert!(sql.contains("SELECT") && sql.contains("sub"));
570 }
571
572 #[test]
573 fn test_canonicalize_order_by() {
574 let expr = parse("SELECT a FROM t ORDER BY a");
575 let result = canonicalize(expr, None);
576 let sql = gen(&result);
577 assert!(sql.contains("ORDER BY"));
578 }
579
580 #[test]
581 fn test_canonicalize_union() {
582 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
583 let result = canonicalize(expr, None);
584 let sql = gen(&result);
585 assert!(sql.contains("UNION"));
586 }
587
588 #[test]
589 fn test_add_text_to_concat_passthrough() {
590 let expr = parse("SELECT 1 + 2");
592 let result = canonicalize(expr, None);
593 let sql = gen(&result);
594 assert!(sql.contains("+"));
595 }
596
597 #[test]
598 fn test_canonicalize_function() {
599 let expr = parse("SELECT MAX(a) FROM t");
600 let result = canonicalize(expr, None);
601 let sql = gen(&result);
602 assert!(sql.contains("MAX"));
603 }
604
605 #[test]
606 fn test_canonicalize_between() {
607 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
608 let result = canonicalize(expr, None);
609 let sql = gen(&result);
610 assert!(sql.contains("BETWEEN"));
611 }
612}