Skip to main content

cynos_query/optimizer/
topn_pushdown.rs

1//! TopN pushdown optimization pass.
2//!
3//! This pass identifies `Limit -> Sort` patterns and converts them to a single
4//! `TopN` operator, which is more efficient for selecting the top K elements.
5//!
6//! Example:
7//! ```text
8//! Limit(10, 0)                =>    TopN(order_by, limit=10, offset=0)
9//!      |                                  |
10//! Sort(order_by)                      input
11//!      |
12//!   input
13//! ```
14//!
15//! Performance benefit:
16//! - Full sort: O(n log n) time, O(n) space
17//! - TopN with heap: O(n log k) time, O(k) space (where k = limit + offset)
18//!
19//! This optimization is safe because:
20//! - TopN produces the same result as Sort + Limit
21//! - It only applies when Limit is directly above Sort (no intervening operators)
22
23use crate::planner::PhysicalPlan;
24use alloc::boxed::Box;
25
26/// TopN pushdown optimization pass.
27///
28/// Converts `Limit -> Sort` patterns to `TopN` for more efficient top-k selection.
29pub struct TopNPushdown;
30
31impl TopNPushdown {
32    /// Creates a new TopNPushdown pass.
33    pub fn new() -> Self {
34        Self
35    }
36
37    /// Optimizes the physical plan by converting Limit+Sort to TopN.
38    pub fn optimize(&self, plan: PhysicalPlan) -> PhysicalPlan {
39        self.traverse(plan)
40    }
41
42    fn traverse(&self, plan: PhysicalPlan) -> PhysicalPlan {
43        match plan {
44            PhysicalPlan::Limit {
45                input,
46                limit,
47                offset,
48            } => {
49                let optimized_input = self.traverse(*input);
50
51                // Check if input is a Sort - if so, convert to TopN
52                if let PhysicalPlan::Sort { input: sort_input, order_by } = optimized_input {
53                    // Convert to TopN - more efficient for top-k selection
54                    return PhysicalPlan::TopN {
55                        input: sort_input,
56                        order_by,
57                        limit,
58                        offset,
59                    };
60                }
61
62                // Check if input is IndexGet - push limit into it
63                if let PhysicalPlan::IndexGet { table, index, key, limit: _ } = optimized_input {
64                    // Push limit into IndexGet for early termination
65                    // Note: offset is handled by skipping rows after IndexGet
66                    if offset == 0 {
67                        return PhysicalPlan::IndexGet {
68                            table,
69                            index,
70                            key,
71                            limit: Some(limit),
72                        };
73                    } else {
74                        // With offset, we need to fetch limit + offset rows
75                        return PhysicalPlan::Limit {
76                            input: Box::new(PhysicalPlan::IndexGet {
77                                table,
78                                index,
79                                key,
80                                limit: Some(limit + offset),
81                            }),
82                            limit,
83                            offset,
84                        };
85                    }
86                }
87
88                // Check if input is IndexScan without limit - push limit into it
89                if let PhysicalPlan::IndexScan {
90                    table,
91                    index,
92                    range_start,
93                    range_end,
94                    include_start,
95                    include_end,
96                    limit: None,
97                    offset: None,
98                    reverse,
99                } = optimized_input
100                {
101                    // Push limit into IndexScan for early termination
102                    return PhysicalPlan::IndexScan {
103                        table,
104                        index,
105                        range_start,
106                        range_end,
107                        include_start,
108                        include_end,
109                        limit: Some(limit + offset),
110                        offset: Some(offset),
111                        reverse,
112                    };
113                }
114
115                // Not a Sort/IndexGet/IndexScan, keep as Limit
116                PhysicalPlan::Limit {
117                    input: Box::new(optimized_input),
118                    limit,
119                    offset,
120                }
121            }
122
123            // Recursively process other nodes
124            PhysicalPlan::Filter { input, predicate } => PhysicalPlan::Filter {
125                input: Box::new(self.traverse(*input)),
126                predicate,
127            },
128
129            PhysicalPlan::Project { input, columns } => PhysicalPlan::Project {
130                input: Box::new(self.traverse(*input)),
131                columns,
132            },
133
134            PhysicalPlan::Sort { input, order_by } => PhysicalPlan::Sort {
135                input: Box::new(self.traverse(*input)),
136                order_by,
137            },
138
139            PhysicalPlan::TopN {
140                input,
141                order_by,
142                limit,
143                offset,
144            } => PhysicalPlan::TopN {
145                input: Box::new(self.traverse(*input)),
146                order_by,
147                limit,
148                offset,
149            },
150
151            PhysicalPlan::HashJoin {
152                left,
153                right,
154                condition,
155                join_type,
156            } => PhysicalPlan::HashJoin {
157                left: Box::new(self.traverse(*left)),
158                right: Box::new(self.traverse(*right)),
159                condition,
160                join_type,
161            },
162
163            PhysicalPlan::SortMergeJoin {
164                left,
165                right,
166                condition,
167                join_type,
168            } => PhysicalPlan::SortMergeJoin {
169                left: Box::new(self.traverse(*left)),
170                right: Box::new(self.traverse(*right)),
171                condition,
172                join_type,
173            },
174
175            PhysicalPlan::NestedLoopJoin {
176                left,
177                right,
178                condition,
179                join_type,
180            } => PhysicalPlan::NestedLoopJoin {
181                left: Box::new(self.traverse(*left)),
182                right: Box::new(self.traverse(*right)),
183                condition,
184                join_type,
185            },
186
187            PhysicalPlan::IndexNestedLoopJoin {
188                outer,
189                inner_table,
190                inner_index,
191                condition,
192                join_type,
193            } => PhysicalPlan::IndexNestedLoopJoin {
194                outer: Box::new(self.traverse(*outer)),
195                inner_table,
196                inner_index,
197                condition,
198                join_type,
199            },
200
201            PhysicalPlan::HashAggregate {
202                input,
203                group_by,
204                aggregates,
205            } => PhysicalPlan::HashAggregate {
206                input: Box::new(self.traverse(*input)),
207                group_by,
208                aggregates,
209            },
210
211            PhysicalPlan::CrossProduct { left, right } => PhysicalPlan::CrossProduct {
212                left: Box::new(self.traverse(*left)),
213                right: Box::new(self.traverse(*right)),
214            },
215
216            PhysicalPlan::NoOp { input } => PhysicalPlan::NoOp {
217                input: Box::new(self.traverse(*input)),
218            },
219
220            // Leaf nodes - no transformation
221            plan @ (PhysicalPlan::TableScan { .. }
222            | PhysicalPlan::IndexScan { .. }
223            | PhysicalPlan::IndexGet { .. }
224            | PhysicalPlan::IndexInGet { .. }
225            | PhysicalPlan::GinIndexScan { .. }
226            | PhysicalPlan::GinIndexScanMulti { .. }
227            | PhysicalPlan::Empty) => plan,
228        }
229    }
230}
231
232impl Default for TopNPushdown {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::ast::{Expr, SortOrder};
242
243    #[test]
244    fn test_limit_sort_converted_to_topn() {
245        let pass = TopNPushdown::new();
246
247        // Create: Limit(10, 5) -> Sort -> TableScan
248        let plan = PhysicalPlan::Limit {
249            input: Box::new(PhysicalPlan::Sort {
250                input: Box::new(PhysicalPlan::table_scan("users")),
251                order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Asc)],
252            }),
253            limit: 10,
254            offset: 5,
255        };
256
257        let result = pass.optimize(plan);
258
259        // Should be TopN -> TableScan
260        if let PhysicalPlan::TopN {
261            input,
262            order_by,
263            limit,
264            offset,
265        } = result
266        {
267            assert_eq!(limit, 10);
268            assert_eq!(offset, 5);
269            assert_eq!(order_by.len(), 1);
270            assert!(matches!(*input, PhysicalPlan::TableScan { .. }));
271        } else {
272            panic!("Expected TopN, got {:?}", result);
273        }
274    }
275
276    #[test]
277    fn test_limit_without_sort_unchanged() {
278        let pass = TopNPushdown::new();
279
280        // Create: Limit(10, 0) -> TableScan (no Sort)
281        let plan = PhysicalPlan::Limit {
282            input: Box::new(PhysicalPlan::table_scan("users")),
283            limit: 10,
284            offset: 0,
285        };
286
287        let result = pass.optimize(plan);
288
289        // Should remain as Limit -> TableScan
290        assert!(matches!(result, PhysicalPlan::Limit { .. }));
291        if let PhysicalPlan::Limit { input, .. } = result {
292            assert!(matches!(*input, PhysicalPlan::TableScan { .. }));
293        }
294    }
295
296    #[test]
297    fn test_limit_filter_sort_not_converted() {
298        let pass = TopNPushdown::new();
299
300        // Create: Limit -> Filter -> Sort -> TableScan
301        // Filter between Limit and Sort blocks conversion
302        let plan = PhysicalPlan::Limit {
303            input: Box::new(PhysicalPlan::Filter {
304                input: Box::new(PhysicalPlan::Sort {
305                    input: Box::new(PhysicalPlan::table_scan("users")),
306                    order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Asc)],
307                }),
308                predicate: Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
309            }),
310            limit: 10,
311            offset: 0,
312        };
313
314        let result = pass.optimize(plan);
315
316        // Should remain as Limit -> Filter -> Sort (not converted)
317        // because Filter is between Limit and Sort
318        assert!(matches!(result, PhysicalPlan::Limit { .. }));
319        if let PhysicalPlan::Limit { input, .. } = result {
320            assert!(matches!(*input, PhysicalPlan::Filter { .. }));
321        }
322    }
323
324    #[test]
325    fn test_nested_limit_sort_converted() {
326        let pass = TopNPushdown::new();
327
328        // Create: Project -> Limit -> Sort -> TableScan
329        let plan = PhysicalPlan::Project {
330            input: Box::new(PhysicalPlan::Limit {
331                input: Box::new(PhysicalPlan::Sort {
332                    input: Box::new(PhysicalPlan::table_scan("users")),
333                    order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Desc)],
334                }),
335                limit: 5,
336                offset: 0,
337            }),
338            columns: alloc::vec![Expr::column("users", "name", 1)],
339        };
340
341        let result = pass.optimize(plan);
342
343        // Should be Project -> TopN -> TableScan
344        if let PhysicalPlan::Project { input, .. } = result {
345            if let PhysicalPlan::TopN { limit, offset, .. } = *input {
346                assert_eq!(limit, 5);
347                assert_eq!(offset, 0);
348            } else {
349                panic!("Expected TopN inside Project");
350            }
351        } else {
352            panic!("Expected Project, got {:?}", result);
353        }
354    }
355
356    #[test]
357    fn test_multiple_sort_columns() {
358        let pass = TopNPushdown::new();
359
360        // Create: Limit -> Sort(col1 ASC, col2 DESC) -> TableScan
361        let plan = PhysicalPlan::Limit {
362            input: Box::new(PhysicalPlan::Sort {
363                input: Box::new(PhysicalPlan::table_scan("users")),
364                order_by: alloc::vec![
365                    (Expr::column("users", "name", 1), SortOrder::Asc),
366                    (Expr::column("users", "id", 0), SortOrder::Desc),
367                ],
368            }),
369            limit: 20,
370            offset: 10,
371        };
372
373        let result = pass.optimize(plan);
374
375        // Should be TopN with both sort columns preserved
376        if let PhysicalPlan::TopN {
377            order_by,
378            limit,
379            offset,
380            ..
381        } = result
382        {
383            assert_eq!(limit, 20);
384            assert_eq!(offset, 10);
385            assert_eq!(order_by.len(), 2);
386            assert_eq!(order_by[0].1, SortOrder::Asc);
387            assert_eq!(order_by[1].1, SortOrder::Desc);
388        } else {
389            panic!("Expected TopN, got {:?}", result);
390        }
391    }
392
393    #[test]
394    fn test_sort_in_subquery_converted() {
395        let pass = TopNPushdown::new();
396
397        // Create: HashJoin(left: Limit -> Sort, right: TableScan)
398        let plan = PhysicalPlan::HashJoin {
399            left: Box::new(PhysicalPlan::Limit {
400                input: Box::new(PhysicalPlan::Sort {
401                    input: Box::new(PhysicalPlan::table_scan("orders")),
402                    order_by: alloc::vec![(Expr::column("orders", "amount", 1), SortOrder::Desc)],
403                }),
404                limit: 100,
405                offset: 0,
406            }),
407            right: Box::new(PhysicalPlan::table_scan("users")),
408            condition: Expr::eq(
409                Expr::column("orders", "user_id", 2),
410                Expr::column("users", "id", 0),
411            ),
412            join_type: crate::ast::JoinType::Inner,
413        };
414
415        let result = pass.optimize(plan);
416
417        // Left side should be converted to TopN
418        if let PhysicalPlan::HashJoin { left, .. } = result {
419            assert!(matches!(*left, PhysicalPlan::TopN { .. }));
420        } else {
421            panic!("Expected HashJoin, got {:?}", result);
422        }
423    }
424}