use ahash::AHashMap;
use std::cmp::Ordering;
use std::iter::once;
use itertools::Itertools;
use crate::array::{
growable::make_growable,
ord::{build_compare, DynComparator},
Array,
};
pub use crate::compute::sort::SortOptions;
use crate::error::Result;
pub type MergeSlice = (usize, usize, usize);
pub fn take_arrays<I: IntoIterator<Item = MergeSlice>>(
arrays: &[&dyn Array],
slices: I,
limit: Option<usize>,
) -> Box<dyn Array> {
let slices = slices.into_iter();
let len = arrays.iter().map(|array| array.len()).sum();
let limit = limit.unwrap_or(len);
let limit = limit.min(len);
let mut growable = make_growable(arrays, false, limit);
if limit != len {
let mut current_len = 0;
for (index, start, len) in slices {
if len + current_len >= limit {
growable.extend(index, start, limit - current_len);
break;
} else {
growable.extend(index, start, len);
current_len += len;
}
}
} else {
for (index, start, len) in slices {
growable.extend(index, start, len);
}
}
growable.as_box()
}
pub fn merge_sort(
lhs: &dyn Array,
rhs: &dyn Array,
options: &SortOptions,
limit: Option<usize>,
) -> Result<Box<dyn Array>> {
let arrays = &[lhs, rhs];
let pairs: &[(&[&dyn Array], &SortOptions)] = &[(arrays, options)];
let comparator = build_comparator(pairs)?;
let lhs = (0, 0, lhs.len());
let rhs = (1, 0, rhs.len());
let slices = merge_sort_slices(once(&lhs), once(&rhs), &comparator);
Ok(take_arrays(arrays, slices, limit))
}
pub fn slices(pairs: &[(&[&dyn Array], &SortOptions)]) -> Result<Vec<MergeSlice>> {
assert!(!pairs.is_empty());
let comparator = build_comparator(pairs)?;
let slices = pairs[0]
.0
.iter()
.enumerate()
.map(|(index, array)| vec![(index, 0, array.len())])
.collect::<Vec<_>>();
let slices = slices
.iter()
.map(|slice| slice.as_ref())
.collect::<Vec<_>>();
Ok(recursive_merge_sort(&slices, &comparator))
}
fn recursive_merge_sort(slices: &[&[MergeSlice]], comparator: &Comparator) -> Vec<MergeSlice> {
let n = slices.len();
let m = n / 2;
if n == 1 {
return slices[0].to_vec();
}
if n == 2 {
return merge_sort_slices(slices[0].iter(), slices[1].iter(), comparator)
.collect::<Vec<_>>();
}
let lhs = recursive_merge_sort(&slices[0..m], comparator);
let rhs = recursive_merge_sort(&slices[m..n], comparator);
merge_sort_slices(lhs.iter(), rhs.iter(), comparator).collect::<Vec<_>>()
}
pub struct MergeSortSlices<'a, L, R>
where
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
{
lhs: L,
rhs: R,
comparator: &'a Comparator<'a>,
left: Option<(MergeSlice, usize)>, right: Option<(MergeSlice, usize)>,
has_started: bool,
current_start: usize,
current_len: usize,
current_is_left: bool,
}
impl<'a, L, R> MergeSortSlices<'a, L, R>
where
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
{
fn new(lhs: L, rhs: R, comparator: &'a Comparator<'a>) -> Self {
Self {
lhs,
rhs,
comparator,
left: None,
right: None,
has_started: false,
current_start: 0,
current_len: 0,
current_is_left: true,
}
}
fn next_left(&mut self) {
match self.lhs.next() {
Some(slice) => {
self.left = Some((*slice, slice.1));
self.current_start = slice.1;
}
None => self.left = None,
}
}
fn next_right(&mut self) {
match self.rhs.next() {
Some(slice) => {
self.right = Some((*slice, slice.1));
self.current_start = slice.1;
}
None => self.right = None,
}
}
#[warn(dead_code)]
pub fn to_vec(self, limit: Option<usize>) -> Vec<MergeSlice> {
match limit {
Some(limit) => {
let mut v = Vec::with_capacity(limit);
let mut current_len = 0;
for (index, start, len) in self {
if len + current_len >= limit {
v.push((index, start, limit - current_len));
break;
} else {
v.push((index, start, len));
}
current_len += len;
}
v
}
None => self.into_iter().collect(),
}
}
}
impl<'a, L, R> Iterator for MergeSortSlices<'a, L, R>
where
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
{
type Item = MergeSlice;
fn next(&mut self) -> Option<Self::Item> {
if !self.has_started {
self.next_left();
self.next_right();
}
match (self.left, self.right) {
(None, None) => {
None
}
(Some((left_slice, left_index)), None) => {
self.next_left();
if left_index != left_slice.1 {
Some((
left_slice.0,
left_index,
left_slice.2 - (left_index - left_slice.1),
))
} else {
Some(left_slice)
}
}
(None, Some((right_slice, right_index))) => {
self.next_right();
if right_index != right_slice.1 {
Some((
right_slice.0,
right_index,
right_slice.2 - (right_index - right_slice.1),
))
} else {
Some(right_slice)
}
}
(Some((left_slice, mut left_index)), Some((right_slice, mut right_index))) => {
if !self.has_started {
let ordering =
(self.comparator)(left_slice.0, left_index, right_slice.0, right_index);
if ordering == Ordering::Greater {
self.current_is_left = false;
self.current_start = right_index;
} else {
self.current_is_left = true;
self.current_start = left_index;
}
self.has_started = true;
}
while (left_index < left_slice.1 + left_slice.2)
&& (right_index < right_slice.1 + right_slice.2)
{
match (
(self.comparator)(left_slice.0, left_index, right_slice.0, right_index),
self.current_is_left,
) {
(Ordering::Less, true) | (Ordering::Equal, true) => {
self.current_len += 1;
left_index += 1;
}
(Ordering::Greater, false) | (Ordering::Equal, false) => {
self.current_len += 1;
right_index += 1;
}
(Ordering::Less, false) => {
let start = self.current_start;
let len = self.current_len;
self.current_is_left = true;
self.current_len = 0;
self.current_start = left_index;
if len > 0 {
self.left = Some((left_slice, left_index));
self.right = Some((right_slice, right_index));
return Some((right_slice.0, start, len));
}
}
(Ordering::Greater, true) => {
let start = self.current_start;
let len = self.current_len;
self.current_is_left = false;
self.current_len = 0;
self.current_start = right_index;
if len > 0 {
self.left = Some((left_slice, left_index));
self.right = Some((right_slice, right_index));
return Some((left_slice.0, start, len));
}
}
}
}
let start = self.current_start;
let len = self.current_len;
if left_index == left_slice.1 + left_slice.2 {
self.current_len = 0;
self.next_left();
Some((left_slice.0, start, len))
} else {
debug_assert_eq!(right_index, right_slice.1 + right_slice.2);
self.current_len = 0;
self.next_right();
Some((right_slice.0, start, len))
}
}
}
}
}
pub fn merge_sort_slices<
'a,
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
>(
lhs: L,
rhs: R,
comparator: &'a Comparator,
) -> MergeSortSlices<'a, L, R> {
MergeSortSlices::new(lhs, rhs, comparator)
}
type Comparator<'a> = Box<dyn Fn(usize, usize, usize, usize) -> Ordering + 'a>;
type IsValid<'a> = Box<dyn Fn(usize) -> bool + 'a>;
pub fn build_comparator<'a>(
pairs: &'a [(&'a [&'a dyn Array], &SortOptions)],
) -> Result<Comparator<'a>> {
build_comparator_impl(pairs, &build_compare)
}
pub fn build_comparator_impl<'a>(
pairs: &'a [(&'a [&'a dyn Array], &SortOptions)],
build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result<DynComparator>,
) -> Result<Comparator<'a>> {
let indices_pairs = (0..pairs[0].0.len())
.combinations(2)
.map(|indices| (indices[0], indices[1]));
let data = indices_pairs
.map(|(lhs_index, rhs_index)| {
let multi_column_comparator = pairs
.iter()
.map(move |(arrays, _)| {
Ok((
Box::new(move |row| arrays[lhs_index].is_valid(row)) as IsValid<'a>,
Box::new(move |row| arrays[rhs_index].is_valid(row)) as IsValid<'a>,
build_compare_fn(arrays[lhs_index], arrays[rhs_index])?,
))
})
.collect::<Result<Vec<_>>>()?;
Ok(((lhs_index, rhs_index), multi_column_comparator))
})
.collect::<Result<AHashMap<(usize, usize), Vec<(IsValid, IsValid, DynComparator)>>>>()?;
let cmp = move |left_index, left_row, right_index, right_row| {
let data = data.get(&(left_index, right_index)).unwrap();
for c in 0..pairs.len() {
let descending = pairs[c].1.descending;
let null_first = pairs[c].1.nulls_first;
let (l_is_valid, r_is_valid, value_comparator) = &data[c];
let result = match ((l_is_valid)(left_row), (r_is_valid)(right_row)) {
(true, true) => {
let result = (value_comparator)(left_row, right_row);
match descending {
true => result.reverse(),
false => result,
}
}
(false, true) => {
if null_first {
Ordering::Less
} else {
Ordering::Greater
}
}
(true, false) => {
if null_first {
Ordering::Greater
} else {
Ordering::Less
}
}
(false, false) => Ordering::Equal,
};
if result != Ordering::Equal {
return result;
}
}
Ordering::Equal
};
Ok(Box::new(cmp))
}