use std::ptr;
use std::{
alloc, cmp,
cmp::Ordering,
fmt::Debug,
mem::{MaybeUninit, align_of, size_of},
ops::Range,
};
fn vec_reverse(v: &mut [u8], val_size: usize) {
let len = v.len() / val_size;
let half_len = len / 2;
let Range { start, end } = v.as_mut_ptr_range();
let mut i = 0;
while i < half_len {
unsafe {
ptr::swap_nonoverlapping(
start.add(i * val_size),
end.sub((1 + i) * val_size),
val_size,
);
}
i += 1;
}
}
struct InsertionHole {
src: *const u8,
dest: *mut u8,
val_size: usize,
}
impl Drop for InsertionHole {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, self.val_size);
}
}
}
unsafe fn insert_tail(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
debug_assert!(v.len() / val_size >= 2);
let arr_ptr = v.as_mut_ptr();
let i = v.len() / val_size - 1;
unsafe {
let i_ptr = arr_ptr.add(i * val_size);
if is_less(i_ptr, i_ptr.sub(val_size)) {
ptr::copy_nonoverlapping(i_ptr, scratch, val_size);
let mut hole = InsertionHole {
src: scratch,
dest: i_ptr.sub(val_size),
val_size,
};
ptr::copy_nonoverlapping(hole.dest, i_ptr, val_size);
for j in (0..(i - 1)).rev() {
let j_ptr = arr_ptr.add(j * val_size);
if !is_less(scratch, j_ptr) {
break;
}
ptr::copy_nonoverlapping(j_ptr, hole.dest, val_size);
hole.dest = j_ptr;
}
}
}
}
#[inline(never)]
pub(super) fn insertion_sort_shift_left(
v: &mut [u8],
val_size: usize,
offset: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
let len = v.len() / val_size;
assert!(offset != 0 && offset <= len);
for i in offset..len {
unsafe {
insert_tail(
&mut v[..=(i + 1) * val_size - 1],
val_size,
is_less,
scratch,
);
}
}
}
unsafe fn merge(
v: &mut [u8],
val_size: usize,
mid: usize,
buf: *mut u8,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
) {
let len = v.len() / val_size;
let v = v.as_mut_ptr();
let (v_mid, v_end) = unsafe { (v.add(mid * val_size), v.add(len * val_size)) };
let mut hole;
if mid <= len - mid {
unsafe {
ptr::copy_nonoverlapping(v, buf, mid * val_size);
hole = MergeHole {
start: buf,
end: buf.add(mid * val_size),
dest: v,
};
}
let left = &mut hole.start;
let mut right = v_mid;
let out = &mut hole.dest;
while *left < hole.end && right < v_end {
unsafe {
let is_l = is_less(right, *left);
let to_copy = if is_l { right } else { *left };
ptr::copy_nonoverlapping(to_copy, *out, val_size);
*out = out.add(val_size);
right = right.add((is_l as usize) * val_size);
*left = left.add((!is_l as usize) * val_size);
}
}
} else {
unsafe {
ptr::copy_nonoverlapping(v_mid, buf, (len - mid) * val_size);
hole = MergeHole {
start: buf,
end: buf.add((len - mid) * val_size),
dest: v_mid,
};
}
let left = &mut hole.dest;
let right = &mut hole.end;
let mut out = v_end;
while v < *left && buf < *right {
unsafe {
let is_l = is_less(&*right.sub(val_size), &*left.sub(val_size));
*left = left.sub((is_l as usize) * val_size);
*right = right.sub((!is_l as usize) * val_size);
let to_copy = if is_l { *left } else { *right };
out = out.sub(val_size);
ptr::copy_nonoverlapping(to_copy, out, val_size);
}
}
}
struct MergeHole {
start: *mut u8,
end: *mut u8,
dest: *mut u8,
}
impl Drop for MergeHole {
fn drop(&mut self) {
unsafe {
let len = self.end as usize - self.start as usize;
ptr::copy_nonoverlapping(self.start, self.dest, len);
}
}
}
}
pub fn merge_sort(
v: &mut [u8],
val_size: usize,
align: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
const MAX_INSERTION: usize = 20;
if val_size == 0 {
return;
}
debug_assert_eq!(v.len() % val_size, 0);
let len = v.len() / val_size;
if len <= MAX_INSERTION {
if len >= 2 {
insertion_sort_shift_left(v, val_size, 1, is_less, scratch);
}
return;
}
let buf = BufGuard::new(len / 2, val_size, align);
let buf_ptr = buf.buf_ptr.as_ptr();
let mut runs = RunVec::new();
let mut end = 0;
let mut start = 0;
while end < len {
let (streak_end, was_reversed) = find_streak(&v[start * val_size..], val_size, is_less);
end += streak_end;
if was_reversed {
vec_reverse(&mut v[start * val_size..end * val_size], val_size);
}
end = provide_sorted_batch(v, val_size, start, end, is_less, scratch);
runs.push(TimSortRun {
start,
len: end - start,
});
start = end;
while let Some(r) = collapse(runs.as_slice(), len) {
let left = runs[r];
let right = runs[r + 1];
let merge_slice = &mut v[left.start * val_size..(right.start + right.len) * val_size];
unsafe {
merge(merge_slice, val_size, left.len, buf_ptr, is_less);
}
runs[r + 1] = TimSortRun {
start: left.start,
len: left.len + right.len,
};
runs.remove(r);
}
}
debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
#[inline]
fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> {
let n = runs.len();
if n >= 2
&& (runs[n - 1].start + runs[n - 1].len == stop
|| runs[n - 2].len <= runs[n - 1].len
|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
{
if n >= 3 && runs[n - 3].len < runs[n - 1].len {
Some(n - 3)
} else {
Some(n - 2)
}
} else {
None
}
}
struct BufGuard {
buf_ptr: ptr::NonNull<u8>,
capacity: usize,
val_size: usize,
align: usize,
}
impl BufGuard {
fn new(len: usize, val_size: usize, align: usize) -> Self {
let buf_ptr = unsafe {
alloc::alloc(alloc::Layout::from_size_align_unchecked(
val_size * len,
align,
))
};
Self {
buf_ptr: ptr::NonNull::new(buf_ptr).unwrap(),
capacity: len,
val_size,
align,
}
}
}
impl Drop for BufGuard {
fn drop(&mut self) {
unsafe {
alloc::dealloc(
self.buf_ptr.as_ptr(),
alloc::Layout::from_size_align_unchecked(
self.val_size * self.capacity,
self.align,
),
)
}
}
}
struct RunVec {
buf_ptr: ptr::NonNull<TimSortRun>,
capacity: usize,
len: usize,
}
impl RunVec {
fn new() -> Self {
const START_RUN_CAPACITY: usize = 16;
Self {
buf_ptr: ptr::NonNull::new(run_alloc(START_RUN_CAPACITY)).unwrap(),
capacity: START_RUN_CAPACITY,
len: 0,
}
}
fn push(&mut self, val: TimSortRun) {
if self.len == self.capacity {
let old_capacity = self.capacity;
let old_buf_ptr = self.buf_ptr.as_ptr();
self.capacity *= 2;
self.buf_ptr = ptr::NonNull::new(run_alloc(self.capacity)).unwrap();
unsafe {
ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr.as_ptr(), old_capacity);
}
run_dealloc(old_buf_ptr, old_capacity);
}
unsafe {
self.buf_ptr.as_ptr().add(self.len).write(val);
}
self.len += 1;
}
fn remove(&mut self, index: usize) {
if index >= self.len {
panic!("Index out of bounds");
}
unsafe {
let ptr = self.buf_ptr.as_ptr().add(index);
ptr::copy(ptr.add(1), ptr, self.len - index - 1);
}
self.len -= 1;
}
fn as_slice(&self) -> &[TimSortRun] {
unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr.as_ptr(), self.len) }
}
fn len(&self) -> usize {
self.len
}
}
impl core::ops::Index<usize> for RunVec {
type Output = TimSortRun;
fn index(&self, index: usize) -> &Self::Output {
if index < self.len {
unsafe {
return &*(self.buf_ptr.as_ptr().add(index));
}
}
panic!("Index out of bounds");
}
}
impl core::ops::IndexMut<usize> for RunVec {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
if index < self.len {
unsafe {
return &mut *(self.buf_ptr.as_ptr().add(index));
}
}
panic!("Index out of bounds");
}
}
impl Drop for RunVec {
fn drop(&mut self) {
run_dealloc(self.buf_ptr.as_ptr(), self.capacity);
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TimSortRun {
len: usize,
start: usize,
}
fn provide_sorted_batch(
v: &mut [u8],
val_size: usize,
start: usize,
mut end: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) -> usize {
debug_assert_eq!(v.len() % val_size, 0);
let len = v.len() / val_size;
assert!(end >= start && end <= len);
const MIN_INSERTION_RUN: usize = 10;
let start_end_diff = end - start;
if start_end_diff < MIN_INSERTION_RUN && end < len {
end = cmp::min(start + MIN_INSERTION_RUN, len);
let presorted_start = cmp::max(start_end_diff, 1);
insertion_sort_shift_left(
&mut v[start * val_size..end * val_size],
val_size,
presorted_start,
is_less,
scratch,
);
}
end
}
fn find_streak(
v: &[u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
) -> (usize, bool) {
debug_assert_eq!(v.len() % val_size, 0);
let len = v.len() / val_size;
if len < 2 {
return (len, false);
}
let v = v.as_ptr();
let mut end = 2;
unsafe {
let assume_reverse = is_less(v.add(val_size), v);
if assume_reverse {
while end < len && is_less(v.add(end * val_size), v.add((end - 1) * val_size)) {
end += 1;
}
(end, true)
} else {
while end < len && !is_less(v.add(end * val_size), v.add((end - 1) * val_size)) {
end += 1;
}
(end, false)
}
}
}
fn run_alloc(len: usize) -> *mut TimSortRun {
unsafe {
alloc::alloc(alloc::Layout::array::<TimSortRun>(len).unwrap_unchecked()) as *mut TimSortRun
}
}
fn run_dealloc(buf_ptr: *mut TimSortRun, len: usize) {
unsafe {
alloc::dealloc(
buf_ptr as *mut u8,
alloc::Layout::array::<TimSortRun>(len).unwrap_unchecked(),
);
}
}
pub fn stable_sort<T: Ord>(slice: &mut [T]) {
stable_sort_by(slice, |x, y| x.cmp(y))
}
pub fn stable_sort_by<T, F>(slice: &mut [T], cmp: F)
where
F: Fn(&T, &T) -> Ordering,
{
let byte_slice = unsafe {
std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, std::mem::size_of_val(slice))
};
let is_less = (&|x: *const u8, y: *const u8| {
let x = unsafe { &*(x as *const T) };
let y = unsafe { &*(y as *const T) };
cmp(x, y) == Ordering::Less
}) as &dyn Fn(*const u8, *const u8) -> bool;
merge_sort(
byte_slice,
size_of::<T>(),
align_of::<T>(),
is_less,
<MaybeUninit<T>>::uninit().as_mut_ptr() as *mut u8,
);
}
#[cfg(test)]
mod test {
use super::stable_sort;
use proptest::{collection::vec, prelude::*};
#[test]
fn test_stable_sort() {
let long_sorted: Vec<i32> = (0..1000).collect();
let long_reversed: Vec<i32> = (0..1000).rev().collect();
let corpus: Vec<Vec<i32>> = vec![
vec![],
vec![10],
vec![5, 4],
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
vec![16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 3, 4, 1, 2],
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1560281088, 234560113, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
long_sorted,
long_reversed,
];
for mut vec in corpus.into_iter() {
let mut expected = vec.clone();
expected.sort();
stable_sort(&mut vec);
assert_eq!(vec, expected);
}
}
prop_compose! {
fn short_vec()(batch in vec(any::<u32>(), 0..30)) -> Vec<u32> {
batch
}
}
prop_compose! {
fn long_vec()(batch in vec(any::<u32>(), 0..300)) -> Vec<u32> {
batch
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct TestStruct {
f1: u16,
f2: String,
f3: i64,
f4: i8,
f5: String,
f6: u8,
f7: Vec<u8>,
}
fn test_struct() -> impl Strategy<Value = TestStruct> {
(
any::<u16>(),
any::<String>(),
any::<i64>(),
any::<i8>(),
any::<String>(),
any::<u8>(),
vec(any::<u8>(), 0..20),
)
.prop_map(|(f1, f2, f3, f4, f5, f6, f7)| TestStruct {
f1,
f2,
f3,
f4,
f5,
f6,
f7,
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10000))]
#[test]
fn stable_sort_small_proptest(mut v in short_vec()) {
let mut expected = v.clone();
expected.sort();
stable_sort(&mut v);
assert_eq!(v, expected);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn stable_sort_small_structs_proptest(mut v in vec(test_struct(), 0..25)) {
let mut expected = v.clone();
expected.sort();
stable_sort(&mut v);
assert_eq!(v, expected);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn stable_sort_large_proptest(mut v in long_vec()) {
let mut expected = v.clone();
expected.sort();
stable_sort(&mut v);
assert_eq!(v, expected);
}
}
}