Skip to main content

graphmind/vector/
index.rs

1//! # HNSW Vector Index Implementation
2//!
3//! ## How HNSW works
4//!
5//! HNSW (Hierarchical Navigable Small World) builds a proximity graph with multiple
6//! layers. Each node is assigned a random maximum layer (exponentially distributed —
7//! most nodes live only on layer 0, few reach the top). Insertion connects the new
8//! point to its nearest neighbors on each layer. Search starts at the top layer's
9//! entry point and greedily descends, refining the candidate set at each level.
10//!
11//! ## Key parameters
12//!
13//! - **`m`** (max connections per node): Controls graph density. Higher m = better recall
14//!   but more memory and slower insertion. Typical values: 12-48. Layer 0 uses `2*m`
15//!   connections.
16//! - **`ef_construction`** (search width during insertion): How many candidates to
17//!   consider when connecting a new node. Higher = better graph quality but slower build.
18//!   Typical values: 100-400.
19//! - **`ef_search`** (search width during query): How many candidates to track during
20//!   search. Higher = better recall but slower queries. Must be >= k (number of results).
21//!   This is the main recall-vs-speed knob at query time.
22//!
23//! ## Distance trait
24//!
25//! Rust's trait system enables polymorphic distance computation. The `hnsw_rs` crate
26//! defines a `Distance<T>` trait, and this module implements it with `CosineDistance`
27//! and `InnerProductDistance` structs. This allows the same HNSW data structure to work
28//! with different distance metrics without runtime dispatch overhead (monomorphization).
29//!
30//! ## Cosine distance formula
31//!
32//! `cosine_distance(a, b) = 1 - (a . b) / (||a|| * ||b||)`
33//!
34//! This measures angular distance between vectors:
35//! - **0** = identical direction (parallel vectors)
36//! - **1** = orthogonal (perpendicular, no similarity)
37//! - **2** = opposite direction (anti-correlated)
38//!
39//! ## Persistence strategy
40//!
41//! HNSW indices (from `hnsw_rs`) don't expose an iterator over stored vectors.
42//! To support persistence, all inserted vectors are also stored in a `Vec<StoredVector>`
43//! alongside the HNSW structure. On serialization, this vector list is saved via
44//! `bincode`. On load, a fresh HNSW index is constructed and all stored vectors are
45//! re-inserted. This trades load-time speed for implementation simplicity.
46
47use crate::graph::NodeId;
48use hnsw_rs::prelude::*;
49use thiserror::Error;
50
51/// Vector index errors
52#[derive(Error, Debug)]
53pub enum VectorError {
54    #[error("Index error: {0}")]
55    IndexError(String),
56
57    #[error("IO error: {0}")]
58    Io(#[from] std::io::Error),
59
60    #[error("Dimension mismatch: expected {expected}, got {got}")]
61    DimensionMismatch { expected: usize, got: usize },
62}
63
64pub type VectorResult<T> = Result<T, VectorError>;
65
66/// Distance metric for vector search
67#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
68pub enum DistanceMetric {
69    /// L2 (Euclidean) distance
70    L2,
71    /// Cosine similarity
72    Cosine,
73    /// Inner product
74    InnerProduct,
75}
76
77/// A point in the vector space, associated with a NodeId
78#[derive(Clone, Debug)]
79pub struct VectorPoint {
80    pub node_id: NodeId,
81    pub vector: Vec<f32>,
82}
83
84/// Cosine distance implementation for hnsw_rs
85#[derive(Clone, Copy, Debug, Default)]
86pub struct CosineDistance;
87
88impl Distance<f32> for CosineDistance {
89    fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
90        let mut dot = 0.0;
91        let mut norm_a = 0.0;
92        let mut norm_b = 0.0;
93
94        for (a, b) in va.iter().zip(vb.iter()) {
95            dot += a * b;
96            norm_a += a * a;
97            norm_b += b * b;
98        }
99
100        if norm_a <= 0.0 || norm_b <= 0.0 {
101            return 1.0;
102        }
103
104        // Cosine distance = 1.0 - cosine similarity
105        let sim = dot / (norm_a.sqrt() * norm_b.sqrt());
106        1.0 - sim
107    }
108}
109
110/// Inner Product distance implementation for hnsw_rs
111#[derive(Clone, Copy, Debug, Default)]
112pub struct InnerProductDistance;
113
114impl Distance<f32> for InnerProductDistance {
115    fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
116        let mut dot = 0.0;
117        for (a, b) in va.iter().zip(vb.iter()) {
118            dot += a * b;
119        }
120        // Inner product distance = 1.0 - dot product (for normalized vectors)
121        1.0 - dot
122    }
123}
124
125/// Stored vector entry for persistence
126#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
127pub struct StoredVector {
128    pub node_id: u64,
129    pub vector: Vec<f32>,
130}
131
132/// Wrapper around HNSW index
133pub struct VectorIndex {
134    /// Number of dimensions
135    dimensions: usize,
136    /// Distance metric
137    metric: DistanceMetric,
138    /// The actual HNSW index
139    hnsw: Hnsw<'static, f32, CosineDistance>,
140    /// All inserted vectors (for persistence — HNSW doesn't expose iteration)
141    stored_vectors: Vec<StoredVector>,
142}
143
144// Implement Debug manually because Hnsw doesn't implement it
145impl std::fmt::Debug for VectorIndex {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("VectorIndex")
148            .field("dimensions", &self.dimensions)
149            .field("metric", &self.metric)
150            .finish()
151    }
152}
153
154impl VectorIndex {
155    /// Create a new vector index
156    pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
157        // HNSW parameters
158        let max_elements = 100_000;
159        let m = 16;
160        let ef_construction = 200;
161
162        let hnsw = Hnsw::new(m, max_elements, 16, ef_construction, CosineDistance);
163
164        Self {
165            dimensions,
166            metric,
167            hnsw,
168            stored_vectors: Vec::new(),
169        }
170    }
171
172    /// Add a vector to the index
173    pub fn add(&mut self, node_id: NodeId, vector: &Vec<f32>) -> VectorResult<()> {
174        if vector.len() != self.dimensions {
175            return Err(VectorError::DimensionMismatch {
176                expected: self.dimensions,
177                got: vector.len(),
178            });
179        }
180
181        self.hnsw.insert((vector, node_id.0 as usize));
182
183        // Store vector for persistence
184        self.stored_vectors.push(StoredVector {
185            node_id: node_id.0,
186            vector: vector.clone(),
187        });
188
189        Ok(())
190    }
191
192    /// Search for nearest neighbors
193    pub fn search(&self, query: &[f32], k: usize) -> VectorResult<Vec<(NodeId, f32)>> {
194        if query.len() != self.dimensions {
195            return Err(VectorError::DimensionMismatch {
196                expected: self.dimensions,
197                got: query.len(),
198            });
199        }
200
201        let ef_search = k * 2;
202        let results = self.hnsw.search(query, k, ef_search);
203
204        let mut neighbors = Vec::new();
205        for res in results {
206            neighbors.push((NodeId::new(res.d_id as u64), res.distance));
207        }
208
209        Ok(neighbors)
210    }
211
212    /// Get dimensions
213    pub fn dimensions(&self) -> usize {
214        self.dimensions
215    }
216
217    /// Get metric
218    pub fn metric(&self) -> DistanceMetric {
219        self.metric
220    }
221
222    /// Get count of stored vectors
223    pub fn len(&self) -> usize {
224        self.stored_vectors.len()
225    }
226
227    /// Check if index is empty
228    pub fn is_empty(&self) -> bool {
229        self.stored_vectors.is_empty()
230    }
231
232    /// Save index to disk by serializing stored vectors via bincode.
233    /// On load, vectors are re-inserted into a fresh HNSW index.
234    pub fn dump(&self, path: &std::path::Path) -> VectorResult<()> {
235        let file = std::fs::File::create(path)?;
236        let writer = std::io::BufWriter::new(file);
237        bincode::serialize_into(writer, &self.stored_vectors)
238            .map_err(|e| VectorError::IndexError(format!("serialization error: {}", e)))?;
239        Ok(())
240    }
241
242    /// Load index from disk: deserialize stored vectors and re-insert into HNSW.
243    pub fn load(
244        path: &std::path::Path,
245        dimensions: usize,
246        metric: DistanceMetric,
247    ) -> VectorResult<Self> {
248        if !path.exists() {
249            return Ok(Self::new(dimensions, metric));
250        }
251        let file = std::fs::File::open(path)?;
252        let reader = std::io::BufReader::new(file);
253        let stored_vectors: Vec<StoredVector> = bincode::deserialize_from(reader)
254            .map_err(|e| VectorError::IndexError(format!("deserialization error: {}", e)))?;
255
256        let max_elements = (stored_vectors.len() + 10_000).max(100_000);
257        let m = 16;
258        let ef_construction = 200;
259        let hnsw = Hnsw::new(m, max_elements, 16, ef_construction, CosineDistance);
260
261        // Re-insert all vectors
262        for sv in &stored_vectors {
263            hnsw.insert((&sv.vector, sv.node_id as usize));
264        }
265
266        Ok(Self {
267            dimensions,
268            metric,
269            hnsw,
270            stored_vectors,
271        })
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_vector_index_basic() {
281        let mut index = VectorIndex::new(3, DistanceMetric::Cosine);
282
283        // Add some vectors
284        index.add(NodeId::new(1), &vec![1.0, 0.0, 0.0]).unwrap();
285        index.add(NodeId::new(2), &vec![0.0, 1.0, 0.0]).unwrap();
286        index.add(NodeId::new(3), &vec![0.0, 0.1, 0.9]).unwrap();
287
288        // Search — HNSW is approximate and may return fewer than k results on very small graphs
289        let results = index.search(&[1.0, 0.1, 0.0], 2).unwrap();
290        assert!(results.len() >= 1 && results.len() <= 2);
291        assert_eq!(results[0].0, NodeId::new(1));
292    }
293
294    #[test]
295    fn test_vector_index_persistence() {
296        let dir = tempfile::TempDir::new().unwrap();
297        let dump_path = dir.path().join("test_vectors.bin");
298
299        // Create and populate index
300        let mut index = VectorIndex::new(3, DistanceMetric::Cosine);
301        index.add(NodeId::new(1), &vec![1.0, 0.0, 0.0]).unwrap();
302        index.add(NodeId::new(2), &vec![0.0, 1.0, 0.0]).unwrap();
303        index.add(NodeId::new(3), &vec![0.0, 0.1, 0.9]).unwrap();
304        assert_eq!(index.len(), 3);
305
306        // Dump to disk
307        index.dump(&dump_path).unwrap();
308
309        // Load from disk
310        let loaded = VectorIndex::load(&dump_path, 3, DistanceMetric::Cosine).unwrap();
311        assert_eq!(loaded.len(), 3);
312        assert_eq!(loaded.dimensions(), 3);
313
314        // Verify search still works after reload
315        let results = loaded.search(&[1.0, 0.1, 0.0], 2).unwrap();
316        assert_eq!(results.len(), 2);
317        assert_eq!(results[0].0, NodeId::new(1));
318    }
319
320    #[test]
321    fn test_distance_metrics() {
322        let v1 = vec![1.0, 0.0];
323        let v2 = vec![0.0, 1.0];
324        let _v3 = vec![1.0, 1.0]; // Not normalized
325
326        let cosine = CosineDistance;
327        // Orthogonal
328        assert!((cosine.eval(&v1, &v2) - 1.0).abs() < 1e-6);
329        // Same
330        assert!((cosine.eval(&v1, &v1) - 0.0).abs() < 1e-6);
331
332        let inner = InnerProductDistance;
333        // Dot product = 0
334        assert!((inner.eval(&v1, &v2) - 1.0).abs() < 1e-6); // 1.0 - 0.0
335    }
336}