Skip to main content

nodedb_vector/hnsw/
graph.rs

1//! HNSW graph structure — nodes, parameters, core index operations.
2//!
3//! Production implementation per Malkov & Yashunin (2018).
4//! FP32 construction for structural integrity; heuristic neighbor selection.
5
6use crate::distance::distance;
7
8// Re-export shared params from nodedb-types.
9pub use nodedb_types::hnsw::HnswParams;
10
11/// Hard cap on the layer assigned to any node during insertion.
12/// Standard HNSW practice — prevents pathological RNG draws from inflating
13/// `max_layer` and slowing every subsequent search.
14pub const MAX_LAYER_CAP: usize = 16;
15
16/// Result of a k-NN search.
17#[derive(Debug, Clone)]
18pub struct SearchResult {
19    /// Internal node identifier (insertion order).
20    pub id: u32,
21    /// Distance from the query vector.
22    pub distance: f32,
23}
24
25/// A node in the HNSW graph.
26pub struct Node {
27    /// Full-precision vector data.
28    pub vector: Vec<f32>,
29    /// Neighbors at each layer this node participates in.
30    pub neighbors: Vec<Vec<u32>>,
31    /// Tombstone flag for soft-deletion.
32    pub deleted: bool,
33}
34
35/// Hierarchical Navigable Small World graph index.
36///
37/// - FP32 construction for structural integrity
38/// - Heuristic neighbor selection (Algorithm 4)
39/// - Beam search with configurable ef parameter
40pub struct HnswIndex {
41    pub(crate) params: HnswParams,
42    pub(crate) dim: usize,
43    pub(crate) nodes: Vec<Node>,
44    pub(crate) entry_point: Option<u32>,
45    pub(crate) max_layer: usize,
46    pub(crate) rng: Xorshift64,
47    /// Flat neighbor storage for zero-copy access after checkpoint restore.
48    /// When present, `neighbors_at()` reads from here instead of per-node Vecs.
49    /// Cleared on first mutation (insert/delete).
50    pub(crate) flat_neighbors: Option<crate::hnsw::flat_neighbors::FlatNeighborStore>,
51}
52
53impl HnswIndex {
54    /// Get neighbors of a node at a specific layer.
55    /// Uses flat zero-copy storage if available, otherwise per-node Vec.
56    #[inline]
57    pub(crate) fn neighbors_at(&self, node_id: u32, layer: usize) -> &[u32] {
58        if let Some(ref flat) = self.flat_neighbors {
59            return flat.neighbors_at(node_id, layer);
60        }
61        let node = &self.nodes[node_id as usize];
62        if layer < node.neighbors.len() {
63            &node.neighbors[layer]
64        } else {
65            &[]
66        }
67    }
68
69    /// Number of layers a node participates in.
70    #[inline]
71    pub(crate) fn node_num_layers(&self, node_id: u32) -> usize {
72        if let Some(ref flat) = self.flat_neighbors {
73            return flat.num_layers(node_id);
74        }
75        self.nodes[node_id as usize].neighbors.len()
76    }
77
78    /// Ensure mutable per-node neighbor Vecs are available.
79    /// Materializes flat storage back to per-node Vecs if needed.
80    pub(crate) fn ensure_mutable_neighbors(&mut self) {
81        if let Some(flat) = self.flat_neighbors.take() {
82            let nested = flat.to_nested(self.nodes.len());
83            for (i, layers) in nested.into_iter().enumerate() {
84                self.nodes[i].neighbors = layers;
85            }
86        }
87    }
88}
89
90/// Lightweight xorshift64 PRNG for layer assignment.
91pub struct Xorshift64(pub u64);
92
93impl Xorshift64 {
94    pub fn new(seed: u64) -> Self {
95        Self(seed.max(1))
96    }
97
98    pub fn next_f64(&mut self) -> f64 {
99        self.0 ^= self.0 << 13;
100        self.0 ^= self.0 >> 7;
101        self.0 ^= self.0 << 17;
102        (self.0 as f64) / (u64::MAX as f64)
103    }
104}
105
106/// Ordered candidate for priority queues during search and construction.
107#[derive(Clone, Copy, PartialEq)]
108pub struct Candidate {
109    pub dist: f32,
110    pub id: u32,
111}
112
113impl Eq for Candidate {}
114
115impl PartialOrd for Candidate {
116    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl Ord for Candidate {
122    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
123        self.dist
124            .partial_cmp(&other.dist)
125            .unwrap_or(std::cmp::Ordering::Equal)
126            .then(self.id.cmp(&other.id))
127    }
128}
129
130impl HnswIndex {
131    /// Create a new empty HNSW index.
132    pub fn new(dim: usize, params: HnswParams) -> Self {
133        Self {
134            dim,
135            nodes: Vec::new(),
136            entry_point: None,
137            max_layer: 0,
138            rng: Xorshift64::new(42),
139            flat_neighbors: None,
140            params,
141        }
142    }
143
144    /// Create with a specific RNG seed (for deterministic testing).
145    pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
146        Self {
147            dim,
148            nodes: Vec::new(),
149            entry_point: None,
150            max_layer: 0,
151            rng: Xorshift64::new(seed),
152            flat_neighbors: None,
153            params,
154        }
155    }
156
157    pub fn len(&self) -> usize {
158        self.nodes.len()
159    }
160
161    pub fn live_count(&self) -> usize {
162        self.nodes.len() - self.tombstone_count()
163    }
164
165    pub fn tombstone_count(&self) -> usize {
166        self.nodes.iter().filter(|n| n.deleted).count()
167    }
168
169    /// Tombstone ratio: fraction of nodes that are deleted.
170    pub fn tombstone_ratio(&self) -> f64 {
171        if self.nodes.is_empty() {
172            0.0
173        } else {
174            self.tombstone_count() as f64 / self.nodes.len() as f64
175        }
176    }
177
178    pub fn is_empty(&self) -> bool {
179        self.live_count() == 0
180    }
181
182    /// Soft-delete a vector by internal node ID.
183    pub fn delete(&mut self, id: u32) -> bool {
184        if let Some(node) = self.nodes.get_mut(id as usize) {
185            if node.deleted {
186                return false;
187            }
188            node.deleted = true;
189            true
190        } else {
191            false
192        }
193    }
194
195    pub fn is_deleted(&self, id: u32) -> bool {
196        self.nodes.get(id as usize).is_none_or(|n| n.deleted)
197    }
198
199    pub fn undelete(&mut self, id: u32) -> bool {
200        if let Some(node) = self.nodes.get_mut(id as usize)
201            && node.deleted
202        {
203            node.deleted = false;
204            return true;
205        }
206        false
207    }
208
209    pub fn dim(&self) -> usize {
210        self.dim
211    }
212
213    pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
214        self.nodes.get(id as usize).map(|n| n.vector.as_slice())
215    }
216
217    pub fn params(&self) -> &HnswParams {
218        &self.params
219    }
220
221    pub fn entry_point(&self) -> Option<u32> {
222        self.entry_point
223    }
224
225    pub fn max_layer(&self) -> usize {
226        self.max_layer
227    }
228
229    /// Current RNG state (for snapshot reproducibility).
230    pub fn rng_state(&self) -> u64 {
231        self.rng.0
232    }
233
234    /// Approximate memory usage in bytes (vector data + neighbor lists).
235    pub fn memory_usage_bytes(&self) -> usize {
236        let vector_bytes = self.nodes.len() * self.dim * std::mem::size_of::<f32>();
237        let neighbor_bytes: usize = self
238            .nodes
239            .iter()
240            .map(|n| {
241                n.neighbors
242                    .iter()
243                    .map(|layer| layer.len() * 4)
244                    .sum::<usize>()
245            })
246            .sum();
247        let node_overhead = self.nodes.len() * std::mem::size_of::<Node>();
248        vector_bytes + neighbor_bytes + node_overhead
249    }
250
251    /// Export all vectors for snapshot transfer.
252    pub fn export_vectors(&self) -> Vec<Vec<f32>> {
253        self.nodes.iter().map(|n| n.vector.clone()).collect()
254    }
255
256    /// Export all neighbor lists for snapshot transfer.
257    pub fn export_neighbors(&self) -> Vec<Vec<Vec<u32>>> {
258        self.nodes.iter().map(|n| n.neighbors.clone()).collect()
259    }
260
261    /// Assign a random layer using the exponential distribution.
262    ///
263    /// Capped at `MAX_LAYER_CAP` to prevent pathological RNG draws from
264    /// promoting the index's `max_layer` to hundreds or thousands, which
265    /// would make every search's Phase-1 greedy descent O(max_layer).
266    pub(crate) fn random_layer(&mut self) -> usize {
267        let ml = 1.0 / (self.params.m as f64).ln();
268        let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
269        let layer = (-r.ln() * ml).floor() as usize;
270        layer.min(MAX_LAYER_CAP)
271    }
272
273    /// Compute distance between a query vector and a stored node.
274    pub(crate) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
275        distance(
276            query,
277            &self.nodes[node_id as usize].vector,
278            self.params.metric,
279        )
280    }
281
282    /// Max neighbors allowed at a given layer.
283    pub(crate) fn max_neighbors(&self, layer: usize) -> usize {
284        if layer == 0 {
285            self.params.m0
286        } else {
287            self.params.m
288        }
289    }
290
291    /// Compact the index by removing all tombstoned nodes.
292    ///
293    /// Returns the number of removed nodes. See `compact_with_map` for the
294    /// variant that also returns the old→new id remapping.
295    pub fn compact(&mut self) -> usize {
296        self.compact_with_map().0
297    }
298
299    /// Compact and return both the removed count and the old→new id map.
300    ///
301    /// `id_map[old_local]` = new_local, or `u32::MAX` if the node was
302    /// tombstoned (removed).
303    pub fn compact_with_map(&mut self) -> (usize, Vec<u32>) {
304        let tombstone_count = self.tombstone_count();
305        if tombstone_count == 0 {
306            let identity: Vec<u32> = (0..self.nodes.len() as u32).collect();
307            return (0, identity);
308        }
309        self.ensure_mutable_neighbors();
310
311        let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
312        let mut new_id = 0u32;
313        for node in &self.nodes {
314            if node.deleted {
315                id_map.push(u32::MAX);
316            } else {
317                id_map.push(new_id);
318                new_id += 1;
319            }
320        }
321
322        let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
323        for node in self.nodes.drain(..) {
324            if node.deleted {
325                continue;
326            }
327            let remapped_neighbors: Vec<Vec<u32>> = node
328                .neighbors
329                .into_iter()
330                .map(|layer_neighbors| {
331                    layer_neighbors
332                        .into_iter()
333                        .filter_map(|old_nid| {
334                            let new_nid = id_map[old_nid as usize];
335                            if new_nid == u32::MAX {
336                                None
337                            } else {
338                                Some(new_nid)
339                            }
340                        })
341                        .collect()
342                })
343                .collect();
344            new_nodes.push(Node {
345                vector: node.vector,
346                neighbors: remapped_neighbors,
347                deleted: false,
348            });
349        }
350
351        self.entry_point = if let Some(old_ep) = self.entry_point {
352            let new_ep = id_map[old_ep as usize];
353            if new_ep == u32::MAX {
354                new_nodes
355                    .iter()
356                    .enumerate()
357                    .max_by_key(|(_, n)| n.neighbors.len())
358                    .map(|(i, _)| i as u32)
359            } else {
360                Some(new_ep)
361            }
362        } else {
363            None
364        };
365
366        self.max_layer = new_nodes
367            .iter()
368            .map(|n| n.neighbors.len().saturating_sub(1))
369            .max()
370            .unwrap_or(0);
371
372        self.nodes = new_nodes;
373        (tombstone_count, id_map)
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::distance::DistanceMetric;
381
382    #[test]
383    fn create_empty_index() {
384        let idx = HnswIndex::new(3, HnswParams::default());
385        assert_eq!(idx.len(), 0);
386        assert!(idx.is_empty());
387        assert!(idx.entry_point().is_none());
388    }
389
390    #[test]
391    fn params_default() {
392        let p = HnswParams::default();
393        assert_eq!(p.m, 16);
394        assert_eq!(p.m0, 32);
395        assert_eq!(p.ef_construction, 200);
396        assert_eq!(p.metric, DistanceMetric::Cosine);
397    }
398
399    #[test]
400    fn candidate_ordering() {
401        let a = Candidate { dist: 0.1, id: 1 };
402        let b = Candidate { dist: 0.5, id: 2 };
403        assert!(a < b);
404    }
405}