#![allow(dead_code)]
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DistId {
pub dist: f32,
pub id: usize,
}
impl DistId {
#[inline]
pub fn new(dist: f32, id: usize) -> Self {
Self { dist, id }
}
}
impl Eq for DistId {}
impl PartialOrd for DistId {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DistId {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.dist
.total_cmp(&other.dist)
.then_with(|| self.id.cmp(&other.id))
}
}
#[derive(Default)]
pub struct CandidateHeap(BinaryHeap<std::cmp::Reverse<DistId>>);
impl CandidateHeap {
pub fn new() -> Self {
Self(BinaryHeap::new())
}
pub fn with_capacity(cap: usize) -> Self {
Self(BinaryHeap::with_capacity(cap))
}
#[inline]
pub fn push(&mut self, dist: f32, id: usize) {
self.0.push(std::cmp::Reverse(DistId::new(dist, id)));
}
#[inline]
pub fn pop(&mut self) -> Option<DistId> {
self.0.pop().map(|std::cmp::Reverse(x)| x)
}
#[inline]
pub fn peek_dist(&self) -> Option<f32> {
self.0.peek().map(|std::cmp::Reverse(x)| x.dist)
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
pub struct ResultHeap {
inner: BinaryHeap<DistId>,
capacity: usize,
}
impl ResultHeap {
pub fn new(capacity: usize) -> Self {
Self {
inner: BinaryHeap::with_capacity(capacity + 1),
capacity,
}
}
#[inline]
pub fn push(&mut self, dist: f32, id: usize) {
self.inner.push(DistId::new(dist, id));
if self.inner.len() > self.capacity {
self.inner.pop(); }
}
#[inline]
pub fn peek_worst_dist(&self) -> Option<f32> {
self.inner.peek().map(|x| x.dist)
}
#[inline]
pub fn pop_worst(&mut self) -> Option<DistId> {
self.inner.pop()
}
#[inline]
pub fn len(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn into_sorted_vec(self) -> Vec<DistId> {
self.inner.into_sorted_vec()
}
pub fn iter(&self) -> impl Iterator<Item = &DistId> {
self.inner.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn candidate_heap_min_order() {
let mut h = CandidateHeap::new();
h.push(3.0, 3);
h.push(1.0, 1);
h.push(2.0, 2);
assert_eq!(h.pop().unwrap().dist, 1.0);
assert_eq!(h.pop().unwrap().dist, 2.0);
assert_eq!(h.pop().unwrap().dist, 3.0);
assert!(h.pop().is_none());
}
#[test]
fn result_heap_bounded() {
let mut h = ResultHeap::new(3);
for i in 0..6_u32 {
h.push(i as f32, i as usize);
}
assert_eq!(h.len(), 3);
let v = h.into_sorted_vec();
assert_eq!(v[0].dist, 0.0);
assert_eq!(v[1].dist, 1.0);
assert_eq!(v[2].dist, 2.0);
}
}