1use crate::planner::PhysicalPlan;
24use alloc::boxed::Box;
25
26pub struct TopNPushdown;
30
31impl TopNPushdown {
32 pub fn new() -> Self {
34 Self
35 }
36
37 pub fn optimize(&self, plan: PhysicalPlan) -> PhysicalPlan {
39 self.traverse(plan)
40 }
41
42 fn traverse(&self, plan: PhysicalPlan) -> PhysicalPlan {
43 match plan {
44 PhysicalPlan::Limit {
45 input,
46 limit,
47 offset,
48 } => {
49 let optimized_input = self.traverse(*input);
50
51 if let PhysicalPlan::Sort { input: sort_input, order_by } = optimized_input {
53 return PhysicalPlan::TopN {
55 input: sort_input,
56 order_by,
57 limit,
58 offset,
59 };
60 }
61
62 if let PhysicalPlan::IndexGet { table, index, key, limit: _ } = optimized_input {
64 if offset == 0 {
67 return PhysicalPlan::IndexGet {
68 table,
69 index,
70 key,
71 limit: Some(limit),
72 };
73 } else {
74 return PhysicalPlan::Limit {
76 input: Box::new(PhysicalPlan::IndexGet {
77 table,
78 index,
79 key,
80 limit: Some(limit + offset),
81 }),
82 limit,
83 offset,
84 };
85 }
86 }
87
88 if let PhysicalPlan::IndexScan {
90 table,
91 index,
92 range_start,
93 range_end,
94 include_start,
95 include_end,
96 limit: None,
97 offset: None,
98 reverse,
99 } = optimized_input
100 {
101 return PhysicalPlan::IndexScan {
103 table,
104 index,
105 range_start,
106 range_end,
107 include_start,
108 include_end,
109 limit: Some(limit + offset),
110 offset: Some(offset),
111 reverse,
112 };
113 }
114
115 PhysicalPlan::Limit {
117 input: Box::new(optimized_input),
118 limit,
119 offset,
120 }
121 }
122
123 PhysicalPlan::Filter { input, predicate } => PhysicalPlan::Filter {
125 input: Box::new(self.traverse(*input)),
126 predicate,
127 },
128
129 PhysicalPlan::Project { input, columns } => PhysicalPlan::Project {
130 input: Box::new(self.traverse(*input)),
131 columns,
132 },
133
134 PhysicalPlan::Sort { input, order_by } => PhysicalPlan::Sort {
135 input: Box::new(self.traverse(*input)),
136 order_by,
137 },
138
139 PhysicalPlan::TopN {
140 input,
141 order_by,
142 limit,
143 offset,
144 } => PhysicalPlan::TopN {
145 input: Box::new(self.traverse(*input)),
146 order_by,
147 limit,
148 offset,
149 },
150
151 PhysicalPlan::HashJoin {
152 left,
153 right,
154 condition,
155 join_type,
156 } => PhysicalPlan::HashJoin {
157 left: Box::new(self.traverse(*left)),
158 right: Box::new(self.traverse(*right)),
159 condition,
160 join_type,
161 },
162
163 PhysicalPlan::SortMergeJoin {
164 left,
165 right,
166 condition,
167 join_type,
168 } => PhysicalPlan::SortMergeJoin {
169 left: Box::new(self.traverse(*left)),
170 right: Box::new(self.traverse(*right)),
171 condition,
172 join_type,
173 },
174
175 PhysicalPlan::NestedLoopJoin {
176 left,
177 right,
178 condition,
179 join_type,
180 } => PhysicalPlan::NestedLoopJoin {
181 left: Box::new(self.traverse(*left)),
182 right: Box::new(self.traverse(*right)),
183 condition,
184 join_type,
185 },
186
187 PhysicalPlan::IndexNestedLoopJoin {
188 outer,
189 inner_table,
190 inner_index,
191 condition,
192 join_type,
193 } => PhysicalPlan::IndexNestedLoopJoin {
194 outer: Box::new(self.traverse(*outer)),
195 inner_table,
196 inner_index,
197 condition,
198 join_type,
199 },
200
201 PhysicalPlan::HashAggregate {
202 input,
203 group_by,
204 aggregates,
205 } => PhysicalPlan::HashAggregate {
206 input: Box::new(self.traverse(*input)),
207 group_by,
208 aggregates,
209 },
210
211 PhysicalPlan::CrossProduct { left, right } => PhysicalPlan::CrossProduct {
212 left: Box::new(self.traverse(*left)),
213 right: Box::new(self.traverse(*right)),
214 },
215
216 PhysicalPlan::NoOp { input } => PhysicalPlan::NoOp {
217 input: Box::new(self.traverse(*input)),
218 },
219
220 plan @ (PhysicalPlan::TableScan { .. }
222 | PhysicalPlan::IndexScan { .. }
223 | PhysicalPlan::IndexGet { .. }
224 | PhysicalPlan::IndexInGet { .. }
225 | PhysicalPlan::GinIndexScan { .. }
226 | PhysicalPlan::GinIndexScanMulti { .. }
227 | PhysicalPlan::Empty) => plan,
228 }
229 }
230}
231
232impl Default for TopNPushdown {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::ast::{Expr, SortOrder};
242
243 #[test]
244 fn test_limit_sort_converted_to_topn() {
245 let pass = TopNPushdown::new();
246
247 let plan = PhysicalPlan::Limit {
249 input: Box::new(PhysicalPlan::Sort {
250 input: Box::new(PhysicalPlan::table_scan("users")),
251 order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Asc)],
252 }),
253 limit: 10,
254 offset: 5,
255 };
256
257 let result = pass.optimize(plan);
258
259 if let PhysicalPlan::TopN {
261 input,
262 order_by,
263 limit,
264 offset,
265 } = result
266 {
267 assert_eq!(limit, 10);
268 assert_eq!(offset, 5);
269 assert_eq!(order_by.len(), 1);
270 assert!(matches!(*input, PhysicalPlan::TableScan { .. }));
271 } else {
272 panic!("Expected TopN, got {:?}", result);
273 }
274 }
275
276 #[test]
277 fn test_limit_without_sort_unchanged() {
278 let pass = TopNPushdown::new();
279
280 let plan = PhysicalPlan::Limit {
282 input: Box::new(PhysicalPlan::table_scan("users")),
283 limit: 10,
284 offset: 0,
285 };
286
287 let result = pass.optimize(plan);
288
289 assert!(matches!(result, PhysicalPlan::Limit { .. }));
291 if let PhysicalPlan::Limit { input, .. } = result {
292 assert!(matches!(*input, PhysicalPlan::TableScan { .. }));
293 }
294 }
295
296 #[test]
297 fn test_limit_filter_sort_not_converted() {
298 let pass = TopNPushdown::new();
299
300 let plan = PhysicalPlan::Limit {
303 input: Box::new(PhysicalPlan::Filter {
304 input: Box::new(PhysicalPlan::Sort {
305 input: Box::new(PhysicalPlan::table_scan("users")),
306 order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Asc)],
307 }),
308 predicate: Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
309 }),
310 limit: 10,
311 offset: 0,
312 };
313
314 let result = pass.optimize(plan);
315
316 assert!(matches!(result, PhysicalPlan::Limit { .. }));
319 if let PhysicalPlan::Limit { input, .. } = result {
320 assert!(matches!(*input, PhysicalPlan::Filter { .. }));
321 }
322 }
323
324 #[test]
325 fn test_nested_limit_sort_converted() {
326 let pass = TopNPushdown::new();
327
328 let plan = PhysicalPlan::Project {
330 input: Box::new(PhysicalPlan::Limit {
331 input: Box::new(PhysicalPlan::Sort {
332 input: Box::new(PhysicalPlan::table_scan("users")),
333 order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Desc)],
334 }),
335 limit: 5,
336 offset: 0,
337 }),
338 columns: alloc::vec![Expr::column("users", "name", 1)],
339 };
340
341 let result = pass.optimize(plan);
342
343 if let PhysicalPlan::Project { input, .. } = result {
345 if let PhysicalPlan::TopN { limit, offset, .. } = *input {
346 assert_eq!(limit, 5);
347 assert_eq!(offset, 0);
348 } else {
349 panic!("Expected TopN inside Project");
350 }
351 } else {
352 panic!("Expected Project, got {:?}", result);
353 }
354 }
355
356 #[test]
357 fn test_multiple_sort_columns() {
358 let pass = TopNPushdown::new();
359
360 let plan = PhysicalPlan::Limit {
362 input: Box::new(PhysicalPlan::Sort {
363 input: Box::new(PhysicalPlan::table_scan("users")),
364 order_by: alloc::vec![
365 (Expr::column("users", "name", 1), SortOrder::Asc),
366 (Expr::column("users", "id", 0), SortOrder::Desc),
367 ],
368 }),
369 limit: 20,
370 offset: 10,
371 };
372
373 let result = pass.optimize(plan);
374
375 if let PhysicalPlan::TopN {
377 order_by,
378 limit,
379 offset,
380 ..
381 } = result
382 {
383 assert_eq!(limit, 20);
384 assert_eq!(offset, 10);
385 assert_eq!(order_by.len(), 2);
386 assert_eq!(order_by[0].1, SortOrder::Asc);
387 assert_eq!(order_by[1].1, SortOrder::Desc);
388 } else {
389 panic!("Expected TopN, got {:?}", result);
390 }
391 }
392
393 #[test]
394 fn test_sort_in_subquery_converted() {
395 let pass = TopNPushdown::new();
396
397 let plan = PhysicalPlan::HashJoin {
399 left: Box::new(PhysicalPlan::Limit {
400 input: Box::new(PhysicalPlan::Sort {
401 input: Box::new(PhysicalPlan::table_scan("orders")),
402 order_by: alloc::vec![(Expr::column("orders", "amount", 1), SortOrder::Desc)],
403 }),
404 limit: 100,
405 offset: 0,
406 }),
407 right: Box::new(PhysicalPlan::table_scan("users")),
408 condition: Expr::eq(
409 Expr::column("orders", "user_id", 2),
410 Expr::column("users", "id", 0),
411 ),
412 join_type: crate::ast::JoinType::Inner,
413 };
414
415 let result = pass.optimize(plan);
416
417 if let PhysicalPlan::HashJoin { left, .. } = result {
419 assert!(matches!(*left, PhysicalPlan::TopN { .. }));
420 } else {
421 panic!("Expected HashJoin, got {:?}", result);
422 }
423 }
424}