#![deny(unsafe_op_in_unsafe_fn)]
use std::{
fmt,
mem::{ManuallyDrop, swap},
ops::{Deref, DerefMut},
ptr,
};
use super::compare::Compare;
pub struct BinaryHeap<T, C> {
data: Vec<T>,
cmp: C,
}
impl<T, C: Compare<T>> BinaryHeap<T, C> {
pub fn from_vec_cmp(vec: Vec<T>, cmp: C) -> Self {
let mut heap = Self { data: vec, cmp };
if !heap.data.is_empty() {
heap.rebuild();
}
heap
}
pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T, C>> {
if self.is_empty() {
None
} else {
Some(PeekMut {
heap: self,
sift: false,
})
}
}
pub fn pop(&mut self) -> Option<T> {
self.data.pop().map(|mut item| {
if !self.is_empty() {
swap(&mut item, &mut self.data[0]);
unsafe { self.sift_down_to_bottom(0) };
}
item
})
}
pub fn push(&mut self, item: T) {
let old_len = self.len();
self.data.push(item);
unsafe { self.sift_up(0, old_len) };
}
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if self
.cmp
.compares_le(hole.element(), unsafe { hole.get(parent) })
{
break;
}
unsafe { hole.move_to(parent) };
}
hole.pos()
}
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
while child <= end.saturating_sub(2) {
child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
if self
.cmp
.compares_ge(hole.element(), unsafe { hole.get(child) })
{
return;
}
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
if child == end - 1
&& self
.cmp
.compares_lt(hole.element(), unsafe { hole.get(child) })
{
unsafe { hole.move_to(child) };
}
}
unsafe fn sift_down(&mut self, pos: usize) {
let len = self.len();
unsafe { self.sift_down_range(pos, len) };
}
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
while child <= end.saturating_sub(2) {
child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
if child == end - 1 {
unsafe { hole.move_to(child) };
}
pos = hole.pos();
drop(hole);
unsafe { self.sift_up(start, pos) };
}
fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
n -= 1;
unsafe { self.sift_down(n) };
}
}
}
impl<T, C> BinaryHeap<T, C> {
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.data.clear();
}
}
impl<T: fmt::Debug, C> fmt::Debug for BinaryHeap<T, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_list().entries(self.data.iter()).finish()
}
}
impl<T: Clone, C: Clone> Clone for BinaryHeap<T, C> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
cmp: self.cmp.clone(),
}
}
}
pub struct PeekMut<'a, T: 'a, C: 'a + Compare<T>> {
heap: &'a mut BinaryHeap<T, C>,
sift: bool,
}
impl<T: fmt::Debug, C: Compare<T>> fmt::Debug for PeekMut<'_, T, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PeekMut").field(&self.heap.data[0]).finish()
}
}
impl<T, C: Compare<T>> Drop for PeekMut<'_, T, C> {
fn drop(&mut self) {
if self.sift {
unsafe { self.heap.sift_down(0) };
}
}
}
impl<T, C: Compare<T>> Deref for PeekMut<'_, T, C> {
type Target = T;
fn deref(&self) -> &T {
debug_assert!(!self.heap.is_empty());
unsafe { self.heap.data.get_unchecked(0) }
}
}
impl<T, C: Compare<T>> DerefMut for PeekMut<'_, T, C> {
fn deref_mut(&mut self) -> &mut T {
debug_assert!(!self.heap.is_empty());
self.sift = true;
unsafe { self.heap.data.get_unchecked_mut(0) }
}
}
impl<'a, T, C: Compare<T>> PeekMut<'a, T, C> {
pub fn pop(mut this: Self) -> T {
let value = this.heap.pop().unwrap();
this.sift = false;
value
}
}
struct Hole<'a, T: 'a> {
data: &'a mut [T],
elt: ManuallyDrop<T>,
pos: usize,
}
impl<'a, T> Hole<'a, T> {
#[inline]
unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
debug_assert!(pos < data.len());
let elt = unsafe { ptr::read(data.get_unchecked(pos)) };
Hole {
data,
elt: ManuallyDrop::new(elt),
pos,
}
}
#[inline]
fn pos(&self) -> usize {
self.pos
}
#[inline]
fn element(&self) -> &T {
&self.elt
}
#[inline]
unsafe fn get(&self, index: usize) -> &T {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
unsafe { self.data.get_unchecked(index) }
}
#[inline]
unsafe fn move_to(&mut self, index: usize) {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
unsafe {
let ptr = self.data.as_mut_ptr();
let index_ptr: *const _ = ptr.add(index);
let hole_ptr = ptr.add(self.pos);
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
}
self.pos = index;
}
}
impl<T> Drop for Hole<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
let pos = self.pos;
ptr::copy_nonoverlapping(
ptr::from_ref(&*self.elt),
self.data.get_unchecked_mut(pos),
1,
);
}
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use rstest::rstest;
use super::*;
struct MaxComparator;
impl Compare<i32> for MaxComparator {
fn compare(&self, a: &i32, b: &i32) -> Ordering {
a.cmp(b)
}
}
struct MinComparator;
impl Compare<i32> for MinComparator {
fn compare(&self, a: &i32, b: &i32) -> Ordering {
b.cmp(a)
}
}
#[rstest]
fn test_max_heap() {
let mut heap = BinaryHeap::from_vec_cmp(vec![], MaxComparator);
heap.push(3);
heap.push(1);
heap.push(5);
assert_eq!(heap.pop(), Some(5));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
}
#[rstest]
fn test_min_heap() {
let mut heap = BinaryHeap::from_vec_cmp(vec![], MinComparator);
heap.push(3);
heap.push(1);
heap.push(5);
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(5));
assert_eq!(heap.pop(), None);
}
#[rstest]
fn test_peek_mut() {
let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
if let Some(mut val) = heap.peek_mut() {
*val = 0;
}
assert_eq!(heap.pop(), Some(2));
}
#[rstest]
fn test_peek_mut_pop() {
let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
if let Some(val) = heap.peek_mut() {
let popped = PeekMut::pop(val);
assert_eq!(popped, 5);
}
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(1));
}
#[rstest]
fn test_clear() {
let mut heap = BinaryHeap::from_vec_cmp(vec![1, 2, 3], MaxComparator);
assert!(!heap.is_empty());
heap.clear();
assert!(heap.is_empty());
assert_eq!(heap.len(), 0);
}
#[rstest]
fn test_from_vec() {
let heap = BinaryHeap::from_vec_cmp(vec![3, 1, 4, 1, 5, 9, 2, 6], MaxComparator);
let mut sorted = Vec::new();
let mut heap = heap;
while let Some(v) = heap.pop() {
sorted.push(v);
}
assert_eq!(sorted, vec![9, 6, 5, 4, 3, 2, 1, 1]);
}
}