use std::mem::MaybeUninit;
use std::ptr;
const INSERTION_SORT_THRESHOLD: usize = 20;
#[inline]
pub fn insertion_sort<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
end: usize,
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
for i in (start + 1)..end {
let mut j = i;
while j > start && is_less(v.get_ref(j), v.get_ref(j - 1)) {
v.swap(j, j - 1);
j -= 1;
}
}
}
#[inline(never)]
pub fn heapsort<T, F>(v: &mut impl IndexedAccess<T>, start: usize, end: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = end - start;
if len < 2 {
return;
}
for i in (0..len / 2).rev() {
sift_down(v, start, i, len, is_less);
}
for i in (1..len).rev() {
v.swap(start, start + i);
sift_down(v, start, 0, i, is_less);
}
}
#[inline]
fn sift_down<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
mut node: usize,
heap_size: usize,
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
loop {
let mut child = 2 * node + 1;
if child >= heap_size {
break;
}
if child + 1 < heap_size && is_less(v.get_ref(start + child), v.get_ref(start + child + 1))
{
child += 1;
}
if !is_less(v.get_ref(start + node), v.get_ref(start + child)) {
break;
}
v.swap(start + node, start + child);
node = child;
}
}
pub fn quicksort<T, F>(v: &mut impl IndexedAccess<T>, start: usize, end: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = end - start;
if len < 2 {
return;
}
if len <= INSERTION_SORT_THRESHOLD {
insertion_sort(v, start, end, is_less);
return;
}
let limit = 2 * (usize::BITS - len.leading_zeros());
quicksort_recursive(v, start, end, is_less, limit);
}
fn quicksort_recursive<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
end: usize,
is_less: &mut F,
mut limit: u32,
) where
F: FnMut(&T, &T) -> bool,
{
let mut start = start;
let mut end = end;
loop {
let len = end - start;
if len <= INSERTION_SORT_THRESHOLD {
insertion_sort(v, start, end, is_less);
return;
}
if limit == 0 {
heapsort(v, start, end, is_less);
return;
}
limit -= 1;
let mid = start + len / 2;
let pivot_idx = choose_pivot(v, start, mid, end - 1, is_less);
v.swap(start, pivot_idx);
let pivot_final = partition(v, start, end, is_less);
let left_len = pivot_final - start;
let right_len = end - pivot_final - 1;
if left_len < right_len {
quicksort_recursive(v, start, pivot_final, is_less, limit);
start = pivot_final + 1;
} else {
quicksort_recursive(v, pivot_final + 1, end, is_less, limit);
end = pivot_final;
}
}
}
#[inline]
fn choose_pivot<T, F>(
v: &impl IndexedAccess<T>,
a: usize,
b: usize,
c: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&T, &T) -> bool,
{
if is_less(v.get_ref(a), v.get_ref(b)) {
if is_less(v.get_ref(b), v.get_ref(c)) {
b
} else if is_less(v.get_ref(a), v.get_ref(c)) {
c
} else {
a
}
} else if is_less(v.get_ref(a), v.get_ref(c)) {
a
} else if is_less(v.get_ref(b), v.get_ref(c)) {
c
} else {
b
}
}
fn partition<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
end: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&T, &T) -> bool,
{
let mut left = start + 1;
let mut right = end - 1;
loop {
while left <= right && is_less(v.get_ref(left), v.get_ref(start)) {
left += 1;
}
while left <= right && !is_less(v.get_ref(right), v.get_ref(start)) {
right -= 1;
}
if left > right {
break;
}
v.swap(left, right);
left += 1;
right -= 1;
}
v.swap(start, right);
right
}
pub fn merge_sort<T, F>(v: &mut impl IndexedAccess<T>, start: usize, end: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = end - start;
if len < 2 {
return;
}
if len <= INSERTION_SORT_THRESHOLD {
insertion_sort(v, start, end, is_less);
return;
}
let scratch_len = len / 2 + 1;
let mut scratch: Vec<MaybeUninit<T>> = Vec::with_capacity(scratch_len);
unsafe {
scratch.set_len(scratch_len);
}
merge_sort_with_scratch(v, start, end, &mut scratch, is_less);
}
fn merge_sort_with_scratch<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
end: usize,
scratch: &mut [MaybeUninit<T>],
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
let len = end - start;
if len < 2 {
return;
}
if len <= INSERTION_SORT_THRESHOLD {
insertion_sort(v, start, end, is_less);
return;
}
let mid = start + len / 2;
merge_sort_with_scratch(v, start, mid, scratch, is_less);
merge_sort_with_scratch(v, mid, end, scratch, is_less);
if !is_less(v.get_ref(mid), v.get_ref(mid - 1)) {
return;
}
merge(v, start, mid, end, scratch, is_less);
}
#[allow(clippy::needless_range_loop)]
fn merge<T, F>(
v: &mut impl IndexedAccess<T>,
start: usize,
mid: usize,
end: usize,
scratch: &mut [MaybeUninit<T>],
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
let left_len = mid - start;
let right_len = end - mid;
if left_len <= right_len {
for i in 0..left_len {
unsafe {
let val = ptr::read(v.get_ptr(start + i));
scratch[i].write(val);
}
}
let mut s = 0; let mut r = mid; let mut w = start;
while s < left_len && r < end {
let take_left = unsafe { !is_less(v.get_ref(r), scratch[s].assume_init_ref()) };
if take_left {
unsafe {
ptr::write(v.get_ptr_mut(w), scratch[s].assume_init_read());
}
s += 1;
} else {
unsafe {
let val = ptr::read(v.get_ptr(r));
ptr::write(v.get_ptr_mut(w), val);
}
r += 1;
}
w += 1;
}
while s < left_len {
unsafe {
ptr::write(v.get_ptr_mut(w), scratch[s].assume_init_read());
}
s += 1;
w += 1;
}
} else {
for i in 0..right_len {
unsafe {
let val = ptr::read(v.get_ptr(mid + i));
scratch[i].write(val);
}
}
let mut s = right_len; let mut l = mid; let mut w = end;
while s > 0 && l > start {
let take_right = unsafe { is_less(scratch[s - 1].assume_init_ref(), v.get_ref(l - 1)) };
w -= 1;
if take_right {
s -= 1;
unsafe {
ptr::write(v.get_ptr_mut(w), scratch[s].assume_init_read());
}
} else {
l -= 1;
unsafe {
let val = ptr::read(v.get_ptr(l));
ptr::write(v.get_ptr_mut(w), val);
}
}
}
while s > 0 {
s -= 1;
w -= 1;
unsafe {
ptr::write(v.get_ptr_mut(w), scratch[s].assume_init_read());
}
}
}
}
pub trait IndexedAccess<T> {
fn get_ref(&self, index: usize) -> &T;
fn get_ptr(&self, index: usize) -> *const T;
fn get_ptr_mut(&mut self, index: usize) -> *mut T;
fn swap(&mut self, a: usize, b: usize);
}