use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Default)]
pub(super) struct HnswIndex {
pub(super) connections: DashMap<u32, Vec<u32>>,
entry_state: AtomicU64,
layers: DashMap<u32, usize>,
}
impl HnswIndex {
#[inline]
fn pack_state(entry_point: u32, max_layer: u32) -> u64 {
((entry_point as u64) << 32) | (max_layer as u64)
}
#[inline]
fn unpack_state(state: u64) -> (u32, u32) {
let entry_point = (state >> 32) as u32;
let max_layer = (state & 0xFFFFFFFF) as u32;
(entry_point, max_layer)
}
pub(super) fn new() -> Self {
Self {
connections: DashMap::new(),
entry_state: AtomicU64::new(Self::pack_state(u32::MAX, 0)),
layers: DashMap::new(),
}
}
pub(super) fn is_empty(&self) -> bool {
let (entry_point, _) = Self::unpack_state(self.entry_state.load(Ordering::Acquire));
entry_point == u32::MAX
}
pub(super) fn get_entry_point(&self) -> Option<u32> {
let (entry_point, _) = Self::unpack_state(self.entry_state.load(Ordering::Acquire));
if entry_point == u32::MAX {
None
} else {
Some(entry_point)
}
}
pub(super) fn add_vector(&self, vector_id: u32, layer: usize, connections: Vec<u32>) {
self.layers.insert(vector_id, layer);
self.connections.insert(vector_id, connections);
loop {
let current = self.entry_state.load(Ordering::Acquire);
let (current_ep, current_max) = Self::unpack_state(current);
let should_update = current_ep == u32::MAX || layer as u32 >= current_max;
if !should_update {
break;
}
let new_max = current_max.max(layer as u32);
let new_state = Self::pack_state(vector_id, new_max);
match self.entry_state.compare_exchange(
current,
new_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue, }
}
}
pub(super) fn remove_vector(&self, vector_id: u32) {
self.layers.remove(&vector_id);
self.connections.remove(&vector_id);
for mut entry in self.connections.iter_mut() {
let connections = entry.value_mut();
connections.retain(|&id| id != vector_id);
}
loop {
let current = self.entry_state.load(Ordering::Acquire);
let (current_ep, _) = Self::unpack_state(current);
if current_ep != vector_id {
break;
}
let mut new_max_layer: u32 = 0;
let mut new_entry_point: u32 = u32::MAX;
for entry in self.layers.iter() {
let layer = *entry.value() as u32;
if layer >= new_max_layer {
new_max_layer = layer;
new_entry_point = *entry.key();
}
}
let new_state = Self::pack_state(new_entry_point, new_max_layer);
match self.entry_state.compare_exchange(
current,
new_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue, }
}
}
pub(super) fn get_connections(&self, vector_id: u32) -> Option<Vec<u32>> {
self.connections.get(&vector_id).map(|entry| entry.clone())
}
}