1use crate::graph::NodeId;
48use hnsw_rs::prelude::*;
49use thiserror::Error;
50
51#[derive(Error, Debug)]
53pub enum VectorError {
54 #[error("Index error: {0}")]
55 IndexError(String),
56
57 #[error("IO error: {0}")]
58 Io(#[from] std::io::Error),
59
60 #[error("Dimension mismatch: expected {expected}, got {got}")]
61 DimensionMismatch { expected: usize, got: usize },
62}
63
64pub type VectorResult<T> = Result<T, VectorError>;
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
68pub enum DistanceMetric {
69 L2,
71 Cosine,
73 InnerProduct,
75}
76
77#[derive(Clone, Debug)]
79pub struct VectorPoint {
80 pub node_id: NodeId,
81 pub vector: Vec<f32>,
82}
83
84#[derive(Clone, Copy, Debug, Default)]
86pub struct CosineDistance;
87
88impl Distance<f32> for CosineDistance {
89 fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
90 let mut dot = 0.0;
91 let mut norm_a = 0.0;
92 let mut norm_b = 0.0;
93
94 for (a, b) in va.iter().zip(vb.iter()) {
95 dot += a * b;
96 norm_a += a * a;
97 norm_b += b * b;
98 }
99
100 if norm_a <= 0.0 || norm_b <= 0.0 {
101 return 1.0;
102 }
103
104 let sim = dot / (norm_a.sqrt() * norm_b.sqrt());
106 1.0 - sim
107 }
108}
109
110#[derive(Clone, Copy, Debug, Default)]
112pub struct InnerProductDistance;
113
114impl Distance<f32> for InnerProductDistance {
115 fn eval(&self, va: &[f32], vb: &[f32]) -> f32 {
116 let mut dot = 0.0;
117 for (a, b) in va.iter().zip(vb.iter()) {
118 dot += a * b;
119 }
120 1.0 - dot
122 }
123}
124
125#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
127pub struct StoredVector {
128 pub node_id: u64,
129 pub vector: Vec<f32>,
130}
131
132pub struct VectorIndex {
134 dimensions: usize,
136 metric: DistanceMetric,
138 hnsw: Hnsw<'static, f32, CosineDistance>,
140 stored_vectors: Vec<StoredVector>,
142}
143
144impl std::fmt::Debug for VectorIndex {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("VectorIndex")
148 .field("dimensions", &self.dimensions)
149 .field("metric", &self.metric)
150 .finish()
151 }
152}
153
154impl VectorIndex {
155 pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
157 let max_elements = 100_000;
159 let m = 16;
160 let ef_construction = 200;
161
162 let hnsw = Hnsw::new(m, max_elements, 16, ef_construction, CosineDistance);
163
164 Self {
165 dimensions,
166 metric,
167 hnsw,
168 stored_vectors: Vec::new(),
169 }
170 }
171
172 pub fn add(&mut self, node_id: NodeId, vector: &Vec<f32>) -> VectorResult<()> {
174 if vector.len() != self.dimensions {
175 return Err(VectorError::DimensionMismatch {
176 expected: self.dimensions,
177 got: vector.len(),
178 });
179 }
180
181 self.hnsw.insert((vector, node_id.0 as usize));
182
183 self.stored_vectors.push(StoredVector {
185 node_id: node_id.0,
186 vector: vector.clone(),
187 });
188
189 Ok(())
190 }
191
192 pub fn search(&self, query: &[f32], k: usize) -> VectorResult<Vec<(NodeId, f32)>> {
194 if query.len() != self.dimensions {
195 return Err(VectorError::DimensionMismatch {
196 expected: self.dimensions,
197 got: query.len(),
198 });
199 }
200
201 let ef_search = k * 2;
202 let results = self.hnsw.search(query, k, ef_search);
203
204 let mut neighbors = Vec::new();
205 for res in results {
206 neighbors.push((NodeId::new(res.d_id as u64), res.distance));
207 }
208
209 Ok(neighbors)
210 }
211
212 pub fn dimensions(&self) -> usize {
214 self.dimensions
215 }
216
217 pub fn metric(&self) -> DistanceMetric {
219 self.metric
220 }
221
222 pub fn len(&self) -> usize {
224 self.stored_vectors.len()
225 }
226
227 pub fn is_empty(&self) -> bool {
229 self.stored_vectors.is_empty()
230 }
231
232 pub fn dump(&self, path: &std::path::Path) -> VectorResult<()> {
235 let file = std::fs::File::create(path)?;
236 let writer = std::io::BufWriter::new(file);
237 bincode::serialize_into(writer, &self.stored_vectors)
238 .map_err(|e| VectorError::IndexError(format!("serialization error: {}", e)))?;
239 Ok(())
240 }
241
242 pub fn load(
244 path: &std::path::Path,
245 dimensions: usize,
246 metric: DistanceMetric,
247 ) -> VectorResult<Self> {
248 if !path.exists() {
249 return Ok(Self::new(dimensions, metric));
250 }
251 let file = std::fs::File::open(path)?;
252 let reader = std::io::BufReader::new(file);
253 let stored_vectors: Vec<StoredVector> = bincode::deserialize_from(reader)
254 .map_err(|e| VectorError::IndexError(format!("deserialization error: {}", e)))?;
255
256 let max_elements = (stored_vectors.len() + 10_000).max(100_000);
257 let m = 16;
258 let ef_construction = 200;
259 let hnsw = Hnsw::new(m, max_elements, 16, ef_construction, CosineDistance);
260
261 for sv in &stored_vectors {
263 hnsw.insert((&sv.vector, sv.node_id as usize));
264 }
265
266 Ok(Self {
267 dimensions,
268 metric,
269 hnsw,
270 stored_vectors,
271 })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_vector_index_basic() {
281 let mut index = VectorIndex::new(3, DistanceMetric::Cosine);
282
283 index.add(NodeId::new(1), &vec![1.0, 0.0, 0.0]).unwrap();
285 index.add(NodeId::new(2), &vec![0.0, 1.0, 0.0]).unwrap();
286 index.add(NodeId::new(3), &vec![0.0, 0.1, 0.9]).unwrap();
287
288 let results = index.search(&[1.0, 0.1, 0.0], 2).unwrap();
290 assert!(results.len() >= 1 && results.len() <= 2);
291 assert_eq!(results[0].0, NodeId::new(1));
292 }
293
294 #[test]
295 fn test_vector_index_persistence() {
296 let dir = tempfile::TempDir::new().unwrap();
297 let dump_path = dir.path().join("test_vectors.bin");
298
299 let mut index = VectorIndex::new(3, DistanceMetric::Cosine);
301 index.add(NodeId::new(1), &vec![1.0, 0.0, 0.0]).unwrap();
302 index.add(NodeId::new(2), &vec![0.0, 1.0, 0.0]).unwrap();
303 index.add(NodeId::new(3), &vec![0.0, 0.1, 0.9]).unwrap();
304 assert_eq!(index.len(), 3);
305
306 index.dump(&dump_path).unwrap();
308
309 let loaded = VectorIndex::load(&dump_path, 3, DistanceMetric::Cosine).unwrap();
311 assert_eq!(loaded.len(), 3);
312 assert_eq!(loaded.dimensions(), 3);
313
314 let results = loaded.search(&[1.0, 0.1, 0.0], 2).unwrap();
316 assert_eq!(results.len(), 2);
317 assert_eq!(results[0].0, NodeId::new(1));
318 }
319
320 #[test]
321 fn test_distance_metrics() {
322 let v1 = vec![1.0, 0.0];
323 let v2 = vec![0.0, 1.0];
324 let _v3 = vec![1.0, 1.0]; let cosine = CosineDistance;
327 assert!((cosine.eval(&v1, &v2) - 1.0).abs() < 1e-6);
329 assert!((cosine.eval(&v1, &v1) - 0.0).abs() < 1e-6);
331
332 let inner = InnerProductDistance;
333 assert!((inner.eval(&v1, &v2) - 1.0).abs() < 1e-6); }
336}