nodedb_vector/hnsw/
graph.rs1use std::cell::RefCell;
9
10use crate::distance::distance;
11use crate::hnsw::arena::BeamSearchArena;
12
13pub use nodedb_types::hnsw::HnswParams;
15
16pub(crate) const ARENA_INITIAL_CAPACITY: usize = 256;
21
22pub const MAX_LAYER_CAP: usize = 16;
26
27#[derive(Debug, Clone)]
29pub struct SearchResult {
30 pub id: u32,
32 pub distance: f32,
34}
35
36pub struct Node {
38 pub vector: Vec<f32>,
40 pub neighbors: Vec<Vec<u32>>,
42 pub deleted: bool,
44}
45
46pub struct HnswIndex {
52 pub(crate) params: HnswParams,
53 pub(crate) dim: usize,
54 pub(crate) nodes: Vec<Node>,
55 pub(crate) entry_point: Option<u32>,
56 pub(crate) max_layer: usize,
57 pub(crate) rng: Xorshift64,
58 pub(crate) flat_neighbors: Option<crate::hnsw::flat_neighbors::FlatNeighborStore>,
62 pub(crate) arena: RefCell<BeamSearchArena>,
70}
71
72impl HnswIndex {
73 #[inline]
76 pub(crate) fn neighbors_at(&self, node_id: u32, layer: usize) -> &[u32] {
77 if let Some(ref flat) = self.flat_neighbors {
78 return flat.neighbors_at(node_id, layer);
79 }
80 let node = &self.nodes[node_id as usize];
81 if layer < node.neighbors.len() {
82 &node.neighbors[layer]
83 } else {
84 &[]
85 }
86 }
87
88 #[inline]
90 pub(crate) fn node_num_layers(&self, node_id: u32) -> usize {
91 if let Some(ref flat) = self.flat_neighbors {
92 return flat.num_layers(node_id);
93 }
94 self.nodes[node_id as usize].neighbors.len()
95 }
96
97 pub(crate) fn ensure_mutable_neighbors(&mut self) {
100 if let Some(flat) = self.flat_neighbors.take() {
101 let nested = flat.to_nested(self.nodes.len());
102 for (i, layers) in nested.into_iter().enumerate() {
103 self.nodes[i].neighbors = layers;
104 }
105 }
106 }
107}
108
109pub struct Xorshift64(pub u64);
111
112impl Xorshift64 {
113 pub fn new(seed: u64) -> Self {
114 Self(seed.max(1))
115 }
116
117 pub fn next_f64(&mut self) -> f64 {
118 self.0 ^= self.0 << 13;
119 self.0 ^= self.0 >> 7;
120 self.0 ^= self.0 << 17;
121 (self.0 as f64) / (u64::MAX as f64)
122 }
123}
124
125#[derive(Clone, Copy, PartialEq)]
127pub struct Candidate {
128 pub dist: f32,
129 pub id: u32,
130}
131
132impl Eq for Candidate {}
133
134impl PartialOrd for Candidate {
135 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
136 Some(self.cmp(other))
137 }
138}
139
140impl Ord for Candidate {
141 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
142 self.dist
143 .partial_cmp(&other.dist)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 .then(self.id.cmp(&other.id))
146 }
147}
148
149impl HnswIndex {
150 pub fn new(dim: usize, params: HnswParams) -> Self {
152 let initial_capacity = params.ef_construction.max(ARENA_INITIAL_CAPACITY);
153 Self {
154 dim,
155 nodes: Vec::new(),
156 entry_point: None,
157 max_layer: 0,
158 rng: Xorshift64::new(42),
159 flat_neighbors: None,
160 arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
161 params,
162 }
163 }
164
165 pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
167 let initial_capacity = params.ef_construction.max(ARENA_INITIAL_CAPACITY);
168 Self {
169 dim,
170 nodes: Vec::new(),
171 entry_point: None,
172 max_layer: 0,
173 rng: Xorshift64::new(seed),
174 flat_neighbors: None,
175 arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
176 params,
177 }
178 }
179
180 pub fn len(&self) -> usize {
181 self.nodes.len()
182 }
183
184 pub fn live_count(&self) -> usize {
185 self.nodes.len() - self.tombstone_count()
186 }
187
188 pub fn tombstone_count(&self) -> usize {
189 self.nodes.iter().filter(|n| n.deleted).count()
190 }
191
192 pub fn tombstone_ratio(&self) -> f64 {
194 if self.nodes.is_empty() {
195 0.0
196 } else {
197 self.tombstone_count() as f64 / self.nodes.len() as f64
198 }
199 }
200
201 pub fn is_empty(&self) -> bool {
202 self.live_count() == 0
203 }
204
205 pub fn delete(&mut self, id: u32) -> bool {
207 if let Some(node) = self.nodes.get_mut(id as usize) {
208 if node.deleted {
209 return false;
210 }
211 node.deleted = true;
212 true
213 } else {
214 false
215 }
216 }
217
218 pub fn is_deleted(&self, id: u32) -> bool {
219 self.nodes.get(id as usize).is_none_or(|n| n.deleted)
220 }
221
222 pub fn undelete(&mut self, id: u32) -> bool {
223 if let Some(node) = self.nodes.get_mut(id as usize)
224 && node.deleted
225 {
226 node.deleted = false;
227 return true;
228 }
229 false
230 }
231
232 pub fn dim(&self) -> usize {
233 self.dim
234 }
235
236 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
237 self.nodes.get(id as usize).map(|n| n.vector.as_slice())
238 }
239
240 pub fn params(&self) -> &HnswParams {
241 &self.params
242 }
243
244 pub fn entry_point(&self) -> Option<u32> {
245 self.entry_point
246 }
247
248 pub fn max_layer(&self) -> usize {
249 self.max_layer
250 }
251
252 pub fn rng_state(&self) -> u64 {
254 self.rng.0
255 }
256
257 pub fn memory_usage_bytes(&self) -> usize {
259 let vector_bytes = self.nodes.len() * self.dim * std::mem::size_of::<f32>();
260 let neighbor_bytes: usize = self
261 .nodes
262 .iter()
263 .map(|n| {
264 n.neighbors
265 .iter()
266 .map(|layer| layer.len() * 4)
267 .sum::<usize>()
268 })
269 .sum();
270 let node_overhead = self.nodes.len() * std::mem::size_of::<Node>();
271 vector_bytes + neighbor_bytes + node_overhead
272 }
273
274 pub fn export_vectors(&self) -> Vec<Vec<f32>> {
276 self.nodes.iter().map(|n| n.vector.clone()).collect()
277 }
278
279 pub fn export_neighbors(&self) -> Vec<Vec<Vec<u32>>> {
281 self.nodes.iter().map(|n| n.neighbors.clone()).collect()
282 }
283
284 pub(crate) fn random_layer(&mut self) -> usize {
290 let ml = 1.0 / (self.params.m as f64).ln();
291 let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
292 let layer = (-r.ln() * ml).floor() as usize;
293 layer.min(MAX_LAYER_CAP)
294 }
295
296 pub(crate) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
298 distance(
299 query,
300 &self.nodes[node_id as usize].vector,
301 self.params.metric,
302 )
303 }
304
305 pub(crate) fn max_neighbors(&self, layer: usize) -> usize {
307 if layer == 0 {
308 self.params.m0
309 } else {
310 self.params.m
311 }
312 }
313
314 pub fn compact(&mut self) -> usize {
319 self.compact_with_map().0
320 }
321
322 pub fn compact_with_map(&mut self) -> (usize, Vec<u32>) {
327 let tombstone_count = self.tombstone_count();
328 if tombstone_count == 0 {
329 let identity: Vec<u32> = (0..self.nodes.len() as u32).collect();
330 return (0, identity);
331 }
332 self.ensure_mutable_neighbors();
333
334 let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
336 let mut new_id = 0u32;
337 for node in &self.nodes {
338 if node.deleted {
339 id_map.push(u32::MAX);
340 } else {
341 id_map.push(new_id);
342 new_id += 1;
343 }
344 }
345
346 let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
348 for node in self.nodes.drain(..) {
349 if node.deleted {
350 continue;
351 }
352 let remapped_neighbors: Vec<Vec<u32>> = node
353 .neighbors
354 .into_iter()
355 .map(|layer_neighbors| {
356 layer_neighbors
357 .into_iter()
358 .filter_map(|old_nid| {
359 let new_nid = id_map[old_nid as usize];
360 if new_nid == u32::MAX {
361 None
362 } else {
363 Some(new_nid)
364 }
365 })
366 .collect()
367 })
368 .collect();
369 new_nodes.push(Node {
370 vector: node.vector,
371 neighbors: remapped_neighbors,
372 deleted: false,
373 });
374 }
375
376 self.entry_point = if let Some(old_ep) = self.entry_point {
377 let new_ep = id_map[old_ep as usize];
378 if new_ep == u32::MAX {
379 new_nodes
380 .iter()
381 .enumerate()
382 .max_by_key(|(_, n)| n.neighbors.len())
383 .map(|(i, _)| i as u32)
384 } else {
385 Some(new_ep)
386 }
387 } else {
388 None
389 };
390
391 self.max_layer = new_nodes
392 .iter()
393 .map(|n| n.neighbors.len().saturating_sub(1))
394 .max()
395 .unwrap_or(0);
396
397 self.nodes = new_nodes;
398 (tombstone_count, id_map)
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::distance::DistanceMetric;
406
407 #[test]
408 fn create_empty_index() {
409 let idx = HnswIndex::new(3, HnswParams::default());
410 assert_eq!(idx.len(), 0);
411 assert!(idx.is_empty());
412 assert!(idx.entry_point().is_none());
413 }
414
415 #[test]
416 fn params_default() {
417 let p = HnswParams::default();
418 assert_eq!(p.m, 16);
419 assert_eq!(p.m0, 32);
420 assert_eq!(p.ef_construction, 200);
421 assert_eq!(p.metric, DistanceMetric::Cosine);
422 }
423
424 #[test]
425 fn candidate_ordering() {
426 let a = Candidate { dist: 0.1, id: 1 };
427 let b = Candidate { dist: 0.5, id: 2 };
428 assert!(a < b);
429 }
430}