1use std::{
6 collections::{HashMap, HashSet},
7 path::{Path, PathBuf},
8 sync::{
9 atomic::{AtomicUsize, Ordering},
10 RwLock,
11 },
12};
13
14use hnsw_rs::prelude::*;
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::{
19 config::VectorConfig,
20 error::{VectorError, VectorResult},
21 types::DistanceMetric,
22};
23
24#[derive(Debug, Clone)]
28pub struct HnswStats {
29 pub element_count: usize,
31 pub max_elements: usize,
33 pub ef_construction: usize,
35 pub m_connections: usize,
37 pub layers: usize,
39}
40
41enum HnswInner {
44 L2(Hnsw<'static, f32, DistL2>),
45 Cosine(Hnsw<'static, f32, DistCosine>),
46 Dot(Hnsw<'static, f32, DistDot>),
47}
48
49impl HnswInner {
50 fn insert(&self, id: usize, vector: &[f32]) {
51 match self {
52 HnswInner::L2(h) => h.insert((vector, id)),
53 HnswInner::Cosine(h) => h.insert((vector, id)),
54 HnswInner::Dot(h) => h.insert((vector, id)),
55 }
56 }
57
58 fn parallel_insert(&self, refs: &[(&Vec<f32>, usize)]) {
59 match self {
60 HnswInner::L2(h) => h.parallel_insert(refs),
61 HnswInner::Cosine(h) => h.parallel_insert(refs),
62 HnswInner::Dot(h) => h.parallel_insert(refs),
63 }
64 }
65
66 fn search(&self, query: &[f32], top_k: usize, ef_search: usize) -> Vec<Neighbour> {
67 match self {
68 HnswInner::L2(h) => h.search(query, top_k, ef_search),
69 HnswInner::Cosine(h) => h.search(query, top_k, ef_search),
70 HnswInner::Dot(h) => h.search(query, top_k, ef_search),
71 }
72 }
73
74 fn ef_construction(&self) -> usize {
75 match self {
76 HnswInner::L2(h) => h.get_ef_construction(),
77 HnswInner::Cosine(h) => h.get_ef_construction(),
78 HnswInner::Dot(h) => h.get_ef_construction(),
79 }
80 }
81
82 fn max_nb_connection(&self) -> usize {
83 match self {
84 HnswInner::L2(h) => h.get_max_nb_connection() as usize,
85 HnswInner::Cosine(h) => h.get_max_nb_connection() as usize,
86 HnswInner::Dot(h) => h.get_max_nb_connection() as usize,
87 }
88 }
89
90 fn max_level_observed(&self) -> usize {
91 match self {
92 HnswInner::L2(h) => h.get_max_level_observed() as usize,
93 HnswInner::Cosine(h) => h.get_max_level_observed() as usize,
94 HnswInner::Dot(h) => h.get_max_level_observed() as usize,
95 }
96 }
97}
98
99pub struct HnswIndex {
103 inner: HnswInner,
104 points: RwLock<HashMap<usize, Vec<f32>>>,
106 dimensions: usize,
108 element_count: AtomicUsize,
110 max_elements: usize,
112 deleted: RwLock<HashSet<usize>>,
114}
115
116impl HnswIndex {
117 #[instrument(skip(config))]
119 pub fn new(config: &VectorConfig, distance: DistanceMetric) -> VectorResult<Self> {
120 Self::new_with_dimensions(config, distance, config.default_dimensions)
121 }
122
123 pub fn new_with_dimensions(
125 config: &VectorConfig,
126 distance: DistanceMetric,
127 dimensions: usize,
128 ) -> VectorResult<Self> {
129 let inner = build_inner(
130 config.m_connections,
131 config.max_elements,
132 16,
133 config.ef_construction,
134 distance,
135 );
136 Ok(HnswIndex {
137 inner,
138 points: RwLock::new(HashMap::new()),
139 dimensions,
140 element_count: AtomicUsize::new(0),
141 max_elements: config.max_elements,
142 deleted: RwLock::new(HashSet::new()),
143 })
144 }
145
146 #[instrument(skip(self, vector))]
148 pub fn insert(&self, id: usize, vector: &[f32]) -> VectorResult<()> {
149 if vector.len() != self.dimensions {
150 return Err(VectorError::DimensionMismatch {
151 expected: self.dimensions,
152 got: vector.len(),
153 });
154 }
155 self.inner.insert(id, vector);
156 self.points
157 .write()
158 .map_err(|e| VectorError::Index(e.to_string()))?
159 .insert(id, vector.to_vec());
160 self.element_count.fetch_add(1, Ordering::Relaxed);
161 Ok(())
162 }
163
164 #[instrument(skip(self, items))]
166 pub fn insert_batch(&self, items: &[(usize, Vec<f32>)]) -> VectorResult<()> {
167 for (_, v) in items {
168 if v.len() != self.dimensions {
169 return Err(VectorError::DimensionMismatch {
170 expected: self.dimensions,
171 got: v.len(),
172 });
173 }
174 }
175 let refs: Vec<(&Vec<f32>, usize)> = items.iter().map(|(id, v)| (v, *id)).collect();
176 self.inner.parallel_insert(&refs);
177 let mut pts = self
178 .points
179 .write()
180 .map_err(|e| VectorError::Index(e.to_string()))?;
181 for (id, v) in items {
182 pts.insert(*id, v.clone());
183 }
184 self.element_count.fetch_add(items.len(), Ordering::Relaxed);
185 Ok(())
186 }
187
188 #[instrument(skip(self, query))]
192 pub fn search(
193 &self,
194 query: &[f32],
195 top_k: usize,
196 ef_search: usize,
197 ) -> VectorResult<Vec<(usize, f32)>> {
198 if query.len() != self.dimensions {
199 return Err(VectorError::DimensionMismatch {
200 expected: self.dimensions,
201 got: query.len(),
202 });
203 }
204 let deleted = self
205 .deleted
206 .read()
207 .map_err(|e| VectorError::Index(e.to_string()))?;
208 let neighbours = self.inner.search(query, top_k + deleted.len(), ef_search);
209 let mut results: Vec<(usize, f32)> = neighbours
210 .into_iter()
211 .filter(|n| !deleted.contains(&n.d_id))
212 .map(|n| (n.d_id, n.distance))
213 .collect();
214 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
215 results.truncate(top_k);
216 Ok(results)
217 }
218
219 #[instrument(skip(self))]
221 pub fn delete(&self, id: usize) -> VectorResult<()> {
222 let mut deleted = self
223 .deleted
224 .write()
225 .map_err(|e| VectorError::Index(e.to_string()))?;
226 if deleted.insert(id) {
227 self.points
228 .write()
229 .map_err(|e| VectorError::Index(e.to_string()))?
230 .remove(&id);
231 self.element_count.fetch_sub(1, Ordering::Relaxed);
232 }
233 Ok(())
234 }
235
236 pub fn len(&self) -> usize {
238 self.element_count.load(Ordering::Relaxed)
239 }
240
241 pub fn is_empty(&self) -> bool {
243 self.len() == 0
244 }
245
246 #[instrument(skip(self))]
248 pub fn save(&self, path: &Path, collection_id: &str) -> VectorResult<()> {
249 std::fs::create_dir_all(path)?;
250 let pts = self
251 .points
252 .read()
253 .map_err(|e| VectorError::Index(e.to_string()))?;
254
255 let mut buf = Vec::with_capacity(8 + pts.len() * (8 + self.dimensions * 4));
257 buf.extend_from_slice(&(pts.len() as u64).to_le_bytes());
258 for (&id, vec) in pts.iter() {
259 buf.extend_from_slice(&(id as u64).to_le_bytes());
260 for &v in vec {
261 buf.extend_from_slice(&v.to_le_bytes());
262 }
263 }
264
265 let final_path = index_file(path, collection_id);
266 let tmp_path = tmp_index_file(path, collection_id);
267 std::fs::write(&tmp_path, &buf)?;
268 std::fs::rename(&tmp_path, &final_path)?;
269
270 let checksum = blake3::hash(&buf).to_hex().to_string();
271 let manifest = CollectionManifest {
272 collection_id: collection_id.to_string(),
273 index_type: "hnsw".to_string(),
274 vector_count: pts.len(),
275 dimensions: self.dimensions,
276 saved_at_unix_ms: chrono::Utc::now().timestamp_millis(),
277 index_checksum_blake3: checksum,
278 };
279 std::fs::write(
280 manifest_file(path, collection_id),
281 serde_json::to_string_pretty(&manifest)?,
282 )?;
283 Ok(())
284 }
285
286 #[instrument(skip(config))]
288 pub fn load(
289 path: &Path,
290 collection_id: &str,
291 config: &VectorConfig,
292 distance: DistanceMetric,
293 ) -> VectorResult<Self> {
294 let final_path = index_file(path, collection_id);
295 let tmp_path = tmp_index_file(path, collection_id);
296 if tmp_path.exists() {
297 if final_path.exists() {
298 let _ = std::fs::remove_file(&tmp_path);
299 } else {
300 std::fs::rename(&tmp_path, &final_path)?;
301 }
302 }
303
304 let manifest_path = manifest_file(path, collection_id);
305 let manifest: CollectionManifest =
306 serde_json::from_reader(std::fs::File::open(&manifest_path)?)?;
307 let dimensions = manifest.dimensions;
308 let max_elements = config.max_elements;
309
310 let raw = std::fs::read(&final_path)?;
311 let checksum = blake3::hash(&raw).to_hex().to_string();
312 if checksum != manifest.index_checksum_blake3 {
313 tracing::warn!(
314 collection_id = %collection_id,
315 expected = %manifest.index_checksum_blake3,
316 got = %checksum,
317 "HNSW index checksum mismatch; continuing with best-effort load"
318 );
319 }
320 let points = decode_points_bin(&raw, dimensions)?;
321
322 let mut cfg = config.clone();
323 cfg.default_dimensions = dimensions;
324 cfg.max_elements = max_elements;
325 let index = Self::new_with_dimensions(&cfg, distance, dimensions)?;
326 index.insert_batch(&points)?;
327 Ok(index)
328 }
329
330 pub fn stats(&self) -> HnswStats {
332 HnswStats {
333 element_count: self.element_count.load(Ordering::Relaxed),
334 max_elements: self.max_elements,
335 ef_construction: self.inner.ef_construction(),
336 m_connections: self.inner.max_nb_connection(),
337 layers: self.inner.max_level_observed(),
338 }
339 }
340
341 pub fn snapshot_points(&self) -> VectorResult<Vec<(usize, Vec<f32>)>> {
343 let points = self
344 .points
345 .read()
346 .map_err(|e| VectorError::Index(e.to_string()))?
347 .iter()
348 .map(|(id, vector)| (*id, vector.clone()))
349 .collect();
350 Ok(points)
351 }
352}
353
354fn index_file(path: &Path, collection_id: &str) -> PathBuf {
355 path.join(format!("{collection_id}.hnsw"))
356}
357
358fn tmp_index_file(path: &Path, collection_id: &str) -> PathBuf {
359 path.join(format!("{collection_id}.hnsw.tmp"))
360}
361
362fn manifest_file(path: &Path, collection_id: &str) -> PathBuf {
363 path.join(format!("{collection_id}.manifest.json"))
364}
365
366fn build_inner(
367 m: usize,
368 max_elem: usize,
369 max_layer: usize,
370 ef_c: usize,
371 distance: DistanceMetric,
372) -> HnswInner {
373 match distance {
374 DistanceMetric::Euclidean => {
375 HnswInner::L2(Hnsw::new(m, max_elem, max_layer, ef_c, DistL2 {}))
376 }
377 DistanceMetric::Cosine => {
378 HnswInner::Cosine(Hnsw::new(m, max_elem, max_layer, ef_c, DistCosine {}))
379 }
380 DistanceMetric::DotProduct => {
381 HnswInner::Dot(Hnsw::new(m, max_elem, max_layer, ef_c, DistDot {}))
382 }
383 }
384}
385
386fn decode_points_bin(raw: &[u8], dimensions: usize) -> VectorResult<Vec<(usize, Vec<f32>)>> {
387 if raw.len() < 8 {
388 return Ok(Vec::new());
389 }
390 let n = u64::from_le_bytes(raw[..8].try_into().unwrap()) as usize;
391 let bpr = 8 + dimensions * 4;
392 if raw.len() < 8 + n * bpr {
393 return Err(VectorError::Index("hnsw.points.bin is truncated".into()));
394 }
395 let mut points = Vec::with_capacity(n);
396 let mut off = 8usize;
397 for _ in 0..n {
398 let id = u64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) as usize;
399 off += 8;
400 let floats: Vec<f32> = raw[off..off + dimensions * 4]
401 .chunks_exact(4)
402 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
403 .collect();
404 off += dimensions * 4;
405 points.push((id, floats));
406 }
407 Ok(points)
408}
409
410unsafe impl Send for HnswIndex {}
412unsafe impl Sync for HnswIndex {}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CollectionManifest {
417 pub collection_id: String,
419 pub index_type: String,
421 pub vector_count: usize,
423 pub dimensions: usize,
425 pub saved_at_unix_ms: i64,
427 pub index_checksum_blake3: String,
429}