use crate::hnsw::errors::{HnswError, HnswIndexError};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct HnswLayer {
level: u8,
max_connections: usize,
nodes: Vec<HashSet<u64>>,
entry_points: Vec<u64>,
vector_count: usize,
}
impl HnswLayer {
pub fn new(level: u8, base_connections: usize) -> Self {
let max_connections = Self::compute_max_connections(level, base_connections);
Self {
level,
max_connections,
nodes: Vec::new(),
entry_points: Vec::new(),
vector_count: 0,
}
}
fn compute_max_connections(level: u8, base_connections: usize) -> usize {
let result = base_connections.checked_shr(level as u32).unwrap_or(0);
result.max(1)
}
pub fn level(&self) -> u8 {
self.level
}
pub fn max_connections(&self) -> usize {
self.max_connections
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn vector_count(&self) -> usize {
self.vector_count
}
pub fn contains_node(&self, node_id: u64) -> bool {
node_id < self.nodes.len() as u64
}
pub fn get_connections(&self, node_id: u64) -> Result<&HashSet<u64>, HnswError> {
if !self.contains_node(node_id) {
return Err(HnswError::Index(HnswIndexError::NodeNotFound(node_id)));
}
Ok(&self.nodes[node_id as usize])
}
pub fn add_node(&mut self, node_id: u64) -> Result<(), HnswError> {
if node_id != self.nodes.len() as u64 {
return Err(HnswError::Index(HnswIndexError::InvalidNodeId(node_id)));
}
self.nodes.push(HashSet::new());
self.vector_count += 1;
if self.entry_points.len() < self.max_connections {
self.entry_points.push(node_id);
}
Ok(())
}
pub(crate) fn add_one_way_connection(
&mut self,
from_node: u64,
to_node: u64,
) -> Result<(), HnswError> {
if from_node == to_node {
return Err(HnswError::Index(HnswIndexError::SelfConnection(from_node)));
}
if !self.contains_node(from_node) {
return Err(HnswError::Index(HnswIndexError::NodeNotFound(from_node)));
}
if !self.contains_node(to_node) {
return Err(HnswError::Index(HnswIndexError::NodeNotFound(to_node)));
}
self.nodes[from_node as usize].insert(to_node);
Ok(())
}
pub fn add_connection(&mut self, node_a: u64, node_b: u64) -> Result<(), HnswError> {
if node_a == node_b {
return Err(HnswError::Index(HnswIndexError::SelfConnection(node_a)));
}
if !self.contains_node(node_a) {
return Err(HnswError::Index(HnswIndexError::NodeNotFound(node_a)));
}
if !self.contains_node(node_b) {
return Err(HnswError::Index(HnswIndexError::NodeNotFound(node_b)));
}
self.nodes[node_a as usize].insert(node_b);
self.nodes[node_b as usize].insert(node_a);
self.prune_connections(node_a);
self.prune_connections(node_b);
Ok(())
}
fn prune_connections(&mut self, node_id: u64) {
if !self.contains_node(node_id) {
return;
}
let connections = &mut self.nodes[node_id as usize];
if connections.len() > self.max_connections {
let mut conn_vec: Vec<u64> = connections.iter().cloned().collect();
conn_vec.sort_unstable();
conn_vec.truncate(self.max_connections);
*connections = conn_vec.into_iter().collect();
}
}
pub fn prune_connections_by_distance(
&mut self,
node_id: u64,
connection_distances: &std::collections::HashMap<u64, f32>,
) {
if !self.contains_node(node_id) {
return;
}
let connections = &mut self.nodes[node_id as usize];
if connections.len() > self.max_connections {
let mut conn_with_dist: Vec<(u64, f32)> = connections
.iter()
.filter_map(|&id| connection_distances.get(&id).map(|&d| (id, d)))
.collect();
conn_with_dist.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
let to_keep: HashSet<u64> = conn_with_dist
.into_iter()
.take(self.max_connections)
.map(|(id, _)| id)
.collect();
*connections = to_keep;
}
}
pub fn get_entry_points(&self) -> &[u64] {
&self.entry_points
}
pub fn update_entry_points(&mut self, new_node_id: u64) {
if !self.contains_node(new_node_id) {
return;
}
let mut candidates: Vec<(u64, usize)> = self
.nodes
.iter()
.enumerate()
.map(|(id, connections)| (id as u64, connections.len()))
.collect();
candidates.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
self.entry_points = candidates
.iter()
.take(self.max_connections)
.map(|(id, _)| *id)
.collect();
}
pub fn is_base_layer(&self) -> bool {
self.level == 0
}
pub fn memory_usage(&self) -> usize {
let base_overhead = std::mem::size_of::<Self>();
let nodes_size = self.nodes.len() * std::mem::size_of::<HashSet<u64>>();
let connections_size: usize = self
.nodes
.iter()
.map(|conns| conns.len() * std::mem::size_of::<u64>())
.sum();
let entry_points_size = self.entry_points.len() * std::mem::size_of::<u64>();
base_overhead + nodes_size + connections_size + entry_points_size
}
pub fn clear(&mut self) {
self.nodes.clear();
self.entry_points.clear();
self.vector_count = 0;
}
pub fn get_statistics(&self) -> (usize, usize, f32) {
let node_count = self.nodes.len();
let total_connections: usize = self.nodes.iter().map(|conns| conns.len()).sum();
let avg_connections = if node_count > 0 {
total_connections as f32 / node_count as f32
} else {
0.0
};
(node_count, total_connections, avg_connections)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_creation() {
let layer = HnswLayer::new(0, 16);
assert_eq!(layer.level(), 0);
assert_eq!(layer.max_connections(), 16);
assert_eq!(layer.node_count(), 0);
assert_eq!(layer.vector_count(), 0);
assert!(layer.is_base_layer());
}
#[test]
fn test_layer_level_scaling() {
let layer0 = HnswLayer::new(0, 32);
let layer1 = HnswLayer::new(1, 32);
let layer2 = HnswLayer::new(2, 32);
assert_eq!(layer0.max_connections(), 32);
assert_eq!(layer1.max_connections(), 16);
assert_eq!(layer2.max_connections(), 8);
}
#[test]
fn test_layer_level_scaling_minimum() {
let layer10 = HnswLayer::new(10, 16);
assert_eq!(layer10.max_connections(), 1);
}
#[test]
fn test_add_node_sequential() {
let mut layer = HnswLayer::new(0, 8);
layer.add_node(0).unwrap();
layer.add_node(1).unwrap();
assert_eq!(layer.node_count(), 2);
assert!(layer.contains_node(0));
assert!(layer.contains_node(1));
assert!(!layer.contains_node(2));
}
#[test]
fn test_add_node_out_of_order() {
let mut layer = HnswLayer::new(0, 8);
layer.add_node(0).unwrap();
let result = layer.add_node(2); assert!(result.is_err());
}
#[test]
fn test_add_connection_success() {
let mut layer = HnswLayer::new(0, 4);
layer.add_node(0).unwrap();
layer.add_node(1).unwrap();
layer.add_node(2).unwrap();
layer.add_connection(0, 1).unwrap();
assert!(layer.get_connections(0).unwrap().contains(&1));
assert!(layer.get_connections(1).unwrap().contains(&0));
assert!(!layer.get_connections(2).unwrap().contains(&0));
}
#[test]
fn test_add_connection_self_connection() {
let mut layer = HnswLayer::new(0, 4);
layer.add_node(0).unwrap();
let result = layer.add_connection(0, 0);
assert!(result.is_err());
}
#[test]
fn test_add_connection_nonexistent_node() {
let mut layer = HnswLayer::new(0, 4);
layer.add_node(0).unwrap();
let result = layer.add_connection(0, 1); assert!(result.is_err());
}
#[test]
fn test_connection_pruning() {
let mut layer = HnswLayer::new(0, 2);
for i in 0..4 {
layer.add_node(i).unwrap();
}
layer.add_connection(0, 1).unwrap();
layer.add_connection(0, 2).unwrap();
layer.add_connection(0, 3).unwrap();
let connections = layer.get_connections(0).unwrap();
assert_eq!(connections.len(), 2);
assert!(connections.contains(&1));
assert!(connections.contains(&2));
assert!(!connections.contains(&3));
}
#[test]
fn test_entry_points_initial() {
let mut layer = HnswLayer::new(0, 3);
layer.add_node(0).unwrap();
layer.add_node(1).unwrap();
let entry_points = layer.get_entry_points();
assert_eq!(entry_points.len(), 2);
assert!(entry_points.contains(&0));
assert!(entry_points.contains(&1));
}
#[test]
fn test_update_entry_points() {
let mut layer = HnswLayer::new(0, 2);
for i in 0..5 {
layer.add_node(i).unwrap();
}
layer.add_connection(0, 1).unwrap();
layer.add_connection(0, 2).unwrap();
layer.add_connection(1, 3).unwrap();
layer.update_entry_points(4);
let entry_points = layer.get_entry_points();
assert_eq!(entry_points.len(), 2);
assert!(entry_points.contains(&0));
assert!(entry_points.contains(&1));
}
#[test]
fn test_get_connections_nonexistent() {
let layer = HnswLayer::new(0, 4);
let result = layer.get_connections(0);
assert!(result.is_err());
}
#[test]
fn test_memory_usage() {
let mut layer = HnswLayer::new(0, 4);
let base_usage = layer.memory_usage();
assert!(base_usage > 0);
for i in 0..3 {
layer.add_node(i).unwrap();
}
let with_nodes = layer.memory_usage();
assert!(with_nodes > base_usage);
layer.add_connection(0, 1).unwrap();
layer.add_connection(1, 2).unwrap();
let with_connections = layer.memory_usage();
assert!(with_connections > with_nodes);
}
#[test]
fn test_get_statistics() {
let mut layer = HnswLayer::new(0, 4);
let (nodes, total, avg) = layer.get_statistics();
assert_eq!(nodes, 0);
assert_eq!(total, 0);
assert_eq!(avg, 0.0);
for i in 0..3 {
layer.add_node(i).unwrap();
}
layer.add_connection(0, 1).unwrap();
layer.add_connection(1, 2).unwrap();
let (nodes, total, avg) = layer.get_statistics();
assert_eq!(nodes, 3);
assert_eq!(total, 4); assert!((avg - 1.33).abs() < 0.1); }
#[test]
fn test_clear_layer() {
let mut layer = HnswLayer::new(0, 4);
for i in 0..3 {
layer.add_node(i).unwrap();
}
layer.add_connection(0, 1).unwrap();
assert_eq!(layer.node_count(), 3);
assert_eq!(layer.vector_count(), 3);
assert!(!layer.get_entry_points().is_empty());
layer.clear();
assert_eq!(layer.node_count(), 0);
assert_eq!(layer.vector_count(), 0);
assert!(layer.get_entry_points().is_empty());
}
#[test]
fn test_higher_layer_properties() {
let layer = HnswLayer::new(3, 16);
assert_eq!(layer.level(), 3);
assert_eq!(layer.max_connections(), 2); assert!(!layer.is_base_layer());
}
}