use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone, PartialEq)]
pub struct MinHeapItem<T>(pub T);
impl<T: PartialEq> Eq for MinHeapItem<T> {}
impl<T: PartialOrd + PartialEq> Ord for MinHeapItem<T> {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
impl<T: PartialOrd + PartialEq> PartialOrd for MinHeapItem<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct MinHeap<T> {
heap: BinaryHeap<MinHeapItem<T>>,
}
impl<T: PartialOrd + PartialEq> MinHeap<T> {
pub fn new() -> Self {
Self {
heap: BinaryHeap::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(capacity),
}
}
pub fn push(&mut self, item: T) {
self.heap.push(MinHeapItem(item));
}
pub fn pop(&mut self) -> Option<T> {
self.heap.pop().map(|MinHeapItem(item)| item)
}
pub fn peek(&self) -> Option<&T> {
self.heap.peek().map(|MinHeapItem(item)| item)
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn clear(&mut self) {
self.heap.clear();
}
}
impl<T: PartialOrd + PartialEq> Default for MinHeap<T> {
fn default() -> Self {
Self::new()
}
}
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone)]
struct IndexedItem<K> {
key: K,
priority: f64,
}
impl<K> PartialEq for IndexedItem<K> {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority
}
}
impl<K> Eq for IndexedItem<K> {}
impl<K> Ord for IndexedItem<K> {
fn cmp(&self, other: &Self) -> Ordering {
other
.priority
.partial_cmp(&self.priority)
.unwrap_or(Ordering::Equal)
}
}
impl<K> PartialOrd for IndexedItem<K> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct IndexedMinHeap<K: Clone + Hash + Eq> {
heap: BinaryHeap<IndexedItem<K>>,
priorities: HashMap<K, f64>,
}
impl<K: Clone + Hash + Eq> IndexedMinHeap<K> {
pub fn new() -> Self {
Self {
heap: BinaryHeap::new(),
priorities: HashMap::new(),
}
}
pub fn insert(&mut self, key: K, priority: f64) {
self.priorities.insert(key.clone(), priority);
self.heap.push(IndexedItem { key, priority });
}
pub fn extract_min(&mut self) -> Option<K> {
while let Some(item) = self.heap.pop() {
if let Some(¤t_priority) = self.priorities.get(&item.key) {
if (item.priority - current_priority).abs() < f64::EPSILON {
self.priorities.remove(&item.key);
return Some(item.key);
}
}
}
None
}
pub fn decrease_priority(&mut self, key: K, new_priority: f64) {
self.priorities.insert(key.clone(), new_priority);
self.heap.push(IndexedItem {
key,
priority: new_priority,
});
}
pub fn is_empty(&self) -> bool {
self.priorities.is_empty()
}
pub fn len(&self) -> usize {
self.priorities.len()
}
pub fn contains(&self, key: &K) -> bool {
self.priorities.contains_key(key)
}
pub fn get_priority(&self, key: &K) -> Option<f64> {
self.priorities.get(key).copied()
}
}
impl<K: Clone + Hash + Eq> Default for IndexedMinHeap<K> {
fn default() -> Self {
Self::new()
}
}
pub type MaxHeap<T> = BinaryHeap<T>;
#[derive(Debug, Clone)]
pub struct ScoredItem<T> {
pub score: f32,
pub item: T,
}
impl<T> ScoredItem<T> {
pub fn new(score: f32, item: T) -> Self {
Self { score, item }
}
}
impl<T> PartialEq for ScoredItem<T> {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl<T> Eq for ScoredItem<T> {}
impl<T> Ord for ScoredItem<T> {
fn cmp(&self, other: &Self) -> Ordering {
self
.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
}
}
impl<T> PartialOrd for ScoredItem<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub type MinScoredHeap<T> = MinHeap<ScoredItem<T>>;
pub type MaxScoredHeap<T> = BinaryHeap<ScoredItem<T>>;
pub struct KNearestHeap<T> {
heap: BinaryHeap<ScoredItem<T>>,
k: usize,
}
impl<T> KNearestHeap<T> {
pub fn new(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
}
}
pub fn push(&mut self, score: f32, item: T) {
if self.heap.len() < self.k {
self.heap.push(ScoredItem::new(score, item));
} else if let Some(worst) = self.heap.peek() {
if score < worst.score {
self.heap.pop();
self.heap.push(ScoredItem::new(score, item));
}
}
}
pub fn worst_score(&self) -> Option<f32> {
self.heap.peek().map(|s| s.score)
}
pub fn is_full(&self) -> bool {
self.heap.len() >= self.k
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn into_sorted(self) -> Vec<(f32, T)> {
let mut items: Vec<_> = self.heap.into_iter().map(|s| (s.score, s.item)).collect();
items.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
items
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_min_heap() {
let mut heap = MinHeap::new();
heap.push(5);
heap.push(3);
heap.push(7);
heap.push(1);
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(5));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_k_nearest_heap() {
let mut heap = KNearestHeap::new(3);
heap.push(5.0, "e");
heap.push(2.0, "b");
heap.push(8.0, "h");
heap.push(1.0, "a");
heap.push(3.0, "c");
assert_eq!(heap.len(), 3);
let results = heap.into_sorted();
assert_eq!(results.len(), 3);
assert_eq!(results[0].1, "a");
assert_eq!(results[1].1, "b");
assert_eq!(results[2].1, "c");
}
#[test]
fn test_k_nearest_heap_cutoff() {
let mut heap = KNearestHeap::new(2);
heap.push(10.0, "bad");
heap.push(1.0, "good");
heap.push(5.0, "ok");
let results = heap.into_sorted();
assert_eq!(results.len(), 2);
assert!(results.iter().all(|(s, _)| *s < 10.0));
}
}