Skip to main content

nodedb_vector/codec_index/
build.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Insertion algorithm for `HnswCodecIndex<C>`.
4//!
5//! Implements the standard HNSW insert (Malkov & Yashunin 2018, Algorithm 1)
6//! using `codec.fast_symmetric_distance` for all neighbor-selection passes.
7
8use std::cmp::Reverse;
9use std::collections::{BinaryHeap, HashSet};
10
11use nodedb_codec::vector_quant::codec::VectorCodec;
12
13use super::graph::{HnswCodecIndex, NodeC};
14
15/// Ordered pair for priority queues (dist, node_idx in `nodes` vec).
16#[derive(Clone, Copy, PartialEq)]
17struct Cand {
18    dist: f32,
19    /// Index into `HnswCodecIndex::nodes`.
20    idx: u32,
21}
22
23impl Eq for Cand {}
24
25impl PartialOrd for Cand {
26    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
27        Some(self.cmp(other))
28    }
29}
30
31impl Ord for Cand {
32    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
33        self.dist
34            .partial_cmp(&other.dist)
35            .unwrap_or(std::cmp::Ordering::Equal)
36            .then(self.idx.cmp(&other.idx))
37    }
38}
39
40impl<C: VectorCodec> HnswCodecIndex<C> {
41    /// Insert a vector with the given caller-supplied `id`.
42    ///
43    /// Encodes `v` via `codec.encode`, assigns a random layer, and runs the
44    /// standard HNSW neighbour-selection algorithm.
45    pub fn insert(&mut self, id: u32, v: &[f32]) {
46        let quantized = self.codec.encode(v);
47        let node_layer = self.random_layer();
48
49        // Allocate the node first so we can use its index in the graph wiring.
50        let new_idx = self.nodes.len() as u32;
51
52        // Build empty neighbor lists: one per layer 0..=node_layer.
53        let neighbors = vec![Vec::new(); node_layer + 1];
54
55        self.nodes.push(NodeC {
56            id,
57            deleted: false,
58            layer: node_layer,
59            quantized,
60            neighbors,
61        });
62
63        let Some(ep) = self.entry_point else {
64            // First node: it becomes the entry point.
65            self.entry_point = Some(new_idx);
66            self.max_layer = node_layer;
67            return;
68        };
69
70        // Phase 1: greedy descent from max_layer down to node_layer + 1.
71        // Carry a single nearest candidate per layer (ef = 1).
72        let mut cur_ep = ep;
73        for layer in (node_layer + 1..=self.max_layer).rev() {
74            cur_ep = self.greedy_nearest(new_idx, cur_ep, layer);
75        }
76
77        // Phase 2: ef_construction search from node_layer down to 0.
78        let ef = self.ef_construction;
79        for layer in (0..=node_layer.min(self.max_layer)).rev() {
80            let candidates = self.search_layer_build(new_idx, cur_ep, ef, layer);
81
82            // Choose the m (or m0 at layer 0) nearest as neighbours.
83            let max_nb = self.max_neighbors(layer);
84            let chosen: Vec<u32> = candidates
85                .iter()
86                .filter(|c| c.idx != new_idx)
87                .take(max_nb)
88                .map(|c| c.idx)
89                .collect();
90
91            // Set new node's neighbours at this layer.
92            self.nodes[new_idx as usize].neighbors[layer] = chosen.clone();
93
94            // Update chosen neighbours reciprocally.
95            for &nb_idx in &chosen {
96                let new_dist = {
97                    let nb_q = &self.nodes[nb_idx as usize].quantized as *const C::Quantized;
98                    let new_q = &self.nodes[new_idx as usize].quantized as *const C::Quantized;
99                    // SAFETY: we hold exclusive access to `self`; the two
100                    // borrows are to distinct nodes.
101                    unsafe { self.codec.fast_symmetric_distance(&*nb_q, &*new_q) }
102                };
103
104                if layer < self.nodes[nb_idx as usize].neighbors.len() {
105                    let nb_layer = &mut self.nodes[nb_idx as usize].neighbors[layer];
106                    if !nb_layer.contains(&new_idx) {
107                        nb_layer.push(new_idx);
108                    }
109                    // Prune if over capacity.
110                    if nb_layer.len() > max_nb {
111                        self.prune_neighbors(nb_idx, layer, max_nb, new_dist, new_idx);
112                    }
113                }
114            }
115
116            // Update entry point for the next lower layer.
117            if let Some(best) = candidates.first() {
118                cur_ep = best.idx;
119            }
120        }
121
122        // If the new node's layer exceeds the current max, promote it.
123        if node_layer > self.max_layer {
124            self.entry_point = Some(new_idx);
125            self.max_layer = node_layer;
126        }
127    }
128
129    /// Greedy descent: starting at `ep_idx`, find the single nearest node to
130    /// `query_idx` at the given `layer`.
131    fn greedy_nearest(&self, query_idx: u32, ep_idx: u32, layer: usize) -> u32 {
132        let mut best_idx = ep_idx;
133        let mut best_dist = self.sym_dist(query_idx, ep_idx);
134
135        loop {
136            let mut improved = false;
137            for &nb in self.neighbors_at(best_idx, layer) {
138                if self.nodes[nb as usize].deleted {
139                    continue;
140                }
141                let d = self.sym_dist(query_idx, nb);
142                if d < best_dist {
143                    best_dist = d;
144                    best_idx = nb;
145                    improved = true;
146                }
147            }
148            if !improved {
149                break;
150            }
151        }
152
153        best_idx
154    }
155
156    /// Beam search over `layer` with the given `ef`, returning candidates
157    /// sorted by ascending distance from `query_idx` (excluding deleted nodes).
158    fn search_layer_build(
159        &self,
160        query_idx: u32,
161        ep_idx: u32,
162        ef: usize,
163        layer: usize,
164    ) -> Vec<Cand> {
165        let mut visited: HashSet<u32> = HashSet::new();
166        visited.insert(ep_idx);
167
168        let ep_dist = self.sym_dist(query_idx, ep_idx);
169        let ep_cand = Cand {
170            dist: ep_dist,
171            idx: ep_idx,
172        };
173
174        let mut candidates: BinaryHeap<Reverse<Cand>> = BinaryHeap::new();
175        candidates.push(Reverse(ep_cand));
176
177        let mut results: BinaryHeap<Cand> = BinaryHeap::new();
178        if !self.nodes[ep_idx as usize].deleted {
179            results.push(ep_cand);
180        }
181
182        while let Some(Reverse(cur)) = candidates.pop() {
183            let worst = results.peek().map_or(f32::INFINITY, |w| w.dist);
184            if cur.dist > worst && results.len() >= ef {
185                break;
186            }
187
188            for &nb in self.neighbors_at(cur.idx, layer) {
189                if !visited.insert(nb) {
190                    continue;
191                }
192                let d = self.sym_dist(query_idx, nb);
193                let worst_now = results.peek().map_or(f32::INFINITY, |w| w.dist);
194                if d < worst_now || results.len() < ef {
195                    candidates.push(Reverse(Cand { dist: d, idx: nb }));
196                }
197                if !self.nodes[nb as usize].deleted {
198                    results.push(Cand { dist: d, idx: nb });
199                    if results.len() > ef {
200                        results.pop();
201                    }
202                }
203            }
204        }
205
206        let mut out: Vec<Cand> = results.into_vec();
207        out.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
208        out
209    }
210
211    /// Prune the neighbor list of `nb_idx` at `layer` to `max_nb` entries,
212    /// removing the farthest neighbours (simple distance-based strategy).
213    fn prune_neighbors(
214        &mut self,
215        nb_idx: u32,
216        layer: usize,
217        max_nb: usize,
218        _hint_dist: f32,
219        _hint_id: u32,
220    ) {
221        // Collect (dist_to_nb, idx) for every current neighbour.
222        let nb_list = self.nodes[nb_idx as usize].neighbors[layer].clone();
223        let mut scored: Vec<(f32, u32)> = nb_list
224            .iter()
225            .map(|&cand_idx| {
226                let d = {
227                    let a = &self.nodes[nb_idx as usize].quantized as *const C::Quantized;
228                    let b = &self.nodes[cand_idx as usize].quantized as *const C::Quantized;
229                    // SAFETY: a and b point to distinct nodes in `self.nodes`.
230                    unsafe { self.codec.fast_symmetric_distance(&*a, &*b) }
231                };
232                (d, cand_idx)
233            })
234            .collect();
235
236        // Keep the `max_nb` nearest.
237        scored.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
238        scored.truncate(max_nb);
239
240        self.nodes[nb_idx as usize].neighbors[layer] =
241            scored.into_iter().map(|(_, idx)| idx).collect();
242    }
243
244    /// Symmetric distance between two nodes identified by their dense indices.
245    #[inline]
246    pub(crate) fn sym_dist(&self, a_idx: u32, b_idx: u32) -> f32 {
247        let a = &self.nodes[a_idx as usize].quantized;
248        let b = &self.nodes[b_idx as usize].quantized;
249        // Both borrows are read-only; safe even through raw pointers if
250        // called for distinct indices, but here we just borrow directly.
251        self.codec.fast_symmetric_distance(a, b)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::quantize::Sq8Codec;
259
260    fn make_sq8(dim: usize, n: usize) -> Sq8Codec {
261        let vecs: Vec<Vec<f32>> = (0..n)
262            .map(|i| (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect())
263            .collect();
264        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
265        Sq8Codec::calibrate(&refs, dim)
266    }
267
268    #[test]
269    fn insert_sets_entry_point() {
270        let codec = make_sq8(4, 10);
271        let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 1);
272        idx.insert(0, &[0.1, 0.2, 0.3, 0.4]);
273        assert!(idx.entry_point.is_some());
274        assert_eq!(idx.len(), 1);
275    }
276
277    #[test]
278    fn insert_multiple_grows_nodes() {
279        let codec = make_sq8(4, 30);
280        let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 42);
281        for i in 0..20u32 {
282            let v: Vec<f32> = (0..4).map(|d| (i as usize * 4 + d) as f32).collect();
283            idx.insert(i, &v);
284        }
285        assert_eq!(idx.len(), 20);
286        assert!(idx.entry_point.is_some());
287    }
288}