Skip to main content

nodedb_vector/
hnsw.rs

1//! HNSW graph structure — nodes, parameters, checkpoint serialization.
2//!
3//! Production implementation per Malkov & Yashunin (2018).
4//! Adapted for edge devices: no SIMD runtime dispatch, no Roaring bitmap
5//! (those are in the Origin). Pure scalar distance functions.
6
7use crate::distance::{DistanceMetric, distance};
8
9// Re-export shared params from nodedb-types.
10pub use nodedb_types::hnsw::HnswParams;
11
12/// Result of a k-NN search.
13#[derive(Debug, Clone)]
14pub struct SearchResult {
15    /// Internal node identifier (insertion order).
16    pub id: u32,
17    /// Distance from the query vector.
18    pub distance: f32,
19}
20
21/// A node in the HNSW graph.
22pub(crate) struct Node {
23    /// Full-precision vector data.
24    pub vector: Vec<f32>,
25    /// Neighbors at each layer this node participates in.
26    pub neighbors: Vec<Vec<u32>>,
27    /// Tombstone flag for soft-deletion.
28    pub deleted: bool,
29}
30
31/// Hierarchical Navigable Small World graph index.
32///
33/// Production HNSW per Malkov & Yashunin (2018):
34/// - Multi-layer graph with exponential layer assignment
35/// - FP32 construction for structural integrity
36/// - Heuristic neighbor selection (Algorithm 4)
37/// - Beam search with configurable ef parameter
38pub struct HnswIndex {
39    pub(crate) params: HnswParams,
40    pub(crate) dim: usize,
41    pub(crate) nodes: Vec<Node>,
42    pub(crate) entry_point: Option<u32>,
43    pub(crate) max_layer: usize,
44    pub(crate) rng: Xorshift64,
45}
46
47/// Lightweight xorshift64 PRNG for layer assignment.
48pub(crate) struct Xorshift64(pub u64);
49
50impl Xorshift64 {
51    pub fn new(seed: u64) -> Self {
52        Self(seed.max(1))
53    }
54
55    pub fn next_f64(&mut self) -> f64 {
56        self.0 ^= self.0 << 13;
57        self.0 ^= self.0 >> 7;
58        self.0 ^= self.0 << 17;
59        (self.0 as f64) / (u64::MAX as f64)
60    }
61}
62
63/// Ordered candidate for priority queues during search and construction.
64#[derive(Clone, Copy, PartialEq)]
65pub(crate) struct Candidate {
66    pub dist: f32,
67    pub id: u32,
68}
69
70impl Eq for Candidate {}
71
72impl PartialOrd for Candidate {
73    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
74        Some(self.cmp(other))
75    }
76}
77
78impl Ord for Candidate {
79    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
80        self.dist
81            .partial_cmp(&other.dist)
82            .unwrap_or(std::cmp::Ordering::Equal)
83            .then(self.id.cmp(&other.id))
84    }
85}
86
87impl HnswIndex {
88    /// Create a new empty HNSW index.
89    pub fn new(dim: usize, params: HnswParams) -> Self {
90        Self {
91            dim,
92            nodes: Vec::new(),
93            entry_point: None,
94            max_layer: 0,
95            rng: Xorshift64::new(42),
96            params,
97        }
98    }
99
100    /// Create with a specific RNG seed (for deterministic testing).
101    pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
102        Self {
103            dim,
104            nodes: Vec::new(),
105            entry_point: None,
106            max_layer: 0,
107            rng: Xorshift64::new(seed),
108            params,
109        }
110    }
111
112    pub fn len(&self) -> usize {
113        self.nodes.len()
114    }
115
116    pub fn live_count(&self) -> usize {
117        self.nodes.len() - self.tombstone_count()
118    }
119
120    pub fn tombstone_count(&self) -> usize {
121        self.nodes.iter().filter(|n| n.deleted).count()
122    }
123
124    pub fn is_empty(&self) -> bool {
125        self.live_count() == 0
126    }
127
128    /// Soft-delete a vector by internal node ID.
129    pub fn delete(&mut self, id: u32) -> bool {
130        if let Some(node) = self.nodes.get_mut(id as usize) {
131            if node.deleted {
132                return false;
133            }
134            node.deleted = true;
135            true
136        } else {
137            false
138        }
139    }
140
141    pub fn is_deleted(&self, id: u32) -> bool {
142        self.nodes.get(id as usize).is_none_or(|n| n.deleted)
143    }
144
145    pub fn undelete(&mut self, id: u32) -> bool {
146        if let Some(node) = self.nodes.get_mut(id as usize)
147            && node.deleted
148        {
149            node.deleted = false;
150            return true;
151        }
152        false
153    }
154
155    pub fn dim(&self) -> usize {
156        self.dim
157    }
158
159    pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
160        self.nodes.get(id as usize).map(|n| n.vector.as_slice())
161    }
162
163    pub fn params(&self) -> &HnswParams {
164        &self.params
165    }
166
167    pub fn entry_point(&self) -> Option<u32> {
168        self.entry_point
169    }
170
171    pub fn max_layer(&self) -> usize {
172        self.max_layer
173    }
174
175    /// Serialize the index to MessagePack bytes for storage.
176    ///
177    /// NOTE: The checklist mandates rkyv for engine blobs, but HNSW's
178    /// recursive Vec<Vec<u32>> structure doesn't support rkyv derive
179    /// (same issue as Value). We use MessagePack here — the cold-start
180    /// budget is met because HNSW rebuild from checkpoint is a single
181    /// deserialization, not per-node allocation. For CSR (flat arrays),
182    /// rkyv zero-copy is used.
183    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
184        use serde::{Deserialize, Serialize};
185
186        #[derive(Serialize, Deserialize)]
187        struct Snapshot {
188            dim: usize,
189            m: usize,
190            m0: usize,
191            ef_construction: usize,
192            metric: u8,
193            entry_point: Option<u32>,
194            max_layer: usize,
195            rng_state: u64,
196            nodes: Vec<NodeSnap>,
197        }
198
199        #[derive(Serialize, Deserialize)]
200        struct NodeSnap {
201            vector: Vec<f32>,
202            neighbors: Vec<Vec<u32>>,
203            deleted: bool,
204        }
205
206        let snapshot = Snapshot {
207            dim: self.dim,
208            m: self.params.m,
209            m0: self.params.m0,
210            ef_construction: self.params.ef_construction,
211            metric: self.params.metric as u8,
212            entry_point: self.entry_point,
213            max_layer: self.max_layer,
214            rng_state: self.rng.0,
215            nodes: self
216                .nodes
217                .iter()
218                .map(|n| NodeSnap {
219                    vector: n.vector.clone(),
220                    neighbors: n.neighbors.clone(),
221                    deleted: n.deleted,
222                })
223                .collect(),
224        };
225        match rmp_serde::to_vec_named(&snapshot) {
226            Ok(bytes) => bytes,
227            Err(e) => {
228                tracing::error!(error = %e, "HNSW checkpoint serialization failed");
229                Vec::new()
230            }
231        }
232    }
233
234    /// Restore an index from a checkpoint snapshot.
235    pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
236        use serde::{Deserialize, Serialize};
237
238        #[derive(Serialize, Deserialize)]
239        struct Snapshot {
240            dim: usize,
241            m: usize,
242            m0: usize,
243            ef_construction: usize,
244            metric: u8,
245            entry_point: Option<u32>,
246            max_layer: usize,
247            rng_state: u64,
248            nodes: Vec<NodeSnap>,
249        }
250
251        #[derive(Serialize, Deserialize)]
252        struct NodeSnap {
253            vector: Vec<f32>,
254            neighbors: Vec<Vec<u32>>,
255            deleted: bool,
256        }
257
258        let snap: Snapshot = rmp_serde::from_slice(bytes).ok()?;
259        let metric = match snap.metric {
260            0 => DistanceMetric::L2,
261            1 => DistanceMetric::Cosine,
262            2 => DistanceMetric::InnerProduct,
263            _ => DistanceMetric::Cosine,
264        };
265
266        let nodes: Vec<Node> = snap
267            .nodes
268            .into_iter()
269            .map(|n| Node {
270                vector: n.vector,
271                neighbors: n.neighbors,
272                deleted: n.deleted,
273            })
274            .collect();
275
276        Some(Self {
277            dim: snap.dim,
278            params: HnswParams {
279                m: snap.m,
280                m0: snap.m0,
281                ef_construction: snap.ef_construction,
282                metric,
283            },
284            nodes,
285            entry_point: snap.entry_point,
286            max_layer: snap.max_layer,
287            rng: Xorshift64::new(snap.rng_state),
288        })
289    }
290
291    /// Assign a random layer using the exponential distribution.
292    pub(crate) fn random_layer(&mut self) -> usize {
293        let ml = 1.0 / (self.params.m as f64).ln();
294        let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
295        (-r.ln() * ml).floor() as usize
296    }
297
298    /// Compute distance between a query vector and a stored node.
299    pub(crate) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
300        distance(
301            query,
302            &self.nodes[node_id as usize].vector,
303            self.params.metric,
304        )
305    }
306
307    /// Max neighbors allowed at a given layer.
308    pub(crate) fn max_neighbors(&self, layer: usize) -> usize {
309        if layer == 0 {
310            self.params.m0
311        } else {
312            self.params.m
313        }
314    }
315
316    /// Compact the index by removing all tombstoned nodes.
317    pub fn compact(&mut self) -> usize {
318        let tombstone_count = self.tombstone_count();
319        if tombstone_count == 0 {
320            return 0;
321        }
322
323        let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
324        let mut new_id = 0u32;
325        for node in &self.nodes {
326            if node.deleted {
327                id_map.push(u32::MAX);
328            } else {
329                id_map.push(new_id);
330                new_id += 1;
331            }
332        }
333
334        let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
335        for node in self.nodes.drain(..) {
336            if node.deleted {
337                continue;
338            }
339            let remapped_neighbors: Vec<Vec<u32>> = node
340                .neighbors
341                .into_iter()
342                .map(|layer_neighbors| {
343                    layer_neighbors
344                        .into_iter()
345                        .filter_map(|old_nid| {
346                            let new_nid = id_map[old_nid as usize];
347                            if new_nid == u32::MAX {
348                                None
349                            } else {
350                                Some(new_nid)
351                            }
352                        })
353                        .collect()
354                })
355                .collect();
356            new_nodes.push(Node {
357                vector: node.vector,
358                neighbors: remapped_neighbors,
359                deleted: false,
360            });
361        }
362
363        self.entry_point = if let Some(old_ep) = self.entry_point {
364            let new_ep = id_map[old_ep as usize];
365            if new_ep == u32::MAX {
366                new_nodes
367                    .iter()
368                    .enumerate()
369                    .max_by_key(|(_, n)| n.neighbors.len())
370                    .map(|(i, _)| i as u32)
371            } else {
372                Some(new_ep)
373            }
374        } else {
375            None
376        };
377
378        self.max_layer = new_nodes
379            .iter()
380            .map(|n| n.neighbors.len().saturating_sub(1))
381            .max()
382            .unwrap_or(0);
383
384        self.nodes = new_nodes;
385        tombstone_count
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn create_empty_index() {
395        let idx = HnswIndex::new(3, HnswParams::default());
396        assert_eq!(idx.len(), 0);
397        assert!(idx.is_empty());
398        assert!(idx.entry_point().is_none());
399    }
400
401    #[test]
402    fn params_default() {
403        let p = HnswParams::default();
404        assert_eq!(p.m, 16);
405        assert_eq!(p.m0, 32);
406        assert_eq!(p.ef_construction, 200);
407        assert_eq!(p.metric, DistanceMetric::Cosine);
408    }
409
410    #[test]
411    fn candidate_ordering() {
412        let a = Candidate { dist: 0.1, id: 1 };
413        let b = Candidate { dist: 0.5, id: 2 };
414        assert!(a < b);
415    }
416}