mod helpers;
#[cfg(test)]
mod tests;
mod veclike;
pub use self::veclike::*;
pub trait BinaryHeapCtx<Element> {
fn lt(&mut self, x: &Element, y: &Element) -> bool;
fn on_move(&mut self, e: &mut Element, new_index: usize) {
let _ = (e, new_index);
}
}
impl<T: Ord> BinaryHeapCtx<T> for () {
fn lt(&mut self, x: &T, y: &T) -> bool {
*x < *y
}
}
pub trait BinaryHeap: VecLike {
fn heap_pop(&mut self, ctx: impl BinaryHeapCtx<Self::Element>) -> Option<Self::Element>;
fn heap_remove(
&mut self,
i: usize,
ctx: impl BinaryHeapCtx<Self::Element>,
) -> Option<Self::Element>;
fn heap_push(&mut self, item: Self::Element, ctx: impl BinaryHeapCtx<Self::Element>) -> usize;
}
impl<T: VecLike> BinaryHeap for T {
fn heap_pop(&mut self, ctx: impl BinaryHeapCtx<Self::Element>) -> Option<Self::Element> {
self.heap_remove(0, ctx)
}
fn heap_remove(
&mut self,
i: usize,
mut ctx: impl BinaryHeapCtx<Self::Element>,
) -> Option<Self::Element> {
if i >= self.len() {
return None;
}
if let Some(mut item) = self.pop() {
let slice = &mut **self;
if i < slice.len() {
core::mem::swap(&mut slice[i], &mut item);
ctx.on_move(&mut slice[i], i);
let should_sift_up = i > 0 && ctx.lt(&slice[i], &slice[(i - 1) / 2]);
unsafe {
if should_sift_up {
sift_up(slice, 0, i, ctx);
} else {
sift_down(slice, i, ctx);
}
}
}
Some(item)
} else {
debug_assert!(false);
None
}
}
fn heap_push(&mut self, item: Self::Element, ctx: impl BinaryHeapCtx<Self::Element>) -> usize {
let i = self.len();
self.push(item);
let slice = &mut **self;
assert!(i < slice.len());
unsafe { sift_up(slice, 0, i, ctx) }
}
}
unsafe fn sift_up<Element>(
this: &mut [Element],
start: usize,
pos: usize,
mut ctx: impl BinaryHeapCtx<Element>,
) -> usize {
unsafe {
let mut hole = helpers::Hole::new(this, pos);
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if !ctx.lt(hole.element(), hole.get(parent)) {
break;
}
let prev_pos = hole.pos();
hole.move_to(parent);
ctx.on_move(hole.get_mut(prev_pos), prev_pos);
}
let pos = hole.pos();
ctx.on_move(hole.element_mut(), pos);
pos
}
}
unsafe fn sift_down<Element>(
this: &mut [Element],
pos: usize,
mut ctx: impl BinaryHeapCtx<Element>,
) {
let end = this.len();
unsafe {
let mut hole = helpers::Hole::new(this, pos);
let mut child = 2 * pos + 1;
while child < end {
let right = child + 1;
if right < end && !ctx.lt(hole.get(child), hole.get(right)) {
child = right;
}
if !ctx.lt(hole.get(child), hole.element()) {
break;
}
let prev_pos = hole.pos();
hole.move_to(child);
ctx.on_move(hole.get_mut(prev_pos), prev_pos);
child = 2 * hole.pos() + 1;
}
let pos = hole.pos();
ctx.on_move(hole.element_mut(), pos);
}
}