1use crate::ast::{BinaryOp, Expr, UnaryOp};
9use crate::optimizer::OptimizerPass;
10use crate::planner::LogicalPlan;
11use alloc::boxed::Box;
12
13pub struct NotSimplification;
15
16impl OptimizerPass for NotSimplification {
17 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
18 self.simplify_plan(plan)
19 }
20
21 fn name(&self) -> &'static str {
22 "not_simplification"
23 }
24}
25
26impl NotSimplification {
27 fn simplify_plan(&self, plan: LogicalPlan) -> LogicalPlan {
28 match plan {
29 LogicalPlan::Filter { input, predicate } => {
30 let simplified_predicate = self.simplify_expr(predicate);
31 LogicalPlan::Filter {
32 input: Box::new(self.simplify_plan(*input)),
33 predicate: simplified_predicate,
34 }
35 }
36 LogicalPlan::Join {
37 left,
38 right,
39 condition,
40 join_type,
41 } => LogicalPlan::Join {
42 left: Box::new(self.simplify_plan(*left)),
43 right: Box::new(self.simplify_plan(*right)),
44 condition: self.simplify_expr(condition),
45 join_type,
46 },
47 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
48 input: Box::new(self.simplify_plan(*input)),
49 columns,
50 },
51 LogicalPlan::Aggregate {
52 input,
53 group_by,
54 aggregates,
55 } => LogicalPlan::Aggregate {
56 input: Box::new(self.simplify_plan(*input)),
57 group_by,
58 aggregates,
59 },
60 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
61 input: Box::new(self.simplify_plan(*input)),
62 order_by,
63 },
64 LogicalPlan::Limit {
65 input,
66 limit,
67 offset,
68 } => LogicalPlan::Limit {
69 input: Box::new(self.simplify_plan(*input)),
70 limit,
71 offset,
72 },
73 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
74 left: Box::new(self.simplify_plan(*left)),
75 right: Box::new(self.simplify_plan(*right)),
76 },
77 other => other,
78 }
79 }
80
81 fn simplify_expr(&self, expr: Expr) -> Expr {
82 match expr {
83 Expr::UnaryOp {
85 op: UnaryOp::Not,
86 expr: inner,
87 } => {
88 let simplified_inner = self.simplify_expr(*inner);
89 self.simplify_not(simplified_inner)
90 }
91
92 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
94 left: Box::new(self.simplify_expr(*left)),
95 op,
96 right: Box::new(self.simplify_expr(*right)),
97 },
98
99 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
101 op,
102 expr: Box::new(self.simplify_expr(*expr)),
103 },
104
105 Expr::Between { expr, low, high } => Expr::Between {
107 expr: Box::new(self.simplify_expr(*expr)),
108 low: Box::new(self.simplify_expr(*low)),
109 high: Box::new(self.simplify_expr(*high)),
110 },
111
112 Expr::NotBetween { expr, low, high } => Expr::NotBetween {
113 expr: Box::new(self.simplify_expr(*expr)),
114 low: Box::new(self.simplify_expr(*low)),
115 high: Box::new(self.simplify_expr(*high)),
116 },
117
118 Expr::In { expr, list } => Expr::In {
120 expr: Box::new(self.simplify_expr(*expr)),
121 list: list.into_iter().map(|e| self.simplify_expr(e)).collect(),
122 },
123
124 Expr::NotIn { expr, list } => Expr::NotIn {
125 expr: Box::new(self.simplify_expr(*expr)),
126 list: list.into_iter().map(|e| self.simplify_expr(e)).collect(),
127 },
128
129 Expr::Like { expr, pattern } => Expr::Like {
131 expr: Box::new(self.simplify_expr(*expr)),
132 pattern,
133 },
134
135 Expr::NotLike { expr, pattern } => Expr::NotLike {
136 expr: Box::new(self.simplify_expr(*expr)),
137 pattern,
138 },
139
140 Expr::Match { expr, pattern } => Expr::Match {
142 expr: Box::new(self.simplify_expr(*expr)),
143 pattern,
144 },
145
146 Expr::NotMatch { expr, pattern } => Expr::NotMatch {
147 expr: Box::new(self.simplify_expr(*expr)),
148 pattern,
149 },
150
151 Expr::Function { name, args } => Expr::Function {
153 name,
154 args: args.into_iter().map(|e| self.simplify_expr(e)).collect(),
155 },
156
157 Expr::Aggregate {
159 func,
160 expr,
161 distinct,
162 } => Expr::Aggregate {
163 func,
164 expr: expr.map(|e| Box::new(self.simplify_expr(*e))),
165 distinct,
166 },
167
168 other => other,
170 }
171 }
172
173 fn simplify_not(&self, inner: Expr) -> Expr {
175 match inner {
176 Expr::UnaryOp {
178 op: UnaryOp::Not,
179 expr,
180 } => *expr,
181
182 Expr::BinaryOp {
184 left,
185 op: BinaryOp::And,
186 right,
187 } => Expr::BinaryOp {
188 left: Box::new(self.simplify_not(*left)),
189 op: BinaryOp::Or,
190 right: Box::new(self.simplify_not(*right)),
191 },
192
193 Expr::BinaryOp {
195 left,
196 op: BinaryOp::Or,
197 right,
198 } => Expr::BinaryOp {
199 left: Box::new(self.simplify_not(*left)),
200 op: BinaryOp::And,
201 right: Box::new(self.simplify_not(*right)),
202 },
203
204 Expr::BinaryOp {
206 left,
207 op: BinaryOp::Eq,
208 right,
209 } => Expr::BinaryOp {
210 left,
211 op: BinaryOp::Ne,
212 right,
213 },
214
215 Expr::BinaryOp {
217 left,
218 op: BinaryOp::Ne,
219 right,
220 } => Expr::BinaryOp {
221 left,
222 op: BinaryOp::Eq,
223 right,
224 },
225
226 Expr::BinaryOp {
228 left,
229 op: BinaryOp::Lt,
230 right,
231 } => Expr::BinaryOp {
232 left,
233 op: BinaryOp::Ge,
234 right,
235 },
236
237 Expr::BinaryOp {
239 left,
240 op: BinaryOp::Le,
241 right,
242 } => Expr::BinaryOp {
243 left,
244 op: BinaryOp::Gt,
245 right,
246 },
247
248 Expr::BinaryOp {
250 left,
251 op: BinaryOp::Gt,
252 right,
253 } => Expr::BinaryOp {
254 left,
255 op: BinaryOp::Le,
256 right,
257 },
258
259 Expr::BinaryOp {
261 left,
262 op: BinaryOp::Ge,
263 right,
264 } => Expr::BinaryOp {
265 left,
266 op: BinaryOp::Lt,
267 right,
268 },
269
270 Expr::In { expr, list } => Expr::NotIn { expr, list },
272
273 Expr::NotIn { expr, list } => Expr::In { expr, list },
275
276 Expr::Between { expr, low, high } => Expr::NotBetween { expr, low, high },
278
279 Expr::NotBetween { expr, low, high } => Expr::Between { expr, low, high },
281
282 Expr::Like { expr, pattern } => Expr::NotLike { expr, pattern },
284
285 Expr::NotLike { expr, pattern } => Expr::Like { expr, pattern },
287
288 Expr::Match { expr, pattern } => Expr::NotMatch { expr, pattern },
290
291 Expr::NotMatch { expr, pattern } => Expr::Match { expr, pattern },
293
294 Expr::UnaryOp {
296 op: UnaryOp::IsNull,
297 expr,
298 } => Expr::UnaryOp {
299 op: UnaryOp::IsNotNull,
300 expr,
301 },
302
303 Expr::UnaryOp {
305 op: UnaryOp::IsNotNull,
306 expr,
307 } => Expr::UnaryOp {
308 op: UnaryOp::IsNull,
309 expr,
310 },
311
312 other => Expr::UnaryOp {
314 op: UnaryOp::Not,
315 expr: Box::new(other),
316 },
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use cynos_core::Value;
325
326 #[test]
327 fn test_double_negation() {
328 let pass = NotSimplification;
329
330 let expr = Expr::not(Expr::not(Expr::column("t", "c", 0)));
332 let simplified = pass.simplify_expr(expr);
333
334 assert!(matches!(simplified, Expr::Column(_)));
335 }
336
337 #[test]
338 fn test_not_eq_to_ne() {
339 let pass = NotSimplification;
340
341 let expr = Expr::not(Expr::eq(
343 Expr::column("t", "a", 0),
344 Expr::literal(1i64),
345 ));
346 let simplified = pass.simplify_expr(expr);
347
348 assert!(matches!(
349 simplified,
350 Expr::BinaryOp {
351 op: BinaryOp::Ne,
352 ..
353 }
354 ));
355 }
356
357 #[test]
358 fn test_not_in_to_not_in() {
359 let pass = NotSimplification;
360
361 let expr = Expr::not(Expr::in_list(
363 Expr::column("t", "c", 0),
364 alloc::vec![Value::Int64(1), Value::Int64(2)],
365 ));
366 let simplified = pass.simplify_expr(expr);
367
368 assert!(matches!(simplified, Expr::NotIn { .. }));
369 }
370
371 #[test]
372 fn test_not_between_to_not_between() {
373 let pass = NotSimplification;
374
375 let expr = Expr::not(Expr::between(
377 Expr::column("t", "c", 0),
378 Expr::literal(1i64),
379 Expr::literal(10i64),
380 ));
381 let simplified = pass.simplify_expr(expr);
382
383 assert!(matches!(simplified, Expr::NotBetween { .. }));
384 }
385
386 #[test]
387 fn test_de_morgan_and() {
388 let pass = NotSimplification;
389
390 let expr = Expr::not(Expr::and(
392 Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64)),
393 Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64)),
394 ));
395 let simplified = pass.simplify_expr(expr);
396
397 assert!(matches!(
399 simplified,
400 Expr::BinaryOp {
401 op: BinaryOp::Or,
402 ..
403 }
404 ));
405 }
406
407 #[test]
408 fn test_de_morgan_or() {
409 let pass = NotSimplification;
410
411 let expr = Expr::not(Expr::or(
413 Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64)),
414 Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64)),
415 ));
416 let simplified = pass.simplify_expr(expr);
417
418 assert!(matches!(
420 simplified,
421 Expr::BinaryOp {
422 op: BinaryOp::And,
423 ..
424 }
425 ));
426 }
427}