use std::{
fmt::{Debug, Display},
marker::PhantomData,
};
use diskann_wide::{SIMDMask, SIMDPartialOrd, SIMDVector};
use super::Neighbor;
pub trait NeighborPriorityQueueIdType: Eq + Clone + Copy + Debug + Display + Send + Sync {}
impl<T> NeighborPriorityQueueIdType for T where T: Eq + Clone + Copy + Debug + Display + Send + Sync {}
pub trait NeighborQueue<I: NeighborPriorityQueueIdType>: Debug + Send + Sync {
type Iter<'a>: ExactSizeIterator<Item = Neighbor<I>> + Send + Sync
where
Self: 'a,
I: 'a;
fn insert(&mut self, nbr: Neighbor<I>);
fn get(&self, index: usize) -> Neighbor<I>;
fn closest_notvisited(&mut self) -> Option<Neighbor<I>>;
fn has_notvisited_node(&self) -> bool;
fn size(&self) -> usize;
fn capacity(&self) -> usize;
fn search_l(&self) -> usize;
fn clear(&mut self);
fn iter(&self) -> Self::Iter<'_>;
}
#[derive(Debug, Clone)]
pub struct NeighborPriorityQueue<I: NeighborPriorityQueueIdType> {
size: usize,
capacity: usize,
cursor: usize,
id_visiteds: Vec<(I, bool)>,
distances: Vec<f32>,
auto_resizable: bool,
search_param_l: usize,
}
impl<I: NeighborPriorityQueueIdType> NeighborPriorityQueue<I> {
pub fn new(search_param_l: usize) -> Self {
Self {
size: 0,
capacity: search_param_l,
cursor: 0,
id_visiteds: Vec::with_capacity(search_param_l),
distances: Vec::with_capacity(search_param_l),
auto_resizable: false,
search_param_l,
}
}
pub fn auto_resizable_with_search_param_l(search_param_l: usize) -> Self {
Self {
size: 0,
capacity: search_param_l,
cursor: 0,
id_visiteds: Vec::with_capacity(search_param_l),
distances: Vec::with_capacity(search_param_l),
auto_resizable: true,
search_param_l,
}
}
pub fn insert(&mut self, nbr: Neighbor<I>) {
if nbr.distance.is_nan() {
return;
}
self.dbgassert_unique_insert(nbr.id);
if self.auto_resizable {
if self.size == self.capacity {
self.reserve(1.max(self.capacity >> 1)); }
} else if self.size == self.capacity && self.get_unchecked(self.size - 1) < nbr {
return;
}
let insert_idx = if self.size > 0 {
self.get_lower_bound(&nbr)
} else {
0
};
if self.size == self.capacity {
self.id_visiteds.truncate(self.size - 1);
self.distances.truncate(self.size - 1);
self.size -= 1;
}
self.id_visiteds.insert(insert_idx, (nbr.id, false));
self.distances.insert(insert_idx, nbr.distance);
self.size += 1;
debug_assert!(self.size == self.id_visiteds.len());
debug_assert!(self.size == self.distances.len());
if insert_idx < self.cursor {
self.cursor = insert_idx;
}
}
pub fn drain_best(&mut self, count: usize) {
let count = count.min(self.size);
self.id_visiteds.drain(0..count);
self.distances.drain(0..count);
self.size -= count;
self.cursor = 0;
}
pub fn iter(&self) -> BestCandidatesIterator<'_, I, Self> {
let sz = self.search_param_l.min(self.size);
BestCandidatesIterator::new(sz, self)
}
pub fn remove(&mut self, nbr: Neighbor<I>) -> bool {
if self.size == 0 {
return false;
}
let index = self.get_lower_bound(&nbr);
if index < self.size && self.get_unchecked(index).id == nbr.id {
self.id_visiteds.remove(index);
self.distances.remove(index);
self.size -= 1;
if index < self.cursor && self.cursor > 0 {
self.cursor -= 1;
}
debug_assert!(self.size == self.id_visiteds.len());
debug_assert!(self.size == self.distances.len());
return true;
}
false
}
fn get_lower_bound(&mut self, nbr: &Neighbor<I>) -> usize {
diskann_wide::alias!(f32s = f32x8);
let target = f32s::splat(diskann_wide::ARCH, nbr.distance);
let mut index = 0;
while index + 16 <= self.size {
let search =
unsafe { f32s::load_simd(diskann_wide::ARCH, self.distances.as_ptr().add(index)) };
let offset1 = search.ge_simd(target).first();
let search = unsafe {
f32s::load_simd(diskann_wide::ARCH, self.distances.as_ptr().add(index + 8))
};
let offset2 = search.ge_simd(target).first();
match (offset1, offset2) {
(Some(offset), _) => return index + offset,
(None, Some(offset)) => return index + 8 + offset,
_ => (),
}
index += 16;
}
if index + 8 <= self.size {
let search =
unsafe { f32s::load_simd(diskann_wide::ARCH, self.distances.as_ptr().add(index)) };
let offset = search.ge_simd(target).first();
if let Some(offset) = offset {
return index + offset;
}
index += 8;
}
if index < self.size {
let search = unsafe {
f32s::load_simd_first(
diskann_wide::ARCH,
self.distances.as_ptr().add(index),
self.size - index,
)
};
let offset = search.ge_simd(target).first();
if let Some(offset) = offset {
return index + offset;
}
}
self.size
}
fn get_unchecked(&self, index: usize) -> Neighbor<I> {
debug_assert!(index < self.size);
let id = unsafe { self.id_visiteds.get_unchecked(index).0 };
let distance = unsafe { *self.distances.get_unchecked(index) };
Neighbor::new(id, distance)
}
pub fn get(&self, index: usize) -> Neighbor<I> {
assert!(index < self.size, "index out of bounds");
self.get_unchecked(index)
}
pub fn closest_notvisited(&mut self) -> Option<Neighbor<I>> {
if !self.has_notvisited_node() {
return None;
}
let current = self.cursor;
self.set_visited(current, true);
self.cursor += 1;
while self.cursor < self.size && self.get_visited(self.cursor) {
self.cursor += 1;
}
Some(self.get_unchecked(current))
}
pub fn has_notvisited_node(&self) -> bool {
self.cursor < self.search_param_l.min(self.size)
}
pub fn size(&self) -> usize {
self.size
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn search_l(&self) -> usize {
self.search_param_l
}
pub fn reconfigure(&mut self, search_param_l: usize) {
self.search_param_l = search_param_l;
if search_param_l < self.size {
self.id_visiteds.truncate(search_param_l);
self.distances.truncate(search_param_l);
self.size = search_param_l;
self.cursor = self.cursor.min(search_param_l);
} else if search_param_l > self.capacity {
let additional = search_param_l - self.size;
self.id_visiteds.reserve(additional);
self.distances.reserve(additional);
}
self.capacity = search_param_l;
}
fn reserve(&mut self, additional: usize) {
self.id_visiteds.reserve(additional);
self.distances.reserve(additional);
self.capacity += additional;
}
pub fn clear(&mut self) {
self.id_visiteds.clear();
self.distances.clear();
self.size = 0;
self.cursor = 0;
}
fn set_visited(&mut self, index: usize, flag: bool) {
assert!(index < self.size);
assert!(self.size <= self.capacity);
unsafe { self.id_visiteds.get_unchecked_mut(index) }.1 = flag;
}
fn get_visited(&self, index: usize) -> bool {
assert!(index < self.size);
unsafe { self.id_visiteds.get_unchecked(index).1 }
}
pub fn is_resizable(&self) -> bool {
self.auto_resizable
}
pub fn is_full(&self) -> bool {
!self.auto_resizable && self.size == self.capacity
}
#[cfg(debug_assertions)]
fn dbgassert_unique_insert(&self, id: I) {
for i in 0..self.size {
debug_assert!(
self.id_visiteds[i].0 != id,
"Neighbor with ID {} already exists in the priority queue",
id
);
}
}
#[cfg(not(debug_assertions))]
fn dbgassert_unique_insert(&self, _id: I) {}
#[cfg(feature = "experimental_diversity_search")]
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&Neighbor<I>) -> bool,
{
if self.size == 0 {
return;
}
let mut write_idx = 0;
for read_idx in 0..self.size {
let neighbor = self.get_unchecked(read_idx);
if f(&neighbor) {
if write_idx != read_idx {
self.id_visiteds[write_idx] = self.id_visiteds[read_idx];
self.distances[write_idx] = self.distances[read_idx];
}
self.id_visiteds[write_idx].1 = false;
write_idx += 1;
}
}
self.truncate(write_idx);
}
#[cfg(feature = "experimental_diversity_search")]
pub fn truncate(&mut self, len: usize) {
let new_size = len;
if new_size < self.size {
self.id_visiteds.truncate(new_size);
self.distances.truncate(new_size);
self.size = new_size;
self.cursor = 0;
}
}
}
impl<I: NeighborPriorityQueueIdType> NeighborQueue<I> for NeighborPriorityQueue<I> {
type Iter<'a>
= BestCandidatesIterator<'a, I, Self>
where
Self: 'a,
I: 'a;
fn insert(&mut self, nbr: Neighbor<I>) {
self.insert(nbr)
}
fn get(&self, index: usize) -> Neighbor<I> {
self.get(index)
}
fn closest_notvisited(&mut self) -> Option<Neighbor<I>> {
self.closest_notvisited()
}
fn has_notvisited_node(&self) -> bool {
self.has_notvisited_node()
}
fn size(&self) -> usize {
self.size()
}
fn capacity(&self) -> usize {
self.capacity()
}
fn search_l(&self) -> usize {
self.search_l()
}
fn clear(&mut self) {
self.clear()
}
fn iter(&self) -> Self::Iter<'_> {
self.iter()
}
}
impl<'a, I> IntoIterator for &'a NeighborPriorityQueue<I>
where
I: NeighborPriorityQueueIdType,
{
type Item = Neighbor<I>;
type IntoIter = BestCandidatesIterator<'a, I, NeighborPriorityQueue<I>>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct BestCandidatesIterator<'a, I, Q>
where
I: NeighborPriorityQueueIdType,
Q: NeighborQueue<I> + ?Sized,
{
cursor: usize,
size: usize,
queue: &'a Q,
_phantom: PhantomData<I>,
}
impl<'a, I, Q> BestCandidatesIterator<'a, I, Q>
where
I: NeighborPriorityQueueIdType,
Q: NeighborQueue<I> + ?Sized,
{
pub fn new(size: usize, queue: &'a Q) -> Self {
Self {
cursor: 0,
size,
queue,
_phantom: PhantomData,
}
}
}
impl<I, Q> Iterator for BestCandidatesIterator<'_, I, Q>
where
I: NeighborPriorityQueueIdType,
Q: NeighborQueue<I> + ?Sized,
{
type Item = Neighbor<I>;
fn next(&mut self) -> Option<Self::Item> {
if self.cursor < self.size {
let result = self.queue.get(self.cursor);
self.cursor += 1;
Some(result)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.size - self.cursor;
(remaining, Some(remaining))
}
}
impl<I, Q> ExactSizeIterator for BestCandidatesIterator<'_, I, Q>
where
I: NeighborPriorityQueueIdType,
Q: NeighborQueue<I> + ?Sized,
{
}
#[cfg(test)]
mod neighbor_priority_queue_test {
use rand::{Rng, SeedableRng};
use super::*;
#[test]
fn test_reconfigure() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
assert_eq!(queue.capacity(), 10);
assert_eq!(queue.search_l(), 10);
queue.reconfigure(20);
assert_eq!(queue.capacity(), 20);
assert_eq!(queue.search_l(), 20);
queue.reconfigure(20);
assert_eq!(queue.capacity(), 20);
assert_eq!(queue.search_l(), 20);
queue.reconfigure(10);
assert_eq!(queue.capacity(), 10);
assert_eq!(queue.search_l(), 10);
}
#[test]
fn test_insert() {
let mut queue = NeighborPriorityQueue::new(3);
assert_eq!(queue.size(), 0);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_eq!(queue.size(), 2);
queue.insert(Neighbor::new(3, 0.9));
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(2).id, 1);
queue.insert(Neighbor::new(4, 2.0));
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 3);
assert_eq!(queue.get(2).id, 1);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn test_repeat_insert_panics() {
let mut queue = NeighborPriorityQueue::new(10);
assert_eq!(queue.size(), 0);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(1, 0.5));
}
#[test]
fn test_is_sorted() {
let mut queue = NeighborPriorityQueue::new(40);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let data: Vec<f32> = (0..60).map(|_| rng.random_range(-1.0..1.0)).collect();
for i in 0..60 {
queue.insert(Neighbor::new(i, data[i as usize]));
}
for i in 0..39 {
assert!(queue.get(i).distance <= queue.get(i + 1).distance);
}
}
#[test]
fn test_index() {
let mut queue = NeighborPriorityQueue::new(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
assert_eq!(queue.get(0).id, 2);
assert_eq!(queue.get(0).distance, 0.5);
}
#[test]
fn test_visit() {
let mut queue = NeighborPriorityQueue::new(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert!(!queue.get_visited(0));
queue.insert(Neighbor::new(3, 1.5)); assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 2);
assert_eq!(nbr.distance, 0.5);
assert!(queue.get_visited(0)); assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 1);
assert_eq!(nbr.distance, 1.0);
assert!(queue.get_visited(1));
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 3);
assert_eq!(nbr.distance, 1.5);
assert!(queue.get_visited(2));
assert!(!queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_none());
}
#[test]
fn test_closest_notvisited_when_no_notvisited_nodes_left() {
let mut queue = NeighborPriorityQueue::new(1);
queue.insert(Neighbor::new(1, 1.0));
assert!(queue.closest_notvisited().is_some());
assert!(queue.closest_notvisited().is_none());
}
#[test]
fn test_clear_queue() {
let mut queue = NeighborPriorityQueue::new(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_eq!(queue.size(), 2);
assert!(queue.has_notvisited_node());
queue.clear();
assert_eq!(queue.size(), 0);
assert!(!queue.has_notvisited_node());
}
#[test]
fn test_reserve() {
let mut queue = NeighborPriorityQueue::<u32>::new(5);
queue.reconfigure(10);
assert_eq!(queue.id_visiteds.len(), 0);
assert_eq!(queue.distances.len(), 0);
assert_eq!(queue.capacity, 10);
}
#[test]
fn test_set_capacity() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.reconfigure(5);
assert_eq!(queue.capacity, 5);
assert_eq!(queue.id_visiteds.len(), 0);
assert_eq!(queue.distances.len(), 0);
queue.reconfigure(11);
assert_eq!(queue.capacity, 11);
}
#[test]
fn test_resizable_with_initial_capacity() {
let resizable_queue = NeighborPriorityQueue::<u32>::auto_resizable_with_search_param_l(10);
assert_eq!(resizable_queue.capacity(), 10);
assert_eq!(resizable_queue.size(), 0);
assert!(resizable_queue.auto_resizable);
assert_eq!(resizable_queue.id_visiteds.len(), 0);
assert_eq!(resizable_queue.distances.len(), 0);
}
#[test]
fn test_insert_on_full_queue() {
let mut fixed_queue = NeighborPriorityQueue::new(5);
fixed_queue.insert(Neighbor::new(5, 0.5));
fixed_queue.insert(Neighbor::new(2, 0.2));
fixed_queue.insert(Neighbor::new(4, 0.4));
fixed_queue.insert(Neighbor::new(1, 0.1));
fixed_queue.insert(Neighbor::new(3, 0.3));
fixed_queue.insert(Neighbor::new(6, 0.6));
assert_eq!(fixed_queue.get(4).id, 5);
assert_eq!(fixed_queue.capacity(), 5);
assert_eq!(fixed_queue.size(), 5);
fixed_queue.insert(Neighbor::new(35, 0.35));
assert_eq!(fixed_queue.get(4).id, 4);
assert_eq!(fixed_queue.capacity(), 5);
assert_eq!(fixed_queue.size(), 5);
}
#[test]
fn test_reconfigure_after_insert() {
let mut queue = NeighborPriorityQueue::new(5);
queue.insert(Neighbor::new(5, 0.5));
queue.insert(Neighbor::new(2, 0.2));
queue.insert(Neighbor::new(4, 0.4));
queue.insert(Neighbor::new(1, 0.1));
queue.insert(Neighbor::new(3, 0.3));
assert!(queue.closest_notvisited().is_some());
assert!(queue.closest_notvisited().is_some());
assert!(queue.closest_notvisited().is_some());
assert!(queue.closest_notvisited().is_some());
assert_eq!(queue.capacity(), 5);
assert_eq!(queue.size(), 5);
assert_eq!(queue.search_l(), 5);
assert_eq!(queue.cursor, 4);
queue.reconfigure(3);
assert_eq!(queue.capacity(), 3);
assert_eq!(queue.size(), 3);
assert_eq!(queue.search_l(), 3);
assert_eq!(queue.cursor, 3);
queue.reconfigure(5);
assert_eq!(queue.capacity(), 5);
assert_eq!(queue.size(), 3);
assert_eq!(queue.search_l(), 5);
assert_eq!(queue.cursor, 3);
}
#[test]
fn test_insert_on_resizable_queue() {
let mut resizable_queue = NeighborPriorityQueue::auto_resizable_with_search_param_l(2);
resizable_queue.insert(Neighbor::new(1, 1.0));
resizable_queue.insert(Neighbor::new(2, 0.5));
assert_eq!(resizable_queue.size(), 2);
assert_eq!(resizable_queue.capacity(), 2);
resizable_queue.insert(Neighbor::new(3, 0.9));
assert_eq!(resizable_queue.size(), 3);
assert_eq!(resizable_queue.capacity(), 3);
resizable_queue.insert(Neighbor::new(4, 2.0));
assert_eq!(resizable_queue.size(), 4);
assert_eq!(resizable_queue.capacity(), 4);
}
#[test]
fn test_iter() {
let mut queue = NeighborPriorityQueue::<u32>::auto_resizable_with_search_param_l(3);
assert_eq!(queue.size(), 0);
let mut iter = (&queue).into_iter();
let iter_dyn: &mut dyn ExactSizeIterator<Item = Neighbor<u32>> = &mut iter;
assert_eq!(iter_dyn.len(), 0);
assert!(iter_dyn.next().is_none());
queue.insert(Neighbor::new(1, 1.0));
assert_eq!(queue.size(), 1);
let mut iter = queue.iter();
assert_eq!(iter.len(), 1);
assert_eq!(iter.next().unwrap().id, 1);
assert!(iter.next().is_none());
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 0.1));
queue.insert(Neighbor::new(4, 5.0));
queue.insert(Neighbor::new(5, 0.2));
let mut iter = queue.iter();
assert_eq!(iter.len(), 3);
assert_eq!(iter.next().unwrap().id, 3);
assert_eq!(iter.next().unwrap().id, 5);
assert_eq!(iter.next().unwrap().id, 2);
assert!(iter.next().is_none());
for (i, neighbor) in (&queue).into_iter().enumerate() {
assert_eq!(neighbor.id, queue.get(i).id);
}
queue.clear();
let mut iter = queue.iter();
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());
}
#[test]
fn test_has_notvisited_node_fixed_size_queue() {
let mut queue = NeighborPriorityQueue::new(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_queue_size_search_param_l_cursor(
&queue, 2, 3, 0,
);
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert_queue_size_search_param_l_cursor(
&queue, 2, 3, 2,
);
assert!(!queue.has_notvisited_node());
queue.insert(Neighbor::new(3, 0.1));
queue.insert(Neighbor::new(4, 5.0));
assert_queue_size_search_param_l_cursor(
&queue, 3, 3, 0,
);
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert_queue_size_search_param_l_cursor(
&queue, 3, 3, 3,
);
assert!(!queue.has_notvisited_node());
}
#[test]
fn test_has_notvisited_node_fixed_size_queue_with_mannual_resize() {
let mut queue = NeighborPriorityQueue::new(3);
queue.reconfigure(5);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 0.1));
queue.insert(Neighbor::new(4, 5.0));
queue.insert(Neighbor::new(5, 0.2));
assert_queue_size_search_param_l_cursor(
&queue, 5, 5, 0,
);
for i in 1..=5 {
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert_queue_size_search_param_l_cursor(
&queue, 5, 5, i,
);
}
assert!(!queue.has_notvisited_node());
}
#[test]
fn test_has_notvisited_auto_resizable_queue() {
let mut queue = NeighborPriorityQueue::auto_resizable_with_search_param_l(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_queue_size_search_param_l_cursor(
&queue, 2, 3, 0,
);
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert_queue_size_search_param_l_cursor(
&queue, 2, 3, 2,
);
assert!(!queue.has_notvisited_node());
queue.insert(Neighbor::new(3, 0.1));
queue.insert(Neighbor::new(4, 5.0));
assert_queue_size_search_param_l_cursor(
&queue, 4, 3, 0,
);
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert_queue_size_search_param_l_cursor(
&queue, 4, 3, 3,
);
assert!(!queue.has_notvisited_node());
}
fn assert_queue_size_search_param_l_cursor(
queue: &NeighborPriorityQueue<u32>,
size: usize,
search_param_l: usize,
cursor: usize,
) {
assert_eq!(queue.size(), size);
assert_eq!(queue.search_param_l, search_param_l);
assert_eq!(queue.cursor, cursor);
}
#[cfg(not(miri))]
#[test]
fn insertion_is_in_sorted_order() {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let a: Vec<f32> = (0..100)
.map(|_| rng.random_range(-1000000.0..1000000.0))
.collect();
for i in 0..100usize {
let capacity = i + 1;
let mut queue = NeighborPriorityQueue::new(capacity);
for (j, &v) in a.iter().enumerate() {
queue.insert(Neighbor::new(j as u32, v));
}
for j in 0..capacity - 1 {
assert!(queue.get(j).distance <= queue.get(j + 1).distance);
}
}
}
#[test]
fn test_trait_implementation_basic_operations() {
let mut queue = NeighborPriorityQueue::new(5);
assert_eq!(queue.size(), 0);
assert_eq!(queue.capacity(), 5);
assert_eq!(queue.search_l(), 5);
assert!(!queue.has_notvisited_node());
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
assert_eq!(queue.size(), 3);
assert!(queue.has_notvisited_node());
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 1); assert_eq!(queue.get(2).id, 3);
let closest = queue.closest_notvisited().unwrap();
assert_eq!(closest.id, 2);
assert_eq!(closest.distance, 0.5);
queue.clear();
assert_eq!(queue.size(), 0);
assert!(!queue.has_notvisited_node());
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
let mut iter = NeighborQueue::iter(&queue);
assert_eq!(iter.len(), 2);
assert_eq!(iter.next().unwrap().id, 2);
assert_eq!(iter.next().unwrap().id, 1);
}
#[test]
fn test_trait_implementation_drain() {
let mut queue = NeighborPriorityQueue::auto_resizable_with_search_param_l(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.1));
queue.insert(Neighbor::new(5, 2.0));
assert_eq!(queue.size(), 5);
queue.drain_best(3);
assert_eq!(queue.size(), 2);
assert_eq!(queue.get(0).id, 3); assert_eq!(queue.get(1).id, 5); }
#[test]
fn test_trait_implementation_reconfigure() {
let mut queue = NeighborPriorityQueue::new(5);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
assert_eq!(queue.capacity(), 5);
assert_eq!(queue.search_l(), 5);
assert_eq!(queue.size(), 3);
queue.reconfigure(2);
assert_eq!(queue.capacity(), 2);
assert_eq!(queue.search_l(), 2);
assert_eq!(queue.size(), 2);
queue.reconfigure(10);
assert_eq!(queue.capacity(), 10);
assert_eq!(queue.search_l(), 10);
assert_eq!(queue.size(), 2); }
#[test]
fn test_trait_polymorphism() {
fn test_queue_operations<T: NeighborQueue<u32>>(mut queue: T) {
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_eq!(queue.size(), 2);
assert!(queue.has_notvisited_node());
let closest = queue.closest_notvisited().unwrap();
assert_eq!(closest.id, 2);
}
let fixed_queue = NeighborPriorityQueue::new(5);
test_queue_operations(fixed_queue);
let resizable_queue = NeighborPriorityQueue::auto_resizable_with_search_param_l(5);
test_queue_operations(resizable_queue);
}
#[test]
fn test_remove() {
let mut queue = NeighborPriorityQueue::new(10);
assert!(!queue.remove(Neighbor::new(1, 1.0)));
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
queue.insert(Neighbor::new(5, 2.0));
assert_eq!(queue.size(), 5);
assert!(queue.remove(Neighbor::new(1, 1.0)));
assert_eq!(queue.size(), 4);
assert_eq!(queue.get(0).id, 4); assert_eq!(queue.get(1).id, 2); assert_eq!(queue.get(2).id, 3); assert_eq!(queue.get(3).id, 5);
assert!(queue.remove(Neighbor::new(4, 0.3)));
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 3); assert_eq!(queue.get(2).id, 5);
assert!(queue.remove(Neighbor::new(5, 2.0)));
assert_eq!(queue.size(), 2);
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 3);
assert!(!queue.remove(Neighbor::new(99, 0.5)));
assert_eq!(queue.size(), 2);
assert!(!queue.remove(Neighbor::new(2, 99.0)));
assert_eq!(queue.size(), 2);
assert!(queue.remove(Neighbor::new(2, 0.5)));
assert_eq!(queue.size(), 1);
assert!(queue.remove(Neighbor::new(3, 1.5)));
assert_eq!(queue.size(), 0);
assert!(!queue.remove(Neighbor::new(1, 1.0)));
}
#[test]
fn test_remove_with_cursor() {
let mut queue = NeighborPriorityQueue::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
assert!(queue.closest_notvisited().is_some()); assert!(queue.closest_notvisited().is_some());
assert_eq!(queue.cursor, 2);
assert!(queue.remove(Neighbor::new(4, 0.3)));
assert_eq!(queue.cursor, 1);
assert_eq!(queue.size(), 3);
assert!(queue.remove(Neighbor::new(1, 1.0)));
assert_eq!(queue.cursor, 1);
assert_eq!(queue.size(), 2);
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 3); }
#[test]
fn test_remove_maintains_sorted_order() {
let mut queue = NeighborPriorityQueue::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
queue.insert(Neighbor::new(5, 2.0));
queue.insert(Neighbor::new(6, 0.8));
queue.remove(Neighbor::new(3, 1.5));
queue.remove(Neighbor::new(4, 0.3));
for i in 0..queue.size() - 1 {
assert!(queue.get(i).distance <= queue.get(i + 1).distance);
}
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 6); assert_eq!(queue.get(2).id, 1); assert_eq!(queue.get(3).id, 5); }
#[test]
fn test_insert_neighbors_with_infinity_distance() {
let mut queue = NeighborPriorityQueue::new(5);
assert_eq!(queue.size(), 0);
assert_eq!(queue.capacity(), 5);
for id in 0..2 {
queue.insert(Neighbor::new(id, f32::INFINITY));
}
assert_eq!(queue.size(), 2);
assert_eq!(queue.capacity(), 5);
for id in 2..10 {
queue.insert(Neighbor::new(id, f32::INFINITY));
}
assert_eq!(queue.size(), 5);
assert_eq!(queue.capacity(), 5);
assert!(queue.get(0).id >= 0, "First element should be retrievable");
}
#[test]
fn test_normal_distances_should_push_infinity_distances_away_from_queue() {
let mut queue = NeighborPriorityQueue::new(5);
assert_eq!(queue.size(), 0);
assert_eq!(queue.capacity(), 5);
for id in 0..=4 {
queue.insert(Neighbor::new(id, f32::INFINITY));
}
assert_eq!(queue.size(), 5);
assert_eq!(queue.capacity(), 5);
assert!(queue.get(0).id >= 0, "First element should be retrievable");
for id in 5..=7 {
queue.insert(Neighbor::new(id, id as f32));
}
assert_eq!(queue.size(), 5);
assert_eq!(queue.capacity(), 5);
assert_eq!(queue.get(0).id, 5);
assert_eq!(queue.get(1).id, 6);
assert_eq!(queue.get(2).id, 7);
assert_eq!(queue.get(3).id, 4);
assert_eq!(queue.get(4).id, 3);
}
#[test]
fn test_insert_neighbor_with_nan_distance_is_ignored() {
let mut queue = NeighborPriorityQueue::new(5);
assert_eq!(queue.size(), 0);
queue.insert(Neighbor::new(0, f32::NAN));
assert_eq!(queue.size(), 0);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_retain() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
queue.insert(Neighbor::new(5, 2.0));
queue.insert(Neighbor::new(6, 0.8));
assert_eq!(queue.size(), 6);
queue.retain(|nbr| nbr.distance <= 1.0);
assert_eq!(queue.size(), 4);
assert_eq!(queue.get(0).id, 4); assert_eq!(queue.get(1).id, 2); assert_eq!(queue.get(2).id, 6); assert_eq!(queue.get(3).id, 1);
queue.retain(|nbr| nbr.id >= 3);
assert_eq!(queue.size(), 2);
assert_eq!(queue.get(0).id, 4);
assert_eq!(queue.get(1).id, 6);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_retain_empty() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.retain(|_| true);
assert_eq!(queue.size(), 0);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_retain_remove_all() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.retain(|_| false);
assert_eq!(queue.size(), 0);
assert_eq!(queue.cursor, 0);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_retain_remove_none() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.retain(|_| true);
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(0).id, 2); assert_eq!(queue.get(1).id, 1); assert_eq!(queue.get(2).id, 3); }
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_retain_resets_visited_state() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
assert!(queue.closest_notvisited().is_some()); assert!(queue.closest_notvisited().is_some());
assert_eq!(queue.cursor, 2);
assert!(!queue.has_notvisited_node() || queue.cursor < queue.size());
queue.retain(|nbr| nbr.distance <= 1.0);
assert_eq!(queue.size(), 3);
assert_eq!(queue.cursor, 0);
assert!(queue.has_notvisited_node());
let first = queue.closest_notvisited().unwrap();
assert_eq!(first.id, 4); assert_eq!(queue.cursor, 1);
let second = queue.closest_notvisited().unwrap();
assert_eq!(second.id, 2); assert_eq!(queue.cursor, 2);
let third = queue.closest_notvisited().unwrap();
assert_eq!(third.id, 1); assert_eq!(queue.cursor, 3);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_truncate() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
queue.insert(Neighbor::new(5, 2.0));
assert_eq!(queue.size(), 5);
queue.truncate(3);
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(0).id, 4); assert_eq!(queue.get(1).id, 2); assert_eq!(queue.get(2).id, 1); }
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_truncate_larger_size() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.truncate(10);
assert_eq!(queue.size(), 2);
}
#[test]
#[cfg(feature = "experimental_diversity_search")]
fn test_truncate_with_cursor() {
let mut queue = NeighborPriorityQueue::<u32>::new(10);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
queue.insert(Neighbor::new(4, 0.3));
assert!(queue.closest_notvisited().is_some()); assert!(queue.closest_notvisited().is_some());
assert_eq!(queue.cursor, 2);
queue.truncate(1);
assert_eq!(queue.size(), 1);
assert_eq!(queue.cursor, 0); }
}