pub struct BoundedHeap<T, D: Fn(&T) -> f64> {
data: Vec<T>,
limit: usize,
dist: D,
}
impl<T, D: Fn(&T) -> f64> BoundedHeap<T, D> {
#[must_use]
pub fn new(limit: usize, dist: D) -> Self {
Self {
data: Vec::with_capacity(limit),
limit,
dist,
}
}
pub fn insert(&mut self, entry: T) {
if self.data.len() < self.limit {
self.data.push(entry);
let n = self.data.len() - 1;
self.sift_up(n);
} else if !self.data.is_empty() && (self.dist)(&entry) < (self.dist)(&self.data[0]) {
self.data[0] = entry;
self.sift_down(0);
}
}
#[must_use]
pub fn peek(&self) -> Option<&T> {
self.data.first()
}
#[must_use]
pub fn to_sorted(&self) -> Vec<&T> {
let mut refs: Vec<&T> = self.data.iter().collect();
refs.sort_by(|a, b| {
(self.dist)(a)
.partial_cmp(&(self.dist)(b))
.unwrap_or(std::cmp::Ordering::Equal)
});
refs
}
#[must_use]
pub const fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.data.is_empty()
}
fn sift_up(&mut self, mut i: usize) {
while i > 0 {
let p = (i - 1) / 2;
if (self.dist)(&self.data[p]) >= (self.dist)(&self.data[i]) {
break;
}
self.data.swap(p, i);
i = p;
}
}
fn sift_down(&mut self, mut i: usize) {
let n = self.data.len();
loop {
let mut largest = i;
let l = 2 * i + 1;
let r = 2 * i + 2;
if l < n && (self.dist)(&self.data[l]) > (self.dist)(&self.data[largest]) {
largest = l;
}
if r < n && (self.dist)(&self.data[r]) > (self.dist)(&self.data[largest]) {
largest = r;
}
if largest == i {
break;
}
self.data.swap(i, largest);
i = largest;
}
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct Item {
id: u32,
dist: f64,
}
#[test]
fn keeps_k_nearest() {
let mut h = BoundedHeap::new(3, |x: &Item| x.dist);
h.insert(Item { id: 1, dist: 10.0 });
h.insert(Item { id: 2, dist: 2.0 });
h.insert(Item { id: 3, dist: 7.0 });
h.insert(Item { id: 4, dist: 1.0 }); h.insert(Item { id: 5, dist: 50.0 }); assert_eq!(h.len(), 3);
let sorted: Vec<u32> = h.to_sorted().iter().map(|x| x.id).collect();
assert_eq!(sorted, vec![4, 2, 3]);
}
#[test]
fn peek_is_max() {
let mut h = BoundedHeap::new(2, |x: &Item| x.dist);
h.insert(Item { id: 1, dist: 5.0 });
h.insert(Item { id: 2, dist: 3.0 });
assert_eq!(h.peek().expect("Heap should not be empty").id, 1); }
#[test]
fn limit_one() {
let mut h = BoundedHeap::new(1, |x: &Item| x.dist);
h.insert(Item { id: 1, dist: 9.0 });
h.insert(Item { id: 2, dist: 3.0 });
assert_eq!(h.len(), 1);
assert_eq!(h.peek().expect("Heap should not be empty").id, 2);
}
#[test]
fn is_empty_check() {
let h: BoundedHeap<Item, _> = BoundedHeap::new(5, |x: &Item| x.dist);
assert!(h.is_empty());
}
#[test]
fn works_with_closure() {
let offset = 1.0;
let mut h = BoundedHeap::new(2, move |x: &Item| x.dist + offset);
h.insert(Item { id: 1, dist: 5.0 });
h.insert(Item { id: 2, dist: 3.0 });
assert_eq!(h.len(), 2);
}
}