nodedb_vector/hnsw/
build.rs1use crate::error::VectorError;
4use crate::hnsw::graph::{Candidate, HnswIndex, Node};
5use crate::hnsw::search::search_layer;
6
7impl HnswIndex {
8 pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
16 self.ensure_mutable_neighbors();
18
19 if vector.len() != self.dim {
20 return Err(VectorError::DimensionMismatch {
21 expected: self.dim,
22 got: vector.len(),
23 });
24 }
25
26 let new_id = self.nodes.len() as u32;
27 let new_layer = self.random_layer();
28
29 let node = Node {
30 vector,
31 neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
32 deleted: false,
33 };
34 self.nodes.push(node);
35
36 let Some(ep) = self.entry_point else {
37 self.entry_point = Some(new_id);
38 self.max_layer = new_layer;
39 return Ok(());
40 };
41
42 let query = self.nodes[new_id as usize].vector.clone();
44
45 let mut current_ep = ep;
46
47 if self.max_layer > new_layer {
49 for layer in (new_layer + 1..=self.max_layer).rev() {
50 let results = search_layer(self, &query, current_ep, 1, layer, None, 0);
51 if let Some(nearest) = results.first() {
52 current_ep = nearest.id;
53 }
54 }
55 }
56
57 let insert_top = new_layer.min(self.max_layer);
59 for layer in (0..=insert_top).rev() {
60 let ef = self.params.ef_construction;
61 let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0);
62
63 let m = self.max_neighbors(layer);
64 let selected = select_neighbors_heuristic(self, &candidates, m);
65
66 self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
67
68 for neighbor in &selected {
69 let nid = neighbor.id as usize;
70 self.nodes[nid].neighbors[layer].push(new_id);
71
72 if self.nodes[nid].neighbors[layer].len() > m {
73 let node_vec = self.nodes[nid].vector.clone();
74 self.prune_neighbors(nid, layer, &node_vec, m);
75 }
76 }
77
78 if let Some(nearest) = candidates.first() {
79 current_ep = nearest.id;
80 }
81 }
82
83 if new_layer > self.max_layer {
84 self.entry_point = Some(new_id);
85 self.max_layer = new_layer;
86 }
87
88 Ok(())
89 }
90
91 fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
93 let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
94
95 let mut candidates: Vec<Candidate> = neighbor_ids
96 .iter()
97 .map(|&nid| Candidate {
98 id: nid,
99 dist: self.dist_to_node(node_vec, nid),
100 })
101 .collect();
102 candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
103
104 let selected = select_neighbors_heuristic(self, &candidates, m);
105 self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
106 }
107}
108
109fn select_neighbors_heuristic(
111 index: &HnswIndex,
112 candidates: &[Candidate],
113 m: usize,
114) -> Vec<Candidate> {
115 let mut selected: Vec<Candidate> = Vec::with_capacity(m);
116
117 for candidate in candidates {
118 if selected.len() >= m {
119 break;
120 }
121
122 let candidate_vec = &index.nodes[candidate.id as usize].vector;
123 let selected_vecs: Vec<&[f32]> = selected
124 .iter()
125 .map(|s| index.nodes[s.id as usize].vector.as_slice())
126 .collect();
127
128 let is_diverse = crate::batch_distance::is_diverse_batched(
129 candidate_vec,
130 candidate.dist,
131 &selected_vecs,
132 index.params.metric,
133 );
134
135 if is_diverse {
136 selected.push(*candidate);
137 }
138 }
139
140 if selected.len() < m {
142 let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
143 for candidate in candidates {
144 if selected.len() >= m {
145 break;
146 }
147 if !selected_ids.contains(&candidate.id) {
148 selected.push(*candidate);
149 }
150 }
151 }
152
153 selected
154}
155
156#[cfg(test)]
157mod tests {
158 use crate::distance::DistanceMetric;
159 use crate::hnsw::{HnswIndex, HnswParams};
160
161 fn make_index() -> HnswIndex {
162 HnswIndex::with_seed(
163 3,
164 HnswParams {
165 m: 4,
166 m0: 8,
167 ef_construction: 32,
168 metric: DistanceMetric::L2,
169 },
170 12345,
171 )
172 }
173
174 #[test]
175 fn insert_single() {
176 let mut idx = make_index();
177 idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
178 assert_eq!(idx.len(), 1);
179 assert_eq!(idx.entry_point(), Some(0));
180 }
181
182 #[test]
183 fn insert_many_maintains_invariants() {
184 let mut idx = make_index();
185 for i in 0..100 {
186 let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
187 idx.insert(v).unwrap();
188 }
189 assert_eq!(idx.len(), 100);
190 assert!(idx.entry_point().is_some());
191
192 for node in &idx.nodes {
193 assert!(node.neighbors[0].len() <= idx.params.m0);
194 }
195 for node in &idx.nodes {
196 for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
197 assert!(
198 neighbors.len() <= idx.params.m,
199 "layer {layer} neighbor count {} exceeds m={}",
200 neighbors.len(),
201 idx.params.m
202 );
203 }
204 }
205 }
206
207 #[test]
208 fn all_nodes_reachable_from_entry() {
209 let mut idx = make_index();
210 for i in 0..20 {
211 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
212 }
213
214 for target in 0..20u32 {
215 let query = idx.get_vector(target).unwrap().to_vec();
216 let results = idx.search(&query, 1, 32);
217 assert_eq!(results[0].id, target, "node {target} not reachable");
218 }
219 }
220
221 #[test]
222 fn compact_removes_tombstones() {
223 let mut idx = make_index();
224 for i in 0..20u32 {
225 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
226 }
227
228 for i in (0..20u32).step_by(2) {
229 assert!(idx.delete(i));
230 }
231 assert_eq!(idx.compact(), 10);
232 assert_eq!(idx.len(), 10);
233
234 for target_old_id in (1..20u32).step_by(2) {
235 let query = vec![target_old_id as f32, 0.0, 0.0];
236 let results = idx.search(&query, 1, 32);
237 assert!(!results.is_empty());
238 let found_vec = idx.get_vector(results[0].id).unwrap();
239 assert_eq!(found_vec[0], target_old_id as f32);
240 }
241 }
242}