1use crate::expr::{CompareOp, Expr, Filter, Operator, ScalarExpr};
9use crate::literal::Literal;
10
11pub fn normalize_predicate<'expr, F: Clone>(expr: Expr<'expr, F>) -> Expr<'expr, F> {
19 normalize_expr(expr)
20}
21
22fn normalize_expr<'expr, F: Clone>(expr: Expr<'expr, F>) -> Expr<'expr, F> {
23 match expr {
24 Expr::And(children) => {
25 let mut normalized = Vec::with_capacity(children.len());
26 for child in children {
27 let child = normalize_expr(child);
28 match child {
29 Expr::And(nested) => normalized.extend(nested),
30 other => normalized.push(other),
31 }
32 }
33 Expr::And(normalized)
34 }
35 Expr::Or(children) => {
36 let mut normalized = Vec::with_capacity(children.len());
37 for child in children {
38 let child = normalize_expr(child);
39 match child {
40 Expr::Or(nested) => normalized.extend(nested),
41 other => normalized.push(other),
42 }
43 }
44 Expr::Or(normalized)
45 }
46 Expr::Not(inner) => normalize_negated(*inner),
47 Expr::Compare { left, op, right } => normalize_compare(left, op, right),
48 other => other,
49 }
50}
51
52fn normalize_compare<'expr, F: Clone>(
53 left: ScalarExpr<F>,
54 op: CompareOp,
55 right: ScalarExpr<F>,
56) -> Expr<'expr, F> {
57 match (left, right) {
58 (ScalarExpr::Column(field_id), ScalarExpr::Literal(lit))
59 if !matches!(lit, Literal::Null) =>
60 {
61 match op {
62 CompareOp::Eq => Expr::Pred(Filter {
63 field_id,
64 op: Operator::Equals(lit),
65 }),
66 CompareOp::Gt => Expr::Pred(Filter {
67 field_id,
68 op: Operator::GreaterThan(lit),
69 }),
70 CompareOp::GtEq => Expr::Pred(Filter {
71 field_id,
72 op: Operator::GreaterThanOrEquals(lit),
73 }),
74 CompareOp::Lt => Expr::Pred(Filter {
75 field_id,
76 op: Operator::LessThan(lit),
77 }),
78 CompareOp::LtEq => Expr::Pred(Filter {
79 field_id,
80 op: Operator::LessThanOrEquals(lit),
81 }),
82 CompareOp::NotEq => Expr::Not(Box::new(Expr::Pred(Filter {
83 field_id,
84 op: Operator::Equals(lit),
85 }))),
86 }
87 }
88 (ScalarExpr::Literal(lit), ScalarExpr::Column(field_id))
89 if !matches!(lit, Literal::Null) =>
90 {
91 match op {
92 CompareOp::Eq => Expr::Pred(Filter {
93 field_id,
94 op: Operator::Equals(lit),
95 }),
96 CompareOp::Gt => Expr::Pred(Filter {
97 field_id,
98 op: Operator::LessThan(lit),
99 }),
100 CompareOp::GtEq => Expr::Pred(Filter {
101 field_id,
102 op: Operator::LessThanOrEquals(lit),
103 }),
104 CompareOp::Lt => Expr::Pred(Filter {
105 field_id,
106 op: Operator::GreaterThan(lit),
107 }),
108 CompareOp::LtEq => Expr::Pred(Filter {
109 field_id,
110 op: Operator::GreaterThanOrEquals(lit),
111 }),
112 CompareOp::NotEq => Expr::Not(Box::new(Expr::Pred(Filter {
113 field_id,
114 op: Operator::Equals(lit),
115 }))),
116 }
117 }
118 (left, right) => Expr::Compare { left, op, right },
119 }
120}
121
122fn normalize_negated<'expr, F: Clone>(inner: Expr<'expr, F>) -> Expr<'expr, F> {
123 match inner {
124 Expr::Not(nested) => normalize_expr(*nested),
125 Expr::And(children) => {
126 let mapped = children
127 .into_iter()
128 .map(|child| normalize_expr(Expr::Not(Box::new(child))))
129 .collect();
130 Expr::Or(mapped)
131 }
132 Expr::Or(children) => {
133 let mapped = children
134 .into_iter()
135 .map(|child| normalize_expr(Expr::Not(Box::new(child))))
136 .collect();
137 Expr::And(mapped)
138 }
139 Expr::Compare { left, op, right } => {
140 let negated_op = match op {
141 CompareOp::Eq => CompareOp::NotEq,
142 CompareOp::NotEq => CompareOp::Eq,
143 CompareOp::Gt => CompareOp::LtEq,
144 CompareOp::GtEq => CompareOp::Lt,
145 CompareOp::Lt => CompareOp::GtEq,
146 CompareOp::LtEq => CompareOp::Gt,
147 };
148 normalize_compare(left, negated_op, right)
149 }
150 Expr::Literal(value) => Expr::Literal(!value),
151 Expr::IsNull { expr, negated } => Expr::IsNull {
152 expr,
153 negated: !negated,
154 },
155 other => Expr::Not(Box::new(normalize_expr(other))),
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::literal::Literal;
163
164 type TestId = usize;
166
167 #[test]
168 fn normalize_not_between_expands_to_or() {
169 let field: TestId = 7;
170 let column = ScalarExpr::Column(field);
171 let lower = ScalarExpr::Literal(Literal::Int128(5));
172 let upper = ScalarExpr::Literal(Literal::Null);
173
174 let between = Expr::And(vec![
175 Expr::Compare {
176 left: column.clone(),
177 op: CompareOp::GtEq,
178 right: lower,
179 },
180 Expr::Compare {
181 left: column.clone(),
182 op: CompareOp::LtEq,
183 right: upper,
184 },
185 ]);
186
187 let normalized = normalize_predicate(Expr::Not(Box::new(between)));
188
189 let Expr::Or(children) = normalized else {
190 panic!("expected OR after normalization");
191 };
192 assert_eq!(children.len(), 2);
193
194 match &children[0] {
195 Expr::Pred(Filter {
196 op: Operator::LessThan(_),
197 ..
198 }) => {}
199 other => panic!("left branch should be Pred(LessThan), got {other:?}"),
200 }
201
202 match &children[1] {
203 Expr::Compare {
204 op: CompareOp::Gt,
205 right: ScalarExpr::Literal(Literal::Null),
206 ..
207 } => {}
208 other => panic!("right branch should be Compare(Gt, Null), got {other:?}"),
209 }
210 }
211
212 #[test]
213 fn normalize_flips_literal_bool() {
214 let normalized = normalize_predicate(Expr::<TestId>::Not(Box::new(Expr::Literal(true))));
215 assert!(matches!(normalized, Expr::Literal(false)));
216 }
217}