1use crate::quantized::{StoredVector, StoredVectorEntry, quantized_hnsw_distance};
2use anndists::dist::distances::{DistCosine, Distance};
3use contextdb_core::{Error, Result, RowId, VectorIndexRef, VectorQuantization};
4use hnsw_rs::hnsw::Hnsw;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9pub struct HnswIndex {
10 hnsw: HnswInner,
11 id_to_row: RwLock<HashMap<usize, RowId>>,
12 row_to_id: RwLock<HashMap<RowId, usize>>,
13 next_id: AtomicUsize,
14 dimension: usize,
15 quantization: VectorQuantization,
16 ef_search: usize,
17}
18
19enum HnswInner {
20 F32(Hnsw<'static, f32, DistCosine>),
21 Quantized(Hnsw<'static, u8, DistQuantizedCosine>),
22}
23
24#[derive(Debug, Clone, Copy)]
25struct DistQuantizedCosine {
26 quantization: VectorQuantization,
27}
28
29impl Distance<u8> for DistQuantizedCosine {
30 fn eval(&self, va: &[u8], vb: &[u8]) -> f32 {
31 quantized_hnsw_distance(va, vb, self.quantization)
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct HnswGraphStats {
37 pub point_count: usize,
38 pub layer0_points: usize,
39 pub layer0_neighbor_edges: usize,
40 pub max_level_observed: u8,
41 pub dimension: usize,
42}
43
44impl HnswIndex {
45 pub(crate) fn new(
46 entries: &[StoredVectorEntry],
47 dimension: usize,
48 quantization: VectorQuantization,
49 ) -> Self {
50 let (m, ef_construction, ef_search) = select_params(entries.len(), quantization);
51 let max_elements = entries.len().max(1);
52 let hnsw = match quantization {
53 VectorQuantization::F32 => {
54 let mut hnsw = Hnsw::new(m, max_elements, 16, ef_construction, DistCosine);
55 hnsw.set_extend_candidates(true);
56 hnsw.set_keeping_pruned(true);
57 HnswInner::F32(hnsw)
58 }
59 VectorQuantization::SQ8 | VectorQuantization::SQ4 => {
60 let mut hnsw = Hnsw::new(
61 m,
62 max_elements,
63 16,
64 ef_construction,
65 DistQuantizedCosine { quantization },
66 );
67 hnsw.set_extend_candidates(true);
68 hnsw.set_keeping_pruned(true);
69 HnswInner::Quantized(hnsw)
70 }
71 };
72 let id_to_row = RwLock::new(HashMap::with_capacity(entries.len()));
73 let row_to_id = RwLock::new(HashMap::with_capacity(entries.len()));
74 let mut sorted_entries = entries.iter().collect::<Vec<_>>();
75 sorted_entries.sort_by_key(|entry| {
76 (
77 insertion_key(entry),
78 entry.lsn,
79 entry.created_tx,
80 entry.row_id,
81 )
82 });
83
84 match &hnsw {
85 HnswInner::F32(index) => {
86 let data = sorted_entries
87 .iter()
88 .enumerate()
89 .filter_map(|(data_id, entry)| {
90 entry.vector.as_f32_slice().map(|vector| {
91 id_to_row.write().insert(data_id, entry.row_id);
92 row_to_id.write().insert(entry.row_id, data_id);
93 (vector.to_vec(), data_id)
94 })
95 })
96 .collect::<Vec<_>>();
97 let refs = data
98 .iter()
99 .map(|(vector, data_id)| (vector, *data_id))
100 .collect::<Vec<_>>();
101 index.parallel_insert(&refs);
102 }
103 HnswInner::Quantized(index) => {
104 let data = sorted_entries
105 .iter()
106 .enumerate()
107 .filter_map(|(data_id, entry)| {
108 let encoded = entry.vector.to_hnsw_u8();
109 (!encoded.is_empty()).then(|| {
110 id_to_row.write().insert(data_id, entry.row_id);
111 row_to_id.write().insert(entry.row_id, data_id);
112 (encoded, data_id)
113 })
114 })
115 .collect::<Vec<_>>();
116 let refs = data
117 .iter()
118 .map(|(vector, data_id)| (vector, *data_id))
119 .collect::<Vec<_>>();
120 index.parallel_insert(&refs);
121 }
122 }
123
124 Self {
125 hnsw,
126 id_to_row,
127 row_to_id,
128 next_id: AtomicUsize::new(entries.len()),
129 dimension,
130 quantization,
131 ef_search,
132 }
133 }
134
135 pub(crate) fn insert(&self, row_id: RowId, vector: &StoredVector) {
136 let data_id = self.next_id.fetch_add(1, Ordering::Relaxed);
137 insert_into_hnsw(&self.hnsw, vector, data_id);
138 self.id_to_row.write().insert(data_id, row_id);
139 self.row_to_id.write().insert(row_id, data_id);
140 }
141
142 pub fn len(&self) -> usize {
144 self.next_id.load(Ordering::Relaxed)
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.len() == 0
149 }
150
151 #[doc(hidden)]
152 pub fn graph_stats(&self) -> HnswGraphStats {
153 let (point_count, layer0_neighbor_edges, max_level_observed) = match &self.hnsw {
154 HnswInner::F32(hnsw) => hnsw_stats(hnsw),
155 HnswInner::Quantized(hnsw) => hnsw_stats(hnsw),
156 };
157
158 HnswGraphStats {
159 point_count,
160 layer0_points: point_count,
161 layer0_neighbor_edges,
162 max_level_observed,
163 dimension: self.dimension,
164 }
165 }
166
167 pub fn search(
168 &self,
169 index: &VectorIndexRef,
170 query: &[f32],
171 k: usize,
172 ) -> Result<Vec<(RowId, f32)>> {
173 if k == 0 {
174 return Ok(Vec::new());
175 }
176
177 let got = query.len();
178 if got != self.dimension {
179 return Err(Error::VectorIndexDimensionMismatch {
180 index: index.clone(),
181 expected: self.dimension,
182 actual: got,
183 });
184 }
185
186 let ef = self.ef_search.max(k.saturating_mul(10)).max(1);
187 let neighbors = match &self.hnsw {
188 HnswInner::F32(hnsw) => hnsw.search(query, ef, ef),
189 HnswInner::Quantized(hnsw) => {
190 let encoded = StoredVector::from_f32(query, self.quantization).to_hnsw_u8();
191 hnsw.search(&encoded, ef, ef)
192 }
193 };
194 let id_to_row = self.id_to_row.read();
195
196 Ok(neighbors
197 .into_iter()
198 .filter_map(|neighbor| {
199 id_to_row
200 .get(&neighbor.d_id)
201 .copied()
202 .map(|row_id| (row_id, 1.0 - neighbor.distance))
203 })
204 .collect())
205 }
206}
207
208fn insert_into_hnsw(hnsw: &HnswInner, vector: &StoredVector, data_id: usize) {
209 match hnsw {
210 HnswInner::F32(hnsw) => {
211 let Some(vector) = vector.as_f32_slice() else {
212 return;
213 };
214 hnsw.insert((vector, data_id));
215 }
216 HnswInner::Quantized(hnsw) => {
217 let encoded = vector.to_hnsw_u8();
218 if !encoded.is_empty() {
219 hnsw.insert((&encoded, data_id));
220 }
221 }
222 }
223}
224
225fn hnsw_stats<T, D>(hnsw: &Hnsw<'_, T, D>) -> (usize, usize, u8)
226where
227 T: Clone + Send + Sync,
228 D: Distance<T> + Send + Sync,
229{
230 let indexation = hnsw.get_point_indexation();
231 let layer0_neighbor_edges = indexation
232 .get_layer_iterator(0)
233 .map(|point| {
234 point
235 .get_neighborhood_id()
236 .first()
237 .map_or(0, |neighbors| neighbors.len())
238 })
239 .sum();
240 (
241 hnsw.get_nb_point(),
242 layer0_neighbor_edges,
243 hnsw.get_max_level_observed(),
244 )
245}
246
247fn select_params(count: usize, quantization: VectorQuantization) -> (usize, usize, usize) {
248 if !matches!(quantization, VectorQuantization::F32) {
249 return match count {
250 0..=5000 => (8, 32, 96.min(count.max(32))),
251 5001..=50000 => (12, 64, 128),
252 _ => (12, 64, 128),
253 };
254 }
255 match count {
256 0..=5000 => (16, 200, count.max(200)),
257 5001..=50000 => (24, 400, 400),
258 _ => (16, 200, 200),
259 }
260}
261
262fn insertion_key(entry: &StoredVectorEntry) -> u64 {
263 let mut x = entry.row_id.0 ^ entry.lsn.0 ^ entry.created_tx.0;
264 x = x.wrapping_add(0x9e37_79b9_7f4a_7c15);
265 x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
266 x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
267 x ^ (x >> 31)
268}