use std::fmt::Debug;
use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TreeNodeRecursion};
use datafusion_common::utils::combine_limit;
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion_physical_plan::empty::EmptyExec;
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
#[derive(Default, Debug)]
pub struct LimitPushdown {}
#[derive(Default, Clone, Debug)]
pub struct GlobalRequirements {
fetch: Option<usize>,
skip: usize,
satisfied: bool,
preserve_order: bool,
}
impl LimitPushdown {
#[expect(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for LimitPushdown {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let global_state = GlobalRequirements {
fetch: None,
skip: 0,
satisfied: false,
preserve_order: false,
};
pushdown_limits(plan, global_state)
}
fn name(&self) -> &str {
"LimitPushdown"
}
fn schema_check(&self) -> bool {
true
}
}
struct LimitInfo {
input: Arc<dyn ExecutionPlan>,
fetch: Option<usize>,
skip: usize,
preserve_order: bool,
}
pub fn pushdown_limit_helper(
mut pushdown_plan: Arc<dyn ExecutionPlan>,
mut global_state: GlobalRequirements,
) -> Result<(Transformed<Arc<dyn ExecutionPlan>>, GlobalRequirements)> {
if let Some(limit_info) = extract_limit(&pushdown_plan) {
let (skip, fetch) = combine_limit(
global_state.skip,
global_state.fetch,
limit_info.skip,
limit_info.fetch,
);
global_state.skip = skip;
global_state.fetch = fetch;
global_state.preserve_order = limit_info.preserve_order;
global_state.satisfied = false;
if let Some(fetch) = fetch
&& limit_satisfied_by_input(&limit_info.input, skip, fetch)?
{
global_state.satisfied = true;
return Ok((
Transformed {
data: limit_info.input,
transformed: true,
tnr: TreeNodeRecursion::Stop,
},
global_state,
));
}
return Ok((
Transformed {
data: limit_info.input,
transformed: true,
tnr: TreeNodeRecursion::Stop,
},
global_state,
));
}
if pushdown_plan.fetch().is_some() {
if global_state.skip == 0 {
global_state.satisfied = true;
}
(global_state.skip, global_state.fetch) = combine_limit(
global_state.skip,
global_state.fetch,
0,
pushdown_plan.fetch(),
);
}
let Some(global_fetch) = global_state.fetch else {
return if global_state.skip > 0 && !global_state.satisfied {
global_state.satisfied = true;
Ok((
Transformed::yes(add_global_limit(
pushdown_plan,
global_state.skip,
None,
)),
global_state,
))
} else {
Ok((Transformed::no(pushdown_plan), global_state))
};
};
let skip_and_fetch = Some(global_fetch + global_state.skip);
if pushdown_plan.supports_limit_pushdown() {
if !combines_input_partitions(&pushdown_plan) {
Ok((Transformed::no(pushdown_plan), global_state))
} else if let Some(plan_with_fetch) = pushdown_plan.with_fetch(skip_and_fetch) {
let mut new_plan = plan_with_fetch;
if global_state.skip > 0 {
new_plan =
add_global_limit(new_plan, global_state.skip, global_state.fetch);
}
global_state.fetch = skip_and_fetch;
global_state.skip = 0;
global_state.satisfied = true;
Ok((Transformed::yes(new_plan), global_state))
} else if global_state.satisfied {
Ok((Transformed::no(pushdown_plan), global_state))
} else {
global_state.satisfied = true;
Ok((
Transformed::yes(add_limit(
pushdown_plan,
global_state.skip,
global_fetch,
)),
global_state,
))
}
} else {
let global_skip = global_state.skip;
global_state.fetch = None;
global_state.skip = 0;
let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch);
if global_state.satisfied {
if let Some(plan_with_fetch) = maybe_fetchable {
let plan_with_preserve_order = plan_with_fetch
.with_preserve_order(global_state.preserve_order)
.unwrap_or(plan_with_fetch);
Ok((Transformed::yes(plan_with_preserve_order), global_state))
} else {
Ok((Transformed::no(pushdown_plan), global_state))
}
} else {
global_state.satisfied = true;
pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable {
let plan_with_preserve_order = plan_with_fetch
.with_preserve_order(global_state.preserve_order)
.unwrap_or(plan_with_fetch);
if global_skip > 0 {
add_global_limit(
plan_with_preserve_order,
global_skip,
Some(global_fetch),
)
} else {
plan_with_preserve_order
}
} else {
add_limit(pushdown_plan, global_skip, global_fetch)
};
Ok((Transformed::yes(pushdown_plan), global_state))
}
}
}
fn limit_satisfied_by_input(
plan: &Arc<dyn ExecutionPlan>,
skip: usize,
fetch: usize,
) -> Result<bool> {
if skip > 0 {
return Ok(false);
}
if plan.output_partitioning().partition_count() != 1 {
return Ok(false);
}
let Some(num_rows) = limit_eliminable_exact_num_rows(plan)? else {
return Ok(false);
};
Ok(num_rows <= fetch)
}
fn limit_eliminable_exact_num_rows(
plan: &Arc<dyn ExecutionPlan>,
) -> Result<Option<usize>> {
let mut current = plan;
while let Some(projection) = current.downcast_ref::<ProjectionExec>() {
current = projection.input();
}
if current.is::<EmptyExec>() {
return Ok(Some(0));
}
if current.is::<PlaceholderRowExec>() {
return Ok(Some(1));
}
if matches!(
current.partition_statistics(None)?.num_rows,
Precision::Exact(0)
) {
return Ok(Some(0));
}
Ok(None)
}
pub(crate) fn pushdown_limits(
pushdown_plan: Arc<dyn ExecutionPlan>,
global_state: GlobalRequirements,
) -> Result<Arc<dyn ExecutionPlan>> {
let (mut new_node, mut global_state) =
pushdown_limit_helper(pushdown_plan, global_state)?;
while new_node.tnr == TreeNodeRecursion::Stop {
(new_node, global_state) = pushdown_limit_helper(new_node.data, global_state)?;
}
if global_state.satisfied {
global_state.skip = 0;
}
let children = new_node.data.children();
let mut changed = false;
let new_children = children
.into_iter()
.map(|child: &Arc<dyn ExecutionPlan>| {
let new_child = pushdown_limits(
Arc::<dyn ExecutionPlan>::clone(child),
global_state.clone(),
)?;
changed |= !Arc::ptr_eq(child, &new_child);
Ok(new_child)
})
.collect::<Result<_>>()?;
if changed {
new_node.data.with_new_children(new_children)
} else {
Ok(new_node.data)
}
}
fn extract_limit(plan: &Arc<dyn ExecutionPlan>) -> Option<LimitInfo> {
if let Some(global_limit) = plan.downcast_ref::<GlobalLimitExec>() {
Some(LimitInfo {
input: Arc::clone(global_limit.input()),
fetch: global_limit.fetch(),
skip: global_limit.skip(),
preserve_order: global_limit.required_ordering().is_some(),
})
} else {
plan.downcast_ref::<LocalLimitExec>()
.map(|local_limit| LimitInfo {
input: Arc::clone(local_limit.input()),
fetch: Some(local_limit.fetch()),
skip: 0,
preserve_order: local_limit.required_ordering().is_some(),
})
}
}
fn combines_input_partitions(plan: &Arc<dyn ExecutionPlan>) -> bool {
plan.is::<CoalescePartitionsExec>() || plan.is::<SortPreservingMergeExec>()
}
fn add_limit(
pushdown_plan: Arc<dyn ExecutionPlan>,
skip: usize,
fetch: usize,
) -> Arc<dyn ExecutionPlan> {
if skip > 0 || pushdown_plan.output_partitioning().partition_count() == 1 {
add_global_limit(pushdown_plan, skip, Some(fetch))
} else {
Arc::new(LocalLimitExec::new(pushdown_plan, fetch + skip)) as _
}
}
fn add_global_limit(
pushdown_plan: Arc<dyn ExecutionPlan>,
skip: usize,
fetch: Option<usize>,
) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(pushdown_plan, skip, fetch)) as _
}