use std::cmp::Ordering;
pub trait PriorityElement {
type Priority: Ord;
fn priority(&self) -> Self::Priority;
}
pub trait PriorityQueue<T: PriorityElement> {
fn enqueue(&mut self, element: T) -> Result<(), T>;
fn peek(&self) -> Option<&T>;
fn dequeue(&mut self) -> Option<T>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HeapType {
MaxHeap,
MinHeap,
}
#[derive(Debug)]
struct HeapEntry<T> {
data: T,
index: usize,
}
pub struct BinaryHeap<T> {
data: Vec<HeapEntry<T>>,
heap_type: HeapType,
capacity: Option<usize>,
}
impl<T: Ord> BinaryHeap<T> {
pub fn new(heap_type: HeapType) -> Self {
BinaryHeap {
data: Vec::new(),
heap_type,
capacity: None,
}
}
pub fn with_capacity(heap_type: HeapType, capacity: usize) -> Self {
BinaryHeap {
data: Vec::with_capacity(capacity),
heap_type,
capacity: Some(capacity),
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn is_full(&self) -> bool {
if let Some(capacity) = self.capacity {
self.len() >= capacity
} else {
false
}
}
pub fn peek(&self) -> Option<&T> {
self.data.first().map(|entry| &entry.data)
}
pub fn peek_mut(&mut self) -> Option<&mut T> {
self.data.first_mut().map(|entry| &mut entry.data)
}
pub fn push(&mut self, value: T) -> Result<(), T> {
if self.is_full() {
return Err(value);
}
self.data.push(HeapEntry {
data: value,
index: self.len(),
});
self.sift_up(self.len() - 1);
Ok(())
}
pub fn pop(&mut self) -> Option<T> {
if self.is_empty() {
return None;
}
let value = self.data.swap_remove(0);
if !self.is_empty() {
self.sift_down(0);
}
Some(value.data)
}
pub fn clear(&mut self) {
self.data.clear();
}
pub fn from_vec(mut vec: Vec<T>, heap_type: HeapType) -> Self {
let mut heap = BinaryHeap {
data: vec
.into_iter()
.map(|value| HeapEntry {
data: value,
index: 0,
})
.collect(),
heap_type,
capacity: None,
};
heap.heapify();
heap
}
pub fn heap_type(&self) -> HeapType {
self.heap_type
}
fn compare(&self, a: &T, b: &T) -> Ordering {
match self.heap_type {
HeapType::MaxHeap => a.cmp(b),
HeapType::MinHeap => b.cmp(a),
}
}
fn parent(index: usize) -> Option<usize> {
if index > 0 {
Some((index - 1) / 2)
} else {
None
}
}
fn left_child(index: usize) -> usize {
2 * index + 1
}
fn right_child(index: usize) -> usize {
2 * index + 2
}
fn sift_up(&mut self, mut index: usize) {
while let Some(parent) = Self::parent(index) {
if self
.compare(&self.data[index].data, &self.data[parent].data)
.is_gt()
{
self.data.swap(index, parent);
self.data[index].index = index;
self.data[parent].index = parent;
index = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut index: usize) {
let len = self.len();
loop {
let left = Self::left_child(index);
let right = Self::right_child(index);
let mut largest = index;
if left < len
&& self
.compare(&self.data[left].data, &self.data[largest].data)
.is_gt()
{
largest = left;
}
if right < len
&& self
.compare(&self.data[right].data, &self.data[largest].data)
.is_gt()
{
largest = right;
}
if largest == index {
break;
}
self.data.swap(index, largest);
self.data[index].index = index;
self.data[largest].index = largest;
index = largest;
}
}
fn heapify(&mut self) {
if self.len() <= 1 {
return;
}
let last_parent = Self::parent(self.len() - 1).unwrap();
for i in (0..=last_parent).rev() {
self.sift_down(i);
}
}
pub fn decrease_key(&mut self, index: usize, new_value: T) -> Result<(), ()> {
if index >= self.len() {
return Err(());
}
let is_valid = match self.heap_type {
HeapType::MinHeap => new_value < self.data[index].data,
HeapType::MaxHeap => new_value > self.data[index].data,
};
if !is_valid {
return Err(());
}
self.data[index].data = new_value;
self.sift_up(index);
Ok(())
}
pub fn increase_key(&mut self, index: usize, new_value: T) -> Result<(), ()> {
if index >= self.len() {
return Err(());
}
if self.compare(&new_value, &self.data[index].data).is_lt() {
return Err(());
}
self.data[index].data = new_value;
self.sift_up(index);
Ok(())
}
pub fn update_key(&mut self, index: usize, new_value: T) -> Result<(), ()> {
if index >= self.len() {
return Err(());
}
match self.compare(&new_value, &self.data[index].data) {
Ordering::Less => self.decrease_key(index, new_value),
Ordering::Greater => self.increase_key(index, new_value),
Ordering::Equal => Ok(()),
}
}
pub fn remove(&mut self, index: usize) -> Option<T> {
if index >= self.len() {
return None;
}
let len = self.len();
if index == len - 1 {
return self.data.pop().map(|entry| entry.data);
}
self.data.swap(index, len - 1);
let removed = self.data.pop().map(|entry| entry.data);
if index < self.len() {
let parent_idx = Self::parent(index);
if let Some(parent) = parent_idx {
let parent_relation = match self.heap_type {
HeapType::MinHeap => self.data[index].data < self.data[parent].data,
HeapType::MaxHeap => self.data[index].data > self.data[parent].data,
};
if parent_relation {
self.sift_up(index);
return removed;
}
}
let left = Self::left_child(index);
let right = Self::right_child(index);
let mut need_sift_down = false;
if left < self.len() {
let left_relation = match self.heap_type {
HeapType::MinHeap => self.data[left].data < self.data[index].data,
HeapType::MaxHeap => self.data[left].data > self.data[index].data,
};
if left_relation {
need_sift_down = true;
}
}
if right < self.len() && !need_sift_down {
let right_relation = match self.heap_type {
HeapType::MinHeap => self.data[right].data < self.data[index].data,
HeapType::MaxHeap => self.data[right].data > self.data[index].data,
};
if right_relation {
need_sift_down = true;
}
}
if need_sift_down {
self.sift_down(index);
}
}
removed
}
pub fn get(&self, index: usize) -> Option<&T> {
self.data.get(index).map(|entry| &entry.data)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
self.data.get_mut(index).map(|entry| &mut entry.data)
}
pub fn merge(&mut self, other: Self) -> Result<(), Self> {
if let Some(cap) = self.capacity {
if self.len() + other.len() > cap {
return Err(other);
}
}
for entry in other.data {
if let Err(_) = self.push(entry.data) {
break;
}
}
Ok(())
}
fn is_valid_heap(&self) -> bool {
for i in 0..self.len() {
let left = Self::left_child(i);
let right = Self::right_child(i);
if left < self.len() {
if self
.compare(&self.data[left].data, &self.data[i].data)
.is_gt()
{
return false;
}
}
if right < self.len() {
if self
.compare(&self.data[right].data, &self.data[i].data)
.is_gt()
{
return false;
}
}
}
true
}
}
impl<T: Ord + PriorityElement> PriorityQueue<T> for BinaryHeap<T> {
fn enqueue(&mut self, element: T) -> Result<(), T> {
self.push(element)
}
fn peek(&self) -> Option<&T> {
self.peek()
}
fn dequeue(&mut self) -> Option<T> {
self.pop()
}
fn len(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.is_empty()
}
}
impl<T: Ord> FromIterator<T> for BinaryHeap<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let vec: Vec<T> = iter.into_iter().collect();
Self::from_vec(vec, HeapType::MaxHeap)
}
}
impl<T: Ord> Extend<T> for BinaryHeap<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for value in iter {
let _ = self.push(value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_heap() {
let mut heap = BinaryHeap::new(HeapType::MaxHeap);
heap.push(3).unwrap();
heap.push(1).unwrap();
heap.push(4).unwrap();
heap.push(2).unwrap();
assert_eq!(heap.pop(), Some(4));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_min_heap() {
let mut heap = BinaryHeap::new(HeapType::MinHeap);
heap.push(3).unwrap();
heap.push(1).unwrap();
heap.push(4).unwrap();
heap.push(2).unwrap();
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(4));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_capacity() {
let mut heap = BinaryHeap::with_capacity(HeapType::MaxHeap, 3);
assert!(heap.push(1).is_ok());
assert!(heap.push(2).is_ok());
assert!(heap.push(3).is_ok());
assert!(heap.push(4).is_err());
heap.pop();
assert!(heap.push(4).is_ok());
}
#[test]
fn test_from_vec() {
let vec = vec![3, 1, 4, 2];
let mut heap = BinaryHeap::from_vec(vec, HeapType::MaxHeap);
assert_eq!(heap.pop(), Some(4));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(2));
assert_eq!(heap.pop(), Some(1));
}
#[test]
fn test_peek() {
let mut heap = BinaryHeap::new(HeapType::MaxHeap);
assert_eq!(heap.peek(), None);
heap.push(1).unwrap();
heap.push(2).unwrap();
assert_eq!(heap.peek(), Some(&2));
if let Some(value) = heap.peek_mut() {
*value = 3;
}
assert_eq!(heap.pop(), Some(3));
}
#[test]
fn test_clear() {
let mut heap = BinaryHeap::new(HeapType::MaxHeap);
heap.push(1).unwrap();
heap.push(2).unwrap();
heap.clear();
assert!(heap.is_empty());
assert_eq!(heap.pop(), None);
}
#[test]
fn test_from_iterator() {
let heap: BinaryHeap<i32> = vec![3, 1, 4, 2].into_iter().collect();
assert_eq!(heap.len(), 4);
assert_eq!(heap.peek(), Some(&4));
}
#[test]
fn test_extend() {
let mut heap = BinaryHeap::new(HeapType::MaxHeap);
heap.push(1).unwrap();
heap.extend(vec![3, 2, 4]);
assert_eq!(heap.len(), 4);
assert_eq!(heap.peek(), Some(&4));
}
#[test]
fn test_decrease_key() {
let mut heap = BinaryHeap::new(HeapType::MinHeap);
heap.push(5).unwrap();
heap.push(3).unwrap();
heap.push(7).unwrap();
assert!(heap.decrease_key(0, 1).is_ok());
assert_eq!(heap.peek(), Some(&1));
assert!(heap.decrease_key(0, 6).is_err());
}
#[test]
fn test_increase_key() {
let mut heap = BinaryHeap::new(HeapType::MaxHeap);
heap.push(5).unwrap();
heap.push(3).unwrap();
heap.push(7).unwrap();
assert!(heap.increase_key(1, 8).is_ok());
assert_eq!(heap.peek(), Some(&8));
assert!(heap.increase_key(0, 4).is_err());
}
#[test]
fn test_remove() {
let mut heap = BinaryHeap::new(HeapType::MinHeap);
heap.push(5).unwrap();
heap.push(3).unwrap();
heap.push(7).unwrap();
assert_eq!(heap.remove(1), Some(5));
assert_eq!(heap.len(), 2);
assert!(heap.is_valid_heap());
}
#[test]
fn test_merge() {
let mut heap1 = BinaryHeap::new(HeapType::MinHeap);
heap1.push(1).unwrap();
heap1.push(3).unwrap();
let mut heap2 = BinaryHeap::new(HeapType::MinHeap);
heap2.push(2).unwrap();
heap2.push(4).unwrap();
assert!(heap1.merge(heap2).is_ok());
assert_eq!(heap1.len(), 4);
assert_eq!(heap1.pop(), Some(1));
assert_eq!(heap1.pop(), Some(2));
assert_eq!(heap1.pop(), Some(3));
assert_eq!(heap1.pop(), Some(4));
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
struct Task {
priority: i32,
id: String,
}
impl PriorityElement for Task {
type Priority = i32;
fn priority(&self) -> Self::Priority {
self.priority
}
}
#[test]
fn test_priority_queue() {
let mut pq: BinaryHeap<Task> = BinaryHeap::new(HeapType::MaxHeap);
let task1 = Task {
priority: 1,
id: "low".to_string(),
};
let task2 = Task {
priority: 3,
id: "high".to_string(),
};
let task3 = Task {
priority: 2,
id: "medium".to_string(),
};
pq.enqueue(task1).unwrap();
pq.enqueue(task2).unwrap();
pq.enqueue(task3).unwrap();
assert_eq!(pq.dequeue().unwrap().priority, 3);
assert_eq!(pq.dequeue().unwrap().priority, 2);
assert_eq!(pq.dequeue().unwrap().priority, 1);
assert!(pq.is_empty());
}
}