use crate::{insertion_sort::insertion_sort_shift_left, maybe_grow};
use core::{
cmp::{
self,
Ordering::{Equal, Greater, Less},
},
mem::{self, ManuallyDrop, MaybeUninit},
ptr,
};
use ndarray::{ArrayView1, ArrayViewMut1, Axis, IndexLonger, s};
const MAX_INSERTION: usize = 10;
pub fn partition_at_indices<'a, T, E, F>(
mut v: ArrayViewMut1<'a, T>,
mut offset: usize,
mut indices: ArrayView1<usize>,
collection: &mut E,
is_less: &mut F,
) where
E: Extend<(usize, &'a mut T)>,
F: FnMut(&T, &T) -> bool,
{
while !indices.is_empty() {
let at = indices.len() / 2;
let (left_indices, right_indices) = indices.split_at(Axis(0), at);
let (index, right_indices) = right_indices.split_at(Axis(0), 1);
let pivot = *index.index(0);
let (left, value, right) = partition_at_index(v, pivot - offset, is_less);
maybe_grow(|| partition_at_indices(left, offset, left_indices, collection, is_less));
collection.extend([(pivot, value)]);
v = right;
offset = pivot + 1;
indices = right_indices;
}
}
pub fn partition_at_index<'a, T, F>(
mut v: ArrayViewMut1<'a, T>,
index: usize,
is_less: &mut F,
) -> (ArrayViewMut1<'a, T>, &'a mut T, ArrayViewMut1<'a, T>)
where
F: FnMut(&T, &T) -> bool,
{
if index >= v.len() {
panic!(
"partition_at_index index {} greater than length of slice {}",
index,
v.len()
);
}
if mem::size_of::<T>() == 0 {
} else if index == v.len() - 1 {
let (max_index, _) = v.iter().enumerate().max_by(from_is_less(is_less)).unwrap();
v.swap(max_index, index);
} else if index == 0 {
let (min_index, _) = v.iter().enumerate().min_by(from_is_less(is_less)).unwrap();
v.swap(min_index, index);
} else {
partition_at_index_loop(v.view_mut(), index, is_less, None);
}
let (left, right) = v.split_at(Axis(0), index);
let (pivot, right) = right.split_at(Axis(0), 1);
(left, pivot.index(0), right)
}
fn from_is_less<T>(
is_less: &mut impl FnMut(&T, &T) -> bool,
) -> impl FnMut(&(usize, &T), &(usize, &T)) -> cmp::Ordering + '_ {
|&(_, x), &(_, y)| {
if is_less(x, y) {
cmp::Ordering::Less
} else {
cmp::Ordering::Greater
}
}
}
fn partition_at_index_loop<'a, T, F>(
mut v: ArrayViewMut1<'a, T>,
mut index: usize,
is_less: &mut F,
mut pred: Option<&'a T>,
) where
F: FnMut(&T, &T) -> bool,
{
let mut limit = 16;
let mut was_balanced = true;
loop {
if v.len() <= MAX_INSERTION {
if !v.is_empty() {
insertion_sort_shift_left(v.view_mut(), 1, is_less);
}
return;
}
if limit == 0 {
median_of_medians(v.view_mut(), is_less, index);
return;
}
if !was_balanced {
break_patterns(v.view_mut());
limit -= 1;
}
let (pivot, _) = choose_pivot(v.view_mut(), is_less);
if let Some(p) = pred {
if !is_less(p, &v[pivot]) {
let mid = partition_equal(v.view_mut(), pivot, is_less);
if mid > index {
return;
}
let (_, new_v) = v.split_at(Axis(0), mid);
v = new_v;
index -= mid;
pred = None;
continue;
}
}
let (mid, _) = partition(v.view_mut(), pivot, is_less);
was_balanced = cmp::min(mid, v.len() - mid) >= v.len() / 8;
let (left, right) = v.split_at(Axis(0), mid);
let (pivot, right) = right.split_at(Axis(0), 1);
let pivot = pivot.index(0);
match mid.cmp(&index) {
Less => {
v = right;
index = index - mid - 1;
pred = Some(pivot);
}
Greater => v = left,
Equal => return,
}
}
}
fn median_of_medians<T, F: FnMut(&T, &T) -> bool>(
mut v: ArrayViewMut1<'_, T>,
is_less: &mut F,
mut k: usize,
) {
debug_assert!(k < v.len());
debug_assert!(mem::size_of::<T>() != 0);
loop {
if v.len() <= MAX_INSERTION {
if v.len() > 1 {
insertion_sort_shift_left(v.view_mut(), 1, is_less);
}
return;
}
if k == v.len() - 1 {
let (max_index, _) = v.iter().enumerate().max_by(from_is_less(is_less)).unwrap();
v.swap(max_index, k);
return;
} else if k == 0 {
let (min_index, _) = v.iter().enumerate().min_by(from_is_less(is_less)).unwrap();
v.swap(min_index, k);
return;
}
let p = median_of_ninthers(v.view_mut(), is_less);
match p.cmp(&k) {
Equal => return,
Greater => {
let (left, _right) = v.split_at(Axis(0), p);
v = left;
}
Less => {
let (_left, right) = v.split_at(Axis(0), p + 1);
v = right;
k -= p + 1;
}
}
}
}
fn median_of_ninthers<T, F: FnMut(&T, &T) -> bool>(
mut v: ArrayViewMut1<'_, T>,
is_less: &mut F,
) -> usize {
let frac = if v.len() <= 1024 {
v.len() / 12
} else if v.len() <= 128_usize.saturating_mul(1024) {
v.len() / 64
} else {
v.len() / 1024
};
let pivot = frac / 2;
let lo = v.len() / 2 - pivot;
let hi = frac + lo;
let gap = (v.len() - 9 * frac) / 4;
let mut a = lo - 4 * frac - gap;
let mut b = hi + gap;
for i in lo..hi {
ninther(
v.view_mut(),
is_less,
[a, i - frac, b, a + 1, i, b + 1, a + 2, i + frac, b + 2],
);
a += 3;
b += 3;
}
median_of_medians(v.slice_mut(s![lo..lo + frac]), is_less, pivot);
partition(v, lo + pivot, is_less).0
}
fn ninther<T, F: FnMut(&T, &T) -> bool>(
mut v: ArrayViewMut1<'_, T>,
is_less: &mut F,
n: [usize; 9],
) {
let [a, mut b, c, mut d, e, mut f, g, mut h, i] = n;
b = median_idx(v.view(), is_less, a, b, c);
h = median_idx(v.view(), is_less, g, h, i);
if is_less(&v[h], &v[b]) {
mem::swap(&mut b, &mut h);
}
if is_less(&v[f], &v[d]) {
mem::swap(&mut d, &mut f);
}
if is_less(&v[e], &v[d]) {
} else if is_less(&v[f], &v[e]) {
d = f;
} else {
if is_less(&v[e], &v[b]) {
v.swap(e, b);
} else if is_less(&v[h], &v[e]) {
v.swap(e, h);
}
return;
}
if is_less(&v[d], &v[b]) {
d = b;
} else if is_less(&v[h], &v[d]) {
d = h;
}
v.swap(d, e);
}
fn median_idx<T, F: FnMut(&T, &T) -> bool>(
v: ArrayView1<'_, T>,
is_less: &mut F,
mut a: usize,
b: usize,
mut c: usize,
) -> usize {
if is_less(&v[c], &v[a]) {
mem::swap(&mut a, &mut c);
}
if is_less(&v[c], &v[b]) {
return c;
}
if is_less(&v[b], &v[a]) {
return a;
}
b
}
pub fn partition_equal<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: usize, is_less: &mut F) -> usize
where
F: FnMut(&T, &T) -> bool,
{
v.swap(0, pivot);
let (pivot, mut v) = v.split_at(Axis(0), 1);
let pivot = pivot.index(0);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = InsertionHole {
src: &*tmp,
dest: pivot,
};
let pivot = &*tmp;
let len = v.len();
if len == 0 {
return 0;
}
let mut l = 0;
let mut r = len;
loop {
unsafe {
while l < r && !is_less(pivot, v.view().uget(l)) {
l += 1;
}
loop {
r -= 1;
if l >= r || !is_less(pivot, v.view().uget(r)) {
break;
}
}
if l >= r {
break;
}
v.uswap(l, r);
l += 1;
}
}
l + 1
}
pub fn partition<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: usize, is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
let (mid, was_partitioned) = {
let mut v = v.view_mut();
v.swap(0, pivot);
let (pivot, mut v) = v.split_at(Axis(0), 1);
let pivot = pivot.index(0);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = InsertionHole {
src: &*tmp,
dest: pivot,
};
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
unsafe {
while l < r && is_less(v.view().uget(l), pivot) {
l += 1;
}
while l < r && !is_less(v.view().uget(r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(v.slice_mut(s![l..r]), pivot, is_less),
l >= r,
)
};
v.swap(0, mid);
(mid, was_partitioned)
}
fn partition_in_blocks<T, F>(mut v: ArrayViewMut1<'_, T>, pivot: &T, is_less: &mut F) -> usize
where
F: FnMut(&T, &T) -> bool,
{
if v.is_empty() {
return 0;
}
const BLOCK: usize = 128;
let mut l = 0; let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
let mut r = v.len(); let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
fn ptr_width<T>(l: *mut T, r: *mut T) -> usize {
assert!(mem::size_of::<T>() > 0);
#[cfg(miri)]
{
(r.addr() - l.addr()) / mem::size_of::<T>()
}
#[cfg(not(miri))]
{
(r as usize - l as usize) / mem::size_of::<T>()
}
}
fn width(l: usize, r: usize) -> usize {
r - l
}
loop {
let is_done = width(l, r) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r) == block_l + block_r);
}
if start_l == end_l {
#[cfg(miri)]
{
start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l);
}
#[cfg(not(miri))]
{
start_l = offsets_l[0].as_mut_ptr();
}
end_l = start_l;
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.add(!is_less(v.view_mut().index(elem), pivot) as usize);
elem += 1; }
}
}
if start_r == end_r {
#[cfg(miri)]
{
start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r);
}
#[cfg(not(miri))]
{
start_r = offsets_r[0].as_mut_ptr();
}
end_r = start_r;
let mut elem = r;
for i in 0..block_r {
unsafe {
elem -= 1;
*end_r = i as u8;
end_r = end_r.add(is_less(v.view_mut().index(elem), pivot) as usize);
}
}
}
let count = cmp::min(ptr_width(start_l, end_l), ptr_width(start_r, end_r));
if count > 0 {
macro_rules! left {
() => {
v.view_mut().index(l + usize::from(*start_l)) as *mut T };
}
macro_rules! right {
() => {
v.view_mut().index(r - (usize::from(*start_r) + 1)) as *mut T };
}
unsafe {
let tmp = ptr::read(left!());
ptr::copy_nonoverlapping(right!(), left!(), 1);
for _ in 1..count {
start_l = start_l.add(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
start_r = start_r.add(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
}
ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
start_l = start_l.add(1);
start_r = start_r.add(1);
}
}
if start_l == end_l {
l += block_l; }
if start_r == end_r {
r -= block_r; }
if is_done {
break;
}
}
if start_l < end_l {
debug_assert_eq!(width(l, r), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.sub(1);
v.uswap(l + usize::from(*end_l), r - 1); r -= 1; }
}
width(0, r) } else if start_r < end_r {
debug_assert_eq!(width(l, r), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.sub(1);
v.uswap(l, r - (usize::from(*end_r) + 1)); l += 1; }
}
width(0, l) } else {
width(0, l) }
}
pub fn is_sorted<T, F>(v: ArrayView1<'_, T>, mut compare: F) -> bool
where
F: FnMut(&T, &T) -> bool,
{
for i in 1..v.len() {
let [a, b] = [&v[i - 1], &v[i]];
if !compare(a, b) {
return false;
}
}
true
}
pub fn reverse<T>(v: ArrayViewMut1<'_, T>) {
let len = v.len();
let half_len = v.len() / 2;
let (front_half, back_half) = v.split_at(Axis(0), len - half_len);
let (front_half, _middle_item) = front_half.split_at(Axis(0), half_len);
revswap(front_half, back_half, half_len);
#[inline]
fn revswap<T>(mut a: ArrayViewMut1<'_, T>, mut b: ArrayViewMut1<'_, T>, n: usize) {
debug_assert!(a.len() == n);
debug_assert!(b.len() == n);
let mut i = 0;
while i < n {
mem::swap(a.view_mut().index(i), b.view_mut().index(n - 1 - i));
i += 1;
}
}
}
#[cold]
pub fn break_patterns<T>(mut v: ArrayViewMut1<'_, T>) {
let len = v.len();
if len >= 8 {
let mut seed = len;
let mut gen_usize = || {
if usize::BITS <= 32 {
let mut r = seed as u32;
r ^= r << 13;
r ^= r >> 17;
r ^= r << 5;
seed = r as usize;
seed
} else {
let mut r = seed as u64;
r ^= r << 13;
r ^= r >> 7;
r ^= r << 17;
seed = r as usize;
seed
}
};
let modulus = len.next_power_of_two();
let pos = len / 4 * 2;
for i in 0..3 {
let mut other = gen_usize() & (modulus - 1);
if other >= len {
other -= len;
}
v.swap(pos - 1 + i, other);
}
}
}
pub fn choose_pivot<T, F>(v: ArrayViewMut1<'_, T>, is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
let len = v.len();
let mut a = len / 4;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let v = v.view();
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(v.uget(*b), v.uget(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
reverse(v);
(len - 1 - b, true)
}
}
pub struct InsertionHole<T> {
pub src: *const T,
pub dest: *mut T,
}
impl<T> Drop for InsertionHole<T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod test {
use super::{partition_at_index, partition_at_indices, reverse};
use crate::{partition_dedup::partition_dedup, quick_sort::quick_sort};
use ndarray::{Array1, arr1};
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
#[quickcheck]
fn at_indices(xs: Vec<u32>) -> TestResult {
if xs.is_empty() {
return TestResult::discard();
}
let mut array = arr1(&xs);
let mut sorted = arr1(&xs);
quick_sort(sorted.view_mut(), &mut u32::lt);
let mut indices = arr1(&[xs.len() - 1, xs.len() / 2, xs.len() / 3, xs.len() / 4, 0]);
quick_sort(indices.view_mut(), &mut usize::lt);
let (indices, _duplicates) = partition_dedup(indices.view_mut(), |a, b| a.eq(&b));
if indices.iter().any(|&index| index >= xs.len()) {
return TestResult::discard();
}
let mut values = Vec::with_capacity(indices.len());
partition_at_indices(
array.view_mut(),
0,
indices.view(),
&mut values,
&mut u32::lt,
);
assert_eq!(
indices,
Array1::from_iter(values.iter().map(|(index, _value)| *index))
);
for (index, value) in values {
assert_eq!(*value, sorted[index]);
}
TestResult::passed()
}
#[quickcheck]
fn at_index(xs: Vec<u32>) -> TestResult {
if xs.is_empty() {
return TestResult::discard();
}
let mut array = arr1(&xs);
let (left, value, right) = partition_at_index(array.view_mut(), xs.len() / 3, &mut u32::lt);
for left in left {
assert!(left <= value);
}
for right in right {
assert!(value <= right);
}
TestResult::passed()
}
#[quickcheck]
fn reversed(xs: Vec<u32>) -> bool {
let array = Array1::from_vec(xs);
let mut array_rev = array.clone();
reverse(array_rev.view_mut());
array
.iter()
.zip(array_rev.iter().rev())
.all(|(a, b)| a == b)
}
}