Skip to main content

polyglot_sql/optimizer/
normalize.rs

1//! Boolean Normalization Module
2//!
3//! This module provides functionality for converting SQL boolean expressions
4//! to Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
5//!
6//! CNF: (a OR b) AND (c OR d) - useful for predicate pushdown
7//! DNF: (a AND b) OR (c AND d) - useful for partition pruning
8//!
9//! Ported from sqlglot's optimizer/normalize.py
10
11use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15/// Maximum default distance for normalization
16pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18/// Errors that can occur during normalization
19#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21    #[error("Normalization distance {distance} exceeds max {max}")]
22    DistanceExceeded { distance: i64, max: i64 },
23}
24
25/// Result type for normalization operations
26pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28/// Rewrite SQL AST into Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
29///
30/// CNF (default): (x AND y) OR z => (x OR z) AND (y OR z)
31/// DNF: (x OR y) AND z => (x AND z) OR (y AND z)
32///
33/// # Arguments
34/// * `expression` - Expression to normalize
35/// * `dnf` - If true, convert to DNF; otherwise CNF (default)
36/// * `max_distance` - Maximum estimated distance before giving up
37///
38/// # Returns
39/// The normalized expression, or the original if normalization would be too expensive.
40pub fn normalize(
41    expression: Expression,
42    dnf: bool,
43    max_distance: i64,
44) -> NormalizeResult<Expression> {
45    let simplifier = Simplifier::new(None);
46    normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49/// Normalize with a provided simplifier instance.
50fn normalize_with_simplifier(
51    expression: Expression,
52    dnf: bool,
53    max_distance: i64,
54    simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56    let mut result = expression.clone();
57
58    // Walk through all connector nodes (AND/OR)
59    let connectors = collect_connectors(&expression);
60
61    for node in connectors {
62        if normalized(&node, dnf) {
63            continue;
64        }
65
66        // Estimate the cost of normalization
67        let distance = normalization_distance(&node, dnf, max_distance);
68
69        if distance > max_distance {
70            // Too expensive, return original
71            return Ok(expression);
72        }
73
74        // Apply distributive law repeatedly until normalized
75        let normalized_node = apply_distributive_law(&node, dnf, max_distance, simplifier)?;
76
77        // In a real implementation, we would replace the node in the AST
78        // For now, we just return the normalized version if it's the root
79        if is_same_expression(&node, &expression) {
80            result = normalized_node;
81        }
82    }
83
84    Ok(result)
85}
86
87/// Check whether a given expression is in a normal form.
88///
89/// CNF (Conjunctive Normal Form): (A OR B) AND (C OR D)
90///   - Conjunction (AND) of disjunctions (OR)
91///   - An OR cannot have an AND as a descendant
92///
93/// DNF (Disjunctive Normal Form): (A AND B) OR (C AND D)
94///   - Disjunction (OR) of conjunctions (AND)
95///   - An AND cannot have an OR as a descendant
96///
97/// # Arguments
98/// * `expression` - The expression to check
99/// * `dnf` - Whether to check for DNF (true) or CNF (false)
100///
101/// # Returns
102/// True if the expression is in the requested normal form.
103pub fn normalized(expression: &Expression, dnf: bool) -> bool {
104    if dnf {
105        // DNF: An AND cannot have OR as a descendant
106        !has_and_with_or_descendant(expression)
107    } else {
108        // CNF: An OR cannot have AND as a descendant
109        !has_or_with_and_descendant(expression)
110    }
111}
112
113/// Check if any OR in the expression has an AND as a descendant (violates CNF)
114fn has_or_with_and_descendant(expression: &Expression) -> bool {
115    match expression {
116        Expression::Or(bin) => {
117            // Check if either child is an AND, or if children have the violation
118            contains_and(&bin.left)
119                || contains_and(&bin.right)
120                || has_or_with_and_descendant(&bin.left)
121                || has_or_with_and_descendant(&bin.right)
122        }
123        Expression::And(bin) => {
124            has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
125        }
126        Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
127        _ => false,
128    }
129}
130
131/// Check if any AND in the expression has an OR as a descendant (violates DNF)
132fn has_and_with_or_descendant(expression: &Expression) -> bool {
133    match expression {
134        Expression::And(bin) => {
135            // Check if either child is an OR, or if children have the violation
136            contains_or(&bin.left)
137                || contains_or(&bin.right)
138                || has_and_with_or_descendant(&bin.left)
139                || has_and_with_or_descendant(&bin.right)
140        }
141        Expression::Or(bin) => {
142            has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
143        }
144        Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
145        _ => false,
146    }
147}
148
149/// Check if expression contains any AND (at any level)
150fn contains_and(expression: &Expression) -> bool {
151    match expression {
152        Expression::And(_) => true,
153        Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
154        Expression::Paren(paren) => contains_and(&paren.this),
155        _ => false,
156    }
157}
158
159/// Check if expression contains any OR (at any level)
160fn contains_or(expression: &Expression) -> bool {
161    match expression {
162        Expression::Or(_) => true,
163        Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
164        Expression::Paren(paren) => contains_or(&paren.this),
165        _ => false,
166    }
167}
168
169/// Calculate the normalization distance for an expression.
170///
171/// This estimates the cost of converting to normal form.
172/// The conversion is exponential in complexity, so this helps decide
173/// whether to attempt it.
174///
175/// # Arguments
176/// * `expression` - The expression to analyze
177/// * `dnf` - Whether checking distance to DNF (true) or CNF (false)
178/// * `max_distance` - Early exit if distance exceeds this
179///
180/// # Returns
181/// The estimated normalization distance.
182pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
183    let connector_count = count_connectors(expression);
184    let mut total: i64 = -(connector_count as i64 + 1);
185
186    for length in predicate_lengths(expression, dnf, max_distance, 0) {
187        total += length;
188        if total > max_distance {
189            return total;
190        }
191    }
192
193    total
194}
195
196/// Calculate predicate lengths when expanded to normalized form.
197///
198/// For example: (A AND B) OR C -> [2, 2] because len(A OR C) = 2, len(B OR C) = 2
199///
200/// In CNF mode (dnf=false): OR distributes over AND
201///   x OR (y AND z) => (x OR y) AND (x OR z)
202///
203/// In DNF mode (dnf=true): AND distributes over OR
204///   x AND (y OR z) => (x AND y) OR (x AND z)
205fn predicate_lengths(
206    expression: &Expression,
207    dnf: bool,
208    max_distance: i64,
209    depth: i64,
210) -> Vec<i64> {
211    if depth > max_distance {
212        return vec![depth];
213    }
214
215    let expr = unwrap_paren(expression);
216
217    match expr {
218        // In CNF mode, OR is the distributing operator (we're breaking up ORs of ANDs)
219        Expression::Or(bin) if !dnf => {
220            // For CNF: OR causes multiplication in the distance calculation
221            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
222            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
223
224            let mut result = Vec::new();
225            for a in &left_lengths {
226                for b in &right_lengths {
227                    result.push(a + b);
228                }
229            }
230            result
231        }
232        // In DNF mode, AND is the distributing operator (we're breaking up ANDs of ORs)
233        Expression::And(bin) if dnf => {
234            // For DNF: AND causes multiplication in the distance calculation
235            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
236            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
237
238            let mut result = Vec::new();
239            for a in &left_lengths {
240                for b in &right_lengths {
241                    result.push(a + b);
242                }
243            }
244            result
245        }
246        // Non-distributing connectors: just collect lengths from both sides
247        Expression::And(bin) | Expression::Or(bin) => {
248            let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
249            result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
250            result
251        }
252        _ => vec![1], // Leaf predicate
253    }
254}
255
256/// Apply the distributive law to normalize an expression.
257///
258/// CNF: x OR (y AND z) => (x OR y) AND (x OR z)
259/// DNF: x AND (y OR z) => (x AND y) OR (x AND z)
260fn apply_distributive_law(
261    expression: &Expression,
262    dnf: bool,
263    max_distance: i64,
264    simplifier: &Simplifier,
265) -> NormalizeResult<Expression> {
266    if normalized(expression, dnf) {
267        return Ok(expression.clone());
268    }
269
270    let distance = normalization_distance(expression, dnf, max_distance);
271    if distance > max_distance {
272        return Err(NormalizeError::DistanceExceeded {
273            distance,
274            max: max_distance,
275        });
276    }
277
278    // Apply distributive law based on mode
279    let result = if dnf {
280        distribute_dnf(expression, simplifier)
281    } else {
282        distribute_cnf(expression, simplifier)
283    };
284
285    // Recursively apply until normalized
286    if !normalized(&result, dnf) {
287        apply_distributive_law(&result, dnf, max_distance, simplifier)
288    } else {
289        Ok(result)
290    }
291}
292
293/// Apply distributive law for CNF conversion.
294/// x OR (y AND z) => (x OR y) AND (x OR z)
295fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
296    match expression {
297        Expression::Or(bin) => {
298            let left = distribute_cnf(&bin.left, simplifier);
299            let right = distribute_cnf(&bin.right, simplifier);
300
301            // Check if either side is an AND
302            if let Expression::And(and_bin) = &right {
303                // x OR (y AND z) => (x OR y) AND (x OR z)
304                let left_or_y = make_or(left.clone(), and_bin.left.clone());
305                let left_or_z = make_or(left, and_bin.right.clone());
306                return make_and(left_or_y, left_or_z);
307            }
308
309            if let Expression::And(and_bin) = &left {
310                // (y AND z) OR x => (y OR x) AND (z OR x)
311                let y_or_right = make_or(and_bin.left.clone(), right.clone());
312                let z_or_right = make_or(and_bin.right.clone(), right);
313                return make_and(y_or_right, z_or_right);
314            }
315
316            // No AND found, return simplified OR
317            make_or(left, right)
318        }
319        Expression::And(bin) => {
320            // Recurse into AND
321            let left = distribute_cnf(&bin.left, simplifier);
322            let right = distribute_cnf(&bin.right, simplifier);
323            make_and(left, right)
324        }
325        Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
326        _ => expression.clone(),
327    }
328}
329
330/// Apply distributive law for DNF conversion.
331/// x AND (y OR z) => (x AND y) OR (x AND z)
332fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
333    match expression {
334        Expression::And(bin) => {
335            let left = distribute_dnf(&bin.left, simplifier);
336            let right = distribute_dnf(&bin.right, simplifier);
337
338            // Check if either side is an OR
339            if let Expression::Or(or_bin) = &right {
340                // x AND (y OR z) => (x AND y) OR (x AND z)
341                let left_and_y = make_and(left.clone(), or_bin.left.clone());
342                let left_and_z = make_and(left, or_bin.right.clone());
343                return make_or(left_and_y, left_and_z);
344            }
345
346            if let Expression::Or(or_bin) = &left {
347                // (y OR z) AND x => (y AND x) OR (z AND x)
348                let y_and_right = make_and(or_bin.left.clone(), right.clone());
349                let z_and_right = make_and(or_bin.right.clone(), right);
350                return make_or(y_and_right, z_and_right);
351            }
352
353            // No OR found, return simplified AND
354            make_and(left, right)
355        }
356        Expression::Or(bin) => {
357            // Recurse into OR
358            let left = distribute_dnf(&bin.left, simplifier);
359            let right = distribute_dnf(&bin.right, simplifier);
360            make_or(left, right)
361        }
362        Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
363        _ => expression.clone(),
364    }
365}
366
367// ============================================================================
368// Helper functions
369// ============================================================================
370
371/// Collect all connector (AND/OR) expressions from an AST
372fn collect_connectors(expression: &Expression) -> Vec<Expression> {
373    let mut result = Vec::new();
374    collect_connectors_recursive(expression, &mut result);
375    result
376}
377
378fn collect_connectors_recursive(expression: &Expression, result: &mut Vec<Expression>) {
379    match expression {
380        Expression::And(bin) => {
381            result.push(expression.clone());
382            collect_connectors_recursive(&bin.left, result);
383            collect_connectors_recursive(&bin.right, result);
384        }
385        Expression::Or(bin) => {
386            result.push(expression.clone());
387            collect_connectors_recursive(&bin.left, result);
388            collect_connectors_recursive(&bin.right, result);
389        }
390        Expression::Paren(paren) => {
391            collect_connectors_recursive(&paren.this, result);
392        }
393        _ => {}
394    }
395}
396
397/// Count the number of connector nodes in an expression
398fn count_connectors(expression: &Expression) -> usize {
399    match expression {
400        Expression::And(bin) | Expression::Or(bin) => {
401            1 + count_connectors(&bin.left) + count_connectors(&bin.right)
402        }
403        Expression::Paren(paren) => count_connectors(&paren.this),
404        _ => 0,
405    }
406}
407
408/// Unwrap parentheses from an expression
409fn unwrap_paren(expression: &Expression) -> &Expression {
410    match expression {
411        Expression::Paren(paren) => unwrap_paren(&paren.this),
412        _ => expression,
413    }
414}
415
416/// Check if two expressions are the same (simple identity check)
417fn is_same_expression(a: &Expression, b: &Expression) -> bool {
418    // Simple identity check - in a real implementation, this would be more sophisticated
419    std::ptr::eq(a as *const _, b as *const _) || format!("{:?}", a) == format!("{:?}", b)
420}
421
422/// Create an AND expression
423fn make_and(left: Expression, right: Expression) -> Expression {
424    Expression::And(Box::new(BinaryOp {
425        left,
426        right,
427        left_comments: vec![],
428        operator_comments: vec![],
429        trailing_comments: vec![],
430    }))
431}
432
433/// Create an OR expression
434fn make_or(left: Expression, right: Expression) -> Expression {
435    Expression::Or(Box::new(BinaryOp {
436        left,
437        right,
438        left_comments: vec![],
439        operator_comments: vec![],
440        trailing_comments: vec![],
441    }))
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::parser::Parser;
448
449    fn parse(sql: &str) -> Expression {
450        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
451    }
452
453    fn parse_predicate(sql: &str) -> Expression {
454        let full = format!("SELECT 1 WHERE {}", sql);
455        let stmt = parse(&full);
456        if let Expression::Select(select) = stmt {
457            if let Some(where_clause) = select.where_clause {
458                return where_clause.this;
459            }
460        }
461        panic!("Failed to extract predicate from: {}", sql);
462    }
463
464    #[test]
465    fn test_normalized_cnf() {
466        // (a OR b) AND (c OR d) is in CNF
467        let expr = parse_predicate("(a OR b) AND (c OR d)");
468        assert!(normalized(&expr, false)); // CNF
469    }
470
471    #[test]
472    fn test_normalized_dnf() {
473        // (a AND b) OR (c AND d) is in DNF
474        let expr = parse_predicate("(a AND b) OR (c AND d)");
475        assert!(normalized(&expr, true)); // DNF
476    }
477
478    #[test]
479    fn test_not_normalized_cnf() {
480        // (a AND b) OR c is NOT in CNF (has AND under OR)
481        let expr = parse_predicate("(a AND b) OR c");
482        assert!(!normalized(&expr, false)); // Not CNF
483    }
484
485    #[test]
486    fn test_not_normalized_dnf() {
487        // (a OR b) AND c is NOT in DNF (has OR under AND)
488        let expr = parse_predicate("(a OR b) AND c");
489        assert!(!normalized(&expr, true)); // Not DNF
490    }
491
492    #[test]
493    fn test_simple_literal_is_normalized() {
494        let expr = parse_predicate("a = 1");
495        assert!(normalized(&expr, false)); // CNF
496        assert!(normalized(&expr, true)); // DNF
497    }
498
499    #[test]
500    fn test_normalization_distance_simple() {
501        // Simple predicate should have low distance
502        let expr = parse_predicate("a = 1");
503        let distance = normalization_distance(&expr, false, 128);
504        assert!(distance <= 0);
505    }
506
507    #[test]
508    fn test_normalization_distance_complex() {
509        // (a AND b) OR (c AND d) requires expansion
510        let expr = parse_predicate("(a AND b) OR (c AND d)");
511        let distance = normalization_distance(&expr, false, 128);
512        assert!(distance > 0);
513    }
514
515    #[test]
516    fn test_normalize_to_cnf() {
517        // (x AND y) OR z => (x OR z) AND (y OR z)
518        let expr = parse_predicate("(x AND y) OR z");
519        let result = normalize(expr, false, 128).unwrap();
520
521        // Result should be in CNF
522        assert!(normalized(&result, false));
523    }
524
525    #[test]
526    fn test_normalize_to_dnf() {
527        // (x OR y) AND z => (x AND z) OR (y AND z)
528        let expr = parse_predicate("(x OR y) AND z");
529        let result = normalize(expr, true, 128).unwrap();
530
531        // Result should be in DNF
532        assert!(normalized(&result, true));
533    }
534
535    #[test]
536    fn test_count_connectors() {
537        let expr = parse_predicate("a AND b AND c");
538        let count = count_connectors(&expr);
539        assert_eq!(count, 2); // Two AND connectors
540    }
541
542    #[test]
543    fn test_predicate_lengths() {
544        // Simple case
545        let expr = parse_predicate("a = 1");
546        let lengths = predicate_lengths(&expr, false, 128, 0);
547        assert_eq!(lengths, vec![1]);
548    }
549}