#![cfg(feature = "alloc")]
use crate::insertion_sort::insertion_sort_shift_left;
use crate::partition::reverse;
use core::{cmp, mem, ptr};
use ndarray::{ArrayView1, ArrayViewMut1, IndexLonger, s};
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn merge<T, F>(v: ArrayViewMut1<'_, T>, mid: usize, buf: *mut T, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
let mut hole;
if mid <= len - mid {
unsafe {
for i in 0..mid {
ptr::copy_nonoverlapping(&v[i], buf.add(i), 1);
}
hole = MergeHole {
buf,
start: 0,
end: mid,
dest: 0,
v,
};
}
let left = &mut hole.start;
let mut right = mid; let out = &mut hole.dest;
while *left < hole.end && right < len {
unsafe {
let w = hole.v.view();
let is_l = is_less(w.uget(right), &*hole.buf.add(*left));
let to_copy = if is_l {
w.uget(right)
} else {
&*hole.buf.add(*left)
};
ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(*out), 1);
*out += 1;
if is_l {
right += 1;
} else {
*left += 1;
}
}
}
} else {
unsafe {
for i in 0..len - mid {
ptr::copy_nonoverlapping(&v[mid + i], buf.add(i), 1);
}
hole = MergeHole {
buf,
start: 0,
end: len - mid,
dest: mid,
v,
};
}
let left = &mut hole.dest;
let right = &mut hole.end;
let mut out = len;
while 0 < *left && 0 < *right {
unsafe {
let w = hole.v.view();
let is_l = is_less(&*hole.buf.add(*right - 1), w.uget(*left - 1));
if is_l {
*left -= 1;
} else {
*right -= 1;
}
let to_copy = if is_l {
w.uget(*left)
} else {
&*hole.buf.add(*right)
};
out -= 1;
ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(out), 1);
}
}
}
struct MergeHole<'a, T> {
buf: *mut T,
start: usize,
end: usize,
v: ArrayViewMut1<'a, T>,
dest: usize,
}
impl<T> Drop for MergeHole<'_, T> {
fn drop(&mut self) {
unsafe {
let len = self.end - self.start; for i in 0..len {
let src = self.buf.add(self.start + i);
let dst = self.v.view_mut().index(self.dest + i);
ptr::copy_nonoverlapping(src, dst, 1);
}
}
}
}
}
pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>(
mut v: ArrayViewMut1<'_, T>,
is_less: &mut CmpF,
elem_alloc_fn: ElemAllocF,
elem_dealloc_fn: ElemDeallocF,
run_alloc_fn: RunAllocF,
run_dealloc_fn: RunDeallocF,
) where
CmpF: FnMut(&T, &T) -> bool,
ElemAllocF: Fn(usize) -> *mut T,
ElemDeallocF: Fn(*mut T, usize),
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
const MAX_INSERTION: usize = 20;
debug_assert!(mem::size_of::<T>() > 0);
let len = v.len();
if len <= MAX_INSERTION {
if len >= 2 {
insertion_sort_shift_left(v, 1, is_less);
}
return;
}
let buf = BufGuard::new(len / 2, elem_alloc_fn, elem_dealloc_fn);
let buf_ptr = buf.buf_ptr.as_ptr();
let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn);
let mut end = 0;
let mut start = 0;
while end < len {
let (streak_end, was_reversed) = find_streak(v.slice(s![start..]), is_less);
end += streak_end;
if was_reversed {
reverse(v.slice_mut(s![start..end]));
}
end = provide_sorted_batch(v.view_mut(), start, end, is_less);
runs.push(TimSortRun {
start,
len: end - start,
});
start = end;
while let Some(r) = collapse(runs.as_slice(), len) {
let left = runs[r];
let right = runs[r + 1];
let merge_slice = v.slice_mut(s![left.start..right.start + right.len]);
unsafe {
merge(merge_slice, left.len, buf_ptr, is_less);
}
runs[r + 1] = TimSortRun {
start: left.start,
len: left.len + right.len,
};
runs.remove(r);
}
}
debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
#[inline]
fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> {
let n = runs.len();
if n >= 2
&& (runs[n - 1].start + runs[n - 1].len == stop
|| runs[n - 2].len <= runs[n - 1].len
|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
{
if n >= 3 && runs[n - 3].len < runs[n - 1].len {
Some(n - 3)
} else {
Some(n - 2)
}
} else {
None
}
}
struct BufGuard<T, ElemDeallocF>
where
ElemDeallocF: Fn(*mut T, usize),
{
buf_ptr: ptr::NonNull<T>,
capacity: usize,
elem_dealloc_fn: ElemDeallocF,
}
impl<T, ElemDeallocF> BufGuard<T, ElemDeallocF>
where
ElemDeallocF: Fn(*mut T, usize),
{
fn new<ElemAllocF>(
len: usize,
elem_alloc_fn: ElemAllocF,
elem_dealloc_fn: ElemDeallocF,
) -> Self
where
ElemAllocF: Fn(usize) -> *mut T,
{
Self {
buf_ptr: ptr::NonNull::new(elem_alloc_fn(len)).unwrap(),
capacity: len,
elem_dealloc_fn,
}
}
}
impl<T, ElemDeallocF> Drop for BufGuard<T, ElemDeallocF>
where
ElemDeallocF: Fn(*mut T, usize),
{
fn drop(&mut self) {
(self.elem_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
}
}
struct RunVec<RunAllocF, RunDeallocF>
where
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
buf_ptr: ptr::NonNull<TimSortRun>,
capacity: usize,
len: usize,
run_alloc_fn: RunAllocF,
run_dealloc_fn: RunDeallocF,
}
impl<RunAllocF, RunDeallocF> RunVec<RunAllocF, RunDeallocF>
where
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
fn new(run_alloc_fn: RunAllocF, run_dealloc_fn: RunDeallocF) -> Self {
const START_RUN_CAPACITY: usize = 16;
Self {
buf_ptr: ptr::NonNull::new(run_alloc_fn(START_RUN_CAPACITY)).unwrap(),
capacity: START_RUN_CAPACITY,
len: 0,
run_alloc_fn,
run_dealloc_fn,
}
}
fn push(&mut self, val: TimSortRun) {
if self.len == self.capacity {
let old_capacity = self.capacity;
let old_buf_ptr = self.buf_ptr.as_ptr();
self.capacity *= 2;
self.buf_ptr = ptr::NonNull::new((self.run_alloc_fn)(self.capacity)).unwrap();
unsafe {
ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr.as_ptr(), old_capacity);
}
(self.run_dealloc_fn)(old_buf_ptr, old_capacity);
}
unsafe {
self.buf_ptr.as_ptr().add(self.len).write(val);
}
self.len += 1;
}
fn remove(&mut self, index: usize) {
if index >= self.len {
panic!("Index out of bounds");
}
unsafe {
let ptr = self.buf_ptr.as_ptr().add(index);
ptr::copy(ptr.add(1), ptr, self.len - index - 1);
}
self.len -= 1;
}
fn as_slice(&self) -> &[TimSortRun] {
unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr.as_ptr(), self.len) }
}
fn len(&self) -> usize {
self.len
}
}
impl<RunAllocF, RunDeallocF> core::ops::Index<usize> for RunVec<RunAllocF, RunDeallocF>
where
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
type Output = TimSortRun;
fn index(&self, index: usize) -> &Self::Output {
if index < self.len {
unsafe {
return &*(self.buf_ptr.as_ptr().add(index));
}
}
panic!("Index out of bounds");
}
}
impl<RunAllocF, RunDeallocF> core::ops::IndexMut<usize> for RunVec<RunAllocF, RunDeallocF>
where
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
if index < self.len {
unsafe {
return &mut *(self.buf_ptr.as_ptr().add(index));
}
}
panic!("Index out of bounds");
}
}
impl<RunAllocF, RunDeallocF> Drop for RunVec<RunAllocF, RunDeallocF>
where
RunAllocF: Fn(usize) -> *mut TimSortRun,
RunDeallocF: Fn(*mut TimSortRun, usize),
{
fn drop(&mut self) {
(self.run_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TimSortRun {
len: usize,
start: usize,
}
fn provide_sorted_batch<T, F>(
mut v: ArrayViewMut1<'_, T>,
start: usize,
mut end: usize,
is_less: &mut F,
) -> usize
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
assert!(end >= start && end <= len);
const MIN_INSERTION_RUN: usize = 10;
let start_end_diff = end - start;
if start_end_diff < MIN_INSERTION_RUN && end < len {
end = cmp::min(start + MIN_INSERTION_RUN, len);
let presorted_start = cmp::max(start_end_diff, 1);
insertion_sort_shift_left(v.slice_mut(s![start..end]), presorted_start, is_less);
}
end
}
fn find_streak<T, F>(v: ArrayView1<'_, T>, is_less: &mut F) -> (usize, bool)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
if len < 2 {
return (len, false);
}
let mut end = 2;
unsafe {
let assume_reverse = is_less(v.uget(1), v.uget(0));
if assume_reverse {
while end < len && is_less(v.uget(end), v.uget(end - 1)) {
end += 1;
}
(end, true)
} else {
while end < len && !is_less(v.uget(end), v.uget(end - 1)) {
end += 1;
}
(end, false)
}
}
}