use std::cmp::Ordering;
use std::iter::FusedIterator;
use std::mem::{self, ManuallyDrop, swap};
use std::num::NonZero;
use std::ops::{Deref, DerefMut};
use std::{fmt, ptr};
use std::collections::TryReserveError;
use std::slice;
use std::vec::{self, Vec};
pub struct BinaryHeap<T, C> {
data: Vec<T>,
cmp: C,
}
pub struct PeekMut<'a, T: 'a, C>
where
C: Fn(&T, &T) -> Ordering,
{
heap: &'a mut BinaryHeap<T, C>,
original_len: Option<NonZero<usize>>,
}
impl<T: fmt::Debug, C> fmt::Debug for PeekMut<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("PeekMut").field(&self.heap.data[0]).finish()
}
}
impl<T, C> Drop for PeekMut<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn drop(&mut self) {
if let Some(original_len) = self.original_len {
unsafe { self.heap.data.set_len(original_len.get()) };
unsafe { self.heap.sift_down(0) };
}
}
}
impl<T, C> Deref for PeekMut<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
type Target = T;
fn deref(&self) -> &T {
debug_assert!(!self.heap.is_empty());
unsafe { self.heap.data.get_unchecked(0) }
}
}
impl<T, C> DerefMut for PeekMut<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn deref_mut(&mut self) -> &mut T {
debug_assert!(!self.heap.is_empty());
let len = self.heap.len();
if len > 1 {
unsafe {
self.original_len = Some(NonZero::new_unchecked(len));
self.heap.data.set_len(1);
}
}
unsafe { self.heap.data.get_unchecked_mut(0) }
}
}
impl<'a, T, C> PeekMut<'a, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
#[must_use = "is equivalent to dropping and getting a new PeekMut except for return information"]
pub fn refresh(&mut self) -> bool {
if let Some(original_len) = self.original_len.take() {
unsafe { self.heap.data.set_len(original_len.get()) };
(unsafe { self.heap.sift_down(0) }) != 0
} else {
false
}
}
pub fn pop(mut this: PeekMut<'a, T, C>) -> T {
if let Some(original_len) = this.original_len.take() {
unsafe { this.heap.data.set_len(original_len.get()) };
}
unsafe { this.heap.pop().unwrap_unchecked() }
}
}
impl<T: Clone, C: Clone> Clone for BinaryHeap<T, C> {
fn clone(&self) -> Self {
BinaryHeap {
data: self.data.clone(),
cmp: self.cmp.clone(),
}
}
fn clone_from(&mut self, source: &Self) {
self.data.clone_from(&source.data);
}
}
impl<T: fmt::Debug, C> fmt::Debug for BinaryHeap<T, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
struct RebuildOnDrop<'a, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
heap: &'a mut BinaryHeap<T, C>,
rebuild_from: usize,
}
impl<T, C> Drop for RebuildOnDrop<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn drop(&mut self) {
self.heap.rebuild_tail(self.rebuild_from);
}
}
impl<T, C> BinaryHeap<T, C>
where
C: Fn(&T, &T) -> Ordering,
{
pub const fn new(cmp: C) -> BinaryHeap<T, C> {
BinaryHeap { data: vec![], cmp }
}
pub fn from_vec(vec: Vec<T>, cmp: C) -> BinaryHeap<T, C> {
let mut heap = BinaryHeap { data: vec, cmp };
heap.rebuild();
heap
}
pub unsafe fn from_vec_unchecked(vec: Vec<T>, cmp: C) -> BinaryHeap<T, C> {
BinaryHeap { data: vec, cmp }
}
#[must_use]
pub fn with_capacity(capacity: usize, cmp: C) -> BinaryHeap<T, C> {
BinaryHeap {
data: Vec::with_capacity(capacity),
cmp,
}
}
}
impl<T, C> BinaryHeap<T, C>
where
C: Fn(&T, &T) -> Ordering,
{
pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T, C>> {
if self.is_empty() {
None
} else {
Some(PeekMut {
heap: self,
original_len: None,
})
}
}
pub fn pop(&mut self) -> Option<T> {
self.data.pop().map(|mut item| {
if !self.is_empty() {
swap(&mut item, &mut self.data[0]);
unsafe { self.sift_down_to_bottom(0) };
}
item
})
}
pub unsafe fn update_pos_sift_down(&mut self, pos: usize, item: T) {
debug_assert!((self.cmp)(&item, &self.data[pos]) != Ordering::Greater);
self.data[pos] = item;
unsafe { self.sift_down(pos) };
}
pub fn peek_all(&self, mut cb: impl FnMut(usize, &T), queue: &mut Vec<usize>) {
let Some(max) = self.peek() else {
return;
};
cb(0, max);
let mut head = 0usize;
queue.clear();
let len = self.data.len();
if len > 1 {
queue.push(1);
}
if len > 2 {
queue.push(2);
}
while let Some(&index) = queue.get(head) {
head += 1;
let node = unsafe { self.data.get_unchecked(index) };
if (self.cmp)(node, max) != Ordering::Equal {
continue;
}
cb(index, node);
let left = index * 2 + 1;
if left < len {
queue.push(left);
}
let right = left + 1;
if right < len {
queue.push(right);
}
}
}
pub fn remove(&mut self, index: usize) {
let last = self.data.pop().expect("index is in bounds");
if index == self.data.len() {
return;
}
self.data[index] = last;
unsafe {
self.sift_down_to_bottom(index);
}
}
pub fn push(&mut self, item: T) {
let old_len = self.len();
self.data.push(item);
unsafe { self.sift_up(0, old_len) };
}
pub fn into_sorted_vec(mut self) -> Vec<T> {
let mut end = self.len();
while end > 1 {
end -= 1;
unsafe {
let ptr = self.data.as_mut_ptr();
ptr::swap(ptr, ptr.add(end));
}
unsafe { self.sift_down_range(0, end) };
}
self.into_vec()
}
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if (self.cmp)(hole.element(), unsafe { hole.get(parent) }) != Ordering::Greater {
break;
}
unsafe { hole.move_to(parent) };
}
hole.pos()
}
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) -> usize {
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
while child <= end.saturating_sub(2) {
child += ((self.cmp)(unsafe { hole.get(child) }, unsafe { hole.get(child + 1) })
!= Ordering::Greater) as usize;
if (self.cmp)(hole.element(), unsafe { hole.get(child) }) != Ordering::Less {
return hole.pos();
}
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
if child == end - 1
&& (self.cmp)(hole.element(), unsafe { hole.get(child) }) == Ordering::Less
{
unsafe { hole.move_to(child) };
}
hole.pos()
}
pub unsafe fn sift_down(&mut self, pos: usize) -> usize {
let len = self.len();
unsafe { self.sift_down_range(pos, len) }
}
pub unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
while child <= end.saturating_sub(2) {
child += ((self.cmp)(unsafe { hole.get(child) }, unsafe { hole.get(child + 1) })
!= Ordering::Greater) as usize;
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
if child == end - 1 {
unsafe { hole.move_to(child) };
}
pos = hole.pos();
drop(hole);
unsafe { self.sift_up(start, pos) };
}
fn rebuild_tail(&mut self, start: usize) {
if start == self.len() {
return;
}
let tail_len = self.len() - start;
#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
let better_to_rebuild = if start < tail_len {
true
} else if self.len() <= 2048 {
2 * self.len() < tail_len * log2_fast(start)
} else {
2 * self.len() < tail_len * 11
};
if better_to_rebuild {
self.rebuild();
} else {
for i in start..self.len() {
unsafe { self.sift_up(0, i) };
}
}
}
fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
n -= 1;
unsafe { self.sift_down(n) };
}
}
pub fn append(&mut self, other: &mut Self) {
if self.len() < other.len() {
swap(self, other);
}
let start = self.data.len();
self.data.append(&mut other.data);
self.rebuild_tail(start);
}
#[inline]
pub fn drain_sorted(&mut self) -> DrainSorted<'_, T, C> {
DrainSorted { inner: self }
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
let mut guard = RebuildOnDrop {
rebuild_from: self.len(),
heap: self,
};
let mut i = 0;
guard.heap.data.retain(|e| {
let keep = f(e);
if !keep && i < guard.rebuild_from {
guard.rebuild_from = i;
}
i += 1;
keep
});
}
}
impl<T, C> BinaryHeap<T, C> {
pub fn iter(&self) -> Iter<'_, T> {
Iter {
iter: self.data.iter(),
}
}
pub fn into_iter_sorted(self) -> IntoIterSorted<T, C> {
IntoIterSorted { inner: self }
}
#[must_use]
pub fn peek(&self) -> Option<&T> {
self.data.first()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn reserve_exact(&mut self, additional: usize) {
self.data.reserve_exact(additional);
}
pub fn reserve(&mut self, additional: usize) {
self.data.reserve(additional);
}
pub fn try_reserve_exact(&mut self, additional: usize) -> Result<(), TryReserveError> {
self.data.try_reserve_exact(additional)
}
pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> {
self.data.try_reserve(additional)
}
pub fn shrink_to_fit(&mut self) {
self.data.shrink_to_fit();
}
#[inline]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.data.shrink_to(min_capacity)
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
self.data.as_slice()
}
#[must_use = "`self` will be dropped if the result is not used"]
pub fn into_vec(self) -> Vec<T> {
self.into()
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn drain(&mut self) -> Drain<'_, T> {
Drain {
iter: self.data.drain(..),
}
}
pub fn clear(&mut self) {
self.drain();
}
}
struct Hole<'a, T: 'a> {
data: &'a mut [T],
elt: ManuallyDrop<T>,
pos: usize,
}
impl<'a, T> Hole<'a, T> {
#[inline]
unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
debug_assert!(pos < data.len());
let elt = unsafe { ptr::read(data.get_unchecked(pos)) };
Hole {
data,
elt: ManuallyDrop::new(elt),
pos,
}
}
#[inline]
fn pos(&self) -> usize {
self.pos
}
#[inline]
fn element(&self) -> &T {
&self.elt
}
#[inline]
unsafe fn get(&self, index: usize) -> &T {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
unsafe { self.data.get_unchecked(index) }
}
#[inline]
unsafe fn move_to(&mut self, index: usize) {
debug_assert!(index != self.pos);
debug_assert!(index < self.data.len());
unsafe {
let ptr = self.data.as_mut_ptr();
let index_ptr: *const _ = ptr.add(index);
let hole_ptr = ptr.add(self.pos);
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
}
self.pos = index;
}
}
impl<T> Drop for Hole<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
let pos = self.pos;
ptr::copy_nonoverlapping(&*self.elt, self.data.get_unchecked_mut(pos), 1);
}
}
}
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct Iter<'a, T: 'a> {
iter: slice::Iter<'a, T>,
}
impl<T> Default for Iter<'_, T> {
fn default() -> Self {
Iter {
iter: Default::default(),
}
}
}
impl<T: fmt::Debug> fmt::Debug for Iter<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Iter").field(&self.iter.as_slice()).finish()
}
}
impl<T> Clone for Iter<'_, T> {
fn clone(&self) -> Self {
Iter {
iter: self.iter.clone(),
}
}
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
#[inline]
fn next(&mut self) -> Option<&'a T> {
self.iter.next()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
#[inline]
fn last(self) -> Option<&'a T> {
self.iter.last()
}
}
impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
#[inline]
fn next_back(&mut self) -> Option<&'a T> {
self.iter.next_back()
}
}
impl<T> FusedIterator for Iter<'_, T> {}
#[derive(Clone)]
pub struct IntoIter<T> {
iter: vec::IntoIter<T>,
}
impl<T: fmt::Debug> fmt::Debug for IntoIter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoIter")
.field(&self.iter.as_slice())
.finish()
}
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
self.iter.next()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<T> DoubleEndedIterator for IntoIter<T> {
#[inline]
fn next_back(&mut self) -> Option<T> {
self.iter.next_back()
}
}
impl<T> FusedIterator for IntoIter<T> {}
impl<T> Default for IntoIter<T> {
fn default() -> Self {
IntoIter {
iter: Default::default(),
}
}
}
#[must_use = "iterators are lazy and do nothing unless consumed"]
#[derive(Clone, Debug)]
pub struct IntoIterSorted<T, C> {
inner: BinaryHeap<T, C>,
}
impl<T, C> Iterator for IntoIterSorted<T, C>
where
C: Fn(&T, &T) -> Ordering,
{
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
self.inner.pop()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let exact = self.inner.len();
(exact, Some(exact))
}
}
#[derive(Debug)]
pub struct Drain<'a, T: 'a> {
iter: vec::Drain<'a, T>,
}
impl<T> Iterator for Drain<'_, T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
self.iter.next()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<T> DoubleEndedIterator for Drain<'_, T> {
#[inline]
fn next_back(&mut self) -> Option<T> {
self.iter.next_back()
}
}
impl<T> ExactSizeIterator for Drain<'_, T> {}
impl<T> FusedIterator for Drain<'_, T> {}
#[derive(Debug)]
pub struct DrainSorted<'a, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
inner: &'a mut BinaryHeap<T, C>,
}
impl<'a, T, C> Drop for DrainSorted<'a, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn drop(&mut self) {
struct DropGuard<'r, 'a, T, C>(&'r mut DrainSorted<'a, T, C>)
where
C: Fn(&T, &T) -> Ordering;
impl<'r, 'a, T, C> Drop for DropGuard<'r, 'a, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn drop(&mut self) {
while self.0.inner.pop().is_some() {}
}
}
while let Some(item) = self.inner.pop() {
let guard = DropGuard(self);
drop(item);
mem::forget(guard);
}
}
}
impl<T, C> Iterator for DrainSorted<'_, T, C>
where
C: Fn(&T, &T) -> Ordering,
{
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
self.inner.pop()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let exact = self.inner.len();
(exact, Some(exact))
}
}
impl<T, C> ExactSizeIterator for DrainSorted<'_, T, C> where C: Fn(&T, &T) -> Ordering {}
impl<T, C> FusedIterator for DrainSorted<'_, T, C> where C: Fn(&T, &T) -> Ordering {}
impl<T, C> From<BinaryHeap<T, C>> for Vec<T> {
fn from(heap: BinaryHeap<T, C>) -> Vec<T> {
heap.data
}
}
impl<T, C> IntoIterator for BinaryHeap<T, C> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> IntoIter<T> {
IntoIter {
iter: self.data.into_iter(),
}
}
}
impl<'a, T, C> IntoIterator for &'a BinaryHeap<T, C> {
type Item = &'a T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Iter<'a, T> {
self.iter()
}
}
impl<T, C> Extend<T> for BinaryHeap<T, C>
where
C: Fn(&T, &T) -> Ordering,
{
#[inline]
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
let guard = RebuildOnDrop {
rebuild_from: self.len(),
heap: self,
};
guard.heap.data.extend(iter);
}
}
impl<'a, T, C> Extend<&'a T> for BinaryHeap<T, C>
where
T: Clone,
C: Fn(&T, &T) -> Ordering,
{
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
self.extend(iter.into_iter().cloned());
}
}
#[cfg(test)]
mod tests {
use super::BinaryHeap;
#[test]
fn peek_all_returns_all_max_elements() {
let heap = BinaryHeap::from_vec(vec![3, 5, 5, 2, 5, 4], |a, b| a.cmp(b));
let mut values = Vec::new();
heap.peek_all(|_, x| values.push(*x), &mut Vec::new());
assert_eq!(values.len(), 3);
assert!(values.into_iter().all(|x| x == 5));
}
#[test]
fn peek_all_on_empty_heap_invokes_nothing() {
let heap = BinaryHeap::<i32, _>::new(|a, b| a.cmp(b));
let mut called = false;
heap.peek_all(|__, _| called = true, &mut Vec::new());
assert!(!called);
}
#[test]
fn peek_all_on_single_element_heap_returns_that_element() {
let heap = BinaryHeap::<i32, _>::from_vec(vec![42], |a, b| a.cmp(b));
let mut values = Vec::new();
heap.peek_all(|_, x| values.push(*x), &mut Vec::new());
assert_eq!(values, vec![42]);
}
#[test]
fn peek_all_returns_every_element_when_all_are_equal() {
let heap = BinaryHeap::<i32, _>::from_vec(vec![7, 7, 7, 7, 7, 7, 7, 7], |a, b| a.cmp(b));
let mut values = Vec::new();
heap.peek_all(|_, x| values.push(*x), &mut Vec::new());
assert_eq!(values.len(), heap.len());
assert!(values.into_iter().all(|x| x == 7));
}
#[test]
fn peek_all_returns_only_one_when_maximum_is_unique() {
let heap = BinaryHeap::<i32, _>::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], |a, b| a.cmp(b));
let mut values = Vec::new();
heap.peek_all(|_, x| values.push(*x), &mut Vec::new());
assert_eq!(values, vec![9]);
}
}