1use crate::distance::distance;
4use crate::error::VectorError;
5use crate::hnsw::{Candidate, HnswIndex, Node};
6use crate::search::search_layer;
7
8impl HnswIndex {
9 pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
17 if vector.len() != self.dim {
18 return Err(VectorError::DimensionMismatch {
19 expected: self.dim,
20 got: vector.len(),
21 });
22 }
23
24 let new_id = self.nodes.len() as u32;
25 let new_layer = self.random_layer();
26
27 let node = Node {
28 vector,
29 neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
30 deleted: false,
31 };
32 self.nodes.push(node);
33
34 let Some(ep) = self.entry_point else {
35 self.entry_point = Some(new_id);
36 self.max_layer = new_layer;
37 return Ok(());
38 };
39
40 let query = self.nodes[new_id as usize].vector.clone();
42
43 let mut current_ep = ep;
44
45 if self.max_layer > new_layer {
47 for layer in (new_layer + 1..=self.max_layer).rev() {
48 let results = search_layer(self, &query, current_ep, 1, layer, None);
49 if let Some(nearest) = results.first() {
50 current_ep = nearest.id;
51 }
52 }
53 }
54
55 let insert_top = new_layer.min(self.max_layer);
57 for layer in (0..=insert_top).rev() {
58 let ef = self.params.ef_construction;
59 let candidates = search_layer(self, &query, current_ep, ef, layer, None);
60
61 let m = self.max_neighbors(layer);
62 let selected = select_neighbors_heuristic(self, &candidates, m);
63
64 self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
65
66 for neighbor in &selected {
67 let nid = neighbor.id as usize;
68 self.nodes[nid].neighbors[layer].push(new_id);
69
70 if self.nodes[nid].neighbors[layer].len() > m {
71 let node_vec = self.nodes[nid].vector.clone();
72 self.prune_neighbors(nid, layer, &node_vec, m);
73 }
74 }
75
76 if let Some(nearest) = candidates.first() {
77 current_ep = nearest.id;
78 }
79 }
80
81 if new_layer > self.max_layer {
82 self.entry_point = Some(new_id);
83 self.max_layer = new_layer;
84 }
85
86 Ok(())
87 }
88
89 fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
91 let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
92
93 let mut candidates: Vec<Candidate> = neighbor_ids
94 .iter()
95 .map(|&nid| Candidate {
96 id: nid,
97 dist: self.dist_to_node(node_vec, nid),
98 })
99 .collect();
100 candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
101
102 let selected = select_neighbors_heuristic(self, &candidates, m);
103 self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
104 }
105}
106
107fn select_neighbors_heuristic(
113 index: &HnswIndex,
114 candidates: &[Candidate],
115 m: usize,
116) -> Vec<Candidate> {
117 let mut selected: Vec<Candidate> = Vec::with_capacity(m);
118
119 for candidate in candidates {
120 if selected.len() >= m {
121 break;
122 }
123
124 let dist_to_query = candidate.dist;
125 let is_diverse = selected.iter().all(|s| {
126 let dist_to_selected = distance(
127 &index.nodes[candidate.id as usize].vector,
128 &index.nodes[s.id as usize].vector,
129 index.params.metric,
130 );
131 dist_to_query <= dist_to_selected
132 });
133
134 if is_diverse {
135 selected.push(*candidate);
136 }
137 }
138
139 if selected.len() < m {
141 let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
142 for candidate in candidates {
143 if selected.len() >= m {
144 break;
145 }
146 if !selected_ids.contains(&candidate.id) {
147 selected.push(*candidate);
148 }
149 }
150 }
151
152 selected
153}
154
155#[cfg(test)]
156mod tests {
157 use crate::distance::DistanceMetric;
158 use crate::hnsw::{HnswIndex, HnswParams};
159
160 fn make_index() -> HnswIndex {
161 HnswIndex::with_seed(
162 3,
163 HnswParams {
164 m: 4,
165 m0: 8,
166 ef_construction: 32,
167 metric: DistanceMetric::L2,
168 },
169 12345,
170 )
171 }
172
173 #[test]
174 fn insert_single() {
175 let mut idx = make_index();
176 idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
177 assert_eq!(idx.len(), 1);
178 assert_eq!(idx.entry_point(), Some(0));
179 }
180
181 #[test]
182 fn insert_many_maintains_invariants() {
183 let mut idx = make_index();
184 for i in 0..100 {
185 let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
186 idx.insert(v).unwrap();
187 }
188 assert_eq!(idx.len(), 100);
189 assert!(idx.entry_point().is_some());
190
191 for node in &idx.nodes {
192 assert!(node.neighbors[0].len() <= idx.params.m0);
193 }
194 for node in &idx.nodes {
195 for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
196 assert!(
197 neighbors.len() <= idx.params.m,
198 "layer {layer} neighbor count {} exceeds m={}",
199 neighbors.len(),
200 idx.params.m
201 );
202 }
203 }
204 }
205
206 #[test]
207 fn all_nodes_reachable_from_entry() {
208 let mut idx = make_index();
209 for i in 0..20 {
210 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
211 }
212
213 for target in 0..20u32 {
214 let query = idx.get_vector(target).unwrap().to_vec();
215 let results = idx.search(&query, 1, 32);
216 assert_eq!(results[0].id, target, "node {target} not reachable");
217 }
218 }
219
220 #[test]
221 fn compact_removes_tombstones() {
222 let mut idx = make_index();
223 for i in 0..20u32 {
224 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
225 }
226
227 for i in (0..20u32).step_by(2) {
228 assert!(idx.delete(i));
229 }
230 assert_eq!(idx.compact(), 10);
231 assert_eq!(idx.len(), 10);
232
233 for target_old_id in (1..20u32).step_by(2) {
234 let query = vec![target_old_id as f32, 0.0, 0.0];
235 let results = idx.search(&query, 1, 32);
236 assert!(!results.is_empty());
237 let found_vec = idx.get_vector(results[0].id).unwrap();
238 assert_eq!(found_vec[0], target_old_id as f32);
239 }
240 }
241
242 #[test]
243 fn checkpoint_roundtrip() {
244 let mut idx = make_index();
245 for i in 0..50 {
246 idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
247 .unwrap();
248 }
249
250 let bytes = idx.checkpoint_to_bytes();
251 assert!(!bytes.is_empty());
252
253 let restored = HnswIndex::from_checkpoint(&bytes).unwrap();
254 assert_eq!(restored.len(), 50);
255 assert_eq!(restored.dim(), 3);
256 assert_eq!(restored.entry_point(), idx.entry_point());
257 assert_eq!(restored.max_layer(), idx.max_layer());
258
259 let query = vec![1.0, 2.0, 3.0];
261 let orig_results = idx.search(&query, 5, 32);
262 let rest_results = restored.search(&query, 5, 32);
263 assert_eq!(orig_results.len(), rest_results.len());
264 for (a, b) in orig_results.iter().zip(rest_results.iter()) {
265 assert_eq!(a.id, b.id);
266 assert!((a.distance - b.distance).abs() < 1e-6);
267 }
268 }
269}