1use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21 #[error("Normalization distance {distance} exceeds max {max}")]
22 DistanceExceeded { distance: i64, max: i64 },
23}
24
25pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28pub fn normalize(
41 expression: Expression,
42 dnf: bool,
43 max_distance: i64,
44) -> NormalizeResult<Expression> {
45 let simplifier = Simplifier::new(None);
46 normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49fn normalize_with_simplifier(
51 expression: Expression,
52 dnf: bool,
53 max_distance: i64,
54 simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56 let mut result = expression.clone();
57
58 let connectors = collect_connectors(&expression);
60
61 for node in connectors {
62 if normalized(&node, dnf) {
63 continue;
64 }
65
66 let distance = normalization_distance(&node, dnf, max_distance);
68
69 if distance > max_distance {
70 return Ok(expression);
72 }
73
74 let normalized_node = apply_distributive_law(&node, dnf, max_distance, simplifier)?;
76
77 if is_same_expression(&node, &expression) {
80 result = normalized_node;
81 }
82 }
83
84 Ok(result)
85}
86
87pub fn normalized(expression: &Expression, dnf: bool) -> bool {
104 if dnf {
105 !has_and_with_or_descendant(expression)
107 } else {
108 !has_or_with_and_descendant(expression)
110 }
111}
112
113fn has_or_with_and_descendant(expression: &Expression) -> bool {
115 match expression {
116 Expression::Or(bin) => {
117 contains_and(&bin.left)
119 || contains_and(&bin.right)
120 || has_or_with_and_descendant(&bin.left)
121 || has_or_with_and_descendant(&bin.right)
122 }
123 Expression::And(bin) => {
124 has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
125 }
126 Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
127 _ => false,
128 }
129}
130
131fn has_and_with_or_descendant(expression: &Expression) -> bool {
133 match expression {
134 Expression::And(bin) => {
135 contains_or(&bin.left)
137 || contains_or(&bin.right)
138 || has_and_with_or_descendant(&bin.left)
139 || has_and_with_or_descendant(&bin.right)
140 }
141 Expression::Or(bin) => {
142 has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
143 }
144 Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
145 _ => false,
146 }
147}
148
149fn contains_and(expression: &Expression) -> bool {
151 match expression {
152 Expression::And(_) => true,
153 Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
154 Expression::Paren(paren) => contains_and(&paren.this),
155 _ => false,
156 }
157}
158
159fn contains_or(expression: &Expression) -> bool {
161 match expression {
162 Expression::Or(_) => true,
163 Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
164 Expression::Paren(paren) => contains_or(&paren.this),
165 _ => false,
166 }
167}
168
169pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
183 let connector_count = count_connectors(expression);
184 let mut total: i64 = -(connector_count as i64 + 1);
185
186 for length in predicate_lengths(expression, dnf, max_distance, 0) {
187 total += length;
188 if total > max_distance {
189 return total;
190 }
191 }
192
193 total
194}
195
196fn predicate_lengths(
206 expression: &Expression,
207 dnf: bool,
208 max_distance: i64,
209 depth: i64,
210) -> Vec<i64> {
211 if depth > max_distance {
212 return vec![depth];
213 }
214
215 let expr = unwrap_paren(expression);
216
217 match expr {
218 Expression::Or(bin) if !dnf => {
220 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
222 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
223
224 let mut result = Vec::new();
225 for a in &left_lengths {
226 for b in &right_lengths {
227 result.push(a + b);
228 }
229 }
230 result
231 }
232 Expression::And(bin) if dnf => {
234 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
236 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
237
238 let mut result = Vec::new();
239 for a in &left_lengths {
240 for b in &right_lengths {
241 result.push(a + b);
242 }
243 }
244 result
245 }
246 Expression::And(bin) | Expression::Or(bin) => {
248 let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
249 result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
250 result
251 }
252 _ => vec![1], }
254}
255
256fn apply_distributive_law(
261 expression: &Expression,
262 dnf: bool,
263 max_distance: i64,
264 simplifier: &Simplifier,
265) -> NormalizeResult<Expression> {
266 if normalized(expression, dnf) {
267 return Ok(expression.clone());
268 }
269
270 let distance = normalization_distance(expression, dnf, max_distance);
271 if distance > max_distance {
272 return Err(NormalizeError::DistanceExceeded {
273 distance,
274 max: max_distance,
275 });
276 }
277
278 let result = if dnf {
280 distribute_dnf(expression, simplifier)
281 } else {
282 distribute_cnf(expression, simplifier)
283 };
284
285 if !normalized(&result, dnf) {
287 apply_distributive_law(&result, dnf, max_distance, simplifier)
288 } else {
289 Ok(result)
290 }
291}
292
293fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
296 match expression {
297 Expression::Or(bin) => {
298 let left = distribute_cnf(&bin.left, simplifier);
299 let right = distribute_cnf(&bin.right, simplifier);
300
301 if let Expression::And(and_bin) = &right {
303 let left_or_y = make_or(left.clone(), and_bin.left.clone());
305 let left_or_z = make_or(left, and_bin.right.clone());
306 return make_and(left_or_y, left_or_z);
307 }
308
309 if let Expression::And(and_bin) = &left {
310 let y_or_right = make_or(and_bin.left.clone(), right.clone());
312 let z_or_right = make_or(and_bin.right.clone(), right);
313 return make_and(y_or_right, z_or_right);
314 }
315
316 make_or(left, right)
318 }
319 Expression::And(bin) => {
320 let left = distribute_cnf(&bin.left, simplifier);
322 let right = distribute_cnf(&bin.right, simplifier);
323 make_and(left, right)
324 }
325 Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
326 _ => expression.clone(),
327 }
328}
329
330fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
333 match expression {
334 Expression::And(bin) => {
335 let left = distribute_dnf(&bin.left, simplifier);
336 let right = distribute_dnf(&bin.right, simplifier);
337
338 if let Expression::Or(or_bin) = &right {
340 let left_and_y = make_and(left.clone(), or_bin.left.clone());
342 let left_and_z = make_and(left, or_bin.right.clone());
343 return make_or(left_and_y, left_and_z);
344 }
345
346 if let Expression::Or(or_bin) = &left {
347 let y_and_right = make_and(or_bin.left.clone(), right.clone());
349 let z_and_right = make_and(or_bin.right.clone(), right);
350 return make_or(y_and_right, z_and_right);
351 }
352
353 make_and(left, right)
355 }
356 Expression::Or(bin) => {
357 let left = distribute_dnf(&bin.left, simplifier);
359 let right = distribute_dnf(&bin.right, simplifier);
360 make_or(left, right)
361 }
362 Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
363 _ => expression.clone(),
364 }
365}
366
367fn collect_connectors(expression: &Expression) -> Vec<Expression> {
373 let mut result = Vec::new();
374 collect_connectors_recursive(expression, &mut result);
375 result
376}
377
378fn collect_connectors_recursive(expression: &Expression, result: &mut Vec<Expression>) {
379 match expression {
380 Expression::And(bin) => {
381 result.push(expression.clone());
382 collect_connectors_recursive(&bin.left, result);
383 collect_connectors_recursive(&bin.right, result);
384 }
385 Expression::Or(bin) => {
386 result.push(expression.clone());
387 collect_connectors_recursive(&bin.left, result);
388 collect_connectors_recursive(&bin.right, result);
389 }
390 Expression::Paren(paren) => {
391 collect_connectors_recursive(&paren.this, result);
392 }
393 _ => {}
394 }
395}
396
397fn count_connectors(expression: &Expression) -> usize {
399 match expression {
400 Expression::And(bin) | Expression::Or(bin) => {
401 1 + count_connectors(&bin.left) + count_connectors(&bin.right)
402 }
403 Expression::Paren(paren) => count_connectors(&paren.this),
404 _ => 0,
405 }
406}
407
408fn unwrap_paren(expression: &Expression) -> &Expression {
410 match expression {
411 Expression::Paren(paren) => unwrap_paren(&paren.this),
412 _ => expression,
413 }
414}
415
416fn is_same_expression(a: &Expression, b: &Expression) -> bool {
418 std::ptr::eq(a as *const _, b as *const _) || format!("{:?}", a) == format!("{:?}", b)
420}
421
422fn make_and(left: Expression, right: Expression) -> Expression {
424 Expression::And(Box::new(BinaryOp {
425 left,
426 right,
427 left_comments: vec![],
428 operator_comments: vec![],
429 trailing_comments: vec![],
430 }))
431}
432
433fn make_or(left: Expression, right: Expression) -> Expression {
435 Expression::Or(Box::new(BinaryOp {
436 left,
437 right,
438 left_comments: vec![],
439 operator_comments: vec![],
440 trailing_comments: vec![],
441 }))
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::parser::Parser;
448
449 fn parse(sql: &str) -> Expression {
450 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
451 }
452
453 fn parse_predicate(sql: &str) -> Expression {
454 let full = format!("SELECT 1 WHERE {}", sql);
455 let stmt = parse(&full);
456 if let Expression::Select(select) = stmt {
457 if let Some(where_clause) = select.where_clause {
458 return where_clause.this;
459 }
460 }
461 panic!("Failed to extract predicate from: {}", sql);
462 }
463
464 #[test]
465 fn test_normalized_cnf() {
466 let expr = parse_predicate("(a OR b) AND (c OR d)");
468 assert!(normalized(&expr, false)); }
470
471 #[test]
472 fn test_normalized_dnf() {
473 let expr = parse_predicate("(a AND b) OR (c AND d)");
475 assert!(normalized(&expr, true)); }
477
478 #[test]
479 fn test_not_normalized_cnf() {
480 let expr = parse_predicate("(a AND b) OR c");
482 assert!(!normalized(&expr, false)); }
484
485 #[test]
486 fn test_not_normalized_dnf() {
487 let expr = parse_predicate("(a OR b) AND c");
489 assert!(!normalized(&expr, true)); }
491
492 #[test]
493 fn test_simple_literal_is_normalized() {
494 let expr = parse_predicate("a = 1");
495 assert!(normalized(&expr, false)); assert!(normalized(&expr, true)); }
498
499 #[test]
500 fn test_normalization_distance_simple() {
501 let expr = parse_predicate("a = 1");
503 let distance = normalization_distance(&expr, false, 128);
504 assert!(distance <= 0);
505 }
506
507 #[test]
508 fn test_normalization_distance_complex() {
509 let expr = parse_predicate("(a AND b) OR (c AND d)");
511 let distance = normalization_distance(&expr, false, 128);
512 assert!(distance > 0);
513 }
514
515 #[test]
516 fn test_normalize_to_cnf() {
517 let expr = parse_predicate("(x AND y) OR z");
519 let result = normalize(expr, false, 128).unwrap();
520
521 assert!(normalized(&result, false));
523 }
524
525 #[test]
526 fn test_normalize_to_dnf() {
527 let expr = parse_predicate("(x OR y) AND z");
529 let result = normalize(expr, true, 128).unwrap();
530
531 assert!(normalized(&result, true));
533 }
534
535 #[test]
536 fn test_count_connectors() {
537 let expr = parse_predicate("a AND b AND c");
538 let count = count_connectors(&expr);
539 assert_eq!(count, 2); }
541
542 #[test]
543 fn test_predicate_lengths() {
544 let expr = parse_predicate("a = 1");
546 let lengths = predicate_lengths(&expr, false, 128, 0);
547 assert_eq!(lengths, vec![1]);
548 }
549}