use std::sync::Arc;
use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use crate::PhysicalOptimizerRule;
use itertools::Itertools;
#[derive(Debug)]
pub struct LimitedDistinctAggregation {}
impl LimitedDistinctAggregation {
pub fn new() -> Self {
Self {}
}
fn transform_agg(
aggr: &AggregateExec,
limit: usize,
) -> Option<Arc<dyn ExecutionPlan>> {
if !aggr.is_unordered_unfiltered_group_by_distinct() {
return None;
}
let new_aggr = AggregateExec::try_new(
*aggr.mode(),
aggr.group_expr().clone(),
aggr.aggr_expr().to_vec(),
aggr.filter_expr().to_vec(),
aggr.input().to_owned(),
aggr.input_schema(),
)
.expect("Unable to copy Aggregate!")
.with_limit(Some(limit));
Some(Arc::new(new_aggr))
}
fn transform_limit(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
let limit: usize;
let mut global_fetch: Option<usize> = None;
let mut global_skip: usize = 0;
let children: Vec<Arc<dyn ExecutionPlan>>;
let mut is_global_limit = false;
if let Some(local_limit) = plan.as_any().downcast_ref::<LocalLimitExec>() {
limit = local_limit.fetch();
children = local_limit.children().into_iter().cloned().collect();
} else if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>()
{
global_fetch = global_limit.fetch();
global_fetch?;
global_skip = global_limit.skip();
limit = global_fetch.unwrap() + global_skip;
children = global_limit.children().into_iter().cloned().collect();
is_global_limit = true
} else {
return None;
}
let child = children.iter().exactly_one().ok()?;
if plan.output_ordering().is_some() {
return None;
}
if plan.required_input_ordering()[0].is_some() {
return None;
}
let mut match_aggr: Arc<dyn ExecutionPlan> = plan;
let mut found_match_aggr = false;
let mut rewrite_applicable = true;
let closure = |plan: Arc<dyn ExecutionPlan>| {
if !rewrite_applicable {
return Ok(Transformed::no(plan));
}
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
if found_match_aggr {
if let Some(parent_aggr) =
match_aggr.as_any().downcast_ref::<AggregateExec>()
{
if !parent_aggr.group_expr().eq(aggr.group_expr()) {
rewrite_applicable = false;
return Ok(Transformed::no(plan));
}
}
}
match Self::transform_agg(aggr, limit) {
None => {}
Some(new_aggr) => {
match_aggr = plan;
found_match_aggr = true;
return Ok(Transformed::yes(new_aggr));
}
}
}
rewrite_applicable = false;
Ok(Transformed::no(plan))
};
let child = child.to_owned().transform_down(closure).data().ok()?;
if is_global_limit {
return Some(Arc::new(GlobalLimitExec::new(
child,
global_skip,
global_fetch,
)));
}
Some(Arc::new(LocalLimitExec::new(child, limit)))
}
}
impl Default for LimitedDistinctAggregation {
fn default() -> Self {
Self::new()
}
}
impl PhysicalOptimizerRule for LimitedDistinctAggregation {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if config.optimizer.enable_distinct_aggregation_soft_limit {
plan.transform_down(|plan| {
Ok(
if let Some(plan) =
LimitedDistinctAggregation::transform_limit(plan.to_owned())
{
Transformed::yes(plan)
} else {
Transformed::no(plan)
},
)
})
.data()
} else {
Ok(plan)
}
}
fn name(&self) -> &str {
"LimitedDistinctAggregation"
}
fn schema_check(&self) -> bool {
true
}
}