use crate::PhysicalOptimizerRule;
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_expr::{LimitEffect, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr::window::{
PlainAggregateWindowExpr, SlidingAggregateWindowExpr, StandardWindowExpr,
StandardWindowFunctionExpr, WindowExpr,
};
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::limit::GlobalLimitExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr};
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::cmp;
use std::sync::Arc;
#[derive(Default, Clone, Debug)]
pub struct LimitPushPastWindows;
impl LimitPushPastWindows {
pub fn new() -> Self {
Self
}
}
#[derive(Eq, PartialEq)]
enum Phase {
FindOrGrow,
Apply,
}
#[derive(Default)]
struct TraverseState {
pub limit: Option<usize>,
pub lookahead: usize,
}
impl TraverseState {
pub fn reset_limit(&mut self, limit: Option<usize>) {
self.limit = limit;
self.lookahead = 0;
}
pub fn max_lookahead(&mut self, new_val: usize) {
self.lookahead = self.lookahead.max(new_val);
}
}
impl PhysicalOptimizerRule for LimitPushPastWindows {
fn optimize(
&self,
original: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
if !config.optimizer.enable_window_limits {
return Ok(original);
}
let mut ctx = TraverseState::default();
let mut phase = Phase::FindOrGrow;
let result = original.transform_down(|node| {
let reset = |node,
ctx: &mut TraverseState|
-> datafusion_common::Result<
Transformed<Arc<dyn ExecutionPlan>>,
> {
ctx.limit = None;
ctx.lookahead = 0;
Ok(Transformed::no(node))
};
if node.children().len() > 1 {
return reset(node, &mut ctx);
}
if phase == Phase::FindOrGrow && get_limit(&node, &mut ctx) {
return Ok(Transformed::no(node));
}
if let Some(window) = node.as_any().downcast_ref::<BoundedWindowAggExec>() {
phase = Phase::Apply;
if !grow_limit(window, &mut ctx) {
return reset(node, &mut ctx);
}
return Ok(Transformed::no(node));
}
if phase == Phase::Apply
&& let Some(out) = apply_limit(&node, &mut ctx)
{
return Ok(out);
}
if !node.supports_limit_pushdown() {
return reset(node, &mut ctx);
}
if let Some(part) = node.as_any().downcast_ref::<RepartitionExec>() {
let output = part.partitioning().partition_count();
let input = part.input().output_partitioning().partition_count();
if output < input {
return reset(node, &mut ctx);
}
}
match node.cardinality_effect() {
CardinalityEffect::Unknown => return reset(node, &mut ctx),
CardinalityEffect::LowerEqual => return reset(node, &mut ctx),
CardinalityEffect::Equal => {}
CardinalityEffect::GreaterEqual => {}
}
Ok(Transformed::no(node))
})?;
Ok(result.data)
}
fn name(&self) -> &str {
"LimitPushPastWindows"
}
fn schema_check(&self) -> bool {
false }
}
fn grow_limit(window: &BoundedWindowAggExec, ctx: &mut TraverseState) -> bool {
let mut max_rel = 0;
for expr in window.window_expr().iter() {
match get_limit_effect(expr) {
LimitEffect::None => {}
LimitEffect::Unknown => return false,
LimitEffect::Relative(rel) => max_rel = max_rel.max(rel),
LimitEffect::Absolute(val) => {
let cur = ctx.limit.unwrap_or(0);
ctx.limit = Some(cur.max(val))
}
}
let frame = expr.get_window_frame();
if frame.units != WindowFrameUnits::Rows {
return false; }
let Some(end_bound) = bound_to_usize(&frame.end_bound) else {
return false; };
ctx.max_lookahead(end_bound);
}
ctx.max_lookahead(ctx.lookahead + max_rel);
true
}
fn apply_limit(
node: &Arc<dyn ExecutionPlan>,
ctx: &mut TraverseState,
) -> Option<Transformed<Arc<dyn ExecutionPlan>>> {
if !node.as_any().is::<SortExec>() && !node.as_any().is::<SortPreservingMergeExec>() {
return None;
}
let latest = ctx.limit.take();
let Some(fetch) = latest else {
ctx.limit = None;
ctx.lookahead = 0;
return Some(Transformed::no(Arc::clone(node)));
};
let fetch = match node.fetch() {
None => fetch + ctx.lookahead,
Some(existing) => cmp::min(existing, fetch + ctx.lookahead),
};
Some(Transformed::complete(node.with_fetch(Some(fetch)).unwrap()))
}
fn get_limit(node: &Arc<dyn ExecutionPlan>, ctx: &mut TraverseState) -> bool {
if let Some(limit) = node.as_any().downcast_ref::<GlobalLimitExec>() {
ctx.reset_limit(limit.fetch().map(|fetch| fetch + limit.skip()));
return true;
}
if let Some(limit) = node.as_any().downcast_ref::<SortPreservingMergeExec>() {
ctx.reset_limit(limit.fetch());
return true;
}
false
}
fn get_limit_effect(expr: &Arc<dyn WindowExpr>) -> LimitEffect {
if expr.as_any().is::<PlainAggregateWindowExpr>()
|| expr.as_any().is::<SlidingAggregateWindowExpr>()
{
return LimitEffect::None;
}
let Some(swe) = expr.as_any().downcast_ref::<StandardWindowExpr>() else {
return LimitEffect::Unknown; };
let swfe = swe.get_standard_func_expr();
let Some(udf) = swfe.as_any().downcast_ref::<WindowUDFExpr>() else {
return LimitEffect::Unknown; };
udf.limit_effect()
}
fn bound_to_usize(bound: &WindowFrameBound) -> Option<usize> {
match bound {
WindowFrameBound::Preceding(_) => Some(0),
WindowFrameBound::CurrentRow => Some(0),
WindowFrameBound::Following(ScalarValue::UInt64(Some(scalar))) => {
Some(*scalar as usize)
}
_ => None,
}
}