1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 }));
108 add_text_to_concat(result)
109 }
110
111 Expression::And(bin) => {
113 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115 Expression::And(Box::new(crate::expressions::BinaryOp {
116 left,
117 right,
118 left_comments: bin.left_comments,
119 operator_comments: bin.operator_comments,
120 trailing_comments: bin.trailing_comments,
121 }))
122 }
123 Expression::Or(bin) => {
124 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126 Expression::Or(Box::new(crate::expressions::BinaryOp {
127 left,
128 right,
129 left_comments: bin.left_comments,
130 operator_comments: bin.operator_comments,
131 trailing_comments: bin.trailing_comments,
132 }))
133 }
134
135 Expression::Not(un) => {
136 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137 Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138 }
139
140 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152 Expression::Cast(cast) => {
154 let inner = canonicalize_recursive(cast.this, dialect);
155 let result = Expression::Cast(Box::new(crate::expressions::Cast {
156 this: inner,
157 to: cast.to,
158 trailing_comments: cast.trailing_comments,
159 double_colon_syntax: cast.double_colon_syntax,
160 format: cast.format,
161 default: cast.default,
162 }));
163 remove_redundant_casts(result)
164 }
165
166 Expression::Function(func) => {
168 let args = func
169 .args
170 .into_iter()
171 .map(|e| canonicalize_recursive(e, dialect))
172 .collect();
173 Expression::Function(Box::new(crate::expressions::Function {
174 name: func.name,
175 args,
176 distinct: func.distinct,
177 trailing_comments: func.trailing_comments,
178 use_bracket_syntax: func.use_bracket_syntax,
179 no_parens: func.no_parens,
180 quoted: func.quoted,
181 }))
182 }
183
184 Expression::AggregateFunction(agg) => {
185 let args = agg
186 .args
187 .into_iter()
188 .map(|e| canonicalize_recursive(e, dialect))
189 .collect();
190 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
191 name: agg.name,
192 args,
193 distinct: agg.distinct,
194 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
195 order_by: agg.order_by,
196 limit: agg.limit,
197 ignore_nulls: agg.ignore_nulls,
198 }))
199 }
200
201 Expression::Alias(alias) => {
203 let inner = canonicalize_recursive(alias.this, dialect);
204 Expression::Alias(Box::new(crate::expressions::Alias {
205 this: inner,
206 alias: alias.alias,
207 column_aliases: alias.column_aliases,
208 pre_alias_comments: alias.pre_alias_comments,
209 trailing_comments: alias.trailing_comments,
210 }))
211 }
212
213 Expression::Paren(paren) => {
215 let inner = canonicalize_recursive(paren.this, dialect);
216 Expression::Paren(Box::new(crate::expressions::Paren {
217 this: inner,
218 trailing_comments: paren.trailing_comments,
219 }))
220 }
221
222 Expression::Case(case) => {
224 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
225 let whens = case
226 .whens
227 .into_iter()
228 .map(|(w, t)| {
229 (
230 canonicalize_recursive(w, dialect),
231 canonicalize_recursive(t, dialect),
232 )
233 })
234 .collect();
235 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
236 Expression::Case(Box::new(crate::expressions::Case {
237 operand,
238 whens,
239 else_,
240 comments: Vec::new(),
241 }))
242 }
243
244 Expression::Between(between) => {
246 let this = canonicalize_recursive(between.this, dialect);
247 let low = canonicalize_recursive(between.low, dialect);
248 let high = canonicalize_recursive(between.high, dialect);
249 Expression::Between(Box::new(crate::expressions::Between {
250 this,
251 low,
252 high,
253 not: between.not,
254 symmetric: between.symmetric,
255 }))
256 }
257
258 Expression::In(in_expr) => {
260 let this = canonicalize_recursive(in_expr.this, dialect);
261 let expressions = in_expr
262 .expressions
263 .into_iter()
264 .map(|e| canonicalize_recursive(e, dialect))
265 .collect();
266 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
267 Expression::In(Box::new(crate::expressions::In {
268 this,
269 expressions,
270 query,
271 not: in_expr.not,
272 global: in_expr.global,
273 unnest: in_expr.unnest,
274 is_field: in_expr.is_field,
275 }))
276 }
277
278 Expression::Subquery(subquery) => {
280 let this = canonicalize_recursive(subquery.this, dialect);
281 Expression::Subquery(Box::new(crate::expressions::Subquery {
282 this,
283 alias: subquery.alias,
284 column_aliases: subquery.column_aliases,
285 order_by: subquery.order_by,
286 limit: subquery.limit,
287 offset: subquery.offset,
288 distribute_by: subquery.distribute_by,
289 sort_by: subquery.sort_by,
290 cluster_by: subquery.cluster_by,
291 lateral: subquery.lateral,
292 modifiers_inside: subquery.modifiers_inside,
293 trailing_comments: subquery.trailing_comments,
294 }))
295 }
296
297 Expression::Union(union) => {
299 let left = canonicalize_recursive(union.left, dialect);
300 let right = canonicalize_recursive(union.right, dialect);
301 Expression::Union(Box::new(crate::expressions::Union {
302 left,
303 right,
304 all: union.all,
305 distinct: union.distinct,
306 with: union.with,
307 order_by: union.order_by,
308 limit: union.limit,
309 offset: union.offset,
310 distribute_by: union.distribute_by,
311 sort_by: union.sort_by,
312 cluster_by: union.cluster_by,
313 by_name: union.by_name,
314 side: union.side,
315 kind: union.kind,
316 corresponding: union.corresponding,
317 strict: union.strict,
318 on_columns: union.on_columns,
319 }))
320 }
321 Expression::Intersect(intersect) => {
322 let left = canonicalize_recursive(intersect.left, dialect);
323 let right = canonicalize_recursive(intersect.right, dialect);
324 Expression::Intersect(Box::new(crate::expressions::Intersect {
325 left,
326 right,
327 all: intersect.all,
328 distinct: intersect.distinct,
329 with: intersect.with,
330 order_by: intersect.order_by,
331 limit: intersect.limit,
332 offset: intersect.offset,
333 distribute_by: intersect.distribute_by,
334 sort_by: intersect.sort_by,
335 cluster_by: intersect.cluster_by,
336 by_name: intersect.by_name,
337 side: intersect.side,
338 kind: intersect.kind,
339 corresponding: intersect.corresponding,
340 strict: intersect.strict,
341 on_columns: intersect.on_columns,
342 }))
343 }
344 Expression::Except(except) => {
345 let left = canonicalize_recursive(except.left, dialect);
346 let right = canonicalize_recursive(except.right, dialect);
347 Expression::Except(Box::new(crate::expressions::Except {
348 left,
349 right,
350 all: except.all,
351 distinct: except.distinct,
352 with: except.with,
353 order_by: except.order_by,
354 limit: except.limit,
355 offset: except.offset,
356 distribute_by: except.distribute_by,
357 sort_by: except.sort_by,
358 cluster_by: except.cluster_by,
359 by_name: except.by_name,
360 side: except.side,
361 kind: except.kind,
362 corresponding: except.corresponding,
363 strict: except.strict,
364 on_columns: except.on_columns,
365 }))
366 }
367
368 other => other,
370 };
371
372 expr
373}
374
375fn add_text_to_concat(expression: Expression) -> Expression {
380 expression
383}
384
385fn remove_redundant_casts(expression: Expression) -> Expression {
389 if let Expression::Cast(cast) = &expression {
390 if let Expression::Literal(Literal::String(_)) = &cast.this {
396 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
397 return cast.this.clone();
398 }
399 }
400 if let Expression::Literal(Literal::Number(_)) = &cast.this {
401 if matches!(
402 &cast.to,
403 DataType::Int { .. }
404 | DataType::BigInt { .. }
405 | DataType::Decimal { .. }
406 | DataType::Float { .. }
407 ) {
408 }
410 }
411 }
412 expression
413}
414
415fn ensure_bools(expression: Expression) -> Expression {
420 expression
424}
425
426fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
430 if !ordered.desc && ordered.explicit_asc {
433 ordered.explicit_asc = false;
434 }
435 ordered
436}
437
438fn canonicalize_comparison<F>(
440 constructor: F,
441 bin: crate::expressions::BinaryOp,
442 dialect: Option<DialectType>,
443) -> Expression
444where
445 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
446{
447 let left = canonicalize_recursive(bin.left, dialect);
448 let right = canonicalize_recursive(bin.right, dialect);
449
450 let (left, right) = coerce_date_operands(left, right);
452
453 constructor(Box::new(crate::expressions::BinaryOp {
454 left,
455 right,
456 left_comments: bin.left_comments,
457 operator_comments: bin.operator_comments,
458 trailing_comments: bin.trailing_comments,
459 }))
460}
461
462fn canonicalize_binary<F>(
464 constructor: F,
465 bin: crate::expressions::BinaryOp,
466 dialect: Option<DialectType>,
467) -> Expression
468where
469 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
470{
471 let left = canonicalize_recursive(bin.left, dialect);
472 let right = canonicalize_recursive(bin.right, dialect);
473
474 constructor(Box::new(crate::expressions::BinaryOp {
475 left,
476 right,
477 left_comments: bin.left_comments,
478 operator_comments: bin.operator_comments,
479 trailing_comments: bin.trailing_comments,
480 }))
481}
482
483fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
488 let left = coerce_date_string(left, &right);
490 let right = coerce_date_string(right, &left);
491 (left, right)
492}
493
494fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
496 if let Expression::Literal(Literal::String(ref s)) = expr {
497 if is_iso_date(s) {
499 } else if is_iso_datetime(s) {
502 }
505 }
506 expr
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::generator::Generator;
513 use crate::parser::Parser;
514
515 fn gen(expr: &Expression) -> String {
516 Generator::new().generate(expr).unwrap()
517 }
518
519 fn parse(sql: &str) -> Expression {
520 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
521 }
522
523 #[test]
524 fn test_canonicalize_simple() {
525 let expr = parse("SELECT a FROM t");
526 let result = canonicalize(expr, None);
527 let sql = gen(&result);
528 assert!(sql.contains("SELECT"));
529 }
530
531 #[test]
532 fn test_canonicalize_preserves_structure() {
533 let expr = parse("SELECT a, b FROM t WHERE c = 1");
534 let result = canonicalize(expr, None);
535 let sql = gen(&result);
536 assert!(sql.contains("WHERE"));
537 }
538
539 #[test]
540 fn test_canonicalize_and_or() {
541 let expr = parse("SELECT 1 WHERE a AND b OR c");
542 let result = canonicalize(expr, None);
543 let sql = gen(&result);
544 assert!(sql.contains("AND") || sql.contains("OR"));
545 }
546
547 #[test]
548 fn test_canonicalize_comparison() {
549 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
550 let result = canonicalize(expr, None);
551 let sql = gen(&result);
552 assert!(sql.contains("=") && sql.contains(">"));
553 }
554
555 #[test]
556 fn test_canonicalize_case() {
557 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
558 let result = canonicalize(expr, None);
559 let sql = gen(&result);
560 assert!(sql.contains("CASE") && sql.contains("WHEN"));
561 }
562
563 #[test]
564 fn test_canonicalize_subquery() {
565 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
566 let result = canonicalize(expr, None);
567 let sql = gen(&result);
568 assert!(sql.contains("SELECT") && sql.contains("sub"));
569 }
570
571 #[test]
572 fn test_canonicalize_order_by() {
573 let expr = parse("SELECT a FROM t ORDER BY a");
574 let result = canonicalize(expr, None);
575 let sql = gen(&result);
576 assert!(sql.contains("ORDER BY"));
577 }
578
579 #[test]
580 fn test_canonicalize_union() {
581 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
582 let result = canonicalize(expr, None);
583 let sql = gen(&result);
584 assert!(sql.contains("UNION"));
585 }
586
587 #[test]
588 fn test_add_text_to_concat_passthrough() {
589 let expr = parse("SELECT 1 + 2");
591 let result = canonicalize(expr, None);
592 let sql = gen(&result);
593 assert!(sql.contains("+"));
594 }
595
596 #[test]
597 fn test_canonicalize_function() {
598 let expr = parse("SELECT MAX(a) FROM t");
599 let result = canonicalize(expr, None);
600 let sql = gen(&result);
601 assert!(sql.contains("MAX"));
602 }
603
604 #[test]
605 fn test_canonicalize_between() {
606 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
607 let result = canonicalize(expr, None);
608 let sql = gen(&result);
609 assert!(sql.contains("BETWEEN"));
610 }
611}