datafusion_optimizer/
eliminate_limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateLimit`] eliminates `LIMIT` when possible
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType};
24use std::sync::Arc;
25
26/// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is
27/// greater than or equal to current's fetch
28///
29/// It can cooperate with `propagate_empty_relation` and `limit_push_down`. on a
30/// plan with an empty relation.
31///
32/// This rule also removes OFFSET 0 from the [LogicalPlan]
33#[derive(Default, Debug)]
34pub struct EliminateLimit;
35
36impl EliminateLimit {
37    #[allow(missing_docs)]
38    pub fn new() -> Self {
39        Self {}
40    }
41}
42
43impl OptimizerRule for EliminateLimit {
44    fn name(&self) -> &str {
45        "eliminate_limit"
46    }
47
48    fn apply_order(&self) -> Option<ApplyOrder> {
49        Some(ApplyOrder::BottomUp)
50    }
51
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn rewrite(
57        &self,
58        plan: LogicalPlan,
59        _config: &dyn OptimizerConfig,
60    ) -> Result<Transformed<LogicalPlan>, datafusion_common::DataFusionError> {
61        match plan {
62            LogicalPlan::Limit(limit) => {
63                // Only supports rewriting for literal fetch
64                let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
65                    return Ok(Transformed::no(LogicalPlan::Limit(limit)));
66                };
67
68                if let Some(v) = fetch {
69                    if v == 0 {
70                        return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
71                            EmptyRelation {
72                                produce_one_row: false,
73                                schema: Arc::clone(limit.input.schema()),
74                            },
75                        )));
76                    }
77                } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) {
78                    // If fetch is `None` and skip is 0, then Limit takes no effect and
79                    // we can remove it. Its input also can be Limit, so we should apply again.
80                    #[allow(clippy::used_underscore_binding)]
81                    return self.rewrite(Arc::unwrap_or_clone(limit.input), _config);
82                }
83                Ok(Transformed::no(LogicalPlan::Limit(limit)))
84            }
85            _ => Ok(Transformed::no(plan)),
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::optimizer::Optimizer;
94    use crate::test::*;
95    use crate::OptimizerContext;
96    use datafusion_common::Column;
97    use datafusion_expr::{
98        col,
99        logical_plan::{builder::LogicalPlanBuilder, JoinType},
100    };
101    use std::sync::Arc;
102
103    use crate::push_down_limit::PushDownLimit;
104    use datafusion_expr::test::function_stub::sum;
105
106    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
107    fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
108        let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]);
109        let optimized_plan =
110            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
111
112        let formatted_plan = format!("{optimized_plan}");
113        assert_eq!(formatted_plan, expected);
114        Ok(())
115    }
116
117    fn assert_optimized_plan_eq_with_pushdown(
118        plan: LogicalPlan,
119        expected: &str,
120    ) -> Result<()> {
121        fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
122        let config = OptimizerContext::new().with_max_passes(1);
123        let optimizer = Optimizer::with_rules(vec![
124            Arc::new(PushDownLimit::new()),
125            Arc::new(EliminateLimit::new()),
126        ]);
127        let optimized_plan = optimizer
128            .optimize(plan, &config, observe)
129            .expect("failed to optimize plan");
130        let formatted_plan = format!("{optimized_plan}");
131        assert_eq!(formatted_plan, expected);
132        Ok(())
133    }
134
135    #[test]
136    fn limit_0_root() -> Result<()> {
137        let table_scan = test_table_scan().unwrap();
138        let plan = LogicalPlanBuilder::from(table_scan)
139            .aggregate(vec![col("a")], vec![sum(col("b"))])?
140            .limit(0, Some(0))?
141            .build()?;
142        // No aggregate / scan / limit
143        let expected = "EmptyRelation";
144        assert_optimized_plan_eq(plan, expected)
145    }
146
147    #[test]
148    fn limit_0_nested() -> Result<()> {
149        let table_scan = test_table_scan()?;
150        let plan1 = LogicalPlanBuilder::from(table_scan.clone())
151            .aggregate(vec![col("a")], vec![sum(col("b"))])?
152            .build()?;
153        let plan = LogicalPlanBuilder::from(table_scan)
154            .aggregate(vec![col("a")], vec![sum(col("b"))])?
155            .limit(0, Some(0))?
156            .union(plan1)?
157            .build()?;
158
159        // Left side is removed
160        let expected = "Union\
161            \n  EmptyRelation\
162            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
163            \n    TableScan: test";
164        assert_optimized_plan_eq(plan, expected)
165    }
166
167    #[test]
168    fn limit_fetch_with_ancestor_limit_skip() -> Result<()> {
169        let table_scan = test_table_scan()?;
170        let plan = LogicalPlanBuilder::from(table_scan)
171            .aggregate(vec![col("a")], vec![sum(col("b"))])?
172            .limit(0, Some(2))?
173            .limit(2, None)?
174            .build()?;
175
176        // No aggregate / scan / limit
177        let expected = "EmptyRelation";
178        assert_optimized_plan_eq_with_pushdown(plan, expected)
179    }
180
181    #[test]
182    fn multi_limit_offset_sort_eliminate() -> Result<()> {
183        let table_scan = test_table_scan()?;
184        let plan = LogicalPlanBuilder::from(table_scan)
185            .aggregate(vec![col("a")], vec![sum(col("b"))])?
186            .limit(0, Some(2))?
187            .sort_by(vec![col("a")])?
188            .limit(2, Some(1))?
189            .build()?;
190
191        // After remove global-state, we don't record the parent <skip, fetch>
192        // So, bottom don't know parent info, so can't eliminate.
193        let expected = "Limit: skip=2, fetch=1\
194        \n  Sort: test.a ASC NULLS LAST, fetch=3\
195        \n    Limit: skip=0, fetch=2\
196        \n      Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
197        \n        TableScan: test";
198        assert_optimized_plan_eq_with_pushdown(plan, expected)
199    }
200
201    #[test]
202    fn limit_fetch_with_ancestor_limit_fetch() -> Result<()> {
203        let table_scan = test_table_scan()?;
204        let plan = LogicalPlanBuilder::from(table_scan)
205            .aggregate(vec![col("a")], vec![sum(col("b"))])?
206            .limit(0, Some(2))?
207            .sort_by(vec![col("a")])?
208            .limit(0, Some(1))?
209            .build()?;
210
211        let expected = "Limit: skip=0, fetch=1\
212            \n  Sort: test.a ASC NULLS LAST\
213            \n    Limit: skip=0, fetch=2\
214            \n      Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
215            \n        TableScan: test";
216        assert_optimized_plan_eq(plan, expected)
217    }
218
219    #[test]
220    fn limit_with_ancestor_limit() -> Result<()> {
221        let table_scan = test_table_scan().unwrap();
222        let plan = LogicalPlanBuilder::from(table_scan)
223            .aggregate(vec![col("a")], vec![sum(col("b"))])?
224            .limit(2, Some(1))?
225            .sort_by(vec![col("a")])?
226            .limit(3, Some(1))?
227            .build()?;
228
229        let expected = "Limit: skip=3, fetch=1\
230        \n  Sort: test.a ASC NULLS LAST\
231        \n    Limit: skip=2, fetch=1\
232        \n      Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
233        \n        TableScan: test";
234        assert_optimized_plan_eq(plan, expected)
235    }
236
237    #[test]
238    fn limit_join_with_ancestor_limit() -> Result<()> {
239        let table_scan = test_table_scan()?;
240        let table_scan_inner = test_table_scan_with_name("test1")?;
241        let plan = LogicalPlanBuilder::from(table_scan)
242            .limit(2, Some(1))?
243            .join_using(
244                table_scan_inner,
245                JoinType::Inner,
246                vec![Column::from_name("a".to_string())],
247            )?
248            .limit(3, Some(1))?
249            .build()?;
250
251        let expected = "Limit: skip=3, fetch=1\
252            \n  Inner Join: Using test.a = test1.a\
253            \n    Limit: skip=2, fetch=1\
254            \n      TableScan: test\
255            \n    TableScan: test1";
256        assert_optimized_plan_eq(plan, expected)
257    }
258
259    #[test]
260    fn remove_zero_offset() -> Result<()> {
261        let table_scan = test_table_scan()?;
262        let plan = LogicalPlanBuilder::from(table_scan)
263            .aggregate(vec![col("a")], vec![sum(col("b"))])?
264            .limit(0, None)?
265            .build()?;
266
267        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
268            \n  TableScan: test";
269        assert_optimized_plan_eq(plan, expected)
270    }
271}