use core::cmp::Ordering;
use core::marker::PhantomData;
use core::mem::ManuallyDrop;
use core::{mem, ptr, slice};
use generic_array::ArrayLength;
use Vec;
pub enum Min {}
pub enum Max {}
pub unsafe trait Kind {
#[doc(hidden)]
fn ordering() -> Ordering;
}
unsafe impl Kind for Min {
fn ordering() -> Ordering {
Ordering::Less
}
}
unsafe impl Kind for Max {
fn ordering() -> Ordering {
Ordering::Greater
}
}
pub struct BinaryHeap<T, N, KIND>
where
T: Ord,
N: ArrayLength<T>,
KIND: Kind,
{
_kind: PhantomData<KIND>,
data: Vec<T, N>,
}
impl<T, N, K> BinaryHeap<T, N, K>
where
T: Ord,
N: ArrayLength<T>,
K: Kind,
{
pub const fn new() -> Self {
BinaryHeap {
_kind: PhantomData,
data: Vec::new(),
}
}
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn clear(&mut self) {
self.data.clear()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> slice::Iter<T> {
self.data.iter()
}
pub fn iter_mut(&mut self) -> slice::IterMut<T> {
self.data.iter_mut()
}
pub fn peek(&self) -> Option<&T> {
self.data.get(0)
}
pub fn pop(&mut self) -> Option<T> {
if self.is_empty() {
None
} else {
Some(unsafe { self.pop_unchecked() })
}
}
pub unsafe fn pop_unchecked(&mut self) -> T {
let mut item = self.data.pop_unchecked();
if !self.is_empty() {
mem::swap(&mut item, &mut self.data[0]);
self.sift_down_to_bottom(0);
}
item
}
pub fn push(&mut self, item: T) -> Result<(), T> {
if self.data.is_full() {
return Err(item);
}
unsafe { self.push_unchecked(item) }
Ok(())
}
pub unsafe fn push_unchecked(&mut self, item: T) {
let old_len = self.len();
self.data.push_unchecked(item);
self.sift_up(0, old_len);
}
fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end {
let right = child + 1;
if right < end && hole.get(child).cmp(hole.get(right)) != K::ordering() {
child = right;
}
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
pos = hole.pos;
}
self.sift_up(start, pos);
}
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if hole.element().cmp(hole.get(parent)) != K::ordering() {
break;
}
hole.move_to(parent);
}
hole.pos()
}
}
}
struct Hole<'a, T: 'a> {
data: &'a mut [T],
elt: ManuallyDrop<T>,
pos: usize,
}
impl<'a, T> Hole<'a, T> {
#[inline]
unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
debug_assert!(pos < data.len());
let elt = ptr::read(data.get_unchecked(pos));
Hole {
data,
elt: ManuallyDrop::new(elt),
pos,
}
}
#[inline]
fn pos(&self) -> usize {
self.pos
}
#[inline]
fn element(&self) -> &T {
&self.elt
}
#[inline]
unsafe fn get(&self, index: usize) -> &T {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
self.data.get_unchecked(index)
}
#[inline]
unsafe fn move_to(&mut self, index: usize) {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
let index_ptr: *const _ = self.data.get_unchecked(index);
let hole_ptr = self.data.get_unchecked_mut(self.pos);
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
self.pos = index;
}
}
impl<'a, T> Drop for Hole<'a, T> {
#[inline]
fn drop(&mut self) {
unsafe {
let pos = self.pos;
ptr::write(self.data.get_unchecked_mut(pos), ptr::read(&*self.elt));
}
}
}
impl<'a, T, N, K> IntoIterator for &'a BinaryHeap<T, N, K>
where
N: ArrayLength<T>,
K: Kind,
T: Ord,
{
type Item = &'a T;
type IntoIter = slice::Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use std::vec::Vec;
use binary_heap::{self, BinaryHeap, Min};
use consts::*;
#[test]
fn min() {
let mut heap = BinaryHeap::<_, U16, Min>::new();
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.push(3).unwrap();
heap.push(17).unwrap();
heap.push(19).unwrap();
heap.push(36).unwrap();
heap.push(7).unwrap();
heap.push(25).unwrap();
heap.push(100).unwrap();
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[1, 2, 3, 17, 19, 36, 7, 25, 100]
);
assert_eq!(heap.pop(), Some(1));
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[2, 17, 3, 25, 19, 36, 7, 100]
);
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(17));
assert_eq!(heap.pop(), Some(19));
assert_eq!(heap.pop(), Some(25));
assert_eq!(heap.pop(), Some(36));
assert_eq!(heap.pop(), Some(100));
assert_eq!(heap.pop(), None);
}
#[test]
fn max() {
let mut heap = BinaryHeap::<_, U16, binary_heap::Max>::new();
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.push(3).unwrap();
heap.push(17).unwrap();
heap.push(19).unwrap();
heap.push(36).unwrap();
heap.push(7).unwrap();
heap.push(25).unwrap();
heap.push(100).unwrap();
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[100, 36, 19, 25, 3, 2, 7, 1, 17]
);
assert_eq!(heap.pop(), Some(100));
assert_eq!(
heap.iter().cloned().collect::<Vec<_>>(),
[36, 25, 19, 17, 3, 2, 7, 1]
);
assert_eq!(heap.pop(), Some(36));
assert_eq!(heap.pop(), Some(25));
assert_eq!(heap.pop(), Some(19));
assert_eq!(heap.pop(), Some(17));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
}
}