Skip to main content

nodedb_vector/hnsw/
build.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! HNSW insert algorithm (Malkov & Yashunin, Algorithm 1).
4
5use crate::error::VectorError;
6use crate::hnsw::graph::{Candidate, HnswIndex, Node};
7use crate::hnsw::search::search_layer;
8
9impl HnswIndex {
10    /// Insert a vector into the index.
11    ///
12    /// 1. Assign a random layer using the exponential distribution
13    /// 2. Greedily descend from the entry point to the new node's layer + 1
14    /// 3. At each layer from the node's layer down to 0, search for nearest
15    ///    neighbors, select via the diversity heuristic, and add bidirectional edges
16    /// 4. Prune over-connected nodes to maintain the M/M0 invariant
17    pub fn insert(&mut self, vector: Vec<f32>) -> Result<(), VectorError> {
18        // Materialize flat neighbor storage on first mutation.
19        self.ensure_mutable_neighbors();
20
21        if vector.len() != self.dim {
22            return Err(VectorError::DimensionMismatch {
23                expected: self.dim,
24                got: vector.len(),
25            });
26        }
27
28        let new_id = self.nodes.len() as u32;
29        let new_layer = self.random_layer();
30
31        let node = Node {
32            vector,
33            neighbors: (0..=new_layer).map(|_| Vec::new()).collect(),
34            deleted: false,
35        };
36        self.nodes.push(node);
37
38        let Some(ep) = self.entry_point else {
39            self.entry_point = Some(new_id);
40            self.max_layer = new_layer;
41            return Ok(());
42        };
43
44        // Clone the query vector to avoid aliasing self.nodes during mutation.
45        let query = self.nodes[new_id as usize].vector.clone();
46
47        let mut current_ep = ep;
48
49        // Phase 1: Greedy descent from top layer to new_layer + 1.
50        if self.max_layer > new_layer {
51            for layer in (new_layer + 1..=self.max_layer).rev() {
52                let results = search_layer(self, &query, current_ep, 1, layer, None, 0);
53                if let Some(nearest) = results.first() {
54                    current_ep = nearest.id;
55                }
56            }
57        }
58
59        // Phase 2: Insert at each layer from min(new_layer, max_layer) down to 0.
60        let insert_top = new_layer.min(self.max_layer);
61        for layer in (0..=insert_top).rev() {
62            let ef = self.params.ef_construction;
63            let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0);
64
65            let m = self.max_neighbors(layer);
66            let selected = select_neighbors_heuristic(self, &candidates, m);
67
68            self.nodes[new_id as usize].neighbors[layer] = selected.iter().map(|c| c.id).collect();
69
70            for neighbor in &selected {
71                let nid = neighbor.id as usize;
72                self.nodes[nid].neighbors[layer].push(new_id);
73
74                if self.nodes[nid].neighbors[layer].len() > m {
75                    let node_vec = self.nodes[nid].vector.clone();
76                    self.prune_neighbors(nid, layer, &node_vec, m);
77                }
78            }
79
80            if let Some(nearest) = candidates.first() {
81                current_ep = nearest.id;
82            }
83        }
84
85        if new_layer > self.max_layer {
86            self.entry_point = Some(new_id);
87            self.max_layer = new_layer;
88        }
89
90        Ok(())
91    }
92
93    /// Prune a node's neighbor list using the diversity heuristic.
94    fn prune_neighbors(&mut self, node_idx: usize, layer: usize, node_vec: &[f32], m: usize) {
95        let neighbor_ids: Vec<u32> = self.nodes[node_idx].neighbors[layer].clone();
96
97        let mut candidates: Vec<Candidate> = neighbor_ids
98            .iter()
99            .map(|&nid| Candidate {
100                id: nid,
101                dist: self.dist_to_node(node_vec, nid),
102            })
103            .collect();
104        candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
105
106        let selected = select_neighbors_heuristic(self, &candidates, m);
107        self.nodes[node_idx].neighbors[layer] = selected.iter().map(|c| c.id).collect();
108    }
109}
110
111/// Heuristic neighbor selection (Malkov & Yashunin, Algorithm 4).
112fn select_neighbors_heuristic(
113    index: &HnswIndex,
114    candidates: &[Candidate],
115    m: usize,
116) -> Vec<Candidate> {
117    // no-governor: hot-path neighbor selection; bounded by m (graph degree ≤ 32)
118    let mut selected: Vec<Candidate> = Vec::with_capacity(m);
119
120    for candidate in candidates {
121        if selected.len() >= m {
122            break;
123        }
124
125        let candidate_vec = &index.nodes[candidate.id as usize].vector;
126        let selected_vecs: Vec<&[f32]> = selected
127            .iter()
128            .map(|s| index.nodes[s.id as usize].vector.as_slice())
129            .collect();
130
131        let is_diverse = crate::batch_distance::is_diverse_batched(
132            candidate_vec,
133            candidate.dist,
134            &selected_vecs,
135            index.params.metric,
136        );
137
138        if is_diverse {
139            selected.push(*candidate);
140        }
141    }
142
143    // Backfill if heuristic was too aggressive.
144    if selected.len() < m {
145        let selected_ids: std::collections::HashSet<u32> = selected.iter().map(|c| c.id).collect();
146        for candidate in candidates {
147            if selected.len() >= m {
148                break;
149            }
150            if !selected_ids.contains(&candidate.id) {
151                selected.push(*candidate);
152            }
153        }
154    }
155
156    selected
157}
158
159#[cfg(test)]
160mod tests {
161    use crate::distance::DistanceMetric;
162    use crate::hnsw::{HnswIndex, HnswParams};
163
164    fn make_index() -> HnswIndex {
165        HnswIndex::with_seed(
166            3,
167            HnswParams {
168                m: 4,
169                m0: 8,
170                ef_construction: 32,
171                metric: DistanceMetric::L2,
172            },
173            12345,
174        )
175    }
176
177    #[test]
178    fn insert_single() {
179        let mut idx = make_index();
180        idx.insert(vec![1.0, 0.0, 0.0]).unwrap();
181        assert_eq!(idx.len(), 1);
182        assert_eq!(idx.entry_point(), Some(0));
183    }
184
185    #[test]
186    fn insert_many_maintains_invariants() {
187        let mut idx = make_index();
188        for i in 0..100 {
189            let v = vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3];
190            idx.insert(v).unwrap();
191        }
192        assert_eq!(idx.len(), 100);
193        assert!(idx.entry_point().is_some());
194
195        for node in &idx.nodes {
196            assert!(node.neighbors[0].len() <= idx.params.m0);
197        }
198        for node in &idx.nodes {
199            for (layer, neighbors) in node.neighbors.iter().enumerate().skip(1) {
200                assert!(
201                    neighbors.len() <= idx.params.m,
202                    "layer {layer} neighbor count {} exceeds m={}",
203                    neighbors.len(),
204                    idx.params.m
205                );
206            }
207        }
208    }
209
210    #[test]
211    fn all_nodes_reachable_from_entry() {
212        let mut idx = make_index();
213        for i in 0..20 {
214            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
215        }
216
217        for target in 0..20u32 {
218            let query = idx.get_vector(target).unwrap().to_vec();
219            let results = idx.search(&query, 1, 32);
220            assert_eq!(results[0].id, target, "node {target} not reachable");
221        }
222    }
223
224    #[test]
225    fn compact_removes_tombstones() {
226        let mut idx = make_index();
227        for i in 0..20u32 {
228            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
229        }
230
231        for i in (0..20u32).step_by(2) {
232            assert!(idx.delete(i));
233        }
234        assert_eq!(idx.compact(), 10);
235        assert_eq!(idx.len(), 10);
236
237        for target_old_id in (1..20u32).step_by(2) {
238            let query = vec![target_old_id as f32, 0.0, 0.0];
239            let results = idx.search(&query, 1, 32);
240            assert!(!results.is_empty());
241            let found_vec = idx.get_vector(results[0].id).unwrap();
242            assert_eq!(found_vec[0], target_old_id as f32);
243        }
244    }
245}