Skip to main content

polyglot_sql/optimizer/
normalize.rs

1//! Boolean Normalization Module
2//!
3//! This module provides functionality for converting SQL boolean expressions
4//! to Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
5//!
6//! CNF: (a OR b) AND (c OR d) - useful for predicate pushdown
7//! DNF: (a AND b) OR (c AND d) - useful for partition pruning
8//!
9//! Ported from sqlglot's optimizer/normalize.py
10
11use crate::expressions::{BinaryOp, Expression};
12use crate::optimizer::simplify::Simplifier;
13use thiserror::Error;
14
15/// Maximum default distance for normalization
16pub const DEFAULT_MAX_DISTANCE: i64 = 128;
17
18/// Errors that can occur during normalization
19#[derive(Debug, Error, Clone)]
20pub enum NormalizeError {
21    #[error("Normalization distance {distance} exceeds max {max}")]
22    DistanceExceeded { distance: i64, max: i64 },
23}
24
25/// Result type for normalization operations
26pub type NormalizeResult<T> = Result<T, NormalizeError>;
27
28/// Rewrite SQL AST into Conjunctive Normal Form (CNF) or Disjunctive Normal Form (DNF).
29///
30/// CNF (default): (x AND y) OR z => (x OR z) AND (y OR z)
31/// DNF: (x OR y) AND z => (x AND z) OR (y AND z)
32///
33/// # Arguments
34/// * `expression` - Expression to normalize
35/// * `dnf` - If true, convert to DNF; otherwise CNF (default)
36/// * `max_distance` - Maximum estimated distance before giving up
37///
38/// # Returns
39/// The normalized expression, or the original if normalization would be too expensive.
40pub fn normalize(
41    expression: Expression,
42    dnf: bool,
43    max_distance: i64,
44) -> NormalizeResult<Expression> {
45    let simplifier = Simplifier::new(None);
46    normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
47}
48
49/// Normalize with a provided simplifier instance.
50fn normalize_with_simplifier(
51    expression: Expression,
52    dnf: bool,
53    max_distance: i64,
54    simplifier: &Simplifier,
55) -> NormalizeResult<Expression> {
56    let mut result = expression.clone();
57
58    // Walk through all connector nodes (AND/OR)
59    let connectors = collect_connectors(&expression);
60
61    for node in connectors {
62        if normalized(&node, dnf) {
63            continue;
64        }
65
66        // Estimate the cost of normalization
67        let distance = normalization_distance(&node, dnf, max_distance);
68
69        if distance > max_distance {
70            // Too expensive, return original
71            return Ok(expression);
72        }
73
74        // Apply distributive law repeatedly until normalized
75        let normalized_node = apply_distributive_law(&node, dnf, max_distance, simplifier)?;
76
77        // In a real implementation, we would replace the node in the AST
78        // For now, we just return the normalized version if it's the root
79        if is_same_expression(&node, &expression) {
80            result = normalized_node;
81        }
82    }
83
84    Ok(result)
85}
86
87/// Check whether a given expression is in a normal form.
88///
89/// CNF (Conjunctive Normal Form): (A OR B) AND (C OR D)
90///   - Conjunction (AND) of disjunctions (OR)
91///   - An OR cannot have an AND as a descendant
92///
93/// DNF (Disjunctive Normal Form): (A AND B) OR (C AND D)
94///   - Disjunction (OR) of conjunctions (AND)
95///   - An AND cannot have an OR as a descendant
96///
97/// # Arguments
98/// * `expression` - The expression to check
99/// * `dnf` - Whether to check for DNF (true) or CNF (false)
100///
101/// # Returns
102/// True if the expression is in the requested normal form.
103pub fn normalized(expression: &Expression, dnf: bool) -> bool {
104    if dnf {
105        // DNF: An AND cannot have OR as a descendant
106        !has_and_with_or_descendant(expression)
107    } else {
108        // CNF: An OR cannot have AND as a descendant
109        !has_or_with_and_descendant(expression)
110    }
111}
112
113/// Check if any OR in the expression has an AND as a descendant (violates CNF)
114fn has_or_with_and_descendant(expression: &Expression) -> bool {
115    match expression {
116        Expression::Or(bin) => {
117            // Check if either child is an AND, or if children have the violation
118            contains_and(&bin.left)
119                || contains_and(&bin.right)
120                || has_or_with_and_descendant(&bin.left)
121                || has_or_with_and_descendant(&bin.right)
122        }
123        Expression::And(bin) => {
124            has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
125        }
126        Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
127        _ => false,
128    }
129}
130
131/// Check if any AND in the expression has an OR as a descendant (violates DNF)
132fn has_and_with_or_descendant(expression: &Expression) -> bool {
133    match expression {
134        Expression::And(bin) => {
135            // Check if either child is an OR, or if children have the violation
136            contains_or(&bin.left)
137                || contains_or(&bin.right)
138                || has_and_with_or_descendant(&bin.left)
139                || has_and_with_or_descendant(&bin.right)
140        }
141        Expression::Or(bin) => {
142            has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
143        }
144        Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
145        _ => false,
146    }
147}
148
149/// Check if expression contains any AND (at any level)
150fn contains_and(expression: &Expression) -> bool {
151    match expression {
152        Expression::And(_) => true,
153        Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
154        Expression::Paren(paren) => contains_and(&paren.this),
155        _ => false,
156    }
157}
158
159/// Check if expression contains any OR (at any level)
160fn contains_or(expression: &Expression) -> bool {
161    match expression {
162        Expression::Or(_) => true,
163        Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
164        Expression::Paren(paren) => contains_or(&paren.this),
165        _ => false,
166    }
167}
168
169/// Calculate the normalization distance for an expression.
170///
171/// This estimates the cost of converting to normal form.
172/// The conversion is exponential in complexity, so this helps decide
173/// whether to attempt it.
174///
175/// # Arguments
176/// * `expression` - The expression to analyze
177/// * `dnf` - Whether checking distance to DNF (true) or CNF (false)
178/// * `max_distance` - Early exit if distance exceeds this
179///
180/// # Returns
181/// The estimated normalization distance.
182pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
183    let connector_count = count_connectors(expression);
184    let mut total: i64 = -(connector_count as i64 + 1);
185
186    for length in predicate_lengths(expression, dnf, max_distance, 0) {
187        total += length;
188        if total > max_distance {
189            return total;
190        }
191    }
192
193    total
194}
195
196/// Calculate predicate lengths when expanded to normalized form.
197///
198/// For example: (A AND B) OR C -> [2, 2] because len(A OR C) = 2, len(B OR C) = 2
199///
200/// In CNF mode (dnf=false): OR distributes over AND
201///   x OR (y AND z) => (x OR y) AND (x OR z)
202///
203/// In DNF mode (dnf=true): AND distributes over OR
204///   x AND (y OR z) => (x AND y) OR (x AND z)
205fn predicate_lengths(expression: &Expression, dnf: bool, max_distance: i64, depth: i64) -> Vec<i64> {
206    if depth > max_distance {
207        return vec![depth];
208    }
209
210    let expr = unwrap_paren(expression);
211
212    match expr {
213        // In CNF mode, OR is the distributing operator (we're breaking up ORs of ANDs)
214        Expression::Or(bin) if !dnf => {
215            // For CNF: OR causes multiplication in the distance calculation
216            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
217            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
218
219            let mut result = Vec::new();
220            for a in &left_lengths {
221                for b in &right_lengths {
222                    result.push(a + b);
223                }
224            }
225            result
226        }
227        // In DNF mode, AND is the distributing operator (we're breaking up ANDs of ORs)
228        Expression::And(bin) if dnf => {
229            // For DNF: AND causes multiplication in the distance calculation
230            let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
231            let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
232
233            let mut result = Vec::new();
234            for a in &left_lengths {
235                for b in &right_lengths {
236                    result.push(a + b);
237                }
238            }
239            result
240        }
241        // Non-distributing connectors: just collect lengths from both sides
242        Expression::And(bin) | Expression::Or(bin) => {
243            let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
244            result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
245            result
246        }
247        _ => vec![1], // Leaf predicate
248    }
249}
250
251/// Apply the distributive law to normalize an expression.
252///
253/// CNF: x OR (y AND z) => (x OR y) AND (x OR z)
254/// DNF: x AND (y OR z) => (x AND y) OR (x AND z)
255fn apply_distributive_law(
256    expression: &Expression,
257    dnf: bool,
258    max_distance: i64,
259    simplifier: &Simplifier,
260) -> NormalizeResult<Expression> {
261    if normalized(expression, dnf) {
262        return Ok(expression.clone());
263    }
264
265    let distance = normalization_distance(expression, dnf, max_distance);
266    if distance > max_distance {
267        return Err(NormalizeError::DistanceExceeded {
268            distance,
269            max: max_distance,
270        });
271    }
272
273    // Apply distributive law based on mode
274    let result = if dnf {
275        distribute_dnf(expression, simplifier)
276    } else {
277        distribute_cnf(expression, simplifier)
278    };
279
280    // Recursively apply until normalized
281    if !normalized(&result, dnf) {
282        apply_distributive_law(&result, dnf, max_distance, simplifier)
283    } else {
284        Ok(result)
285    }
286}
287
288/// Apply distributive law for CNF conversion.
289/// x OR (y AND z) => (x OR y) AND (x OR z)
290fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
291    match expression {
292        Expression::Or(bin) => {
293            let left = distribute_cnf(&bin.left, simplifier);
294            let right = distribute_cnf(&bin.right, simplifier);
295
296            // Check if either side is an AND
297            if let Expression::And(and_bin) = &right {
298                // x OR (y AND z) => (x OR y) AND (x OR z)
299                let left_or_y = make_or(left.clone(), and_bin.left.clone());
300                let left_or_z = make_or(left, and_bin.right.clone());
301                return make_and(left_or_y, left_or_z);
302            }
303
304            if let Expression::And(and_bin) = &left {
305                // (y AND z) OR x => (y OR x) AND (z OR x)
306                let y_or_right = make_or(and_bin.left.clone(), right.clone());
307                let z_or_right = make_or(and_bin.right.clone(), right);
308                return make_and(y_or_right, z_or_right);
309            }
310
311            // No AND found, return simplified OR
312            make_or(left, right)
313        }
314        Expression::And(bin) => {
315            // Recurse into AND
316            let left = distribute_cnf(&bin.left, simplifier);
317            let right = distribute_cnf(&bin.right, simplifier);
318            make_and(left, right)
319        }
320        Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
321        _ => expression.clone(),
322    }
323}
324
325/// Apply distributive law for DNF conversion.
326/// x AND (y OR z) => (x AND y) OR (x AND z)
327fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
328    match expression {
329        Expression::And(bin) => {
330            let left = distribute_dnf(&bin.left, simplifier);
331            let right = distribute_dnf(&bin.right, simplifier);
332
333            // Check if either side is an OR
334            if let Expression::Or(or_bin) = &right {
335                // x AND (y OR z) => (x AND y) OR (x AND z)
336                let left_and_y = make_and(left.clone(), or_bin.left.clone());
337                let left_and_z = make_and(left, or_bin.right.clone());
338                return make_or(left_and_y, left_and_z);
339            }
340
341            if let Expression::Or(or_bin) = &left {
342                // (y OR z) AND x => (y AND x) OR (z AND x)
343                let y_and_right = make_and(or_bin.left.clone(), right.clone());
344                let z_and_right = make_and(or_bin.right.clone(), right);
345                return make_or(y_and_right, z_and_right);
346            }
347
348            // No OR found, return simplified AND
349            make_and(left, right)
350        }
351        Expression::Or(bin) => {
352            // Recurse into OR
353            let left = distribute_dnf(&bin.left, simplifier);
354            let right = distribute_dnf(&bin.right, simplifier);
355            make_or(left, right)
356        }
357        Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
358        _ => expression.clone(),
359    }
360}
361
362// ============================================================================
363// Helper functions
364// ============================================================================
365
366/// Collect all connector (AND/OR) expressions from an AST
367fn collect_connectors(expression: &Expression) -> Vec<Expression> {
368    let mut result = Vec::new();
369    collect_connectors_recursive(expression, &mut result);
370    result
371}
372
373fn collect_connectors_recursive(expression: &Expression, result: &mut Vec<Expression>) {
374    match expression {
375        Expression::And(bin) => {
376            result.push(expression.clone());
377            collect_connectors_recursive(&bin.left, result);
378            collect_connectors_recursive(&bin.right, result);
379        }
380        Expression::Or(bin) => {
381            result.push(expression.clone());
382            collect_connectors_recursive(&bin.left, result);
383            collect_connectors_recursive(&bin.right, result);
384        }
385        Expression::Paren(paren) => {
386            collect_connectors_recursive(&paren.this, result);
387        }
388        _ => {}
389    }
390}
391
392/// Count the number of connector nodes in an expression
393fn count_connectors(expression: &Expression) -> usize {
394    match expression {
395        Expression::And(bin) | Expression::Or(bin) => {
396            1 + count_connectors(&bin.left) + count_connectors(&bin.right)
397        }
398        Expression::Paren(paren) => count_connectors(&paren.this),
399        _ => 0,
400    }
401}
402
403/// Unwrap parentheses from an expression
404fn unwrap_paren(expression: &Expression) -> &Expression {
405    match expression {
406        Expression::Paren(paren) => unwrap_paren(&paren.this),
407        _ => expression,
408    }
409}
410
411/// Check if two expressions are the same (simple identity check)
412fn is_same_expression(a: &Expression, b: &Expression) -> bool {
413    // Simple identity check - in a real implementation, this would be more sophisticated
414    std::ptr::eq(a as *const _, b as *const _) || format!("{:?}", a) == format!("{:?}", b)
415}
416
417/// Create an AND expression
418fn make_and(left: Expression, right: Expression) -> Expression {
419    Expression::And(Box::new(BinaryOp {
420        left,
421        right,
422        left_comments: vec![],
423        operator_comments: vec![],
424        trailing_comments: vec![],
425    }))
426}
427
428/// Create an OR expression
429fn make_or(left: Expression, right: Expression) -> Expression {
430    Expression::Or(Box::new(BinaryOp {
431        left,
432        right,
433        left_comments: vec![],
434        operator_comments: vec![],
435        trailing_comments: vec![],
436    }))
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use crate::generator::Generator;
443    use crate::parser::Parser;
444
445    fn gen(expr: &Expression) -> String {
446        Generator::new().generate(expr).unwrap()
447    }
448
449    fn parse(sql: &str) -> Expression {
450        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
451    }
452
453    fn parse_predicate(sql: &str) -> Expression {
454        let full = format!("SELECT 1 WHERE {}", sql);
455        let stmt = parse(&full);
456        if let Expression::Select(select) = stmt {
457            if let Some(where_clause) = select.where_clause {
458                return where_clause.this;
459            }
460        }
461        panic!("Failed to extract predicate from: {}", sql);
462    }
463
464    #[test]
465    fn test_normalized_cnf() {
466        // (a OR b) AND (c OR d) is in CNF
467        let expr = parse_predicate("(a OR b) AND (c OR d)");
468        assert!(normalized(&expr, false)); // CNF
469    }
470
471    #[test]
472    fn test_normalized_dnf() {
473        // (a AND b) OR (c AND d) is in DNF
474        let expr = parse_predicate("(a AND b) OR (c AND d)");
475        assert!(normalized(&expr, true)); // DNF
476    }
477
478    #[test]
479    fn test_not_normalized_cnf() {
480        // (a AND b) OR c is NOT in CNF (has AND under OR)
481        let expr = parse_predicate("(a AND b) OR c");
482        assert!(!normalized(&expr, false)); // Not CNF
483    }
484
485    #[test]
486    fn test_not_normalized_dnf() {
487        // (a OR b) AND c is NOT in DNF (has OR under AND)
488        let expr = parse_predicate("(a OR b) AND c");
489        assert!(!normalized(&expr, true)); // Not DNF
490    }
491
492    #[test]
493    fn test_simple_literal_is_normalized() {
494        let expr = parse_predicate("a = 1");
495        assert!(normalized(&expr, false)); // CNF
496        assert!(normalized(&expr, true)); // DNF
497    }
498
499    #[test]
500    fn test_normalization_distance_simple() {
501        // Simple predicate should have low distance
502        let expr = parse_predicate("a = 1");
503        let distance = normalization_distance(&expr, false, 128);
504        assert!(distance <= 0);
505    }
506
507    #[test]
508    fn test_normalization_distance_complex() {
509        // (a AND b) OR (c AND d) requires expansion
510        let expr = parse_predicate("(a AND b) OR (c AND d)");
511        let distance = normalization_distance(&expr, false, 128);
512        assert!(distance > 0);
513    }
514
515    #[test]
516    fn test_normalize_to_cnf() {
517        // (x AND y) OR z => (x OR z) AND (y OR z)
518        let expr = parse_predicate("(x AND y) OR z");
519        let result = normalize(expr, false, 128).unwrap();
520
521        // Result should be in CNF
522        assert!(normalized(&result, false));
523    }
524
525    #[test]
526    fn test_normalize_to_dnf() {
527        // (x OR y) AND z => (x AND z) OR (y AND z)
528        let expr = parse_predicate("(x OR y) AND z");
529        let result = normalize(expr, true, 128).unwrap();
530
531        // Result should be in DNF
532        assert!(normalized(&result, true));
533    }
534
535    #[test]
536    fn test_count_connectors() {
537        let expr = parse_predicate("a AND b AND c");
538        let count = count_connectors(&expr);
539        assert_eq!(count, 2); // Two AND connectors
540    }
541
542    #[test]
543    fn test_predicate_lengths() {
544        // Simple case
545        let expr = parse_predicate("a = 1");
546        let lengths = predicate_lengths(&expr, false, 128, 0);
547        assert_eq!(lengths, vec![1]);
548    }
549}