use crate::error::VectorError;
use crate::hnsw::graph::{Candidate, HnswIndex, Node};
use crate::hnsw::search::search_layer;
impl HnswIndex {
pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
self.ensure_mutable_neighbors();
if vector.len() != self.dim {
return Err(VectorError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
let new_id = self.nodes.len() as u32;
let new_layer = self.random_layer();
let node = Node {
vector,
neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
deleted: false,
};
self.nodes.push(node);
let Some(ep) = self.entry_point else {
self.entry_point = Some(new_id);
self.max_layer = new_layer;
return Ok(());
};
let query = self.nodes[new_id as usize].vector.clone();
let mut current_ep = ep;
if self.max_layer > new_layer {
for layer in (new_layer + 1..=self.max_layer).rev() {
let results = search_layer(self, &query, current_ep, 1, layer, None);
if let Some(nearest) = results.first() {
current_ep = nearest.id;
}
}
}
let insert_top = new_layer.min(self.max_layer);
for layer in (0..=insert_top).rev() {
let ef = self.params.ef_construction;
let candidates = search_layer(self, &query, current_ep, ef, layer, None);
let m = self.max_neighbors(layer);
let selected = select_neighbors_heuristic(self, &candidates, m);
self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
for neighbor in &selected {
let nid = neighbor.id as usize;
self.nodes[nid].neighbors[layer].push(new_id);
if self.nodes[nid].neighbors[layer].len() > m {
let node_vec = self.nodes[nid].vector.clone();
self.prune_neighbors(nid, layer, &node_vec, m);
}
}
if let Some(nearest) = candidates.first() {
current_ep = nearest.id;
}
}
if new_layer > self.max_layer {
self.entry_point = Some(new_id);
self.max_layer = new_layer;
}
Ok(())
}
fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
let mut candidates: Vec<Candidate> = neighbor_ids
.iter()
.map(|&nid| Candidate {
id: nid,
dist: self.dist_to_node(node_vec, nid),
})
.collect();
candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
let selected = select_neighbors_heuristic(self, &candidates, m);
self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
}
}
fn select_neighbors_heuristic(
index: &HnswIndex,
candidates: &[Candidate],
m: usize,
) -> Vec<Candidate> {
let mut selected: Vec<Candidate> = Vec::with_capacity(m);
for candidate in candidates {
if selected.len() >= m {
break;
}
let candidate_vec = &index.nodes[candidate.id as usize].vector;
let selected_vecs: Vec<&[f32]> = selected
.iter()
.map(|s| index.nodes[s.id as usize].vector.as_slice())
.collect();
let is_diverse = crate::batch_distance::is_diverse_batched(
candidate_vec,
candidate.dist,
&selected_vecs,
index.params.metric,
);
if is_diverse {
selected.push(*candidate);
}
}
if selected.len() < m {
let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
for candidate in candidates {
if selected.len() >= m {
break;
}
if !selected_ids.contains(&candidate.id) {
selected.push(*candidate);
}
}
}
selected
}
#[cfg(test)]
mod tests {
use crate::distance::DistanceMetric;
use crate::hnsw::{HnswIndex, HnswParams};
fn make_index() -> HnswIndex {
HnswIndex::with_seed(
3,
HnswParams {
m: 4,
m0: 8,
ef_construction: 32,
metric: DistanceMetric::L2,
},
12345,
)
}
#[test]
fn insert_single() {
let mut idx = make_index();
idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
assert_eq!(idx.len(), 1);
assert_eq!(idx.entry_point(), Some(0));
}
#[test]
fn insert_many_maintains_invariants() {
let mut idx = make_index();
for i in 0..100 {
let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
idx.insert(v).unwrap();
}
assert_eq!(idx.len(), 100);
assert!(idx.entry_point().is_some());
for node in &idx.nodes {
assert!(node.neighbors[0].len() <= idx.params.m0);
}
for node in &idx.nodes {
for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
assert!(
neighbors.len() <= idx.params.m,
"layer {layer} neighbor count {} exceeds m={}",
neighbors.len(),
idx.params.m
);
}
}
}
#[test]
fn all_nodes_reachable_from_entry() {
let mut idx = make_index();
for i in 0..20 {
idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
}
for target in 0..20u32 {
let query = idx.get_vector(target).unwrap().to_vec();
let results = idx.search(&query, 1, 32);
assert_eq!(results[0].id, target, "node {target} not reachable");
}
}
#[test]
fn compact_removes_tombstones() {
let mut idx = make_index();
for i in 0..20u32 {
idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
}
for i in (0..20u32).step_by(2) {
assert!(idx.delete(i));
}
assert_eq!(idx.compact(), 10);
assert_eq!(idx.len(), 10);
for target_old_id in (1..20u32).step_by(2) {
let query = vec![target_old_id as f32, 0.0, 0.0];
let results = idx.search(&query, 1, 32);
assert!(!results.is_empty());
let found_vec = idx.get_vector(results[0].id).unwrap();
assert_eq!(found_vec[0], target_old_id as f32);
}
}
}