use std::sync::Arc;
use crate::physical_plan::aggregates::AggregateExec;
use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::ExecutionPlan;
use arrow_schema::DataType;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalSortExpr;
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use itertools::Itertools;
pub struct TopKAggregation {}
impl TopKAggregation {
pub fn new() -> Self {
Self {}
}
fn transform_agg(
aggr: &AggregateExec,
order: &PhysicalSortExpr,
limit: usize,
) -> Option<Arc<dyn ExecutionPlan>> {
let (field, desc) = aggr.get_minmax_desc()?;
if desc != order.options.descending {
return None;
}
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
if !kt.is_primitive() && kt != DataType::Utf8 {
return None;
}
if aggr.filter_expr().iter().any(|e| e.is_some()) {
return None;
}
let col = order.expr.as_any().downcast_ref::<Column>()?;
if col.name() != field.name() {
return None;
}
let new_aggr = AggregateExec::try_new(
*aggr.mode(),
aggr.group_expr().clone(),
aggr.aggr_expr().to_vec(),
aggr.filter_expr().to_vec(),
aggr.input().clone(),
aggr.input_schema(),
)
.expect("Unable to copy Aggregate!")
.with_limit(Some(limit));
Some(Arc::new(new_aggr))
}
fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
let sort = plan.as_any().downcast_ref::<SortExec>()?;
let children = sort.children();
let child = children.into_iter().exactly_one().ok()?;
let order = sort.properties().output_ordering()?;
let order = order.iter().exactly_one().ok()?;
let limit = sort.fetch()?;
let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
plan.as_any()
.downcast_ref::<CoalesceBatchesExec>()
.is_some()
|| plan.as_any().downcast_ref::<RepartitionExec>().is_some()
|| plan.as_any().downcast_ref::<FilterExec>().is_some()
};
let mut cardinality_preserved = true;
let closure = |plan: Arc<dyn ExecutionPlan>| {
if !cardinality_preserved {
return Ok(Transformed::no(plan));
}
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
match Self::transform_agg(aggr, order, limit) {
None => cardinality_preserved = false,
Some(plan) => return Ok(Transformed::yes(plan)),
}
} else {
if !is_cardinality_preserving(plan.clone()) {
cardinality_preserved = false;
}
}
Ok(Transformed::no(plan))
};
let child = child.clone().transform_down(closure).data().ok()?;
let sort = SortExec::new(sort.expr().to_vec(), child)
.with_fetch(sort.fetch())
.with_preserve_partitioning(sort.preserve_partitioning());
Some(Arc::new(sort))
}
}
impl Default for TopKAggregation {
fn default() -> Self {
Self::new()
}
}
impl PhysicalOptimizerRule for TopKAggregation {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if config.optimizer.enable_topk_aggregation {
plan.transform_down(|plan| {
Ok(
if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) {
Transformed::yes(plan)
} else {
Transformed::no(plan)
},
)
})
.data()
} else {
Ok(plan)
}
}
fn name(&self) -> &str {
"LimitAggregation"
}
fn schema_check(&self) -> bool {
true
}
}