use core::marker::PhantomData;
use crate::gap_guard::GapGuard;
use crate::mut_slice::states::{AlwaysInit, Init, Uninit, Weak};
use crate::mut_slice::{Brand, MutSlice, Unbranded};
use crate::tracking::ptr;
use crate::util::*;
pub struct GapLeft; pub struct GapRight; pub struct GapBoth;
pub trait HasLeftGap {}
impl HasLeftGap for GapLeft {}
impl HasLeftGap for GapBoth {}
pub trait HasRightGap {}
impl HasRightGap for GapRight {}
impl HasRightGap for GapBoth {}
pub struct BranchlessMergeState<'l, 'r, 'dst, T, G> {
dst: MutSlice<'dst, Unbranded, T, Weak>,
left_begin: *mut T,
left_end: *mut T,
right_begin: *mut T,
right_end: *mut T,
_gap: G,
_lt: PhantomData<(&'l mut (), &'r mut ())>,
}
impl<'l, 'r, 'dst, T, G> BranchlessMergeState<'l, 'r, 'dst, T, G> {
fn new<BL, BR, BD>(
left: MutSlice<'l, BL, T, Weak>,
right: MutSlice<'r, BR, T, Weak>,
dst: MutSlice<'dst, BD, T, Weak>,
gap: G,
) -> Self {
if left.len() + right.len() != dst.len() {
abort();
}
Self {
left_begin: left.begin(),
left_end: left.end(),
right_begin: right.begin(),
right_end: right.end(),
dst: dst.weak().forget_brand(),
_gap: gap,
_lt: PhantomData,
}
}
}
impl<'l, 'r, 'dst, T> BranchlessMergeState<'l, 'r, 'dst, T, GapBoth> {
pub fn new_disjoint<BL, BR, BD>(
left: MutSlice<'l, BL, T, Init>,
right: MutSlice<'r, BR, T, Init>,
dst: MutSlice<'dst, BD, T, Uninit>,
) -> Self {
Self::new(left.weak(), right.weak(), dst.weak(), GapBoth)
}
}
impl<'l, 'r, T> BranchlessMergeState<'l, 'r, 'r, T, GapLeft> {
pub fn new_gap_left<BL: Brand, BR: Brand>(
left: GapGuard<'l, 'r, BL, BR, T>,
right: MutSlice<'r, BR, T, AlwaysInit>,
) -> Self {
unsafe {
let dst = left.gap_weak().concat(right.weak());
let left = left.take_disjoint().0.weak();
let right = right.raw().weak();
Self::new(left, right, dst, GapLeft)
}
}
}
impl<'l, 'r, T> BranchlessMergeState<'l, 'r, 'l, T, GapRight> {
pub fn new_gap_right<BL: Brand, BR: Brand>(
left: MutSlice<'l, BL, T, AlwaysInit>,
right: GapGuard<'r, 'l, BR, BL, T>,
) -> Self {
unsafe {
let dst = left.weak().concat(right.gap_weak());
let left = left.raw().weak();
let right = right.take_disjoint().0.weak();
Self::new(left, right, dst, GapRight)
}
}
}
impl<'l, 'r, 'dst, T, G: HasLeftGap> BranchlessMergeState<'l, 'r, 'dst, T, G> {
#[inline(always)]
pub unsafe fn branchless_merge_one_at_begin<F: Cmp<T>>(&mut self, is_less: &mut F) {
unsafe {
let left_scan = self.left_begin;
let right_scan = self.right_begin;
let right_less = is_less(&*right_scan, &*left_scan);
let src = select(right_less, right_scan, left_scan);
ptr::copy_nonoverlapping(src, self.dst.begin(), 1);
self.dst.add_begin(1);
self.left_begin = self.left_begin.wrapping_sub(right_less as usize); self.right_begin = self.right_begin.add(right_less as usize);
self.left_begin = self.left_begin.wrapping_add(1).add(0); }
}
#[inline]
pub unsafe fn branchless_merge_one_at_begin_imbalance_guarded<F: Cmp<T>>(
&mut self,
is_less: &mut F,
) {
unsafe {
let left_empty = self.left_begin == self.left_end;
let right_nonempty = self.right_begin != self.right_end;
let left_scan = select(left_empty, self.right_begin, self.left_begin);
let right_scan = select(right_nonempty, self.right_begin, self.left_begin);
let right_less = is_less(&*right_scan, &*left_scan);
let shrink_right = right_less & right_nonempty | left_empty;
let src = select(right_less, right_scan, left_scan);
ptr::copy(src, self.dst.begin(), 1);
self.dst.add_begin(1);
self.left_begin = self.left_begin.wrapping_sub(shrink_right as usize); self.right_begin = self.right_begin.add(shrink_right as usize);
self.left_begin = self.left_begin.wrapping_add(1).add(0); }
}
}
impl<'l, 'r, 'dst, T, G: HasRightGap> BranchlessMergeState<'l, 'r, 'dst, T, G> {
#[inline(always)]
pub unsafe fn branchless_merge_one_at_end<F: Cmp<T>>(&mut self, is_less: &mut F) {
unsafe {
let left_scan = self.left_end.sub(1);
let right_scan = self.right_end.sub(1);
let right_less = is_less(&*right_scan, &*left_scan);
let src = select(right_less, left_scan, right_scan);
self.dst.sub_end(1);
ptr::copy_nonoverlapping(src, self.dst.end(), 1);
self.right_end = self.right_end.wrapping_add(right_less as usize); self.left_end = self.left_end.sub(right_less as usize);
self.right_end = self.right_end.wrapping_sub(1).add(0); }
}
#[inline]
pub unsafe fn branchless_merge_one_at_end_imbalance_guarded<F: Cmp<T>>(
&mut self,
is_less: &mut F,
) {
unsafe {
let left_nonempty = self.left_begin != self.left_end;
let right_empty = self.right_begin == self.right_end;
let left_scan = select(left_nonempty, self.left_end, self.right_end).sub(1);
let right_scan = select(right_empty, self.left_end, self.right_end).sub(1);
let right_less = is_less(&*right_scan, &*left_scan);
let shrink_left = right_less & left_nonempty | right_empty;
let src = select(right_less, left_scan, right_scan);
self.dst.sub_end(1);
ptr::copy(src, self.dst.end(), 1);
self.right_end = self.right_end.wrapping_add(shrink_left as usize); self.left_end = self.left_end.sub(shrink_left as usize);
self.right_end = self.right_end.wrapping_sub(1).add(0); }
}
}
impl<'l, 'r, 'dst, T, G> BranchlessMergeState<'l, 'r, 'dst, T, G> {
#[inline(always)]
pub fn symmetric_merge_successful(&self) -> bool {
self.left_begin == self.left_end
}
pub fn num_safe_merge_ops(&self) -> usize {
unsafe {
let left_len = self.left_end.offset_from(self.left_begin);
let right_len = self.right_end.offset_from(self.right_begin);
let min = left_len.min(right_len);
if min < 0 {
abort();
}
min as usize
}
}
}
impl<'l, 'r, 'dst, T> BranchlessMergeState<'l, 'r, 'dst, T, GapLeft> {
#[inline(never)]
pub fn finish_merge<F: Cmp<T>>(mut self, is_less: &mut F) {
loop {
let n = self.num_safe_merge_ops();
if n == 0 {
return;
}
unsafe {
for _ in 0..n / 2 {
self.branchless_merge_one_at_begin(is_less);
self.branchless_merge_one_at_begin(is_less);
}
for _ in 0..n % 2 {
self.branchless_merge_one_at_begin(is_less);
}
}
}
}
}
impl<'l, 'r, 'dst, T> BranchlessMergeState<'l, 'r, 'dst, T, GapRight> {
#[inline(never)]
pub fn finish_merge<F: Cmp<T>>(mut self, is_less: &mut F) {
loop {
let n = self.num_safe_merge_ops();
if n == 0 {
return;
}
unsafe {
for _ in 0..n / 2 {
self.branchless_merge_one_at_end(is_less);
self.branchless_merge_one_at_end(is_less);
}
for _ in 0..n % 2 {
self.branchless_merge_one_at_end(is_less);
}
}
}
}
}
impl<'l, 'r, 'dst, T> BranchlessMergeState<'l, 'r, 'dst, T, GapBoth> {
#[inline(never)]
pub fn finish_merge<F: Cmp<T>>(mut self, is_less: &mut F) {
loop {
let n = self.num_safe_merge_ops();
if n == 0 {
return;
}
unsafe {
for _ in 0..n / 4 {
self.branchless_merge_one_at_begin(is_less);
self.branchless_merge_one_at_end(is_less);
self.branchless_merge_one_at_begin(is_less);
self.branchless_merge_one_at_end(is_less);
}
for _ in 0..n % 4 {
self.branchless_merge_one_at_begin(is_less);
}
}
}
}
#[inline(never)]
pub fn finish_merge_interleaved<F: Cmp<T>>(mut self, mut other: Self, is_less: &mut F) {
loop {
let common_remaining = self.num_safe_merge_ops().min(other.num_safe_merge_ops());
if common_remaining < 2 {
break;
}
unsafe {
for _ in 0..common_remaining / 2 {
self.branchless_merge_one_at_begin(is_less);
other.branchless_merge_one_at_begin(is_less);
self.branchless_merge_one_at_end(is_less);
other.branchless_merge_one_at_end(is_less);
}
}
}
self.finish_merge(is_less);
other.finish_merge(is_less);
}
}
impl<'l, 'r, 'dst, T, G> Drop for BranchlessMergeState<'l, 'r, 'dst, T, G> {
fn drop(&mut self) {
unsafe {
let left_len = self
.left_end
.offset_from(self.left_begin)
.try_into()
.unwrap_abort();
let right_len = self
.right_end
.offset_from(self.right_begin)
.try_into()
.unwrap_abort();
assert_abort(left_len + right_len == self.dst.len());
let dst_begin = self.dst.begin();
let mid = dst_begin.add(left_len);
ptr::copy(self.left_begin, dst_begin, left_len);
ptr::copy(self.right_begin, mid, right_len);
}
}
}