use crate::Result;
use ruvix_region::backing::MemoryBacking;
use ruvix_region::slab::{SlabAllocator, SlotHandle};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct HnswConfig {
pub m: u16,
pub m0: u16,
pub ef_construction: u16,
pub ef_search: u16,
pub max_layers: u8,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 16,
m0: 32,
ef_construction: 200,
ef_search: 50,
max_layers: 16,
}
}
}
impl HnswConfig {
#[inline]
#[must_use]
pub const fn new(m: u16, ef_construction: u16) -> Self {
Self {
m,
m0: (m * 2) as u16,
ef_construction,
ef_search: 50,
max_layers: 16,
}
}
#[inline]
#[must_use]
pub const fn with_ef_search(mut self, ef_search: u16) -> Self {
self.ef_search = ef_search;
self
}
#[inline]
#[must_use]
pub const fn node_slot_size(&self) -> usize {
16 + (self.m0 as usize) * 8
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct HnswNode {
pub layer: u8,
_padding: [u8; 3],
pub link_count: u32,
pub vector_slot: SlotHandle,
pub links: [SlotHandle; 64], }
impl HnswNode {
#[inline]
#[must_use]
pub const fn new(layer: u8, vector_slot: SlotHandle) -> Self {
Self {
layer,
_padding: [0; 3],
link_count: 0,
vector_slot,
links: [SlotHandle::invalid(); 64],
}
}
#[inline]
#[must_use]
pub const fn max_links(&self, config: &HnswConfig) -> u16 {
if self.layer == 0 {
config.m0
} else {
config.m
}
}
pub fn add_link(&mut self, target: SlotHandle, config: &HnswConfig) -> bool {
let max_links = self.max_links(config) as usize;
if self.link_count as usize >= max_links {
return false;
}
for i in 0..self.link_count as usize {
if self.links[i] == target {
return true; }
}
self.links[self.link_count as usize] = target;
self.link_count += 1;
true
}
pub fn remove_link(&mut self, target: SlotHandle) -> bool {
for i in 0..self.link_count as usize {
if self.links[i] == target {
self.links[i] = self.links[self.link_count as usize - 1];
self.links[self.link_count as usize - 1] = SlotHandle::invalid();
self.link_count -= 1;
return true;
}
}
false
}
#[inline]
#[must_use]
pub fn links_slice(&self) -> &[SlotHandle] {
&self.links[..self.link_count as usize]
}
}
pub struct HnswRegion<B: MemoryBacking> {
node_slab: SlabAllocator<B>,
config: HnswConfig,
entry_point: Option<SlotHandle>,
current_max_layer: u8,
node_count: u32,
}
impl<B: MemoryBacking> HnswRegion<B> {
pub fn new(backing: B, config: HnswConfig, capacity: usize) -> Result<Self> {
let slot_size = core::mem::size_of::<HnswNode>();
let node_slab = SlabAllocator::new(backing, slot_size, capacity)?;
Ok(Self {
node_slab,
config,
entry_point: None,
current_max_layer: 0,
node_count: 0,
})
}
pub fn alloc_node(&mut self, layer: u8, vector_slot: SlotHandle) -> Result<SlotHandle> {
let handle = self.node_slab.alloc()?;
let node = HnswNode::new(layer, vector_slot);
let node_bytes =
unsafe { core::slice::from_raw_parts(&node as *const _ as *const u8, core::mem::size_of::<HnswNode>()) };
self.node_slab.write(handle, node_bytes)?;
self.node_count += 1;
if self.entry_point.is_none() || layer > self.current_max_layer {
self.entry_point = Some(handle);
self.current_max_layer = layer;
}
Ok(handle)
}
pub fn free_node(&mut self, handle: SlotHandle) -> Result<()> {
self.node_slab.free(handle)?;
self.node_count = self.node_count.saturating_sub(1);
if self.entry_point == Some(handle) {
self.entry_point = None;
self.current_max_layer = 0;
}
Ok(())
}
pub fn read_node(&self, handle: SlotHandle) -> Result<HnswNode> {
let mut bytes = [0u8; core::mem::size_of::<HnswNode>()];
self.node_slab.read(handle, &mut bytes)?;
Ok(unsafe { core::ptr::read(bytes.as_ptr() as *const HnswNode) })
}
pub fn write_node(&mut self, handle: SlotHandle, node: &HnswNode) -> Result<()> {
let node_bytes =
unsafe { core::slice::from_raw_parts(node as *const _ as *const u8, core::mem::size_of::<HnswNode>()) };
self.node_slab.write(handle, node_bytes)?;
Ok(())
}
pub fn add_link(&mut self, from: SlotHandle, to: SlotHandle) -> Result<bool> {
let mut node = self.read_node(from)?;
let result = node.add_link(to, &self.config);
self.write_node(from, &node)?;
Ok(result)
}
pub fn remove_link(&mut self, from: SlotHandle, to: SlotHandle) -> Result<bool> {
let mut node = self.read_node(from)?;
let result = node.remove_link(to);
self.write_node(from, &node)?;
Ok(result)
}
#[inline]
#[must_use]
pub const fn entry_point(&self) -> Option<SlotHandle> {
self.entry_point
}
#[inline]
#[must_use]
pub const fn current_max_layer(&self) -> u8 {
self.current_max_layer
}
#[inline]
#[must_use]
pub const fn node_count(&self) -> u32 {
self.node_count
}
#[inline]
#[must_use]
pub const fn config(&self) -> &HnswConfig {
&self.config
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.node_count == 0
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.node_slab.slot_count()
}
#[inline]
#[must_use]
pub fn free_slots(&self) -> usize {
self.node_slab.free_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ruvix_region::backing::StaticBacking;
#[test]
fn test_hnsw_config_defaults() {
let config = HnswConfig::default();
assert_eq!(config.m, 16);
assert_eq!(config.m0, 32);
assert_eq!(config.ef_construction, 200);
}
#[test]
fn test_hnsw_node_links() {
let config = HnswConfig::new(4, 100);
let mut node = HnswNode::new(1, SlotHandle::new(0, 0));
assert_eq!(node.max_links(&config), 4);
assert!(node.add_link(SlotHandle::new(1, 0), &config));
assert!(node.add_link(SlotHandle::new(2, 0), &config));
assert!(node.add_link(SlotHandle::new(3, 0), &config));
assert!(node.add_link(SlotHandle::new(4, 0), &config));
assert!(!node.add_link(SlotHandle::new(5, 0), &config));
assert!(node.remove_link(SlotHandle::new(2, 0)));
assert_eq!(node.link_count, 3);
assert!(node.add_link(SlotHandle::new(5, 0), &config));
}
#[test]
fn test_hnsw_region_alloc_free() {
let backing = StaticBacking::<16384>::new();
let config = HnswConfig::default();
let mut region = HnswRegion::new(backing, config, 10).unwrap();
assert!(region.is_empty());
assert_eq!(region.node_count(), 0);
let vector_slot = SlotHandle::new(0, 0);
let node_handle = region.alloc_node(0, vector_slot).unwrap();
assert_eq!(region.node_count(), 1);
assert_eq!(region.entry_point(), Some(node_handle));
let node = region.read_node(node_handle).unwrap();
assert_eq!(node.layer, 0);
assert_eq!(node.vector_slot, vector_slot);
region.free_node(node_handle).unwrap();
assert_eq!(region.node_count(), 0);
assert!(region.entry_point().is_none());
}
#[test]
fn test_hnsw_region_links() {
let backing = StaticBacking::<16384>::new();
let config = HnswConfig::default();
let mut region = HnswRegion::new(backing, config, 10).unwrap();
let vector_slot = SlotHandle::new(0, 0);
let node1 = region.alloc_node(0, vector_slot).unwrap();
let node2 = region.alloc_node(0, SlotHandle::new(1, 0)).unwrap();
let node3 = region.alloc_node(0, SlotHandle::new(2, 0)).unwrap();
assert!(region.add_link(node1, node2).unwrap());
assert!(region.add_link(node2, node1).unwrap());
assert!(region.add_link(node1, node3).unwrap());
assert!(region.add_link(node3, node1).unwrap());
let n1 = region.read_node(node1).unwrap();
assert_eq!(n1.link_count, 2);
assert!(n1.links_slice().contains(&node2));
assert!(n1.links_slice().contains(&node3));
}
#[test]
fn test_hnsw_region_entry_point_update() {
let backing = StaticBacking::<16384>::new();
let config = HnswConfig::default();
let mut region = HnswRegion::new(backing, config, 10).unwrap();
let vector_slot = SlotHandle::new(0, 0);
let node1 = region.alloc_node(0, vector_slot).unwrap();
assert_eq!(region.entry_point(), Some(node1));
assert_eq!(region.current_max_layer(), 0);
let node2 = region.alloc_node(2, SlotHandle::new(1, 0)).unwrap();
assert_eq!(region.entry_point(), Some(node2));
assert_eq!(region.current_max_layer(), 2);
let _node3 = region.alloc_node(1, SlotHandle::new(2, 0)).unwrap();
assert_eq!(region.entry_point(), Some(node2));
}
}