datafusion_physical_optimizer/window_topn.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`WindowTopN`] optimizer rule for per-partition top-K window queries.
19//!
20//! Detects queries of the form:
21//!
22//! ```sql
23//! SELECT * FROM (
24//! SELECT *, ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) as rn
25//! FROM t
26//! ) WHERE rn <= K;
27//! ```
28//!
29//! And replaces the `FilterExec → BoundedWindowAggExec → SortExec` pipeline
30//! with `BoundedWindowAggExec → PartitionedTopKExec(fetch=K)`, removing both
31//! the `FilterExec` and `SortExec`.
32//!
33//! See [`PartitionedTopKExec`]
34//! for details on the replacement operator.
35
36use std::sync::Arc;
37
38use crate::PhysicalOptimizerRule;
39use arrow::datatypes::DataType;
40use datafusion_common::config::ConfigOptions;
41use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
42use datafusion_common::{Result, ScalarValue};
43use datafusion_expr::Operator;
44use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
45use datafusion_physical_expr::window::StandardWindowExpr;
46use datafusion_physical_plan::ExecutionPlan;
47use datafusion_physical_plan::filter::FilterExec;
48use datafusion_physical_plan::projection::ProjectionExec;
49use datafusion_physical_plan::sorts::partitioned_topk::PartitionedTopKExec;
50use datafusion_physical_plan::sorts::sort::SortExec;
51use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr};
52
53/// Physical optimizer rule that converts per-partition `ROW_NUMBER` top-K
54/// queries into a more efficient plan using [`PartitionedTopKExec`].
55///
56/// # Pattern Detected
57///
58/// ```text
59/// FilterExec(rn <= K)
60/// [optional ProjectionExec]
61/// BoundedWindowAggExec(ROW_NUMBER PARTITION BY ... ORDER BY ...)
62/// SortExec(partition_keys, order_keys)
63/// ```
64///
65/// # Replacement
66///
67/// ```text
68/// [optional ProjectionExec]
69/// BoundedWindowAggExec(ROW_NUMBER PARTITION BY ... ORDER BY ...)
70/// PartitionedTopKExec(partition_keys, order_keys, fetch=K)
71/// ```
72///
73/// The `FilterExec` is removed entirely (all output rows have `rn ∈ {1..K}`).
74/// The `SortExec` is replaced by `PartitionedTopKExec` which maintains a
75/// per-partition top-K heap instead of sorting the entire dataset.
76///
77/// # Supported Predicates
78///
79/// - `rn <= K` → fetch = K
80/// - `rn < K` → fetch = K - 1
81/// - `K >= rn` (flipped) → fetch = K
82/// - `K > rn` (flipped) → fetch = K - 1
83///
84/// # When the Rule Fires
85///
86/// All of the following must be true:
87/// - Config flag `enable_window_topn` is `true`
88/// - The plan matches `FilterExec → [ProjectionExec] → BoundedWindowAggExec → SortExec`
89/// - The window function is `ROW_NUMBER` (not `RANK`, `DENSE_RANK`, etc.)
90/// - `ROW_NUMBER` has a `PARTITION BY` clause (global top-K is already
91/// handled by `SortExec` with `fetch`)
92/// - The filter predicate compares the window output column to an integer
93/// literal using `<=`, `<`, `>=`, or `>`
94///
95/// [`PartitionedTopKExec`]: datafusion_physical_plan::sorts::partitioned_topk::PartitionedTopKExec
96#[derive(Default, Clone, Debug)]
97pub struct WindowTopN;
98
99impl WindowTopN {
100 pub fn new() -> Self {
101 Self
102 }
103
104 /// Attempt to transform a single plan node.
105 ///
106 /// Returns `Some(new_plan)` if the node matches the
107 /// `FilterExec → [ProjectionExec] → BoundedWindowAggExec → SortExec`
108 /// pattern and can be rewritten, or `None` if the node should be
109 /// left unchanged.
110 fn try_transform(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
111 // Step 1: Match FilterExec at the top
112 let filter = plan.downcast_ref::<FilterExec>()?;
113
114 // Don't handle filters with projections
115 if filter.projection().is_some() {
116 return None;
117 }
118
119 // Step 2: Extract limit from predicate (rn <= K, rn < K, etc.)
120 let (col_idx, limit_n) = extract_window_limit(filter.predicate())?;
121
122 // Step 3: Walk through optional ProjectionExec to find BoundedWindowAggExec
123 let child = filter.input();
124 let (window_exec, proj_between) = find_window_below(child)?;
125
126 // Step 4: Verify col_idx references a ROW_NUMBER window output column
127 let input_field_count = window_exec.input().schema().fields().len();
128 if col_idx < input_field_count {
129 return None; // Filter is on an input column, not a window column
130 }
131 let window_expr_idx = col_idx - input_field_count;
132 let window_exprs = window_exec.window_expr();
133 if window_expr_idx >= window_exprs.len() {
134 return None;
135 }
136 if !is_row_number(&window_exprs[window_expr_idx]) {
137 return None;
138 }
139
140 // Step 5: Verify child of window is SortExec
141 let sort_exec = window_exec.input().downcast_ref::<SortExec>()?;
142 let sort_child = sort_exec.input();
143
144 // Step 6: Determine partition_prefix_len from the window expression
145 let partition_by = window_exprs[window_expr_idx].partition_by();
146 let partition_prefix_len = partition_by.len();
147
148 // Without PARTITION BY, this is just a global top-K which
149 // SortExec with fetch already handles efficiently.
150 if partition_prefix_len == 0 {
151 return None;
152 }
153
154 // Step 7: Build PartitionedTopKExec using SortExec's expressions
155 let partitioned_topk = PartitionedTopKExec::try_new(
156 Arc::clone(sort_child),
157 sort_exec.expr().clone(),
158 partition_prefix_len,
159 limit_n,
160 )
161 .ok()?;
162
163 // Step 8: Rebuild window with new child
164 let new_window = Arc::clone(&child_as_arc(window_exec))
165 .with_new_children(vec![Arc::new(partitioned_topk)])
166 .ok()?;
167
168 // Step 9: If ProjectionExec was between Filter and Window, rebuild it
169 let result = match proj_between {
170 Some(proj) => Arc::clone(&child_as_arc(proj))
171 .with_new_children(vec![new_window])
172 .ok()?,
173 None => new_window,
174 };
175
176 Some(result)
177 }
178}
179
180/// Helper to get an `Arc<dyn ExecutionPlan>` from a reference.
181/// We need this because `with_new_children` takes `Arc<Self>`.
182fn child_as_arc<T: ExecutionPlan + Clone + 'static>(plan: &T) -> Arc<dyn ExecutionPlan> {
183 Arc::new(plan.clone())
184}
185
186impl PhysicalOptimizerRule for WindowTopN {
187 fn optimize(
188 &self,
189 plan: Arc<dyn ExecutionPlan>,
190 config: &ConfigOptions,
191 ) -> Result<Arc<dyn ExecutionPlan>> {
192 if !config.optimizer.enable_window_topn {
193 return Ok(plan);
194 }
195
196 plan.transform_down(|node| {
197 Ok(
198 if let Some(transformed) = WindowTopN::try_transform(&node) {
199 Transformed::yes(transformed)
200 } else {
201 Transformed::no(node)
202 },
203 )
204 })
205 .data()
206 }
207
208 fn name(&self) -> &str {
209 "WindowTopN"
210 }
211
212 fn schema_check(&self) -> bool {
213 true
214 }
215}
216
217/// Extract a window limit from a predicate expression.
218///
219/// Returns `(column_index, fetch)` if the predicate constrains a column
220/// to at most N rows.
221///
222/// # Supported Patterns
223///
224/// | Predicate | Returns |
225/// |-----------|---------|
226/// | `Column(idx) <= Literal(N)` | `(idx, N)` |
227/// | `Column(idx) < Literal(N)` | `(idx, N-1)` |
228/// | `Literal(N) >= Column(idx)` | `(idx, N)` |
229/// | `Literal(N) > Column(idx)` | `(idx, N-1)` |
230///
231/// # Examples
232///
233/// - `rn <= 5` → `Some((2, 5))` (assuming rn is column index 2)
234/// - `rn < 3` → `Some((2, 2))`
235/// - `10 >= rn` → `Some((2, 10))`
236/// - `rn = 1` → `None` (equality not supported)
237/// - `val <= 5` → `Some((1, 5))` (caller must verify it's a window column)
238fn extract_window_limit(
239 predicate: &Arc<dyn datafusion_physical_expr::PhysicalExpr>,
240) -> Option<(usize, usize)> {
241 let binary = predicate.downcast_ref::<BinaryExpr>()?;
242 let op = binary.op();
243 let left = binary.left();
244 let right = binary.right();
245
246 // Try Column op Literal
247 if let (Some(col), Some(lit_val)) = (
248 left.downcast_ref::<Column>(),
249 right.downcast_ref::<Literal>(),
250 ) {
251 let n = scalar_to_usize(lit_val.value())?;
252 return match *op {
253 Operator::LtEq => Some((col.index(), n)),
254 Operator::Lt => Some((col.index(), n - 1)),
255 _ => None,
256 };
257 }
258
259 // Try Literal op Column (flipped)
260 if let (Some(lit_val), Some(col)) = (
261 left.downcast_ref::<Literal>(),
262 right.downcast_ref::<Column>(),
263 ) {
264 let n = scalar_to_usize(lit_val.value())?;
265 return match *op {
266 Operator::GtEq => Some((col.index(), n)),
267 Operator::Gt => Some((col.index(), n - 1)),
268 _ => None,
269 };
270 }
271
272 None
273}
274
275/// Convert a [`ScalarValue`] to `usize` if it's a positive integer.
276///
277/// Returns `None` for null values, zero, negative integers, and
278/// non-integer types (floats, strings, decimals, etc.).
279fn scalar_to_usize(value: &ScalarValue) -> Option<usize> {
280 if !value.data_type().is_integer() {
281 return None;
282 }
283 let casted = value.cast_to(&DataType::UInt64).ok()?;
284 match casted {
285 ScalarValue::UInt64(Some(v)) if v > 0 => usize::try_from(v).ok(),
286 _ => None,
287 }
288}
289
290/// Check if a window expression is `ROW_NUMBER`.
291///
292/// Downcasts through `StandardWindowExpr` → `WindowUDFExpr` and checks
293/// that the UDF name is `"row_number"`. Returns `false` for all other
294/// window functions (e.g., `RANK`, `DENSE_RANK`, `SUM`).
295fn is_row_number(expr: &Arc<dyn datafusion_physical_expr::window::WindowExpr>) -> bool {
296 let Some(swe) = expr.as_any().downcast_ref::<StandardWindowExpr>() else {
297 return false;
298 };
299 let swfe = swe.get_standard_func_expr();
300 let Some(udf) = swfe.as_any().downcast_ref::<WindowUDFExpr>() else {
301 return false;
302 };
303 udf.fun().name() == "row_number"
304}
305
306/// Walk below a plan node looking for a [`BoundedWindowAggExec`].
307///
308/// Handles two cases:
309/// - Direct child: `FilterExec → BoundedWindowAggExec`
310/// - With projection: `FilterExec → ProjectionExec → BoundedWindowAggExec`
311///
312/// Returns the window exec and an optional `ProjectionExec` in between,
313/// or `None` if no `BoundedWindowAggExec` is found within one or two levels.
314fn find_window_below(
315 plan: &Arc<dyn ExecutionPlan>,
316) -> Option<(&BoundedWindowAggExec, Option<&ProjectionExec>)> {
317 // Direct child is BoundedWindowAggExec
318 if let Some(window) = plan.downcast_ref::<BoundedWindowAggExec>() {
319 return Some((window, None));
320 }
321
322 // Child is ProjectionExec with BoundedWindowAggExec below
323 if let Some(proj) = plan.downcast_ref::<ProjectionExec>() {
324 let proj_child = proj.input();
325 if let Some(window) = proj_child.downcast_ref::<BoundedWindowAggExec>() {
326 return Some((window, Some(proj)));
327 }
328 }
329
330 None
331}