llkv_expr/
normalization.rs

1//! Expression normalization logic.
2//!
3//! This module provides functions to normalize expressions, such as:
4//! - Flattening nested AND/OR expressions.
5//! - Pushing down NOT operators.
6//! - Normalizing comparisons (e.g., converting `Compare` to `Pred` where possible).
7
8use crate::expr::{CompareOp, Expr, Filter, Operator, ScalarExpr};
9use crate::literal::Literal;
10
11/// Normalize a predicate expression.
12///
13/// This applies several simplifications:
14/// - Flattens nested AND/OR chains.
15/// - Pushes NOT down to the leaves (De Morgan's laws).
16/// - Converts `Compare` expressions to `Pred` (Filter) when one side is a column and the other is a literal.
17/// - Simplifies boolean literals.
18pub fn normalize_predicate<'expr, F: Clone>(expr: Expr<'expr, F>) -> Expr<'expr, F> {
19    normalize_expr(expr)
20}
21
22fn normalize_expr<'expr, F: Clone>(expr: Expr<'expr, F>) -> Expr<'expr, F> {
23    match expr {
24        Expr::And(children) => {
25            let mut normalized = Vec::with_capacity(children.len());
26            for child in children {
27                let child = normalize_expr(child);
28                match child {
29                    Expr::And(nested) => normalized.extend(nested),
30                    other => normalized.push(other),
31                }
32            }
33            Expr::And(normalized)
34        }
35        Expr::Or(children) => {
36            let mut normalized = Vec::with_capacity(children.len());
37            for child in children {
38                let child = normalize_expr(child);
39                match child {
40                    Expr::Or(nested) => normalized.extend(nested),
41                    other => normalized.push(other),
42                }
43            }
44            Expr::Or(normalized)
45        }
46        Expr::Not(inner) => normalize_negated(*inner),
47        Expr::Compare { left, op, right } => normalize_compare(left, op, right),
48        other => other,
49    }
50}
51
52fn normalize_compare<'expr, F: Clone>(
53    left: ScalarExpr<F>,
54    op: CompareOp,
55    right: ScalarExpr<F>,
56) -> Expr<'expr, F> {
57    match (left, right) {
58        (ScalarExpr::Column(field_id), ScalarExpr::Literal(lit))
59            if !matches!(lit, Literal::Null) =>
60        {
61            match op {
62                CompareOp::Eq => Expr::Pred(Filter {
63                    field_id,
64                    op: Operator::Equals(lit),
65                }),
66                CompareOp::Gt => Expr::Pred(Filter {
67                    field_id,
68                    op: Operator::GreaterThan(lit),
69                }),
70                CompareOp::GtEq => Expr::Pred(Filter {
71                    field_id,
72                    op: Operator::GreaterThanOrEquals(lit),
73                }),
74                CompareOp::Lt => Expr::Pred(Filter {
75                    field_id,
76                    op: Operator::LessThan(lit),
77                }),
78                CompareOp::LtEq => Expr::Pred(Filter {
79                    field_id,
80                    op: Operator::LessThanOrEquals(lit),
81                }),
82                CompareOp::NotEq => Expr::Not(Box::new(Expr::Pred(Filter {
83                    field_id,
84                    op: Operator::Equals(lit),
85                }))),
86            }
87        }
88        (ScalarExpr::Literal(lit), ScalarExpr::Column(field_id))
89            if !matches!(lit, Literal::Null) =>
90        {
91            match op {
92                CompareOp::Eq => Expr::Pred(Filter {
93                    field_id,
94                    op: Operator::Equals(lit),
95                }),
96                CompareOp::Gt => Expr::Pred(Filter {
97                    field_id,
98                    op: Operator::LessThan(lit),
99                }),
100                CompareOp::GtEq => Expr::Pred(Filter {
101                    field_id,
102                    op: Operator::LessThanOrEquals(lit),
103                }),
104                CompareOp::Lt => Expr::Pred(Filter {
105                    field_id,
106                    op: Operator::GreaterThan(lit),
107                }),
108                CompareOp::LtEq => Expr::Pred(Filter {
109                    field_id,
110                    op: Operator::GreaterThanOrEquals(lit),
111                }),
112                CompareOp::NotEq => Expr::Not(Box::new(Expr::Pred(Filter {
113                    field_id,
114                    op: Operator::Equals(lit),
115                }))),
116            }
117        }
118        (left, right) => Expr::Compare { left, op, right },
119    }
120}
121
122fn normalize_negated<'expr, F: Clone>(inner: Expr<'expr, F>) -> Expr<'expr, F> {
123    match inner {
124        Expr::Not(nested) => normalize_expr(*nested),
125        Expr::And(children) => {
126            let mapped = children
127                .into_iter()
128                .map(|child| normalize_expr(Expr::Not(Box::new(child))))
129                .collect();
130            Expr::Or(mapped)
131        }
132        Expr::Or(children) => {
133            let mapped = children
134                .into_iter()
135                .map(|child| normalize_expr(Expr::Not(Box::new(child))))
136                .collect();
137            Expr::And(mapped)
138        }
139        Expr::Compare { left, op, right } => {
140            let negated_op = match op {
141                CompareOp::Eq => CompareOp::NotEq,
142                CompareOp::NotEq => CompareOp::Eq,
143                CompareOp::Gt => CompareOp::LtEq,
144                CompareOp::GtEq => CompareOp::Lt,
145                CompareOp::Lt => CompareOp::GtEq,
146                CompareOp::LtEq => CompareOp::Gt,
147            };
148            normalize_compare(left, negated_op, right)
149        }
150        Expr::Literal(value) => Expr::Literal(!value),
151        Expr::IsNull { expr, negated } => Expr::IsNull {
152            expr,
153            negated: !negated,
154        },
155        other => Expr::Not(Box::new(normalize_expr(other))),
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::literal::Literal;
163
164    // Simple integer ID for testing
165    type TestId = usize;
166
167    #[test]
168    fn normalize_not_between_expands_to_or() {
169        let field: TestId = 7;
170        let column = ScalarExpr::Column(field);
171        let lower = ScalarExpr::Literal(Literal::Int128(5));
172        let upper = ScalarExpr::Literal(Literal::Null);
173
174        let between = Expr::And(vec![
175            Expr::Compare {
176                left: column.clone(),
177                op: CompareOp::GtEq,
178                right: lower,
179            },
180            Expr::Compare {
181                left: column.clone(),
182                op: CompareOp::LtEq,
183                right: upper,
184            },
185        ]);
186
187        let normalized = normalize_predicate(Expr::Not(Box::new(between)));
188
189        let Expr::Or(children) = normalized else {
190            panic!("expected OR after normalization");
191        };
192        assert_eq!(children.len(), 2);
193
194        match &children[0] {
195            Expr::Pred(Filter {
196                op: Operator::LessThan(_),
197                ..
198            }) => {}
199            other => panic!("left branch should be Pred(LessThan), got {other:?}"),
200        }
201
202        match &children[1] {
203            Expr::Compare {
204                op: CompareOp::Gt,
205                right: ScalarExpr::Literal(Literal::Null),
206                ..
207            } => {}
208            other => panic!("right branch should be Compare(Gt, Null), got {other:?}"),
209        }
210    }
211
212    #[test]
213    fn normalize_flips_literal_bool() {
214        let normalized = normalize_predicate(Expr::<TestId>::Not(Box::new(Expr::Literal(true))));
215        assert!(matches!(normalized, Expr::Literal(false)));
216    }
217}