use std::any::Any;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
use crate::PhysicalExpr;
use arrow::array::BooleanArray;
use arrow::array::{Array, ArrayRef, new_empty_array};
use arrow::compute::SortOptions;
use arrow::compute::filter as arrow_filter;
use arrow::compute::kernels::sort::SortColumn;
use arrow::datatypes::FieldRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::utils::compare_rows;
use datafusion_common::{
Result, ScalarValue, arrow_datafusion_err, exec_datafusion_err, internal_err,
};
use datafusion_expr::window_state::{
PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
};
use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use indexmap::IndexMap;
pub trait WindowExpr: Send + Sync + Debug {
fn as_any(&self) -> &dyn Any;
fn field(&self) -> Result<FieldRef>;
fn name(&self) -> &str {
"WindowExpr: default name"
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
evaluate_expressions_to_arrays(&self.expressions(), batch)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
fn evaluate_stateful(
&self,
_partition_batches: &PartitionBatches,
_window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
internal_err!("evaluate_stateful is not implemented for {}", self.name())
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
fn order_by(&self) -> &[PhysicalSortExpr];
fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.order_by()
.iter()
.map(|e| e.evaluate_to_sort_column(batch))
.collect()
}
fn get_window_frame(&self) -> &Arc<WindowFrame>;
fn uses_bounded_memory(&self) -> bool;
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
fn create_window_fn(&self) -> Result<WindowFn>;
fn all_expressions(&self) -> WindowPhysicalExpressions {
let args = self.expressions();
let partition_by_exprs = self.partition_by().to_vec();
let order_by_exprs = self
.order_by()
.iter()
.map(|sort_expr| Arc::clone(&sort_expr.expr))
.collect();
WindowPhysicalExpressions {
args,
partition_by_exprs,
order_by_exprs,
}
}
fn with_new_expressions(
&self,
_args: Vec<Arc<dyn PhysicalExpr>>,
_partition_bys: Vec<Arc<dyn PhysicalExpr>>,
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn WindowExpr>> {
None
}
}
pub struct WindowPhysicalExpressions {
pub args: Vec<Arc<dyn PhysicalExpr>>,
pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
}
pub trait AggregateWindowExpr: WindowExpr {
fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>>;
fn get_aggregate_result_inside_range(
&self,
last_range: &Range<usize>,
cur_range: &Range<usize>,
value_slice: &[ArrayRef],
accumulator: &mut Box<dyn Accumulator>,
filter_mask: Option<&BooleanArray>,
) -> Result<ScalarValue>;
fn is_constant_in_partition(&self) -> bool;
fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let mut accumulator = self.get_accumulator()?;
let mut last_range = Range { start: 0, end: 0 };
let sort_options = self.order_by().iter().map(|o| o.options).collect();
let mut window_frame_ctx =
WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
self.get_result_column(
&mut accumulator,
batch,
None,
&mut last_range,
&mut window_frame_ctx,
0,
false,
)
}
fn aggregate_evaluate_stateful(
&self,
partition_batches: &PartitionBatches,
window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
let field = self.field()?;
let out_type = field.data_type();
for (partition_row, partition_batch_state) in partition_batches.iter() {
if !window_agg_state.contains_key(partition_row) {
let accumulator = self.get_accumulator()?;
window_agg_state.insert(
partition_row.clone(),
WindowState {
state: WindowAggState::new(out_type)?,
window_fn: WindowFn::Aggregate(accumulator),
},
);
};
let window_state = window_agg_state
.get_mut(partition_row)
.ok_or_else(|| exec_datafusion_err!("Cannot find state"))?;
let accumulator = match &mut window_state.window_fn {
WindowFn::Aggregate(accumulator) => accumulator,
_ => unreachable!(),
};
let state = &mut window_state.state;
let record_batch = &partition_batch_state.record_batch;
let most_recent_row = partition_batch_state.most_recent_row.as_ref();
let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
let sort_options = self.order_by().iter().map(|o| o.options).collect();
WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
});
let out_col = self.get_result_column(
accumulator,
record_batch,
most_recent_row,
&mut state.window_frame_range,
window_frame_ctx,
state.last_calculated_index,
!partition_batch_state.is_end,
)?;
state.update(&out_col, partition_batch_state)?;
}
Ok(())
}
#[expect(clippy::too_many_arguments)]
fn get_result_column(
&self,
accumulator: &mut Box<dyn Accumulator>,
record_batch: &RecordBatch,
most_recent_row: Option<&RecordBatch>,
last_range: &mut Range<usize>,
window_frame_ctx: &mut WindowFrameContext,
mut idx: usize,
not_end: bool,
) -> Result<ArrayRef> {
let values = self.evaluate_args(record_batch)?;
let filter_mask_arr: Option<ArrayRef> = match self.filter_expr() {
Some(expr) => {
let value = expr.evaluate(record_batch)?;
Some(value.into_array(record_batch.num_rows())?)
}
None => None,
};
let filter_mask: Option<&BooleanArray> = match filter_mask_arr.as_deref() {
Some(arr) => Some(as_boolean_array(arr)?),
None => None,
};
if self.is_constant_in_partition() {
if not_end {
let field = self.field()?;
let out_type = field.data_type();
return Ok(new_empty_array(out_type));
}
let values = if let Some(mask) = filter_mask {
filter_arrays(&values, mask)?
} else {
values
};
accumulator.update_batch(&values)?;
let value = accumulator.evaluate()?;
return value.to_array_of_size(record_batch.num_rows());
}
let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
let most_recent_row_order_bys = most_recent_row
.map(|batch| self.order_by_columns(batch))
.transpose()?
.map(get_orderby_values);
let length = values[0].len();
let mut row_wise_results: Vec<ScalarValue> = vec![];
let is_causal = self.get_window_frame().is_causal();
while idx < length {
let cur_range =
window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
if cur_range.end == length
&& !is_causal
&& not_end
&& !is_end_bound_safe(
window_frame_ctx,
&order_bys,
most_recent_row_order_bys.as_deref(),
self.order_by(),
idx,
)?
{
break;
}
let value = self.get_aggregate_result_inside_range(
last_range,
&cur_range,
&values,
accumulator,
filter_mask,
)?;
*last_range = cur_range;
row_wise_results.push(value);
idx += 1;
}
if row_wise_results.is_empty() {
let field = self.field()?;
let out_type = field.data_type();
Ok(new_empty_array(out_type))
} else {
ScalarValue::iter_to_array(row_wise_results)
}
}
}
pub(crate) fn filter_array(array: &ArrayRef, mask: &BooleanArray) -> Result<ArrayRef> {
arrow_filter(array.as_ref(), mask)
.map(|a| a as ArrayRef)
.map_err(|e| arrow_datafusion_err!(e))
}
pub(crate) fn filter_arrays(
arrays: &[ArrayRef],
mask: &BooleanArray,
) -> Result<Vec<ArrayRef>> {
arrays.iter().map(|arr| filter_array(arr, mask)).collect()
}
pub(crate) fn is_end_bound_safe(
window_frame_ctx: &WindowFrameContext,
order_bys: &[ArrayRef],
most_recent_order_bys: Option<&[ArrayRef]>,
sort_exprs: &[PhysicalSortExpr],
idx: usize,
) -> Result<bool> {
if sort_exprs.is_empty() {
return Ok(false);
};
match window_frame_ctx {
WindowFrameContext::Rows(window_frame) => {
is_end_bound_safe_for_rows(&window_frame.end_bound)
}
WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
&window_frame.end_bound,
&order_bys[0],
most_recent_order_bys.map(|items| &items[0]),
&sort_exprs[0].options,
idx,
),
WindowFrameContext::Groups {
window_frame,
state,
} => is_end_bound_safe_for_groups(
&window_frame.end_bound,
state,
&order_bys[0],
most_recent_order_bys.map(|items| &items[0]),
&sort_exprs[0].options,
),
}
}
fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
if let WindowFrameBound::Following(value) = end_bound {
let zero = ScalarValue::new_zero(&value.data_type());
Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
} else {
Ok(true)
}
}
fn is_end_bound_safe_for_range(
end_bound: &WindowFrameBound,
orderby_col: &ArrayRef,
most_recent_ob_col: Option<&ArrayRef>,
sort_options: &SortOptions,
idx: usize,
) -> Result<bool> {
match end_bound {
WindowFrameBound::Preceding(value) => {
let zero = ScalarValue::new_zero(&value.data_type())?;
if value.eq(&zero) {
is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
} else {
Ok(true)
}
}
WindowFrameBound::CurrentRow => {
is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
}
WindowFrameBound::Following(delta) => {
let Some(most_recent_ob_col) = most_recent_ob_col else {
return Ok(false);
};
let most_recent_row_value =
ScalarValue::try_from_array(most_recent_ob_col, 0)?;
let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
if sort_options.descending {
current_row_value
.sub(delta)
.map(|value| value > most_recent_row_value)
} else {
current_row_value
.add(delta)
.map(|value| most_recent_row_value > value)
}
}
}
}
fn is_end_bound_safe_for_groups(
end_bound: &WindowFrameBound,
state: &WindowFrameStateGroups,
orderby_col: &ArrayRef,
most_recent_ob_col: Option<&ArrayRef>,
sort_options: &SortOptions,
) -> Result<bool> {
match end_bound {
WindowFrameBound::Preceding(value) => {
let zero = ScalarValue::new_zero(&value.data_type())?;
if value.eq(&zero) {
is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
} else {
Ok(true)
}
}
WindowFrameBound::CurrentRow => {
is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
}
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
let delta = state.group_end_indices.len() - state.current_group_idx;
if delta == (*offset as usize) + 1 {
is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
} else {
Ok(false)
}
}
_ => Ok(false),
}
}
fn is_row_ahead(
old_col: &ArrayRef,
current_col: Option<&ArrayRef>,
sort_options: &SortOptions,
) -> Result<bool> {
let Some(current_col) = current_col else {
return Ok(false);
};
if old_col.is_empty() || current_col.is_empty() {
return Ok(false);
}
let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
let current_value = ScalarValue::try_from_array(current_col, 0)?;
let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
Ok(cmp.is_gt())
}
pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
order_by_columns.into_iter().map(|s| s.values).collect()
}
#[derive(Debug)]
pub enum WindowFn {
Builtin(Box<dyn PartitionEvaluator>),
Aggregate(Box<dyn Accumulator>),
}
pub type PartitionKey = Vec<ScalarValue>;
#[derive(Debug)]
pub struct WindowState {
pub state: WindowAggState,
pub window_fn: WindowFn,
}
pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::window::window_expr::is_row_ahead;
use arrow::array::{ArrayRef, Float64Array};
use arrow::compute::SortOptions;
use datafusion_common::Result;
#[test]
fn test_is_row_ahead() -> Result<()> {
let old_values: ArrayRef =
Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
assert!(is_row_ahead(
&old_values,
Some(&new_values1),
&SortOptions {
descending: false,
nulls_first: false
}
)?);
assert!(!is_row_ahead(
&old_values,
Some(&new_values2),
&SortOptions {
descending: false,
nulls_first: false
}
)?);
Ok(())
}
}