datafusion_optimizer/
eliminate_limit.rs1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType};
24use std::sync::Arc;
25
26#[derive(Default, Debug)]
34pub struct EliminateLimit;
35
36impl EliminateLimit {
37 #[allow(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl OptimizerRule for EliminateLimit {
44 fn name(&self) -> &str {
45 "eliminate_limit"
46 }
47
48 fn apply_order(&self) -> Option<ApplyOrder> {
49 Some(ApplyOrder::BottomUp)
50 }
51
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn rewrite(
57 &self,
58 plan: LogicalPlan,
59 _config: &dyn OptimizerConfig,
60 ) -> Result<Transformed<LogicalPlan>, datafusion_common::DataFusionError> {
61 match plan {
62 LogicalPlan::Limit(limit) => {
63 let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
65 return Ok(Transformed::no(LogicalPlan::Limit(limit)));
66 };
67
68 if let Some(v) = fetch {
69 if v == 0 {
70 return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
71 EmptyRelation {
72 produce_one_row: false,
73 schema: Arc::clone(limit.input.schema()),
74 },
75 )));
76 }
77 } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) {
78 #[allow(clippy::used_underscore_binding)]
81 return self.rewrite(Arc::unwrap_or_clone(limit.input), _config);
82 }
83 Ok(Transformed::no(LogicalPlan::Limit(limit)))
84 }
85 _ => Ok(Transformed::no(plan)),
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::optimizer::Optimizer;
94 use crate::test::*;
95 use crate::OptimizerContext;
96 use datafusion_common::Column;
97 use datafusion_expr::{
98 col,
99 logical_plan::{builder::LogicalPlanBuilder, JoinType},
100 };
101 use std::sync::Arc;
102
103 use crate::push_down_limit::PushDownLimit;
104 use datafusion_expr::test::function_stub::sum;
105
106 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
107 fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
108 let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]);
109 let optimized_plan =
110 optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
111
112 let formatted_plan = format!("{optimized_plan}");
113 assert_eq!(formatted_plan, expected);
114 Ok(())
115 }
116
117 fn assert_optimized_plan_eq_with_pushdown(
118 plan: LogicalPlan,
119 expected: &str,
120 ) -> Result<()> {
121 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
122 let config = OptimizerContext::new().with_max_passes(1);
123 let optimizer = Optimizer::with_rules(vec![
124 Arc::new(PushDownLimit::new()),
125 Arc::new(EliminateLimit::new()),
126 ]);
127 let optimized_plan = optimizer
128 .optimize(plan, &config, observe)
129 .expect("failed to optimize plan");
130 let formatted_plan = format!("{optimized_plan}");
131 assert_eq!(formatted_plan, expected);
132 Ok(())
133 }
134
135 #[test]
136 fn limit_0_root() -> Result<()> {
137 let table_scan = test_table_scan().unwrap();
138 let plan = LogicalPlanBuilder::from(table_scan)
139 .aggregate(vec![col("a")], vec![sum(col("b"))])?
140 .limit(0, Some(0))?
141 .build()?;
142 let expected = "EmptyRelation";
144 assert_optimized_plan_eq(plan, expected)
145 }
146
147 #[test]
148 fn limit_0_nested() -> Result<()> {
149 let table_scan = test_table_scan()?;
150 let plan1 = LogicalPlanBuilder::from(table_scan.clone())
151 .aggregate(vec![col("a")], vec![sum(col("b"))])?
152 .build()?;
153 let plan = LogicalPlanBuilder::from(table_scan)
154 .aggregate(vec![col("a")], vec![sum(col("b"))])?
155 .limit(0, Some(0))?
156 .union(plan1)?
157 .build()?;
158
159 let expected = "Union\
161 \n EmptyRelation\
162 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
163 \n TableScan: test";
164 assert_optimized_plan_eq(plan, expected)
165 }
166
167 #[test]
168 fn limit_fetch_with_ancestor_limit_skip() -> Result<()> {
169 let table_scan = test_table_scan()?;
170 let plan = LogicalPlanBuilder::from(table_scan)
171 .aggregate(vec![col("a")], vec![sum(col("b"))])?
172 .limit(0, Some(2))?
173 .limit(2, None)?
174 .build()?;
175
176 let expected = "EmptyRelation";
178 assert_optimized_plan_eq_with_pushdown(plan, expected)
179 }
180
181 #[test]
182 fn multi_limit_offset_sort_eliminate() -> Result<()> {
183 let table_scan = test_table_scan()?;
184 let plan = LogicalPlanBuilder::from(table_scan)
185 .aggregate(vec![col("a")], vec![sum(col("b"))])?
186 .limit(0, Some(2))?
187 .sort_by(vec![col("a")])?
188 .limit(2, Some(1))?
189 .build()?;
190
191 let expected = "Limit: skip=2, fetch=1\
194 \n Sort: test.a ASC NULLS LAST, fetch=3\
195 \n Limit: skip=0, fetch=2\
196 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
197 \n TableScan: test";
198 assert_optimized_plan_eq_with_pushdown(plan, expected)
199 }
200
201 #[test]
202 fn limit_fetch_with_ancestor_limit_fetch() -> Result<()> {
203 let table_scan = test_table_scan()?;
204 let plan = LogicalPlanBuilder::from(table_scan)
205 .aggregate(vec![col("a")], vec![sum(col("b"))])?
206 .limit(0, Some(2))?
207 .sort_by(vec![col("a")])?
208 .limit(0, Some(1))?
209 .build()?;
210
211 let expected = "Limit: skip=0, fetch=1\
212 \n Sort: test.a ASC NULLS LAST\
213 \n Limit: skip=0, fetch=2\
214 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
215 \n TableScan: test";
216 assert_optimized_plan_eq(plan, expected)
217 }
218
219 #[test]
220 fn limit_with_ancestor_limit() -> Result<()> {
221 let table_scan = test_table_scan().unwrap();
222 let plan = LogicalPlanBuilder::from(table_scan)
223 .aggregate(vec![col("a")], vec![sum(col("b"))])?
224 .limit(2, Some(1))?
225 .sort_by(vec![col("a")])?
226 .limit(3, Some(1))?
227 .build()?;
228
229 let expected = "Limit: skip=3, fetch=1\
230 \n Sort: test.a ASC NULLS LAST\
231 \n Limit: skip=2, fetch=1\
232 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
233 \n TableScan: test";
234 assert_optimized_plan_eq(plan, expected)
235 }
236
237 #[test]
238 fn limit_join_with_ancestor_limit() -> Result<()> {
239 let table_scan = test_table_scan()?;
240 let table_scan_inner = test_table_scan_with_name("test1")?;
241 let plan = LogicalPlanBuilder::from(table_scan)
242 .limit(2, Some(1))?
243 .join_using(
244 table_scan_inner,
245 JoinType::Inner,
246 vec![Column::from_name("a".to_string())],
247 )?
248 .limit(3, Some(1))?
249 .build()?;
250
251 let expected = "Limit: skip=3, fetch=1\
252 \n Inner Join: Using test.a = test1.a\
253 \n Limit: skip=2, fetch=1\
254 \n TableScan: test\
255 \n TableScan: test1";
256 assert_optimized_plan_eq(plan, expected)
257 }
258
259 #[test]
260 fn remove_zero_offset() -> Result<()> {
261 let table_scan = test_table_scan()?;
262 let plan = LogicalPlanBuilder::from(table_scan)
263 .aggregate(vec![col("a")], vec![sum(col("b"))])?
264 .limit(0, None)?
265 .build()?;
266
267 let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
268 \n TableScan: test";
269 assert_optimized_plan_eq(plan, expected)
270 }
271}