nodedb_vector/hnsw/
build.rs1use crate::error::VectorError;
6use crate::hnsw::graph::{Candidate, HnswIndex, Node};
7use crate::hnsw::search::search_layer;
8
9impl HnswIndex {
10 pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
18 self.ensure_mutable_neighbors();
20
21 if vector.len() != self.dim {
22 return Err(VectorError::DimensionMismatch {
23 expected: self.dim,
24 got: vector.len(),
25 });
26 }
27
28 let new_id = self.nodes.len() as u32;
29 let new_layer = self.random_layer();
30
31 let node = Node {
32 vector,
33 neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
34 deleted: false,
35 };
36 self.nodes.push(node);
37
38 let Some(ep) = self.entry_point else {
39 self.entry_point = Some(new_id);
40 self.max_layer = new_layer;
41 return Ok(());
42 };
43
44 let query = self.nodes[new_id as usize].vector.clone();
46
47 let mut current_ep = ep;
48
49 if self.max_layer > new_layer {
51 for layer in (new_layer + 1..=self.max_layer).rev() {
52 let results = search_layer(self, &query, current_ep, 1, layer, None, 0);
53 if let Some(nearest) = results.first() {
54 current_ep = nearest.id;
55 }
56 }
57 }
58
59 let insert_top = new_layer.min(self.max_layer);
61 for layer in (0..=insert_top).rev() {
62 let ef = self.params.ef_construction;
63 let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0);
64
65 let m = self.max_neighbors(layer);
66 let selected = select_neighbors_heuristic(self, &candidates, m);
67
68 self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
69
70 for neighbor in &selected {
71 let nid = neighbor.id as usize;
72 self.nodes[nid].neighbors[layer].push(new_id);
73
74 if self.nodes[nid].neighbors[layer].len() > m {
75 let node_vec = self.nodes[nid].vector.clone();
76 self.prune_neighbors(nid, layer, &node_vec, m);
77 }
78 }
79
80 if let Some(nearest) = candidates.first() {
81 current_ep = nearest.id;
82 }
83 }
84
85 if new_layer > self.max_layer {
86 self.entry_point = Some(new_id);
87 self.max_layer = new_layer;
88 }
89
90 Ok(())
91 }
92
93 fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
95 let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
96
97 let mut candidates: Vec<Candidate> = neighbor_ids
98 .iter()
99 .map(|&nid| Candidate {
100 id: nid,
101 dist: self.dist_to_node(node_vec, nid),
102 })
103 .collect();
104 candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
105
106 let selected = select_neighbors_heuristic(self, &candidates, m);
107 self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
108 }
109}
110
111fn select_neighbors_heuristic(
113 index: &HnswIndex,
114 candidates: &[Candidate],
115 m: usize,
116) -> Vec<Candidate> {
117 let mut selected: Vec<Candidate> = Vec::with_capacity(m);
119
120 for candidate in candidates {
121 if selected.len() >= m {
122 break;
123 }
124
125 let candidate_vec = &index.nodes[candidate.id as usize].vector;
126 let selected_vecs: Vec<&[f32]> = selected
127 .iter()
128 .map(|s| index.nodes[s.id as usize].vector.as_slice())
129 .collect();
130
131 let is_diverse = crate::batch_distance::is_diverse_batched(
132 candidate_vec,
133 candidate.dist,
134 &selected_vecs,
135 index.params.metric,
136 );
137
138 if is_diverse {
139 selected.push(*candidate);
140 }
141 }
142
143 if selected.len() < m {
145 let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
146 for candidate in candidates {
147 if selected.len() >= m {
148 break;
149 }
150 if !selected_ids.contains(&candidate.id) {
151 selected.push(*candidate);
152 }
153 }
154 }
155
156 selected
157}
158
159#[cfg(test)]
160mod tests {
161 use crate::distance::DistanceMetric;
162 use crate::hnsw::{HnswIndex, HnswParams};
163
164 fn make_index() -> HnswIndex {
165 HnswIndex::with_seed(
166 3,
167 HnswParams {
168 m: 4,
169 m0: 8,
170 ef_construction: 32,
171 metric: DistanceMetric::L2,
172 },
173 12345,
174 )
175 }
176
177 #[test]
178 fn insert_single() {
179 let mut idx = make_index();
180 idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
181 assert_eq!(idx.len(), 1);
182 assert_eq!(idx.entry_point(), Some(0));
183 }
184
185 #[test]
186 fn insert_many_maintains_invariants() {
187 let mut idx = make_index();
188 for i in 0..100 {
189 let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
190 idx.insert(v).unwrap();
191 }
192 assert_eq!(idx.len(), 100);
193 assert!(idx.entry_point().is_some());
194
195 for node in &idx.nodes {
196 assert!(node.neighbors[0].len() <= idx.params.m0);
197 }
198 for node in &idx.nodes {
199 for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
200 assert!(
201 neighbors.len() <= idx.params.m,
202 "layer {layer} neighbor count {} exceeds m={}",
203 neighbors.len(),
204 idx.params.m
205 );
206 }
207 }
208 }
209
210 #[test]
211 fn all_nodes_reachable_from_entry() {
212 let mut idx = make_index();
213 for i in 0..20 {
214 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
215 }
216
217 for target in 0..20u32 {
218 let query = idx.get_vector(target).unwrap().to_vec();
219 let results = idx.search(&query, 1, 32);
220 assert_eq!(results[0].id, target, "node {target} not reachable");
221 }
222 }
223
224 #[test]
225 fn compact_removes_tombstones() {
226 let mut idx = make_index();
227 for i in 0..20u32 {
228 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
229 }
230
231 for i in (0..20u32).step_by(2) {
232 assert!(idx.delete(i));
233 }
234 assert_eq!(idx.compact(), 10);
235 assert_eq!(idx.len(), 10);
236
237 for target_old_id in (1..20u32).step_by(2) {
238 let query = vec![target_old_id as f32, 0.0, 0.0];
239 let results = idx.search(&query, 1, 32);
240 assert!(!results.is_empty());
241 let found_vec = idx.get_vector(results[0].id).unwrap();
242 assert_eq!(found_vec[0], target_old_id as f32);
243 }
244 }
245}