1use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21 #[error("Normalization distance {distance} exceeds max {max}")]
22 DistanceExceeded { distance: i64, max: i64 },
23}
24
25pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28pub fn normalize(
41 expression: Expression,
42 dnf: bool,
43 max_distance: i64,
44) -> NormalizeResult<Expression> {
45 let simplifier = Simplifier::new(None);
46 normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49fn normalize_with_simplifier(
51 expression: Expression,
52 dnf: bool,
53 max_distance: i64,
54 simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56 let mut result = expression.clone();
57
58 let connectors = collect_connectors(&expression);
60
61 for node in connectors {
62 if normalized(&node, dnf) {
63 continue;
64 }
65
66 let distance = normalization_distance(&node, dnf, max_distance);
68
69 if distance > max_distance {
70 return Ok(expression);
72 }
73
74 let normalized_node = apply_distributive_law(&node, dnf, max_distance, simplifier)?;
76
77 if is_same_expression(&node, &expression) {
80 result = normalized_node;
81 }
82 }
83
84 Ok(result)
85}
86
87pub fn normalized(expression: &Expression, dnf: bool) -> bool {
104 if dnf {
105 !has_and_with_or_descendant(expression)
107 } else {
108 !has_or_with_and_descendant(expression)
110 }
111}
112
113fn has_or_with_and_descendant(expression: &Expression) -> bool {
115 match expression {
116 Expression::Or(bin) => {
117 contains_and(&bin.left)
119 || contains_and(&bin.right)
120 || has_or_with_and_descendant(&bin.left)
121 || has_or_with_and_descendant(&bin.right)
122 }
123 Expression::And(bin) => {
124 has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
125 }
126 Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
127 _ => false,
128 }
129}
130
131fn has_and_with_or_descendant(expression: &Expression) -> bool {
133 match expression {
134 Expression::And(bin) => {
135 contains_or(&bin.left)
137 || contains_or(&bin.right)
138 || has_and_with_or_descendant(&bin.left)
139 || has_and_with_or_descendant(&bin.right)
140 }
141 Expression::Or(bin) => {
142 has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
143 }
144 Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
145 _ => false,
146 }
147}
148
149fn contains_and(expression: &Expression) -> bool {
151 match expression {
152 Expression::And(_) => true,
153 Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
154 Expression::Paren(paren) => contains_and(&paren.this),
155 _ => false,
156 }
157}
158
159fn contains_or(expression: &Expression) -> bool {
161 match expression {
162 Expression::Or(_) => true,
163 Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
164 Expression::Paren(paren) => contains_or(&paren.this),
165 _ => false,
166 }
167}
168
169pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
183 let connector_count = count_connectors(expression);
184 let mut total: i64 = -(connector_count as i64 + 1);
185
186 for length in predicate_lengths(expression, dnf, max_distance, 0) {
187 total += length;
188 if total > max_distance {
189 return total;
190 }
191 }
192
193 total
194}
195
196fn predicate_lengths(expression: &Expression, dnf: bool, max_distance: i64, depth: i64) -> Vec<i64> {
206 if depth > max_distance {
207 return vec![depth];
208 }
209
210 let expr = unwrap_paren(expression);
211
212 match expr {
213 Expression::Or(bin) if !dnf => {
215 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
217 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
218
219 let mut result = Vec::new();
220 for a in &left_lengths {
221 for b in &right_lengths {
222 result.push(a + b);
223 }
224 }
225 result
226 }
227 Expression::And(bin) if dnf => {
229 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
232
233 let mut result = Vec::new();
234 for a in &left_lengths {
235 for b in &right_lengths {
236 result.push(a + b);
237 }
238 }
239 result
240 }
241 Expression::And(bin) | Expression::Or(bin) => {
243 let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
244 result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
245 result
246 }
247 _ => vec![1], }
249}
250
251fn apply_distributive_law(
256 expression: &Expression,
257 dnf: bool,
258 max_distance: i64,
259 simplifier: &Simplifier,
260) -> NormalizeResult<Expression> {
261 if normalized(expression, dnf) {
262 return Ok(expression.clone());
263 }
264
265 let distance = normalization_distance(expression, dnf, max_distance);
266 if distance > max_distance {
267 return Err(NormalizeError::DistanceExceeded {
268 distance,
269 max: max_distance,
270 });
271 }
272
273 let result = if dnf {
275 distribute_dnf(expression, simplifier)
276 } else {
277 distribute_cnf(expression, simplifier)
278 };
279
280 if !normalized(&result, dnf) {
282 apply_distributive_law(&result, dnf, max_distance, simplifier)
283 } else {
284 Ok(result)
285 }
286}
287
288fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
291 match expression {
292 Expression::Or(bin) => {
293 let left = distribute_cnf(&bin.left, simplifier);
294 let right = distribute_cnf(&bin.right, simplifier);
295
296 if let Expression::And(and_bin) = &right {
298 let left_or_y = make_or(left.clone(), and_bin.left.clone());
300 let left_or_z = make_or(left, and_bin.right.clone());
301 return make_and(left_or_y, left_or_z);
302 }
303
304 if let Expression::And(and_bin) = &left {
305 let y_or_right = make_or(and_bin.left.clone(), right.clone());
307 let z_or_right = make_or(and_bin.right.clone(), right);
308 return make_and(y_or_right, z_or_right);
309 }
310
311 make_or(left, right)
313 }
314 Expression::And(bin) => {
315 let left = distribute_cnf(&bin.left, simplifier);
317 let right = distribute_cnf(&bin.right, simplifier);
318 make_and(left, right)
319 }
320 Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
321 _ => expression.clone(),
322 }
323}
324
325fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
328 match expression {
329 Expression::And(bin) => {
330 let left = distribute_dnf(&bin.left, simplifier);
331 let right = distribute_dnf(&bin.right, simplifier);
332
333 if let Expression::Or(or_bin) = &right {
335 let left_and_y = make_and(left.clone(), or_bin.left.clone());
337 let left_and_z = make_and(left, or_bin.right.clone());
338 return make_or(left_and_y, left_and_z);
339 }
340
341 if let Expression::Or(or_bin) = &left {
342 let y_and_right = make_and(or_bin.left.clone(), right.clone());
344 let z_and_right = make_and(or_bin.right.clone(), right);
345 return make_or(y_and_right, z_and_right);
346 }
347
348 make_and(left, right)
350 }
351 Expression::Or(bin) => {
352 let left = distribute_dnf(&bin.left, simplifier);
354 let right = distribute_dnf(&bin.right, simplifier);
355 make_or(left, right)
356 }
357 Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
358 _ => expression.clone(),
359 }
360}
361
362fn collect_connectors(expression: &Expression) -> Vec<Expression> {
368 let mut result = Vec::new();
369 collect_connectors_recursive(expression, &mut result);
370 result
371}
372
373fn collect_connectors_recursive(expression: &Expression, result: &mut Vec<Expression>) {
374 match expression {
375 Expression::And(bin) => {
376 result.push(expression.clone());
377 collect_connectors_recursive(&bin.left, result);
378 collect_connectors_recursive(&bin.right, result);
379 }
380 Expression::Or(bin) => {
381 result.push(expression.clone());
382 collect_connectors_recursive(&bin.left, result);
383 collect_connectors_recursive(&bin.right, result);
384 }
385 Expression::Paren(paren) => {
386 collect_connectors_recursive(&paren.this, result);
387 }
388 _ => {}
389 }
390}
391
392fn count_connectors(expression: &Expression) -> usize {
394 match expression {
395 Expression::And(bin) | Expression::Or(bin) => {
396 1 + count_connectors(&bin.left) + count_connectors(&bin.right)
397 }
398 Expression::Paren(paren) => count_connectors(&paren.this),
399 _ => 0,
400 }
401}
402
403fn unwrap_paren(expression: &Expression) -> &Expression {
405 match expression {
406 Expression::Paren(paren) => unwrap_paren(&paren.this),
407 _ => expression,
408 }
409}
410
411fn is_same_expression(a: &Expression, b: &Expression) -> bool {
413 std::ptr::eq(a as *const _, b as *const _) || format!("{:?}", a) == format!("{:?}", b)
415}
416
417fn make_and(left: Expression, right: Expression) -> Expression {
419 Expression::And(Box::new(BinaryOp {
420 left,
421 right,
422 left_comments: vec![],
423 operator_comments: vec![],
424 trailing_comments: vec![],
425 }))
426}
427
428fn make_or(left: Expression, right: Expression) -> Expression {
430 Expression::Or(Box::new(BinaryOp {
431 left,
432 right,
433 left_comments: vec![],
434 operator_comments: vec![],
435 trailing_comments: vec![],
436 }))
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::parser::Parser;
443
444 fn parse(sql: &str) -> Expression {
445 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
446 }
447
448 fn parse_predicate(sql: &str) -> Expression {
449 let full = format!("SELECT 1 WHERE {}", sql);
450 let stmt = parse(&full);
451 if let Expression::Select(select) = stmt {
452 if let Some(where_clause) = select.where_clause {
453 return where_clause.this;
454 }
455 }
456 panic!("Failed to extract predicate from: {}", sql);
457 }
458
459 #[test]
460 fn test_normalized_cnf() {
461 let expr = parse_predicate("(a OR b) AND (c OR d)");
463 assert!(normalized(&expr, false)); }
465
466 #[test]
467 fn test_normalized_dnf() {
468 let expr = parse_predicate("(a AND b) OR (c AND d)");
470 assert!(normalized(&expr, true)); }
472
473 #[test]
474 fn test_not_normalized_cnf() {
475 let expr = parse_predicate("(a AND b) OR c");
477 assert!(!normalized(&expr, false)); }
479
480 #[test]
481 fn test_not_normalized_dnf() {
482 let expr = parse_predicate("(a OR b) AND c");
484 assert!(!normalized(&expr, true)); }
486
487 #[test]
488 fn test_simple_literal_is_normalized() {
489 let expr = parse_predicate("a = 1");
490 assert!(normalized(&expr, false)); assert!(normalized(&expr, true)); }
493
494 #[test]
495 fn test_normalization_distance_simple() {
496 let expr = parse_predicate("a = 1");
498 let distance = normalization_distance(&expr, false, 128);
499 assert!(distance <= 0);
500 }
501
502 #[test]
503 fn test_normalization_distance_complex() {
504 let expr = parse_predicate("(a AND b) OR (c AND d)");
506 let distance = normalization_distance(&expr, false, 128);
507 assert!(distance > 0);
508 }
509
510 #[test]
511 fn test_normalize_to_cnf() {
512 let expr = parse_predicate("(x AND y) OR z");
514 let result = normalize(expr, false, 128).unwrap();
515
516 assert!(normalized(&result, false));
518 }
519
520 #[test]
521 fn test_normalize_to_dnf() {
522 let expr = parse_predicate("(x OR y) AND z");
524 let result = normalize(expr, true, 128).unwrap();
525
526 assert!(normalized(&result, true));
528 }
529
530 #[test]
531 fn test_count_connectors() {
532 let expr = parse_predicate("a AND b AND c");
533 let count = count_connectors(&expr);
534 assert_eq!(count, 2); }
536
537 #[test]
538 fn test_predicate_lengths() {
539 let expr = parse_predicate("a = 1");
541 let lengths = predicate_lengths(&expr, false, 128, 0);
542 assert_eq!(lengths, vec![1]);
543 }
544}