Skip to main content

datafusion_physical_optimizer/
limit_pushdown_past_window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::PhysicalOptimizerRule;
19use datafusion_common::ScalarValue;
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode};
22use datafusion_expr::{LimitEffect, WindowFrameBound, WindowFrameUnits};
23use datafusion_physical_expr::window::{
24    PlainAggregateWindowExpr, SlidingAggregateWindowExpr, StandardWindowExpr,
25    StandardWindowFunctionExpr, WindowExpr,
26};
27use datafusion_physical_plan::execution_plan::CardinalityEffect;
28use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
29use datafusion_physical_plan::repartition::RepartitionExec;
30use datafusion_physical_plan::sorts::sort::SortExec;
31use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
32use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr};
33use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
34use std::cmp;
35use std::sync::Arc;
36
37/// This rule inspects [`ExecutionPlan`]'s attempting to find fetch limits that were not pushed
38/// down by `LimitPushdown` because [BoundedWindowAggExec]s were "in the way". If the window is
39/// bounded by [WindowFrameUnits::Rows] then we calculate the adjustment needed to grow the limit
40/// and continue pushdown.
41#[derive(Default, Clone, Debug)]
42pub struct LimitPushPastWindows;
43
44impl LimitPushPastWindows {
45    pub fn new() -> Self {
46        Self
47    }
48}
49
50#[derive(Eq, PartialEq)]
51enum Phase {
52    FindOrGrow,
53    Apply,
54}
55
56#[derive(Default)]
57struct TraverseState {
58    pub limit: Option<usize>,
59    pub lookahead: usize,
60}
61
62impl TraverseState {
63    pub fn reset_limit(&mut self, limit: Option<usize>) {
64        self.limit = limit;
65        self.lookahead = 0;
66    }
67
68    pub fn max_lookahead(&mut self, new_val: usize) {
69        self.lookahead = self.lookahead.max(new_val);
70    }
71}
72
73impl PhysicalOptimizerRule for LimitPushPastWindows {
74    fn optimize(
75        &self,
76        original: Arc<dyn ExecutionPlan>,
77        config: &ConfigOptions,
78    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
79        if !config.optimizer.enable_window_limits {
80            return Ok(original);
81        }
82        let mut ctx = TraverseState::default();
83        let mut phase = Phase::FindOrGrow;
84        let result = original.transform_down(|node| {
85            // helper closure to DRY out most the early return cases
86            let reset = |node,
87                         ctx: &mut TraverseState|
88             -> datafusion_common::Result<
89                Transformed<Arc<dyn ExecutionPlan>>,
90            > {
91                ctx.limit = None;
92                ctx.lookahead = 0;
93                Ok(Transformed::no(node))
94            };
95
96            // traversing sides of joins will require more thought
97            if node.children().len() > 1 {
98                return reset(node, &mut ctx);
99            }
100
101            // grab the latest limit we see
102            if phase == Phase::FindOrGrow && get_limit(&node, &mut ctx) {
103                return Ok(Transformed::no(node));
104            }
105
106            // grow the limit if we hit a window function
107            if let Some(window) = node.downcast_ref::<BoundedWindowAggExec>() {
108                phase = Phase::Apply;
109                if !grow_limit(window, &mut ctx) {
110                    return reset(node, &mut ctx);
111                }
112                return Ok(Transformed::no(node));
113            }
114
115            // Apply the limit if we hit a sortpreservingmerge node
116            if phase == Phase::Apply
117                && let Some(out) = apply_limit(&node, &mut ctx)
118            {
119                return Ok(out);
120            }
121
122            // nodes along the way
123            if !node.supports_limit_pushdown() {
124                return reset(node, &mut ctx);
125            }
126            if let Some(part) = node.downcast_ref::<RepartitionExec>() {
127                let output = part.partitioning().partition_count();
128                let input = part.input().output_partitioning().partition_count();
129                if output < input {
130                    return reset(node, &mut ctx);
131                }
132            }
133            match node.cardinality_effect() {
134                CardinalityEffect::Unknown => return reset(node, &mut ctx),
135                CardinalityEffect::LowerEqual => return reset(node, &mut ctx),
136                CardinalityEffect::Equal => {}
137                CardinalityEffect::GreaterEqual => {}
138            }
139
140            Ok(Transformed::no(node))
141        })?;
142        Ok(result.data)
143    }
144
145    fn name(&self) -> &str {
146        "LimitPushPastWindows"
147    }
148
149    fn schema_check(&self) -> bool {
150        false // we don't change the schema
151    }
152}
153
154fn grow_limit(window: &BoundedWindowAggExec, ctx: &mut TraverseState) -> bool {
155    let mut max_rel = 0;
156    for expr in window.window_expr().iter() {
157        // grow based on function requirements
158        match get_limit_effect(expr) {
159            LimitEffect::None => {}
160            LimitEffect::Unknown => return false,
161            LimitEffect::Relative(rel) => max_rel = max_rel.max(rel),
162            LimitEffect::Absolute(val) => {
163                let cur = ctx.limit.unwrap_or(0);
164                ctx.limit = Some(cur.max(val))
165            }
166        }
167
168        // grow based on frames
169        let frame = expr.get_window_frame();
170        if frame.units != WindowFrameUnits::Rows {
171            return false; // expression-based limits not statically evaluatable
172        }
173        let Some(end_bound) = bound_to_usize(&frame.end_bound) else {
174            return false; // can't optimize unbounded window expressions
175        };
176        ctx.max_lookahead(end_bound);
177    }
178
179    // finish grow
180    ctx.max_lookahead(ctx.lookahead + max_rel);
181    true
182}
183
184fn apply_limit(
185    node: &Arc<dyn ExecutionPlan>,
186    ctx: &mut TraverseState,
187) -> Option<Transformed<Arc<dyn ExecutionPlan>>> {
188    if !node.is::<SortExec>() && !node.is::<SortPreservingMergeExec>() {
189        return None;
190    }
191    let latest = ctx.limit.take();
192    let Some(fetch) = latest else {
193        ctx.limit = None;
194        ctx.lookahead = 0;
195        return Some(Transformed::no(Arc::clone(node)));
196    };
197    let fetch = match node.fetch() {
198        None => fetch + ctx.lookahead,
199        Some(existing) => cmp::min(existing, fetch + ctx.lookahead),
200    };
201    Some(Transformed::complete(node.with_fetch(Some(fetch)).unwrap()))
202}
203
204fn get_limit(node: &Arc<dyn ExecutionPlan>, ctx: &mut TraverseState) -> bool {
205    if let Some(limit) = node.downcast_ref::<GlobalLimitExec>() {
206        ctx.reset_limit(limit.fetch().map(|fetch| fetch + limit.skip()));
207        return true;
208    }
209    // In distributed execution, GlobalLimitExec becomes LocalLimitExec
210    // per partition. Handle it the same way (LocalLimitExec has no skip).
211    if let Some(limit) = node.downcast_ref::<LocalLimitExec>() {
212        ctx.reset_limit(Some(limit.fetch()));
213        return true;
214    }
215    if let Some(limit) = node.downcast_ref::<SortPreservingMergeExec>() {
216        ctx.reset_limit(limit.fetch());
217        return true;
218    }
219    false
220}
221
222/// Examines the `WindowExpr` and decides:
223/// 1. The expression does not change the window size
224/// 2. The expression grows it by X amount
225/// 3. We don't know
226///
227/// # Arguments
228///
229/// * `expr` the expression to examine
230///
231/// # Returns
232///
233/// The effect on the limit
234fn get_limit_effect(expr: &Arc<dyn WindowExpr>) -> LimitEffect {
235    // White list aggregates
236    if expr.as_any().is::<PlainAggregateWindowExpr>()
237        || expr.as_any().is::<SlidingAggregateWindowExpr>()
238    {
239        return LimitEffect::None;
240    }
241
242    // Grab the window function
243    let Some(swe) = expr.as_any().downcast_ref::<StandardWindowExpr>() else {
244        return LimitEffect::Unknown; // should be only remaining type
245    };
246    let swfe = swe.get_standard_func_expr();
247    let Some(udf) = swfe.as_any().downcast_ref::<WindowUDFExpr>() else {
248        return LimitEffect::Unknown; // should be only remaining type
249    };
250    udf.limit_effect()
251}
252
253fn bound_to_usize(bound: &WindowFrameBound) -> Option<usize> {
254    match bound {
255        WindowFrameBound::Preceding(_) => Some(0),
256        WindowFrameBound::CurrentRow => Some(0),
257        WindowFrameBound::Following(ScalarValue::UInt64(Some(scalar))) => {
258            Some(*scalar as usize)
259        }
260        _ => None,
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use arrow::datatypes::{DataType, Field, Schema};
268    use datafusion_expr::WindowFrame;
269    use datafusion_functions_window::row_number::row_number_udwf;
270    use datafusion_physical_expr::expressions::col;
271    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
272    use datafusion_physical_plan::InputOrderMode;
273    use datafusion_physical_plan::displayable;
274    use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
275    use datafusion_physical_plan::windows::{
276        BoundedWindowAggExec, create_udwf_window_expr,
277    };
278    use insta::assert_snapshot;
279
280    fn plan_str(plan: &dyn ExecutionPlan) -> String {
281        displayable(plan).indent(true).to_string()
282    }
283
284    fn schema() -> Arc<Schema> {
285        Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]))
286    }
287
288    /// Build: LocalLimitExec or GlobalLimitExec → BoundedWindowAggExec(row_number) → SortExec
289    fn build_window_plan(
290        use_local_limit: bool,
291    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
292        let s = schema();
293        let input: Arc<dyn ExecutionPlan> =
294            Arc::new(PlaceholderRowExec::new(Arc::clone(&s)));
295
296        let ordering =
297            LexOrdering::new(vec![PhysicalSortExpr::new_default(col("a", &s)?).asc()])
298                .unwrap();
299
300        let sort: Arc<dyn ExecutionPlan> = Arc::new(
301            SortExec::new(ordering.clone(), input).with_preserve_partitioning(true),
302        );
303
304        let window_expr = Arc::new(StandardWindowExpr::new(
305            create_udwf_window_expr(
306                &row_number_udwf(),
307                &[],
308                &s,
309                "row_number".to_string(),
310                false,
311            )?,
312            &[],
313            ordering.as_ref(),
314            Arc::new(WindowFrame::new_bounds(
315                WindowFrameUnits::Rows,
316                WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
317                WindowFrameBound::CurrentRow,
318            )),
319        ));
320
321        let window: Arc<dyn ExecutionPlan> = Arc::new(BoundedWindowAggExec::try_new(
322            vec![window_expr],
323            sort,
324            InputOrderMode::Sorted,
325            true,
326        )?);
327
328        let limit: Arc<dyn ExecutionPlan> = if use_local_limit {
329            Arc::new(LocalLimitExec::new(window, 100))
330        } else {
331            Arc::new(GlobalLimitExec::new(window, 0, Some(100)))
332        };
333
334        Ok(limit)
335    }
336
337    fn optimize(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
338        let mut config = ConfigOptions::new();
339        config.optimizer.enable_window_limits = true;
340        LimitPushPastWindows::new().optimize(plan, &config).unwrap()
341    }
342
343    /// GlobalLimitExec above a windowed sort should push fetch into the SortExec.
344    #[test]
345    fn global_limit_pushes_past_window() {
346        let plan = build_window_plan(false).unwrap();
347        let optimized = optimize(plan);
348        assert_snapshot!(plan_str(optimized.as_ref()), @r#"
349        GlobalLimitExec: skip=0, fetch=100
350          BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
351            SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true]
352              PlaceholderRowExec
353        "#);
354    }
355
356    /// LocalLimitExec above a windowed sort should also push fetch into the SortExec.
357    /// This is the case in distributed execution where GlobalLimitExec becomes LocalLimitExec.
358    #[test]
359    fn local_limit_pushes_past_window() {
360        let plan = build_window_plan(true).unwrap();
361        let optimized = optimize(plan);
362        assert_snapshot!(plan_str(optimized.as_ref()), @r#"
363        LocalLimitExec: fetch=100
364          BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
365            SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true]
366              PlaceholderRowExec
367        "#);
368    }
369}