use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_frame_state::WindowFrameContext;
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::{new_empty_array, ArrayRef};
use arrow::compute::kernels::partition::lexicographical_partition_ranges;
use arrow::compute::kernels::sort::SortColumn;
use arrow::compute::{concat, SortOptions};
use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType;
use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, WindowFrame};
use indexmap::IndexMap;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
pub trait WindowExpr: Send + Sync + Debug {
fn as_any(&self) -> &dyn Any;
fn field(&self) -> Result<Field>;
fn name(&self) -> &str {
"WindowExpr: default name"
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
self.expressions()
.iter()
.map(|e| e.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect()
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
fn evaluate_stateful(
&self,
_partition_batches: &PartitionBatches,
_window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
Err(DataFusionError::Internal(format!(
"evaluate_stateful is not implemented for {}",
self.name()
)))
}
fn evaluate_partition_points(
&self,
num_rows: usize,
partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
if partition_columns.is_empty() {
Ok(vec![Range {
start: 0,
end: num_rows,
}])
} else {
Ok(lexicographical_partition_ranges(partition_columns)
.map_err(DataFusionError::ArrowError)?
.collect::<Vec<_>>())
}
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
fn order_by(&self) -> &[PhysicalSortExpr];
fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.order_by()
.iter()
.map(|e| e.evaluate_to_sort_column(batch))
.collect::<Result<Vec<SortColumn>>>()
}
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
let order_by_columns = self.order_by_columns(batch)?;
Ok(order_by_columns)
}
fn get_values_orderbys(
&self,
record_batch: &RecordBatch,
) -> Result<(Vec<ArrayRef>, Vec<ArrayRef>)> {
let values = self.evaluate_args(record_batch)?;
let order_by_columns = self.order_by_columns(record_batch)?;
let order_bys: Vec<ArrayRef> =
order_by_columns.iter().map(|s| s.values.clone()).collect();
Ok((values, order_bys))
}
fn get_window_frame(&self) -> &Arc<WindowFrame>;
fn uses_bounded_memory(&self) -> bool;
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
}
pub trait AggregateWindowExpr: WindowExpr {
fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
fn get_aggregate_result_inside_range(
&self,
last_range: &Range<usize>,
cur_range: &Range<usize>,
value_slice: &[ArrayRef],
accumulator: &mut Box<dyn Accumulator>,
) -> Result<ScalarValue>;
fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
let mut accumulator = self.get_accumulator()?;
let mut last_range = Range { start: 0, end: 0 };
let mut idx = 0;
self.get_result_column(
&mut accumulator,
batch,
&mut window_frame_ctx,
&mut last_range,
&mut idx,
false,
)
}
fn aggregate_evaluate_stateful(
&self,
partition_batches: &PartitionBatches,
window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
let field = self.field()?;
let out_type = field.data_type();
for (partition_row, partition_batch_state) in partition_batches.iter() {
if !window_agg_state.contains_key(partition_row) {
let accumulator = self.get_accumulator()?;
window_agg_state.insert(
partition_row.clone(),
WindowState {
state: WindowAggState::new(out_type)?,
window_fn: WindowFn::Aggregate(accumulator),
},
);
};
let window_state =
window_agg_state.get_mut(partition_row).ok_or_else(|| {
DataFusionError::Execution("Cannot find state".to_string())
})?;
let accumulator = match &mut window_state.window_fn {
WindowFn::Aggregate(accumulator) => accumulator,
_ => unreachable!(),
};
let mut state = &mut window_state.state;
let record_batch = &partition_batch_state.record_batch;
let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
let out_col = self.get_result_column(
accumulator,
record_batch,
&mut window_frame_ctx,
&mut state.window_frame_range,
&mut state.last_calculated_index,
!partition_batch_state.is_end,
)?;
state.is_end = partition_batch_state.is_end;
state.out_col = concat(&[&state.out_col, &out_col])?;
state.n_row_result_missing =
record_batch.num_rows() - state.last_calculated_index;
}
Ok(())
}
fn get_result_column(
&self,
accumulator: &mut Box<dyn Accumulator>,
record_batch: &RecordBatch,
window_frame_ctx: &mut WindowFrameContext,
last_range: &mut Range<usize>,
idx: &mut usize,
not_end: bool,
) -> Result<ArrayRef> {
let (values, order_bys) = self.get_values_orderbys(record_batch)?;
let length = values[0].len();
let sort_options: Vec<SortOptions> =
self.order_by().iter().map(|o| o.options).collect();
let mut row_wise_results: Vec<ScalarValue> = vec![];
while *idx < length {
let cur_range = window_frame_ctx.calculate_range(
&order_bys,
&sort_options,
length,
*idx,
last_range,
)?;
if cur_range.end == length && not_end {
break;
}
let value = self.get_aggregate_result_inside_range(
last_range,
&cur_range,
&values,
accumulator,
)?;
last_range.clone_from(&cur_range);
row_wise_results.push(value);
*idx += 1;
}
if row_wise_results.is_empty() {
let field = self.field()?;
let out_type = field.data_type();
Ok(new_empty_array(out_type))
} else {
ScalarValue::iter_to_array(row_wise_results.into_iter())
}
}
}
pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec<PhysicalSortExpr> {
order_bys
.iter()
.map(|e| PhysicalSortExpr {
expr: e.expr.clone(),
options: reverse_sort_options(e.options),
})
.collect()
}
#[derive(Debug)]
pub enum WindowFn {
Builtin(Box<dyn PartitionEvaluator>),
Aggregate(Box<dyn Accumulator>),
}
#[derive(Debug, Clone, Default)]
pub struct RankState {
pub last_rank_data: Vec<ScalarValue>,
pub last_rank_boundary: usize,
pub n_rank: usize,
}
#[derive(Debug, Clone, Default)]
pub struct NumRowsState {
pub n_rows: usize,
}
#[derive(Debug, Copy, Clone)]
pub enum NthValueKind {
First,
Last,
Nth(u32),
}
#[derive(Debug, Clone)]
pub struct NthValueState {
pub range: Range<usize>,
pub finalized_result: Option<ScalarValue>,
pub kind: NthValueKind,
}
#[derive(Debug, Clone, Default)]
pub struct LeadLagState {
pub idx: usize,
}
#[derive(Debug, Clone, Default)]
pub enum BuiltinWindowState {
Rank(RankState),
NumRows(NumRowsState),
NthValue(NthValueState),
LeadLag(LeadLagState),
#[default]
Default,
}
#[derive(Debug)]
pub struct WindowAggState {
pub window_frame_range: Range<usize>,
pub last_calculated_index: usize,
pub offset_pruned_rows: usize,
pub out_col: ArrayRef,
pub n_row_result_missing: usize,
pub is_end: bool,
}
#[derive(Debug)]
pub struct PartitionBatchState {
pub record_batch: RecordBatch,
pub is_end: bool,
}
pub type PartitionKey = Vec<ScalarValue>;
#[derive(Debug)]
pub struct WindowState {
pub state: WindowAggState,
pub window_fn: WindowFn,
}
pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
impl WindowAggState {
pub fn new(out_type: &DataType) -> Result<Self> {
let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
Ok(Self {
window_frame_range: Range { start: 0, end: 0 },
last_calculated_index: 0,
offset_pruned_rows: 0,
out_col: empty_out_col,
n_row_result_missing: 0,
is_end: false,
})
}
}