Skip to main content

polyglot_sql/optimizer/
normalize.rs

1//! Boolean Normalization Module
2//!
3//! This module provides functionality for converting SQL boolean expressions
4//! to Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
5//!
6//! CNF: (a OR b) AND (c OR d) - useful for predicate pushdown
7//! DNF: (a AND b) OR (c AND d) - useful for partition pruning
8//!
9//! Ported from sqlglot's optimizer/normalize.py
10
11use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15/// Maximum default distance for normalization
16pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18/// Errors that can occur during normalization
19#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21    #[error("Normalization distance {distance} exceeds max {max}")]
22    DistanceExceeded { distance: i64, max: i64 },
23}
24
25/// Result type for normalization operations
26pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28/// Rewrite SQL AST into Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
29///
30/// CNF (default): (x AND y) OR z => (x OR z) AND (y OR z)
31/// DNF: (x OR y) AND z => (x AND z) OR (y AND z)
32///
33/// # Arguments
34/// * `expression` - Expression to normalize
35/// * `dnf` - If true, convert to DNF; otherwise CNF (default)
36/// * `max_distance` - Maximum estimated distance before giving up
37///
38/// # Returns
39/// The normalized expression, or the original if normalization would be too expensive.
40pub fn normalize(
41    expression: Expression,
42    dnf: bool,
43    max_distance: i64,
44) -> NormalizeResult<Expression> {
45    let simplifier = Simplifier::new(None);
46    normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49/// Normalize with a provided simplifier instance.
50fn normalize_with_simplifier(
51    expression: Expression,
52    dnf: bool,
53    max_distance: i64,
54    simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56    if normalized(&expression, dnf) {
57        return Ok(expression);
58    }
59
60    // Estimate full-tree cost first to avoid expensive expansion.
61    let distance = normalization_distance(&expression, dnf, max_distance);
62    if distance > max_distance {
63        return Ok(expression);
64    }
65
66    apply_distributive_law(&expression, dnf, max_distance, simplifier)
67}
68
69/// Check whether a given expression is in a normal form.
70///
71/// CNF (Conjunctive Normal Form): (A OR B) AND (C OR D)
72///   - Conjunction (AND) of disjunctions (OR)
73///   - An OR cannot have an AND as a descendant
74///
75/// DNF (Disjunctive Normal Form): (A AND B) OR (C AND D)
76///   - Disjunction (OR) of conjunctions (AND)
77///   - An AND cannot have an OR as a descendant
78///
79/// # Arguments
80/// * `expression` - The expression to check
81/// * `dnf` - Whether to check for DNF (true) or CNF (false)
82///
83/// # Returns
84/// True if the expression is in the requested normal form.
85pub fn normalized(expression: &Expression, dnf: bool) -> bool {
86    if dnf {
87        // DNF: An AND cannot have OR as a descendant
88        !has_and_with_or_descendant(expression)
89    } else {
90        // CNF: An OR cannot have AND as a descendant
91        !has_or_with_and_descendant(expression)
92    }
93}
94
95/// Check if any OR in the expression has an AND as a descendant (violates CNF)
96fn has_or_with_and_descendant(expression: &Expression) -> bool {
97    match expression {
98        Expression::Or(bin) => {
99            // Check if either child is an AND, or if children have the violation
100            contains_and(&bin.left)
101                || contains_and(&bin.right)
102                || has_or_with_and_descendant(&bin.left)
103                || has_or_with_and_descendant(&bin.right)
104        }
105        Expression::And(bin) => {
106            has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
107        }
108        Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
109        _ => false,
110    }
111}
112
113/// Check if any AND in the expression has an OR as a descendant (violates DNF)
114fn has_and_with_or_descendant(expression: &Expression) -> bool {
115    match expression {
116        Expression::And(bin) => {
117            // Check if either child is an OR, or if children have the violation
118            contains_or(&bin.left)
119                || contains_or(&bin.right)
120                || has_and_with_or_descendant(&bin.left)
121                || has_and_with_or_descendant(&bin.right)
122        }
123        Expression::Or(bin) => {
124            has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
125        }
126        Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
127        _ => false,
128    }
129}
130
131/// Check if expression contains any AND (at any level)
132fn contains_and(expression: &Expression) -> bool {
133    match expression {
134        Expression::And(_) => true,
135        Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
136        Expression::Paren(paren) => contains_and(&paren.this),
137        _ => false,
138    }
139}
140
141/// Check if expression contains any OR (at any level)
142fn contains_or(expression: &Expression) -> bool {
143    match expression {
144        Expression::Or(_) => true,
145        Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
146        Expression::Paren(paren) => contains_or(&paren.this),
147        _ => false,
148    }
149}
150
151/// Calculate the normalization distance for an expression.
152///
153/// This estimates the cost of converting to normal form.
154/// The conversion is exponential in complexity, so this helps decide
155/// whether to attempt it.
156///
157/// # Arguments
158/// * `expression` - The expression to analyze
159/// * `dnf` - Whether checking distance to DNF (true) or CNF (false)
160/// * `max_distance` - Early exit if distance exceeds this
161///
162/// # Returns
163/// The estimated normalization distance.
164pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
165    let connector_count = count_connectors(expression);
166    let mut total: i64 = -(connector_count as i64 + 1);
167
168    for length in predicate_lengths(expression, dnf, max_distance, 0) {
169        total += length;
170        if total > max_distance {
171            return total;
172        }
173    }
174
175    total
176}
177
178/// Calculate predicate lengths when expanded to normalized form.
179///
180/// For example: (A AND B) OR C -> [2, 2] because len(A OR C) = 2, len(B OR C) = 2
181///
182/// In CNF mode (dnf=false): OR distributes over AND
183///   x OR (y AND z) => (x OR y) AND (x OR z)
184///
185/// In DNF mode (dnf=true): AND distributes over OR
186///   x AND (y OR z) => (x AND y) OR (x AND z)
187fn predicate_lengths(
188    expression: &Expression,
189    dnf: bool,
190    max_distance: i64,
191    depth: i64,
192) -> Vec<i64> {
193    if depth > max_distance {
194        return vec![depth];
195    }
196
197    let expr = unwrap_paren(expression);
198
199    match expr {
200        // In CNF mode, OR is the distributing operator (we're breaking up ORs of ANDs)
201        Expression::Or(bin) if !dnf => {
202            // For CNF: OR causes multiplication in the distance calculation
203            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
204            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
205
206            let mut result = Vec::new();
207            for a in &left_lengths {
208                for b in &right_lengths {
209                    result.push(a + b);
210                }
211            }
212            result
213        }
214        // In DNF mode, AND is the distributing operator (we're breaking up ANDs of ORs)
215        Expression::And(bin) if dnf => {
216            // For DNF: AND causes multiplication in the distance calculation
217            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
218            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
219
220            let mut result = Vec::new();
221            for a in &left_lengths {
222                for b in &right_lengths {
223                    result.push(a + b);
224                }
225            }
226            result
227        }
228        // Non-distributing connectors: just collect lengths from both sides
229        Expression::And(bin) | Expression::Or(bin) => {
230            let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231            result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
232            result
233        }
234        _ => vec![1], // Leaf predicate
235    }
236}
237
238/// Apply the distributive law to normalize an expression.
239///
240/// CNF: x OR (y AND z) => (x OR y) AND (x OR z)
241/// DNF: x AND (y OR z) => (x AND y) OR (x AND z)
242fn apply_distributive_law(
243    expression: &Expression,
244    dnf: bool,
245    max_distance: i64,
246    simplifier: &Simplifier,
247) -> NormalizeResult<Expression> {
248    if normalized(expression, dnf) {
249        return Ok(expression.clone());
250    }
251
252    let distance = normalization_distance(expression, dnf, max_distance);
253    if distance > max_distance {
254        return Err(NormalizeError::DistanceExceeded {
255            distance,
256            max: max_distance,
257        });
258    }
259
260    // Apply distributive law based on mode
261    let result = if dnf {
262        distribute_dnf(expression, simplifier)
263    } else {
264        distribute_cnf(expression, simplifier)
265    };
266
267    // Recursively apply until normalized
268    if !normalized(&result, dnf) {
269        apply_distributive_law(&result, dnf, max_distance, simplifier)
270    } else {
271        Ok(result)
272    }
273}
274
275/// Apply distributive law for CNF conversion.
276/// x OR (y AND z) => (x OR y) AND (x OR z)
277fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
278    match expression {
279        Expression::Or(bin) => {
280            let left = distribute_cnf(&bin.left, simplifier);
281            let right = distribute_cnf(&bin.right, simplifier);
282
283            // Check if either side is an AND
284            if let Expression::And(and_bin) = &right {
285                // x OR (y AND z) => (x OR y) AND (x OR z)
286                let left_or_y = make_or(left.clone(), and_bin.left.clone());
287                let left_or_z = make_or(left, and_bin.right.clone());
288                return make_and(left_or_y, left_or_z);
289            }
290
291            if let Expression::And(and_bin) = &left {
292                // (y AND z) OR x => (y OR x) AND (z OR x)
293                let y_or_right = make_or(and_bin.left.clone(), right.clone());
294                let z_or_right = make_or(and_bin.right.clone(), right);
295                return make_and(y_or_right, z_or_right);
296            }
297
298            // No AND found, return simplified OR
299            make_or(left, right)
300        }
301        Expression::And(bin) => {
302            // Recurse into AND
303            let left = distribute_cnf(&bin.left, simplifier);
304            let right = distribute_cnf(&bin.right, simplifier);
305            make_and(left, right)
306        }
307        Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
308        _ => expression.clone(),
309    }
310}
311
312/// Apply distributive law for DNF conversion.
313/// x AND (y OR z) => (x AND y) OR (x AND z)
314fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
315    match expression {
316        Expression::And(bin) => {
317            let left = distribute_dnf(&bin.left, simplifier);
318            let right = distribute_dnf(&bin.right, simplifier);
319
320            // Check if either side is an OR
321            if let Expression::Or(or_bin) = &right {
322                // x AND (y OR z) => (x AND y) OR (x AND z)
323                let left_and_y = make_and(left.clone(), or_bin.left.clone());
324                let left_and_z = make_and(left, or_bin.right.clone());
325                return make_or(left_and_y, left_and_z);
326            }
327
328            if let Expression::Or(or_bin) = &left {
329                // (y OR z) AND x => (y AND x) OR (z AND x)
330                let y_and_right = make_and(or_bin.left.clone(), right.clone());
331                let z_and_right = make_and(or_bin.right.clone(), right);
332                return make_or(y_and_right, z_and_right);
333            }
334
335            // No OR found, return simplified AND
336            make_and(left, right)
337        }
338        Expression::Or(bin) => {
339            // Recurse into OR
340            let left = distribute_dnf(&bin.left, simplifier);
341            let right = distribute_dnf(&bin.right, simplifier);
342            make_or(left, right)
343        }
344        Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
345        _ => expression.clone(),
346    }
347}
348
349// ============================================================================
350// Helper functions
351// ============================================================================
352
353/// Count the number of connector nodes in an expression
354fn count_connectors(expression: &Expression) -> usize {
355    match expression {
356        Expression::And(bin) | Expression::Or(bin) => {
357            1 + count_connectors(&bin.left) + count_connectors(&bin.right)
358        }
359        Expression::Paren(paren) => count_connectors(&paren.this),
360        _ => 0,
361    }
362}
363
364/// Unwrap parentheses from an expression
365fn unwrap_paren(expression: &Expression) -> &Expression {
366    match expression {
367        Expression::Paren(paren) => unwrap_paren(&paren.this),
368        _ => expression,
369    }
370}
371
372/// Create an AND expression
373fn make_and(left: Expression, right: Expression) -> Expression {
374    Expression::And(Box::new(BinaryOp {
375        left,
376        right,
377        left_comments: vec![],
378        operator_comments: vec![],
379        trailing_comments: vec![],
380        inferred_type: None,
381    }))
382}
383
384/// Create an OR expression
385fn make_or(left: Expression, right: Expression) -> Expression {
386    Expression::Or(Box::new(BinaryOp {
387        left,
388        right,
389        left_comments: vec![],
390        operator_comments: vec![],
391        trailing_comments: vec![],
392        inferred_type: None,
393    }))
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::parser::Parser;
400
401    fn parse(sql: &str) -> Expression {
402        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
403    }
404
405    fn parse_predicate(sql: &str) -> Expression {
406        let full = format!("SELECT 1 WHERE {}", sql);
407        let stmt = parse(&full);
408        if let Expression::Select(select) = stmt {
409            if let Some(where_clause) = select.where_clause {
410                return where_clause.this;
411            }
412        }
413        panic!("Failed to extract predicate from: {}", sql);
414    }
415
416    #[test]
417    fn test_normalized_cnf() {
418        // (a OR b) AND (c OR d) is in CNF
419        let expr = parse_predicate("(a OR b) AND (c OR d)");
420        assert!(normalized(&expr, false)); // CNF
421    }
422
423    #[test]
424    fn test_normalized_dnf() {
425        // (a AND b) OR (c AND d) is in DNF
426        let expr = parse_predicate("(a AND b) OR (c AND d)");
427        assert!(normalized(&expr, true)); // DNF
428    }
429
430    #[test]
431    fn test_not_normalized_cnf() {
432        // (a AND b) OR c is NOT in CNF (has AND under OR)
433        let expr = parse_predicate("(a AND b) OR c");
434        assert!(!normalized(&expr, false)); // Not CNF
435    }
436
437    #[test]
438    fn test_not_normalized_dnf() {
439        // (a OR b) AND c is NOT in DNF (has OR under AND)
440        let expr = parse_predicate("(a OR b) AND c");
441        assert!(!normalized(&expr, true)); // Not DNF
442    }
443
444    #[test]
445    fn test_simple_literal_is_normalized() {
446        let expr = parse_predicate("a = 1");
447        assert!(normalized(&expr, false)); // CNF
448        assert!(normalized(&expr, true)); // DNF
449    }
450
451    #[test]
452    fn test_normalization_distance_simple() {
453        // Simple predicate should have low distance
454        let expr = parse_predicate("a = 1");
455        let distance = normalization_distance(&expr, false, 128);
456        assert!(distance <= 0);
457    }
458
459    #[test]
460    fn test_normalization_distance_complex() {
461        // (a AND b) OR (c AND d) requires expansion
462        let expr = parse_predicate("(a AND b) OR (c AND d)");
463        let distance = normalization_distance(&expr, false, 128);
464        assert!(distance > 0);
465    }
466
467    #[test]
468    fn test_normalize_to_cnf() {
469        // (x AND y) OR z => (x OR z) AND (y OR z)
470        let expr = parse_predicate("(x AND y) OR z");
471        let result = normalize(expr, false, 128).unwrap();
472
473        // Result should be in CNF
474        assert!(normalized(&result, false));
475    }
476
477    #[test]
478    fn test_normalize_to_dnf() {
479        // (x OR y) AND z => (x AND z) OR (y AND z)
480        let expr = parse_predicate("(x OR y) AND z");
481        let result = normalize(expr, true, 128).unwrap();
482
483        // Result should be in DNF
484        assert!(normalized(&result, true));
485    }
486
487    #[test]
488    fn test_count_connectors() {
489        let expr = parse_predicate("a AND b AND c");
490        let count = count_connectors(&expr);
491        assert_eq!(count, 2); // Two AND connectors
492    }
493
494    #[test]
495    fn test_predicate_lengths() {
496        // Simple case
497        let expr = parse_predicate("a = 1");
498        let lengths = predicate_lengths(&expr, false, 128, 0);
499        assert_eq!(lengths, vec![1]);
500    }
501}