Skip to main content

nodedb_vector/
build.rs

1//! HNSW insert algorithm (Malkov & Yashunin, Algorithm 1).
2
3use crate::distance::distance;
4use crate::error::VectorError;
5use crate::hnsw::{Candidate, HnswIndex, Node};
6use crate::search::search_layer;
7
8impl HnswIndex {
9    /// Insert a vector into the index.
10    ///
11    /// 1. Assign a random layer using the exponential distribution
12    /// 2. Greedily descend from the entry point to the new node's layer + 1
13    /// 3. At each layer from the node's layer down to 0, search for nearest
14    ///    neighbors, select via the diversity heuristic, and add bidirectional edges
15    /// 4. Prune over-connected nodes to maintain the M/M0 invariant
16    pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
17        if vector.len() != self.dim {
18            return Err(VectorError::DimensionMismatch {
19                expected: self.dim,
20                got: vector.len(),
21            });
22        }
23
24        let new_id = self.nodes.len() as u32;
25        let new_layer = self.random_layer();
26
27        let node = Node {
28            vector,
29            neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
30            deleted: false,
31        };
32        self.nodes.push(node);
33
34        let Some(ep) = self.entry_point else {
35            self.entry_point = Some(new_id);
36            self.max_layer = new_layer;
37            return Ok(());
38        };
39
40        // Clone the query vector to avoid aliasing self.nodes during mutation.
41        let query = self.nodes[new_id as usize].vector.clone();
42
43        let mut current_ep = ep;
44
45        // Phase 1: Greedy descent from top layer to new_layer + 1.
46        if self.max_layer > new_layer {
47            for layer in (new_layer + 1..=self.max_layer).rev() {
48                let results = search_layer(self, &query, current_ep, 1, layer, None);
49                if let Some(nearest) = results.first() {
50                    current_ep = nearest.id;
51                }
52            }
53        }
54
55        // Phase 2: Insert at each layer from min(new_layer, max_layer) down to 0.
56        let insert_top = new_layer.min(self.max_layer);
57        for layer in (0..=insert_top).rev() {
58            let ef = self.params.ef_construction;
59            let candidates = search_layer(self, &query, current_ep, ef, layer, None);
60
61            let m = self.max_neighbors(layer);
62            let selected = select_neighbors_heuristic(self, &candidates, m);
63
64            self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
65
66            for neighbor in &selected {
67                let nid = neighbor.id as usize;
68                self.nodes[nid].neighbors[layer].push(new_id);
69
70                if self.nodes[nid].neighbors[layer].len() > m {
71                    let node_vec = self.nodes[nid].vector.clone();
72                    self.prune_neighbors(nid, layer, &node_vec, m);
73                }
74            }
75
76            if let Some(nearest) = candidates.first() {
77                current_ep = nearest.id;
78            }
79        }
80
81        if new_layer > self.max_layer {
82            self.entry_point = Some(new_id);
83            self.max_layer = new_layer;
84        }
85
86        Ok(())
87    }
88
89    /// Prune a node's neighbor list using the diversity heuristic.
90    fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
91        let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
92
93        let mut candidates: Vec<Candidate> = neighbor_ids
94            .iter()
95            .map(|&nid| Candidate {
96                id: nid,
97                dist: self.dist_to_node(node_vec, nid),
98            })
99            .collect();
100        candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
101
102        let selected = select_neighbors_heuristic(self, &candidates, m);
103        self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
104    }
105}
106
107/// Heuristic neighbor selection (Malkov & Yashunin, Algorithm 4).
108///
109/// Selects up to `m` neighbors that are both close AND diverse.
110/// A candidate is kept only if it is closer to query than to every
111/// already-selected neighbor — producing better long-range connectivity.
112fn select_neighbors_heuristic(
113    index: &HnswIndex,
114    candidates: &[Candidate],
115    m: usize,
116) -> Vec<Candidate> {
117    let mut selected: Vec<Candidate> = Vec::with_capacity(m);
118
119    for candidate in candidates {
120        if selected.len() >= m {
121            break;
122        }
123
124        let dist_to_query = candidate.dist;
125        let is_diverse = selected.iter().all(|s| {
126            let dist_to_selected = distance(
127                &index.nodes[candidate.id as usize].vector,
128                &index.nodes[s.id as usize].vector,
129                index.params.metric,
130            );
131            dist_to_query <= dist_to_selected
132        });
133
134        if is_diverse {
135            selected.push(*candidate);
136        }
137    }
138
139    // Backfill if heuristic was too aggressive.
140    if selected.len() < m {
141        let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
142        for candidate in candidates {
143            if selected.len() >= m {
144                break;
145            }
146            if !selected_ids.contains(&candidate.id) {
147                selected.push(*candidate);
148            }
149        }
150    }
151
152    selected
153}
154
155#[cfg(test)]
156mod tests {
157    use crate::distance::DistanceMetric;
158    use crate::hnsw::{HnswIndex, HnswParams};
159
160    fn make_index() -> HnswIndex {
161        HnswIndex::with_seed(
162            3,
163            HnswParams {
164                m: 4,
165                m0: 8,
166                ef_construction: 32,
167                metric: DistanceMetric::L2,
168            },
169            12345,
170        )
171    }
172
173    #[test]
174    fn insert_single() {
175        let mut idx = make_index();
176        idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
177        assert_eq!(idx.len(), 1);
178        assert_eq!(idx.entry_point(), Some(0));
179    }
180
181    #[test]
182    fn insert_many_maintains_invariants() {
183        let mut idx = make_index();
184        for i in 0..100 {
185            let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
186            idx.insert(v).unwrap();
187        }
188        assert_eq!(idx.len(), 100);
189        assert!(idx.entry_point().is_some());
190
191        for node in &idx.nodes {
192            assert!(node.neighbors[0].len() <= idx.params.m0);
193        }
194        for node in &idx.nodes {
195            for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
196                assert!(
197                    neighbors.len() <= idx.params.m,
198                    "layer {layer} neighbor count {} exceeds m={}",
199                    neighbors.len(),
200                    idx.params.m
201                );
202            }
203        }
204    }
205
206    #[test]
207    fn all_nodes_reachable_from_entry() {
208        let mut idx = make_index();
209        for i in 0..20 {
210            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
211        }
212
213        for target in 0..20u32 {
214            let query = idx.get_vector(target).unwrap().to_vec();
215            let results = idx.search(&query, 1, 32);
216            assert_eq!(results[0].id, target, "node {target} not reachable");
217        }
218    }
219
220    #[test]
221    fn compact_removes_tombstones() {
222        let mut idx = make_index();
223        for i in 0..20u32 {
224            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
225        }
226
227        for i in (0..20u32).step_by(2) {
228            assert!(idx.delete(i));
229        }
230        assert_eq!(idx.compact(), 10);
231        assert_eq!(idx.len(), 10);
232
233        for target_old_id in (1..20u32).step_by(2) {
234            let query = vec![target_old_id as f32, 0.0, 0.0];
235            let results = idx.search(&query, 1, 32);
236            assert!(!results.is_empty());
237            let found_vec = idx.get_vector(results[0].id).unwrap();
238            assert_eq!(found_vec[0], target_old_id as f32);
239        }
240    }
241
242    #[test]
243    fn checkpoint_roundtrip() {
244        let mut idx = make_index();
245        for i in 0..50 {
246            idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
247                .unwrap();
248        }
249
250        let bytes = idx.checkpoint_to_bytes();
251        assert!(!bytes.is_empty());
252
253        let restored = HnswIndex::from_checkpoint(&bytes).unwrap();
254        assert_eq!(restored.len(), 50);
255        assert_eq!(restored.dim(), 3);
256        assert_eq!(restored.entry_point(), idx.entry_point());
257        assert_eq!(restored.max_layer(), idx.max_layer());
258
259        // Verify search works on restored index.
260        let query = vec![1.0, 2.0, 3.0];
261        let orig_results = idx.search(&query, 5, 32);
262        let rest_results = restored.search(&query, 5, 32);
263        assert_eq!(orig_results.len(), rest_results.len());
264        for (a, b) in orig_results.iter().zip(rest_results.iter()) {
265            assert_eq!(a.id, b.id);
266            assert!((a.distance - b.distance).abs() < 1e-6);
267        }
268    }
269}