use std::fmt::{self, Formatter};
use std::sync::Arc;
use arrow::array::{RecordBatch, UInt32Array};
use arrow::compute::{BatchCoalescer, take_record_batch};
use arrow::datatypes::SchemaRef;
use arrow::row::{OwnedRow, RowConverter};
use datafusion_common::{HashMap, Result};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::StreamExt;
use futures::TryStreamExt;
use parking_lot::RwLock;
use crate::execution_plan::{Boundedness, EmissionType};
use crate::metrics::ExecutionPlanMetricsSet;
use crate::topk::{TopK, TopKDynamicFilters, build_sort_fields};
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PlanProperties, SendableRecordBatchStream, stream::RecordBatchStreamAdapter,
};
#[derive(Debug, Clone)]
pub struct PartitionedTopKExec {
input: Arc<dyn ExecutionPlan>,
expr: LexOrdering,
partition_prefix_len: usize,
fetch: usize,
metrics_set: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
impl PartitionedTopKExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
expr: LexOrdering,
partition_prefix_len: usize,
fetch: usize,
) -> Result<Self> {
let cache = Self::compute_properties(&input, expr.clone())?;
Ok(Self {
input,
expr,
partition_prefix_len,
fetch,
metrics_set: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
})
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn expr(&self) -> &LexOrdering {
&self.expr
}
pub fn partition_prefix_len(&self) -> usize {
self.partition_prefix_len
}
pub fn fetch(&self) -> usize {
self.fetch
}
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
sort_exprs: LexOrdering,
) -> Result<PlanProperties> {
let mut eq_properties = input.equivalence_properties().clone();
eq_properties.reorder(sort_exprs)?;
Ok(PlanProperties::new(
eq_properties,
input.output_partitioning().clone(),
EmissionType::Final,
Boundedness::Bounded,
))
}
}
impl DisplayAs for PartitionedTopKExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
.iter()
.map(|e| format!("{}", e.expr))
.collect();
let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
.iter()
.map(|e| format!("{e}"))
.collect();
write!(
f,
"PartitionedTopKExec: fetch={}, partition=[{}], order=[{}]",
self.fetch,
partition_exprs.join(", "),
order_exprs.join(", "),
)
}
DisplayFormatType::TreeRender => {
let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
.iter()
.map(|e| format!("{}", e.expr))
.collect();
let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
.iter()
.map(|e| format!("{e}"))
.collect();
writeln!(f, "fetch={}", self.fetch)?;
writeln!(f, "partition=[{}]", partition_exprs.join(", "))?;
writeln!(f, "order=[{}]", order_exprs.join(", "))
}
}
}
}
impl ExecutionPlan for PartitionedTopKExec {
fn name(&self) -> &'static str {
"PartitionedTopKExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
[..self.partition_prefix_len]
.iter()
.map(|e| Arc::clone(&e.expr))
.collect();
vec![Distribution::HashPartitioned(partition_exprs)]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
assert_eq!(children.len(), 1);
Ok(Arc::new(PartitionedTopKExec::try_new(
Arc::clone(&children[0]),
self.expr.clone(),
self.partition_prefix_len,
self.fetch,
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, Arc::clone(&context))?;
let schema = input.schema();
let partition_sort_fields =
build_sort_fields(&self.expr[..self.partition_prefix_len], &schema)?;
let partition_converter = RowConverter::new(partition_sort_fields)?;
let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
[..self.partition_prefix_len]
.iter()
.map(|e| Arc::clone(&e.expr))
.collect();
let order_expr: LexOrdering =
LexOrdering::new(self.expr[self.partition_prefix_len..].iter().cloned())
.expect("PartitionedTopKExec requires at least one order-by expression");
let fetch = self.fetch;
let batch_size = context.session_config().batch_size();
let runtime = Arc::clone(&context.runtime_env());
let metrics_set = self.metrics_set.clone();
let stream = futures::stream::once(async move {
do_partitioned_topk(
input,
schema,
partition_converter,
partition_exprs,
order_expr,
fetch,
batch_size,
runtime,
metrics_set,
)
.await
})
.try_flatten();
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.input.schema(),
stream,
)))
}
}
fn create_noop_dynamic_filter() -> Arc<RwLock<TopKDynamicFilters>> {
Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
DynamicFilterPhysicalExpr::new(vec![], lit(true)),
))))
}
#[expect(clippy::too_many_arguments)]
async fn do_partitioned_topk(
mut input: SendableRecordBatchStream,
schema: SchemaRef,
partition_converter: RowConverter,
partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
order_expr: LexOrdering,
fetch: usize,
batch_size: usize,
runtime: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
metrics_set: ExecutionPlanMetricsSet,
) -> Result<SendableRecordBatchStream> {
let mut partitions: HashMap<OwnedRow, TopK> = HashMap::new();
let mut partition_counter: usize = 0;
macro_rules! new_topk {
() => {{
let id = partition_counter;
partition_counter += 1;
TopK::try_new(
id,
Arc::clone(&schema),
vec![],
order_expr.clone(),
fetch,
batch_size,
Arc::clone(&runtime),
&metrics_set,
create_noop_dynamic_filter(),
)
}};
}
while let Some(batch) = input.next().await {
let batch = batch?;
let num_rows = batch.num_rows();
if num_rows == 0 {
continue;
}
let pk_arrays: Vec<_> = partition_exprs
.iter()
.map(|e| e.evaluate(&batch).and_then(|v| v.into_array(num_rows)))
.collect::<Result<Vec<_>>>()?;
let pk_rows = partition_converter.convert_columns(&pk_arrays)?;
let mut groups: HashMap<OwnedRow, Vec<u32>> = HashMap::new();
for row_idx in 0..num_rows {
let pk = pk_rows.row(row_idx).owned();
groups.entry(pk).or_default().push(row_idx as u32);
}
for (pk, indices) in groups {
if !partitions.contains_key(&pk) {
partitions.insert(pk.clone(), new_topk!()?);
}
let topk = partitions.get_mut(&pk).unwrap();
let indices_array = UInt32Array::from(indices);
let sub_batch = take_record_batch(&batch, &indices_array)?;
topk.insert_batch(sub_batch)?;
}
}
drop(input);
let mut sorted_pks: Vec<OwnedRow> = partitions.keys().cloned().collect();
sorted_pks.sort();
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), batch_size);
for pk in sorted_pks {
if let Some(topk) = partitions.remove(&pk) {
let mut stream = topk.emit()?;
while let Some(batch) = stream.next().await {
coalescer.push_batch(batch?)?;
}
}
}
coalescer.finish_buffered_batch()?;
let mut output_batches: Vec<RecordBatch> = Vec::new();
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(output_batches.into_iter().map(Ok)),
)))
}