use crate::error::Result;
use crate::execution::context::TaskContext;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
use crate::physical_plan::{
ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr,
};
use arrow::array::Array;
use arrow::compute::{
concat, concat_batches, lexicographical_partition_ranges, SortColumn,
};
use arrow::{
array::ArrayRef,
datatypes::{Schema, SchemaRef},
record_batch::RecordBatch,
};
use datafusion_common::{DataFusionError, ScalarValue};
use futures::stream::Stream;
use futures::{ready, StreamExt};
use std::any::Any;
use std::cmp::min;
use std::collections::HashMap;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use datafusion_physical_expr::window::{
PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates,
WindowAggState, WindowState,
};
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
use indexmap::IndexMap;
use log::debug;
#[derive(Debug)]
pub struct BoundedWindowAggExec {
input: Arc<dyn ExecutionPlan>,
window_expr: Vec<Arc<dyn WindowExpr>>,
schema: SchemaRef,
input_schema: SchemaRef,
pub partition_keys: Vec<Arc<dyn PhysicalExpr>>,
pub sort_keys: Option<Vec<PhysicalSortExpr>>,
metrics: ExecutionPlanMetricsSet,
}
impl BoundedWindowAggExec {
pub fn try_new(
window_expr: Vec<Arc<dyn WindowExpr>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
partition_keys: Vec<Arc<dyn PhysicalExpr>>,
sort_keys: Option<Vec<PhysicalSortExpr>>,
) -> Result<Self> {
let schema = create_schema(&input_schema, &window_expr)?;
let schema = Arc::new(schema);
Ok(Self {
input,
window_expr,
schema,
input_schema,
partition_keys,
sort_keys,
metrics: ExecutionPlanMetricsSet::new(),
})
}
pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
&self.window_expr
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
let mut result = vec![];
let partition_by = self.window_expr()[0].partition_by();
let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]);
for item in partition_by {
if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) {
result.push(a.clone());
} else {
return Err(DataFusionError::Internal(
"Partition key not found in sort keys".to_string(),
));
}
}
Ok(result)
}
}
impl ExecutionPlan for BoundedWindowAggExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
Ok(children[0])
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
self.input().output_ordering()
}
fn required_input_ordering(&self) -> Vec<Option<&[PhysicalSortExpr]>> {
let sort_keys = self.sort_keys.as_deref();
vec![sort_keys]
}
fn required_input_distribution(&self) -> Vec<Distribution> {
if self.partition_keys.is_empty() {
debug!("No partition defined for BoundedWindowAggExec!!!");
vec![Distribution::SinglePartition]
} else {
vec![Distribution::HashPartitioned(self.partition_keys.clone())]
}
}
fn equivalence_properties(&self) -> EquivalenceProperties {
self.input().equivalence_properties()
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(BoundedWindowAggExec::try_new(
self.window_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
self.partition_keys.clone(),
self.sort_keys.clone(),
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, context)?;
let stream = Box::pin(SortedPartitionByBoundedWindowStream::new(
self.schema.clone(),
self.window_expr.clone(),
input,
BaselineMetrics::new(&self.metrics, partition),
self.partition_by_sort_keys()?,
));
Ok(stream)
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "BoundedWindowAggExec: ")?;
let g: Vec<String> = self
.window_expr
.iter()
.map(|e| {
format!(
"{}: {:?}, frame: {:?}",
e.name().to_owned(),
e.field(),
e.get_window_frame()
)
})
.collect();
write!(f, "wdw=[{}]", g.join(", "))?;
}
}
Ok(())
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Statistics {
let input_stat = self.input.statistics();
let win_cols = self.window_expr.len();
let input_cols = self.input_schema.fields().len();
let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
if let Some(input_col_stats) = input_stat.column_statistics {
column_statistics.extend(input_col_stats);
} else {
column_statistics.extend(vec![ColumnStatistics::default(); input_cols]);
}
column_statistics.extend(vec![ColumnStatistics::default(); win_cols]);
Statistics {
is_exact: input_stat.is_exact,
num_rows: input_stat.num_rows,
column_statistics: Some(column_statistics),
total_byte_size: None,
}
}
}
fn create_schema(
input_schema: &Schema,
window_expr: &[Arc<dyn WindowExpr>],
) -> Result<Schema> {
let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len());
fields.extend_from_slice(input_schema.fields());
for expr in window_expr {
fields.push(expr.field()?);
}
Ok(Schema::new(fields))
}
pub trait PartitionByHandler {
fn calculate_out_columns(&self) -> Result<Option<Vec<ArrayRef>>>;
fn prune_state(&mut self, n_out: usize) -> Result<()>;
fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()>;
}
pub struct SortedPartitionByBoundedWindowStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
input_buffer: RecordBatch,
partition_buffers: PartitionBatches,
window_agg_states: Vec<PartitionWindowAggStates>,
finished: bool,
window_expr: Vec<Arc<dyn WindowExpr>>,
partition_by_sort_keys: Vec<PhysicalSortExpr>,
baseline_metrics: BaselineMetrics,
}
impl PartitionByHandler for SortedPartitionByBoundedWindowStream {
fn calculate_out_columns(&self) -> Result<Option<Vec<ArrayRef>>> {
let n_out = self.calculate_n_out_row();
if n_out == 0 {
Ok(None)
} else {
self.input_buffer
.columns()
.iter()
.map(|elem| Ok(elem.slice(0, n_out)))
.chain(
self.window_agg_states
.iter()
.map(|elem| get_aggregate_result_out_column(elem, n_out)),
)
.collect::<Result<Vec<_>>>()
.map(Some)
}
}
fn prune_state(&mut self, n_out: usize) -> Result<()> {
self.prune_partition_batches()?;
self.prune_input_batch(n_out)?;
self.prune_out_columns(n_out)?;
Ok(())
}
fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()> {
let partition_columns = self.partition_columns(&record_batch)?;
let num_rows = record_batch.num_rows();
if num_rows > 0 {
let partition_points =
self.evaluate_partition_points(num_rows, &partition_columns)?;
for partition_range in partition_points {
let partition_row = partition_columns
.iter()
.map(|arr| {
ScalarValue::try_from_array(&arr.values, partition_range.start)
})
.collect::<Result<PartitionKey>>()?;
let partition_batch = record_batch.slice(
partition_range.start,
partition_range.end - partition_range.start,
);
if let Some(partition_batch_state) =
self.partition_buffers.get_mut(&partition_row)
{
partition_batch_state.record_batch = concat_batches(
&self.input.schema(),
[&partition_batch_state.record_batch, &partition_batch],
)?;
} else {
let partition_batch_state = PartitionBatchState {
record_batch: partition_batch,
is_end: false,
};
self.partition_buffers
.insert(partition_row, partition_batch_state);
};
}
}
let n_partitions = self.partition_buffers.len();
for (idx, (_, partition_batch_state)) in
self.partition_buffers.iter_mut().enumerate()
{
partition_batch_state.is_end |= idx < n_partitions - 1;
}
self.input_buffer = if self.input_buffer.num_rows() == 0 {
record_batch
} else {
concat_batches(&self.input.schema(), [&self.input_buffer, &record_batch])?
};
Ok(())
}
}
impl Stream for SortedPartitionByBoundedWindowStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
self.baseline_metrics.record_poll(poll)
}
}
impl SortedPartitionByBoundedWindowStream {
pub fn new(
schema: SchemaRef,
window_expr: Vec<Arc<dyn WindowExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
partition_by_sort_keys: Vec<PhysicalSortExpr>,
) -> Self {
let state = window_expr.iter().map(|_| IndexMap::new()).collect();
let empty_batch = RecordBatch::new_empty(schema.clone());
Self {
schema,
input,
input_buffer: empty_batch,
partition_buffers: IndexMap::new(),
window_agg_states: state,
finished: false,
window_expr,
baseline_metrics,
partition_by_sort_keys,
}
}
fn compute_aggregates(&mut self) -> Result<RecordBatch> {
for (cur_window_expr, state) in
self.window_expr.iter().zip(&mut self.window_agg_states)
{
cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
}
let schema = self.schema.clone();
let columns_to_show = self.calculate_out_columns()?;
if let Some(columns_to_show) = columns_to_show {
let n_generated = columns_to_show[0].len();
self.prune_state(n_generated)?;
Ok(RecordBatch::try_new(schema, columns_to_show)?)
} else {
Ok(RecordBatch::new_empty(schema))
}
}
#[inline]
fn poll_next_inner(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
if self.finished {
return Poll::Ready(None);
}
let result = match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.update_partition_batch(batch)?;
self.compute_aggregates()
}
Some(Err(e)) => Err(e),
None => {
self.finished = true;
for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
partition_batch_state.is_end = true;
}
self.compute_aggregates()
}
};
Poll::Ready(Some(result))
}
fn calculate_n_out_row(&self) -> usize {
self.window_agg_states
.iter()
.map(|window_agg_state| {
let mut cur_window_expr_out_result_len = 0;
for (_, WindowState { state, .. }) in window_agg_state.iter() {
cur_window_expr_out_result_len += state.out_col.len();
if state.n_row_result_missing > 0 {
break;
}
}
cur_window_expr_out_result_len
})
.min()
.unwrap_or(0)
}
fn prune_partition_batches(&mut self) -> Result<()> {
self.partition_buffers
.retain(|_, partition_batch_state| !partition_batch_state.is_end);
let mut n_prune_each_partition: HashMap<PartitionKey, usize> = HashMap::new();
for window_agg_state in self.window_agg_states.iter_mut() {
window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
for (partition_row, WindowState { state: value, .. }) in window_agg_state {
let n_prune =
min(value.window_frame_range.start, value.last_calculated_index);
if let Some(state) = n_prune_each_partition.get_mut(partition_row) {
if n_prune < *state {
*state = n_prune;
}
} else {
n_prune_each_partition.insert(partition_row.clone(), n_prune);
}
}
}
let err = || DataFusionError::Execution("Expects to have partition".to_string());
for (partition_row, n_prune) in n_prune_each_partition.iter() {
let partition_batch_state = self
.partition_buffers
.get_mut(partition_row)
.ok_or_else(err)?;
let batch = &partition_batch_state.record_batch;
partition_batch_state.record_batch =
batch.slice(*n_prune, batch.num_rows() - n_prune);
for window_agg_state in self.window_agg_states.iter_mut() {
let window_state =
window_agg_state.get_mut(partition_row).ok_or_else(err)?;
let mut state = &mut window_state.state;
state.window_frame_range = Range {
start: state.window_frame_range.start - n_prune,
end: state.window_frame_range.end - n_prune,
};
state.last_calculated_index -= n_prune;
state.offset_pruned_rows += n_prune;
}
}
Ok(())
}
fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
let n_to_keep = self.input_buffer.num_rows() - n_out;
let batch_to_keep = self
.input_buffer
.columns()
.iter()
.map(|elem| elem.slice(n_out, n_to_keep))
.collect::<Vec<_>>();
self.input_buffer =
RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?;
Ok(())
}
fn prune_out_columns(&mut self, n_out: usize) -> Result<()> {
for partition_window_agg_states in self.window_agg_states.iter_mut() {
let mut running_length = 0;
for (
_,
WindowState {
state: WindowAggState { out_col, .. },
..
},
) in partition_window_agg_states
{
if running_length < n_out {
let n_to_del = min(out_col.len(), n_out - running_length);
let n_to_keep = out_col.len() - n_to_del;
*out_col = out_col.slice(n_to_del, n_to_keep);
running_length += n_to_del;
}
}
}
Ok(())
}
pub fn partition_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.partition_by_sort_keys
.iter()
.map(|e| e.evaluate_to_sort_column(batch))
.collect::<Result<Vec<_>>>()
}
fn evaluate_partition_points(
&self,
num_rows: usize,
partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
Ok(if partition_columns.is_empty() {
vec![Range {
start: 0,
end: num_rows,
}]
} else {
lexicographical_partition_ranges(partition_columns)?.collect()
})
}
}
impl RecordBatchStream for SortedPartitionByBoundedWindowStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn get_aggregate_result_out_column(
partition_window_agg_states: &PartitionWindowAggStates,
len_to_show: usize,
) -> Result<ArrayRef> {
let mut result = None;
let mut running_length = 0;
for (
_,
WindowState {
state: WindowAggState { out_col, .. },
..
},
) in partition_window_agg_states
{
if running_length < len_to_show {
let n_to_use = min(len_to_show - running_length, out_col.len());
let slice_to_use = out_col.slice(0, n_to_use);
result = Some(match result {
Some(arr) => concat(&[&arr, &slice_to_use])?,
None => slice_to_use,
});
running_length += n_to_use;
} else {
break;
}
}
if running_length != len_to_show {
return Err(DataFusionError::Execution(format!(
"Generated row number should be {len_to_show}, it is {running_length}"
)));
}
result
.ok_or_else(|| DataFusionError::Execution("Should contain something".to_string()))
}