Skip to main content

claw_vector/index/
hnsw.rs

1// index/hnsw.rs — HNSW index wrapper around the hnsw_rs crate.
2//
3// Uses a "shadow map" of (id → vector) for persistence and migration, avoiding
4// hnsw_rs's own file-format which carries awkward lifetime constraints.
5use std::{
6    collections::{HashMap, HashSet},
7    path::{Path, PathBuf},
8    sync::{
9        atomic::{AtomicUsize, Ordering},
10        RwLock,
11    },
12};
13
14use hnsw_rs::prelude::*;
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::{
19    config::VectorConfig,
20    error::{VectorError, VectorResult},
21    types::DistanceMetric,
22};
23
24// ─── HnswStats ───────────────────────────────────────────────────────────────
25
26/// Runtime statistics snapshot for a [`HnswIndex`].
27#[derive(Debug, Clone)]
28pub struct HnswStats {
29    /// Number of live (non-deleted) elements.
30    pub element_count: usize,
31    /// Maximum capacity the index was configured for.
32    pub max_elements: usize,
33    /// HNSW `ef_construction` build parameter.
34    pub ef_construction: usize,
35    /// HNSW `M` connections parameter.
36    pub m_connections: usize,
37    /// Number of layers observed in the current graph.
38    pub layers: usize,
39}
40
41// ─── HnswInner (type-erased distance variant) ─────────────────────────────────
42
43enum HnswInner {
44    L2(Hnsw<'static, f32, DistL2>),
45    Cosine(Hnsw<'static, f32, DistCosine>),
46    Dot(Hnsw<'static, f32, DistDot>),
47}
48
49impl HnswInner {
50    fn insert(&self, id: usize, vector: &[f32]) {
51        match self {
52            HnswInner::L2(h) => h.insert((vector, id)),
53            HnswInner::Cosine(h) => h.insert((vector, id)),
54            HnswInner::Dot(h) => h.insert((vector, id)),
55        }
56    }
57
58    fn parallel_insert(&self, refs: &[(&Vec<f32>, usize)]) {
59        match self {
60            HnswInner::L2(h) => h.parallel_insert(refs),
61            HnswInner::Cosine(h) => h.parallel_insert(refs),
62            HnswInner::Dot(h) => h.parallel_insert(refs),
63        }
64    }
65
66    fn search(&self, query: &[f32], top_k: usize, ef_search: usize) -> Vec<Neighbour> {
67        match self {
68            HnswInner::L2(h) => h.search(query, top_k, ef_search),
69            HnswInner::Cosine(h) => h.search(query, top_k, ef_search),
70            HnswInner::Dot(h) => h.search(query, top_k, ef_search),
71        }
72    }
73
74    fn ef_construction(&self) -> usize {
75        match self {
76            HnswInner::L2(h) => h.get_ef_construction(),
77            HnswInner::Cosine(h) => h.get_ef_construction(),
78            HnswInner::Dot(h) => h.get_ef_construction(),
79        }
80    }
81
82    fn max_nb_connection(&self) -> usize {
83        match self {
84            HnswInner::L2(h) => h.get_max_nb_connection() as usize,
85            HnswInner::Cosine(h) => h.get_max_nb_connection() as usize,
86            HnswInner::Dot(h) => h.get_max_nb_connection() as usize,
87        }
88    }
89
90    fn max_level_observed(&self) -> usize {
91        match self {
92            HnswInner::L2(h) => h.get_max_level_observed() as usize,
93            HnswInner::Cosine(h) => h.get_max_level_observed() as usize,
94            HnswInner::Dot(h) => h.get_max_level_observed() as usize,
95        }
96    }
97}
98
99// ─── HnswIndex ───────────────────────────────────────────────────────────────
100
101/// Thread-safe HNSW index with support for three distance metrics.
102pub struct HnswIndex {
103    inner: HnswInner,
104    /// Shadow copy of (id → vector) used for serialisation and migration.
105    points: RwLock<HashMap<usize, Vec<f32>>>,
106    /// Expected vector dimensionality.
107    dimensions: usize,
108    /// Count of live (non-deleted) elements.
109    element_count: AtomicUsize,
110    /// Maximum capacity the index was configured for.
111    max_elements: usize,
112    /// Logically deleted ids (tombstones).
113    deleted: RwLock<HashSet<usize>>,
114}
115
116impl HnswIndex {
117    /// Build a new empty HNSW index from the given config and distance metric.
118    #[instrument(skip(config))]
119    pub fn new(config: &VectorConfig, distance: DistanceMetric) -> VectorResult<Self> {
120        Self::new_with_dimensions(config, distance, config.default_dimensions)
121    }
122
123    /// Build a new empty HNSW index for an explicit `dimensions` count.
124    pub fn new_with_dimensions(
125        config: &VectorConfig,
126        distance: DistanceMetric,
127        dimensions: usize,
128    ) -> VectorResult<Self> {
129        let inner = build_inner(
130            config.m_connections,
131            config.max_elements,
132            16,
133            config.ef_construction,
134            distance,
135        );
136        Ok(HnswIndex {
137            inner,
138            points: RwLock::new(HashMap::new()),
139            dimensions,
140            element_count: AtomicUsize::new(0),
141            max_elements: config.max_elements,
142            deleted: RwLock::new(HashSet::new()),
143        })
144    }
145
146    /// Insert a single vector, validating its dimensionality.
147    #[instrument(skip(self, vector))]
148    pub fn insert(&self, id: usize, vector: &[f32]) -> VectorResult<()> {
149        if vector.len() != self.dimensions {
150            return Err(VectorError::DimensionMismatch {
151                expected: self.dimensions,
152                got: vector.len(),
153            });
154        }
155        self.inner.insert(id, vector);
156        self.points
157            .write()
158            .map_err(|e| VectorError::Index(e.to_string()))?
159            .insert(id, vector.to_vec());
160        self.element_count.fetch_add(1, Ordering::Relaxed);
161        Ok(())
162    }
163
164    /// Insert a batch of vectors in parallel via hnsw_rs `parallel_insert`.
165    #[instrument(skip(self, items))]
166    pub fn insert_batch(&self, items: &[(usize, Vec<f32>)]) -> VectorResult<()> {
167        for (_, v) in items {
168            if v.len() != self.dimensions {
169                return Err(VectorError::DimensionMismatch {
170                    expected: self.dimensions,
171                    got: v.len(),
172                });
173            }
174        }
175        let refs: Vec<(&Vec<f32>, usize)> = items.iter().map(|(id, v)| (v, *id)).collect();
176        self.inner.parallel_insert(&refs);
177        let mut pts = self
178            .points
179            .write()
180            .map_err(|e| VectorError::Index(e.to_string()))?;
181        for (id, v) in items {
182            pts.insert(*id, v.clone());
183        }
184        self.element_count.fetch_add(items.len(), Ordering::Relaxed);
185        Ok(())
186    }
187
188    /// Search for the `top_k` nearest neighbours of `query`.
189    ///
190    /// Returns `(internal_id, distance)` pairs sorted by ascending distance.
191    #[instrument(skip(self, query))]
192    pub fn search(
193        &self,
194        query: &[f32],
195        top_k: usize,
196        ef_search: usize,
197    ) -> VectorResult<Vec<(usize, f32)>> {
198        if query.len() != self.dimensions {
199            return Err(VectorError::DimensionMismatch {
200                expected: self.dimensions,
201                got: query.len(),
202            });
203        }
204        let deleted = self
205            .deleted
206            .read()
207            .map_err(|e| VectorError::Index(e.to_string()))?;
208        let neighbours = self.inner.search(query, top_k + deleted.len(), ef_search);
209        let mut results: Vec<(usize, f32)> = neighbours
210            .into_iter()
211            .filter(|n| !deleted.contains(&n.d_id))
212            .map(|n| (n.d_id, n.distance))
213            .collect();
214        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
215        results.truncate(top_k);
216        Ok(results)
217    }
218
219    /// Mark a vector as deleted (tombstone — hnsw_rs does not support physical removal).
220    #[instrument(skip(self))]
221    pub fn delete(&self, id: usize) -> VectorResult<()> {
222        let mut deleted = self
223            .deleted
224            .write()
225            .map_err(|e| VectorError::Index(e.to_string()))?;
226        if deleted.insert(id) {
227            self.points
228                .write()
229                .map_err(|e| VectorError::Index(e.to_string()))?
230                .remove(&id);
231            self.element_count.fetch_sub(1, Ordering::Relaxed);
232        }
233        Ok(())
234    }
235
236    /// Return the number of live (non-deleted) elements.
237    pub fn len(&self) -> usize {
238        self.element_count.load(Ordering::Relaxed)
239    }
240
241    /// Return `true` if the index contains no live elements.
242    pub fn is_empty(&self) -> bool {
243        self.len() == 0
244    }
245
246    /// Persist the index under `path` using an atomic tmp+rename strategy.
247    #[instrument(skip(self))]
248    pub fn save(&self, path: &Path, collection_id: &str) -> VectorResult<()> {
249        std::fs::create_dir_all(path)?;
250        let pts = self
251            .points
252            .read()
253            .map_err(|e| VectorError::Index(e.to_string()))?;
254
255        // Binary format: [n: u64][(id: u64)(v0..vN: f32) ...]
256        let mut buf = Vec::with_capacity(8 + pts.len() * (8 + self.dimensions * 4));
257        buf.extend_from_slice(&(pts.len() as u64).to_le_bytes());
258        for (&id, vec) in pts.iter() {
259            buf.extend_from_slice(&(id as u64).to_le_bytes());
260            for &v in vec {
261                buf.extend_from_slice(&v.to_le_bytes());
262            }
263        }
264
265        let final_path = index_file(path, collection_id);
266        let tmp_path = tmp_index_file(path, collection_id);
267        std::fs::write(&tmp_path, &buf)?;
268        std::fs::rename(&tmp_path, &final_path)?;
269
270        let checksum = blake3::hash(&buf).to_hex().to_string();
271        let manifest = CollectionManifest {
272            collection_id: collection_id.to_string(),
273            index_type: "hnsw".to_string(),
274            vector_count: pts.len(),
275            dimensions: self.dimensions,
276            saved_at_unix_ms: chrono::Utc::now().timestamp_millis(),
277            index_checksum_blake3: checksum,
278        };
279        std::fs::write(
280            manifest_file(path, collection_id),
281            serde_json::to_string_pretty(&manifest)?,
282        )?;
283        Ok(())
284    }
285
286    /// Reload a previously saved index by re-inserting all persisted points.
287    #[instrument(skip(config))]
288    pub fn load(
289        path: &Path,
290        collection_id: &str,
291        config: &VectorConfig,
292        distance: DistanceMetric,
293    ) -> VectorResult<Self> {
294        let final_path = index_file(path, collection_id);
295        let tmp_path = tmp_index_file(path, collection_id);
296        if tmp_path.exists() {
297            if final_path.exists() {
298                let _ = std::fs::remove_file(&tmp_path);
299            } else {
300                std::fs::rename(&tmp_path, &final_path)?;
301            }
302        }
303
304        let manifest_path = manifest_file(path, collection_id);
305        let manifest: CollectionManifest =
306            serde_json::from_reader(std::fs::File::open(&manifest_path)?)?;
307        let dimensions = manifest.dimensions;
308        let max_elements = config.max_elements;
309
310        let raw = std::fs::read(&final_path)?;
311        let checksum = blake3::hash(&raw).to_hex().to_string();
312        if checksum != manifest.index_checksum_blake3 {
313            tracing::warn!(
314                collection_id = %collection_id,
315                expected = %manifest.index_checksum_blake3,
316                got = %checksum,
317                "HNSW index checksum mismatch; continuing with best-effort load"
318            );
319        }
320        let points = decode_points_bin(&raw, dimensions)?;
321
322        let mut cfg = config.clone();
323        cfg.default_dimensions = dimensions;
324        cfg.max_elements = max_elements;
325        let index = Self::new_with_dimensions(&cfg, distance, dimensions)?;
326        index.insert_batch(&points)?;
327        Ok(index)
328    }
329
330    /// Return a statistics snapshot.
331    pub fn stats(&self) -> HnswStats {
332        HnswStats {
333            element_count: self.element_count.load(Ordering::Relaxed),
334            max_elements: self.max_elements,
335            ef_construction: self.inner.ef_construction(),
336            m_connections: self.inner.max_nb_connection(),
337            layers: self.inner.max_level_observed(),
338        }
339    }
340
341    /// Return a clone of all live points for persistence.
342    pub fn snapshot_points(&self) -> VectorResult<Vec<(usize, Vec<f32>)>> {
343        let points = self
344            .points
345            .read()
346            .map_err(|e| VectorError::Index(e.to_string()))?
347            .iter()
348            .map(|(id, vector)| (*id, vector.clone()))
349            .collect();
350        Ok(points)
351    }
352}
353
354fn index_file(path: &Path, collection_id: &str) -> PathBuf {
355    path.join(format!("{collection_id}.hnsw"))
356}
357
358fn tmp_index_file(path: &Path, collection_id: &str) -> PathBuf {
359    path.join(format!("{collection_id}.hnsw.tmp"))
360}
361
362fn manifest_file(path: &Path, collection_id: &str) -> PathBuf {
363    path.join(format!("{collection_id}.manifest.json"))
364}
365
366fn build_inner(
367    m: usize,
368    max_elem: usize,
369    max_layer: usize,
370    ef_c: usize,
371    distance: DistanceMetric,
372) -> HnswInner {
373    match distance {
374        DistanceMetric::Euclidean => {
375            HnswInner::L2(Hnsw::new(m, max_elem, max_layer, ef_c, DistL2 {}))
376        }
377        DistanceMetric::Cosine => {
378            HnswInner::Cosine(Hnsw::new(m, max_elem, max_layer, ef_c, DistCosine {}))
379        }
380        DistanceMetric::DotProduct => {
381            HnswInner::Dot(Hnsw::new(m, max_elem, max_layer, ef_c, DistDot {}))
382        }
383    }
384}
385
386fn decode_points_bin(raw: &[u8], dimensions: usize) -> VectorResult<Vec<(usize, Vec<f32>)>> {
387    if raw.len() < 8 {
388        return Ok(Vec::new());
389    }
390    let n = u64::from_le_bytes(raw[..8].try_into().unwrap()) as usize;
391    let bpr = 8 + dimensions * 4;
392    if raw.len() < 8 + n * bpr {
393        return Err(VectorError::Index("hnsw.points.bin is truncated".into()));
394    }
395    let mut points = Vec::with_capacity(n);
396    let mut off = 8usize;
397    for _ in 0..n {
398        let id = u64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) as usize;
399        off += 8;
400        let floats: Vec<f32> = raw[off..off + dimensions * 4]
401            .chunks_exact(4)
402            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
403            .collect();
404        off += dimensions * 4;
405        points.push((id, floats));
406    }
407    Ok(points)
408}
409
410// SAFETY: Hnsw<'static, T, D> owns all its data; interior mutation is guarded by parking_lot.
411unsafe impl Send for HnswIndex {}
412unsafe impl Sync for HnswIndex {}
413
414/// Persisted metadata for a saved HNSW index file.
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CollectionManifest {
417    /// Collection identifier associated with this index artifact.
418    pub collection_id: String,
419    /// Index implementation name (currently always `hnsw`).
420    pub index_type: String,
421    /// Number of vectors present when the index was persisted.
422    pub vector_count: usize,
423    /// Vector dimensionality encoded in the index.
424    pub dimensions: usize,
425    /// Unix timestamp in milliseconds when the index was saved.
426    pub saved_at_unix_ms: i64,
427    /// Blake3 checksum of the `.hnsw` file contents.
428    pub index_checksum_blake3: String,
429}