1use crate::ast::{BinaryOp, Expr};
21use crate::optimizer::OptimizerPass;
22use crate::planner::LogicalPlan;
23use alloc::boxed::Box;
24use alloc::vec::Vec;
25
26pub struct AndPredicatePass;
28
29impl OptimizerPass for AndPredicatePass {
30 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
31 self.traverse(plan)
32 }
33
34 fn name(&self) -> &'static str {
35 "and_predicate"
36 }
37}
38
39impl AndPredicatePass {
40 fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
42 match plan {
43 LogicalPlan::Filter { input, predicate } => {
44 let optimized_input = self.traverse(*input);
46
47 let predicates = self.break_and_predicate(predicate);
49
50 self.create_filter_chain(optimized_input, predicates)
52 }
53
54 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
55 input: Box::new(self.traverse(*input)),
56 columns,
57 },
58
59 LogicalPlan::Join {
60 left,
61 right,
62 condition,
63 join_type,
64 } => LogicalPlan::Join {
65 left: Box::new(self.traverse(*left)),
66 right: Box::new(self.traverse(*right)),
67 condition,
68 join_type,
69 },
70
71 LogicalPlan::Aggregate {
72 input,
73 group_by,
74 aggregates,
75 } => LogicalPlan::Aggregate {
76 input: Box::new(self.traverse(*input)),
77 group_by,
78 aggregates,
79 },
80
81 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
82 input: Box::new(self.traverse(*input)),
83 order_by,
84 },
85
86 LogicalPlan::Limit {
87 input,
88 limit,
89 offset,
90 } => LogicalPlan::Limit {
91 input: Box::new(self.traverse(*input)),
92 limit,
93 offset,
94 },
95
96 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
97 left: Box::new(self.traverse(*left)),
98 right: Box::new(self.traverse(*right)),
99 },
100
101 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
102 left: Box::new(self.traverse(*left)),
103 right: Box::new(self.traverse(*right)),
104 all,
105 },
106
107 plan @ (LogicalPlan::Scan { .. }
109 | LogicalPlan::IndexScan { .. }
110 | LogicalPlan::IndexGet { .. }
111 | LogicalPlan::IndexInGet { .. }
112 | LogicalPlan::GinIndexScan { .. }
113 | LogicalPlan::GinIndexScanMulti { .. }
114 | LogicalPlan::Empty) => plan,
115 }
116 }
117
118 fn break_and_predicate(&self, predicate: Expr) -> Vec<Expr> {
123 match predicate {
124 Expr::BinaryOp {
125 left,
126 op: BinaryOp::And,
127 right,
128 } => {
129 let mut result = self.break_and_predicate(*left);
130 result.extend(self.break_and_predicate(*right));
131 result
132 }
133 other => alloc::vec![other],
135 }
136 }
137
138 fn create_filter_chain(&self, input: LogicalPlan, predicates: Vec<Expr>) -> LogicalPlan {
141 if predicates.is_empty() {
142 return input;
143 }
144
145 let mut result = input;
147 for predicate in predicates.into_iter().rev() {
148 result = LogicalPlan::Filter {
149 input: Box::new(result),
150 predicate,
151 };
152 }
153 result
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::ast::Expr;
161
162 #[test]
163 fn test_simple_filter_unchanged() {
164 let pass = AndPredicatePass;
165 let plan = LogicalPlan::filter(
166 LogicalPlan::scan("users"),
167 Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
168 );
169
170 let result = pass.optimize(plan);
171
172 assert!(matches!(result, LogicalPlan::Filter { .. }));
174 if let LogicalPlan::Filter { input, .. } = result {
175 assert!(matches!(*input, LogicalPlan::Scan { .. }));
176 }
177 }
178
179 #[test]
180 fn test_and_predicate_split() {
181 let pass = AndPredicatePass;
182
183 let pred_a = Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64));
185 let pred_b = Expr::eq(Expr::column("users", "name", 1), Expr::literal("Alice"));
186 let and_pred = Expr::and(pred_a.clone(), pred_b.clone());
187
188 let plan = LogicalPlan::filter(LogicalPlan::scan("users"), and_pred);
189
190 let result = pass.optimize(plan);
191
192 assert!(matches!(result, LogicalPlan::Filter { .. }));
194 if let LogicalPlan::Filter { input, .. } = result {
195 assert!(matches!(*input, LogicalPlan::Filter { .. }));
196 if let LogicalPlan::Filter { input: inner, .. } = *input {
197 assert!(matches!(*inner, LogicalPlan::Scan { .. }));
198 }
199 }
200 }
201
202 #[test]
203 fn test_nested_and_predicate_flattened() {
204 let pass = AndPredicatePass;
205
206 let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
208 let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
209 let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
210 let nested_and = Expr::and(Expr::and(pred_a, pred_b), pred_c);
211
212 let plan = LogicalPlan::filter(LogicalPlan::scan("t"), nested_and);
213
214 let result = pass.optimize(plan);
215
216 let mut depth = 0;
219 let mut current = &result;
220 while let LogicalPlan::Filter { input, .. } = current {
221 depth += 1;
222 current = input;
223 }
224 assert_eq!(depth, 3);
225 assert!(matches!(current, LogicalPlan::Scan { .. }));
226 }
227
228 #[test]
229 fn test_or_predicate_preserved() {
230 let pass = AndPredicatePass;
231
232 let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
234 let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
235 let or_pred = Expr::or(pred_a, pred_b);
236
237 let plan = LogicalPlan::filter(LogicalPlan::scan("t"), or_pred);
238
239 let result = pass.optimize(plan);
240
241 assert!(matches!(result, LogicalPlan::Filter { .. }));
243 if let LogicalPlan::Filter { input, predicate } = result {
244 assert!(matches!(*input, LogicalPlan::Scan { .. }));
245 assert!(matches!(
246 predicate,
247 Expr::BinaryOp {
248 op: BinaryOp::Or,
249 ..
250 }
251 ));
252 }
253 }
254
255 #[test]
256 fn test_mixed_and_or_predicate() {
257 let pass = AndPredicatePass;
258
259 let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
261 let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
262 let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
263 let or_pred = Expr::or(pred_b, pred_c);
264 let and_pred = Expr::and(pred_a, or_pred);
265
266 let plan = LogicalPlan::filter(LogicalPlan::scan("t"), and_pred);
267
268 let result = pass.optimize(plan);
269
270 let mut depth = 0;
272 let mut current = &result;
273 while let LogicalPlan::Filter { input, .. } = current {
274 depth += 1;
275 current = input;
276 }
277 assert_eq!(depth, 2);
278 }
279
280 #[test]
281 fn test_break_and_predicate() {
282 let pass = AndPredicatePass;
283
284 let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
285 let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
286 let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
287
288 let and_pred = Expr::and(pred_a.clone(), pred_b.clone());
290 let result = pass.break_and_predicate(and_pred);
291 assert_eq!(result.len(), 2);
292
293 let nested = Expr::and(Expr::and(pred_a.clone(), pred_b.clone()), pred_c.clone());
295 let result = pass.break_and_predicate(nested);
296 assert_eq!(result.len(), 3);
297
298 let result = pass.break_and_predicate(pred_a);
300 assert_eq!(result.len(), 1);
301 }
302}