1use crate::distance::{DistanceMetric, distance};
8
9pub use nodedb_types::hnsw::HnswParams;
11
12#[derive(Debug, Clone)]
14pub struct SearchResult {
15 pub id: u32,
17 pub distance: f32,
19}
20
21pub(crate) struct Node {
23 pub vector: Vec<f32>,
25 pub neighbors: Vec<Vec<u32>>,
27 pub deleted: bool,
29}
30
31pub struct HnswIndex {
39 pub(crate) params: HnswParams,
40 pub(crate) dim: usize,
41 pub(crate) nodes: Vec<Node>,
42 pub(crate) entry_point: Option<u32>,
43 pub(crate) max_layer: usize,
44 pub(crate) rng: Xorshift64,
45}
46
47pub(crate) struct Xorshift64(pub u64);
49
50impl Xorshift64 {
51 pub fn new(seed: u64) -> Self {
52 Self(seed.max(1))
53 }
54
55 pub fn next_f64(&mut self) -> f64 {
56 self.0 ^= self.0 << 13;
57 self.0 ^= self.0 >> 7;
58 self.0 ^= self.0 << 17;
59 (self.0 as f64) / (u64::MAX as f64)
60 }
61}
62
63#[derive(Clone, Copy, PartialEq)]
65pub(crate) struct Candidate {
66 pub dist: f32,
67 pub id: u32,
68}
69
70impl Eq for Candidate {}
71
72impl PartialOrd for Candidate {
73 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
74 Some(self.cmp(other))
75 }
76}
77
78impl Ord for Candidate {
79 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
80 self.dist
81 .partial_cmp(&other.dist)
82 .unwrap_or(std::cmp::Ordering::Equal)
83 .then(self.id.cmp(&other.id))
84 }
85}
86
87impl HnswIndex {
88 pub fn new(dim: usize, params: HnswParams) -> Self {
90 Self {
91 dim,
92 nodes: Vec::new(),
93 entry_point: None,
94 max_layer: 0,
95 rng: Xorshift64::new(42),
96 params,
97 }
98 }
99
100 pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
102 Self {
103 dim,
104 nodes: Vec::new(),
105 entry_point: None,
106 max_layer: 0,
107 rng: Xorshift64::new(seed),
108 params,
109 }
110 }
111
112 pub fn len(&self) -> usize {
113 self.nodes.len()
114 }
115
116 pub fn live_count(&self) -> usize {
117 self.nodes.len() - self.tombstone_count()
118 }
119
120 pub fn tombstone_count(&self) -> usize {
121 self.nodes.iter().filter(|n| n.deleted).count()
122 }
123
124 pub fn is_empty(&self) -> bool {
125 self.live_count() == 0
126 }
127
128 pub fn delete(&mut self, id: u32) -> bool {
130 if let Some(node) = self.nodes.get_mut(id as usize) {
131 if node.deleted {
132 return false;
133 }
134 node.deleted = true;
135 true
136 } else {
137 false
138 }
139 }
140
141 pub fn is_deleted(&self, id: u32) -> bool {
142 self.nodes.get(id as usize).is_none_or(|n| n.deleted)
143 }
144
145 pub fn undelete(&mut self, id: u32) -> bool {
146 if let Some(node) = self.nodes.get_mut(id as usize)
147 && node.deleted
148 {
149 node.deleted = false;
150 return true;
151 }
152 false
153 }
154
155 pub fn dim(&self) -> usize {
156 self.dim
157 }
158
159 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
160 self.nodes.get(id as usize).map(|n| n.vector.as_slice())
161 }
162
163 pub fn params(&self) -> &HnswParams {
164 &self.params
165 }
166
167 pub fn entry_point(&self) -> Option<u32> {
168 self.entry_point
169 }
170
171 pub fn max_layer(&self) -> usize {
172 self.max_layer
173 }
174
175 pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
184 use serde::{Deserialize, Serialize};
185
186 #[derive(Serialize, Deserialize)]
187 struct Snapshot {
188 dim: usize,
189 m: usize,
190 m0: usize,
191 ef_construction: usize,
192 metric: u8,
193 entry_point: Option<u32>,
194 max_layer: usize,
195 rng_state: u64,
196 nodes: Vec<NodeSnap>,
197 }
198
199 #[derive(Serialize, Deserialize)]
200 struct NodeSnap {
201 vector: Vec<f32>,
202 neighbors: Vec<Vec<u32>>,
203 deleted: bool,
204 }
205
206 let snapshot = Snapshot {
207 dim: self.dim,
208 m: self.params.m,
209 m0: self.params.m0,
210 ef_construction: self.params.ef_construction,
211 metric: self.params.metric as u8,
212 entry_point: self.entry_point,
213 max_layer: self.max_layer,
214 rng_state: self.rng.0,
215 nodes: self
216 .nodes
217 .iter()
218 .map(|n| NodeSnap {
219 vector: n.vector.clone(),
220 neighbors: n.neighbors.clone(),
221 deleted: n.deleted,
222 })
223 .collect(),
224 };
225 match rmp_serde::to_vec_named(&snapshot) {
226 Ok(bytes) => bytes,
227 Err(e) => {
228 tracing::error!(error = %e, "HNSW checkpoint serialization failed");
229 Vec::new()
230 }
231 }
232 }
233
234 pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
236 use serde::{Deserialize, Serialize};
237
238 #[derive(Serialize, Deserialize)]
239 struct Snapshot {
240 dim: usize,
241 m: usize,
242 m0: usize,
243 ef_construction: usize,
244 metric: u8,
245 entry_point: Option<u32>,
246 max_layer: usize,
247 rng_state: u64,
248 nodes: Vec<NodeSnap>,
249 }
250
251 #[derive(Serialize, Deserialize)]
252 struct NodeSnap {
253 vector: Vec<f32>,
254 neighbors: Vec<Vec<u32>>,
255 deleted: bool,
256 }
257
258 let snap: Snapshot = rmp_serde::from_slice(bytes).ok()?;
259 let metric = match snap.metric {
260 0 => DistanceMetric::L2,
261 1 => DistanceMetric::Cosine,
262 2 => DistanceMetric::InnerProduct,
263 _ => DistanceMetric::Cosine,
264 };
265
266 let nodes: Vec<Node> = snap
267 .nodes
268 .into_iter()
269 .map(|n| Node {
270 vector: n.vector,
271 neighbors: n.neighbors,
272 deleted: n.deleted,
273 })
274 .collect();
275
276 Some(Self {
277 dim: snap.dim,
278 params: HnswParams {
279 m: snap.m,
280 m0: snap.m0,
281 ef_construction: snap.ef_construction,
282 metric,
283 },
284 nodes,
285 entry_point: snap.entry_point,
286 max_layer: snap.max_layer,
287 rng: Xorshift64::new(snap.rng_state),
288 })
289 }
290
291 pub(crate) fn random_layer(&mut self) -> usize {
293 let ml = 1.0 / (self.params.m as f64).ln();
294 let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
295 (-r.ln() * ml).floor() as usize
296 }
297
298 pub(crate) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
300 distance(
301 query,
302 &self.nodes[node_id as usize].vector,
303 self.params.metric,
304 )
305 }
306
307 pub(crate) fn max_neighbors(&self, layer: usize) -> usize {
309 if layer == 0 {
310 self.params.m0
311 } else {
312 self.params.m
313 }
314 }
315
316 pub fn compact(&mut self) -> usize {
318 let tombstone_count = self.tombstone_count();
319 if tombstone_count == 0 {
320 return 0;
321 }
322
323 let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
324 let mut new_id = 0u32;
325 for node in &self.nodes {
326 if node.deleted {
327 id_map.push(u32::MAX);
328 } else {
329 id_map.push(new_id);
330 new_id += 1;
331 }
332 }
333
334 let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
335 for node in self.nodes.drain(..) {
336 if node.deleted {
337 continue;
338 }
339 let remapped_neighbors: Vec<Vec<u32>> = node
340 .neighbors
341 .into_iter()
342 .map(|layer_neighbors| {
343 layer_neighbors
344 .into_iter()
345 .filter_map(|old_nid| {
346 let new_nid = id_map[old_nid as usize];
347 if new_nid == u32::MAX {
348 None
349 } else {
350 Some(new_nid)
351 }
352 })
353 .collect()
354 })
355 .collect();
356 new_nodes.push(Node {
357 vector: node.vector,
358 neighbors: remapped_neighbors,
359 deleted: false,
360 });
361 }
362
363 self.entry_point = if let Some(old_ep) = self.entry_point {
364 let new_ep = id_map[old_ep as usize];
365 if new_ep == u32::MAX {
366 new_nodes
367 .iter()
368 .enumerate()
369 .max_by_key(|(_, n)| n.neighbors.len())
370 .map(|(i, _)| i as u32)
371 } else {
372 Some(new_ep)
373 }
374 } else {
375 None
376 };
377
378 self.max_layer = new_nodes
379 .iter()
380 .map(|n| n.neighbors.len().saturating_sub(1))
381 .max()
382 .unwrap_or(0);
383
384 self.nodes = new_nodes;
385 tombstone_count
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn create_empty_index() {
395 let idx = HnswIndex::new(3, HnswParams::default());
396 assert_eq!(idx.len(), 0);
397 assert!(idx.is_empty());
398 assert!(idx.entry_point().is_none());
399 }
400
401 #[test]
402 fn params_default() {
403 let p = HnswParams::default();
404 assert_eq!(p.m, 16);
405 assert_eq!(p.m0, 32);
406 assert_eq!(p.ef_construction, 200);
407 assert_eq!(p.metric, DistanceMetric::Cosine);
408 }
409
410 #[test]
411 fn candidate_ordering() {
412 let a = Candidate { dist: 0.1, id: 1 };
413 let b = Candidate { dist: 0.5, id: 2 };
414 assert!(a < b);
415 }
416}