use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashMap;
use std::mem;
use crate::hnsw::{
config::HnswConfig,
errors::{HnswError, HnswMultiLayerError},
};
#[derive(Debug, Clone)]
pub struct LayerMappings {
global_to_local: HashMap<u64, Vec<Option<u64>>>,
local_to_global: Vec<HashMap<u64, u64>>,
next_local_id: Vec<usize>,
}
impl LayerMappings {
pub fn new(max_layers: usize) -> Self {
Self {
global_to_local: HashMap::new(),
local_to_global: (0..max_layers).map(|_| HashMap::new()).collect(),
next_local_id: vec![0; max_layers],
}
}
pub fn add_mapping(
&mut self,
global_id: u64,
layer_id: usize,
local_id: Option<u64>,
) -> Result<(), HnswError> {
if layer_id >= self.local_to_global.len() {
self.extend_layers(layer_id + 1)?;
}
let local_id = local_id.unwrap_or_else(|| {
let id = self.next_local_id[layer_id] as u64;
self.next_local_id[layer_id] += 1;
id
});
if local_id != self.local_to_global[layer_id].len() as u64 {
return Err(HnswError::MultiLayer(
HnswMultiLayerError::LayerMappingConflict {
global_id,
layer_id,
local_id,
expected: self.local_to_global[layer_id].len() as u64,
},
));
}
let entry = self.global_to_local.entry(global_id).or_default();
while entry.len() <= layer_id {
entry.push(None);
}
entry[layer_id] = Some(local_id);
self.local_to_global[layer_id].insert(local_id, global_id);
Ok(())
}
pub fn get_local_id(&self, global_id: u64, layer_id: usize) -> Option<u64> {
self.global_to_local
.get(&global_id)
.and_then(|mappings| mappings.get(layer_id).copied().flatten())
}
pub fn get_global_id(&self, layer_id: usize, local_id: u64) -> Option<u64> {
self.local_to_global
.get(layer_id)
.and_then(|mapping| mapping.get(&local_id).copied())
}
pub fn remove_global_id(&mut self, global_id: u64) -> Result<(), HnswError> {
if let Some(mappings) = self.global_to_local.remove(&global_id) {
for (layer_id, local_id) in mappings.iter().enumerate() {
if let Some(id) = local_id {
self.local_to_global[layer_id].remove(id);
}
}
}
Ok(())
}
pub fn get_layer_vectors(&self, layer_id: usize) -> Vec<u64> {
if layer_id >= self.local_to_global.len() {
return Vec::new();
}
let mut vectors: Vec<u64> = self.local_to_global[layer_id].values().copied().collect();
vectors.sort(); vectors
}
pub fn validate_consistency(&self) -> Result<(), HnswError> {
for (&global_id, mappings) in &self.global_to_local {
for (layer_id, &local_id) in mappings.iter().enumerate() {
if let Some(id) = local_id {
if let Some(mapped_global) = self.get_global_id(layer_id, id) {
if mapped_global != global_id {
return Err(HnswError::MultiLayer(
HnswMultiLayerError::InconsistentMapping {
global_id,
layer_id,
local_id: id,
mapped_global,
},
));
}
}
}
}
}
for (layer_id, mapping) in self.local_to_global.iter().enumerate() {
for (&local_id, &global_id) in mapping {
if let Some(mapped_local) = self.get_local_id(global_id, layer_id) {
if mapped_local != local_id {
return Err(HnswError::MultiLayer(
HnswMultiLayerError::InconsistentMapping {
global_id,
layer_id,
local_id,
mapped_global: global_id, },
));
}
} else {
return Err(HnswError::MultiLayer(
HnswMultiLayerError::InconsistentMapping {
global_id,
layer_id,
local_id,
mapped_global: u64::MAX, },
));
}
}
}
for (layer_id, mapping) in self.local_to_global.iter().enumerate() {
let expected_count = mapping.len();
if expected_count != self.next_local_id[layer_id] {
return Err(HnswError::MultiLayer(
HnswMultiLayerError::InconsistentLayerState {
layer_id,
expected_nodes: expected_count,
actual_nodes: mapping.len(),
},
));
}
}
Ok(())
}
pub fn memory_usage(&self) -> usize {
let base_overhead = mem::size_of::<Self>();
let global_to_local_size = self.global_to_local.len()
* (mem::size_of::<u64>() + mem::size_of::<Vec<Option<u64>>>() + self.global_to_local.values()
.map(|v| v.len() * mem::size_of::<Option<u64>>())
.sum::<usize>());
let local_to_global_size = self
.local_to_global
.iter()
.map(|m| m.len() * (mem::size_of::<u64>() + mem::size_of::<u64>()))
.sum::<usize>();
let next_local_id_size = self.next_local_id.len() * mem::size_of::<usize>();
base_overhead + global_to_local_size + local_to_global_size + next_local_id_size
}
pub fn clear(&mut self) {
self.global_to_local.clear();
for mapping in &mut self.local_to_global {
mapping.clear();
}
for next_id in &mut self.next_local_id {
*next_id = 0;
}
}
fn extend_layers(&mut self, required_layers: usize) -> Result<(), HnswError> {
let current_layers = self.local_to_global.len();
if required_layers <= current_layers {
return Ok(());
}
for _ in current_layers..required_layers {
self.local_to_global.push(HashMap::new());
self.next_local_id.push(0);
}
for mappings in self.global_to_local.values_mut() {
while mappings.len() < required_layers {
mappings.push(None);
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LevelDistributor {
base_m: f64,
max_layers: usize,
rng: StdRng,
}
impl LevelDistributor {
pub fn new(base_m: f64, max_layers: usize) -> Self {
Self {
base_m,
max_layers,
rng: StdRng::from_entropy(), }
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.rng = StdRng::seed_from_u64(seed);
self
}
pub fn sample_level(&mut self, rng: Option<&mut impl Rng>) -> usize {
let mut level = 0;
if let Some(rng) = rng {
while rng.r#gen::<f64>() < 1.0 / self.base_m && level < self.max_layers - 1 {
level += 1;
}
} else {
while self.rng.r#gen::<f64>() < 1.0 / self.base_m && level < self.max_layers - 1 {
level += 1;
}
}
level
}
pub fn level_probability(&self, level: usize) -> f64 {
if level >= self.max_layers {
return 0.0;
}
self.base_m.powf(-(level as f64))
}
pub fn expected_vectors_at_level(&self, total_vectors: usize, level: usize) -> f64 {
if level >= self.max_layers {
return 0.0;
}
let prob = self.level_probability(level);
total_vectors as f64 * prob
}
pub fn all_level_probabilities(&self) -> Vec<f64> {
(0..self.max_layers)
.map(|level| self.level_probability(level))
.collect()
}
pub fn sample_level_internal(&mut self) -> usize {
let mut level = 0;
while self.rng.r#gen::<f64>() < 1.0 / self.base_m && level < self.max_layers - 1 {
level += 1;
}
level
}
}
#[derive(Debug)]
pub struct MultiLayerNodeManager {
mappings: LayerMappings,
distributor: LevelDistributor,
config: HnswConfig,
vector_levels: HashMap<u64, usize>,
}
impl MultiLayerNodeManager {
pub fn new(config: HnswConfig) -> Result<Self, HnswError> {
let max_layers = config.ml as usize;
let mappings = LayerMappings::new(max_layers);
let distributor = LevelDistributor::new(config.m as f64, max_layers);
let vector_levels = HashMap::new();
Ok(Self {
mappings,
distributor,
config,
vector_levels,
})
}
pub fn insert_vector(
&mut self,
vector_id: u64,
) -> Result<(usize, Vec<(usize, u64)>), HnswError> {
let highest_level = self.distributor.sample_level_internal();
let mut layer_assignments = Vec::new();
for level in (0..=highest_level).rev() {
let local_id = self.mappings.next_local_id[level]; self.mappings.add_mapping(vector_id, level, None)?; layer_assignments.push((level, local_id as u64));
}
self.vector_levels.insert(vector_id, highest_level);
Ok((highest_level, layer_assignments))
}
pub fn remove_vector(&mut self, vector_id: u64) -> Result<(), HnswError> {
self.mappings.remove_global_id(vector_id)?;
self.vector_levels.remove(&vector_id);
Ok(())
}
pub fn get_local_id(&self, vector_id: u64, layer_id: usize) -> Option<u64> {
self.mappings.get_local_id(vector_id, layer_id)
}
pub fn get_global_id(&self, layer_id: usize, local_id: u64) -> Option<u64> {
self.mappings.get_global_id(layer_id, local_id)
}
pub fn get_layer_vectors(&self, layer_id: usize) -> Vec<u64> {
self.mappings.get_layer_vectors(layer_id)
}
pub fn get_vector_level(&self, vector_id: u64) -> Option<usize> {
self.vector_levels.get(&vector_id).copied()
}
pub fn get_statistics(&self) -> (usize, Vec<usize>, usize) {
let total_vectors = self.vector_levels.len();
let max_layers = self.config.ml as usize;
let mut layer_counts = vec![0; max_layers];
for (level, count) in layer_counts.iter_mut().enumerate().take(max_layers) {
*count = self.mappings.get_layer_vectors(level).len();
}
let memory_usage = self.mappings.memory_usage();
(total_vectors, layer_counts, memory_usage)
}
pub fn validate_consistency(&self) -> Result<(), HnswError> {
self.mappings.validate_consistency()
}
pub fn clear(&mut self) -> Result<(), HnswError> {
self.mappings.clear();
self.vector_levels.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::hnsw_config;
use rand::rngs::StdRng;
#[test]
fn test_layer_mappings_basic_operations() {
let mut mappings = LayerMappings::new(3);
mappings.add_mapping(1, 0, Some(0)).unwrap();
mappings.add_mapping(1, 1, Some(0)).unwrap();
mappings.add_mapping(2, 0, Some(1)).unwrap();
assert_eq!(mappings.get_local_id(1, 0), Some(0));
assert_eq!(mappings.get_local_id(1, 1), Some(0));
assert_eq!(mappings.get_local_id(2, 0), Some(1));
assert_eq!(mappings.get_local_id(3, 0), None);
assert_eq!(mappings.get_global_id(0, 0), Some(1));
assert_eq!(mappings.get_global_id(0, 1), Some(2));
assert_eq!(mappings.get_global_id(1, 0), Some(1));
}
#[test]
fn test_layer_mappings_sequential_assignment() {
let mut mappings = LayerMappings::new(2);
mappings.add_mapping(1, 0, None).unwrap(); mappings.add_mapping(2, 0, None).unwrap(); mappings.add_mapping(3, 0, None).unwrap();
assert_eq!(mappings.get_local_id(1, 0), Some(0));
assert_eq!(mappings.get_local_id(2, 0), Some(1));
assert_eq!(mappings.get_layer_vectors(0), vec![1, 2, 3]);
}
#[test]
fn test_layer_mappings_sequential_violation() {
let mut mappings = LayerMappings::new(2);
mappings.add_mapping(1, 0, Some(0)).unwrap();
let result = mappings.add_mapping(2, 0, Some(2));
assert!(result.is_err());
assert_eq!(mappings.get_local_id(2, 0), None);
}
#[test]
fn test_level_distributor_deterministic() {
let seed = 42u64;
let mut distributor1 = LevelDistributor::new(16.0, 5).with_seed(seed);
let mut distributor2 = LevelDistributor::new(16.0, 5).with_seed(seed);
let mut counts1 = vec![0; 5];
let mut counts2 = vec![0; 5];
for _ in 0..1000 {
counts1[distributor1.sample_level(None::<&mut StdRng>)] += 1;
counts2[distributor2.sample_level(None::<&mut StdRng>)] += 1;
}
assert_eq!(counts1, counts2);
}
#[test]
fn test_level_distributor_mathematical_properties() {
let distributor = LevelDistributor::new(16.0, 4);
assert_eq!(distributor.level_probability(0), 1.0);
assert_eq!(distributor.level_probability(1), 1.0 / 16.0);
assert_eq!(distributor.level_probability(2), 1.0 / 256.0);
assert_eq!(distributor.level_probability(3), 1.0 / 4096.0);
let total_vectors = 10000;
let expected_l0 = distributor.expected_vectors_at_level(total_vectors, 0);
let expected_l1 = distributor.expected_vectors_at_level(total_vectors, 1);
let expected_l2 = distributor.expected_vectors_at_level(total_vectors, 2);
assert!(expected_l0 > expected_l1);
assert!(expected_l1 > expected_l2);
let expected_total_slots = total_vectors as f64 * (16.0 / 15.0); let sum = expected_l0 + expected_l1 + expected_l2;
assert!((sum - expected_total_slots).abs() < 200.0); }
#[test]
fn test_multilayer_node_manager_basic_operations() {
let config = hnsw_config()
.m_connections(16)
.max_layers(8)
.build()
.unwrap();
let mut manager = MultiLayerNodeManager::new(config).unwrap();
let (level1, assignments1) = manager.insert_vector(1).unwrap();
let (level2, assignments2) = manager.insert_vector(2).unwrap();
let (level3, assignments3) = manager.insert_vector(3).unwrap();
assert!(level1 <= 7);
assert_eq!(assignments1.len(), level1 + 1); assert_eq!(assignments2.len(), level2 + 1); assert_eq!(assignments3.len(), level3 + 1);
for (_level, local_id) in assignments1.iter() {
let expected_local_id = 0; assert_eq!(*local_id, expected_local_id as u64);
}
}
#[test]
fn test_multilayer_node_manager_statistics() {
let config = hnsw_config().max_layers(4).build().unwrap();
let mut manager = MultiLayerNodeManager::new(config).unwrap();
for i in 1..=20 {
manager.insert_vector(i).unwrap();
}
let (total, layer_counts, memory) = manager.get_statistics();
assert_eq!(total, 20);
assert_eq!(layer_counts.len(), 4);
assert!(layer_counts[0] >= layer_counts[1]);
assert!(layer_counts[1] >= layer_counts[2]);
assert!(layer_counts[2] >= layer_counts[3]);
assert_eq!(layer_counts[0], 20);
assert!(memory > 0);
}
#[test]
fn test_multilayer_node_manager_consistency() {
let config = hnsw_config().build().unwrap();
let mut manager = MultiLayerNodeManager::new(config).unwrap();
manager.insert_vector(1).unwrap();
manager.insert_vector(2).unwrap();
manager.insert_vector(3).unwrap();
assert!(manager.validate_consistency().is_ok());
manager.mappings.local_to_global[0].remove(&1);
assert!(manager.validate_consistency().is_err());
}
#[test]
fn test_multilayer_node_manager_removal() {
let config = hnsw_config().build().unwrap();
let mut manager = MultiLayerNodeManager::new(config).unwrap();
let _id1 = manager.insert_vector(1).unwrap();
let _id2 = manager.insert_vector(2).unwrap();
let _id3 = manager.insert_vector(3).unwrap();
manager.remove_vector(2).unwrap();
assert_eq!(manager.get_local_id(2, 0), None);
assert_eq!(manager.get_local_id(1, 0), Some(0));
assert_eq!(manager.get_local_id(3, 0), Some(2));
let (total, _, _) = manager.get_statistics();
assert_eq!(total, 2); }
}