use std::{
collections::HashMap,
fmt::{Debug, Display},
hash::Hash,
num::NonZeroUsize,
sync::Arc,
};
use crate::neighbor::{
Neighbor,
queue::{
BestCandidatesIterator, NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueue,
},
};
pub trait Attribute: Hash + Eq + Copy + Default + Debug + Display + Send + Sync {}
impl<T> Attribute for T where T: Hash + Eq + Copy + Default + Debug + Display + Send + Sync {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct VectorIdWithAttribute<I, A>
where
I: NeighborPriorityQueueIdType,
A: Attribute,
{
pub id: I,
pub attribute: A,
}
impl<I, A> VectorIdWithAttribute<I, A>
where
I: NeighborPriorityQueueIdType,
A: Attribute,
{
fn new(id: I, attribute: A) -> Self {
Self { id, attribute }
}
}
impl<I, A> Display for VectorIdWithAttribute<I, A>
where
I: NeighborPriorityQueueIdType,
A: Attribute,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}, {})", self.id, self.attribute)
}
}
#[derive(Debug, Clone)]
pub struct DiverseNeighborQueue<P>
where
P: AttributeValueProvider,
{
global_queue: NeighborPriorityQueue<VectorIdWithAttribute<P::Id, P::Value>>,
local_queue_map: HashMap<P::Value, NeighborPriorityQueue<P::Id>>,
attribute_provider: Arc<P>,
diverse_results_l: usize,
diverse_results_k: usize,
}
impl<P> DiverseNeighborQueue<P>
where
P: AttributeValueProvider,
{
pub fn new(
l_value: usize,
k_value: NonZeroUsize,
diverse_results_k: usize,
attribute_provider: Arc<P>,
) -> Self {
let diverse_results_l = diverse_results_k * l_value / k_value.get();
Self {
global_queue: NeighborPriorityQueue::new(l_value),
local_queue_map: HashMap::new(),
attribute_provider,
diverse_results_l,
diverse_results_k,
}
}
pub fn post_process(&mut self) {
use hashbrown::HashSet;
let mut removed_items = HashSet::new();
for local_queue in self.local_queue_map.values_mut() {
if local_queue.size() > self.diverse_results_k {
removed_items.extend(
local_queue
.iter()
.skip(self.diverse_results_k)
.map(|n| n.id),
);
local_queue.truncate(self.diverse_results_k);
}
}
if !removed_items.is_empty() {
self.global_queue
.retain(|neighbor| !removed_items.contains(&neighbor.id.id));
}
}
}
impl<P> NeighborQueue<P::Id> for DiverseNeighborQueue<P>
where
P: AttributeValueProvider,
{
type Iter<'a>
= BestCandidatesIterator<'a, P::Id, Self>
where
Self: 'a,
P::Id: 'a;
fn insert(&mut self, nbr: Neighbor<P::Id>) {
let Some(attribute_value) = self.attribute_provider.get(nbr.id) else {
return;
};
let local_queue = self
.local_queue_map
.entry(attribute_value)
.or_insert_with(|| NeighborPriorityQueue::new(self.diverse_results_l));
let local_queue_full = local_queue.is_full();
let global_queue_full = self.global_queue.is_full();
let nbr_with_attribute = Neighbor::new(
VectorIdWithAttribute::new(nbr.id, attribute_value),
nbr.distance,
);
if !local_queue_full && !global_queue_full {
local_queue.insert(nbr);
self.global_queue.insert(nbr_with_attribute);
} else if local_queue_full {
if nbr.distance < local_queue.get(self.diverse_results_l - 1).distance {
let worst_neighbor = local_queue.get(self.diverse_results_l - 1);
let worst_neighbor_with_attribute = Neighbor::new(
VectorIdWithAttribute::new(worst_neighbor.id, attribute_value),
worst_neighbor.distance,
);
self.global_queue.remove(worst_neighbor_with_attribute);
local_queue.insert(nbr);
self.global_queue.insert(nbr_with_attribute);
}
} else if !local_queue_full && global_queue_full {
let l_size = self.global_queue.search_l();
if nbr.distance < self.global_queue.get(l_size - 1).distance {
let worst_global = self.global_queue.get(l_size - 1);
let attribute_of_worst_global = worst_global.id.attribute;
local_queue.insert(nbr);
self.global_queue.insert(nbr_with_attribute);
if let Some(local_queue) = self.local_queue_map.get_mut(&attribute_of_worst_global)
{
let worst_neighbor_without_attribute =
Neighbor::new(worst_global.id.id, worst_global.distance);
local_queue.remove(worst_neighbor_without_attribute);
}
}
}
}
fn get(&self, index: usize) -> Neighbor<P::Id> {
let neighbor_with_attribute = self.global_queue.get(index);
Neighbor::new(
neighbor_with_attribute.id.id,
neighbor_with_attribute.distance,
)
}
fn closest_notvisited(&mut self) -> Option<Neighbor<P::Id>> {
let neighbor_with_attribute = self.global_queue.closest_notvisited()?;
Some(Neighbor::new(
neighbor_with_attribute.id.id,
neighbor_with_attribute.distance,
))
}
fn has_notvisited_node(&self) -> bool {
self.global_queue.has_notvisited_node()
}
fn size(&self) -> usize {
self.global_queue.size()
}
fn capacity(&self) -> usize {
self.global_queue.capacity()
}
fn search_l(&self) -> usize {
self.global_queue.search_l()
}
fn clear(&mut self) {
self.global_queue.clear();
self.local_queue_map.clear();
}
fn iter(&self) -> BestCandidatesIterator<'_, P::Id, Self> {
let sz = self.global_queue.search_l().min(self.global_queue.size());
BestCandidatesIterator::new(sz, self)
}
}
pub trait AttributeValueProvider: crate::provider::HasId + Send + Sync + std::fmt::Debug {
type Value: Attribute;
fn get(&self, id: Self::Id) -> Option<Self::Value>;
}
#[cfg(test)]
mod diverse_priority_queue_test {
use super::*;
#[derive(Debug, Clone)]
struct TestAttributeValueProvider {
attributes: HashMap<u32, u32>,
}
impl TestAttributeValueProvider {
fn new() -> Self {
Self {
attributes: HashMap::new(),
}
}
fn insert(&mut self, vector_id: u32, attribute_value: u32) {
self.attributes.insert(vector_id, attribute_value);
}
}
impl crate::provider::HasId for TestAttributeValueProvider {
type Id = u32;
}
impl AttributeValueProvider for TestAttributeValueProvider {
type Value = u32;
fn get(&self, id: Self::Id) -> Option<Self::Value> {
self.attributes.get(&id).copied()
}
}
impl Default for TestAttributeValueProvider {
fn default() -> Self {
Self::new()
}
}
type TestDiverseQueue = DiverseNeighborQueue<TestAttributeValueProvider>;
fn create_test_attribute_provider() -> Arc<TestAttributeValueProvider> {
let mut provider = TestAttributeValueProvider::new();
for i in 0..20 {
provider.insert(i, i / 3);
}
Arc::new(provider)
}
#[test]
fn test_new() {
let attribute_provider = create_test_attribute_provider();
let queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
assert_eq!(queue.size(), 0);
assert_eq!(queue.capacity(), 10);
assert_eq!(queue.search_l(), 10);
assert_eq!(queue.diverse_results_l, 10);
}
#[test]
fn test_insert_single_attribute() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
queue.insert(Neighbor::new(2, 1.5));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 1);
assert!(queue.local_queue_map.contains_key(&0));
}
#[test]
fn test_insert_multiple_attributes() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0)); queue.insert(Neighbor::new(3, 0.8)); queue.insert(Neighbor::new(6, 1.2));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 3);
assert!(queue.local_queue_map.contains_key(&0));
assert!(queue.local_queue_map.contains_key(&1));
assert!(queue.local_queue_map.contains_key(&2));
}
#[test]
fn test_insert_maintains_order() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
queue.insert(Neighbor::new(2, 1.5));
assert_eq!(queue.get(0).id, 1); assert_eq!(queue.get(1).id, 0); assert_eq!(queue.get(2).id, 2); }
#[test]
fn test_insert_local_queue_full() {
let mut attribute_provider = TestAttributeValueProvider::new();
for i in 10..=15 {
attribute_provider.insert(i, 0);
}
let mut queue = TestDiverseQueue::new(
20,
NonZeroUsize::new(20).unwrap(),
3,
Arc::new(attribute_provider),
);
queue.insert(Neighbor::new(10, 1.0));
queue.insert(Neighbor::new(11, 0.8));
queue.insert(Neighbor::new(12, 1.2));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map[&0].size(), 3);
queue.insert(Neighbor::new(13, 0.5));
assert_eq!(queue.size(), 3); assert_eq!(queue.get(0).id, 13); }
#[test]
fn test_insert_inner_queue_full() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(3, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0)); queue.insert(Neighbor::new(3, 0.8)); queue.insert(Neighbor::new(6, 1.2));
assert_eq!(queue.size(), 3);
queue.insert(Neighbor::new(9, 0.5));
assert_eq!(queue.size(), 3);
assert_eq!(queue.get(0).id, 9); }
#[test]
fn test_get() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
let nbr = queue.get(0);
assert_eq!(nbr.id, 1);
assert_eq!(nbr.distance, 0.5);
let nbr = queue.get(1);
assert_eq!(nbr.id, 0);
assert_eq!(nbr.distance, 1.0);
}
#[test]
fn test_closest_notvisited() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
queue.insert(Neighbor::new(2, 1.5));
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 1); assert_eq!(nbr.distance, 0.5);
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 0);
let nbr = queue.closest_notvisited().unwrap();
assert_eq!(nbr.id, 2);
assert!(!queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_none());
}
#[test]
fn test_has_notvisited_node() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
assert!(!queue.has_notvisited_node());
queue.insert(Neighbor::new(0, 1.0));
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
assert!(!queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_none());
}
#[test]
fn test_size() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
assert_eq!(queue.size(), 0);
queue.insert(Neighbor::new(0, 1.0));
assert_eq!(queue.size(), 1);
queue.insert(Neighbor::new(1, 0.5));
assert_eq!(queue.size(), 2);
}
#[test]
fn test_capacity() {
let attribute_provider = create_test_attribute_provider();
let queue = TestDiverseQueue::new(15, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
assert_eq!(queue.capacity(), 15);
}
#[test]
fn test_search_l() {
let attribute_provider = create_test_attribute_provider();
let queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
assert_eq!(queue.search_l(), 20);
}
#[test]
fn test_clear() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(3, 0.5));
queue.insert(Neighbor::new(6, 1.5));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 3);
queue.clear();
assert_eq!(queue.size(), 0);
assert_eq!(queue.local_queue_map.len(), 0);
}
#[test]
fn test_iter_candidates() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
queue.insert(Neighbor::new(2, 1.5));
let candidates: Vec<_> = queue.iter().collect();
assert_eq!(candidates.len(), 3);
assert_eq!(candidates[0].id, 1);
assert_eq!(candidates[1].id, 0);
assert_eq!(candidates[2].id, 2);
}
#[test]
fn test_inner_and_inner_mut() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
queue.insert(Neighbor::new(0, 1.0));
assert_eq!(queue.global_queue.size(), 1);
queue.global_queue.clear();
assert_eq!(queue.global_queue.size(), 0);
}
#[test]
fn test_vector_id_with_attribute() {
let vid_attr = VectorIdWithAttribute::new(42u32, 7);
assert_eq!(vid_attr.id, 42);
assert_eq!(vid_attr.attribute, 7);
let formatted = format!("{}", vid_attr);
assert_eq!(formatted, "(42, 7)");
}
#[test]
fn test_attribute_value_provider() {
let mut provider = TestAttributeValueProvider::new();
assert_eq!(provider.get(0), None);
provider.insert(0, 10);
assert_eq!(provider.get(0), Some(10));
provider.insert(5, 20);
assert_eq!(provider.get(5), Some(20));
provider.insert(0, 15);
assert_eq!(provider.get(0), Some(15));
}
#[test]
fn test_attribute_value_provider_default() {
let provider = TestAttributeValueProvider::default();
assert_eq!(provider.get(0), None);
}
#[test]
fn test_diverse_queue_complex_scenario() {
let attribute_provider = create_test_attribute_provider();
let mut queue =
TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 3, attribute_provider);
queue.insert(Neighbor::new(0, 1.0)); queue.insert(Neighbor::new(1, 0.5)); queue.insert(Neighbor::new(2, 1.5)); queue.insert(Neighbor::new(3, 0.8)); queue.insert(Neighbor::new(4, 1.2)); queue.insert(Neighbor::new(6, 0.7));
assert_eq!(queue.size(), 6);
let mut attribute_provider_updated = TestAttributeValueProvider::new();
for i in 0..20 {
attribute_provider_updated.insert(i, i / 3);
}
attribute_provider_updated.insert(17, 0);
let mut queue2 = TestDiverseQueue::new(
10,
NonZeroUsize::new(5).unwrap(),
3,
Arc::new(attribute_provider_updated),
);
queue2.insert(Neighbor::new(0, 1.0)); queue2.insert(Neighbor::new(1, 0.5)); queue2.insert(Neighbor::new(2, 1.5)); queue2.insert(Neighbor::new(3, 0.8)); queue2.insert(Neighbor::new(4, 1.2)); queue2.insert(Neighbor::new(6, 0.7));
queue2.insert(Neighbor::new(17, 0.3));
assert_eq!(queue2.get(0).id, 17);
assert_eq!(queue2.get(0).distance, 0.3);
}
#[test]
fn test_post_process() {
let mut attribute_provider = TestAttributeValueProvider::new();
for i in 0..9 {
attribute_provider.insert(i, i / 3);
}
let mut queue = TestDiverseQueue::new(
20,
NonZeroUsize::new(5).unwrap(),
2,
Arc::new(attribute_provider),
);
queue.insert(Neighbor::new(0, 1.0));
queue.insert(Neighbor::new(1, 0.5));
queue.insert(Neighbor::new(2, 1.5));
queue.insert(Neighbor::new(3, 0.8));
queue.insert(Neighbor::new(4, 1.2));
queue.insert(Neighbor::new(5, 0.6));
queue.insert(Neighbor::new(6, 0.7));
queue.insert(Neighbor::new(7, 1.1));
queue.insert(Neighbor::new(8, 0.9));
assert_eq!(queue.size(), 9);
assert_eq!(queue.local_queue_map[&0].size(), 3);
assert_eq!(queue.local_queue_map[&1].size(), 3);
assert_eq!(queue.local_queue_map[&2].size(), 3);
queue.post_process();
assert_eq!(queue.local_queue_map[&0].size(), 2);
assert_eq!(queue.local_queue_map[&1].size(), 2);
assert_eq!(queue.local_queue_map[&2].size(), 2);
assert_eq!(queue.size(), 6);
assert_eq!(queue.local_queue_map[&0].get(0).id, 1);
assert_eq!(queue.local_queue_map[&0].get(0).distance, 0.5);
assert_eq!(queue.local_queue_map[&0].get(1).id, 0);
assert_eq!(queue.local_queue_map[&0].get(1).distance, 1.0);
assert_eq!(queue.local_queue_map[&1].get(0).id, 5);
assert_eq!(queue.local_queue_map[&1].get(0).distance, 0.6);
assert_eq!(queue.local_queue_map[&1].get(1).id, 3);
assert_eq!(queue.local_queue_map[&1].get(1).distance, 0.8);
assert_eq!(queue.local_queue_map[&2].get(0).id, 6);
assert_eq!(queue.local_queue_map[&2].get(0).distance, 0.7);
assert_eq!(queue.local_queue_map[&2].get(1).id, 8);
assert_eq!(queue.local_queue_map[&2].get(1).distance, 0.9);
assert_eq!(queue.get(0).id, 1); assert_eq!(queue.get(1).id, 5); assert_eq!(queue.get(2).id, 6); assert_eq!(queue.get(3).id, 3); assert_eq!(queue.get(4).id, 8); assert_eq!(queue.get(5).id, 0); }
#[test]
fn test_skip_neighbors_without_attributes() {
let mut attribute_provider = TestAttributeValueProvider::new();
attribute_provider.insert(0, 0); attribute_provider.insert(1, 0); attribute_provider.insert(2, 1); attribute_provider.insert(4, 0);
let mut queue = TestDiverseQueue::new(
10,
NonZeroUsize::new(5).unwrap(),
5,
Arc::new(attribute_provider),
);
queue.insert(Neighbor::new(0, 1.0)); queue.insert(Neighbor::new(1, 0.5)); queue.insert(Neighbor::new(2, 0.8)); queue.insert(Neighbor::new(3, 0.3)); queue.insert(Neighbor::new(4, 1.2));
assert_eq!(queue.size(), 4, "Expected 4 items, ID 3 should be skipped");
assert_eq!(
queue.local_queue_map[&0].size(),
3,
"Attribute 0 should have 3 items"
);
assert_eq!(
queue.local_queue_map[&1].size(),
1,
"Attribute 1 should have 1 item"
);
let ids: Vec<u32> = queue.iter().map(|n| n.id).collect();
assert!(!ids.contains(&3), "ID 3 should not be in the queue");
assert_eq!(
ids,
vec![1, 2, 0, 4],
"Queue should contain IDs 1,2,0,4 in order of distance"
);
}
#[test]
fn test_attribute_zero_vs_missing_attribute() {
let mut attribute_provider = TestAttributeValueProvider::new();
attribute_provider.insert(0, 0);
attribute_provider.insert(2, 0);
let mut queue = TestDiverseQueue::new(
10,
NonZeroUsize::new(5).unwrap(),
5,
Arc::new(attribute_provider),
);
queue.insert(Neighbor::new(0, 1.0)); queue.insert(Neighbor::new(1, 0.5)); queue.insert(Neighbor::new(2, 0.8));
assert_eq!(queue.size(), 2);
assert_eq!(queue.local_queue_map.len(), 1);
assert!(queue.local_queue_map.contains_key(&0));
assert_eq!(queue.local_queue_map[&0].size(), 2);
let ids: Vec<u32> = queue.iter().map(|n| n.id).collect();
assert_eq!(ids, vec![2, 0], "Queue should only contain IDs 2 and 0");
}
}