use core::mem::MaybeUninit;
use crate::mut_slice::states::{AlwaysInit, Uninit};
use crate::mut_slice::{Brand, MutSlice};
use crate::physical_merges::{physical_merge, physical_quad_merge, physical_triple_merge};
use crate::stable_quicksort::quicksort;
#[cfg(feature = "tracking")]
use crate::tracking::ptr;
use crate::util::*;
use crate::{powersort, small_sort, tracking, SMALL_SORT};
pub enum LogicalRun<'l, B: Brand, T> {
Unsorted(MutSlice<'l, B, T, AlwaysInit>),
Sorted(MutSlice<'l, B, T, AlwaysInit>),
DoubleSorted(MutSlice<'l, B, T, AlwaysInit>, usize),
}
impl<'l, B: Brand, T> LogicalRun<'l, B, T> {
fn len(&self) -> usize {
match self {
LogicalRun::Unsorted(r) => r.len(),
LogicalRun::Sorted(r) => r.len(),
LogicalRun::DoubleSorted(r, _mid) => r.len(),
}
}
fn create<F: Cmp<T>>(
mut el: MutSlice<'l, B, T, AlwaysInit>,
is_less: &mut F,
eager_smallsort: bool,
) -> (Self, MutSlice<'l, B, T, AlwaysInit>) {
if el.len() >= SMALL_SORT {
let (run_length, descending) = run_length_at_start(el.as_mut_slice(), is_less);
if run_length >= SMALL_SORT && run_length * run_length >= el.len() / 2 {
if descending {
#[cfg(feature = "tracking")]
{
for i in 0..run_length / 2 {
unsafe {
ptr::swap_nonoverlapping(
el.begin().add(i),
el.begin().add(run_length - 1 - i),
1,
);
}
}
}
#[cfg(not(feature = "tracking"))]
{
el.as_mut_slice()[..run_length].reverse();
}
}
let (run, rest) = el.split_at(run_length).unwrap();
return (LogicalRun::Sorted(run), rest);
}
}
let skip = SMALL_SORT.min(el.len());
let (mut run, rest) = el.split_at(skip).unwrap();
if eager_smallsort {
small_sort::small_sort(run.borrow(), is_less);
(LogicalRun::Sorted(run), rest)
} else {
(LogicalRun::Unsorted(run), rest)
}
}
fn logical_merge<'sc, BS: Brand, F: Cmp<T>>(
self,
right: LogicalRun<'l, B, T>,
mut scratch: MutSlice<'sc, BS, T, Uninit>,
is_less: &mut F,
) -> LogicalRun<'l, B, T> {
use LogicalRun::*;
match (self, right) {
(Unsorted(l), Unsorted(r)) if l.len() + r.len() <= scratch.len() => {
Unsorted(l.concat(r))
}
(Unsorted(l), r) => {
let l = quicksort(l, scratch.borrow(), is_less);
Sorted(l).logical_merge(r, scratch, is_less)
}
(l, Unsorted(r)) => {
let r = quicksort(r, scratch.borrow(), is_less);
l.logical_merge(Sorted(r), scratch, is_less)
}
(Sorted(l), Sorted(r)) => {
let mid = l.len();
DoubleSorted(l.concat(r), mid)
}
(DoubleSorted(l, mid), Sorted(r)) => {
let (l0, l1) = l.split_at(mid).unwrap();
Sorted(physical_triple_merge(l0, l1, r, scratch, is_less))
}
(Sorted(l), DoubleSorted(r, mid)) => {
let (r0, r1) = r.split_at(mid).unwrap();
Sorted(physical_triple_merge(l, r0, r1, scratch, is_less))
}
(DoubleSorted(l, lmid), DoubleSorted(r, rmid)) => {
let (l0, l1) = l.split_at(lmid).unwrap();
let (r0, r1) = r.split_at(rmid).unwrap();
Sorted(physical_quad_merge(l0, l1, r0, r1, scratch, is_less))
}
}
}
fn physical_sort<'sc, BS: Brand, F: Cmp<T>>(
self,
scratch: MutSlice<'sc, BS, T, Uninit>,
is_less: &mut F,
) -> MutSlice<'l, B, T, AlwaysInit> {
match self {
LogicalRun::Sorted(run) => run,
LogicalRun::Unsorted(run) => quicksort(run, scratch, is_less),
LogicalRun::DoubleSorted(run, mid) => {
let (left, right) = run.split_at(mid).unwrap();
physical_merge(left, right, scratch, is_less)
}
}
}
}
struct MergeStack<'l, B: Brand, T> {
left_children: [MaybeUninit<LogicalRun<'l, B, T>>; 64],
desired_depths: [MaybeUninit<u8>; 64],
len: usize,
}
impl<'l, B: Brand, T> MergeStack<'l, B, T> {
fn new() -> Self {
unsafe {
Self {
left_children: MaybeUninit::uninit().assume_init(),
desired_depths: MaybeUninit::uninit().assume_init(),
len: 0,
}
}
}
fn push_node(&mut self, left_child: LogicalRun<'l, B, T>, desired_depth: u8) {
self.left_children[self.len] = MaybeUninit::new(left_child);
self.desired_depths[self.len] = MaybeUninit::new(desired_depth);
self.len += 1;
}
fn pop_node(&mut self) -> Option<LogicalRun<'l, B, T>> {
if self.len == 0 {
return None;
}
self.len -= 1;
Some(unsafe {
self.left_children
.get_unchecked(self.len)
.assume_init_read()
})
}
fn peek_desired_depth(&self) -> Option<u8> {
if self.len == 0 {
return None;
}
Some(unsafe {
self.desired_depths
.get_unchecked(self.len - 1)
.assume_init()
})
}
}
pub fn glidesort<'el, 'sc, BE: Brand, BS: Brand, T, F: Cmp<T>>(
mut el: MutSlice<'el, BE, T, AlwaysInit>,
mut scratch: MutSlice<'sc, BS, T, Uninit>,
is_less: &mut F,
eager_smallsort: bool,
) {
if scratch.len() < SMALL_SORT {
let mut v = Vec::with_capacity(SMALL_SORT);
let (_, new_buffer) = split_at_spare_mut(&mut v);
return MutSlice::from_maybeuninit_mut_slice(new_buffer, |new_scratch| {
glidesort(el, new_scratch.assume_uninit(), is_less, eager_smallsort)
});
}
tracking::register_buffer("input", el.weak());
tracking::register_buffer("scratch", scratch.weak());
let scale_factor = powersort::merge_tree_scale_factor(el.len());
let mut merge_stack = MergeStack::new();
let mut prev_run_start_idx = 0;
let mut prev_run;
(prev_run, el) = LogicalRun::create(el, is_less, eager_smallsort);
while el.len() > 0 {
let next_run_start_idx = prev_run_start_idx + prev_run.len();
let next_run;
(next_run, el) = LogicalRun::create(el, is_less, eager_smallsort);
let desired_depth = powersort::merge_tree_depth(
prev_run_start_idx,
next_run_start_idx,
next_run_start_idx + next_run.len(),
scale_factor,
);
let mut left_child = prev_run;
while merge_stack
.peek_desired_depth()
.map(|top_depth| top_depth >= desired_depth)
.unwrap_or(false)
{
let left_descendant = merge_stack.pop_node().unwrap();
left_child = left_descendant.logical_merge(left_child, scratch.borrow(), is_less);
}
merge_stack.push_node(left_child, desired_depth);
prev_run_start_idx = next_run_start_idx;
prev_run = next_run;
}
let mut result = prev_run;
while let Some(left_child) = merge_stack.pop_node() {
result = left_child.logical_merge(result, scratch.borrow(), is_less);
}
result.physical_sort(scratch, is_less);
tracking::deregister_buffer("input");
tracking::deregister_buffer("scratch");
}
fn run_length_at_start<T, F: Cmp<T>>(v: &[T], is_less: &mut F) -> (usize, bool) {
let descending = v.len() >= 2 && is_less(&v[1], &v[0]);
if descending {
for i in 2..v.len() {
if !is_less(&v[i], &v[i - 1]) {
return (i, true);
}
}
} else {
for i in 2..v.len() {
if is_less(&v[i], &v[i - 1]) {
return (i, false);
}
}
}
(v.len(), descending)
}