Skip to main content

quill_sql/optimizer/rule/
push_down_limit.rs

1use crate::error::QuillSQLResult;
2use crate::optimizer::logical_optimizer::ApplyOrder;
3use crate::optimizer::LogicalOptimizerRule;
4use crate::plan::logical_plan::{LogicalPlan, Sort};
5
6pub struct PushDownLimit;
7
8impl LogicalOptimizerRule for PushDownLimit {
9    fn try_optimize(&self, plan: &LogicalPlan) -> QuillSQLResult<Option<LogicalPlan>> {
10        let LogicalPlan::Limit(limit) = plan else {
11            return Ok(None);
12        };
13
14        let Some(limit_value) = limit.limit else {
15            return Ok(None);
16        };
17
18        match limit.input.as_ref() {
19            LogicalPlan::Sort(sort) => {
20                let new_limit = {
21                    let sort_limit = limit.offset + limit_value;
22                    Some(sort.limit.map(|f| f.min(sort_limit)).unwrap_or(sort_limit))
23                };
24                if new_limit == sort.limit {
25                    Ok(None)
26                } else {
27                    let new_sort = LogicalPlan::Sort(Sort {
28                        order_by: sort.order_by.clone(),
29                        input: sort.input.clone(),
30                        limit: new_limit,
31                    });
32                    plan.with_new_inputs(&[new_sort]).map(Some)
33                }
34            }
35            _ => Ok(None),
36        }
37    }
38
39    fn name(&self) -> &str {
40        "PushDownLimit"
41    }
42
43    fn apply_order(&self) -> Option<ApplyOrder> {
44        Some(ApplyOrder::TopDown)
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use crate::database::Database;
51    use crate::optimizer::rule::PushDownLimit;
52    use crate::optimizer::LogicalOptimizer;
53    use crate::plan::logical_plan::{LogicalPlan, Sort};
54    use std::sync::Arc;
55
56    fn build_optimizer() -> LogicalOptimizer {
57        LogicalOptimizer::with_rules(vec![Arc::new(PushDownLimit)])
58    }
59
60    #[test]
61    fn push_down_limit() {
62        let mut db = Database::new_temp().unwrap();
63        db.run("create table t1 (a int)").unwrap();
64
65        let plan = db
66            .create_logical_plan("select a from t1 order by a limit 10")
67            .unwrap();
68        let optimized_plan = build_optimizer().optimize(&plan).unwrap();
69
70        if let LogicalPlan::Limit(limit) = optimized_plan {
71            if let LogicalPlan::Sort(Sort { limit, .. }) = limit.input.as_ref() {
72                assert_eq!(limit, &Some(10));
73            } else {
74                panic!("the second node should be limit");
75            }
76        } else {
77            panic!("the first node should be limit");
78        }
79    }
80}