1use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21 #[error("Normalization distance {distance} exceeds max {max}")]
22 DistanceExceeded { distance: i64, max: i64 },
23}
24
25pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28pub fn normalize(
41 expression: Expression,
42 dnf: bool,
43 max_distance: i64,
44) -> NormalizeResult<Expression> {
45 let simplifier = Simplifier::new(None);
46 normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49fn normalize_with_simplifier(
51 expression: Expression,
52 dnf: bool,
53 max_distance: i64,
54 simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56 if normalized(&expression, dnf) {
57 return Ok(expression);
58 }
59
60 let distance = normalization_distance(&expression, dnf, max_distance);
62 if distance > max_distance {
63 return Ok(expression);
64 }
65
66 apply_distributive_law(&expression, dnf, max_distance, simplifier)
67}
68
69pub fn normalized(expression: &Expression, dnf: bool) -> bool {
86 if dnf {
87 !has_and_with_or_descendant(expression)
89 } else {
90 !has_or_with_and_descendant(expression)
92 }
93}
94
95fn has_or_with_and_descendant(expression: &Expression) -> bool {
97 match expression {
98 Expression::Or(bin) => {
99 contains_and(&bin.left)
101 || contains_and(&bin.right)
102 || has_or_with_and_descendant(&bin.left)
103 || has_or_with_and_descendant(&bin.right)
104 }
105 Expression::And(bin) => {
106 has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
107 }
108 Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
109 _ => false,
110 }
111}
112
113fn has_and_with_or_descendant(expression: &Expression) -> bool {
115 match expression {
116 Expression::And(bin) => {
117 contains_or(&bin.left)
119 || contains_or(&bin.right)
120 || has_and_with_or_descendant(&bin.left)
121 || has_and_with_or_descendant(&bin.right)
122 }
123 Expression::Or(bin) => {
124 has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
125 }
126 Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
127 _ => false,
128 }
129}
130
131fn contains_and(expression: &Expression) -> bool {
133 match expression {
134 Expression::And(_) => true,
135 Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
136 Expression::Paren(paren) => contains_and(&paren.this),
137 _ => false,
138 }
139}
140
141fn contains_or(expression: &Expression) -> bool {
143 match expression {
144 Expression::Or(_) => true,
145 Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
146 Expression::Paren(paren) => contains_or(&paren.this),
147 _ => false,
148 }
149}
150
151pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
165 let connector_count = count_connectors(expression);
166 let mut total: i64 = -(connector_count as i64 + 1);
167
168 for length in predicate_lengths(expression, dnf, max_distance, 0) {
169 total += length;
170 if total > max_distance {
171 return total;
172 }
173 }
174
175 total
176}
177
178fn predicate_lengths(
188 expression: &Expression,
189 dnf: bool,
190 max_distance: i64,
191 depth: i64,
192) -> Vec<i64> {
193 if depth > max_distance {
194 return vec![depth];
195 }
196
197 let expr = unwrap_paren(expression);
198
199 match expr {
200 Expression::Or(bin) if !dnf => {
202 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
204 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
205
206 let mut result = Vec::new();
207 for a in &left_lengths {
208 for b in &right_lengths {
209 result.push(a + b);
210 }
211 }
212 result
213 }
214 Expression::And(bin) if dnf => {
216 let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
218 let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
219
220 let mut result = Vec::new();
221 for a in &left_lengths {
222 for b in &right_lengths {
223 result.push(a + b);
224 }
225 }
226 result
227 }
228 Expression::And(bin) | Expression::Or(bin) => {
230 let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231 result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
232 result
233 }
234 _ => vec![1], }
236}
237
238fn apply_distributive_law(
243 expression: &Expression,
244 dnf: bool,
245 max_distance: i64,
246 simplifier: &Simplifier,
247) -> NormalizeResult<Expression> {
248 if normalized(expression, dnf) {
249 return Ok(expression.clone());
250 }
251
252 let distance = normalization_distance(expression, dnf, max_distance);
253 if distance > max_distance {
254 return Err(NormalizeError::DistanceExceeded {
255 distance,
256 max: max_distance,
257 });
258 }
259
260 let result = if dnf {
262 distribute_dnf(expression, simplifier)
263 } else {
264 distribute_cnf(expression, simplifier)
265 };
266
267 if !normalized(&result, dnf) {
269 apply_distributive_law(&result, dnf, max_distance, simplifier)
270 } else {
271 Ok(result)
272 }
273}
274
275fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
278 match expression {
279 Expression::Or(bin) => {
280 let left = distribute_cnf(&bin.left, simplifier);
281 let right = distribute_cnf(&bin.right, simplifier);
282
283 if let Expression::And(and_bin) = &right {
285 let left_or_y = make_or(left.clone(), and_bin.left.clone());
287 let left_or_z = make_or(left, and_bin.right.clone());
288 return make_and(left_or_y, left_or_z);
289 }
290
291 if let Expression::And(and_bin) = &left {
292 let y_or_right = make_or(and_bin.left.clone(), right.clone());
294 let z_or_right = make_or(and_bin.right.clone(), right);
295 return make_and(y_or_right, z_or_right);
296 }
297
298 make_or(left, right)
300 }
301 Expression::And(bin) => {
302 let left = distribute_cnf(&bin.left, simplifier);
304 let right = distribute_cnf(&bin.right, simplifier);
305 make_and(left, right)
306 }
307 Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
308 _ => expression.clone(),
309 }
310}
311
312fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
315 match expression {
316 Expression::And(bin) => {
317 let left = distribute_dnf(&bin.left, simplifier);
318 let right = distribute_dnf(&bin.right, simplifier);
319
320 if let Expression::Or(or_bin) = &right {
322 let left_and_y = make_and(left.clone(), or_bin.left.clone());
324 let left_and_z = make_and(left, or_bin.right.clone());
325 return make_or(left_and_y, left_and_z);
326 }
327
328 if let Expression::Or(or_bin) = &left {
329 let y_and_right = make_and(or_bin.left.clone(), right.clone());
331 let z_and_right = make_and(or_bin.right.clone(), right);
332 return make_or(y_and_right, z_and_right);
333 }
334
335 make_and(left, right)
337 }
338 Expression::Or(bin) => {
339 let left = distribute_dnf(&bin.left, simplifier);
341 let right = distribute_dnf(&bin.right, simplifier);
342 make_or(left, right)
343 }
344 Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
345 _ => expression.clone(),
346 }
347}
348
349fn count_connectors(expression: &Expression) -> usize {
355 match expression {
356 Expression::And(bin) | Expression::Or(bin) => {
357 1 + count_connectors(&bin.left) + count_connectors(&bin.right)
358 }
359 Expression::Paren(paren) => count_connectors(&paren.this),
360 _ => 0,
361 }
362}
363
364fn unwrap_paren(expression: &Expression) -> &Expression {
366 match expression {
367 Expression::Paren(paren) => unwrap_paren(&paren.this),
368 _ => expression,
369 }
370}
371
372fn make_and(left: Expression, right: Expression) -> Expression {
374 Expression::And(Box::new(BinaryOp {
375 left,
376 right,
377 left_comments: vec![],
378 operator_comments: vec![],
379 trailing_comments: vec![],
380 inferred_type: None,
381 }))
382}
383
384fn make_or(left: Expression, right: Expression) -> Expression {
386 Expression::Or(Box::new(BinaryOp {
387 left,
388 right,
389 left_comments: vec![],
390 operator_comments: vec![],
391 trailing_comments: vec![],
392 inferred_type: None,
393 }))
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::parser::Parser;
400
401 fn parse(sql: &str) -> Expression {
402 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
403 }
404
405 fn parse_predicate(sql: &str) -> Expression {
406 let full = format!("SELECT 1 WHERE {}", sql);
407 let stmt = parse(&full);
408 if let Expression::Select(select) = stmt {
409 if let Some(where_clause) = select.where_clause {
410 return where_clause.this;
411 }
412 }
413 panic!("Failed to extract predicate from: {}", sql);
414 }
415
416 #[test]
417 fn test_normalized_cnf() {
418 let expr = parse_predicate("(a OR b) AND (c OR d)");
420 assert!(normalized(&expr, false)); }
422
423 #[test]
424 fn test_normalized_dnf() {
425 let expr = parse_predicate("(a AND b) OR (c AND d)");
427 assert!(normalized(&expr, true)); }
429
430 #[test]
431 fn test_not_normalized_cnf() {
432 let expr = parse_predicate("(a AND b) OR c");
434 assert!(!normalized(&expr, false)); }
436
437 #[test]
438 fn test_not_normalized_dnf() {
439 let expr = parse_predicate("(a OR b) AND c");
441 assert!(!normalized(&expr, true)); }
443
444 #[test]
445 fn test_simple_literal_is_normalized() {
446 let expr = parse_predicate("a = 1");
447 assert!(normalized(&expr, false)); assert!(normalized(&expr, true)); }
450
451 #[test]
452 fn test_normalization_distance_simple() {
453 let expr = parse_predicate("a = 1");
455 let distance = normalization_distance(&expr, false, 128);
456 assert!(distance <= 0);
457 }
458
459 #[test]
460 fn test_normalization_distance_complex() {
461 let expr = parse_predicate("(a AND b) OR (c AND d)");
463 let distance = normalization_distance(&expr, false, 128);
464 assert!(distance > 0);
465 }
466
467 #[test]
468 fn test_normalize_to_cnf() {
469 let expr = parse_predicate("(x AND y) OR z");
471 let result = normalize(expr, false, 128).unwrap();
472
473 assert!(normalized(&result, false));
475 }
476
477 #[test]
478 fn test_normalize_to_dnf() {
479 let expr = parse_predicate("(x OR y) AND z");
481 let result = normalize(expr, true, 128).unwrap();
482
483 assert!(normalized(&result, true));
485 }
486
487 #[test]
488 fn test_count_connectors() {
489 let expr = parse_predicate("a AND b AND c");
490 let count = count_connectors(&expr);
491 assert_eq!(count, 2); }
493
494 #[test]
495 fn test_predicate_lengths() {
496 let expr = parse_predicate("a = 1");
498 let lengths = predicate_lengths(&expr, false, 128, 0);
499 assert_eq!(lengths, vec![1]);
500 }
501}