use serde::{Deserialize, Serialize};
use std::ops::{Index, IndexMut};
pub trait CommonHeapOrder<T> {
fn should_swap(&self, parent: &T, child: &T) -> bool;
fn should_replace_root(&self, root: &T, new_value: &T) -> bool;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct KeepSmallest;
impl<T: Ord> CommonHeapOrder<T> for KeepSmallest {
#[inline(always)]
fn should_swap(&self, parent: &T, child: &T) -> bool {
child < parent
}
#[inline(always)]
fn should_replace_root(&self, root: &T, new_value: &T) -> bool {
new_value > root
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeepLargest;
impl<T: Ord> CommonHeapOrder<T> for KeepLargest {
#[inline(always)]
fn should_swap(&self, parent: &T, child: &T) -> bool {
child > parent
}
#[inline(always)]
fn should_replace_root(&self, root: &T, new_value: &T) -> bool {
new_value < root
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CommonHeap<T, O: CommonHeapOrder<T>> {
data: Vec<T>,
size: usize,
order: O,
}
impl<T, O: CommonHeapOrder<T>> CommonHeap<T, O> {
pub fn with_capacity(capacity: usize, order: O) -> Self {
Self {
data: Vec::with_capacity(capacity),
size: capacity,
order,
}
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn capacity(&self) -> usize {
self.size
}
#[inline]
pub fn is_full(&self) -> bool {
self.data.len() >= self.size
}
#[inline]
pub fn clear(&mut self) {
self.data.clear();
}
#[inline]
pub fn peek(&self) -> Option<&T> {
self.data.first()
}
#[inline]
pub fn peek_mut(&mut self) -> Option<&mut T> {
self.data.first_mut()
}
pub fn push(&mut self, value: T) {
if self.data.len() < self.size {
self.data.push(value);
self.bubble_up(self.data.len() - 1);
} else if !self.data.is_empty() && self.order.should_replace_root(&self.data[0], &value) {
self.data[0] = value;
self.bubble_down(0);
}
}
pub fn pop(&mut self) -> Option<T> {
if self.data.is_empty() {
return None;
}
if self.data.len() == 1 {
return self.data.pop();
}
let root = self.data.swap_remove(0);
self.bubble_down(0);
Some(root)
}
#[inline]
pub fn update_at(&mut self, index: usize) -> bool {
if index >= self.data.len() {
return false;
}
if !self.bubble_down(index) {
self.bubble_up(index);
true
} else {
true
}
}
#[inline]
pub fn as_slice(&self) -> &[T] {
&self.data
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.data.iter()
}
#[inline]
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
self.data.iter_mut()
}
#[inline(always)]
fn left_child(i: usize) -> usize {
2 * i + 1
}
#[inline(always)]
fn right_child(i: usize) -> usize {
2 * i + 2
}
#[inline(always)]
fn parent(i: usize) -> usize {
(i.saturating_sub(1)) / 2
}
fn bubble_down(&mut self, mut idx: usize) -> bool {
let start_idx = idx;
let len = self.data.len();
while idx < len {
let left = Self::left_child(idx);
let right = Self::right_child(idx);
let mut target = idx;
if left < len && self.order.should_swap(&self.data[target], &self.data[left]) {
target = left;
}
if right < len
&& self
.order
.should_swap(&self.data[target], &self.data[right])
{
target = right;
}
if target == idx {
break;
}
self.data.swap(idx, target);
idx = target;
}
idx != start_idx
}
fn bubble_up(&mut self, mut idx: usize) {
while idx > 0 {
let parent_idx = Self::parent(idx);
if self
.order
.should_swap(&self.data[parent_idx], &self.data[idx])
{
self.data.swap(parent_idx, idx);
idx = parent_idx;
} else {
break;
}
}
}
}
impl<T: Ord> CommonHeap<T, KeepSmallest> {
#[inline]
pub fn new_min(capacity: usize) -> Self {
Self::with_capacity(capacity, KeepSmallest)
}
}
impl<T: Ord> CommonHeap<T, KeepLargest> {
#[inline]
pub fn new_max(capacity: usize) -> Self {
Self::with_capacity(capacity, KeepLargest)
}
}
impl<T, O: CommonHeapOrder<T>> Index<usize> for CommonHeap<T, O> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl<T, O: CommonHeapOrder<T>> IndexMut<usize> for CommonHeap<T, O> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index]
}
}