use crate::partition::reverse;
use core::{
mem::{self, size_of},
ptr,
};
use ndarray::{ArrayView1, ArrayViewMut1, Axis, IndexLonger, s};
use rayon::iter::{ParallelBridge, ParallelIterator};
struct SendPtr<T>(*mut T);
unsafe impl<T: Send> Send for SendPtr<T> {}
unsafe impl<T: Send> Sync for SendPtr<T> {}
impl<T> SendPtr<T> {
fn get(self) -> *mut T {
self.0
}
}
impl<T> Clone for SendPtr<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for SendPtr<T> {}
fn insert_head<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
if v.len() >= 2 && is_less(&v[1], &v[0]) {
unsafe {
let tmp = mem::ManuallyDrop::new(ptr::read(&v[0]));
let mut hole = InsertionHole {
src: &*tmp,
dest: &mut v[1],
};
ptr::copy_nonoverlapping(&v[1], &mut v[0], 1);
for i in 2..v.len() {
if !is_less(&v[i], &*tmp) {
break;
}
ptr::copy_nonoverlapping(&v[i], &mut v[i - 1], 1);
hole.dest = &mut v[i];
}
}
}
struct InsertionHole<T> {
src: *const T,
dest: *mut T,
}
impl<T> Drop for InsertionHole<T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}
}
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn merge<T, F>(v: ArrayViewMut1<'_, T>, mid: usize, buf: *mut T, is_less: &F)
where
F: Fn(&T, &T) -> bool,
{
let len = v.len();
let mut hole;
if mid <= len - mid {
unsafe {
for i in 0..mid {
ptr::copy_nonoverlapping(&v[i], buf.add(i), 1);
}
hole = MergeHole {
buf,
start: 0,
end: mid,
dest: 0,
v,
};
}
let mut right = mid;
while hole.start < hole.end && right < len {
unsafe {
let w = hole.v.view();
let to_copy = if is_less(w.uget(right), &*hole.buf.add(hole.start)) {
let idx = &mut right;
let old = hole.v.view_mut().index(*idx);
*idx += 1; old
} else {
let idx = &mut hole.start;
let old = hole.buf.add(*idx);
*idx += 1; old
};
let idx = &mut hole.dest;
let old = hole.v.view_mut().index(*idx);
*idx += 1; let dst = old;
ptr::copy_nonoverlapping(to_copy, dst, 1);
}
}
} else {
unsafe {
for i in 0..len - mid {
ptr::copy_nonoverlapping(&v[mid + i], buf.add(i), 1);
}
hole = MergeHole {
buf,
start: 0,
end: len - mid,
dest: mid,
v,
};
}
let mut out = len;
while 0 < hole.dest && 0 < hole.end {
unsafe {
let w = hole.v.view();
let to_copy = if is_less(&*hole.buf.add(hole.end - 1), w.uget(hole.dest - 1)) {
let idx = &mut hole.dest;
*idx -= 1; hole.v.view_mut().index(*idx)
} else {
let idx = &mut hole.end;
*idx -= 1; hole.buf.add(*idx)
};
let idx = &mut out;
*idx -= 1; let dst = hole.v.view_mut().index(*idx);
ptr::copy_nonoverlapping(to_copy, dst, 1);
}
}
}
struct MergeHole<'a, T> {
buf: *mut T,
start: usize,
end: usize,
v: ArrayViewMut1<'a, T>,
dest: usize,
}
impl<T> Drop for MergeHole<'_, T> {
fn drop(&mut self) {
unsafe {
let len = self.end - self.start; for i in 0..len {
let src = self.buf.add(self.start + i);
let dst = self.v.view_mut().index(self.dest + i);
ptr::copy_nonoverlapping(src, dst, 1);
}
}
}
}
}
#[must_use]
#[derive(Clone, Copy, PartialEq, Eq)]
enum MergesortResult {
NonDescending,
Descending,
Sorted,
}
#[derive(Clone, Copy)]
struct Run {
start: usize,
len: usize,
}
#[inline]
fn collapse(runs: &[Run]) -> Option<usize> {
let n = runs.len();
if n >= 2
&& (runs[n - 1].start == 0
|| runs[n - 2].len <= runs[n - 1].len
|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
{
if n >= 3 && runs[n - 3].len < runs[n - 1].len {
Some(n - 3)
} else {
Some(n - 2)
}
} else {
None
}
}
unsafe fn merge_sort<T, F>(mut v: ArrayViewMut1<'_, T>, buf: *mut T, is_less: &F) -> MergesortResult
where
T: Send,
F: Fn(&T, &T) -> bool,
{
unsafe {
const MIN_RUN: usize = 10;
let len = v.len();
let mut runs = vec![];
let mut end = len;
while end > 0 {
let mut start = end - 1;
if start > 0 {
start -= 1;
let w = v.view();
if is_less(w.uget(start + 1), w.uget(start)) {
while start > 0 && is_less(w.uget(start), w.uget(start - 1)) {
start -= 1;
}
if start == 0 && end == len {
return MergesortResult::Descending;
} else {
reverse(v.slice_mut(s![start..end]));
}
} else {
while start > 0 && !is_less(w.uget(start), w.uget(start - 1)) {
start -= 1;
}
if end - start == len {
return MergesortResult::NonDescending;
}
}
}
while start > 0 && end - start < MIN_RUN {
start -= 1;
insert_head(v.slice_mut(s![start..end]), &is_less);
}
runs.push(Run {
start,
len: end - start,
});
end = start;
while let Some(r) = collapse(&runs) {
let left = runs[r + 1];
let right = runs[r];
merge(
v.slice_mut(s![left.start..right.start + right.len]),
left.len,
buf,
&is_less,
);
runs[r] = Run {
start: left.start,
len: left.len + right.len,
};
runs.remove(r + 1);
}
}
debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
MergesortResult::Sorted
}
}
fn split_for_merge<T, F>(
left: ArrayView1<'_, T>,
right: ArrayView1<'_, T>,
is_less: &F,
) -> (usize, usize)
where
F: Fn(&T, &T) -> bool,
{
let left_len = left.len();
let right_len = right.len();
if left_len >= right_len {
let left_mid = left_len / 2;
let mut a = 0;
let mut b = right_len;
while a < b {
let m = a + (b - a) / 2;
if is_less(&right[m], &left[left_mid]) {
a = m + 1;
} else {
b = m;
}
}
(left_mid, a)
} else {
let right_mid = right_len / 2;
let mut a = 0;
let mut b = left_len;
while a < b {
let m = a + (b - a) / 2;
if is_less(&right[right_mid], &left[m]) {
b = m;
} else {
a = m + 1;
}
}
(a, right_mid)
}
}
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn par_merge<T, F>(
mut left: ArrayViewMut1<'_, T>,
mut right: ArrayViewMut1<'_, T>,
mut dest: ArrayViewMut1<'_, T>,
is_less: &F,
) where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
const MAX_SEQUENTIAL: usize = 5000;
let left_raw = left.raw_view_mut();
let right_raw = right.raw_view_mut();
let dest_raw = dest.raw_view_mut();
let left_len = left.len();
let right_len = right.len();
let left = unsafe { left_raw.deref_into_view_mut() };
let right = unsafe { right_raw.deref_into_view_mut() };
let dest = unsafe { dest_raw.deref_into_view_mut() };
let mut s = State {
left,
left_start: 0,
right,
right_start: 0,
dest,
dest_start: 0,
};
if left_len == 0 || right_len == 0 || left_len + right_len < MAX_SEQUENTIAL {
while s.left_start < s.left.len() && s.right_start < s.right.len() {
if is_less(&s.right[s.right_start], &s.left[s.left_start]) {
unsafe {
ptr::copy_nonoverlapping(&s.right[s.right_start], &mut s.dest[s.dest_start], 1)
};
s.right_start += 1;
} else {
unsafe {
ptr::copy_nonoverlapping(&s.left[s.left_start], &mut s.dest[s.dest_start], 1)
};
s.left_start += 1;
};
s.dest_start += 1;
}
} else {
let left = unsafe { left_raw.deref_into_view_mut() };
let right = unsafe { right_raw.deref_into_view_mut() };
let dest = unsafe { dest_raw.deref_into_view_mut() };
let (left_mid, right_mid) = split_for_merge(left.view(), right.view(), is_less);
let (left_l, left_r) = left.split_at(Axis(0), left_mid);
let (right_l, right_r) = right.split_at(Axis(0), right_mid);
mem::forget(s);
let (dest_l, dest_r) = dest.split_at(Axis(0), left_l.len() + right_l.len());
rayon::join(
move || unsafe { par_merge(left_l, right_l, dest_l, is_less) },
move || unsafe { par_merge(left_r, right_r, dest_r, is_less) },
);
}
struct State<'a, T> {
left: ArrayViewMut1<'a, T>,
left_start: usize,
right: ArrayViewMut1<'a, T>,
right_start: usize,
dest: ArrayViewMut1<'a, T>,
dest_start: usize,
}
impl<T> Drop for State<'_, T> {
fn drop(&mut self) {
unsafe {
let left_len = self.left.len() - self.left_start;
for i in 0..left_len {
ptr::copy_nonoverlapping(
&self.left[i + self.left_start],
&mut self.dest[i + self.dest_start],
1,
);
}
let right_len = self.right.len() - self.right_start;
for i in 0..right_len {
ptr::copy_nonoverlapping(
&self.right[i + self.right_start],
&mut self.dest[i + self.dest_start + left_len],
1,
);
}
}
}
}
}
#[warn(unsafe_op_in_unsafe_fn)]
unsafe fn recurse<T, F>(
mut v: ArrayViewMut1<'_, T>,
mut buf: ArrayViewMut1<'_, T>,
chunks: &[(usize, usize)],
into_buf: bool,
is_less: &F,
) where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
let v_raw = v.raw_view_mut();
let buf_raw = buf.raw_view_mut();
let len = chunks.len();
debug_assert!(len > 0);
if len == 1 {
if into_buf {
let (start, end) = chunks[0];
for i in start..end {
unsafe { ptr::copy_nonoverlapping(&v[i], &mut buf[i], 1) };
}
}
return;
}
let (start, _) = chunks[0];
let (mid, _) = chunks[len / 2];
let (_, end) = chunks[len - 1];
let (left, right) = chunks.split_at(len / 2);
let v = unsafe { v_raw.deref_into_view_mut() };
let buf = unsafe { buf_raw.deref_into_view_mut() };
let (mut src, mut dest) = if into_buf { (v, buf) } else { (buf, v) };
let guard = CopyOnDrop {
src: src.view_mut(),
dest: dest.view_mut(),
src_start: start,
dest_start: start,
len: end - start,
};
let v_left = unsafe { v_raw.deref_into_view_mut() };
let buf_left = unsafe { buf_raw.deref_into_view_mut() };
let v_right = unsafe { v_raw.deref_into_view_mut() };
let buf_right = unsafe { buf_raw.deref_into_view_mut() };
rayon::join(
move || {
unsafe {
recurse(
v_left, buf_left, left, !into_buf, is_less,
)
}
},
move || {
unsafe {
recurse(
v_right, buf_right, right, !into_buf, is_less,
)
}
},
);
mem::forget(guard);
let (src_left, src_right) = src.multi_slice_mut((s![start..mid], s![mid..end]));
unsafe { par_merge(src_left, src_right, dest.slice_mut(s![start..]), is_less) };
struct CopyOnDrop<'a, T> {
src: ArrayViewMut1<'a, T>,
dest: ArrayViewMut1<'a, T>,
src_start: usize,
dest_start: usize,
len: usize,
}
impl<T> Drop for CopyOnDrop<'_, T> {
fn drop(&mut self) {
unsafe {
for i in 0..self.len {
let a = self.src_start + i;
let b = self.dest_start + i;
ptr::copy_nonoverlapping(&self.src[a], &mut self.dest[b], 1);
}
}
}
}
}
pub fn par_merge_sort<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: F)
where
T: Send,
F: Fn(&T, &T) -> bool + Sync,
{
const MAX_INSERTION: usize = 20;
const CHUNK_LENGTH: usize = 2000;
if size_of::<T>() == 0 {
return;
}
let len = v.len();
if len <= MAX_INSERTION {
if len >= 2 {
for i in (0..len - 1).rev() {
insert_head(v.slice_mut(s![i..]), &is_less);
}
}
return;
}
let mut buf = Vec::<T>::with_capacity(len);
let buf = buf.as_mut_ptr();
if len <= CHUNK_LENGTH {
let res = unsafe { merge_sort(v.view_mut(), buf, &is_less) };
if res == MergesortResult::Descending {
reverse(v.view_mut());
}
return;
}
let mut iter = {
let buf = SendPtr(buf);
let is_less = &is_less;
let chunks_iter = v.axis_chunks_iter_mut(Axis(0), CHUNK_LENGTH);
let len = chunks_iter.len();
let mut chunks = Vec::with_capacity(len);
chunks_iter
.enumerate()
.zip(chunks.spare_capacity_mut())
.par_bridge()
.for_each(move |((i, chunk), out)| {
let l = CHUNK_LENGTH * i;
let r = l + chunk.len();
unsafe {
let buf = buf.get().add(l);
out.write((l, r, merge_sort(chunk, buf, is_less)));
}
});
unsafe { chunks.set_len(len) };
chunks.into_iter().peekable()
};
let mut chunks = Vec::with_capacity(iter.len());
while let Some((a, mut b, res)) = iter.next() {
if res != MergesortResult::Sorted {
while let Some(&(x, y, r)) = iter.peek() {
if r == res && (r == MergesortResult::Descending) == is_less(&v[x], &v[x - 1]) {
b = y;
iter.next();
} else {
break;
}
}
}
if res == MergesortResult::Descending {
reverse(v.slice_mut(s![a..b]));
}
chunks.push((a, b));
}
unsafe {
let buf = ArrayViewMut1::from_shape_ptr(len, buf);
recurse(v, buf, &chunks, false, &is_less);
}
}
#[cfg(test)]
mod test {
use super::{par_merge_sort, split_for_merge};
use core::cmp::Ordering;
use ndarray::{Array1, ArrayView1, s};
use quickcheck_macros::quickcheck;
use rand::distr::Uniform;
use rand::{Rng, rng};
#[test]
fn split() {
fn check(left: &[u32], right: &[u32]) {
let left = ArrayView1::from_shape(left.len(), left).unwrap();
let right = ArrayView1::from_shape(right.len(), right).unwrap();
let (l, r) = split_for_merge(left, right, &|&a, &b| a < b);
assert!(
left.slice(s![..l])
.iter()
.all(|&x| right.slice(s![r..]).iter().all(|&y| x <= y))
);
assert!(
right
.slice(s![..r])
.iter()
.all(|&x| left.slice(s![l..]).iter().all(|&y| x < y))
);
}
check(&[1, 2, 2, 2, 2, 3], &[1, 2, 2, 2, 2, 3]);
check(&[1, 2, 2, 2, 2, 3], &[]);
check(&[], &[1, 2, 2, 2, 2, 3]);
let rng = &mut rng();
for _ in 0..100 {
let limit: u32 = rng.random_range(1..21);
let left_len: usize = rng.random_range(0..20);
let right_len: usize = rng.random_range(0..20);
let mut left = rng
.sample_iter(&Uniform::new(0, limit).unwrap())
.take(left_len)
.collect::<Vec<_>>();
let mut right = rng
.sample_iter(&Uniform::new(0, limit).unwrap())
.take(right_len)
.collect::<Vec<_>>();
left.sort();
right.sort();
check(&left, &right);
}
}
#[derive(Debug, Clone, Copy)]
struct Item {
index: usize,
value: u32,
}
impl Eq for Item {}
impl PartialEq for Item {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl Ord for Item {
fn cmp(&self, other: &Self) -> Ordering {
self.value.cmp(&other.value)
}
}
impl PartialOrd for Item {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl From<(usize, u32)> for Item {
fn from((index, value): (usize, u32)) -> Self {
Self { index, value }
}
}
#[cfg_attr(miri, ignore)]
#[quickcheck]
fn stably_sorted(xs: Vec<u32>) {
let xs = xs
.into_iter()
.enumerate()
.map(Item::from)
.collect::<Vec<Item>>();
let mut sorted = xs.clone();
sorted.sort();
let sorted = Array1::from_vec(sorted);
let mut array = Array1::from_vec(xs);
par_merge_sort(array.view_mut(), Item::lt);
for (a, s) in array.iter().zip(&sorted) {
assert_eq!(a.index, s.index);
assert_eq!(a.value, s.value);
}
}
}