Skip to main content

nodedb_vector/hnsw/
build.rs

1//! HNSW insert algorithm (Malkov & Yashunin, Algorithm 1).
2
3use crate::error::VectorError;
4use crate::hnsw::graph::{Candidate, HnswIndex, Node};
5use crate::hnsw::search::search_layer;
6
7impl HnswIndex {
8    /// Insert a vector into the index.
9    ///
10    /// 1. Assign a random layer using the exponential distribution
11    /// 2. Greedily descend from the entry point to the new node's layer + 1
12    /// 3. At each layer from the node's layer down to 0, search for nearest
13    ///    neighbors, select via the diversity heuristic, and add bidirectional edges
14    /// 4. Prune over-connected nodes to maintain the M/M0 invariant
15    pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
16        // Materialize flat neighbor storage on first mutation.
17        self.ensure_mutable_neighbors();
18
19        if vector.len() != self.dim {
20            return Err(VectorError::DimensionMismatch {
21                expected: self.dim,
22                got: vector.len(),
23            });
24        }
25
26        let new_id = self.nodes.len() as u32;
27        let new_layer = self.random_layer();
28
29        let node = Node {
30            vector,
31            neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
32            deleted: false,
33        };
34        self.nodes.push(node);
35
36        let Some(ep) = self.entry_point else {
37            self.entry_point = Some(new_id);
38            self.max_layer = new_layer;
39            return Ok(());
40        };
41
42        // Clone the query vector to avoid aliasing self.nodes during mutation.
43        let query = self.nodes[new_id as usize].vector.clone();
44
45        let mut current_ep = ep;
46
47        // Phase 1: Greedy descent from top layer to new_layer + 1.
48        if self.max_layer > new_layer {
49            for layer in (new_layer + 1..=self.max_layer).rev() {
50                let results = search_layer(self, &query, current_ep, 1, layer, None, 0);
51                if let Some(nearest) = results.first() {
52                    current_ep = nearest.id;
53                }
54            }
55        }
56
57        // Phase 2: Insert at each layer from min(new_layer, max_layer) down to 0.
58        let insert_top = new_layer.min(self.max_layer);
59        for layer in (0..=insert_top).rev() {
60            let ef = self.params.ef_construction;
61            let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0);
62
63            let m = self.max_neighbors(layer);
64            let selected = select_neighbors_heuristic(self, &candidates, m);
65
66            self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
67
68            for neighbor in &selected {
69                let nid = neighbor.id as usize;
70                self.nodes[nid].neighbors[layer].push(new_id);
71
72                if self.nodes[nid].neighbors[layer].len() > m {
73                    let node_vec = self.nodes[nid].vector.clone();
74                    self.prune_neighbors(nid, layer, &node_vec, m);
75                }
76            }
77
78            if let Some(nearest) = candidates.first() {
79                current_ep = nearest.id;
80            }
81        }
82
83        if new_layer > self.max_layer {
84            self.entry_point = Some(new_id);
85            self.max_layer = new_layer;
86        }
87
88        Ok(())
89    }
90
91    /// Prune a node's neighbor list using the diversity heuristic.
92    fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
93        let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
94
95        let mut candidates: Vec<Candidate> = neighbor_ids
96            .iter()
97            .map(|&nid| Candidate {
98                id: nid,
99                dist: self.dist_to_node(node_vec, nid),
100            })
101            .collect();
102        candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
103
104        let selected = select_neighbors_heuristic(self, &candidates, m);
105        self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
106    }
107}
108
109/// Heuristic neighbor selection (Malkov & Yashunin, Algorithm 4).
110fn select_neighbors_heuristic(
111    index: &HnswIndex,
112    candidates: &[Candidate],
113    m: usize,
114) -> Vec<Candidate> {
115    let mut selected: Vec<Candidate> = Vec::with_capacity(m);
116
117    for candidate in candidates {
118        if selected.len() >= m {
119            break;
120        }
121
122        let candidate_vec = &index.nodes[candidate.id as usize].vector;
123        let selected_vecs: Vec<&[f32]> = selected
124            .iter()
125            .map(|s| index.nodes[s.id as usize].vector.as_slice())
126            .collect();
127
128        let is_diverse = crate::batch_distance::is_diverse_batched(
129            candidate_vec,
130            candidate.dist,
131            &selected_vecs,
132            index.params.metric,
133        );
134
135        if is_diverse {
136            selected.push(*candidate);
137        }
138    }
139
140    // Backfill if heuristic was too aggressive.
141    if selected.len() < m {
142        let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
143        for candidate in candidates {
144            if selected.len() >= m {
145                break;
146            }
147            if !selected_ids.contains(&candidate.id) {
148                selected.push(*candidate);
149            }
150        }
151    }
152
153    selected
154}
155
156#[cfg(test)]
157mod tests {
158    use crate::distance::DistanceMetric;
159    use crate::hnsw::{HnswIndex, HnswParams};
160
161    fn make_index() -> HnswIndex {
162        HnswIndex::with_seed(
163            3,
164            HnswParams {
165                m: 4,
166                m0: 8,
167                ef_construction: 32,
168                metric: DistanceMetric::L2,
169            },
170            12345,
171        )
172    }
173
174    #[test]
175    fn insert_single() {
176        let mut idx = make_index();
177        idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
178        assert_eq!(idx.len(), 1);
179        assert_eq!(idx.entry_point(), Some(0));
180    }
181
182    #[test]
183    fn insert_many_maintains_invariants() {
184        let mut idx = make_index();
185        for i in 0..100 {
186            let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
187            idx.insert(v).unwrap();
188        }
189        assert_eq!(idx.len(), 100);
190        assert!(idx.entry_point().is_some());
191
192        for node in &idx.nodes {
193            assert!(node.neighbors[0].len() <= idx.params.m0);
194        }
195        for node in &idx.nodes {
196            for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
197                assert!(
198                    neighbors.len() <= idx.params.m,
199                    "layer {layer} neighbor count {} exceeds m={}",
200                    neighbors.len(),
201                    idx.params.m
202                );
203            }
204        }
205    }
206
207    #[test]
208    fn all_nodes_reachable_from_entry() {
209        let mut idx = make_index();
210        for i in 0..20 {
211            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
212        }
213
214        for target in 0..20u32 {
215            let query = idx.get_vector(target).unwrap().to_vec();
216            let results = idx.search(&query, 1, 32);
217            assert_eq!(results[0].id, target, "node {target} not reachable");
218        }
219    }
220
221    #[test]
222    fn compact_removes_tombstones() {
223        let mut idx = make_index();
224        for i in 0..20u32 {
225            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
226        }
227
228        for i in (0..20u32).step_by(2) {
229            assert!(idx.delete(i));
230        }
231        assert_eq!(idx.compact(), 10);
232        assert_eq!(idx.len(), 10);
233
234        for target_old_id in (1..20u32).step_by(2) {
235            let query = vec![target_old_id as f32, 0.0, 0.0];
236            let results = idx.search(&query, 1, 32);
237            assert!(!results.is_empty());
238            let found_vec = idx.get_vector(results[0].id).unwrap();
239            assert_eq!(found_vec[0], target_old_id as f32);
240        }
241    }
242}