use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
use crate::aggregate::AggregateFunctionExpr;
use crate::window::standard::add_new_ordering_expr_with_partition_by;
use crate::window::window_expr::{AggregateWindowExpr, WindowFn, filter_array};
use crate::window::{
PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
};
use crate::{EquivalenceProperties, PhysicalExpr};
use arrow::array::ArrayRef;
use arrow::array::BooleanArray;
use arrow::datatypes::FieldRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{Result, ScalarValue, exec_datafusion_err};
use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
#[derive(Debug)]
pub struct PlainAggregateWindowExpr {
aggregate: Arc<AggregateFunctionExpr>,
partition_by: Vec<Arc<dyn PhysicalExpr>>,
order_by: Vec<PhysicalSortExpr>,
window_frame: Arc<WindowFrame>,
is_constant_in_partition: bool,
filter: Option<Arc<dyn PhysicalExpr>>,
}
impl PlainAggregateWindowExpr {
pub fn new(
aggregate: Arc<AggregateFunctionExpr>,
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
window_frame: Arc<WindowFrame>,
filter: Option<Arc<dyn PhysicalExpr>>,
) -> Self {
let is_constant_in_partition =
Self::is_window_constant_in_partition(order_by, &window_frame);
Self {
aggregate,
partition_by: partition_by.to_vec(),
order_by: order_by.to_vec(),
window_frame,
is_constant_in_partition,
filter,
}
}
pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
&self.aggregate
}
pub fn add_equal_orderings(
&self,
eq_properties: &mut EquivalenceProperties,
window_expr_index: usize,
) -> Result<()> {
if let Some(expr) = self
.get_aggregate_expr()
.get_result_ordering(window_expr_index)
{
add_new_ordering_expr_with_partition_by(
eq_properties,
expr,
&self.partition_by,
)?;
}
Ok(())
}
fn is_window_constant_in_partition(
order_by: &[PhysicalSortExpr],
window_frame: &WindowFrame,
) -> bool {
let is_constant_bound = |bound: &WindowFrameBound| match bound {
WindowFrameBound::CurrentRow => {
window_frame.units == WindowFrameUnits::Range && order_by.is_empty()
}
_ => bound.is_unbounded(),
};
is_constant_bound(&window_frame.start_bound)
&& is_constant_bound(&window_frame.end_bound)
}
}
impl WindowExpr for PlainAggregateWindowExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn field(&self) -> Result<FieldRef> {
Ok(self.aggregate.field())
}
fn name(&self) -> &str {
self.aggregate.name()
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.aggregate.expressions()
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
self.aggregate_evaluate(batch)
}
fn evaluate_stateful(
&self,
partition_batches: &PartitionBatches,
window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
for partition_row in partition_batches.keys() {
let window_state = window_agg_state
.get_mut(partition_row)
.ok_or_else(|| exec_datafusion_err!("Cannot find state"))?;
let state = &mut window_state.state;
if self.window_frame.start_bound.is_unbounded() {
state.window_frame_range.start =
state.window_frame_range.end.saturating_sub(1);
}
}
Ok(())
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.partition_by
}
fn order_by(&self) -> &[PhysicalSortExpr] {
&self.order_by
}
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
self.aggregate.reverse_expr().map(|reverse_expr| {
let reverse_window_frame = self.window_frame.reverse();
if reverse_window_frame.is_ever_expanding() {
Arc::new(PlainAggregateWindowExpr::new(
Arc::new(reverse_expr),
&self.partition_by.clone(),
&self
.order_by
.iter()
.map(|e| e.reverse())
.collect::<Vec<_>>(),
Arc::new(self.window_frame.reverse()),
self.filter.clone(),
)) as _
} else {
Arc::new(SlidingAggregateWindowExpr::new(
Arc::new(reverse_expr),
&self.partition_by.clone(),
&self
.order_by
.iter()
.map(|e| e.reverse())
.collect::<Vec<_>>(),
Arc::new(self.window_frame.reverse()),
self.filter.clone(),
)) as _
}
})
}
fn uses_bounded_memory(&self) -> bool {
!self.window_frame.end_bound.is_unbounded()
}
fn create_window_fn(&self) -> Result<WindowFn> {
Ok(WindowFn::Aggregate(self.get_accumulator()?))
}
}
impl AggregateWindowExpr for PlainAggregateWindowExpr {
fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
self.aggregate.create_accumulator()
}
fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.filter.as_ref()
}
fn get_aggregate_result_inside_range(
&self,
last_range: &Range<usize>,
cur_range: &Range<usize>,
value_slice: &[ArrayRef],
accumulator: &mut Box<dyn Accumulator>,
filter_mask: Option<&BooleanArray>,
) -> Result<ScalarValue> {
if cur_range.start == cur_range.end {
self.aggregate
.default_value(self.aggregate.field().data_type())
} else {
let update_bound = cur_range.end - last_range.end;
if update_bound > 0 {
let slice_mask =
filter_mask.map(|m| m.slice(last_range.end, update_bound));
let update: Vec<ArrayRef> = value_slice
.iter()
.map(|v| v.slice(last_range.end, update_bound))
.map(|arr| match &slice_mask {
Some(m) => filter_array(&arr, m),
None => Ok(arr),
})
.collect::<Result<Vec<_>>>()?;
accumulator.update_batch(&update)?
}
accumulator.evaluate()
}
}
fn is_constant_in_partition(&self) -> bool {
self.is_constant_in_partition
}
}