use std::cmp::Ordering;
use std::mem::size_of;
use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::compute::SortOptions;
use arrow_ord::partition::partition;
use datafusion_common::utils::{compare_rows, get_row_at_idx};
use datafusion_common::{Result, ScalarValue};
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
#[derive(Debug)]
pub struct GroupOrderingPartial {
state: State,
order_indices: Vec<usize>,
}
#[derive(Debug, Default, PartialEq)]
enum State {
#[default]
Taken,
Start,
InProgress {
current_sort: usize,
sort_key: Vec<ScalarValue>,
current: usize,
},
Complete,
}
impl State {
fn size(&self) -> usize {
match self {
State::Taken => 0,
State::Start => 0,
State::InProgress { sort_key, .. } => sort_key
.iter()
.map(|scalar_value| scalar_value.size())
.sum(),
State::Complete => 0,
}
}
}
impl GroupOrderingPartial {
pub fn try_new(order_indices: Vec<usize>) -> Result<Self> {
debug_assert!(!order_indices.is_empty());
Ok(Self {
state: State::Start,
order_indices,
})
}
fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec<ArrayRef> {
self.order_indices
.iter()
.map(|&idx| Arc::clone(&group_values[idx]))
.collect()
}
pub fn emit_to(&self) -> Option<EmitTo> {
match &self.state {
State::Taken => unreachable!("State previously taken"),
State::Start => None,
State::InProgress { current_sort, .. } => {
if *current_sort == 0 {
None
} else {
Some(EmitTo::First(*current_sort))
}
}
State::Complete => Some(EmitTo::All),
}
}
pub fn remove_groups(&mut self, n: usize) {
match &mut self.state {
State::Taken => unreachable!("State previously taken"),
State::Start => panic!("invalid state: start"),
State::InProgress {
current_sort,
current,
sort_key: _,
} => {
assert!(*current >= n);
*current -= n;
assert!(*current_sort >= n);
*current_sort -= n;
}
State::Complete => panic!("invalid state: complete"),
}
}
pub fn input_done(&mut self) {
self.state = match self.state {
State::Taken => unreachable!("State previously taken"),
_ => State::Complete,
};
}
fn updated_sort_key(
current_sort: usize,
sort_key: Option<Vec<ScalarValue>>,
range_current_sort: usize,
range_sort_key: Vec<ScalarValue>,
) -> Result<(usize, Vec<ScalarValue>)> {
if let Some(sort_key) = sort_key {
let sort_options = vec![SortOptions::new(false, false); sort_key.len()];
let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?;
if ordering == Ordering::Equal {
return Ok((current_sort, sort_key));
}
}
Ok((range_current_sort, range_sort_key))
}
pub fn new_groups(
&mut self,
batch_group_values: &[ArrayRef],
group_indices: &[usize],
total_num_groups: usize,
) -> Result<()> {
assert!(total_num_groups > 0);
assert!(!batch_group_values.is_empty());
let max_group_index = total_num_groups - 1;
let (current_sort, sort_key) = match std::mem::take(&mut self.state) {
State::Taken => unreachable!("State previously taken"),
State::Start => (0, None),
State::InProgress {
current_sort,
sort_key,
..
} => (current_sort, Some(sort_key)),
State::Complete => {
panic!("Saw new group after the end of input");
}
};
let sort_keys = self.compute_sort_keys(batch_group_values);
let ranges = partition(&sort_keys)?.ranges();
let last_range = ranges.last().unwrap();
let range_current_sort = group_indices[last_range.start];
let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?;
let (current_sort, sort_key) = if last_range.start == 0 {
Self::updated_sort_key(
current_sort,
sort_key,
range_current_sort,
range_sort_key,
)?
} else {
(range_current_sort, range_sort_key)
};
self.state = State::InProgress {
current_sort,
current: max_group_index,
sort_key,
};
Ok(())
}
pub(crate) fn size(&self) -> usize {
size_of::<Self>() + self.order_indices.allocated_size() + self.state.size()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int32Array;
#[test]
fn test_group_ordering_partial() -> Result<()> {
let order_indices = vec![0];
let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?;
let batch_group_values: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![2, 1, 3])),
];
let group_indices = vec![0, 1, 2];
let total_num_groups = 3;
group_ordering.new_groups(
&batch_group_values,
&group_indices,
total_num_groups,
)?;
assert_eq!(
group_ordering.state,
State::InProgress {
current_sort: 2,
sort_key: vec![ScalarValue::Int32(Some(3))],
current: 2
}
);
let batch_group_values: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(vec![3, 3, 3])),
Arc::new(Int32Array::from(vec![2, 1, 7])),
];
let group_indices = vec![3, 4, 5];
let total_num_groups = 6;
group_ordering.new_groups(
&batch_group_values,
&group_indices,
total_num_groups,
)?;
assert_eq!(
group_ordering.state,
State::InProgress {
current_sort: 2,
sort_key: vec![ScalarValue::Int32(Some(3))],
current: 5
}
);
let batch_group_values: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(vec![4, 4, 4])),
Arc::new(Int32Array::from(vec![1, 1, 1])),
];
let group_indices = vec![6, 7, 8];
let total_num_groups = 9;
group_ordering.new_groups(
&batch_group_values,
&group_indices,
total_num_groups,
)?;
assert_eq!(
group_ordering.state,
State::InProgress {
current_sort: 6,
sort_key: vec![ScalarValue::Int32(Some(4))],
current: 8
}
);
Ok(())
}
}