Skip to main content

polyglot_sql/optimizer/
normalize.rs

1//! Boolean Normalization Module
2//!
3//! This module provides functionality for converting SQL boolean expressions
4//! to Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
5//!
6//! CNF: (a OR b) AND (c OR d) - useful for predicate pushdown
7//! DNF: (a AND b) OR (c AND d) - useful for partition pruning
8//!
9//! Ported from sqlglot's optimizer/normalize.py
10
11use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15/// Maximum default distance for normalization
16pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18/// Errors that can occur during normalization
19#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21    #[error("Normalization distance {distance} exceeds max {max}")]
22    DistanceExceeded { distance: i64, max: i64 },
23}
24
25/// Result type for normalization operations
26pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28/// Rewrite SQL AST into Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
29///
30/// CNF (default): (x AND y) OR z => (x OR z) AND (y OR z)
31/// DNF: (x OR y) AND z => (x AND z) OR (y AND z)
32///
33/// # Arguments
34/// * `expression` - Expression to normalize
35/// * `dnf` - If true, convert to DNF; otherwise CNF (default)
36/// * `max_distance` - Maximum estimated distance before giving up
37///
38/// # Returns
39/// The normalized expression, or the original if normalization would be too expensive.
40pub fn normalize(
41    expression: Expression,
42    dnf: bool,
43    max_distance: i64,
44) -> NormalizeResult<Expression> {
45    let simplifier = Simplifier::new(None);
46    normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49/// Normalize with a provided simplifier instance.
50fn normalize_with_simplifier(
51    expression: Expression,
52    dnf: bool,
53    max_distance: i64,
54    simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56    if normalized(&expression, dnf) {
57        return Ok(expression);
58    }
59
60    // Estimate full-tree cost first to avoid expensive expansion.
61    let distance = normalization_distance(&expression, dnf, max_distance);
62    if distance > max_distance {
63        return Ok(expression);
64    }
65
66    apply_distributive_law(&expression, dnf, max_distance, simplifier)
67}
68
69/// Check whether a given expression is in a normal form.
70///
71/// CNF (Conjunctive Normal Form): (A OR B) AND (C OR D)
72///   - Conjunction (AND) of disjunctions (OR)
73///   - An OR cannot have an AND as a descendant
74///
75/// DNF (Disjunctive Normal Form): (A AND B) OR (C AND D)
76///   - Disjunction (OR) of conjunctions (AND)
77///   - An AND cannot have an OR as a descendant
78///
79/// # Arguments
80/// * `expression` - The expression to check
81/// * `dnf` - Whether to check for DNF (true) or CNF (false)
82///
83/// # Returns
84/// True if the expression is in the requested normal form.
85pub fn normalized(expression: &Expression, dnf: bool) -> bool {
86    if dnf {
87        // DNF: An AND cannot have OR as a descendant
88        !has_and_with_or_descendant(expression)
89    } else {
90        // CNF: An OR cannot have AND as a descendant
91        !has_or_with_and_descendant(expression)
92    }
93}
94
95/// Check if any OR in the expression has an AND as a descendant (violates CNF)
96fn has_or_with_and_descendant(expression: &Expression) -> bool {
97    match expression {
98        Expression::Or(bin) => {
99            // Check if either child is an AND, or if children have the violation
100            contains_and(&bin.left)
101                || contains_and(&bin.right)
102                || has_or_with_and_descendant(&bin.left)
103                || has_or_with_and_descendant(&bin.right)
104        }
105        Expression::And(bin) => {
106            has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
107        }
108        Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
109        _ => false,
110    }
111}
112
113/// Check if any AND in the expression has an OR as a descendant (violates DNF)
114fn has_and_with_or_descendant(expression: &Expression) -> bool {
115    match expression {
116        Expression::And(bin) => {
117            // Check if either child is an OR, or if children have the violation
118            contains_or(&bin.left)
119                || contains_or(&bin.right)
120                || has_and_with_or_descendant(&bin.left)
121                || has_and_with_or_descendant(&bin.right)
122        }
123        Expression::Or(bin) => {
124            has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
125        }
126        Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
127        _ => false,
128    }
129}
130
131/// Check if expression contains any AND (at any level)
132fn contains_and(expression: &Expression) -> bool {
133    match expression {
134        Expression::And(_) => true,
135        Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
136        Expression::Paren(paren) => contains_and(&paren.this),
137        _ => false,
138    }
139}
140
141/// Check if expression contains any OR (at any level)
142fn contains_or(expression: &Expression) -> bool {
143    match expression {
144        Expression::Or(_) => true,
145        Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
146        Expression::Paren(paren) => contains_or(&paren.this),
147        _ => false,
148    }
149}
150
151/// Calculate the normalization distance for an expression.
152///
153/// This estimates the cost of converting to normal form.
154/// The conversion is exponential in complexity, so this helps decide
155/// whether to attempt it.
156///
157/// # Arguments
158/// * `expression` - The expression to analyze
159/// * `dnf` - Whether checking distance to DNF (true) or CNF (false)
160/// * `max_distance` - Early exit if distance exceeds this
161///
162/// # Returns
163/// The estimated normalization distance.
164pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
165    let connector_count = count_connectors(expression);
166    let mut total: i64 = -(connector_count as i64 + 1);
167
168    for length in predicate_lengths(expression, dnf, max_distance, 0) {
169        total += length;
170        if total > max_distance {
171            return total;
172        }
173    }
174
175    total
176}
177
178/// Calculate predicate lengths when expanded to normalized form.
179///
180/// For example: (A AND B) OR C -> [2, 2] because len(A OR C) = 2, len(B OR C) = 2
181///
182/// In CNF mode (dnf=false): OR distributes over AND
183///   x OR (y AND z) => (x OR y) AND (x OR z)
184///
185/// In DNF mode (dnf=true): AND distributes over OR
186///   x AND (y OR z) => (x AND y) OR (x AND z)
187fn predicate_lengths(
188    expression: &Expression,
189    dnf: bool,
190    max_distance: i64,
191    depth: i64,
192) -> Vec<i64> {
193    if depth > max_distance {
194        return vec![depth];
195    }
196
197    let expr = unwrap_paren(expression);
198
199    match expr {
200        // In CNF mode, OR is the distributing operator (we're breaking up ORs of ANDs)
201        Expression::Or(bin) if !dnf => {
202            // For CNF: OR causes multiplication in the distance calculation
203            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
204            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
205
206            let mut result = Vec::new();
207            for a in &left_lengths {
208                for b in &right_lengths {
209                    result.push(a + b);
210                }
211            }
212            result
213        }
214        // In DNF mode, AND is the distributing operator (we're breaking up ANDs of ORs)
215        Expression::And(bin) if dnf => {
216            // For DNF: AND causes multiplication in the distance calculation
217            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
218            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
219
220            let mut result = Vec::new();
221            for a in &left_lengths {
222                for b in &right_lengths {
223                    result.push(a + b);
224                }
225            }
226            result
227        }
228        // Non-distributing connectors: just collect lengths from both sides
229        Expression::And(bin) | Expression::Or(bin) => {
230            let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231            result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
232            result
233        }
234        _ => vec![1], // Leaf predicate
235    }
236}
237
238/// Apply the distributive law to normalize an expression.
239///
240/// CNF: x OR (y AND z) => (x OR y) AND (x OR z)
241/// DNF: x AND (y OR z) => (x AND y) OR (x AND z)
242fn apply_distributive_law(
243    expression: &Expression,
244    dnf: bool,
245    max_distance: i64,
246    simplifier: &Simplifier,
247) -> NormalizeResult<Expression> {
248    if normalized(expression, dnf) {
249        return Ok(expression.clone());
250    }
251
252    let distance = normalization_distance(expression, dnf, max_distance);
253    if distance > max_distance {
254        return Err(NormalizeError::DistanceExceeded {
255            distance,
256            max: max_distance,
257        });
258    }
259
260    // Apply distributive law based on mode
261    let result = if dnf {
262        distribute_dnf(expression, simplifier)
263    } else {
264        distribute_cnf(expression, simplifier)
265    };
266
267    // Recursively apply until normalized
268    if !normalized(&result, dnf) {
269        apply_distributive_law(&result, dnf, max_distance, simplifier)
270    } else {
271        Ok(result)
272    }
273}
274
275/// Apply distributive law for CNF conversion.
276/// x OR (y AND z) => (x OR y) AND (x OR z)
277fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
278    match expression {
279        Expression::Or(bin) => {
280            let left = distribute_cnf(&bin.left, simplifier);
281            let right = distribute_cnf(&bin.right, simplifier);
282
283            // Check if either side is an AND
284            if let Expression::And(and_bin) = &right {
285                // x OR (y AND z) => (x OR y) AND (x OR z)
286                let left_or_y = make_or(left.clone(), and_bin.left.clone());
287                let left_or_z = make_or(left, and_bin.right.clone());
288                return make_and(left_or_y, left_or_z);
289            }
290
291            if let Expression::And(and_bin) = &left {
292                // (y AND z) OR x => (y OR x) AND (z OR x)
293                let y_or_right = make_or(and_bin.left.clone(), right.clone());
294                let z_or_right = make_or(and_bin.right.clone(), right);
295                return make_and(y_or_right, z_or_right);
296            }
297
298            // No AND found, return simplified OR
299            make_or(left, right)
300        }
301        Expression::And(bin) => {
302            // Recurse into AND
303            let left = distribute_cnf(&bin.left, simplifier);
304            let right = distribute_cnf(&bin.right, simplifier);
305            make_and(left, right)
306        }
307        Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
308        _ => expression.clone(),
309    }
310}
311
312/// Apply distributive law for DNF conversion.
313/// x AND (y OR z) => (x AND y) OR (x AND z)
314fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
315    match expression {
316        Expression::And(bin) => {
317            let left = distribute_dnf(&bin.left, simplifier);
318            let right = distribute_dnf(&bin.right, simplifier);
319
320            // Check if either side is an OR
321            if let Expression::Or(or_bin) = &right {
322                // x AND (y OR z) => (x AND y) OR (x AND z)
323                let left_and_y = make_and(left.clone(), or_bin.left.clone());
324                let left_and_z = make_and(left, or_bin.right.clone());
325                return make_or(left_and_y, left_and_z);
326            }
327
328            if let Expression::Or(or_bin) = &left {
329                // (y OR z) AND x => (y AND x) OR (z AND x)
330                let y_and_right = make_and(or_bin.left.clone(), right.clone());
331                let z_and_right = make_and(or_bin.right.clone(), right);
332                return make_or(y_and_right, z_and_right);
333            }
334
335            // No OR found, return simplified AND
336            make_and(left, right)
337        }
338        Expression::Or(bin) => {
339            // Recurse into OR
340            let left = distribute_dnf(&bin.left, simplifier);
341            let right = distribute_dnf(&bin.right, simplifier);
342            make_or(left, right)
343        }
344        Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
345        _ => expression.clone(),
346    }
347}
348
349// ============================================================================
350// Helper functions
351// ============================================================================
352
353/// Count the number of connector nodes in an expression
354fn count_connectors(expression: &Expression) -> usize {
355    match expression {
356        Expression::And(bin) | Expression::Or(bin) => {
357            1 + count_connectors(&bin.left) + count_connectors(&bin.right)
358        }
359        Expression::Paren(paren) => count_connectors(&paren.this),
360        _ => 0,
361    }
362}
363
364/// Unwrap parentheses from an expression
365fn unwrap_paren(expression: &Expression) -> &Expression {
366    match expression {
367        Expression::Paren(paren) => unwrap_paren(&paren.this),
368        _ => expression,
369    }
370}
371
372/// Create an AND expression
373fn make_and(left: Expression, right: Expression) -> Expression {
374    Expression::And(Box::new(BinaryOp {
375        left,
376        right,
377        left_comments: vec![],
378        operator_comments: vec![],
379        trailing_comments: vec![],
380    }))
381}
382
383/// Create an OR expression
384fn make_or(left: Expression, right: Expression) -> Expression {
385    Expression::Or(Box::new(BinaryOp {
386        left,
387        right,
388        left_comments: vec![],
389        operator_comments: vec![],
390        trailing_comments: vec![],
391    }))
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::parser::Parser;
398
399    fn parse(sql: &str) -> Expression {
400        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
401    }
402
403    fn parse_predicate(sql: &str) -> Expression {
404        let full = format!("SELECT 1 WHERE {}", sql);
405        let stmt = parse(&full);
406        if let Expression::Select(select) = stmt {
407            if let Some(where_clause) = select.where_clause {
408                return where_clause.this;
409            }
410        }
411        panic!("Failed to extract predicate from: {}", sql);
412    }
413
414    #[test]
415    fn test_normalized_cnf() {
416        // (a OR b) AND (c OR d) is in CNF
417        let expr = parse_predicate("(a OR b) AND (c OR d)");
418        assert!(normalized(&expr, false)); // CNF
419    }
420
421    #[test]
422    fn test_normalized_dnf() {
423        // (a AND b) OR (c AND d) is in DNF
424        let expr = parse_predicate("(a AND b) OR (c AND d)");
425        assert!(normalized(&expr, true)); // DNF
426    }
427
428    #[test]
429    fn test_not_normalized_cnf() {
430        // (a AND b) OR c is NOT in CNF (has AND under OR)
431        let expr = parse_predicate("(a AND b) OR c");
432        assert!(!normalized(&expr, false)); // Not CNF
433    }
434
435    #[test]
436    fn test_not_normalized_dnf() {
437        // (a OR b) AND c is NOT in DNF (has OR under AND)
438        let expr = parse_predicate("(a OR b) AND c");
439        assert!(!normalized(&expr, true)); // Not DNF
440    }
441
442    #[test]
443    fn test_simple_literal_is_normalized() {
444        let expr = parse_predicate("a = 1");
445        assert!(normalized(&expr, false)); // CNF
446        assert!(normalized(&expr, true)); // DNF
447    }
448
449    #[test]
450    fn test_normalization_distance_simple() {
451        // Simple predicate should have low distance
452        let expr = parse_predicate("a = 1");
453        let distance = normalization_distance(&expr, false, 128);
454        assert!(distance <= 0);
455    }
456
457    #[test]
458    fn test_normalization_distance_complex() {
459        // (a AND b) OR (c AND d) requires expansion
460        let expr = parse_predicate("(a AND b) OR (c AND d)");
461        let distance = normalization_distance(&expr, false, 128);
462        assert!(distance > 0);
463    }
464
465    #[test]
466    fn test_normalize_to_cnf() {
467        // (x AND y) OR z => (x OR z) AND (y OR z)
468        let expr = parse_predicate("(x AND y) OR z");
469        let result = normalize(expr, false, 128).unwrap();
470
471        // Result should be in CNF
472        assert!(normalized(&result, false));
473    }
474
475    #[test]
476    fn test_normalize_to_dnf() {
477        // (x OR y) AND z => (x AND z) OR (y AND z)
478        let expr = parse_predicate("(x OR y) AND z");
479        let result = normalize(expr, true, 128).unwrap();
480
481        // Result should be in DNF
482        assert!(normalized(&result, true));
483    }
484
485    #[test]
486    fn test_count_connectors() {
487        let expr = parse_predicate("a AND b AND c");
488        let count = count_connectors(&expr);
489        assert_eq!(count, 2); // Two AND connectors
490    }
491
492    #[test]
493    fn test_predicate_lengths() {
494        // Simple case
495        let expr = parse_predicate("a = 1");
496        let lengths = predicate_lengths(&expr, false, 128, 0);
497        assert_eq!(lengths, vec![1]);
498    }
499}