use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Result, Statistics};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::Distribution;
use datafusion_physical_expr_common::sort_expr::OrderingRequirements;
use datafusion_physical_plan::execution_plan::Boundedness;
use datafusion_physical_plan::projection::{
ProjectionExec, make_with_child, update_expr, update_ordering_requirement,
};
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
SendableRecordBatchStream,
};
#[derive(Debug)]
pub struct OutputRequirements {
mode: RuleMode,
}
impl OutputRequirements {
pub fn new_add_mode() -> Self {
Self {
mode: RuleMode::Add,
}
}
pub fn new_remove_mode() -> Self {
Self {
mode: RuleMode::Remove,
}
}
}
#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Hash)]
enum RuleMode {
Add,
Remove,
}
#[derive(Debug)]
pub struct OutputRequirementExec {
input: Arc<dyn ExecutionPlan>,
order_requirement: Option<OrderingRequirements>,
dist_requirement: Distribution,
cache: Arc<PlanProperties>,
fetch: Option<usize>,
}
impl OutputRequirementExec {
pub fn new(
input: Arc<dyn ExecutionPlan>,
requirements: Option<OrderingRequirements>,
dist_requirement: Distribution,
fetch: Option<usize>,
) -> Self {
let cache = Self::compute_properties(&input, &fetch);
Self {
input,
order_requirement: requirements,
dist_requirement,
cache: Arc::new(cache),
fetch,
}
}
pub fn input(&self) -> Arc<dyn ExecutionPlan> {
Arc::clone(&self.input)
}
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
fetch: &Option<usize>,
) -> PlanProperties {
let boundedness = if fetch.is_some() {
Boundedness::Bounded
} else {
input.boundedness()
};
PlanProperties::new(
input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(), boundedness, )
}
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
}
impl DisplayAs for OutputRequirementExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let order_cols = self
.order_requirement
.as_ref()
.map(|reqs| reqs.first())
.map(|lex| {
let pairs: Vec<String> = lex
.iter()
.map(|req| {
let direction = req
.options
.as_ref()
.map(
|opt| if opt.descending { "desc" } else { "asc" },
)
.unwrap_or("unspecified");
format!("({}, {direction})", req.expr)
})
.collect();
format!("[{}]", pairs.join(", "))
})
.unwrap_or_else(|| "[]".to_string());
write!(
f,
"OutputRequirementExec: order_by={}, dist_by={}",
order_cols, self.dist_requirement
)
}
DisplayFormatType::TreeRender => {
write!(f, "")
}
}
}
}
impl ExecutionPlan for OutputRequirementExec {
fn name(&self) -> &'static str {
"OutputRequirementExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![self.dist_requirement.clone()]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
vec![self.order_requirement.clone()]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::new(
children.remove(0), self.order_requirement.clone(),
self.dist_requirement.clone(),
self.fetch,
)))
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
unreachable!();
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
self.input.partition_statistics(partition)
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let proj_exprs = projection.expr();
if proj_exprs.len() >= projection.input().schema().fields().len() {
return Ok(None);
}
let mut requirements = self.required_input_ordering().swap_remove(0);
if let Some(reqs) = requirements {
let mut updated_reqs = vec![];
let (lexes, soft) = reqs.into_alternatives();
for lex in lexes.into_iter() {
let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)?
else {
return Ok(None);
};
updated_reqs.push(updated_lex);
}
requirements = OrderingRequirements::new_alternatives(updated_reqs, soft);
}
let dist_req = match &self.required_input_distribution()[0] {
Distribution::HashPartitioned(exprs) => {
let mut updated_exprs = vec![];
for expr in exprs {
let Some(new_expr) = update_expr(expr, projection.expr(), false)?
else {
return Ok(None);
};
updated_exprs.push(new_expr);
}
Distribution::HashPartitioned(updated_exprs)
}
dist => dist.clone(),
};
make_with_child(projection, &self.input()).map(|input| {
let e = OutputRequirementExec::new(input, requirements, dist_req, self.fetch);
Some(Arc::new(e) as _)
})
}
fn fetch(&self) -> Option<usize> {
self.fetch
}
}
impl PhysicalOptimizerRule for OutputRequirements {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
match self.mode {
RuleMode::Add => require_top_ordering(plan),
RuleMode::Remove => plan
.transform_up(|plan| {
if let Some(sort_req) =
plan.as_any().downcast_ref::<OutputRequirementExec>()
{
Ok(Transformed::yes(sort_req.input()))
} else {
Ok(Transformed::no(plan))
}
})
.data(),
}
}
fn name(&self) -> &str {
"OutputRequirements"
}
fn schema_check(&self) -> bool {
true
}
}
fn require_top_ordering(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
let (new_plan, is_changed) = require_top_ordering_helper(plan)?;
if is_changed {
Ok(new_plan)
} else {
Ok(Arc::new(OutputRequirementExec::new(
new_plan,
None,
Distribution::UnspecifiedDistribution,
None,
)) as _)
}
}
fn require_top_ordering_helper(
plan: Arc<dyn ExecutionPlan>,
) -> Result<(Arc<dyn ExecutionPlan>, bool)> {
let mut children = plan.children();
if children.len() != 1 {
Ok((plan, false))
} else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
let req_dist = sort_exec.required_input_distribution().swap_remove(0);
let req_ordering = sort_exec.expr();
let reqs = OrderingRequirements::from(req_ordering.clone());
let fetch = sort_exec.fetch();
Ok((
Arc::new(OutputRequirementExec::new(
plan,
Some(reqs),
req_dist,
fetch,
)) as _,
true,
))
} else if let Some(spm) = plan.as_any().downcast_ref::<SortPreservingMergeExec>() {
let reqs = OrderingRequirements::from(spm.expr().clone());
let fetch = spm.fetch();
Ok((
Arc::new(OutputRequirementExec::new(
plan,
Some(reqs),
Distribution::SinglePartition,
fetch,
)) as _,
true,
))
} else if plan.maintains_input_order()[0]
&& (plan.required_input_ordering()[0]
.as_ref()
.is_none_or(|o| matches!(o, OrderingRequirements::Soft(_))))
{
let (new_child, is_changed) =
require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?;
Ok((plan.with_new_children(vec![new_child])?, is_changed))
} else {
Ok((plan, false))
}
}