1use crate::context::ExecutionContext;
26use crate::planner::PhysicalPlan;
27use alloc::boxed::Box;
28
29pub struct LimitSkipByIndexPass<'a> {
31 ctx: &'a ExecutionContext,
32}
33
34impl<'a> LimitSkipByIndexPass<'a> {
35 pub fn new(ctx: &'a ExecutionContext) -> Self {
37 Self { ctx }
38 }
39
40 pub fn optimize(&self, plan: PhysicalPlan) -> PhysicalPlan {
42 self.traverse(plan)
43 }
44
45 fn traverse(&self, plan: PhysicalPlan) -> PhysicalPlan {
46 match plan {
47 PhysicalPlan::Limit {
48 input,
49 limit,
50 offset,
51 } => {
52 let optimized_input = self.traverse(*input);
53
54 if let Some(optimized) =
56 self.try_push_to_index_scan(optimized_input.clone(), limit, offset)
57 {
58 return optimized;
59 }
60
61 PhysicalPlan::Limit {
63 input: Box::new(optimized_input),
64 limit,
65 offset,
66 }
67 }
68
69 PhysicalPlan::Filter { input, predicate } => PhysicalPlan::Filter {
71 input: Box::new(self.traverse(*input)),
72 predicate,
73 },
74
75 PhysicalPlan::Project { input, columns } => PhysicalPlan::Project {
76 input: Box::new(self.traverse(*input)),
77 columns,
78 },
79
80 PhysicalPlan::Sort { input, order_by } => PhysicalPlan::Sort {
81 input: Box::new(self.traverse(*input)),
82 order_by,
83 },
84
85 PhysicalPlan::TopN {
86 input,
87 order_by,
88 limit,
89 offset,
90 } => PhysicalPlan::TopN {
91 input: Box::new(self.traverse(*input)),
92 order_by,
93 limit,
94 offset,
95 },
96
97 PhysicalPlan::HashJoin {
98 left,
99 right,
100 condition,
101 join_type,
102 } => PhysicalPlan::HashJoin {
103 left: Box::new(self.traverse(*left)),
104 right: Box::new(self.traverse(*right)),
105 condition,
106 join_type,
107 },
108
109 PhysicalPlan::SortMergeJoin {
110 left,
111 right,
112 condition,
113 join_type,
114 } => PhysicalPlan::SortMergeJoin {
115 left: Box::new(self.traverse(*left)),
116 right: Box::new(self.traverse(*right)),
117 condition,
118 join_type,
119 },
120
121 PhysicalPlan::NestedLoopJoin {
122 left,
123 right,
124 condition,
125 join_type,
126 } => PhysicalPlan::NestedLoopJoin {
127 left: Box::new(self.traverse(*left)),
128 right: Box::new(self.traverse(*right)),
129 condition,
130 join_type,
131 },
132
133 PhysicalPlan::IndexNestedLoopJoin {
134 outer,
135 inner_table,
136 inner_index,
137 condition,
138 join_type,
139 } => PhysicalPlan::IndexNestedLoopJoin {
140 outer: Box::new(self.traverse(*outer)),
141 inner_table,
142 inner_index,
143 condition,
144 join_type,
145 },
146
147 PhysicalPlan::HashAggregate {
148 input,
149 group_by,
150 aggregates,
151 } => PhysicalPlan::HashAggregate {
152 input: Box::new(self.traverse(*input)),
153 group_by,
154 aggregates,
155 },
156
157 PhysicalPlan::CrossProduct { left, right } => PhysicalPlan::CrossProduct {
158 left: Box::new(self.traverse(*left)),
159 right: Box::new(self.traverse(*right)),
160 },
161
162 PhysicalPlan::NoOp { input } => PhysicalPlan::NoOp {
163 input: Box::new(self.traverse(*input)),
164 },
165
166 plan @ (PhysicalPlan::TableScan { .. }
168 | PhysicalPlan::IndexScan { .. }
169 | PhysicalPlan::IndexGet { .. }
170 | PhysicalPlan::IndexInGet { .. }
171 | PhysicalPlan::GinIndexScan { .. }
172 | PhysicalPlan::GinIndexScanMulti { .. }
173 | PhysicalPlan::Empty) => plan,
174 }
175 }
176
177 fn try_push_to_index_scan(
180 &self,
181 plan: PhysicalPlan,
182 limit: usize,
183 offset: usize,
184 ) -> Option<PhysicalPlan> {
185 match plan {
186 PhysicalPlan::IndexScan {
188 table,
189 index,
190 range_start,
191 range_end,
192 include_start,
193 include_end,
194 limit: existing_limit,
195 offset: existing_offset,
196 reverse,
197 } => {
198 let (new_limit, new_offset) = if existing_limit.is_some() {
200 return None;
202 } else {
203 (Some(limit), Some(offset))
204 };
205
206 if existing_offset.is_some() && existing_offset.unwrap() > 0 {
208 return None;
209 }
210
211 Some(PhysicalPlan::IndexScan {
212 table,
213 index,
214 range_start,
215 range_end,
216 include_start,
217 include_end,
218 limit: new_limit,
219 offset: new_offset,
220 reverse,
221 })
222 }
223
224 PhysicalPlan::Project { input, columns } => {
226 let optimized = self.try_push_to_index_scan(*input, limit, offset)?;
227 Some(PhysicalPlan::Project {
228 input: Box::new(optimized),
229 columns,
230 })
231 }
232
233 PhysicalPlan::TableScan { table } => {
235 let pk_index = self.ctx.find_primary_index(&table)?;
237
238 Some(PhysicalPlan::IndexScan {
239 table,
240 index: pk_index.name.clone(),
241 range_start: None,
242 range_end: None,
243 include_start: true,
244 include_end: true,
245 limit: Some(limit),
246 offset: Some(offset),
247 reverse: false,
248 })
249 }
250
251 _ => None,
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::ast::{Expr, SortOrder};
263 use crate::context::{IndexInfo, TableStats};
264
265 fn create_test_context() -> ExecutionContext {
266 let mut ctx = ExecutionContext::new();
267
268 ctx.register_table(
269 "users",
270 TableStats {
271 row_count: 1000,
272 is_sorted: false,
273 indexes: alloc::vec![IndexInfo::new(
274 "idx_id",
275 alloc::vec!["id".into()],
276 true
277 )],
278 },
279 );
280
281 ctx
282 }
283
284 #[test]
285 fn test_limit_pushed_to_index_scan() {
286 let ctx = create_test_context();
287 let pass = LimitSkipByIndexPass::new(&ctx);
288
289 let plan = PhysicalPlan::Limit {
291 input: Box::new(PhysicalPlan::IndexScan {
292 table: "users".into(),
293 index: "idx_id".into(),
294 range_start: None,
295 range_end: None,
296 include_start: true,
297 include_end: true,
298 limit: None,
299 offset: None,
300 reverse: false,
301 }),
302 limit: 10,
303 offset: 5,
304 };
305
306 let result = pass.optimize(plan);
307
308 if let PhysicalPlan::IndexScan { limit, offset, .. } = result {
310 assert_eq!(limit, Some(10));
311 assert_eq!(offset, Some(5));
312 } else {
313 panic!("Expected IndexScan, got {:?}", result);
314 }
315 }
316
317 #[test]
318 fn test_limit_pushed_through_project() {
319 let ctx = create_test_context();
320 let pass = LimitSkipByIndexPass::new(&ctx);
321
322 let plan = PhysicalPlan::Limit {
324 input: Box::new(PhysicalPlan::Project {
325 input: Box::new(PhysicalPlan::IndexScan {
326 table: "users".into(),
327 index: "idx_id".into(),
328 range_start: None,
329 range_end: None,
330 include_start: true,
331 include_end: true,
332 limit: None,
333 offset: None,
334 reverse: false,
335 }),
336 columns: alloc::vec![Expr::column("users", "id", 0)],
337 }),
338 limit: 10,
339 offset: 0,
340 };
341
342 let result = pass.optimize(plan);
343
344 if let PhysicalPlan::Project { input, .. } = result {
346 if let PhysicalPlan::IndexScan { limit, offset, .. } = *input {
347 assert_eq!(limit, Some(10));
348 assert_eq!(offset, Some(0));
349 } else {
350 panic!("Expected IndexScan inside Project");
351 }
352 } else {
353 panic!("Expected Project, got {:?}", result);
354 }
355 }
356
357 #[test]
358 fn test_limit_not_pushed_through_filter() {
359 let ctx = create_test_context();
360 let pass = LimitSkipByIndexPass::new(&ctx);
361
362 let plan = PhysicalPlan::Limit {
364 input: Box::new(PhysicalPlan::Filter {
365 input: Box::new(PhysicalPlan::IndexScan {
366 table: "users".into(),
367 index: "idx_id".into(),
368 range_start: None,
369 range_end: None,
370 include_start: true,
371 include_end: true,
372 limit: None,
373 offset: None,
374 reverse: false,
375 }),
376 predicate: Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
377 }),
378 limit: 10,
379 offset: 0,
380 };
381
382 let result = pass.optimize(plan);
383
384 assert!(matches!(result, PhysicalPlan::Limit { .. }));
386 }
387
388 #[test]
389 fn test_limit_on_table_scan_uses_primary_key() {
390 let ctx = create_test_context();
391 let pass = LimitSkipByIndexPass::new(&ctx);
392
393 let plan = PhysicalPlan::Limit {
395 input: Box::new(PhysicalPlan::table_scan("users")),
396 limit: 10,
397 offset: 0,
398 };
399
400 let result = pass.optimize(plan);
401
402 if let PhysicalPlan::IndexScan { table, index, limit, offset, .. } = result {
404 assert_eq!(table, "users");
405 assert_eq!(index, "idx_id"); assert_eq!(limit, Some(10));
407 assert_eq!(offset, Some(0));
408 } else {
409 panic!("Expected IndexScan, got {:?}", result);
410 }
411 }
412
413 #[test]
414 fn test_limit_on_table_scan_no_pk_not_optimized() {
415 let mut ctx = ExecutionContext::new();
417 ctx.register_table(
418 "logs",
419 TableStats {
420 row_count: 1000,
421 is_sorted: false,
422 indexes: alloc::vec![], },
424 );
425
426 let pass = LimitSkipByIndexPass::new(&ctx);
427
428 let plan = PhysicalPlan::Limit {
430 input: Box::new(PhysicalPlan::table_scan("logs")),
431 limit: 10,
432 offset: 0,
433 };
434
435 let result = pass.optimize(plan);
436
437 assert!(matches!(result, PhysicalPlan::Limit { .. }));
439 if let PhysicalPlan::Limit { input, .. } = result {
440 assert!(matches!(*input, PhysicalPlan::TableScan { .. }));
441 }
442 }
443
444 #[test]
445 fn test_limit_after_sort_not_optimized() {
446 let ctx = create_test_context();
447 let pass = LimitSkipByIndexPass::new(&ctx);
448
449 let plan = PhysicalPlan::Limit {
452 input: Box::new(PhysicalPlan::Sort {
453 input: Box::new(PhysicalPlan::IndexScan {
454 table: "users".into(),
455 index: "idx_id".into(),
456 range_start: None,
457 range_end: None,
458 include_start: true,
459 include_end: true,
460 limit: None,
461 offset: None,
462 reverse: false,
463 }),
464 order_by: alloc::vec![(Expr::column("users", "id", 0), SortOrder::Asc)],
465 }),
466 limit: 10,
467 offset: 0,
468 };
469
470 let result = pass.optimize(plan);
471
472 assert!(matches!(result, PhysicalPlan::Limit { .. }));
474 if let PhysicalPlan::Limit { input, .. } = result {
475 assert!(matches!(*input, PhysicalPlan::Sort { .. }));
476 }
477 }
478
479 #[test]
480 fn test_limit_after_aggregate_not_optimized() {
481 let ctx = create_test_context();
482 let pass = LimitSkipByIndexPass::new(&ctx);
483
484 let plan = PhysicalPlan::Limit {
487 input: Box::new(PhysicalPlan::HashAggregate {
488 input: Box::new(PhysicalPlan::IndexScan {
489 table: "users".into(),
490 index: "idx_id".into(),
491 range_start: None,
492 range_end: None,
493 include_start: true,
494 include_end: true,
495 limit: None,
496 offset: None,
497 reverse: false,
498 }),
499 group_by: alloc::vec![],
500 aggregates: alloc::vec![],
501 }),
502 limit: 10,
503 offset: 0,
504 };
505
506 let result = pass.optimize(plan);
507
508 assert!(matches!(result, PhysicalPlan::Limit { .. }));
510 if let PhysicalPlan::Limit { input, .. } = result {
511 assert!(matches!(*input, PhysicalPlan::HashAggregate { .. }));
512 }
513 }
514
515 #[test]
516 fn test_existing_limit_not_overridden() {
517 let ctx = create_test_context();
518 let pass = LimitSkipByIndexPass::new(&ctx);
519
520 let plan = PhysicalPlan::Limit {
522 input: Box::new(PhysicalPlan::IndexScan {
523 table: "users".into(),
524 index: "idx_id".into(),
525 range_start: None,
526 range_end: None,
527 include_start: true,
528 include_end: true,
529 limit: Some(10),
530 offset: None,
531 reverse: false,
532 }),
533 limit: 5,
534 offset: 0,
535 };
536
537 let result = pass.optimize(plan);
538
539 assert!(matches!(result, PhysicalPlan::Limit { .. }));
541 }
542}