use arrow::array::ArrayRef;
use arrow::compute::kernels::sort::SortOptions;
use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::cmp::min;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
#[derive(Debug)]
pub enum WindowFrameContext {
Rows(Arc<WindowFrame>),
Range {
window_frame: Arc<WindowFrame>,
state: WindowFrameStateRange,
},
Groups {
window_frame: Arc<WindowFrame>,
state: WindowFrameStateGroups,
},
}
impl WindowFrameContext {
pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self {
match window_frame.units {
WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
WindowFrameUnits::Range => WindowFrameContext::Range {
window_frame,
state: WindowFrameStateRange::new(sort_options),
},
WindowFrameUnits::Groups => WindowFrameContext::Groups {
window_frame,
state: WindowFrameStateGroups::default(),
},
}
}
pub fn calculate_range(
&mut self,
range_columns: &[ArrayRef],
last_range: &Range<usize>,
length: usize,
idx: usize,
) -> Result<Range<usize>> {
match self {
WindowFrameContext::Rows(window_frame) => {
Self::calculate_range_rows(window_frame, length, idx)
}
WindowFrameContext::Range {
window_frame,
ref mut state,
} => state.calculate_range(
window_frame,
last_range,
range_columns,
length,
idx,
),
WindowFrameContext::Groups {
window_frame,
ref mut state,
} => state.calculate_range(window_frame, range_columns, length, idx),
}
}
fn calculate_range_rows(
window_frame: &Arc<WindowFrame>,
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
if idx >= n as usize {
idx - n as usize
} else {
0
}
}
WindowFrameBound::CurrentRow => idx,
WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
return Err(DataFusionError::Internal(format!(
"Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'"
)))
}
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
min(idx + n as usize, length)
}
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
return Err(DataFusionError::Internal("Rows should be Uint".to_string()))
}
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
return Err(DataFusionError::Internal(format!(
"Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'"
)))
}
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
if idx >= n as usize {
idx - n as usize + 1
} else {
0
}
}
WindowFrameBound::CurrentRow => idx + 1,
WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
min(idx + n as usize + 1, length)
}
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
return Err(DataFusionError::Internal("Rows should be Uint".to_string()))
}
};
Ok(Range { start, end })
}
}
#[derive(Debug, Default)]
pub struct WindowFrameStateRange {
sort_options: Vec<SortOptions>,
}
impl WindowFrameStateRange {
fn new(sort_options: Vec<SortOptions>) -> Self {
Self { sort_options }
}
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
last_range: &Range<usize>,
range_columns: &[ArrayRef],
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
if n.is_null() {
0
} else {
self.calculate_index_of_row::<true, true>(
range_columns,
last_range,
idx,
Some(n),
length,
)?
}
}
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
range_columns,
last_range,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => self
.calculate_index_of_row::<true, false>(
range_columns,
last_range,
idx,
Some(n),
length,
)?,
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ref n) => self
.calculate_index_of_row::<false, true>(
range_columns,
last_range,
idx,
Some(n),
length,
)?,
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
range_columns,
last_range,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => {
if n.is_null() {
length
} else {
self.calculate_index_of_row::<false, false>(
range_columns,
last_range,
idx,
Some(n),
length,
)?
}
}
};
Ok(Range { start, end })
}
fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
last_range: &Range<usize>,
idx: usize,
delta: Option<&ScalarValue>,
length: usize,
) -> Result<usize> {
let current_row_values = get_row_at_idx(range_columns, idx)?;
let end_range = if let Some(delta) = delta {
let is_descending: bool = self
.sort_options
.first()
.ok_or_else(|| {
DataFusionError::Internal(
"Sort options unexpectedly absent in a window frame".to_string(),
)
})?
.descending;
current_row_values
.iter()
.map(|value| {
if value.is_null() {
return Ok(value.clone());
}
if SEARCH_SIDE == is_descending {
value.add(delta)
} else if value.is_unsigned() && value < delta {
value.sub(value)
} else {
value.sub(delta)
}
})
.collect::<Result<Vec<ScalarValue>>>()?
} else {
current_row_values
};
let search_start = if SIDE {
last_range.start
} else {
last_range.end
};
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, &self.sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
}
}
#[derive(Debug, Default)]
pub struct WindowFrameStateGroups {
pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>,
pub current_group_idx: usize,
}
impl WindowFrameStateGroups {
fn calculate_range(
&mut self,
window_frame: &Arc<WindowFrame>,
range_columns: &[ArrayRef],
length: usize,
idx: usize,
) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
if n.is_null() {
0
} else {
self.calculate_index_of_row::<true, true>(
range_columns,
idx,
Some(n),
length,
)?
}
}
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
range_columns,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => self
.calculate_index_of_row::<true, false>(
range_columns,
idx,
Some(n),
length,
)?,
};
let end = match window_frame.end_bound {
WindowFrameBound::Preceding(ref n) => self
.calculate_index_of_row::<false, true>(
range_columns,
idx,
Some(n),
length,
)?,
WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
range_columns,
idx,
None,
length,
)?,
WindowFrameBound::Following(ref n) => {
if n.is_null() {
length
} else {
self.calculate_index_of_row::<false, false>(
range_columns,
idx,
Some(n),
length,
)?
}
}
};
Ok(Range { start, end })
}
fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
&mut self,
range_columns: &[ArrayRef],
idx: usize,
delta: Option<&ScalarValue>,
length: usize,
) -> Result<usize> {
let delta = if let Some(delta) = delta {
if let ScalarValue::UInt64(Some(value)) = delta {
*value as usize
} else {
return Err(DataFusionError::Internal(
"Unexpectedly got a non-UInt64 value in a GROUPS mode window frame"
.to_string(),
));
}
} else {
0
};
let mut group_start = 0;
let last_group = self.group_end_indices.back_mut();
if let Some((group_row, group_end)) = last_group {
if *group_end < length {
let new_group_row = get_row_at_idx(range_columns, *group_end)?;
if new_group_row.eq(group_row) {
*group_end = search_in_slice(
range_columns,
group_row,
check_equality,
*group_end,
length,
)?;
}
}
group_start = *group_end;
}
while idx >= group_start {
let group_row = get_row_at_idx(range_columns, group_start)?;
let group_end = search_in_slice(
range_columns,
&group_row,
check_equality,
group_start,
length,
)?;
self.group_end_indices.push_back((group_row, group_end));
group_start = group_end;
}
while self.current_group_idx < self.group_end_indices.len()
&& idx >= self.group_end_indices[self.current_group_idx].1
{
self.current_group_idx += 1;
}
let group_idx = if SEARCH_SIDE {
if self.current_group_idx > delta {
self.current_group_idx - delta
} else {
0
}
} else {
self.current_group_idx + delta
};
while self.group_end_indices.len() <= group_idx && group_start < length {
let group_row = get_row_at_idx(range_columns, group_start)?;
let group_end = search_in_slice(
range_columns,
&group_row,
check_equality,
group_start,
length,
)?;
self.group_end_indices.push_back((group_row, group_end));
group_start = group_end;
}
Ok(match (SIDE, SEARCH_SIDE) {
(true, _) => {
let group_idx = min(group_idx, self.group_end_indices.len());
if group_idx > 0 {
self.group_end_indices[group_idx - 1].1
} else {
0
}
}
(false, true) => {
if self.current_group_idx >= delta {
let group_idx = self.current_group_idx - delta;
self.group_end_indices[group_idx].1
} else {
0
}
}
(false, false) => {
let group_idx = min(
self.current_group_idx + delta,
self.group_end_indices.len() - 1,
);
self.group_end_indices[group_idx].1
}
})
}
}
fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> {
Ok(current == target)
}
#[cfg(test)]
mod tests {
use crate::window::window_frame_state::WindowFrameStateGroups;
use arrow::array::{ArrayRef, Float64Array};
use arrow_schema::SortOptions;
use datafusion_common::from_slice::FromSlice;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::ops::Range;
use std::sync::Arc;
fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from_slice([
5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
]))];
let sort_options = vec![SortOptions {
descending: false,
nulls_first: false,
}];
(range_columns, sort_options)
}
fn assert_expected(
expected_results: Vec<(Range<usize>, usize)>,
window_frame: &Arc<WindowFrame>,
) -> Result<()> {
let mut window_frame_groups = WindowFrameStateGroups::default();
let (range_columns, _) = get_test_data();
let n_row = range_columns[0].len();
for (idx, (expected_range, expected_group_idx)) in
expected_results.into_iter().enumerate()
{
let range = window_frame_groups.calculate_range(
window_frame,
&range_columns,
n_row,
idx,
)?;
assert_eq!(range, expected_range);
assert_eq!(window_frame_groups.current_group_idx, expected_group_idx);
}
Ok(())
}
#[test]
fn test_window_frame_group_boundaries() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
});
let expected_results = vec![
(Range { start: 0, end: 2 }, 0),
(Range { start: 0, end: 4 }, 1),
(Range { start: 1, end: 5 }, 2),
(Range { start: 1, end: 5 }, 2),
(Range { start: 2, end: 8 }, 3),
(Range { start: 4, end: 9 }, 4),
(Range { start: 4, end: 9 }, 4),
(Range { start: 4, end: 9 }, 4),
(Range { start: 5, end: 9 }, 5),
];
assert_expected(expected_results, &window_frame)
}
#[test]
fn test_window_frame_group_boundaries_both_following() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
});
let expected_results = vec![
(Range::<usize> { start: 1, end: 4 }, 0),
(Range::<usize> { start: 2, end: 5 }, 1),
(Range::<usize> { start: 4, end: 8 }, 2),
(Range::<usize> { start: 4, end: 8 }, 2),
(Range::<usize> { start: 5, end: 9 }, 3),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 8, end: 9 }, 4),
(Range::<usize> { start: 9, end: 9 }, 5),
];
assert_expected(expected_results, &window_frame)
}
#[test]
fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
end_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
});
let expected_results = vec![
(Range::<usize> { start: 0, end: 0 }, 0),
(Range::<usize> { start: 0, end: 1 }, 1),
(Range::<usize> { start: 0, end: 2 }, 2),
(Range::<usize> { start: 0, end: 2 }, 2),
(Range::<usize> { start: 1, end: 4 }, 3),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 2, end: 5 }, 4),
(Range::<usize> { start: 4, end: 8 }, 5),
];
assert_expected(expected_results, &window_frame)
}
}