use rayon_core;
use std::cmp;
use std::mem;
use std::ptr;
struct WriteOnDrop<T> {
value: Option<T>,
dest: *mut T,
}
impl<T> Drop for WriteOnDrop<T> {
fn drop(&mut self) {
unsafe {
ptr::write(self.dest, self.value.take().unwrap());
}
}
}
struct NoDrop<T> {
value: Option<T>,
}
impl<T> Drop for NoDrop<T> {
fn drop(&mut self) {
mem::forget(self.value.take());
}
}
struct CopyOnDrop<T> {
src: *mut T,
dest: *mut T,
}
impl<T> Drop for CopyOnDrop<T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}
fn shift_head<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let len = v.len();
unsafe {
if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) {
let mut tmp = NoDrop {
value: Some(ptr::read(v.get_unchecked(0))),
};
let mut hole = CopyOnDrop {
src: tmp.value.as_mut().unwrap(),
dest: v.get_unchecked_mut(1),
};
ptr::copy_nonoverlapping(v.get_unchecked(1), v.get_unchecked_mut(0), 1);
for i in 2..len {
if !is_less(v.get_unchecked(i), tmp.value.as_ref().unwrap()) {
break;
}
ptr::copy_nonoverlapping(v.get_unchecked(i), v.get_unchecked_mut(i - 1), 1);
hole.dest = v.get_unchecked_mut(i);
}
}
}
}
fn shift_tail<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let len = v.len();
unsafe {
if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
let mut tmp = NoDrop {
value: Some(ptr::read(v.get_unchecked(len - 1))),
};
let mut hole = CopyOnDrop {
src: tmp.value.as_mut().unwrap(),
dest: v.get_unchecked_mut(len - 2),
};
ptr::copy_nonoverlapping(v.get_unchecked(len - 2), v.get_unchecked_mut(len - 1), 1);
for i in (0..len - 2).rev() {
if !is_less(&tmp.value.as_ref().unwrap(), v.get_unchecked(i)) {
break;
}
ptr::copy_nonoverlapping(v.get_unchecked(i), v.get_unchecked_mut(i + 1), 1);
hole.dest = v.get_unchecked_mut(i);
}
}
}
}
#[cold]
fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool
where
F: Fn(&T, &T) -> bool,
{
const MAX_STEPS: usize = 5;
const SHORTEST_SHIFTING: usize = 50;
let len = v.len();
let mut i = 1;
for _ in 0..MAX_STEPS {
unsafe {
while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) {
i += 1;
}
}
if i == len {
return true;
}
if len < SHORTEST_SHIFTING {
return false;
}
v.swap(i - 1, i);
shift_tail(&mut v[..i], is_less);
shift_head(&mut v[i..], is_less);
}
false
}
fn insertion_sort<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
for i in 1..v.len() {
shift_tail(&mut v[..=i], is_less);
}
}
#[cold]
fn heapsort<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let sift_down = |v: &mut [T], mut node| {
loop {
let left = 2 * node + 1;
let right = 2 * node + 2;
let greater = if right < v.len() && is_less(&v[left], &v[right]) {
right
} else {
left
};
if greater >= v.len() || !is_less(&v[node], &v[greater]) {
break;
}
v.swap(node, greater);
node = greater;
}
};
for i in (0..v.len() / 2).rev() {
sift_down(v, i);
}
for i in (1..v.len()).rev() {
v.swap(0, i);
sift_down(&mut v[..i], 0);
}
}
fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
const BLOCK: usize = 128;
let mut l = v.as_mut_ptr();
let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l: [u8; BLOCK] = unsafe { mem::uninitialized() };
let mut r = unsafe { l.add(v.len()) };
let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r: [u8; BLOCK] = unsafe { mem::uninitialized() };
fn width<T>(l: *mut T, r: *mut T) -> usize {
assert!(mem::size_of::<T>() > 0);
(r as usize - l as usize) / mem::size_of::<T>()
}
loop {
let is_done = width(l, r) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r) == block_l + block_r);
}
if start_l == end_l {
start_l = offsets_l.as_mut_ptr();
end_l = offsets_l.as_mut_ptr();
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
elem = elem.offset(1);
}
}
}
if start_r == end_r {
start_r = offsets_r.as_mut_ptr();
end_r = offsets_r.as_mut_ptr();
let mut elem = r;
for i in 0..block_r {
unsafe {
elem = elem.offset(-1);
*end_r = i as u8;
end_r = end_r.offset(is_less(&*elem, pivot) as isize);
}
}
}
let count = cmp::min(width(start_l, end_l), width(start_r, end_r));
if count > 0 {
macro_rules! left {
() => {
l.offset(*start_l as isize)
};
}
macro_rules! right {
() => {
r.offset(-(*start_r as isize) - 1)
};
}
unsafe {
let tmp = ptr::read(left!());
ptr::copy_nonoverlapping(right!(), left!(), 1);
for _ in 1..count {
start_l = start_l.offset(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
start_r = start_r.offset(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
}
ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
start_l = start_l.offset(1);
start_r = start_r.offset(1);
}
}
if start_l == end_l {
l = unsafe { l.add(block_l) };
}
if start_r == end_r {
r = unsafe { r.sub(block_r) };
}
if is_done {
break;
}
}
if start_l < end_l {
debug_assert_eq!(width(l, r), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.offset(-1);
ptr::swap(l.offset(*end_l as isize), r.offset(-1));
r = r.offset(-1);
}
}
width(v.as_mut_ptr(), r)
} else if start_r < end_r {
debug_assert_eq!(width(l, r), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.offset(-1);
ptr::swap(l, r.offset(-(*end_r as isize) - 1));
l = l.offset(1);
}
}
width(v.as_mut_ptr(), l)
} else {
width(v.as_mut_ptr(), l)
}
}
fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
let (mid, was_partitioned) = {
v.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let pivot = &mut pivot[0];
let write_on_drop = WriteOnDrop {
value: unsafe { Some(ptr::read(pivot)) },
dest: pivot,
};
let pivot = write_on_drop.value.as_ref().unwrap();
let mut l = 0;
let mut r = v.len();
unsafe {
while l < r && is_less(v.get_unchecked(l), pivot) {
l += 1;
}
while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(&mut v[l..r], pivot, is_less),
l >= r,
)
};
v.swap(0, mid);
(mid, was_partitioned)
}
fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
v.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let pivot = &mut pivot[0];
let write_on_drop = WriteOnDrop {
value: unsafe { Some(ptr::read(pivot)) },
dest: pivot,
};
let pivot = write_on_drop.value.as_ref().unwrap();
let mut l = 0;
let mut r = v.len();
loop {
unsafe {
while l < r && !is_less(pivot, v.get_unchecked(l)) {
l += 1;
}
while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
r -= 1;
}
if l >= r {
break;
}
r -= 1;
ptr::swap(v.get_unchecked_mut(l), v.get_unchecked_mut(r));
l += 1;
}
}
l + 1
}
#[cold]
fn break_patterns<T>(v: &mut [T]) {
let len = v.len();
if len >= 8 {
let mut random = len as u32;
let mut gen_u32 = || {
random ^= random << 13;
random ^= random >> 17;
random ^= random << 5;
random
};
let mut gen_usize = || {
if mem::size_of::<usize>() <= 4 {
gen_u32() as usize
} else {
((u64::from(gen_u32()) << 32) | u64::from(gen_u32())) as usize
}
};
let modulus = len.next_power_of_two();
let pos = len / 4 * 2;
for i in 0..3 {
let mut other = gen_usize() & (modulus - 1);
if other >= len {
other -= len;
}
v.swap(pos - 1 + i, other);
}
}
}
fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
let len = v.len();
let mut a = len / 4 * 1;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
v.reverse();
(len - 1 - b, true)
}
}
fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &F, mut pred: Option<&'a mut T>, mut limit: usize)
where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
const MAX_INSERTION: usize = 20;
const MAX_SEQUENTIAL: usize = 2000;
let mut was_balanced = true;
let mut was_partitioned = true;
loop {
let len = v.len();
if len <= MAX_INSERTION {
insertion_sort(v, is_less);
return;
}
if limit == 0 {
heapsort(v, is_less);
return;
}
if !was_balanced {
break_patterns(v);
limit -= 1;
}
let (pivot, likely_sorted) = choose_pivot(v, is_less);
if was_balanced && was_partitioned && likely_sorted {
if partial_insertion_sort(v, is_less) {
return;
}
}
if let Some(ref p) = pred {
if !is_less(p, &v[pivot]) {
let mid = partition_equal(v, pivot, is_less);
v = &mut { v }[mid..];
continue;
}
}
let (mid, was_p) = partition(v, pivot, is_less);
was_balanced = cmp::min(mid, len - mid) >= len / 8;
was_partitioned = was_p;
let (left, right) = { v }.split_at_mut(mid);
let (pivot, right) = right.split_at_mut(1);
let pivot = &mut pivot[0];
if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL {
if left.len() < right.len() {
recurse(left, is_less, pred, limit);
v = right;
pred = Some(pivot);
} else {
recurse(right, is_less, Some(pivot), limit);
v = left;
}
} else {
rayon_core::join(
|| recurse(left, is_less, pred, limit),
|| recurse(right, is_less, Some(pivot), limit),
);
break;
}
}
}
pub(super) fn par_quicksort<T, F>(v: &mut [T], is_less: F)
where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
if mem::size_of::<T>() == 0 {
return;
}
let limit = mem::size_of::<usize>() * 8 - v.len().leading_zeros() as usize;
recurse(v, &is_less, None, limit);
}
#[cfg(test)]
mod tests {
use super::heapsort;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
#[test]
fn test_heapsort() {
let mut rng = thread_rng();
for len in (0..25).chain(500..501) {
for &modulus in &[5, 10, 100] {
let dist = Uniform::new(0, modulus);
for _ in 0..100 {
let v: Vec<i32> = rng.sample_iter(&dist).take(len).collect();
let mut tmp = v.clone();
heapsort(&mut tmp, &|a, b| a < b);
assert!(tmp.windows(2).all(|w| w[0] <= w[1]));
let mut tmp = v.clone();
heapsort(&mut tmp, &|a, b| a > b);
assert!(tmp.windows(2).all(|w| w[0] >= w[1]));
}
}
}
let mut v: Vec<_> = (0..100).collect();
heapsort(&mut v, &|_, _| thread_rng().gen());
heapsort(&mut v, &|a, b| a < b);
for i in 0..v.len() {
assert_eq!(v[i], i);
}
}
}