use indicatif::{ProgressBar, ProgressStyle};
use std::cmp;
use std::marker::PhantomData;
use std::mem::{self, MaybeUninit};
use std::ptr;
pub fn par_sort_unstable_by_key<T, K, F>(data: &mut [T], f: F, verbose: bool)
where
T: Send,
K: Ord + Send,
F: Fn(&T) -> K + Sync,
{
par_quicksort(data, |a, b| f(a).lt(&f(b)), verbose);
}
#[must_use]
struct CopyOnDrop<'a, T> {
src: *const T,
dest: *mut T,
marker: PhantomData<&'a mut T>,
}
impl<'a, T> CopyOnDrop<'a, T> {
unsafe fn new(src: &'a T, dest: *mut T) -> Self {
CopyOnDrop {
src,
dest,
marker: PhantomData,
}
}
}
impl<T> Drop for CopyOnDrop<'_, T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}
fn shift_head<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let len = v.len();
unsafe {
if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) {
let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0)));
let v = v.as_mut_ptr();
let mut hole = CopyOnDrop::new(&*tmp, v.add(1));
ptr::copy_nonoverlapping(v.add(1), v.add(0), 1);
for i in 2..len {
if !is_less(&*v.add(i), &*tmp) {
break;
}
ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1);
hole.dest = v.add(i);
}
}
}
}
fn shift_tail<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let len = v.len();
unsafe {
if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
let v = v.as_mut_ptr();
let mut hole = CopyOnDrop::new(&*tmp, v.add(len - 2));
ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1);
for i in (0..len - 2).rev() {
if !is_less(&*tmp, &*v.add(i)) {
break;
}
ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1);
hole.dest = v.add(i);
}
}
}
}
#[cold]
fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool
where
F: Fn(&T, &T) -> bool,
{
const MAX_STEPS: usize = 5;
const SHORTEST_SHIFTING: usize = 50;
let len = v.len();
let mut i = 1;
for _ in 0..MAX_STEPS {
unsafe {
while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) {
i += 1;
}
}
if i == len {
return true;
}
if len < SHORTEST_SHIFTING {
return false;
}
v.swap(i - 1, i);
shift_tail(&mut v[..i], is_less);
shift_head(&mut v[i..], is_less);
}
false
}
fn insertion_sort<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
for i in 1..v.len() {
shift_tail(&mut v[..i + 1], is_less);
}
}
#[cold]
fn heapsort<T, F>(v: &mut [T], is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let sift_down = |v: &mut [T], mut node| {
loop {
let mut child = 2 * node + 1;
if child >= v.len() {
break;
}
if child + 1 < v.len() && is_less(&v[child], &v[child + 1]) {
child += 1;
}
if !is_less(&v[node], &v[child]) {
break;
}
v.swap(node, child);
node = child;
}
};
for i in (0..v.len() / 2).rev() {
sift_down(v, i);
}
for i in (1..v.len()).rev() {
v.swap(0, i);
sift_down(&mut v[..i], 0);
}
}
fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
const BLOCK: usize = 128;
let mut l = v.as_mut_ptr();
let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
let mut r = unsafe { l.add(v.len()) };
let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
fn width<T>(l: *mut T, r: *mut T) -> usize {
assert!(mem::size_of::<T>() > 0);
(r as usize - l as usize) / mem::size_of::<T>()
}
loop {
let is_done = width(l, r) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r) == block_l + block_r);
}
if start_l == end_l {
start_l = offsets_l.as_mut_ptr() as *mut u8;
end_l = start_l;
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
elem = elem.offset(1);
}
}
}
if start_r == end_r {
start_r = offsets_r.as_mut_ptr() as *mut u8;
end_r = start_r;
let mut elem = r;
for i in 0..block_r {
unsafe {
elem = elem.offset(-1);
*end_r = i as u8;
end_r = end_r.offset(is_less(&*elem, pivot) as isize);
}
}
}
let count = cmp::min(width(start_l, end_l), width(start_r, end_r));
if count > 0 {
macro_rules! left {
() => {
l.offset(*start_l as isize)
};
}
macro_rules! right {
() => {
r.offset(-(*start_r as isize) - 1)
};
}
unsafe {
let tmp = ptr::read(left!());
ptr::copy_nonoverlapping(right!(), left!(), 1);
for _ in 1..count {
start_l = start_l.offset(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
start_r = start_r.offset(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
}
ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
start_l = start_l.offset(1);
start_r = start_r.offset(1);
}
}
if start_l == end_l {
l = unsafe { l.add(block_l) };
}
if start_r == end_r {
r = unsafe { r.offset(-(block_r as isize)) };
}
if is_done {
break;
}
}
if start_l < end_l {
debug_assert_eq!(width(l, r), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.offset(-1);
ptr::swap(l.offset(*end_l as isize), r.offset(-1));
r = r.offset(-1);
}
}
width(v.as_mut_ptr(), r)
} else if start_r < end_r {
debug_assert_eq!(width(l, r), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.offset(-1);
ptr::swap(l, r.offset(-(*end_r as isize) - 1));
l = l.offset(1);
}
}
width(v.as_mut_ptr(), l)
} else {
width(v.as_mut_ptr(), l)
}
}
fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
let (mid, was_partitioned) = {
v.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let pivot = &mut pivot[0];
let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) };
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
unsafe {
while l < r && is_less(v.get_unchecked(l), pivot) {
l += 1;
}
while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(&mut v[l..r], pivot, is_less),
l >= r,
)
};
v.swap(0, mid);
(mid, was_partitioned)
}
fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize
where
F: Fn(&T, &T) -> bool,
{
v.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let pivot = &mut pivot[0];
let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) };
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
loop {
unsafe {
while l < r && !is_less(pivot, v.get_unchecked(l)) {
l += 1;
}
while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
r -= 1;
}
if l >= r {
break;
}
r -= 1;
let ptr = v.as_mut_ptr();
ptr::swap(ptr.add(l), ptr.add(r));
l += 1;
}
}
l + 1
}
#[cold]
fn break_patterns<T>(v: &mut [T]) {
let len = v.len();
if len >= 8 {
let mut random = len as u32;
let mut gen_u32 = || {
random ^= random << 13;
random ^= random >> 17;
random ^= random << 5;
random
};
let mut gen_usize = || {
if usize::BITS <= 32 {
gen_u32() as usize
} else {
(((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize
}
};
let modulus = len.next_power_of_two();
let pos = len / 4 * 2;
for i in 0..3 {
let mut other = gen_usize() & (modulus - 1);
if other >= len {
other -= len;
}
v.swap(pos - 1 + i, other);
}
}
}
fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool)
where
F: Fn(&T, &T) -> bool,
{
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
let len = v.len();
#[allow(clippy::identity_op)]
let mut a = len / 4 * 1;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
v.reverse();
(len - 1 - b, true)
}
}
fn recurse<'a, T, F>(
mut v: &'a mut [T],
is_less: &F,
mut pred: Option<&'a mut T>,
mut limit: u32,
pbar: &ProgressBar,
) where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
const MAX_INSERTION: usize = 20;
const MAX_SEQUENTIAL: usize = 2000;
let mut was_balanced = true;
let mut was_partitioned = true;
loop {
let len = v.len();
if len <= MAX_INSERTION {
insertion_sort(v, is_less);
return;
}
if limit == 0 {
heapsort(v, is_less);
return;
}
if !was_balanced {
break_patterns(v);
limit -= 1;
}
let (pivot, likely_sorted) = choose_pivot(v, is_less);
if was_balanced && was_partitioned && likely_sorted {
if partial_insertion_sort(v, is_less) {
return;
}
}
if let Some(ref p) = pred {
if !is_less(p, &v[pivot]) {
let mid = partition_equal(v, pivot, is_less);
v = &mut v[mid..];
continue;
}
}
let (mid, was_p) = partition(v, pivot, is_less);
was_balanced = cmp::min(mid, len - mid) >= len / 8;
was_partitioned = was_p;
let (left, right) = v.split_at_mut(mid);
let (pivot, right) = right.split_at_mut(1);
let pivot = &mut pivot[0];
if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL {
if left.len() < right.len() {
recurse(left, is_less, pred, limit, pbar);
v = right;
pred = Some(pivot);
} else {
recurse(right, is_less, Some(pivot), limit, pbar);
v = left;
}
} else {
pbar.inc(1);
rayon_core::join(
|| recurse(left, is_less, pred, limit, pbar),
|| recurse(right, is_less, Some(pivot), limit, pbar),
);
break;
}
}
}
pub fn par_quicksort<T, F>(v: &mut [T], is_less: F, verbose: bool)
where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
if mem::size_of::<T>() == 0 {
return;
}
let limit = usize::BITS - v.len().leading_zeros();
let pbar = if verbose {
let p = ProgressBar::new((v.len() as f64 / 2000.0).ceil() as u64);
p.set_style(
ProgressStyle::with_template(
"{elapsed} elapsed (estimated duration {duration}) {bar:80}",
)
.unwrap(),
);
p
} else {
ProgressBar::hidden()
};
recurse(v, &is_less, None, limit, &pbar);
pbar.finish();
}
#[cfg(test)]
mod tests {
use super::heapsort;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
#[test]
fn test_heapsort() {
let rng = &mut thread_rng();
for len in (0..25).chain(500..501) {
for &modulus in &[5, 10, 100] {
let dist = Uniform::new(0, modulus);
for _ in 0..100 {
let v: Vec<i32> = rng.sample_iter(&dist).take(len).collect();
let mut tmp = v.clone();
heapsort(&mut tmp, &|a, b| a < b);
assert!(tmp.windows(2).all(|w| w[0] <= w[1]));
let mut tmp = v.clone();
heapsort(&mut tmp, &|a, b| a > b);
assert!(tmp.windows(2).all(|w| w[0] >= w[1]));
}
}
}
let mut v: Vec<_> = (0..100).collect();
heapsort(&mut v, &|_, _| thread_rng().gen());
heapsort(&mut v, &|a, b| a < b);
for (i, &entry) in v.iter().enumerate() {
assert_eq!(entry, i);
}
}
}