use std::fmt::Debug;
use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::tree_node::{Transformed, TreeNodeRecursion};
use datafusion_common::utils::combine_limit;
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
#[derive(Default, Debug)]
pub struct LimitPushdown {}
#[derive(Default, Clone, Debug)]
pub struct GlobalRequirements {
fetch: Option<usize>,
skip: usize,
satisfied: bool,
}
impl LimitPushdown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for LimitPushdown {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let global_state = GlobalRequirements {
fetch: None,
skip: 0,
satisfied: false,
};
pushdown_limits(plan, global_state)
}
fn name(&self) -> &str {
"LimitPushdown"
}
fn schema_check(&self) -> bool {
true
}
}
#[derive(Debug)]
pub enum LimitExec {
Global(GlobalLimitExec),
Local(LocalLimitExec),
}
impl LimitExec {
fn input(&self) -> &Arc<dyn ExecutionPlan> {
match self {
Self::Global(global) => global.input(),
Self::Local(local) => local.input(),
}
}
fn fetch(&self) -> Option<usize> {
match self {
Self::Global(global) => global.fetch(),
Self::Local(local) => Some(local.fetch()),
}
}
fn skip(&self) -> usize {
match self {
Self::Global(global) => global.skip(),
Self::Local(_) => 0,
}
}
}
impl From<LimitExec> for Arc<dyn ExecutionPlan> {
fn from(limit_exec: LimitExec) -> Self {
match limit_exec {
LimitExec::Global(global) => Arc::new(global),
LimitExec::Local(local) => Arc::new(local),
}
}
}
pub fn pushdown_limit_helper(
mut pushdown_plan: Arc<dyn ExecutionPlan>,
mut global_state: GlobalRequirements,
) -> Result<(Transformed<Arc<dyn ExecutionPlan>>, GlobalRequirements)> {
if let Some(limit_exec) = extract_limit(&pushdown_plan) {
let (skip, fetch) = combine_limit(
global_state.skip,
global_state.fetch,
limit_exec.skip(),
limit_exec.fetch(),
);
global_state.skip = skip;
global_state.fetch = fetch;
return Ok((
Transformed {
data: Arc::clone(limit_exec.input()),
transformed: true,
tnr: TreeNodeRecursion::Stop,
},
global_state,
));
}
if pushdown_plan.fetch().is_some() {
if global_state.fetch.is_none() {
global_state.satisfied = true;
}
(global_state.skip, global_state.fetch) = combine_limit(
global_state.skip,
global_state.fetch,
0,
pushdown_plan.fetch(),
);
}
let Some(global_fetch) = global_state.fetch else {
return if global_state.skip > 0 && !global_state.satisfied {
global_state.satisfied = true;
Ok((
Transformed::yes(add_global_limit(
pushdown_plan,
global_state.skip,
None,
)),
global_state,
))
} else {
Ok((Transformed::no(pushdown_plan), global_state))
};
};
let skip_and_fetch = Some(global_fetch + global_state.skip);
if pushdown_plan.supports_limit_pushdown() {
if !combines_input_partitions(&pushdown_plan) {
Ok((Transformed::no(pushdown_plan), global_state))
} else if let Some(plan_with_fetch) = pushdown_plan.with_fetch(skip_and_fetch) {
let mut new_plan = plan_with_fetch;
if global_state.skip > 0 {
new_plan =
add_global_limit(new_plan, global_state.skip, global_state.fetch);
}
global_state.fetch = skip_and_fetch;
global_state.skip = 0;
global_state.satisfied = true;
Ok((Transformed::yes(new_plan), global_state))
} else if global_state.satisfied {
Ok((Transformed::no(pushdown_plan), global_state))
} else {
global_state.satisfied = true;
Ok((
Transformed::yes(add_limit(
pushdown_plan,
global_state.skip,
global_fetch,
)),
global_state,
))
}
} else {
let global_skip = global_state.skip;
global_state.fetch = None;
global_state.skip = 0;
let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch);
if global_state.satisfied {
if let Some(plan_with_fetch) = maybe_fetchable {
Ok((Transformed::yes(plan_with_fetch), global_state))
} else {
Ok((Transformed::no(pushdown_plan), global_state))
}
} else {
if !pushdown_plan
.children()
.iter()
.any(|&child| extract_limit(child).is_some())
{
global_state.satisfied = true;
}
pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable {
if global_skip > 0 {
add_global_limit(plan_with_fetch, global_skip, Some(global_fetch))
} else {
plan_with_fetch
}
} else {
add_limit(pushdown_plan, global_skip, global_fetch)
};
Ok((Transformed::yes(pushdown_plan), global_state))
}
}
}
pub(crate) fn pushdown_limits(
pushdown_plan: Arc<dyn ExecutionPlan>,
global_state: GlobalRequirements,
) -> Result<Arc<dyn ExecutionPlan>> {
let (mut new_node, mut global_state) =
pushdown_limit_helper(pushdown_plan, global_state)?;
while new_node.tnr == TreeNodeRecursion::Stop {
(new_node, global_state) = pushdown_limit_helper(new_node.data, global_state)?;
}
let children = new_node.data.children();
let new_children = children
.into_iter()
.map(|child| {
pushdown_limits(Arc::<dyn ExecutionPlan>::clone(child), global_state.clone())
})
.collect::<Result<_>>()?;
new_node.data.with_new_children(new_children)
}
fn extract_limit(plan: &Arc<dyn ExecutionPlan>) -> Option<LimitExec> {
if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>() {
Some(LimitExec::Global(GlobalLimitExec::new(
Arc::clone(global_limit.input()),
global_limit.skip(),
global_limit.fetch(),
)))
} else {
plan.as_any()
.downcast_ref::<LocalLimitExec>()
.map(|local_limit| {
LimitExec::Local(LocalLimitExec::new(
Arc::clone(local_limit.input()),
local_limit.fetch(),
))
})
}
}
fn combines_input_partitions(plan: &Arc<dyn ExecutionPlan>) -> bool {
let plan = plan.as_any();
plan.is::<CoalescePartitionsExec>() || plan.is::<SortPreservingMergeExec>()
}
fn add_limit(
pushdown_plan: Arc<dyn ExecutionPlan>,
skip: usize,
fetch: usize,
) -> Arc<dyn ExecutionPlan> {
if skip > 0 || pushdown_plan.output_partitioning().partition_count() == 1 {
add_global_limit(pushdown_plan, skip, Some(fetch))
} else {
Arc::new(LocalLimitExec::new(pushdown_plan, fetch + skip)) as _
}
}
fn add_global_limit(
pushdown_plan: Arc<dyn ExecutionPlan>,
skip: usize,
fetch: Option<usize>,
) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(pushdown_plan, skip, fetch)) as _
}