nodedb_vector/codec_index/
build.rs1use std::cmp::Reverse;
9use std::collections::{BinaryHeap, HashSet};
10
11use nodedb_codec::vector_quant::codec::VectorCodec;
12
13use super::graph::{HnswCodecIndex, NodeC};
14
15#[derive(Clone, Copy, PartialEq)]
17struct Cand {
18 dist: f32,
19 idx: u32,
21}
22
23impl Eq for Cand {}
24
25impl PartialOrd for Cand {
26 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
27 Some(self.cmp(other))
28 }
29}
30
31impl Ord for Cand {
32 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
33 self.dist
34 .partial_cmp(&other.dist)
35 .unwrap_or(std::cmp::Ordering::Equal)
36 .then(self.idx.cmp(&other.idx))
37 }
38}
39
40impl<C: VectorCodec> HnswCodecIndex<C> {
41 pub fn insert(&mut self, id: u32, v: &[f32]) {
46 let quantized = self.codec.encode(v);
47 let node_layer = self.random_layer();
48
49 let new_idx = self.nodes.len() as u32;
51
52 let neighbors = vec![Vec::new(); node_layer + 1];
54
55 self.nodes.push(NodeC {
56 id,
57 deleted: false,
58 layer: node_layer,
59 quantized,
60 neighbors,
61 });
62
63 let Some(ep) = self.entry_point else {
64 self.entry_point = Some(new_idx);
66 self.max_layer = node_layer;
67 return;
68 };
69
70 let mut cur_ep = ep;
73 for layer in (node_layer + 1..=self.max_layer).rev() {
74 cur_ep = self.greedy_nearest(new_idx, cur_ep, layer);
75 }
76
77 let ef = self.ef_construction;
79 for layer in (0..=node_layer.min(self.max_layer)).rev() {
80 let candidates = self.search_layer_build(new_idx, cur_ep, ef, layer);
81
82 let max_nb = self.max_neighbors(layer);
84 let chosen: Vec<u32> = candidates
85 .iter()
86 .filter(|c| c.idx != new_idx)
87 .take(max_nb)
88 .map(|c| c.idx)
89 .collect();
90
91 self.nodes[new_idx as usize].neighbors[layer] = chosen.clone();
93
94 for &nb_idx in &chosen {
96 let new_dist = {
97 let nb_q = &self.nodes[nb_idx as usize].quantized as *const C::Quantized;
98 let new_q = &self.nodes[new_idx as usize].quantized as *const C::Quantized;
99 unsafe { self.codec.fast_symmetric_distance(&*nb_q, &*new_q) }
102 };
103
104 if layer < self.nodes[nb_idx as usize].neighbors.len() {
105 let nb_layer = &mut self.nodes[nb_idx as usize].neighbors[layer];
106 if !nb_layer.contains(&new_idx) {
107 nb_layer.push(new_idx);
108 }
109 if nb_layer.len() > max_nb {
111 self.prune_neighbors(nb_idx, layer, max_nb, new_dist, new_idx);
112 }
113 }
114 }
115
116 if let Some(best) = candidates.first() {
118 cur_ep = best.idx;
119 }
120 }
121
122 if node_layer > self.max_layer {
124 self.entry_point = Some(new_idx);
125 self.max_layer = node_layer;
126 }
127 }
128
129 fn greedy_nearest(&self, query_idx: u32, ep_idx: u32, layer: usize) -> u32 {
132 let mut best_idx = ep_idx;
133 let mut best_dist = self.sym_dist(query_idx, ep_idx);
134
135 loop {
136 let mut improved = false;
137 for &nb in self.neighbors_at(best_idx, layer) {
138 if self.nodes[nb as usize].deleted {
139 continue;
140 }
141 let d = self.sym_dist(query_idx, nb);
142 if d < best_dist {
143 best_dist = d;
144 best_idx = nb;
145 improved = true;
146 }
147 }
148 if !improved {
149 break;
150 }
151 }
152
153 best_idx
154 }
155
156 fn search_layer_build(
159 &self,
160 query_idx: u32,
161 ep_idx: u32,
162 ef: usize,
163 layer: usize,
164 ) -> Vec<Cand> {
165 let mut visited: HashSet<u32> = HashSet::new();
166 visited.insert(ep_idx);
167
168 let ep_dist = self.sym_dist(query_idx, ep_idx);
169 let ep_cand = Cand {
170 dist: ep_dist,
171 idx: ep_idx,
172 };
173
174 let mut candidates: BinaryHeap<Reverse<Cand>> = BinaryHeap::new();
175 candidates.push(Reverse(ep_cand));
176
177 let mut results: BinaryHeap<Cand> = BinaryHeap::new();
178 if !self.nodes[ep_idx as usize].deleted {
179 results.push(ep_cand);
180 }
181
182 while let Some(Reverse(cur)) = candidates.pop() {
183 let worst = results.peek().map_or(f32::INFINITY, |w| w.dist);
184 if cur.dist > worst && results.len() >= ef {
185 break;
186 }
187
188 for &nb in self.neighbors_at(cur.idx, layer) {
189 if !visited.insert(nb) {
190 continue;
191 }
192 let d = self.sym_dist(query_idx, nb);
193 let worst_now = results.peek().map_or(f32::INFINITY, |w| w.dist);
194 if d < worst_now || results.len() < ef {
195 candidates.push(Reverse(Cand { dist: d, idx: nb }));
196 }
197 if !self.nodes[nb as usize].deleted {
198 results.push(Cand { dist: d, idx: nb });
199 if results.len() > ef {
200 results.pop();
201 }
202 }
203 }
204 }
205
206 let mut out: Vec<Cand> = results.into_vec();
207 out.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
208 out
209 }
210
211 fn prune_neighbors(
214 &mut self,
215 nb_idx: u32,
216 layer: usize,
217 max_nb: usize,
218 _hint_dist: f32,
219 _hint_id: u32,
220 ) {
221 let nb_list = self.nodes[nb_idx as usize].neighbors[layer].clone();
223 let mut scored: Vec<(f32, u32)> = nb_list
224 .iter()
225 .map(|&cand_idx| {
226 let d = {
227 let a = &self.nodes[nb_idx as usize].quantized as *const C::Quantized;
228 let b = &self.nodes[cand_idx as usize].quantized as *const C::Quantized;
229 unsafe { self.codec.fast_symmetric_distance(&*a, &*b) }
231 };
232 (d, cand_idx)
233 })
234 .collect();
235
236 scored.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
238 scored.truncate(max_nb);
239
240 self.nodes[nb_idx as usize].neighbors[layer] =
241 scored.into_iter().map(|(_, idx)| idx).collect();
242 }
243
244 #[inline]
246 pub(crate) fn sym_dist(&self, a_idx: u32, b_idx: u32) -> f32 {
247 let a = &self.nodes[a_idx as usize].quantized;
248 let b = &self.nodes[b_idx as usize].quantized;
249 self.codec.fast_symmetric_distance(a, b)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::quantize::Sq8Codec;
259
260 fn make_sq8(dim: usize, n: usize) -> Sq8Codec {
261 let vecs: Vec<Vec<f32>> = (0..n)
262 .map(|i| (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect())
263 .collect();
264 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
265 Sq8Codec::calibrate(&refs, dim)
266 }
267
268 #[test]
269 fn insert_sets_entry_point() {
270 let codec = make_sq8(4, 10);
271 let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 1);
272 idx.insert(0, &[0.1, 0.2, 0.3, 0.4]);
273 assert!(idx.entry_point.is_some());
274 assert_eq!(idx.len(), 1);
275 }
276
277 #[test]
278 fn insert_multiple_grows_nodes() {
279 let codec = make_sq8(4, 30);
280 let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 42);
281 for i in 0..20u32 {
282 let v: Vec<f32> = (0..4).map(|d| (i as usize * 4 + d) as f32).collect();
283 idx.insert(i, &v);
284 }
285 assert_eq!(idx.len(), 20);
286 assert!(idx.entry_point.is_some());
287 }
288}