use std::{collections::VecDeque, ops::Range, sync::Arc};
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use arrow::{
array::ArrayRef,
compute::{concat, concat_batches, SortOptions},
datatypes::{DataType, SchemaRef},
record_batch::RecordBatch,
};
use datafusion_common::{
internal_err,
utils::{compare_rows, get_row_at_idx, search_in_slice},
DataFusionError, Result, ScalarValue,
};
#[derive(Debug)]
pub struct WindowAggState {
pub window_frame_range: Range<usize>,
pub window_frame_ctx: Option<WindowFrameContext>,
pub last_calculated_index: usize,
pub offset_pruned_rows: usize,
pub out_col: ArrayRef,
pub n_row_result_missing: usize,
pub is_end: bool,
}
impl WindowAggState {
pub fn prune_state(&mut self, n_prune: usize) {
self.window_frame_range = Range {
start: self.window_frame_range.start - n_prune,
end: self.window_frame_range.end - n_prune,
};
self.last_calculated_index -= n_prune;
self.offset_pruned_rows += n_prune;
match self.window_frame_ctx.as_mut() {
Some(WindowFrameContext::Rows(_)) => {}
Some(WindowFrameContext::Range { .. }) => {}
Some(WindowFrameContext::Groups { state, .. }) => {
let mut n_group_to_del = 0;
for (_, end_idx) in &state.group_end_indices {
if n_prune < *end_idx {
break;
}
n_group_to_del += 1;
}
state.group_end_indices.drain(0..n_group_to_del);
state
.group_end_indices
.iter_mut()
.for_each(|(_, start_idx)| *start_idx -= n_prune);
state.current_group_idx -= n_group_to_del;
}
None => {}
};
}
pub fn update(
&mut self,
out_col: &ArrayRef,
partition_batch_state: &PartitionBatchState,
) -> Result<()> {
self.last_calculated_index += out_col.len();
self.out_col = concat(&[&self.out_col, &out_col])?;
self.n_row_result_missing =
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
self.is_end = partition_batch_state.is_end;
Ok(())
}
pub fn new(out_type: &DataType) -> Result<Self> {
let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?;
Ok(Self {
window_frame_range: Range { start: 0, end: 0 },
window_frame_ctx: None,
last_calculated_index: 0,
offset_pruned_rows: 0,
out_col: empty_out_col,
n_row_result_missing: 0,
is_end: false,
})
}
}
#[derive(Debug)]
pub enum WindowFrameContext {
Rows(Arc<WindowFrame>),
Range {
window_frame: Arc<WindowFrame>,
state: WindowFrameStateRange,
},
Groups {
window_frame: Arc<WindowFrame>,
state: WindowFrameStateGroups,
},
}
impl WindowFrameContext {
pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self {
match window_frame.units {
WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
WindowFrameUnits::Range => WindowFrameContext::Range {
window_frame,
state: WindowFrameStateRange::new(sort_options),
},
WindowFrameUnits::Groups => WindowFrameContext::Groups {
window_frame,
state: WindowFrameStateGroups::default(),
},
}
}
pub fn calculate_range(
&mut self,
range_columns: &[ArrayRef],
last_range: &Range<usize>,
length: usize,
idx: usize,
) -> Result<Range<usize>> {
match self {
WindowFrameContext::Rows(window_frame) => {
Self::calculate_range_rows(window_frame, length, idx)
}
WindowFrameContext::Range {
window_frame,
ref mut state,
} => state.calculate_range(
window_frame,
last_range,
range_columns,
length,
idx,
),
WindowFrameContext::Groups {
window_frame,
ref mut state,
} => state.calculate_range(window_frame, range_columns, length, idx),
}
}
fn calculate_range_rows(
window_frame: &Arc<WindowFrame>,
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
idx.saturating_sub(n as usize)
}
WindowFrameBound::CurrentRow => idx,
WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
return internal_err!(
"Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'"
)
}
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
std::cmp::min(idx + n as usize, length)
}
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
return internal_err!("Rows should be Uint")
}
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
return internal_err!(
"Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'"
)
}
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
if idx >= n as usize {
idx - n as usize + 1
} else {
0
}
}
WindowFrameBound::CurrentRow => idx + 1,
WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
std::cmp::min(idx + n as usize + 1, length)
}
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
return internal_err!("Rows should be Uint")
}
};
Ok(Range { start, end })
}
}
#[derive(Debug)]
pub struct PartitionBatchState {
pub record_batch: RecordBatch,
pub most_recent_row: Option<RecordBatch>,
pub is_end: bool,
pub n_out_row: usize,
}
impl PartitionBatchState {
pub fn new(schema: SchemaRef) -> Self {
Self {
record_batch: RecordBatch::new_empty(schema),
most_recent_row: None,
is_end: false,
n_out_row: 0,
}
}
pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> {
self.record_batch =
concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?;
Ok(())
}
pub fn set_most_recent_row(&mut self, batch: RecordBatch) {
self.most_recent_row = Some(batch);
}
}
#[derive(Debug, Default)]
pub struct WindowFrameStateRange {
sort_options: Vec<SortOptions>,
}
impl WindowFrameStateRange {
fn new(sort_options: Vec<SortOptions>) -> Self {
Self { sort_options }
}
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
last_range: &Range<usize>,
range_columns: &[ArrayRef],
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
if n.is_null() {
0
} else {
self.calculate_index_of_row::<true, true>(
range_columns,
last_range,
idx,
Some(n),
length,
)?
}
}
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
range_columns,
last_range,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => self
.calculate_index_of_row::<true, false>(
range_columns,
last_range,
idx,
Some(n),
length,
)?,
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ref n) => self
.calculate_index_of_row::<false, true>(
range_columns,
last_range,
idx,
Some(n),
length,
)?,
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
range_columns,
last_range,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => {
if n.is_null() {
length
} else {
self.calculate_index_of_row::<false, false>(
range_columns,
last_range,
idx,
Some(n),
length,
)?
}
}
};
Ok(Range { start, end })
}
fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
last_range: &Range<usize>,
idx: usize,
delta: Option<&ScalarValue>,
length: usize,
) -> Result<usize> {
let current_row_values = get_row_at_idx(range_columns, idx)?;
let end_range = if let Some(delta) = delta {
let is_descending: bool = self
.sort_options
.first()
.ok_or_else(|| {
DataFusionError::Internal(
"Sort options unexpectedly absent in a window frame".to_string(),
)
})?
.descending;
current_row_values
.iter()
.map(|value| {
if value.is_null() {
return Ok(value.clone());
}
if SEARCH_SIDE == is_descending {
value.add(delta)
} else if value.is_unsigned() && value < delta {
value.sub(value)
} else {
value.sub(delta)
}
})
.collect::<Result<Vec<ScalarValue>>>()?
} else {
current_row_values
};
let search_start = if SIDE {
last_range.start
} else {
last_range.end
};
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, &self.sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
}
}
#[derive(Debug, Default)]
pub struct WindowFrameStateGroups {
pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>,
pub current_group_idx: usize,
}
impl WindowFrameStateGroups {
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
range_columns: &[ArrayRef],
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
if n.is_null() {
0
} else {
self.calculate_index_of_row::<true, true>(
range_columns,
idx,
Some(n),
length,
)?
}
}
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
range_columns,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => self
.calculate_index_of_row::<true, false>(
range_columns,
idx,
Some(n),
length,
)?,
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ref n) => self
.calculate_index_of_row::<false, true>(
range_columns,
idx,
Some(n),
length,
)?,
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
range_columns,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => {
if n.is_null() {
length
} else {
self.calculate_index_of_row::<false, false>(
range_columns,
idx,
Some(n),
length,
)?
}
}
};
Ok(Range { start, end })
}
fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
idx: usize,
delta: Option<&ScalarValue>,
length: usize,
) -> Result<usize> {
let delta = if let Some(delta) = delta {
if let ScalarValue::UInt64(Some(value)) = delta {
*value as usize
} else {
return internal_err!(
"Unexpectedly got a non-UInt64 value in a GROUPS mode window frame"
);
}
} else {
0
};
let mut group_start = 0;
let last_group = self.group_end_indices.back_mut();
if let Some((group_row, group_end)) = last_group {
if *group_end < length {
let new_group_row = get_row_at_idx(range_columns, *group_end)?;
if new_group_row.eq(group_row) {
*group_end = search_in_slice(
range_columns,
group_row,
check_equality,
*group_end,
length,
)?;
}
}
group_start = *group_end;
}
while idx >= group_start {
let group_row = get_row_at_idx(range_columns, group_start)?;
let group_end = search_in_slice(
range_columns,
&group_row,
check_equality,
group_start,
length,
)?;
self.group_end_indices.push_back((group_row, group_end));
group_start = group_end;
}
while self.current_group_idx < self.group_end_indices.len()
&& idx >= self.group_end_indices[self.current_group_idx].1
{
self.current_group_idx += 1;
}
let group_idx = if SEARCH_SIDE {
self.current_group_idx.saturating_sub(delta)
} else {
self.current_group_idx + delta
};
while self.group_end_indices.len() <= group_idx && group_start < length {
let group_row = get_row_at_idx(range_columns, group_start)?;
let group_end = search_in_slice(
range_columns,
&group_row,
check_equality,
group_start,
length,
)?;
self.group_end_indices.push_back((group_row, group_end));
group_start = group_end;
}
Ok(match (SIDE, SEARCH_SIDE) {
(true, _) => {
let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
if group_idx > 0 {
self.group_end_indices[group_idx - 1].1
} else {
0
}
}
(false, true) => {
if self.current_group_idx >= delta {
let group_idx = self.current_group_idx - delta;
self.group_end_indices[group_idx].1
} else {
0
}
}
(false, false) => {
let group_idx = std::cmp::min(
self.current_group_idx + delta,
self.group_end_indices.len() - 1,
);
self.group_end_indices[group_idx].1
}
})
}
}
fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> {
Ok(current == target)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Float64Array;
fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
]))];
let sort_options = vec![SortOptions {
descending: false,
nulls_first: false,
}];
(range_columns, sort_options)
}
fn assert_expected(
expected_results: Vec<(Range<usize>, usize)>,
window_frame: &Arc<WindowFrame>,
) -> Result<()> {
let mut window_frame_groups = WindowFrameStateGroups::default();
let (range_columns, _) = get_test_data();
let n_row = range_columns[0].len();
for (idx, (expected_range, expected_group_idx)) in
expected_results.into_iter().enumerate()
{
let range = window_frame_groups.calculate_range(
window_frame,
&range_columns,
n_row,
idx,
)?;
assert_eq!(range, expected_range);
assert_eq!(window_frame_groups.current_group_idx, expected_group_idx);
}
Ok(())
}
#[test]
fn test_window_frame_group_boundaries() -> Result<()> {
let window_frame = Arc::new(WindowFrame::new_bounds(
WindowFrameUnits::Groups,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
));
let expected_results = vec![
(Range { start: 0, end: 2 }, 0),
(Range { start: 0, end: 4 }, 1),
(Range { start: 1, end: 5 }, 2),
(Range { start: 1, end: 5 }, 2),
(Range { start: 2, end: 8 }, 3),
(Range { start: 4, end: 9 }, 4),
(Range { start: 4, end: 9 }, 4),
(Range { start: 4, end: 9 }, 4),
(Range { start: 5, end: 9 }, 5),
];
assert_expected(expected_results, &window_frame)
}
#[test]
fn test_window_frame_group_boundaries_both_following() -> Result<()> {
let window_frame = Arc::new(WindowFrame::new_bounds(
WindowFrameUnits::Groups,
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
));
let expected_results = vec![
(Range::<usize> { start: 1, end: 4 }, 0),
(Range::<usize> { start: 2, end: 5 }, 1),
(Range::<usize> { start: 4, end: 8 }, 2),
(Range::<usize> { start: 4, end: 8 }, 2),
(Range::<usize> { start: 5, end: 9 }, 3),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 9, end: 9 }, 5),
];
assert_expected(expected_results, &window_frame)
}
#[test]
fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
let window_frame = Arc::new(WindowFrame::new_bounds(
WindowFrameUnits::Groups,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
));
let expected_results = vec![
(Range::<usize> { start: 0, end: 0 }, 0),
(Range::<usize> { start: 0, end: 1 }, 1),
(Range::<usize> { start: 0, end: 2 }, 2),
(Range::<usize> { start: 0, end: 2 }, 2),
(Range::<usize> { start: 1, end: 4 }, 3),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 4, end: 8 }, 5),
];
assert_expected(expected_results, &window_frame)
}
}