1use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
15pub struct VectorEntry {
16 pub id: String,
18 pub namespace: String,
20 pub vector: Vec<f32>,
22 pub metadata: HashMap<String, String>,
24 pub created_at: u64,
26}
27
28#[derive(Debug, Clone)]
30pub struct SearchResult {
31 pub id: String,
33 pub namespace: String,
35 pub score: f32,
37 pub metadata: HashMap<String, String>,
39}
40
41#[derive(Debug, Clone)]
43pub struct VectorStoreStats {
44 pub total_vectors: usize,
46 pub namespace_count: usize,
48 pub dimension: Option<usize>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum StoreError {
56 DimensionMismatch { expected: usize, got: usize },
58 EmptyVector,
60}
61
62impl std::fmt::Display for StoreError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 StoreError::DimensionMismatch { expected, got } => {
66 write!(f, "dimension mismatch: expected {expected}, got {got}")
67 }
68 StoreError::EmptyVector => write!(f, "vector must not be empty"),
69 }
70 }
71}
72
73impl std::error::Error for StoreError {}
74
75pub struct VectorStore {
81 namespaces: HashMap<String, Vec<VectorEntry>>,
83 dimension: Option<usize>,
85}
86
87fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
93 let mut dot = 0.0_f32;
94 let mut norm_a = 0.0_f32;
95 let mut norm_b = 0.0_f32;
96 for (&x, &y) in a.iter().zip(b.iter()) {
97 dot += x * y;
98 norm_a += x * x;
99 norm_b += y * y;
100 }
101 let denom = norm_a.sqrt() * norm_b.sqrt();
102 if denom == 0.0 {
103 0.0
104 } else {
105 (dot / denom).clamp(-1.0, 1.0)
106 }
107}
108
109impl Default for VectorStore {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl VectorStore {
116 pub fn new() -> Self {
118 Self {
119 namespaces: HashMap::new(),
120 dimension: None,
121 }
122 }
123
124 pub fn with_dimension(dim: usize) -> Self {
126 Self {
127 namespaces: HashMap::new(),
128 dimension: Some(dim),
129 }
130 }
131
132 fn validate_vector(&self, vector: &[f32]) -> Result<(), StoreError> {
135 if vector.is_empty() {
136 return Err(StoreError::EmptyVector);
137 }
138 if let Some(dim) = self.dimension {
139 if vector.len() != dim {
140 return Err(StoreError::DimensionMismatch {
141 expected: dim,
142 got: vector.len(),
143 });
144 }
145 }
146 Ok(())
147 }
148
149 pub fn upsert(&mut self, entry: VectorEntry) -> Result<bool, StoreError> {
159 self.validate_vector(&entry.vector)?;
160 let entries = self.namespaces.entry(entry.namespace.clone()).or_default();
161 if let Some(existing) = entries.iter_mut().find(|e| e.id == entry.id) {
162 *existing = entry;
163 return Ok(false);
164 }
165 entries.push(entry);
166 Ok(true)
167 }
168
169 pub fn delete(&mut self, namespace: &str, id: &str) -> bool {
173 match self.namespaces.get_mut(namespace) {
174 Some(entries) => {
175 let before = entries.len();
176 entries.retain(|e| e.id != id);
177 entries.len() < before
178 }
179 None => false,
180 }
181 }
182
183 pub fn get(&self, namespace: &str, id: &str) -> Option<&VectorEntry> {
185 self.namespaces
186 .get(namespace)
187 .and_then(|entries| entries.iter().find(|e| e.id == id))
188 }
189
190 pub fn contains(&self, namespace: &str, id: &str) -> bool {
192 self.get(namespace, id).is_some()
193 }
194
195 pub fn search(&self, namespace: &str, query: &[f32], top_k: usize) -> Vec<SearchResult> {
200 let entries = match self.namespaces.get(namespace) {
201 Some(e) => e,
202 None => return Vec::new(),
203 };
204 let mut scored: Vec<(f32, &VectorEntry)> = entries
205 .iter()
206 .filter(|e| e.vector.len() == query.len())
207 .map(|e| (cosine_similarity(&e.vector, query), e))
208 .collect();
209 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
210 scored
211 .into_iter()
212 .take(top_k)
213 .map(|(score, entry)| SearchResult {
214 id: entry.id.clone(),
215 namespace: entry.namespace.clone(),
216 score,
217 metadata: entry.metadata.clone(),
218 })
219 .collect()
220 }
221
222 pub fn search_all_namespaces(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
224 let mut scored: Vec<(f32, &VectorEntry)> = self
225 .namespaces
226 .values()
227 .flat_map(|entries| entries.iter())
228 .filter(|e| e.vector.len() == query.len())
229 .map(|e| (cosine_similarity(&e.vector, query), e))
230 .collect();
231 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
232 scored
233 .into_iter()
234 .take(top_k)
235 .map(|(score, entry)| SearchResult {
236 id: entry.id.clone(),
237 namespace: entry.namespace.clone(),
238 score,
239 metadata: entry.metadata.clone(),
240 })
241 .collect()
242 }
243
244 pub fn list(&self, namespace: &str) -> Vec<&VectorEntry> {
249 match self.namespaces.get(namespace) {
250 Some(entries) => entries.iter().collect(),
251 None => Vec::new(),
252 }
253 }
254
255 pub fn delete_namespace(&mut self, namespace: &str) -> usize {
261 match self.namespaces.remove(namespace) {
262 Some(entries) => entries.len(),
263 None => 0,
264 }
265 }
266
267 pub fn stats(&self) -> VectorStoreStats {
271 let total_vectors: usize = self.namespaces.values().map(|v| v.len()).sum();
272 let namespace_count = self.namespaces.values().filter(|v| !v.is_empty()).count();
273 let dimension = self.dimension.or_else(|| {
275 self.namespaces
276 .values()
277 .flat_map(|v| v.iter())
278 .next()
279 .map(|e| e.vector.len())
280 });
281 VectorStoreStats {
282 total_vectors,
283 namespace_count,
284 dimension,
285 }
286 }
287}
288
289#[cfg(test)]
294mod tests {
295 use super::*;
296
297 fn entry(id: &str, ns: &str, v: Vec<f32>) -> VectorEntry {
298 VectorEntry {
299 id: id.to_string(),
300 namespace: ns.to_string(),
301 vector: v,
302 metadata: HashMap::new(),
303 created_at: 0,
304 }
305 }
306
307 fn unit_vec(dim: usize, one_at: usize) -> Vec<f32> {
308 let mut v = vec![0.0_f32; dim];
309 v[one_at] = 1.0;
310 v
311 }
312
313 #[test]
316 fn test_upsert_new_returns_true() {
317 let mut store = VectorStore::new();
318 let is_new = store.upsert(entry("e1", "ns", vec![1.0, 0.0])).unwrap();
319 assert!(is_new);
320 }
321
322 #[test]
323 fn test_upsert_update_returns_false() {
324 let mut store = VectorStore::new();
325 store.upsert(entry("e1", "ns", vec![1.0, 0.0])).unwrap();
326 let is_new = store.upsert(entry("e1", "ns", vec![0.0, 1.0])).unwrap();
327 assert!(!is_new);
328 }
329
330 #[test]
331 fn test_upsert_update_replaces_vector() {
332 let mut store = VectorStore::new();
333 store.upsert(entry("e1", "ns", vec![1.0, 0.0])).unwrap();
334 store.upsert(entry("e1", "ns", vec![0.0, 1.0])).unwrap();
335 let got = store.get("ns", "e1").unwrap();
336 assert_eq!(got.vector, vec![0.0, 1.0]);
337 }
338
339 #[test]
340 fn test_upsert_empty_vector_errors() {
341 let mut store = VectorStore::new();
342 let res = store.upsert(entry("e1", "ns", vec![]));
343 assert_eq!(res, Err(StoreError::EmptyVector));
344 }
345
346 #[test]
347 fn test_upsert_dimension_mismatch_errors() {
348 let mut store = VectorStore::with_dimension(3);
349 let res = store.upsert(entry("e1", "ns", vec![1.0, 2.0]));
350 assert_eq!(
351 res,
352 Err(StoreError::DimensionMismatch {
353 expected: 3,
354 got: 2
355 })
356 );
357 }
358
359 #[test]
360 fn test_upsert_correct_dimension_ok() {
361 let mut store = VectorStore::with_dimension(2);
362 assert!(store.upsert(entry("e1", "ns", vec![1.0, 0.0])).is_ok());
363 }
364
365 #[test]
368 fn test_delete_existing() {
369 let mut store = VectorStore::new();
370 store.upsert(entry("e1", "ns", vec![1.0])).unwrap();
371 assert!(store.delete("ns", "e1"));
372 assert!(store.get("ns", "e1").is_none());
373 }
374
375 #[test]
376 fn test_delete_nonexistent_returns_false() {
377 let mut store = VectorStore::new();
378 assert!(!store.delete("ns", "ghost"));
379 }
380
381 #[test]
384 fn test_get_existing() {
385 let mut store = VectorStore::new();
386 store.upsert(entry("e1", "ns", vec![1.0, 2.0])).unwrap();
387 assert!(store.get("ns", "e1").is_some());
388 }
389
390 #[test]
391 fn test_get_nonexistent_none() {
392 let store = VectorStore::new();
393 assert!(store.get("ns", "missing").is_none());
394 }
395
396 #[test]
399 fn test_contains_true() {
400 let mut store = VectorStore::new();
401 store.upsert(entry("x", "ns", vec![1.0])).unwrap();
402 assert!(store.contains("ns", "x"));
403 }
404
405 #[test]
406 fn test_contains_false() {
407 let store = VectorStore::new();
408 assert!(!store.contains("ns", "x"));
409 }
410
411 #[test]
414 fn test_search_sorted_by_score_descending() {
415 let mut store = VectorStore::new();
416 store.upsert(entry("e1", "ns", unit_vec(3, 0))).unwrap();
418 store.upsert(entry("e2", "ns", unit_vec(3, 1))).unwrap();
420 let query = unit_vec(3, 0);
421 let results = store.search("ns", &query, 2);
422 assert_eq!(results.len(), 2);
423 assert!(results[0].score >= results[1].score);
424 assert_eq!(results[0].id, "e1");
425 }
426
427 #[test]
428 fn test_search_top_k_limit() {
429 let mut store = VectorStore::new();
430 for i in 0..10 {
431 store
432 .upsert(entry(&i.to_string(), "ns", vec![i as f32]))
433 .unwrap();
434 }
435 let results = store.search("ns", &[5.0_f32], 3);
436 assert_eq!(results.len(), 3);
437 }
438
439 #[test]
440 fn test_search_same_vector_max_score() {
441 let mut store = VectorStore::new();
442 let v = vec![1.0_f32, 1.0, 1.0];
443 store.upsert(entry("e", "ns", v.clone())).unwrap();
444 let results = store.search("ns", &v, 1);
445 assert_eq!(results.len(), 1);
446 assert!((results[0].score - 1.0).abs() < 1e-6);
447 }
448
449 #[test]
450 fn test_search_empty_namespace_returns_empty() {
451 let store = VectorStore::new();
452 let results = store.search("missing-ns", &[1.0_f32], 5);
453 assert!(results.is_empty());
454 }
455
456 #[test]
459 fn test_search_all_namespaces_cross_namespace() {
460 let mut store = VectorStore::new();
461 store.upsert(entry("a", "ns1", vec![1.0_f32, 0.0])).unwrap();
462 store.upsert(entry("b", "ns2", vec![0.0_f32, 1.0])).unwrap();
463 let results = store.search_all_namespaces(&[1.0_f32, 0.0], 2);
464 assert_eq!(results.len(), 2);
465 assert_eq!(results[0].id, "a"); }
467
468 #[test]
469 fn test_search_all_namespaces_top_k() {
470 let mut store = VectorStore::new();
471 for i in 0..5 {
472 store
473 .upsert(entry(&format!("e{i}"), &format!("ns{i}"), vec![i as f32]))
474 .unwrap();
475 }
476 let results = store.search_all_namespaces(&[2.0_f32], 2);
477 assert_eq!(results.len(), 2);
478 }
479
480 #[test]
483 fn test_list_all_in_namespace() {
484 let mut store = VectorStore::new();
485 store.upsert(entry("a", "ns", vec![1.0])).unwrap();
486 store.upsert(entry("b", "ns", vec![2.0])).unwrap();
487 let listed = store.list("ns");
488 assert_eq!(listed.len(), 2);
489 }
490
491 #[test]
492 fn test_list_nonexistent_namespace_empty() {
493 let store = VectorStore::new();
494 assert!(store.list("ghost").is_empty());
495 }
496
497 #[test]
500 fn test_delete_namespace_returns_count() {
501 let mut store = VectorStore::new();
502 store.upsert(entry("a", "ns", vec![1.0])).unwrap();
503 store.upsert(entry("b", "ns", vec![2.0])).unwrap();
504 assert_eq!(store.delete_namespace("ns"), 2);
505 }
506
507 #[test]
508 fn test_delete_namespace_nonexistent_returns_zero() {
509 let mut store = VectorStore::new();
510 assert_eq!(store.delete_namespace("ghost"), 0);
511 }
512
513 #[test]
514 fn test_delete_namespace_removes_entries() {
515 let mut store = VectorStore::new();
516 store.upsert(entry("a", "ns", vec![1.0])).unwrap();
517 store.delete_namespace("ns");
518 assert!(store.list("ns").is_empty());
519 }
520
521 #[test]
524 fn test_stats_empty_store() {
525 let store = VectorStore::new();
526 let s = store.stats();
527 assert_eq!(s.total_vectors, 0);
528 assert_eq!(s.namespace_count, 0);
529 assert!(s.dimension.is_none());
530 }
531
532 #[test]
533 fn test_stats_counts_correctly() {
534 let mut store = VectorStore::new();
535 store.upsert(entry("a", "ns1", vec![1.0, 2.0])).unwrap();
536 store.upsert(entry("b", "ns1", vec![3.0, 4.0])).unwrap();
537 store.upsert(entry("c", "ns2", vec![5.0, 6.0])).unwrap();
538 let s = store.stats();
539 assert_eq!(s.total_vectors, 3);
540 assert_eq!(s.namespace_count, 2);
541 assert_eq!(s.dimension, Some(2));
542 }
543
544 #[test]
545 fn test_stats_with_configured_dimension() {
546 let store = VectorStore::with_dimension(128);
547 let s = store.stats();
548 assert_eq!(s.dimension, Some(128));
549 }
550
551 #[test]
554 fn test_upsert_multiple_namespaces() {
555 let mut store = VectorStore::new();
556 store.upsert(entry("a", "ns1", vec![1.0])).unwrap();
557 store.upsert(entry("b", "ns2", vec![2.0])).unwrap();
558 assert!(store.contains("ns1", "a"));
559 assert!(store.contains("ns2", "b"));
560 }
561
562 #[test]
563 fn test_delete_from_wrong_namespace() {
564 let mut store = VectorStore::new();
565 store.upsert(entry("a", "ns1", vec![1.0])).unwrap();
566 assert!(!store.delete("ns2", "a")); assert!(store.contains("ns1", "a")); }
569
570 #[test]
571 fn test_search_returns_correct_namespace() {
572 let mut store = VectorStore::new();
573 store.upsert(entry("a", "ns1", vec![1.0_f32, 0.0])).unwrap();
574 let results = store.search("ns1", &[1.0_f32, 0.0], 1);
575 assert_eq!(results[0].namespace, "ns1");
576 }
577
578 #[test]
579 fn test_search_all_namespaces_empty_store() {
580 let store = VectorStore::new();
581 assert!(store.search_all_namespaces(&[1.0_f32], 5).is_empty());
582 }
583
584 #[test]
585 fn test_metadata_stored_and_retrieved() {
586 let mut store = VectorStore::new();
587 let mut meta = HashMap::new();
588 meta.insert("source".into(), "test".into());
589 let mut e = entry("e1", "ns", vec![1.0]);
590 e.metadata = meta;
591 store.upsert(e).unwrap();
592 let got = store.get("ns", "e1").unwrap();
593 assert_eq!(got.metadata.get("source").map(|s| s.as_str()), Some("test"));
594 }
595
596 #[test]
597 fn test_created_at_stored() {
598 let mut store = VectorStore::new();
599 let mut e = entry("e1", "ns", vec![1.0]);
600 e.created_at = 12345678;
601 store.upsert(e).unwrap();
602 assert_eq!(store.get("ns", "e1").unwrap().created_at, 12345678);
603 }
604
605 #[test]
606 fn test_store_error_display_empty_vector() {
607 let e = StoreError::EmptyVector;
608 assert!(!e.to_string().is_empty());
609 }
610
611 #[test]
612 fn test_store_error_display_dimension_mismatch() {
613 let e = StoreError::DimensionMismatch {
614 expected: 3,
615 got: 5,
616 };
617 let s = e.to_string();
618 assert!(s.contains("3") && s.contains("5"));
619 }
620
621 #[test]
622 fn test_stats_namespace_count_after_delete() {
623 let mut store = VectorStore::new();
624 store.upsert(entry("a", "ns1", vec![1.0])).unwrap();
625 store.upsert(entry("b", "ns2", vec![1.0])).unwrap();
626 store.delete_namespace("ns1");
627 let s = store.stats();
628 assert_eq!(s.namespace_count, 1);
629 }
630
631 #[test]
632 fn test_search_scores_in_range() {
633 let mut store = VectorStore::new();
634 store.upsert(entry("a", "ns", vec![1.0_f32, 0.0])).unwrap();
635 store.upsert(entry("b", "ns", vec![0.0_f32, 1.0])).unwrap();
636 let results = store.search("ns", &[0.7_f32, 0.7], 2);
637 for r in &results {
638 assert!(r.score >= -1.0 && r.score <= 1.0);
639 }
640 }
641
642 #[test]
643 fn test_search_returns_metadata() {
644 let mut store = VectorStore::new();
645 let mut e = entry("e1", "ns", vec![1.0_f32, 0.0]);
646 e.metadata.insert("key".into(), "val".into());
647 store.upsert(e).unwrap();
648 let results = store.search("ns", &[1.0_f32, 0.0], 1);
649 assert_eq!(
650 results[0].metadata.get("key").map(|s| s.as_str()),
651 Some("val")
652 );
653 }
654
655 #[test]
656 fn test_upsert_different_ids_same_namespace() {
657 let mut store = VectorStore::new();
658 store.upsert(entry("a", "ns", vec![1.0])).unwrap();
659 store.upsert(entry("b", "ns", vec![2.0])).unwrap();
660 store.upsert(entry("c", "ns", vec![3.0])).unwrap();
661 assert_eq!(store.list("ns").len(), 3);
662 }
663
664 #[test]
665 fn test_delete_reduces_list_count() {
666 let mut store = VectorStore::new();
667 store.upsert(entry("a", "ns", vec![1.0])).unwrap();
668 store.upsert(entry("b", "ns", vec![2.0])).unwrap();
669 store.delete("ns", "a");
670 assert_eq!(store.list("ns").len(), 1);
671 }
672
673 #[test]
674 fn test_search_all_namespaces_result_ids_correct() {
675 let mut store = VectorStore::new();
676 store
677 .upsert(entry("best", "ns1", vec![1.0_f32, 0.0]))
678 .unwrap();
679 store
680 .upsert(entry("other", "ns2", vec![0.0_f32, 1.0]))
681 .unwrap();
682 let results = store.search_all_namespaces(&[1.0_f32, 0.0], 1);
683 assert_eq!(results[0].id, "best");
684 }
685
686 #[test]
687 fn test_cosine_similarity_opposite_vectors() {
688 let mut store = VectorStore::new();
689 store.upsert(entry("a", "ns", vec![1.0_f32, 0.0])).unwrap();
690 store.upsert(entry("b", "ns", vec![-1.0_f32, 0.0])).unwrap();
691 let results = store.search("ns", &[1.0_f32, 0.0], 2);
692 assert_eq!(results[0].id, "a");
694 assert!(results[0].score > results[1].score);
695 }
696
697 #[test]
698 fn test_with_dimension_rejects_extra_dims() {
699 let mut store = VectorStore::with_dimension(2);
700 let res = store.upsert(entry("e", "ns", vec![1.0, 2.0, 3.0]));
701 assert!(matches!(
702 res,
703 Err(StoreError::DimensionMismatch {
704 expected: 2,
705 got: 3
706 })
707 ));
708 }
709
710 #[test]
711 fn test_upsert_returns_new_flag_consistently() {
712 let mut store = VectorStore::new();
713 let r1 = store.upsert(entry("e", "ns", vec![1.0])).unwrap();
714 let r2 = store.upsert(entry("e", "ns", vec![2.0])).unwrap();
715 assert!(r1); assert!(!r2); }
718
719 #[test]
720 fn test_stats_total_includes_all_namespaces() {
721 let mut store = VectorStore::new();
722 for i in 0..5 {
723 store
724 .upsert(entry(&i.to_string(), &format!("ns{i}"), vec![i as f32]))
725 .unwrap();
726 }
727 assert_eq!(store.stats().total_vectors, 5);
728 }
729
730 #[test]
731 fn test_get_after_update_returns_new_vector() {
732 let mut store = VectorStore::new();
733 store.upsert(entry("e", "ns", vec![1.0, 0.0])).unwrap();
734 store.upsert(entry("e", "ns", vec![0.0, 1.0])).unwrap();
735 let got = store.get("ns", "e").unwrap();
736 assert_eq!(got.vector, vec![0.0_f32, 1.0]);
737 }
738
739 #[test]
740 fn test_search_zero_top_k() {
741 let mut store = VectorStore::new();
742 store.upsert(entry("a", "ns", vec![1.0_f32])).unwrap();
743 let results = store.search("ns", &[1.0_f32], 0);
744 assert!(results.is_empty());
745 }
746}