1use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21 #[error("Normalization distance {distance} exceeds max {max}")]
22 DistanceExceeded { distance: i64, max: i64 },
23}
24
25pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28pub fn normalize(
41 expression: Expression,
42 dnf: bool,
43 max_distance: i64,
44) -> NormalizeResult<Expression> {
45 let simplifier = Simplifier::new(None);
46 normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49fn normalize_with_simplifier(
51 expression: Expression,
52 dnf: bool,
53 max_distance: i64,
54 simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56 let mut result = expression.clone();
57
58 let connectors = collect_connectors(&expression);
60
61 for node in connectors {
62 if normalized(&node, dnf) {
63 continue;
64 }
65
66 let distance = normalization_distance(&node, dnf, max_distance);
68
69 if distance > max_distance {
70 return Ok(expression);
72 }
73
74 let normalized_node = apply_distributive_law(&node, dnf, max_distance, simplifier)?;
76
77 if is_same_expression(&node, &expression) {
80 result = normalized_node;
81 }
82 }
83
84 Ok(result)
85}
86
87pub fn normalized(expression: &Expression, dnf: bool) -> bool {
104 if dnf {
105 !has_and_with_or_descendant(expression)
107 } else {
108 !has_or_with_and_descendant(expression)
110 }
111}
112
113fn has_or_with_and_descendant(expression: &Expression) -> bool {
115 match expression {
116 Expression::Or(bin) => {
117 contains_and(&bin.left)
119 || contains_and(&bin.right)
120 || has_or_with_and_descendant(&bin.left)
121 || has_or_with_and_descendant(&bin.right)
122 }
123 Expression::And(bin) => {
124 has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
125 }
126 Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
127 _ => false,
128 }
129}
130
131fn has_and_with_or_descendant(expression: &Expression) -> bool {
133 match expression {
134 Expression::And(bin) => {
135 contains_or(&bin.left)
137 || contains_or(&bin.right)
138 || has_and_with_or_descendant(&bin.left)
139 || has_and_with_or_descendant(&bin.right)
140 }
141 Expression::Or(bin) => {
142 has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
143 }
144 Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
145 _ => false,
146 }
147}
148
149fn contains_and(expression: &Expression) -> bool {
151 match expression {
152 Expression::And(_) => true,
153 Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
154 Expression::Paren(paren) => contains_and(&paren.this),
155 _ => false,
156 }
157}
158
159fn contains_or(expression: &Expression) -> bool {
161 match expression {
162 Expression::Or(_) => true,
163 Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
164 Expression::Paren(paren) => contains_or(&paren.this),
165 _ => false,
166 }
167}
168
169pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
183 let connector_count = count_connectors(expression);
184 let mut total: i64 = -(connector_count as i64 + 1);
185
186 for length in predicate_lengths(expression, dnf, max_distance, 0) {
187 total += length;
188 if total > max_distance {
189 return total;
190 }
191 }
192
193 total
194}
195
196fn predicate_lengths(expression: &Expression, dnf: bool, max_distance: i64, depth: i64) -> Vec<i64> {
206 if depth > max_distance {
207 return vec![depth];
208 }
209
210 let expr = unwrap_paren(expression);
211
212 match expr {
213 Expression::Or(bin) if !dnf => {
215 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
217 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
218
219 let mut result = Vec::new();
220 for a in &left_lengths {
221 for b in &right_lengths {
222 result.push(a + b);
223 }
224 }
225 result
226 }
227 Expression::And(bin) if dnf => {
229 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
232
233 let mut result = Vec::new();
234 for a in &left_lengths {
235 for b in &right_lengths {
236 result.push(a + b);
237 }
238 }
239 result
240 }
241 Expression::And(bin) | Expression::Or(bin) => {
243 let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
244 result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
245 result
246 }
247 _ => vec![1], }
249}
250
251fn apply_distributive_law(
256 expression: &Expression,
257 dnf: bool,
258 max_distance: i64,
259 simplifier: &Simplifier,
260) -> NormalizeResult<Expression> {
261 if normalized(expression, dnf) {
262 return Ok(expression.clone());
263 }
264
265 let distance = normalization_distance(expression, dnf, max_distance);
266 if distance > max_distance {
267 return Err(NormalizeError::DistanceExceeded {
268 distance,
269 max: max_distance,
270 });
271 }
272
273 let result = if dnf {
275 distribute_dnf(expression, simplifier)
276 } else {
277 distribute_cnf(expression, simplifier)
278 };
279
280 if !normalized(&result, dnf) {
282 apply_distributive_law(&result, dnf, max_distance, simplifier)
283 } else {
284 Ok(result)
285 }
286}
287
288fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
291 match expression {
292 Expression::Or(bin) => {
293 let left = distribute_cnf(&bin.left, simplifier);
294 let right = distribute_cnf(&bin.right, simplifier);
295
296 if let Expression::And(and_bin) = &right {
298 let left_or_y = make_or(left.clone(), and_bin.left.clone());
300 let left_or_z = make_or(left, and_bin.right.clone());
301 return make_and(left_or_y, left_or_z);
302 }
303
304 if let Expression::And(and_bin) = &left {
305 let y_or_right = make_or(and_bin.left.clone(), right.clone());
307 let z_or_right = make_or(and_bin.right.clone(), right);
308 return make_and(y_or_right, z_or_right);
309 }
310
311 make_or(left, right)
313 }
314 Expression::And(bin) => {
315 let left = distribute_cnf(&bin.left, simplifier);
317 let right = distribute_cnf(&bin.right, simplifier);
318 make_and(left, right)
319 }
320 Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
321 _ => expression.clone(),
322 }
323}
324
325fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
328 match expression {
329 Expression::And(bin) => {
330 let left = distribute_dnf(&bin.left, simplifier);
331 let right = distribute_dnf(&bin.right, simplifier);
332
333 if let Expression::Or(or_bin) = &right {
335 let left_and_y = make_and(left.clone(), or_bin.left.clone());
337 let left_and_z = make_and(left, or_bin.right.clone());
338 return make_or(left_and_y, left_and_z);
339 }
340
341 if let Expression::Or(or_bin) = &left {
342 let y_and_right = make_and(or_bin.left.clone(), right.clone());
344 let z_and_right = make_and(or_bin.right.clone(), right);
345 return make_or(y_and_right, z_and_right);
346 }
347
348 make_and(left, right)
350 }
351 Expression::Or(bin) => {
352 let left = distribute_dnf(&bin.left, simplifier);
354 let right = distribute_dnf(&bin.right, simplifier);
355 make_or(left, right)
356 }
357 Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
358 _ => expression.clone(),
359 }
360}
361
362fn collect_connectors(expression: &Expression) -> Vec<Expression> {
368 let mut result = Vec::new();
369 collect_connectors_recursive(expression, &mut result);
370 result
371}
372
373fn collect_connectors_recursive(expression: &Expression, result: &mut Vec<Expression>) {
374 match expression {
375 Expression::And(bin) => {
376 result.push(expression.clone());
377 collect_connectors_recursive(&bin.left, result);
378 collect_connectors_recursive(&bin.right, result);
379 }
380 Expression::Or(bin) => {
381 result.push(expression.clone());
382 collect_connectors_recursive(&bin.left, result);
383 collect_connectors_recursive(&bin.right, result);
384 }
385 Expression::Paren(paren) => {
386 collect_connectors_recursive(&paren.this, result);
387 }
388 _ => {}
389 }
390}
391
392fn count_connectors(expression: &Expression) -> usize {
394 match expression {
395 Expression::And(bin) | Expression::Or(bin) => {
396 1 + count_connectors(&bin.left) + count_connectors(&bin.right)
397 }
398 Expression::Paren(paren) => count_connectors(&paren.this),
399 _ => 0,
400 }
401}
402
403fn unwrap_paren(expression: &Expression) -> &Expression {
405 match expression {
406 Expression::Paren(paren) => unwrap_paren(&paren.this),
407 _ => expression,
408 }
409}
410
411fn is_same_expression(a: &Expression, b: &Expression) -> bool {
413 std::ptr::eq(a as *const _, b as *const _) || format!("{:?}", a) == format!("{:?}", b)
415}
416
417fn make_and(left: Expression, right: Expression) -> Expression {
419 Expression::And(Box::new(BinaryOp {
420 left,
421 right,
422 left_comments: vec![],
423 operator_comments: vec![],
424 trailing_comments: vec![],
425 }))
426}
427
428fn make_or(left: Expression, right: Expression) -> Expression {
430 Expression::Or(Box::new(BinaryOp {
431 left,
432 right,
433 left_comments: vec![],
434 operator_comments: vec![],
435 trailing_comments: vec![],
436 }))
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::generator::Generator;
443 use crate::parser::Parser;
444
445 fn gen(expr: &Expression) -> String {
446 Generator::new().generate(expr).unwrap()
447 }
448
449 fn parse(sql: &str) -> Expression {
450 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
451 }
452
453 fn parse_predicate(sql: &str) -> Expression {
454 let full = format!("SELECT 1 WHERE {}", sql);
455 let stmt = parse(&full);
456 if let Expression::Select(select) = stmt {
457 if let Some(where_clause) = select.where_clause {
458 return where_clause.this;
459 }
460 }
461 panic!("Failed to extract predicate from: {}", sql);
462 }
463
464 #[test]
465 fn test_normalized_cnf() {
466 let expr = parse_predicate("(a OR b) AND (c OR d)");
468 assert!(normalized(&expr, false)); }
470
471 #[test]
472 fn test_normalized_dnf() {
473 let expr = parse_predicate("(a AND b) OR (c AND d)");
475 assert!(normalized(&expr, true)); }
477
478 #[test]
479 fn test_not_normalized_cnf() {
480 let expr = parse_predicate("(a AND b) OR c");
482 assert!(!normalized(&expr, false)); }
484
485 #[test]
486 fn test_not_normalized_dnf() {
487 let expr = parse_predicate("(a OR b) AND c");
489 assert!(!normalized(&expr, true)); }
491
492 #[test]
493 fn test_simple_literal_is_normalized() {
494 let expr = parse_predicate("a = 1");
495 assert!(normalized(&expr, false)); assert!(normalized(&expr, true)); }
498
499 #[test]
500 fn test_normalization_distance_simple() {
501 let expr = parse_predicate("a = 1");
503 let distance = normalization_distance(&expr, false, 128);
504 assert!(distance <= 0);
505 }
506
507 #[test]
508 fn test_normalization_distance_complex() {
509 let expr = parse_predicate("(a AND b) OR (c AND d)");
511 let distance = normalization_distance(&expr, false, 128);
512 assert!(distance > 0);
513 }
514
515 #[test]
516 fn test_normalize_to_cnf() {
517 let expr = parse_predicate("(x AND y) OR z");
519 let result = normalize(expr, false, 128).unwrap();
520
521 assert!(normalized(&result, false));
523 }
524
525 #[test]
526 fn test_normalize_to_dnf() {
527 let expr = parse_predicate("(x OR y) AND z");
529 let result = normalize(expr, true, 128).unwrap();
530
531 assert!(normalized(&result, true));
533 }
534
535 #[test]
536 fn test_count_connectors() {
537 let expr = parse_predicate("a AND b AND c");
538 let count = count_connectors(&expr);
539 assert_eq!(count, 2); }
541
542 #[test]
543 fn test_predicate_lengths() {
544 let expr = parse_predicate("a = 1");
546 let lengths = predicate_lengths(&expr, false, 128, 0);
547 assert_eq!(lengths, vec![1]);
548 }
549}