use core::{
cmp::Ordering,
fmt,
marker::PhantomData,
mem::{self, ManuallyDrop},
ops::{Deref, DerefMut},
ptr, slice,
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
use crate::vec::{OwnedVecStorage, Vec, VecInner, VecStorage, ViewVecStorage};
pub enum Min {}
pub enum Max {}
pub trait Kind: private::Sealed {
#[doc(hidden)]
fn ordering() -> Ordering;
}
impl Kind for Min {
fn ordering() -> Ordering {
Ordering::Less
}
}
impl Kind for Max {
fn ordering() -> Ordering {
Ordering::Greater
}
}
mod private {
pub trait Sealed {}
}
impl private::Sealed for Max {}
impl private::Sealed for Min {}
#[cfg_attr(
feature = "zeroize",
derive(Zeroize),
zeroize(bound = "T: Zeroize, S: Zeroize")
)]
pub struct BinaryHeapInner<T, K, S: VecStorage<T> + ?Sized> {
pub(crate) _kind: PhantomData<K>,
pub(crate) data: VecInner<T, usize, S>,
}
pub type BinaryHeap<T, K, const N: usize> = BinaryHeapInner<T, K, OwnedVecStorage<T, N>>;
pub type BinaryHeapView<T, K> = BinaryHeapInner<T, K, ViewVecStorage<T>>;
impl<T, K, const N: usize> BinaryHeap<T, K, N> {
pub const fn new() -> Self {
Self {
_kind: PhantomData,
data: Vec::new(),
}
}
}
impl<T, K, const N: usize> BinaryHeap<T, K, N> {
pub fn into_vec(self) -> Vec<T, N, usize> {
self.data
}
}
impl<T, K, S: VecStorage<T>> BinaryHeapInner<T, K, S> {
pub fn as_view(&self) -> &BinaryHeapView<T, K> {
S::as_binary_heap_view(self)
}
pub fn as_mut_view(&mut self) -> &mut BinaryHeapView<T, K> {
S::as_binary_heap_view_mut(self)
}
}
impl<T, K, S: VecStorage<T> + ?Sized> BinaryHeapInner<T, K, S>
where
T: Ord,
K: Kind,
{
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn clear(&mut self) {
self.data.clear();
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn iter(&self) -> slice::Iter<'_, T> {
self.data.as_slice().iter()
}
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
self.data.as_mut_slice().iter_mut()
}
pub fn peek(&self) -> Option<&T> {
self.data.as_slice().first()
}
pub fn peek_mut(&mut self) -> Option<PeekMutInner<'_, T, K, S>> {
if self.is_empty() {
None
} else {
Some(PeekMutInner {
heap: self,
sift: true,
})
}
}
pub fn pop(&mut self) -> Option<T> {
if self.is_empty() {
None
} else {
Some(unsafe { self.pop_unchecked() })
}
}
pub unsafe fn pop_unchecked(&mut self) -> T {
let mut item = self.data.pop_unchecked();
if !self.is_empty() {
mem::swap(&mut item, self.data.as_mut_slice().get_unchecked_mut(0));
self.sift_down_to_bottom(0);
}
item
}
pub fn push(&mut self, item: T) -> Result<(), T> {
if self.data.is_full() {
return Err(item);
}
unsafe { self.push_unchecked(item) }
Ok(())
}
pub unsafe fn push_unchecked(&mut self, item: T) {
let old_len = self.len();
self.data.push_unchecked(item);
self.sift_up(0, old_len);
}
fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
unsafe {
let mut hole = Hole::new(self.data.as_mut_slice(), pos);
let mut child = 2 * pos + 1;
while child < end {
let right = child + 1;
if right < end && hole.get(child).cmp(hole.get(right)) != K::ordering() {
child = right;
}
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
pos = hole.pos;
}
self.sift_up(start, pos);
}
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
unsafe {
let mut hole = Hole::new(self.data.as_mut_slice(), pos);
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if hole.element().cmp(hole.get(parent)) != K::ordering() {
break;
}
hole.move_to(parent);
}
hole.pos()
}
}
}
struct Hole<'a, T> {
data: &'a mut [T],
elt: ManuallyDrop<T>,
pos: usize,
}
impl<'a, T> Hole<'a, T> {
#[inline]
unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
debug_assert!(pos < data.len());
let elt = ptr::read(data.get_unchecked(pos));
Hole {
data,
elt: ManuallyDrop::new(elt),
pos,
}
}
#[inline]
fn pos(&self) -> usize {
self.pos
}
#[inline]
fn element(&self) -> &T {
&self.elt
}
#[inline]
unsafe fn get(&self, index: usize) -> &T {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
self.data.get_unchecked(index)
}
#[inline]
unsafe fn move_to(&mut self, index: usize) {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
let ptr = self.data.as_mut_ptr();
let index_ptr: *const _ = ptr.add(index);
let hole_ptr = ptr.add(self.pos);
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
self.pos = index;
}
}
pub struct PeekMutInner<'a, T, K, S>
where
T: Ord,
K: Kind,
S: VecStorage<T> + ?Sized,
{
heap: &'a mut BinaryHeapInner<T, K, S>,
sift: bool,
}
pub type PeekMut<'a, T, K, const N: usize> = PeekMutInner<'a, T, K, OwnedVecStorage<T, N>>;
pub type PeekMutView<'a, T, K> = PeekMutInner<'a, T, K, ViewVecStorage<T>>;
impl<T, K, S> Drop for PeekMutInner<'_, T, K, S>
where
T: Ord,
K: Kind,
S: VecStorage<T> + ?Sized,
{
fn drop(&mut self) {
if self.sift {
self.heap.sift_down_to_bottom(0);
}
}
}
impl<T, K, S> Deref for PeekMutInner<'_, T, K, S>
where
T: Ord,
K: Kind,
S: VecStorage<T> + ?Sized,
{
type Target = T;
fn deref(&self) -> &T {
debug_assert!(!self.heap.is_empty());
unsafe { self.heap.data.as_slice().get_unchecked(0) }
}
}
impl<T, K, S> DerefMut for PeekMutInner<'_, T, K, S>
where
T: Ord,
K: Kind,
S: VecStorage<T> + ?Sized,
{
fn deref_mut(&mut self) -> &mut T {
debug_assert!(!self.heap.is_empty());
unsafe { self.heap.data.as_mut_slice().get_unchecked_mut(0) }
}
}
impl<T, K, S> PeekMutInner<'_, T, K, S>
where
T: Ord,
K: Kind,
S: VecStorage<T> + ?Sized,
{
pub fn pop(mut this: Self) -> T {
let value = this.heap.pop().unwrap();
this.sift = false;
value
}
}
impl<T> Drop for Hole<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
let pos = self.pos;
ptr::write(self.data.get_unchecked_mut(pos), ptr::read(&*self.elt));
}
}
}
impl<T, K, const N: usize> Default for BinaryHeap<T, K, N>
where
T: Ord,
K: Kind,
{
fn default() -> Self {
Self::new()
}
}
impl<T, K, const N: usize> Clone for BinaryHeap<T, K, N>
where
K: Kind,
T: Ord + Clone,
{
fn clone(&self) -> Self {
Self {
_kind: self._kind,
data: self.data.clone(),
}
}
}
impl<T, K, S> fmt::Debug for BinaryHeapInner<T, K, S>
where
K: Kind,
T: Ord + fmt::Debug,
S: VecStorage<T> + ?Sized,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
impl<'a, T, K, S> IntoIterator for &'a BinaryHeapInner<T, K, S>
where
K: Kind,
T: Ord,
S: VecStorage<T> + ?Sized,
{
type Item = &'a T;
type IntoIter = slice::Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use static_assertions::assert_not_impl_any;
use super::{BinaryHeap, BinaryHeapView, Max, Min};
assert_not_impl_any!(BinaryHeap<*const (), Max, 4>: Send);
assert_not_impl_any!(BinaryHeap<*const (), Min, 4>: Send);
#[test]
fn static_new() {
static mut _B: BinaryHeap<i32, Min, 16> = BinaryHeap::new();
}
#[test]
fn drop() {
droppable!();
{
let mut v: BinaryHeap<Droppable, Max, 2> = BinaryHeap::new();
v.push(Droppable::new()).ok().unwrap();
v.push(Droppable::new()).ok().unwrap();
v.pop().unwrap();
}
assert_eq!(Droppable::count(), 0);
{
let mut v: BinaryHeap<Droppable, Max, 2> = BinaryHeap::new();
v.push(Droppable::new()).ok().unwrap();
v.push(Droppable::new()).ok().unwrap();
}
assert_eq!(Droppable::count(), 0);
{
let mut v: BinaryHeap<Droppable, Min, 2> = BinaryHeap::new();
v.push(Droppable::new()).ok().unwrap();
v.push(Droppable::new()).ok().unwrap();
v.pop().unwrap();
}
assert_eq!(Droppable::count(), 0);
{
let mut v: BinaryHeap<Droppable, Min, 2> = BinaryHeap::new();
v.push(Droppable::new()).ok().unwrap();
v.push(Droppable::new()).ok().unwrap();
}
assert_eq!(Droppable::count(), 0);
}
#[test]
fn into_vec() {
droppable!();
let mut h: BinaryHeap<Droppable, Max, 2> = BinaryHeap::new();
h.push(Droppable::new()).ok().unwrap();
h.push(Droppable::new()).ok().unwrap();
h.pop().unwrap();
assert_eq!(Droppable::count(), 1);
let v = h.into_vec();
assert_eq!(Droppable::count(), 1);
core::mem::drop(v);
assert_eq!(Droppable::count(), 0);
}
#[test]
fn min() {
let mut heap = BinaryHeap::<_, Min, 16>::new();
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.push(3).unwrap();
heap.push(17).unwrap();
heap.push(19).unwrap();
heap.push(36).unwrap();
heap.push(7).unwrap();
heap.push(25).unwrap();
heap.push(100).unwrap();
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[1, 2, 3, 17, 19, 36, 7, 25, 100]
);
assert_eq!(heap.pop(), Some(1));
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[2, 17, 3, 25, 19, 36, 7, 100]
);
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(17));
assert_eq!(heap.pop(), Some(19));
assert_eq!(heap.pop(), Some(25));
assert_eq!(heap.pop(), Some(36));
assert_eq!(heap.pop(), Some(100));
assert_eq!(heap.pop(), None);
assert!(heap.peek_mut().is_none());
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.push(10).unwrap();
{
let mut val = heap.peek_mut().unwrap();
*val = 7;
}
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(10));
assert_eq!(heap.pop(), None);
}
#[test]
fn max() {
let mut heap = BinaryHeap::<_, Max, 16>::new();
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.push(3).unwrap();
heap.push(17).unwrap();
heap.push(19).unwrap();
heap.push(36).unwrap();
heap.push(7).unwrap();
heap.push(25).unwrap();
heap.push(100).unwrap();
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[100, 36, 19, 25, 3, 2, 7, 1, 17]
);
assert_eq!(heap.pop(), Some(100));
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[36, 25, 19, 17, 3, 2, 7, 1]
);
assert_eq!(heap.pop(), Some(36));
assert_eq!(heap.pop(), Some(25));
assert_eq!(heap.pop(), Some(19));
assert_eq!(heap.pop(), Some(17));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
assert!(heap.peek_mut().is_none());
heap.push(1).unwrap();
heap.push(9).unwrap();
heap.push(10).unwrap();
{
let mut val = heap.peek_mut().unwrap();
*val = 7;
}
assert_eq!(heap.pop(), Some(9));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
}
#[test]
#[cfg(feature = "zeroize")]
fn test_binary_heap_zeroize() {
use zeroize::Zeroize;
let mut heap = BinaryHeap::<u8, Max, 8>::new();
for i in 0..8 {
heap.push(i).unwrap();
}
assert_eq!(heap.len(), 8);
assert_eq!(heap.peek(), Some(&7));
heap.zeroize();
assert_eq!(heap.len(), 0);
}
fn _test_variance<'a: 'b, 'b>(x: BinaryHeap<&'a (), Max, 42>) -> BinaryHeap<&'b (), Max, 42> {
x
}
fn _test_variance_view<'a: 'b, 'b, 'c>(
x: &'c BinaryHeapView<&'a (), Max>,
) -> &'c BinaryHeapView<&'b (), Max> {
x
}
}