use crate::window::partition_evaluator::find_ranges_in_range;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};
use arrow::compute::concat;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_expr::Accumulator;
use datafusion_expr::{WindowFrame, WindowFrameUnits};
use std::any::Any;
use std::iter::IntoIterator;
use std::ops::Range;
use std::sync::Arc;
#[derive(Debug)]
pub struct AggregateWindowExpr {
aggregate: Arc<dyn AggregateExpr>,
partition_by: Vec<Arc<dyn PhysicalExpr>>,
order_by: Vec<PhysicalSortExpr>,
window_frame: Option<WindowFrame>,
}
impl AggregateWindowExpr {
pub fn new(
aggregate: Arc<dyn AggregateExpr>,
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
window_frame: Option<WindowFrame>,
) -> Self {
Self {
aggregate,
partition_by: partition_by.to_vec(),
order_by: order_by.to_vec(),
window_frame,
}
}
fn evaluation_mode(&self) -> WindowFrameUnits {
self.window_frame.unwrap_or_default().units
}
fn create_accumulator(&self) -> Result<AggregateWindowAccumulator> {
let accumulator = self.aggregate.create_accumulator()?;
Ok(AggregateWindowAccumulator { accumulator })
}
fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let num_rows = batch.num_rows();
let partition_points =
self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
let sort_partition_points =
self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?;
let values = self.evaluate_args(batch)?;
let results = partition_points
.iter()
.map(|partition_range| {
let sort_partition_points =
find_ranges_in_range(partition_range, &sort_partition_points);
let mut window_accumulators = self.create_accumulator()?;
sort_partition_points
.iter()
.map(|range| window_accumulators.scan_peers(&values, range))
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<Vec<ArrayRef>>>>()?
.into_iter()
.flatten()
.collect::<Vec<ArrayRef>>();
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
concat(&results).map_err(DataFusionError::ArrowError)
}
fn group_based_evaluate(&self, _batch: &RecordBatch) -> Result<ArrayRef> {
Err(DataFusionError::NotImplemented(format!(
"Group based evaluation for {} is not yet implemented",
self.name()
)))
}
fn row_based_evaluate(&self, _batch: &RecordBatch) -> Result<ArrayRef> {
Err(DataFusionError::NotImplemented(format!(
"Row based evaluation for {} is not yet implemented",
self.name()
)))
}
}
impl WindowExpr for AggregateWindowExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.aggregate.name()
}
fn field(&self) -> Result<Field> {
self.aggregate.field()
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.aggregate.expressions()
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.partition_by
}
fn order_by(&self) -> &[PhysicalSortExpr] {
&self.order_by
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
match self.evaluation_mode() {
WindowFrameUnits::Range => self.peer_based_evaluate(batch),
WindowFrameUnits::Rows => self.row_based_evaluate(batch),
WindowFrameUnits::Groups => self.group_based_evaluate(batch),
}
}
}
#[derive(Debug)]
struct AggregateWindowAccumulator {
accumulator: Box<dyn Accumulator>,
}
impl AggregateWindowAccumulator {
fn scan_peers(
&mut self,
values: &[ArrayRef],
value_range: &Range<usize>,
) -> Result<ArrayRef> {
if value_range.is_empty() {
return Err(DataFusionError::Internal(
"Value range cannot be empty".to_owned(),
));
}
let len = value_range.end - value_range.start;
let values = values
.iter()
.map(|v| v.slice(value_range.start, len))
.collect::<Vec<_>>();
self.accumulator.update_batch(&values)?;
let value = self.accumulator.evaluate()?;
Ok(value.to_array_of_size(len))
}
}