datafusion_physical_optimizer/
limit_pushdown_past_window.rs1use crate::PhysicalOptimizerRule;
19use datafusion_common::ScalarValue;
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode};
22use datafusion_expr::{LimitEffect, WindowFrameBound, WindowFrameUnits};
23use datafusion_physical_expr::window::{
24 PlainAggregateWindowExpr, SlidingAggregateWindowExpr, StandardWindowExpr,
25 StandardWindowFunctionExpr, WindowExpr,
26};
27use datafusion_physical_plan::execution_plan::CardinalityEffect;
28use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
29use datafusion_physical_plan::repartition::RepartitionExec;
30use datafusion_physical_plan::sorts::sort::SortExec;
31use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
32use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr};
33use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
34use std::cmp;
35use std::sync::Arc;
36
37#[derive(Default, Clone, Debug)]
42pub struct LimitPushPastWindows;
43
44impl LimitPushPastWindows {
45 pub fn new() -> Self {
46 Self
47 }
48}
49
50#[derive(Eq, PartialEq)]
51enum Phase {
52 FindOrGrow,
53 Apply,
54}
55
56#[derive(Default)]
57struct TraverseState {
58 pub limit: Option<usize>,
59 pub lookahead: usize,
60}
61
62impl TraverseState {
63 pub fn reset_limit(&mut self, limit: Option<usize>) {
64 self.limit = limit;
65 self.lookahead = 0;
66 }
67
68 pub fn max_lookahead(&mut self, new_val: usize) {
69 self.lookahead = self.lookahead.max(new_val);
70 }
71}
72
73impl PhysicalOptimizerRule for LimitPushPastWindows {
74 fn optimize(
75 &self,
76 original: Arc<dyn ExecutionPlan>,
77 config: &ConfigOptions,
78 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
79 if !config.optimizer.enable_window_limits {
80 return Ok(original);
81 }
82 let mut ctx = TraverseState::default();
83 let mut phase = Phase::FindOrGrow;
84 let result = original.transform_down(|node| {
85 let reset = |node,
87 ctx: &mut TraverseState|
88 -> datafusion_common::Result<
89 Transformed<Arc<dyn ExecutionPlan>>,
90 > {
91 ctx.limit = None;
92 ctx.lookahead = 0;
93 Ok(Transformed::no(node))
94 };
95
96 if node.children().len() > 1 {
98 return reset(node, &mut ctx);
99 }
100
101 if phase == Phase::FindOrGrow && get_limit(&node, &mut ctx) {
103 return Ok(Transformed::no(node));
104 }
105
106 if let Some(window) = node.downcast_ref::<BoundedWindowAggExec>() {
108 phase = Phase::Apply;
109 if !grow_limit(window, &mut ctx) {
110 return reset(node, &mut ctx);
111 }
112 return Ok(Transformed::no(node));
113 }
114
115 if phase == Phase::Apply
117 && let Some(out) = apply_limit(&node, &mut ctx)
118 {
119 return Ok(out);
120 }
121
122 if !node.supports_limit_pushdown() {
124 return reset(node, &mut ctx);
125 }
126 if let Some(part) = node.downcast_ref::<RepartitionExec>() {
127 let output = part.partitioning().partition_count();
128 let input = part.input().output_partitioning().partition_count();
129 if output < input {
130 return reset(node, &mut ctx);
131 }
132 }
133 match node.cardinality_effect() {
134 CardinalityEffect::Unknown => return reset(node, &mut ctx),
135 CardinalityEffect::LowerEqual => return reset(node, &mut ctx),
136 CardinalityEffect::Equal => {}
137 CardinalityEffect::GreaterEqual => {}
138 }
139
140 Ok(Transformed::no(node))
141 })?;
142 Ok(result.data)
143 }
144
145 fn name(&self) -> &str {
146 "LimitPushPastWindows"
147 }
148
149 fn schema_check(&self) -> bool {
150 false }
152}
153
154fn grow_limit(window: &BoundedWindowAggExec, ctx: &mut TraverseState) -> bool {
155 let mut max_rel = 0;
156 for expr in window.window_expr().iter() {
157 match get_limit_effect(expr) {
159 LimitEffect::None => {}
160 LimitEffect::Unknown => return false,
161 LimitEffect::Relative(rel) => max_rel = max_rel.max(rel),
162 LimitEffect::Absolute(val) => {
163 let cur = ctx.limit.unwrap_or(0);
164 ctx.limit = Some(cur.max(val))
165 }
166 }
167
168 let frame = expr.get_window_frame();
170 if frame.units != WindowFrameUnits::Rows {
171 return false; }
173 let Some(end_bound) = bound_to_usize(&frame.end_bound) else {
174 return false; };
176 ctx.max_lookahead(end_bound);
177 }
178
179 ctx.max_lookahead(ctx.lookahead + max_rel);
181 true
182}
183
184fn apply_limit(
185 node: &Arc<dyn ExecutionPlan>,
186 ctx: &mut TraverseState,
187) -> Option<Transformed<Arc<dyn ExecutionPlan>>> {
188 if !node.is::<SortExec>() && !node.is::<SortPreservingMergeExec>() {
189 return None;
190 }
191 let latest = ctx.limit.take();
192 let Some(fetch) = latest else {
193 ctx.limit = None;
194 ctx.lookahead = 0;
195 return Some(Transformed::no(Arc::clone(node)));
196 };
197 let fetch = match node.fetch() {
198 None => fetch + ctx.lookahead,
199 Some(existing) => cmp::min(existing, fetch + ctx.lookahead),
200 };
201 Some(Transformed::complete(node.with_fetch(Some(fetch)).unwrap()))
202}
203
204fn get_limit(node: &Arc<dyn ExecutionPlan>, ctx: &mut TraverseState) -> bool {
205 if let Some(limit) = node.downcast_ref::<GlobalLimitExec>() {
206 ctx.reset_limit(limit.fetch().map(|fetch| fetch + limit.skip()));
207 return true;
208 }
209 if let Some(limit) = node.downcast_ref::<LocalLimitExec>() {
212 ctx.reset_limit(Some(limit.fetch()));
213 return true;
214 }
215 if let Some(limit) = node.downcast_ref::<SortPreservingMergeExec>() {
216 ctx.reset_limit(limit.fetch());
217 return true;
218 }
219 false
220}
221
222fn get_limit_effect(expr: &Arc<dyn WindowExpr>) -> LimitEffect {
235 if expr.as_any().is::<PlainAggregateWindowExpr>()
237 || expr.as_any().is::<SlidingAggregateWindowExpr>()
238 {
239 return LimitEffect::None;
240 }
241
242 let Some(swe) = expr.as_any().downcast_ref::<StandardWindowExpr>() else {
244 return LimitEffect::Unknown; };
246 let swfe = swe.get_standard_func_expr();
247 let Some(udf) = swfe.as_any().downcast_ref::<WindowUDFExpr>() else {
248 return LimitEffect::Unknown; };
250 udf.limit_effect()
251}
252
253fn bound_to_usize(bound: &WindowFrameBound) -> Option<usize> {
254 match bound {
255 WindowFrameBound::Preceding(_) => Some(0),
256 WindowFrameBound::CurrentRow => Some(0),
257 WindowFrameBound::Following(ScalarValue::UInt64(Some(scalar))) => {
258 Some(*scalar as usize)
259 }
260 _ => None,
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use arrow::datatypes::{DataType, Field, Schema};
268 use datafusion_expr::WindowFrame;
269 use datafusion_functions_window::row_number::row_number_udwf;
270 use datafusion_physical_expr::expressions::col;
271 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
272 use datafusion_physical_plan::InputOrderMode;
273 use datafusion_physical_plan::displayable;
274 use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
275 use datafusion_physical_plan::windows::{
276 BoundedWindowAggExec, create_udwf_window_expr,
277 };
278 use insta::assert_snapshot;
279
280 fn plan_str(plan: &dyn ExecutionPlan) -> String {
281 displayable(plan).indent(true).to_string()
282 }
283
284 fn schema() -> Arc<Schema> {
285 Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]))
286 }
287
288 fn build_window_plan(
290 use_local_limit: bool,
291 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
292 let s = schema();
293 let input: Arc<dyn ExecutionPlan> =
294 Arc::new(PlaceholderRowExec::new(Arc::clone(&s)));
295
296 let ordering =
297 LexOrdering::new(vec![PhysicalSortExpr::new_default(col("a", &s)?).asc()])
298 .unwrap();
299
300 let sort: Arc<dyn ExecutionPlan> = Arc::new(
301 SortExec::new(ordering.clone(), input).with_preserve_partitioning(true),
302 );
303
304 let window_expr = Arc::new(StandardWindowExpr::new(
305 create_udwf_window_expr(
306 &row_number_udwf(),
307 &[],
308 &s,
309 "row_number".to_string(),
310 false,
311 )?,
312 &[],
313 ordering.as_ref(),
314 Arc::new(WindowFrame::new_bounds(
315 WindowFrameUnits::Rows,
316 WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
317 WindowFrameBound::CurrentRow,
318 )),
319 ));
320
321 let window: Arc<dyn ExecutionPlan> = Arc::new(BoundedWindowAggExec::try_new(
322 vec![window_expr],
323 sort,
324 InputOrderMode::Sorted,
325 true,
326 )?);
327
328 let limit: Arc<dyn ExecutionPlan> = if use_local_limit {
329 Arc::new(LocalLimitExec::new(window, 100))
330 } else {
331 Arc::new(GlobalLimitExec::new(window, 0, Some(100)))
332 };
333
334 Ok(limit)
335 }
336
337 fn optimize(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
338 let mut config = ConfigOptions::new();
339 config.optimizer.enable_window_limits = true;
340 LimitPushPastWindows::new().optimize(plan, &config).unwrap()
341 }
342
343 #[test]
345 fn global_limit_pushes_past_window() {
346 let plan = build_window_plan(false).unwrap();
347 let optimized = optimize(plan);
348 assert_snapshot!(plan_str(optimized.as_ref()), @r#"
349 GlobalLimitExec: skip=0, fetch=100
350 BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
351 SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true]
352 PlaceholderRowExec
353 "#);
354 }
355
356 #[test]
359 fn local_limit_pushes_past_window() {
360 let plan = build_window_plan(true).unwrap();
361 let optimized = optimize(plan);
362 assert_snapshot!(plan_str(optimized.as_ref()), @r#"
363 LocalLimitExec: fetch=100
364 BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
365 SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true]
366 PlaceholderRowExec
367 "#);
368 }
369}