use crate::{
insertion_sort::InsertionHole,
maybe_grow,
par::insertion_sort::insertion_sort_shift_left,
partition::{break_patterns, reverse},
};
use core::{
cmp::{
self,
Ordering::{Equal, Greater, Less},
},
mem::{self, ManuallyDrop, MaybeUninit},
ptr,
};
use ndarray::{ArrayView1, ArrayViewMut1, Axis, IndexLonger, s};
const MAX_INSERTION: usize = 10;
pub fn par_partition_at_indices<'a, T, F>(
mut v: ArrayViewMut1<'a, T>,
mut offset: usize,
mut indices: ArrayView1<usize>,
mut values: &mut [MaybeUninit<&'a mut T>],
is_less: &F,
) where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
const MAX_SEQUENTIAL: usize = 2000;
while !indices.is_empty() {
let at = indices.len() / 2;
let (left_indices, right_indices) = indices.split_at(Axis(0), at);
let (index, right_indices) = right_indices.split_at(Axis(0), 1);
let pivot = *index.index(0);
let (left, value, right) = partition_at_index(v, pivot - offset, is_less);
values[at].write(value);
let (left_values, right_values) = values.split_at_mut(at);
let right_values = &mut right_values[1..];
if at == 0 || pivot - offset <= MAX_SEQUENTIAL {
maybe_grow(|| {
par_partition_at_indices(left, offset, left_indices, left_values, is_less)
});
v = right;
offset = pivot + 1;
indices = right_indices;
values = right_values;
} else {
rayon::join(
|| {
maybe_grow(|| {
par_partition_at_indices(left, offset, left_indices, left_values, is_less)
})
},
|| {
maybe_grow(|| {
par_partition_at_indices(
right,
pivot + 1,
right_indices,
right_values,
is_less,
)
})
},
);
break;
}
}
}
pub fn partition_at_index<'a, T, F>(
mut v: ArrayViewMut1<'a, T>,
index: usize,
is_less: &F,
) -> (ArrayViewMut1<'a, T>, &'a mut T, ArrayViewMut1<'a, T>)
where
F: Fn(&T, &T) -> bool,
{
if index >= v.len() {
panic!(
"partition_at_index index {} greater than length of slice {}",
index,
v.len()
);
}
if mem::size_of::<T>() == 0 {
} else if index == v.len() - 1 {
let (max_index, _) = v.iter().enumerate().max_by(from_is_less(is_less)).unwrap();
v.swap(max_index, index);
} else if index == 0 {
let (min_index, _) = v.iter().enumerate().min_by(from_is_less(is_less)).unwrap();
v.swap(min_index, index);
} else {
partition_at_index_loop(v.view_mut(), index, is_less, None);
}
let (left, right) = v.split_at(Axis(0), index);
let (pivot, right) = right.split_at(Axis(0), 1);
(left, pivot.index(0), right)
}
fn from_is_less<T>(
is_less: &impl Fn(&T, &T) -> bool,
) -> impl Fn(&(usize, &T), &(usize, &T)) -> cmp::Ordering + '_ {
|&(_, x), &(_, y)| {
if is_less(x, y) {
cmp::Ordering::Less
} else {
cmp::Ordering::Greater
}
}
}
fn partition_at_index_loop<'a, T, F>(
mut v: ArrayViewMut1<'a, T>,
mut index: usize,
is_less: &F,
mut pred: Option<&'a T>,
) where
F: Fn(&T, &T) -> bool,
{
let mut limit = 16;
let mut was_balanced = true;
loop {
if v.len() <= MAX_INSERTION {
if !v.is_empty() {
insertion_sort_shift_left(v.view_mut(), 1, is_less);
}
return;
}
if limit == 0 {
median_of_medians(v.view_mut(), is_less, index);
return;
}
if !was_balanced {
break_patterns(v.view_mut());
limit -= 1;
}
let (pivot, _) = choose_pivot(v.view_mut(), is_less);
if let Some(p) = pred {
if !is_less(p, &v[pivot]) {
let mid = partition_equal(v.view_mut(), pivot, is_less);
if mid > index {
return;
}
let (_, new_v) = v.split_at(Axis(0), mid);
v = new_v;
index -= mid;
pred = None;
continue;
}
}
let (mid, _) = partition(v.view_mut(), pivot, is_less);
was_balanced = cmp::min(mid, v.len() - mid) >= v.len() / 8;
let (left, right) = v.split_at(Axis(0), mid);
let (pivot, right) = right.split_at(Axis(0), 1);
let pivot = pivot.index(0);
match mid.cmp(&index) {
Less => {
v = right;
index = index - mid - 1;
pred = Some(pivot);
}
Greater => v = left,
Equal => return,
}
}
}
fn median_of_medians<T, F: Fn(&T, &T) -> bool>(
mut v: ArrayViewMut1<'_, T>,
is_less: &F,
mut k: usize,
) {
debug_assert!(k < v.len());
debug_assert!(mem::size_of::<T>() != 0);
loop {
if v.len() <= MAX_INSERTION {
if v.len() > 1 {
insertion_sort_shift_left(v.view_mut(), 1, is_less);
}
return;
}
if k == v.len() - 1 {
let (max_index, _) = v.iter().enumerate().max_by(from_is_less(is_less)).unwrap();
v.swap(max_index, k);
return;
} else if k == 0 {
let (min_index, _) = v.iter().enumerate().min_by(from_is_less(is_less)).unwrap();
v.swap(min_index, k);
return;
}
let p = median_of_ninthers(v.view_mut(), is_less);
match p.cmp(&k) {
Equal => return,
Greater => {
let (left, _right) = v.split_at(Axis(0), p);
v = left;
}
Less => {
let (_left, right) = v.split_at(Axis(0), p + 1);
v = right;
k -= p + 1;
}
}
}
}
fn median_of_ninthers<T, F: Fn(&T, &T) -> bool>(mut v: ArrayViewMut1<'_, T>, is_less: &F) -> usize {
let frac = if v.len() <= 1024 {
v.len() / 12
} else if v.len() <= 128_usize.saturating_mul(1024) {
v.len() / 64
} else {
v.len() / 1024
};
let pivot = frac / 2;
let lo = v.len() / 2 - pivot;
let hi = frac + lo;
let gap = (v.len() - 9 * frac) / 4;
let mut a = lo - 4 * frac - gap;
let mut b = hi + gap;
for i in lo..hi {
ninther(
v.view_mut(),
is_less,
[a, i - frac, b, a + 1, i, b + 1, a + 2, i + frac, b + 2],
);
a += 3;
b += 3;
}
median_of_medians(v.slice_mut(s![lo..lo + frac]), is_less, pivot);
partition(v, lo + pivot, is_less).0
}
fn ninther<T, F: Fn(&T, &T) -> bool>(mut v: ArrayViewMut1<'_, T>, is_less: &F, n: [usize; 9]) {
let [a, mut b, c, mut d, e, mut f, g, mut h, i] = n;
b = median_idx(v.view(), is_less, a, b, c);
h = median_idx(v.view(), is_less, g, h, i);
if is_less(&v[h], &v[b]) {
mem::swap(&mut b, &mut h);
}
if is_less(&v[f], &v[d]) {
mem::swap(&mut d, &mut f);
}
if is_less(&v[e], &v[d]) {
} else if is_less(&v[f], &v[e]) {
d = f;
} else {
if is_less(&v[e], &v[b]) {
v.swap(e, b);
} else if is_less(&v[h], &v[e]) {
v.swap(e, h);
}
return;
}
if is_less(&v[d], &v[b]) {
d = b;
} else if is_less(&v[h], &v[d]) {
d = h;
}
v.swap(d, e);
}
fn median_idx<T, F: Fn(&T, &T) -> bool>(
v: ArrayView1<'_, T>,
is_less: &F,
mut a: usize,
b: usize,
mut c: usize,
) -> usize {
if is_less(&v[c], &v[a]) {
mem::swap(&mut a, &mut c);
}
if is_less(&v[c], &v[b]) {
return c;
}
if is_less(&v[b], &v[a]) {
return a;
}
b
}
pub fn partition_equal<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: usize, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
v.swap(0, pivot);
let (pivot, mut v) = v.split_at(Axis(0), 1);
let pivot = pivot.index(0);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = unsafe { InsertionHole::new(&*tmp, pivot) };
let pivot = &*tmp;
let len = v.len();
if len == 0 {
return 0;
}
let mut l = 0;
let mut r = len;
loop {
unsafe {
while l < r && !is_less(pivot, v.view().uget(l)) {
l += 1;
}
loop {
r -= 1;
if l >= r || !is_less(pivot, v.view().uget(r)) {
break;
}
}
if l >= r {
break;
}
v.uswap(l, r);
l += 1;
}
}
l + 1
}
pub fn partition<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: usize, is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
let (mid, was_partitioned) = {
let mut v = v.view_mut();
v.swap(0, pivot);
let (pivot, mut v) = v.split_at(Axis(0), 1);
let pivot = pivot.index(0);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = unsafe { InsertionHole::new(&*tmp, pivot) };
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
unsafe {
while l < r && is_less(v.view().uget(l), pivot) {
l += 1;
}
while l < r && !is_less(v.view().uget(r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(v.slice_mut(s![l..r]), pivot, is_less),
l >= r,
)
};
v.swap(0, mid);
(mid, was_partitioned)
}
fn partition_in_blocks<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: &T, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
if v.is_empty() {
return 0;
}
const BLOCK: usize = 128;
let mut l = 0; let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
let mut r = v.len(); let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
fn ptr_width<T>(l: *mut T, r: *mut T) -> usize {
assert!(mem::size_of::<T>() > 0);
#[cfg(miri)]
{
(r.addr() - l.addr()) / mem::size_of::<T>()
}
#[cfg(not(miri))]
{
(r as usize - l as usize) / mem::size_of::<T>()
}
}
fn width(l: usize, r: usize) -> usize {
r - l
}
loop {
let is_done = width(l, r) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r) == block_l + block_r);
}
if start_l == end_l {
#[cfg(miri)]
{
start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l);
}
#[cfg(not(miri))]
{
start_l = offsets_l[0].as_mut_ptr();
}
end_l = start_l;
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.add(!is_less(v.view_mut().index(elem), pivot) as usize);
elem += 1; }
}
}
if start_r == end_r {
#[cfg(miri)]
{
start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r);
}
#[cfg(not(miri))]
{
start_r = offsets_r[0].as_mut_ptr();
}
end_r = start_r;
let mut elem = r;
for i in 0..block_r {
unsafe {
elem -= 1;
*end_r = i as u8;
end_r = end_r.add(is_less(v.view_mut().index(elem), pivot) as usize);
}
}
}
let count = cmp::min(ptr_width(start_l, end_l), ptr_width(start_r, end_r));
if count > 0 {
macro_rules! left {
() => {
v.view_mut().index(l + usize::from(*start_l)) as *mut T };
}
macro_rules! right {
() => {
v.view_mut().index(r - (usize::from(*start_r) + 1)) as *mut T };
}
unsafe {
let tmp = ptr::read(left!());
ptr::copy_nonoverlapping(right!(), left!(), 1);
for _ in 1..count {
start_l = start_l.add(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
start_r = start_r.add(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
}
ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
start_l = start_l.add(1);
start_r = start_r.add(1);
}
}
if start_l == end_l {
l += block_l; }
if start_r == end_r {
r -= block_r; }
if is_done {
break;
}
}
if start_l < end_l {
debug_assert_eq!(width(l, r), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.sub(1);
v.uswap(l + usize::from(*end_l), r - 1); r -= 1; }
}
width(0, r) } else if start_r < end_r {
debug_assert_eq!(width(l, r), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.sub(1);
v.uswap(l, r - (usize::from(*end_r) + 1)); l += 1; }
}
width(0, l) } else {
width(0, l) }
}
pub fn choose_pivot<T, F>(v: ArrayViewMut1<'_, T>, is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
let len = v.len();
let mut a = len / 4;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let v = v.view();
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(v.uget(*b), v.uget(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
reverse(v);
(len - 1 - b, true)
}
}
#[cfg(test)]
mod test {
use super::{par_partition_at_indices, partition_at_index};
use crate::{par::quick_sort::par_quick_sort, partition_dedup::partition_dedup};
use ndarray::arr1;
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
#[cfg_attr(miri, ignore)]
#[quickcheck]
fn at_indices(xs: Vec<u32>) -> TestResult {
if xs.is_empty() {
return TestResult::discard();
}
let mut array = arr1(&xs);
let mut sorted = arr1(&xs);
par_quick_sort(sorted.view_mut(), u32::lt);
let mut indices = arr1(&[xs.len() - 1, xs.len() / 2, xs.len() / 3, xs.len() / 4, 0]);
par_quick_sort(indices.view_mut(), usize::lt);
let (indices, _duplicates) = partition_dedup(indices.view_mut(), |a, b| a.eq(&b));
if indices.iter().any(|&index| index >= xs.len()) {
return TestResult::discard();
}
let mut collection = Vec::with_capacity(indices.len());
let values = collection.spare_capacity_mut();
assert_eq!(indices.len(), values.len());
par_partition_at_indices(array.view_mut(), 0, indices.view(), values, &u32::lt);
unsafe { collection.set_len(collection.len() + indices.len()) };
for (index, value) in indices.into_iter().zip(collection.into_iter()) {
assert_eq!(*value, sorted[*index]);
}
TestResult::passed()
}
#[quickcheck]
fn at_index(xs: Vec<u32>) -> TestResult {
if xs.is_empty() {
return TestResult::discard();
}
let mut array = arr1(&xs);
let (left, value, right) = partition_at_index(array.view_mut(), xs.len() / 3, &u32::lt);
for left in left {
assert!(left <= value);
}
for right in right {
assert!(value <= right);
}
TestResult::passed()
}
}