use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use arrow::datatypes::DataType;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::window::StandardWindowExpr;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::sorts::partitioned_topk::PartitionedTopKExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr};
#[derive(Default, Clone, Debug)]
pub struct WindowTopN;
impl WindowTopN {
pub fn new() -> Self {
Self
}
fn try_transform(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
let filter = plan.downcast_ref::<FilterExec>()?;
if filter.projection().is_some() {
return None;
}
let (col_idx, limit_n) = extract_window_limit(filter.predicate())?;
let child = filter.input();
let (window_exec, proj_between) = find_window_below(child)?;
let input_field_count = window_exec.input().schema().fields().len();
if col_idx < input_field_count {
return None; }
let window_expr_idx = col_idx - input_field_count;
let window_exprs = window_exec.window_expr();
if window_expr_idx >= window_exprs.len() {
return None;
}
if !is_row_number(&window_exprs[window_expr_idx]) {
return None;
}
let sort_exec = window_exec.input().downcast_ref::<SortExec>()?;
let sort_child = sort_exec.input();
let partition_by = window_exprs[window_expr_idx].partition_by();
let partition_prefix_len = partition_by.len();
if partition_prefix_len == 0 {
return None;
}
let partitioned_topk = PartitionedTopKExec::try_new(
Arc::clone(sort_child),
sort_exec.expr().clone(),
partition_prefix_len,
limit_n,
)
.ok()?;
let new_window = Arc::clone(&child_as_arc(window_exec))
.with_new_children(vec![Arc::new(partitioned_topk)])
.ok()?;
let result = match proj_between {
Some(proj) => Arc::clone(&child_as_arc(proj))
.with_new_children(vec![new_window])
.ok()?,
None => new_window,
};
Some(result)
}
}
fn child_as_arc<T: ExecutionPlan + Clone + 'static>(plan: &T) -> Arc<dyn ExecutionPlan> {
Arc::new(plan.clone())
}
impl PhysicalOptimizerRule for WindowTopN {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if !config.optimizer.enable_window_topn {
return Ok(plan);
}
plan.transform_down(|node| {
Ok(
if let Some(transformed) = WindowTopN::try_transform(&node) {
Transformed::yes(transformed)
} else {
Transformed::no(node)
},
)
})
.data()
}
fn name(&self) -> &str {
"WindowTopN"
}
fn schema_check(&self) -> bool {
true
}
}
fn extract_window_limit(
predicate: &Arc<dyn datafusion_physical_expr::PhysicalExpr>,
) -> Option<(usize, usize)> {
let binary = predicate.downcast_ref::<BinaryExpr>()?;
let op = binary.op();
let left = binary.left();
let right = binary.right();
if let (Some(col), Some(lit_val)) = (
left.downcast_ref::<Column>(),
right.downcast_ref::<Literal>(),
) {
let n = scalar_to_usize(lit_val.value())?;
return match *op {
Operator::LtEq => Some((col.index(), n)),
Operator::Lt => Some((col.index(), n - 1)),
_ => None,
};
}
if let (Some(lit_val), Some(col)) = (
left.downcast_ref::<Literal>(),
right.downcast_ref::<Column>(),
) {
let n = scalar_to_usize(lit_val.value())?;
return match *op {
Operator::GtEq => Some((col.index(), n)),
Operator::Gt => Some((col.index(), n - 1)),
_ => None,
};
}
None
}
fn scalar_to_usize(value: &ScalarValue) -> Option<usize> {
if !value.data_type().is_integer() {
return None;
}
let casted = value.cast_to(&DataType::UInt64).ok()?;
match casted {
ScalarValue::UInt64(Some(v)) if v > 0 => usize::try_from(v).ok(),
_ => None,
}
}
fn is_row_number(expr: &Arc<dyn datafusion_physical_expr::window::WindowExpr>) -> bool {
let Some(swe) = expr.as_any().downcast_ref::<StandardWindowExpr>() else {
return false;
};
let swfe = swe.get_standard_func_expr();
let Some(udf) = swfe.as_any().downcast_ref::<WindowUDFExpr>() else {
return false;
};
udf.fun().name() == "row_number"
}
fn find_window_below(
plan: &Arc<dyn ExecutionPlan>,
) -> Option<(&BoundedWindowAggExec, Option<&ProjectionExec>)> {
if let Some(window) = plan.downcast_ref::<BoundedWindowAggExec>() {
return Some((window, None));
}
if let Some(proj) = plan.downcast_ref::<ProjectionExec>() {
let proj_child = proj.input();
if let Some(window) = proj_child.downcast_ref::<BoundedWindowAggExec>() {
return Some((window, Some(proj)));
}
}
None
}