use crate::utils::assume;
use std::{
cmp::min,
mem::{self, ManuallyDrop, MaybeUninit, size_of},
ptr,
};
pub fn quicksort<T, R>(v: &mut [T], v2: &mut [R])
where
T: Ord,
{
unsafe { assume(v.len() == v2.len()) }
if size_of::<T>() == 0 {
return;
} else if size_of::<R>() == 0 {
v.sort_unstable();
return;
}
let limit = usize::BITS - v.len().leading_zeros();
recurse(v, v2, &mut |k1, k2| k1 < k2, None, limit);
}
fn recurse<'a, T, R, F>(
mut v: &'a mut [T],
mut v2: &'a mut [R],
is_less: &mut F,
mut pred: Option<&'a T>,
mut limit: u32,
) where
F: FnMut(&T, &T) -> bool,
{
const MAX_INSERTION: usize = 20;
assert_ne!(size_of::<T>(), 0);
assert_ne!(size_of::<R>(), 0);
let mut was_balanced = true;
let mut was_partitioned = true;
loop {
unsafe { assume(v.len() == v2.len()) }
let len = v.len();
if len <= MAX_INSERTION {
insertion_sort(v, v2, is_less);
return;
}
if limit == 0 {
unsafe { heapsort(v, v2, is_less) };
return;
}
if !was_balanced {
break_patterns(v, v2);
limit -= 1;
}
let (pivot, likely_sorted) = choose_pivot(v, v2, is_less);
if was_balanced && was_partitioned && likely_sorted {
if partial_insertion_sort(v, v2, is_less) {
return;
}
}
if let Some(p) = pred
&& !is_less(p, &v[pivot])
{
let mid = partition_equal(v, v2, pivot, is_less);
v = &mut v[mid..];
v2 = &mut v2[mid..];
unsafe { assume(v.len() == v2.len()) }
continue;
}
let (mid, was_p) = partition(v, v2, pivot, is_less);
was_balanced = min(mid, len - mid) >= len / 8;
was_partitioned = was_p;
let (left, right) = v.split_at_mut(mid);
let (left2, right2) = v2.split_at_mut(mid);
let (pivot, right) = right.split_at_mut(1);
let (_pivot2, right2) = right2.split_at_mut(1);
let pivot = &pivot[0];
unsafe {
assume(left.len() == left2.len());
assume(right.len() == right2.len());
}
if left.len() < right.len() {
recurse(left, left2, is_less, pred, limit);
v = right;
v2 = right2;
unsafe { assume(v.len() == v2.len()) }
pred = Some(pivot);
} else {
recurse(right, right2, is_less, Some(pivot), limit);
v = left;
v2 = left2;
unsafe { assume(v.len() == v2.len()) }
}
}
}
struct CopyOnDrop<T, R> {
src1: *const T,
dest1: *mut T,
src2: *const R,
dest2: *mut R,
}
impl<T, R> Drop for CopyOnDrop<T, R> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src1, self.dest1, 1);
ptr::copy_nonoverlapping(self.src2, self.dest2, 1);
}
}
}
fn shift_head<T, R, F>(v: &mut [T], v2: &mut [R], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
unsafe { assume(v.len() == v2.len()) }
let len = v.len();
unsafe {
if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) {
let tmp = ManuallyDrop::new(ptr::read(v.get_unchecked(0)));
let tmp2 = ManuallyDrop::new(ptr::read(v2.get_unchecked(0)));
let v = v.as_mut_ptr();
let v2 = v2.as_mut_ptr();
let mut hole = CopyOnDrop {
src1: &*tmp,
dest1: v.add(1),
src2: &*tmp2,
dest2: v2.add(1),
};
ptr::copy_nonoverlapping(v.add(1), v.add(0), 1);
ptr::copy_nonoverlapping(v2.add(1), v2.add(0), 1);
for i in 2..len {
if !is_less(&*v.add(i), &*tmp) {
break;
}
ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1);
ptr::copy_nonoverlapping(v2.add(i), v2.add(i - 1), 1);
hole.dest1 = v.add(i);
hole.dest2 = v2.add(i);
}
}
}
}
fn shift_tail<T, R, F>(v: &mut [T], v2: &mut [R], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
unsafe {
assume(v.len() == v2.len());
if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
let tmp = ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
let tmp2 = ManuallyDrop::new(ptr::read(v2.get_unchecked(len - 1)));
let v = v.as_mut_ptr();
let v2 = v2.as_mut_ptr();
let mut hole = CopyOnDrop {
src1: &*tmp,
dest1: v.add(len - 2),
src2: &*tmp2,
dest2: v2.add(len - 2),
};
ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1);
ptr::copy_nonoverlapping(v2.add(len - 2), v2.add(len - 1), 1);
for i in (0..len - 2).rev() {
if !is_less(&*tmp, &*v.add(i)) {
break;
}
ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1);
ptr::copy_nonoverlapping(v2.add(i), v2.add(i + 1), 1);
hole.dest1 = v.add(i);
hole.dest2 = v2.add(i);
}
}
}
}
#[cold]
fn partial_insertion_sort<T, R, F>(v: &mut [T], v2: &mut [R], is_less: &mut F) -> bool
where
F: FnMut(&T, &T) -> bool,
{
const MAX_STEPS: usize = 5;
const SHORTEST_SHIFTING: usize = 50;
unsafe { assume(v.len() == v2.len()) }
let len = v.len();
let mut i = 1;
for _ in 0..MAX_STEPS {
unsafe {
while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) {
i += 1;
}
}
if i == len {
return true;
}
if len < SHORTEST_SHIFTING {
return false;
}
v.swap(i - 1, i);
v2.swap(i - 1, i);
shift_tail(&mut v[..i], &mut v2[..i], is_less);
shift_head(&mut v[i..], &mut v2[i..], is_less);
}
false
}
fn insertion_sort<T, R, F>(v: &mut [T], v2: &mut [R], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
unsafe { assume(v.len() == v2.len()) }
for i in 1..v.len() {
shift_tail(&mut v[..i + 1], &mut v2[..i + 1], is_less);
}
}
#[cold]
unsafe fn heapsort<T, R, F>(v: &mut [T], v2: &mut [R], mut is_less: F)
where
F: FnMut(&T, &T) -> bool,
{
unsafe { assume(v.len() == v2.len()) }
let mut sift_down = |v: &mut [T], v2: &mut [R], mut node| {
loop {
let mut child = 2 * node + 1;
if child >= v.len() {
break;
}
if child + 1 < v.len() {
child += is_less(&v[child], &v[child + 1]) as usize;
}
if !is_less(&v[node], &v[child]) {
break;
}
v.swap(node, child);
v2.swap(node, child);
node = child;
}
};
for i in (0..v.len() / 2).rev() {
sift_down(v, v2, i);
}
for i in (1..v.len()).rev() {
v.swap(0, i);
v2.swap(0, i);
sift_down(&mut v[..i], &mut v2[..i], 0);
}
}
fn partition_in_blocks<T, R, F>(v: &mut [T], v2: &mut [R], pivot: &T, is_less: &mut F) -> usize
where
F: FnMut(&T, &T) -> bool,
{
const BLOCK: usize = 128;
unsafe { assume(v.len() == v2.len()) }
let mut l = v.as_mut_ptr();
let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
let mut r = unsafe { l.add(v.len()) };
let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
#[inline(always)]
const fn width<T>(l: *mut T, r: *mut T) -> usize {
assert!(mem::size_of::<T>() != 0);
unsafe { r.offset_from(l) as usize }
}
loop {
let is_done = width(l, r) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r) == block_l + block_r);
}
if std::ptr::eq(start_l, end_l) {
start_l = offsets_l.as_mut_ptr().cast();
end_l = start_l;
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.add(!is_less(&*elem, pivot) as usize);
elem = elem.add(1);
}
}
}
if std::ptr::eq(start_r, end_r) {
start_r = offsets_r.as_mut_ptr().cast();
end_r = start_r;
let mut elem = r;
for i in 0..block_r {
unsafe {
elem = elem.sub(1);
*end_r = i as u8;
end_r = end_r.add(is_less(&*elem, pivot) as usize);
}
}
}
let count = min(width(start_l, end_l), width(start_r, end_r));
if count > 0 {
let l2 = unsafe { v2.as_mut_ptr().offset(l.offset_from(v.as_ptr())) };
let r2 = unsafe { v2.as_mut_ptr().offset(r.offset_from(v.as_ptr())) };
macro_rules! left {
() => {
l.add(usize::from(*start_l))
};
}
macro_rules! left2 {
() => {
l2.add(usize::from(*start_l))
};
}
macro_rules! right {
() => {
r.sub(usize::from(*start_r) + 1)
};
}
macro_rules! right2 {
() => {
r2.sub(usize::from(*start_r) + 1)
};
}
unsafe {
let tmp = ptr::read(left!());
let tmp2 = ptr::read(left2!());
ptr::copy_nonoverlapping(right!(), left!(), 1);
ptr::copy_nonoverlapping(right2!(), left2!(), 1);
for _ in 1..count {
start_l = start_l.add(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
ptr::copy_nonoverlapping(left2!(), right2!(), 1);
start_r = start_r.add(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
ptr::copy_nonoverlapping(right2!(), left2!(), 1);
}
ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
ptr::copy_nonoverlapping(&tmp2, right2!(), 1);
mem::forget(tmp2);
start_l = start_l.add(1);
start_r = start_r.add(1);
}
}
if std::ptr::eq(start_l, end_l) {
l = unsafe { l.add(block_l) };
}
if std::ptr::eq(start_r, end_r) {
r = unsafe { r.sub(block_r) };
}
if is_done {
break;
}
}
if start_l < end_l {
let l2 = unsafe { v2.as_mut_ptr().offset(l.offset_from(v.as_ptr())) };
let mut r2 = unsafe { v2.as_mut_ptr().offset(r.offset_from(v.as_ptr())) };
debug_assert_eq!(width(l, r), block_l);
debug_assert_eq!(width(l2, r2), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.sub(1);
ptr::swap(l.add(usize::from(*end_l)), r.sub(1));
ptr::swap(l2.add(usize::from(*end_l)), r2.sub(1));
r = r.sub(1);
r2 = r2.sub(1);
}
}
width(v.as_mut_ptr(), r)
} else if start_r < end_r {
let mut l2 = unsafe { v2.as_mut_ptr().offset(l.offset_from(v.as_ptr())) };
let r2 = unsafe { v2.as_mut_ptr().offset(r.offset_from(v.as_ptr())) };
debug_assert_eq!(width(l, r), block_r);
debug_assert_eq!(width(l2, r2), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.sub(1);
ptr::swap(l, r.sub(usize::from(*end_r) + 1));
ptr::swap(l2, r2.sub(usize::from(*end_r) + 1));
l = l.add(1);
l2 = l2.add(1);
}
}
width(v.as_mut_ptr(), l)
} else {
width(v.as_mut_ptr(), l)
}
}
fn partition<T, R, F>(v: &mut [T], v2: &mut [R], pivot: usize, is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
let (mid, was_partitioned) = {
unsafe { assume(v.len() == v2.len()) }
v.swap(0, pivot);
v2.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let (pivot2, v2) = v2.split_at_mut(1);
unsafe { assume(v.len() == v2.len()) }
debug_assert!(pivot.len() == 1 && pivot2.len() == 1);
let (pivot, pivot2) = (&mut pivot[0], &mut pivot2[0]);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let tmp2 = ManuallyDrop::new(unsafe { ptr::read(pivot2) });
let _pivot_guard = CopyOnDrop {
src1: &*tmp,
dest1: pivot,
src2: &*tmp2,
dest2: pivot2,
};
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
unsafe {
while l < r && is_less(v.get_unchecked(l), pivot) {
l += 1;
}
while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(&mut v[l..r], &mut v2[l..r], pivot, is_less),
l >= r,
)
};
v.swap(0, mid);
v2.swap(0, mid);
(mid, was_partitioned)
}
fn partition_equal<T, R, F>(v: &mut [T], v2: &mut [R], pivot: usize, is_less: &mut F) -> usize
where
F: FnMut(&T, &T) -> bool,
{
unsafe { assume(v.len() == v2.len()) }
v.swap(0, pivot);
v2.swap(0, pivot);
let (pivot, v) = v.split_at_mut(1);
let (pivot2, v2) = v2.split_at_mut(1);
unsafe { assume(v.len() == v2.len()) };
debug_assert!(pivot.len() == 1 && pivot2.len() == 1);
let (pivot, pivot2) = (&mut pivot[0], &mut pivot2[0]);
let tmp = ManuallyDrop::new(unsafe { ptr::read(pivot) });
let tmp2 = ManuallyDrop::new(unsafe { ptr::read(pivot2) });
let _pivot_guard = CopyOnDrop {
src1: &*tmp,
dest1: pivot,
src2: &*tmp2,
dest2: pivot2,
};
let pivot = &*tmp;
let mut l = 0;
let mut r = v.len();
loop {
unsafe {
while l < r && !is_less(pivot, v.get_unchecked(l)) {
l += 1;
}
while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
r -= 1;
}
if l >= r {
break;
}
r -= 1;
debug_assert!(l < v.len() && r < v.len());
let ptr = v.as_mut_ptr();
ptr::swap(ptr.add(l), ptr.add(r));
let ptr2 = v2.as_mut_ptr();
ptr::swap(ptr2.add(l), ptr2.add(r));
l += 1;
}
}
l + 1
}
#[cold]
fn break_patterns<T, R>(v: &mut [T], v2: &mut [R]) {
unsafe { assume(v.len() == v2.len()) }
let len = v.len();
if len >= 8 {
let mut random = len as u32;
let mut gen_u32 = || {
random ^= random << 13;
random ^= random >> 17;
random ^= random << 5;
random
};
let mut gen_usize = || {
if usize::BITS <= 32 {
gen_u32() as usize
} else {
(((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize
}
};
let modulus = len.next_power_of_two();
let pos = len / 4 * 2;
for i in 0..3 {
let mut other = gen_usize() & (modulus - 1);
if other >= len {
other -= len;
}
v.swap(pos - 1 + i, other);
v2.swap(pos - 1 + i, other);
}
}
}
fn choose_pivot<T, R, F>(v: &mut [T], v2: &mut [R], is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
unsafe { assume(v.len() == v2.len()) }
let len = v.len();
let mut a = len / 4;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
debug_assert!(*a < v.len() && *b < v.len());
if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
v.reverse();
v2.reverse();
(len - 1 - b, true)
}
}