use core::{marker::PhantomData, mem, ptr};
use ndarray::{ArrayViewMut1, IndexLonger, s};
#[must_use]
pub(super) struct InsertionHole<'a, T> {
pub(super) src: *const T,
pub(super) dest: *mut T,
marker: PhantomData<&'a mut T>,
}
impl<'a, T> InsertionHole<'a, T> {
pub(super) unsafe fn new(src: &'a T, dest: *mut T) -> Self {
Self {
src,
dest,
marker: PhantomData,
}
}
}
impl<T> Drop for InsertionHole<'_, T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn insert_tail<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
debug_assert!(v.len() >= 2);
let i = v.len() - 1;
unsafe {
let w = v.view();
let w = w.raw_view().deref_into_view();
let mut v = v.raw_view_mut().deref_into_view_mut();
if is_less(w.uget(i), w.uget(i - 1)) {
let tmp = mem::ManuallyDrop::new(ptr::read(v.view_mut().uget(i)));
let mut hole = InsertionHole::new(&*tmp, v.view_mut().uget(i - 1));
ptr::copy_nonoverlapping(hole.dest, v.view_mut().uget(i), 1);
for j in (0..(i - 1)).rev() {
let j_ptr = v.view_mut().uget(j);
if !is_less(&*tmp, &*j_ptr) {
break;
}
ptr::copy_nonoverlapping(j_ptr, hole.dest, 1);
hole.dest = j_ptr;
}
}
}
}
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn insert_head<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
debug_assert!(v.len() >= 2);
unsafe {
let w = v.view();
let w = w.raw_view().deref_into_view();
if is_less(w.uget(1), w.uget(0)) {
let mut v = v.raw_view_mut().deref_into_view_mut();
let tmp = mem::ManuallyDrop::new(ptr::read(v.view_mut().uget(0)));
let dest = v.view_mut().uget(1);
let mut hole = InsertionHole::new(&*tmp, dest);
ptr::copy_nonoverlapping(dest, v.view_mut().uget(0), 1);
for i in 2..v.len() {
if !is_less(w.uget(i), &*tmp) {
break;
}
ptr::copy_nonoverlapping(w.uget(i), v.view_mut().uget(i - 1), 1);
hole.dest = v.view_mut().uget(i) as *mut T;
}
}
}
}
#[inline(never)]
pub(super) fn insertion_sort_shift_left<T, F>(
mut v: ArrayViewMut1<'_, T>,
offset: usize,
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
assert!(offset != 0 && offset <= len);
for i in offset..len {
unsafe {
insert_tail(v.slice_mut(s![..=i]), is_less);
}
}
}
#[inline(never)]
pub(super) fn insertion_sort_shift_right<T, F>(
mut v: ArrayViewMut1<'_, T>,
offset: usize,
is_less: &mut F,
) where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
assert!(offset != 0 && offset <= len && len >= 2);
for i in (0..offset).rev() {
unsafe {
insert_head(v.slice_mut(s![i..len]), is_less);
}
}
}
#[cold]
pub fn partial_insertion_sort<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F) -> bool
where
F: FnMut(&T, &T) -> bool,
{
const MAX_STEPS: usize = 5;
const SHORTEST_SHIFTING: usize = 50;
let len = v.len();
let mut i = 1;
for _ in 0..MAX_STEPS {
unsafe {
let v = v.view();
while i < len && !is_less(v.uget(i), v.uget(i - 1)) {
i += 1;
}
}
if i == len {
return true;
}
if len < SHORTEST_SHIFTING {
return false;
}
v.swap(i - 1, i);
if i >= 2 {
insertion_sort_shift_left(v.slice_mut(s![..i]), i - 1, is_less);
insertion_sort_shift_right(v.slice_mut(s![..i]), 1, is_less);
}
}
false
}
#[cfg(feature = "std")]
#[cfg(test)]
mod test {
use super::insertion_sort_shift_left;
use ndarray::Array1;
use quickcheck_macros::quickcheck;
#[quickcheck]
fn sorted(xs: Vec<u32>) {
let mut array = Array1::from_vec(xs);
if !array.is_empty() {
insertion_sort_shift_left(array.view_mut(), 1, &mut u32::lt);
}
for i in 1..array.len() {
assert!(array[i - 1] <= array[i]);
}
}
}