use crate::window::partition_evaluator::PartitionEvaluator;
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::compute::kernels::partition::lexicographical_partition_ranges;
use arrow::compute::kernels::sort::SortColumn;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
use arrow_schema::DataType;
use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, WindowFrame};
use indexmap::IndexMap;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
pub trait WindowExpr: Send + Sync + Debug {
fn as_any(&self) -> &dyn Any;
fn field(&self) -> Result<Field>;
fn name(&self) -> &str {
"WindowExpr: default name"
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
self.expressions()
.iter()
.map(|e| e.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect()
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
fn evaluate_stateful(
&self,
_partition_batches: &PartitionBatches,
_window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
Err(DataFusionError::Internal(format!(
"evaluate_stateful is not implemented for {}",
self.name()
)))
}
fn evaluate_partition_points(
&self,
num_rows: usize,
partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
if partition_columns.is_empty() {
Ok(vec![Range {
start: 0,
end: num_rows,
}])
} else {
Ok(lexicographical_partition_ranges(partition_columns)
.map_err(DataFusionError::ArrowError)?
.collect::<Vec<_>>())
}
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
fn order_by(&self) -> &[PhysicalSortExpr];
fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.order_by()
.iter()
.map(|e| e.evaluate_to_sort_column(batch))
.collect::<Result<Vec<SortColumn>>>()
}
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
let order_by_columns = self.order_by_columns(batch)?;
Ok(order_by_columns)
}
fn get_values_orderbys(
&self,
record_batch: &RecordBatch,
) -> Result<(Vec<ArrayRef>, Vec<ArrayRef>)> {
let values = self.evaluate_args(record_batch)?;
let order_by_columns = self.order_by_columns(record_batch)?;
let order_bys: Vec<ArrayRef> =
order_by_columns.iter().map(|s| s.values.clone()).collect();
Ok((values, order_bys))
}
fn get_window_frame(&self) -> &Arc<WindowFrame>;
fn uses_bounded_memory(&self) -> bool;
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
}
pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec<PhysicalSortExpr> {
order_bys
.iter()
.map(|e| PhysicalSortExpr {
expr: e.expr.clone(),
options: reverse_sort_options(e.options),
})
.collect()
}
#[derive(Debug)]
pub enum WindowFn {
Builtin(Box<dyn PartitionEvaluator>),
Aggregate(Box<dyn Accumulator>),
}
#[derive(Debug, Clone, Default)]
pub struct RankState {
pub last_rank_data: Vec<ScalarValue>,
pub last_rank_boundary: usize,
pub n_rank: usize,
}
#[derive(Debug, Clone, Default)]
pub struct NumRowsState {
pub n_rows: usize,
}
#[derive(Debug, Clone, Default)]
pub struct NthValueState {
pub range: Range<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct LeadLagState {
pub idx: usize,
}
#[derive(Debug, Clone, Default)]
pub enum BuiltinWindowState {
Rank(RankState),
NumRows(NumRowsState),
NthValue(NthValueState),
LeadLag(LeadLagState),
#[default]
Default,
}
#[derive(Debug)]
pub enum WindowFunctionState {
AggregateState(Vec<ScalarValue>),
BuiltinWindowState(BuiltinWindowState),
}
#[derive(Debug)]
pub struct WindowAggState {
pub window_frame_range: Range<usize>,
pub last_calculated_index: usize,
pub offset_pruned_rows: usize,
pub window_function_state: WindowFunctionState,
pub out_col: ArrayRef,
pub n_row_result_missing: usize,
pub is_end: bool,
}
#[derive(Debug)]
pub struct PartitionBatchState {
pub record_batch: RecordBatch,
pub is_end: bool,
}
pub type PartitionKey = Vec<ScalarValue>;
#[derive(Debug)]
pub struct WindowState {
pub state: WindowAggState,
pub window_fn: WindowFn,
}
pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
impl WindowAggState {
pub fn new(
out_type: &DataType,
window_function_state: WindowFunctionState,
) -> Result<Self> {
let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
Ok(Self {
window_frame_range: Range { start: 0, end: 0 },
last_calculated_index: 0,
offset_pruned_rows: 0,
window_function_state,
out_col: empty_out_col,
n_row_result_missing: 0,
is_end: false,
})
}
}