Skip to main content

cynos_query/optimizer/
and_predicate.rs

1//! AND predicate pass - breaks down AND predicates into chained Filter nodes.
2//!
3//! This pass transforms a single Filter node with a complex AND predicate
4//! into multiple chained Filter nodes, each with a simple predicate.
5//!
6//! Example:
7//! ```text
8//! Filter(a AND b AND c)    =>    Filter(a)
9//!        |                          |
10//!      Scan                      Filter(b)
11//!                                   |
12//!                                Filter(c)
13//!                                   |
14//!                                 Scan
15//! ```
16//!
17//! This transformation enables other optimization passes (like index selection)
18//! to work on individual predicates more effectively.
19
20use crate::ast::{BinaryOp, Expr};
21use crate::optimizer::OptimizerPass;
22use crate::planner::LogicalPlan;
23use alloc::boxed::Box;
24use alloc::vec::Vec;
25
26/// Pass that breaks down AND predicates into chained Filter nodes.
27pub struct AndPredicatePass;
28
29impl OptimizerPass for AndPredicatePass {
30    fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
31        self.traverse(plan)
32    }
33
34    fn name(&self) -> &'static str {
35        "and_predicate"
36    }
37}
38
39impl AndPredicatePass {
40    /// Recursively traverses the plan tree and transforms AND predicates.
41    fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
42        match plan {
43            LogicalPlan::Filter { input, predicate } => {
44                // First, recursively optimize the input
45                let optimized_input = self.traverse(*input);
46
47                // Break down the AND predicate into components
48                let predicates = self.break_and_predicate(predicate);
49
50                // Create a chain of Filter nodes
51                self.create_filter_chain(optimized_input, predicates)
52            }
53
54            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
55                input: Box::new(self.traverse(*input)),
56                columns,
57            },
58
59            LogicalPlan::Join {
60                left,
61                right,
62                condition,
63                join_type,
64            } => LogicalPlan::Join {
65                left: Box::new(self.traverse(*left)),
66                right: Box::new(self.traverse(*right)),
67                condition,
68                join_type,
69            },
70
71            LogicalPlan::Aggregate {
72                input,
73                group_by,
74                aggregates,
75            } => LogicalPlan::Aggregate {
76                input: Box::new(self.traverse(*input)),
77                group_by,
78                aggregates,
79            },
80
81            LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
82                input: Box::new(self.traverse(*input)),
83                order_by,
84            },
85
86            LogicalPlan::Limit {
87                input,
88                limit,
89                offset,
90            } => LogicalPlan::Limit {
91                input: Box::new(self.traverse(*input)),
92                limit,
93                offset,
94            },
95
96            LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
97                left: Box::new(self.traverse(*left)),
98                right: Box::new(self.traverse(*right)),
99            },
100
101            LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
102                left: Box::new(self.traverse(*left)),
103                right: Box::new(self.traverse(*right)),
104                all,
105            },
106
107            // Leaf nodes - no transformation needed
108            plan @ (LogicalPlan::Scan { .. }
109            | LogicalPlan::IndexScan { .. }
110            | LogicalPlan::IndexGet { .. }
111            | LogicalPlan::IndexInGet { .. }
112            | LogicalPlan::GinIndexScan { .. }
113            | LogicalPlan::GinIndexScanMulti { .. }
114            | LogicalPlan::Empty) => plan,
115        }
116    }
117
118    /// Recursively breaks down an AND predicate into its components.
119    /// OR predicates and other predicate types are left unchanged.
120    ///
121    /// Example: (a AND (b AND c)) AND (d OR e) becomes [a, b, c, (d OR e)]
122    fn break_and_predicate(&self, predicate: Expr) -> Vec<Expr> {
123        match predicate {
124            Expr::BinaryOp {
125                left,
126                op: BinaryOp::And,
127                right,
128            } => {
129                let mut result = self.break_and_predicate(*left);
130                result.extend(self.break_and_predicate(*right));
131                result
132            }
133            // Non-AND predicates are returned as-is
134            other => alloc::vec![other],
135        }
136    }
137
138    /// Creates a chain of Filter nodes from a list of predicates.
139    /// The first predicate becomes the outermost Filter.
140    fn create_filter_chain(&self, input: LogicalPlan, predicates: Vec<Expr>) -> LogicalPlan {
141        if predicates.is_empty() {
142            return input;
143        }
144
145        // Build the chain from bottom to top
146        let mut result = input;
147        for predicate in predicates.into_iter().rev() {
148            result = LogicalPlan::Filter {
149                input: Box::new(result),
150                predicate,
151            };
152        }
153        result
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::ast::Expr;
161
162    #[test]
163    fn test_simple_filter_unchanged() {
164        let pass = AndPredicatePass;
165        let plan = LogicalPlan::filter(
166            LogicalPlan::scan("users"),
167            Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
168        );
169
170        let result = pass.optimize(plan);
171
172        // Single predicate should remain as single Filter
173        assert!(matches!(result, LogicalPlan::Filter { .. }));
174        if let LogicalPlan::Filter { input, .. } = result {
175            assert!(matches!(*input, LogicalPlan::Scan { .. }));
176        }
177    }
178
179    #[test]
180    fn test_and_predicate_split() {
181        let pass = AndPredicatePass;
182
183        // Create: Filter(a AND b) -> Scan
184        let pred_a = Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64));
185        let pred_b = Expr::eq(Expr::column("users", "name", 1), Expr::literal("Alice"));
186        let and_pred = Expr::and(pred_a.clone(), pred_b.clone());
187
188        let plan = LogicalPlan::filter(LogicalPlan::scan("users"), and_pred);
189
190        let result = pass.optimize(plan);
191
192        // Should become: Filter(a) -> Filter(b) -> Scan
193        assert!(matches!(result, LogicalPlan::Filter { .. }));
194        if let LogicalPlan::Filter { input, .. } = result {
195            assert!(matches!(*input, LogicalPlan::Filter { .. }));
196            if let LogicalPlan::Filter { input: inner, .. } = *input {
197                assert!(matches!(*inner, LogicalPlan::Scan { .. }));
198            }
199        }
200    }
201
202    #[test]
203    fn test_nested_and_predicate_flattened() {
204        let pass = AndPredicatePass;
205
206        // Create: Filter((a AND b) AND c) -> Scan
207        let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
208        let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
209        let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
210        let nested_and = Expr::and(Expr::and(pred_a, pred_b), pred_c);
211
212        let plan = LogicalPlan::filter(LogicalPlan::scan("t"), nested_and);
213
214        let result = pass.optimize(plan);
215
216        // Should become: Filter(a) -> Filter(b) -> Filter(c) -> Scan
217        // Count the depth of Filter nodes
218        let mut depth = 0;
219        let mut current = &result;
220        while let LogicalPlan::Filter { input, .. } = current {
221            depth += 1;
222            current = input;
223        }
224        assert_eq!(depth, 3);
225        assert!(matches!(current, LogicalPlan::Scan { .. }));
226    }
227
228    #[test]
229    fn test_or_predicate_preserved() {
230        let pass = AndPredicatePass;
231
232        // Create: Filter(a OR b) -> Scan
233        let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
234        let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
235        let or_pred = Expr::or(pred_a, pred_b);
236
237        let plan = LogicalPlan::filter(LogicalPlan::scan("t"), or_pred);
238
239        let result = pass.optimize(plan);
240
241        // OR predicate should remain as single Filter
242        assert!(matches!(result, LogicalPlan::Filter { .. }));
243        if let LogicalPlan::Filter { input, predicate } = result {
244            assert!(matches!(*input, LogicalPlan::Scan { .. }));
245            assert!(matches!(
246                predicate,
247                Expr::BinaryOp {
248                    op: BinaryOp::Or,
249                    ..
250                }
251            ));
252        }
253    }
254
255    #[test]
256    fn test_mixed_and_or_predicate() {
257        let pass = AndPredicatePass;
258
259        // Create: Filter(a AND (b OR c)) -> Scan
260        let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
261        let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
262        let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
263        let or_pred = Expr::or(pred_b, pred_c);
264        let and_pred = Expr::and(pred_a, or_pred);
265
266        let plan = LogicalPlan::filter(LogicalPlan::scan("t"), and_pred);
267
268        let result = pass.optimize(plan);
269
270        // Should become: Filter(a) -> Filter(b OR c) -> Scan
271        let mut depth = 0;
272        let mut current = &result;
273        while let LogicalPlan::Filter { input, .. } = current {
274            depth += 1;
275            current = input;
276        }
277        assert_eq!(depth, 2);
278    }
279
280    #[test]
281    fn test_break_and_predicate() {
282        let pass = AndPredicatePass;
283
284        let pred_a = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
285        let pred_b = Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64));
286        let pred_c = Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64));
287
288        // Simple AND
289        let and_pred = Expr::and(pred_a.clone(), pred_b.clone());
290        let result = pass.break_and_predicate(and_pred);
291        assert_eq!(result.len(), 2);
292
293        // Nested AND
294        let nested = Expr::and(Expr::and(pred_a.clone(), pred_b.clone()), pred_c.clone());
295        let result = pass.break_and_predicate(nested);
296        assert_eq!(result.len(), 3);
297
298        // Single predicate
299        let result = pass.break_and_predicate(pred_a);
300        assert_eq!(result.len(), 1);
301    }
302}