1use std::collections::HashMap;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum DistanceMetric {
25 Cosine,
27 L2,
29 DotProduct,
31}
32
33#[derive(Debug, Clone)]
35pub struct VectorEntry {
36 pub id: u64,
38 pub vector: Vec<f32>,
40 pub metadata: HashMap<String, String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct SearchResult {
47 pub id: u64,
49 pub score: f32,
51 pub metadata: Option<HashMap<String, String>>,
53}
54
55#[derive(Debug, Clone)]
57pub struct VectorStoreConfig {
58 pub dimensions: u32,
60 pub metric: DistanceMetric,
62 pub capacity: usize,
64 pub store_metadata: bool,
66}
67
68impl Default for VectorStoreConfig {
69 fn default() -> Self {
70 Self {
71 dimensions: 768,
72 metric: DistanceMetric::Cosine,
73 capacity: 1_000_000,
74 store_metadata: true,
75 }
76 }
77}
78
79pub struct VectorStore {
85 config: VectorStoreConfig,
87 vectors: Vec<f32>,
89 ids: Vec<u64>,
91 metadata: Vec<Option<HashMap<String, String>>>,
93 count: usize,
95}
96
97impl VectorStore {
98 pub fn new(config: VectorStoreConfig) -> Self {
100 let cap = config.capacity;
101 let dim = config.dimensions as usize;
102 Self {
103 config,
104 vectors: Vec::with_capacity(cap * dim),
105 ids: Vec::with_capacity(cap),
106 metadata: Vec::with_capacity(cap),
107 count: 0,
108 }
109 }
110
111 pub fn insert(&mut self, entry: VectorEntry) -> Result<(), VectorStoreError> {
113 if entry.vector.len() != self.config.dimensions as usize {
114 return Err(VectorStoreError::DimensionMismatch {
115 expected: self.config.dimensions as usize,
116 got: entry.vector.len(),
117 });
118 }
119
120 if self.count >= self.config.capacity {
121 return Err(VectorStoreError::CapacityExceeded {
122 capacity: self.config.capacity,
123 });
124 }
125
126 self.vectors.extend_from_slice(&entry.vector);
127 self.ids.push(entry.id);
128 self.metadata.push(if self.config.store_metadata {
129 Some(entry.metadata)
130 } else {
131 None
132 });
133 self.count += 1;
134
135 Ok(())
136 }
137
138 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, VectorStoreError> {
140 if query.len() != self.config.dimensions as usize {
141 return Err(VectorStoreError::DimensionMismatch {
142 expected: self.config.dimensions as usize,
143 got: query.len(),
144 });
145 }
146
147 let dim = self.config.dimensions as usize;
148 let mut scores: Vec<(usize, f32)> = (0..self.count)
149 .map(|i| {
150 let vec_start = i * dim;
151 let vec_slice = &self.vectors[vec_start..vec_start + dim];
152 let score = match self.config.metric {
153 DistanceMetric::Cosine => cosine_similarity(query, vec_slice),
154 DistanceMetric::L2 => -l2_distance(query, vec_slice), DistanceMetric::DotProduct => dot_product(query, vec_slice),
156 };
157 (i, score)
158 })
159 .collect();
160
161 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
163 scores.truncate(k);
164
165 Ok(scores
166 .into_iter()
167 .map(|(idx, score)| SearchResult {
168 id: self.ids[idx],
169 score,
170 metadata: self.metadata.get(idx).and_then(|m| m.clone()),
171 })
172 .collect())
173 }
174
175 pub fn delete(&mut self, id: u64) -> bool {
177 if let Some(idx) = self.ids.iter().position(|&i| i == id) {
178 let dim = self.config.dimensions as usize;
179 let vec_start = idx * dim;
180
181 self.vectors.drain(vec_start..vec_start + dim);
183 self.ids.remove(idx);
184 self.metadata.remove(idx);
185 self.count -= 1;
186 true
187 } else {
188 false
189 }
190 }
191
192 pub fn len(&self) -> usize {
194 self.count
195 }
196
197 pub fn is_empty(&self) -> bool {
199 self.count == 0
200 }
201
202 pub fn flat_vectors(&self) -> &[f32] {
206 &self.vectors
207 }
208
209 pub fn ids(&self) -> &[u64] {
211 &self.ids
212 }
213
214 pub fn dimensions(&self) -> u32 {
216 self.config.dimensions
217 }
218}
219
220fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
222 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
223 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
224 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
225
226 if norm_a == 0.0 || norm_b == 0.0 {
227 return 0.0;
228 }
229 dot / (norm_a * norm_b)
230}
231
232fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
234 a.iter()
235 .zip(b.iter())
236 .map(|(x, y)| (x - y) * (x - y))
237 .sum::<f32>()
238 .sqrt()
239}
240
241fn dot_product(a: &[f32], b: &[f32]) -> f32 {
243 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
244}
245
246#[derive(Debug, Clone)]
248pub enum VectorStoreError {
249 DimensionMismatch {
251 expected: usize,
253 got: usize,
255 },
256 CapacityExceeded {
258 capacity: usize,
260 },
261 NotFound(u64),
263}
264
265impl std::fmt::Display for VectorStoreError {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 match self {
268 Self::DimensionMismatch { expected, got } => {
269 write!(f, "Dimension mismatch: expected {}, got {}", expected, got)
270 }
271 Self::CapacityExceeded { capacity } => write!(f, "Capacity exceeded: max {}", capacity),
272 Self::NotFound(id) => write!(f, "Vector {} not found", id),
273 }
274 }
275}
276
277impl std::error::Error for VectorStoreError {}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 fn make_vec(dim: usize, base: f32) -> Vec<f32> {
284 (0..dim).map(|i| base + i as f32 * 0.1).collect()
285 }
286
287 #[test]
288 fn test_insert_and_search() {
289 let mut store = VectorStore::new(VectorStoreConfig {
290 dimensions: 4,
291 metric: DistanceMetric::Cosine,
292 capacity: 100,
293 store_metadata: false,
294 });
295
296 store
297 .insert(VectorEntry {
298 id: 1,
299 vector: vec![1.0, 0.0, 0.0, 0.0],
300 metadata: HashMap::new(),
301 })
302 .unwrap();
303 store
304 .insert(VectorEntry {
305 id: 2,
306 vector: vec![0.9, 0.1, 0.0, 0.0],
307 metadata: HashMap::new(),
308 })
309 .unwrap();
310 store
311 .insert(VectorEntry {
312 id: 3,
313 vector: vec![0.0, 1.0, 0.0, 0.0],
314 metadata: HashMap::new(),
315 })
316 .unwrap();
317
318 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
319 assert_eq!(results.len(), 2);
320 assert_eq!(results[0].id, 1); assert_eq!(results[1].id, 2); }
323
324 #[test]
325 fn test_l2_distance_search() {
326 let mut store = VectorStore::new(VectorStoreConfig {
327 dimensions: 3,
328 metric: DistanceMetric::L2,
329 capacity: 100,
330 store_metadata: false,
331 });
332
333 store
334 .insert(VectorEntry {
335 id: 1,
336 vector: vec![0.0, 0.0, 0.0],
337 metadata: HashMap::new(),
338 })
339 .unwrap();
340 store
341 .insert(VectorEntry {
342 id: 2,
343 vector: vec![1.0, 0.0, 0.0],
344 metadata: HashMap::new(),
345 })
346 .unwrap();
347 store
348 .insert(VectorEntry {
349 id: 3,
350 vector: vec![10.0, 10.0, 10.0],
351 metadata: HashMap::new(),
352 })
353 .unwrap();
354
355 let results = store.search(&[0.0, 0.0, 0.0], 2).unwrap();
356 assert_eq!(results[0].id, 1); assert_eq!(results[1].id, 2); }
359
360 #[test]
361 fn test_dimension_mismatch() {
362 let mut store = VectorStore::new(VectorStoreConfig {
363 dimensions: 4,
364 ..Default::default()
365 });
366
367 let result = store.insert(VectorEntry {
368 id: 1,
369 vector: vec![1.0, 2.0], metadata: HashMap::new(),
371 });
372 assert!(matches!(
373 result,
374 Err(VectorStoreError::DimensionMismatch { .. })
375 ));
376 }
377
378 #[test]
379 fn test_capacity_exceeded() {
380 let mut store = VectorStore::new(VectorStoreConfig {
381 dimensions: 2,
382 capacity: 2,
383 ..Default::default()
384 });
385
386 store
387 .insert(VectorEntry {
388 id: 1,
389 vector: vec![1.0, 0.0],
390 metadata: HashMap::new(),
391 })
392 .unwrap();
393 store
394 .insert(VectorEntry {
395 id: 2,
396 vector: vec![0.0, 1.0],
397 metadata: HashMap::new(),
398 })
399 .unwrap();
400
401 let result = store.insert(VectorEntry {
402 id: 3,
403 vector: vec![1.0, 1.0],
404 metadata: HashMap::new(),
405 });
406 assert!(matches!(
407 result,
408 Err(VectorStoreError::CapacityExceeded { .. })
409 ));
410 }
411
412 #[test]
413 fn test_delete() {
414 let mut store = VectorStore::new(VectorStoreConfig {
415 dimensions: 2,
416 capacity: 10,
417 ..Default::default()
418 });
419
420 store
421 .insert(VectorEntry {
422 id: 1,
423 vector: vec![1.0, 0.0],
424 metadata: HashMap::new(),
425 })
426 .unwrap();
427 store
428 .insert(VectorEntry {
429 id: 2,
430 vector: vec![0.0, 1.0],
431 metadata: HashMap::new(),
432 })
433 .unwrap();
434 assert_eq!(store.len(), 2);
435
436 assert!(store.delete(1));
437 assert_eq!(store.len(), 1);
438 assert!(!store.delete(1)); }
440
441 #[test]
442 fn test_metadata_storage() {
443 let mut store = VectorStore::new(VectorStoreConfig {
444 dimensions: 2,
445 store_metadata: true,
446 ..Default::default()
447 });
448
449 let mut meta = HashMap::new();
450 meta.insert("node_type".to_string(), "isa_standard".to_string());
451
452 store
453 .insert(VectorEntry {
454 id: 42,
455 vector: vec![1.0, 0.0],
456 metadata: meta,
457 })
458 .unwrap();
459
460 let results = store.search(&[1.0, 0.0], 1).unwrap();
461 assert_eq!(results[0].id, 42);
462 let meta = results[0].metadata.as_ref().unwrap();
463 assert_eq!(meta.get("node_type").unwrap(), "isa_standard");
464 }
465
466 #[test]
467 fn test_flat_vectors_for_gpu() {
468 let mut store = VectorStore::new(VectorStoreConfig {
469 dimensions: 3,
470 capacity: 10,
471 ..Default::default()
472 });
473
474 store
475 .insert(VectorEntry {
476 id: 1,
477 vector: vec![1.0, 2.0, 3.0],
478 metadata: HashMap::new(),
479 })
480 .unwrap();
481 store
482 .insert(VectorEntry {
483 id: 2,
484 vector: vec![4.0, 5.0, 6.0],
485 metadata: HashMap::new(),
486 })
487 .unwrap();
488
489 let flat = store.flat_vectors();
490 assert_eq!(flat.len(), 6); assert_eq!(flat, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
492 }
493
494 #[test]
495 fn test_cosine_similarity_identical() {
496 let a = vec![1.0, 2.0, 3.0];
497 assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
498 }
499
500 #[test]
501 fn test_cosine_similarity_orthogonal() {
502 let a = vec![1.0, 0.0];
503 let b = vec![0.0, 1.0];
504 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
505 }
506
507 #[test]
508 fn test_dot_product_basic() {
509 let a = vec![1.0, 2.0, 3.0];
510 let b = vec![4.0, 5.0, 6.0];
511 assert!((dot_product(&a, &b) - 32.0).abs() < 1e-6);
512 }
513}