use core::cmp::Ordering;
use core::mem::{self, ManuallyDrop};
use core::ptr;
extern crate alloc;
use alloc::alloc::{alloc, dealloc, Layout};
#[inline(always)]
pub fn sort<T: Ord>(v: &mut [T]) {
stable_sort(v, |a, b| a.lt(b))
}
#[inline(always)]
pub fn sort_by<T, F: FnMut(&T, &T) -> Ordering>(v: &mut [T], mut compare: F) {
stable_sort(v, |a, b| compare(a, b) == Ordering::Less);
}
#[inline(always)]
pub fn sort_by_key<T, K, F>(v: &mut [T], mut f: F)
where
F: FnMut(&T) -> K,
K: Ord,
{
stable_sort(v, |a, b| f(a).lt(&f(b)));
}
#[inline(always)]
fn stable_sort<T, F: FnMut(&T, &T) -> bool>(v: &mut [T], mut is_less: F) {
if mem::size_of::<T>() == 0 {
return;
}
let len = v.len();
if len < 2 {
return;
}
unsafe {
mergesort_main(v, &mut is_less);
}
}
#[inline(never)]
unsafe fn mergesort_main<T, F: FnMut(&T, &T) -> bool>(v: &mut [T], is_less: &mut F) {
let buf = unsafe { BufGuard::new(v.len()) };
unsafe {
mergesort_core(v, buf.buf_ptr.as_ptr(), is_less);
}
}
#[inline(always)]
unsafe fn mergesort_core<T, F: FnMut(&T, &T) -> bool>(
v: &mut [T],
scratch_ptr: *mut T,
is_less: &mut F,
) {
let len = v.len();
if len > 2 {
unsafe {
let mid = len / 2;
mergesort_core(v.get_unchecked_mut(..mid), scratch_ptr, is_less);
mergesort_core(v.get_unchecked_mut(mid..), scratch_ptr, is_less);
merge(v, scratch_ptr, is_less, mid);
}
} else if len == 2 {
let should_swap = is_less(&v[1], &v[0]);
unsafe {
branchless_swap(&mut v[1], &mut v[0], should_swap);
}
}
}
#[inline(always)]
unsafe fn merge<T, F>(v: &mut [T], scratch_ptr: *mut T, is_less: &mut F, mid: usize)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
debug_assert!(mid > 0 && mid < len);
let len = v.len();
let mut l = 0;
let mut r = mid;
unsafe {
let arr_ptr = v.as_ptr();
for i in 0..len {
let left_ptr = arr_ptr.add(l);
let right_ptr = arr_ptr.add(r);
let is_lt = !is_less(&*right_ptr, &*left_ptr);
let copy_ptr = if is_lt { left_ptr } else { right_ptr };
ptr::copy_nonoverlapping(copy_ptr, scratch_ptr.add(i), 1);
l += is_lt as usize;
r += !is_lt as usize;
if ((l == mid) as u8 + (r == len) as u8) != 0 {
break;
}
}
let copy_ptr = if l == mid {
arr_ptr.add(r)
} else {
arr_ptr.add(l)
};
let i = l + (r - mid);
ptr::copy_nonoverlapping(copy_ptr, scratch_ptr.add(i), len - i);
ptr::copy_nonoverlapping(scratch_ptr, v.as_mut_ptr(), len);
}
}
#[inline(always)]
unsafe fn branchless_swap<T>(x: *mut T, y: *mut T, should_swap: bool) {
unsafe {
let x_swap = if should_swap { y } else { x };
let y_swap = if should_swap { x } else { y };
let y_swap_copy = ManuallyDrop::new(ptr::read(y_swap));
ptr::copy(x_swap, x, 1);
ptr::copy_nonoverlapping(&*y_swap_copy, y, 1);
}
}
unsafe fn unwrap_unchecked<T>(opt_val: Option<T>) -> T {
match opt_val {
Some(val) => val,
None => {
unsafe {
core::hint::unreachable_unchecked();
}
}
}
}
struct BufGuard<T> {
buf_ptr: ptr::NonNull<T>,
capacity: usize,
}
impl<T> BufGuard<T> {
unsafe fn new(len: usize) -> Self {
debug_assert!(len > 0 && mem::size_of::<T>() > 0);
let layout = unsafe { unwrap_unchecked(Layout::array::<T>(len).ok()) };
let buf_ptr = unsafe { alloc(layout) as *mut T };
if buf_ptr.is_null() {
panic!("allocation failure");
}
Self {
buf_ptr: ptr::NonNull::new(buf_ptr).unwrap(),
capacity: len,
}
}
}
impl<T> Drop for BufGuard<T> {
fn drop(&mut self) {
unsafe {
dealloc(
self.buf_ptr.as_ptr() as *mut u8,
Layout::array::<T>(self.capacity).unwrap(),
);
}
}
}