use std::ptr;
use std::{
cmp,
cmp::Ordering,
fmt::Debug,
mem::{size_of, MaybeUninit},
ops::Range,
};
unsafe fn swap_if_different(x: *mut u8, y: *mut u8, count: usize) {
if x != y {
ptr::swap_nonoverlapping(x, y, count);
}
}
#[inline]
fn vec_len(v: &[u8], val_size: usize) -> usize {
v.len() / val_size
}
#[inline]
unsafe fn vec_index_unchecked(v: &mut [u8], val_size: usize, index: usize) -> *mut u8 {
let v = v.as_mut_ptr();
v.add(index * val_size)
}
#[inline]
fn vec_suffix(v: &mut [u8], val_size: usize, start: usize) -> &mut [u8] {
&mut v[start * val_size..]
}
#[inline]
fn vec_slice(v: &mut [u8], val_size: usize, start: usize, end: usize) -> &mut [u8] {
&mut v[start * val_size..end * val_size]
}
#[inline]
fn vec_split_at(v: &mut [u8], val_size: usize, start: usize) -> (&mut [u8], &mut [u8]) {
v.split_at_mut(start * val_size)
}
#[inline]
fn vec_swap(v: &mut [u8], val_size: usize, index1: usize, index2: usize) {
let v = v.as_mut_ptr();
unsafe { swap_if_different(v.add(index1 * val_size), v.add(index2 * val_size), val_size) };
}
fn vec_reverse(v: &mut [u8], val_size: usize) {
let len = v.len() / val_size;
let half_len = len / 2;
let Range { start, end } = v.as_mut_ptr_range();
let mut i = 0;
while i < half_len {
unsafe {
ptr::swap_nonoverlapping(
start.add(i * val_size),
end.sub((1 + i) * val_size),
val_size,
);
}
i += 1;
}
}
struct InsertionHole {
src: *const u8,
dest: *mut u8,
val_size: usize,
}
impl Drop for InsertionHole {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, self.val_size);
}
}
}
unsafe fn insert_tail(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
debug_assert!(v.len() / val_size >= 2);
let arr_ptr = v.as_mut_ptr();
let i = v.len() / val_size - 1;
unsafe {
let i_ptr = arr_ptr.add(i * val_size);
if is_less(i_ptr, i_ptr.sub(val_size)) {
ptr::copy_nonoverlapping(i_ptr, scratch, val_size);
let mut hole = InsertionHole {
src: scratch,
dest: i_ptr.sub(val_size),
val_size,
};
ptr::copy_nonoverlapping(hole.dest, i_ptr, val_size);
for j in (0..(i - 1)).rev() {
let j_ptr = arr_ptr.add(j * val_size);
if !is_less(scratch, j_ptr) {
break;
}
ptr::copy_nonoverlapping(j_ptr, hole.dest, val_size);
hole.dest = j_ptr;
}
}
}
}
unsafe fn insert_head(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
debug_assert!(vec_len(v, val_size) >= 2);
unsafe {
if is_less(
vec_index_unchecked(v, val_size, 1),
vec_index_unchecked(v, val_size, 0),
) {
let arr_ptr = v.as_mut_ptr();
ptr::copy_nonoverlapping(arr_ptr, scratch, val_size);
let mut hole = InsertionHole {
src: scratch,
dest: arr_ptr.add(val_size),
val_size,
};
ptr::copy_nonoverlapping(arr_ptr.add(val_size), arr_ptr.add(0), val_size);
for i in 2..vec_len(v, val_size) {
if !is_less(vec_index_unchecked(v, val_size, i), scratch) {
break;
}
ptr::copy_nonoverlapping(
arr_ptr.add(val_size),
arr_ptr.add(val_size * (i - 1)),
val_size,
);
hole.dest = arr_ptr.add(val_size * i);
}
}
}
}
#[inline(never)]
pub(super) fn insertion_sort_shift_left(
v: &mut [u8],
val_size: usize,
offset: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
let len = v.len() / val_size;
assert!(offset != 0 && offset <= len);
for i in offset..len {
unsafe {
insert_tail(
&mut v[..=(i + 1) * val_size - 1],
val_size,
is_less,
scratch,
);
}
}
}
#[inline(never)]
fn insertion_sort_shift_right(
v: &mut [u8],
val_size: usize,
offset: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
let len = vec_len(v, val_size);
assert!(offset != 0 && offset <= len && len >= 2);
for i in (0..offset).rev() {
unsafe {
insert_head(vec_slice(v, val_size, i, len), val_size, is_less, scratch);
}
}
}
#[cold]
fn partial_insertion_sort(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) -> bool {
const MAX_STEPS: usize = 5;
const SHORTEST_SHIFTING: usize = 50;
let len = vec_len(v, val_size);
let mut i = 1;
for _ in 0..MAX_STEPS {
unsafe {
while i < len
&& !is_less(
vec_index_unchecked(v, val_size, i),
vec_index_unchecked(v, val_size, i - 1),
)
{
i += 1;
}
}
if i == len {
return true;
}
if len < SHORTEST_SHIFTING {
return false;
}
vec_swap(v, val_size, i - 1, i);
if i >= 2 {
insertion_sort_shift_left(
vec_slice(v, val_size, 0, i),
val_size,
i - 1,
is_less,
scratch,
);
insertion_sort_shift_right(vec_slice(v, val_size, 0, i), val_size, 1, is_less, scratch);
}
}
false
}
#[cold]
pub fn heapsort(v: &mut [u8], val_size: usize, is_less: &dyn Fn(*const u8, *const u8) -> bool) {
let sift_down = |v: &mut [u8], mut node| {
let len = vec_len(v, val_size);
loop {
let mut child = 2 * node + 1;
if child >= len {
break;
}
if child + 1 < len {
unsafe {
child += is_less(
vec_index_unchecked(v, val_size, child),
vec_index_unchecked(v, val_size, child + 1),
) as usize
};
}
if !unsafe {
is_less(
vec_index_unchecked(v, val_size, node),
vec_index_unchecked(v, val_size, child),
)
} {
break;
}
vec_swap(v, val_size, node, child);
node = child;
}
};
for i in (0..vec_len(v, val_size) / 2).rev() {
sift_down(v, i);
}
for i in (1..vec_len(v, val_size)).rev() {
vec_swap(v, val_size, 0, i);
sift_down(vec_slice(v, val_size, 0, i), 0);
}
}
fn partition_in_blocks(
v: &mut [u8],
val_size: usize,
pivot: *const u8,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) -> usize {
const BLOCK: usize = 128;
let mut l = v.as_mut_ptr();
let mut block_l = BLOCK;
let mut start_l = ptr::null_mut();
let mut end_l = ptr::null_mut();
let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
let mut r = unsafe { l.add(v.len()) };
let mut block_r = BLOCK;
let mut start_r = ptr::null_mut();
let mut end_r = ptr::null_mut();
let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
fn width(l: *mut u8, r: *mut u8, type_size: usize) -> usize {
assert!(type_size > 0);
(r as usize - l as usize) / type_size
}
loop {
let is_done = width(l, r, val_size) <= 2 * BLOCK;
if is_done {
let mut rem = width(l, r, val_size);
if start_l < end_l || start_r < end_r {
rem -= BLOCK;
}
if start_l < end_l {
block_r = rem;
} else if start_r < end_r {
block_l = rem;
} else {
block_l = rem / 2;
block_r = rem - block_l;
}
debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
debug_assert!(width(l, r, val_size) == block_l + block_r);
}
if start_l == end_l {
start_l = offsets_l.as_mut_ptr() as *mut u8;
end_l = start_l;
let mut elem = l;
for i in 0..block_l {
unsafe {
*end_l = i as u8;
end_l = end_l.add(!is_less(&*elem, pivot) as usize);
elem = elem.add(val_size);
}
}
}
if start_r == end_r {
start_r = offsets_r.as_mut_ptr() as *mut u8;
end_r = start_r;
let mut elem = r;
for i in 0..block_r {
unsafe {
elem = elem.sub(val_size);
*end_r = i as u8;
end_r = end_r.add(is_less(&*elem, pivot) as usize);
}
}
}
let count = cmp::min(width(start_l, end_l, 1), width(start_r, end_r, 1));
if count > 0 {
macro_rules! left {
() => {
l.add(usize::from(*start_l) * val_size)
};
}
macro_rules! right {
() => {
r.sub((usize::from(*start_r) + 1) * val_size)
};
}
unsafe {
ptr::copy_nonoverlapping(left!(), scratch, val_size);
ptr::copy_nonoverlapping(right!(), left!(), val_size);
for _ in 1..count {
start_l = start_l.add(1);
ptr::copy_nonoverlapping(left!(), right!(), val_size);
start_r = start_r.add(1);
ptr::copy_nonoverlapping(right!(), left!(), val_size);
}
ptr::copy_nonoverlapping(scratch, right!(), val_size);
start_l = start_l.add(1);
start_r = start_r.add(1);
}
}
if start_l == end_l {
l = unsafe { l.add(block_l * val_size) };
}
if start_r == end_r {
r = unsafe { r.sub(block_r * val_size) };
}
if is_done {
break;
}
}
if start_l < end_l {
debug_assert_eq!(width(l, r, val_size), block_l);
while start_l < end_l {
unsafe {
end_l = end_l.sub(val_size);
swap_if_different(
l.add(usize::from(*end_l) * val_size),
r.sub(val_size),
val_size,
);
r = r.sub(val_size);
}
}
width(v.as_mut_ptr(), r, val_size)
} else if start_r < end_r {
debug_assert_eq!(width(l, r, val_size), block_r);
while start_r < end_r {
unsafe {
end_r = end_r.sub(1);
swap_if_different(l, r.sub(val_size * (usize::from(*end_r) + 1)), val_size);
l = l.add(val_size);
}
}
width(v.as_mut_ptr(), l, val_size)
} else {
width(v.as_mut_ptr(), l, val_size)
}
}
pub(super) fn partition(
v: &mut [u8],
val_size: usize,
pivot: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) -> (usize, bool) {
let (mid, was_partitioned) = {
vec_swap(v, val_size, 0, pivot);
let (pivot, v) = vec_split_at(v, val_size, 1);
let pivot = pivot.as_mut_ptr();
let mut l = 0;
let mut r = vec_len(v, val_size);
unsafe {
while l < r && is_less(vec_index_unchecked(v, val_size, l), pivot) {
l += 1;
}
while l < r && !is_less(vec_index_unchecked(v, val_size, r - 1), pivot) {
r -= 1;
}
}
(
l + partition_in_blocks(
vec_slice(v, val_size, l, r),
val_size,
pivot,
is_less,
scratch,
),
l >= r,
)
};
vec_swap(v, val_size, 0, mid);
(mid, was_partitioned)
}
pub(super) fn partition_equal(
v: &mut [u8],
val_size: usize,
pivot: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
) -> usize {
vec_swap(v, val_size, 0, pivot);
let (pivot, v) = vec_split_at(v, val_size, 1);
let pivot = pivot.as_mut_ptr();
let len = vec_len(v, val_size);
if len == 0 {
return 0;
}
let mut l = 0;
let mut r = len;
loop {
unsafe {
while l < r && !is_less(pivot, vec_index_unchecked(v, val_size, l)) {
l += 1;
}
loop {
r -= 1;
if l >= r || !is_less(pivot, vec_index_unchecked(v, val_size, r)) {
break;
}
}
if l >= r {
break;
}
vec_swap(v, val_size, l, r);
l += 1;
}
}
l + 1
}
#[cold]
pub(super) fn break_patterns(v: &mut [u8], val_size: usize) {
let len = vec_len(v, val_size);
if len >= 8 {
let mut seed = len;
let mut gen_usize = || {
if usize::BITS <= 32 {
let mut r = seed as u32;
r ^= r << 13;
r ^= r >> 17;
r ^= r << 5;
seed = r as usize;
seed
} else {
let mut r = seed as u64;
r ^= r << 13;
r ^= r >> 7;
r ^= r << 17;
seed = r as usize;
seed
}
};
let modulus = len.next_power_of_two();
let pos = len / 4 * 2;
for i in 0..3 {
let mut other = gen_usize() & (modulus - 1);
if other >= len {
other -= len;
}
vec_swap(v, val_size, pos - 1 + i, other);
}
}
}
pub(super) fn choose_pivot(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
) -> (usize, bool) {
const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
const MAX_SWAPS: usize = 4 * 3;
let len = vec_len(v, val_size);
let mut a = len / 4;
let mut b = len / 4 * 2;
let mut c = len / 4 * 3;
let mut swaps = 0;
if len >= 8 {
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(
vec_index_unchecked(v, val_size, *b),
vec_index_unchecked(v, val_size, *a),
) {
ptr::swap(a, b);
swaps += 1;
}
};
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
sort3(&mut a, &mut b, &mut c);
}
if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
vec_reverse(v, val_size);
(len - 1 - b, true)
}
}
fn recurse(
mut v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
mut pred: Option<*const u8>,
mut limit: u32,
scratch: *mut u8,
) {
const MAX_INSERTION: usize = 20;
let mut was_balanced = true;
let mut was_partitioned = true;
loop {
let len = vec_len(v, val_size);
if len <= MAX_INSERTION {
if len >= 2 {
insertion_sort_shift_left(v, val_size, 1, is_less, scratch);
}
return;
}
if limit == 0 {
heapsort(v, val_size, is_less);
return;
}
if !was_balanced {
break_patterns(v, val_size);
limit -= 1;
}
let (pivot, likely_sorted) = choose_pivot(v, val_size, is_less);
if was_balanced && was_partitioned && likely_sorted {
if partial_insertion_sort(v, val_size, is_less, scratch) {
return;
}
}
if let Some(p) = pred {
if !unsafe { is_less(p, vec_index_unchecked(v, val_size, pivot)) } {
let mid = partition_equal(v, val_size, pivot, is_less);
v = vec_suffix(v, val_size, mid);
continue;
}
}
let (mid, was_p) = partition(v, val_size, pivot, is_less, scratch);
was_balanced = cmp::min(mid, len - mid) >= len / 8;
was_partitioned = was_p;
let (left, right) = vec_split_at(v, val_size, mid);
let (pivot, right) = vec_split_at(right, val_size, 1);
let pivot = pivot.as_ptr();
if left.len() < right.len() {
recurse(left, val_size, is_less, pred, limit, scratch);
v = right;
pred = Some(pivot);
} else {
recurse(right, val_size, is_less, Some(pivot), limit, scratch);
v = left;
}
}
}
pub fn quicksort(
v: &mut [u8],
val_size: usize,
is_less: &dyn Fn(*const u8, *const u8) -> bool,
scratch: *mut u8,
) {
if val_size == 0 {
return;
}
let limit = usize::BITS - vec_len(v, val_size).leading_zeros();
recurse(v, val_size, is_less, None, limit, scratch);
}
pub fn unstable_sort<T: Ord + Debug>(slice: &mut [T]) {
unstable_sort_by(slice, |x, y| x.cmp(y))
}
pub fn unstable_sort_by<T, F>(slice: &mut [T], cmp: F)
where
F: Fn(&T, &T) -> Ordering,
{
let byte_slice = unsafe {
std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, std::mem::size_of_val(slice))
};
let is_less = (&|x: *const u8, y: *const u8| {
let x = unsafe { &*(x as *const T) };
let y = unsafe { &*(y as *const T) };
cmp(x, y) == Ordering::Less
}) as &dyn Fn(*const u8, *const u8) -> bool;
quicksort(
byte_slice,
size_of::<T>(),
is_less,
<MaybeUninit<T>>::uninit().as_mut_ptr() as *mut u8,
);
}
#[cfg(test)]
mod test {
use super::unstable_sort;
use proptest::{collection::vec, prelude::*};
#[test]
fn test_unstable_sort() {
let corpus: Vec<Vec<i32>> = vec![
vec![],
vec![10],
vec![5, 4],
vec![1, 2, 3, 4, 5],
vec![5, 3, 4, 1, 2],
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1560281088, 234560113, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
];
for mut vec in corpus.into_iter() {
let mut expected = vec.clone();
expected.sort();
unstable_sort(&mut vec);
assert_eq!(vec, expected);
}
}
prop_compose! {
fn short_vec()(batch in vec(any::<u32>(), 0..30)) -> Vec<u32> {
batch
}
}
prop_compose! {
fn long_vec()(batch in vec(any::<u32>(), 0..300)) -> Vec<u32> {
batch
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct TestStruct {
f1: u16,
f2: String,
f3: i64,
f4: i8,
f5: String,
f6: u8,
f7: Vec<u8>,
}
fn test_struct() -> impl Strategy<Value = TestStruct> {
(
any::<u16>(),
any::<String>(),
any::<i64>(),
any::<i8>(),
any::<String>(),
any::<u8>(),
vec(any::<u8>(), 0..20),
)
.prop_map(|(f1, f2, f3, f4, f5, f6, f7)| TestStruct {
f1,
f2,
f3,
f4,
f5,
f6,
f7,
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10000))]
#[test]
fn unstable_sort_small_proptest(mut v in short_vec()) {
let mut expected = v.clone();
expected.sort();
unstable_sort(&mut v);
assert_eq!(v, expected);
}
}
proptest! {
#[test]
fn unstable_sort_small_structs_proptest(mut v in vec(test_struct(), 0..25)) {
let mut expected = v.clone();
expected.sort();
unstable_sort(&mut v);
assert_eq!(v, expected);
}
}
proptest! {
#[test]
fn unstable_sort_large_proptest(mut v in long_vec()) {
let mut expected = v.clone();
expected.sort();
unstable_sort(&mut v);
assert_eq!(v, expected);
}
}
}