1use std::collections::HashMap;
7
8#[derive(Debug)]
14pub enum StoreError {
15 DimensionMismatch {
17 expected: usize,
19 got: usize,
21 },
22 LabelNotFound(String),
24 EmptyStore,
26}
27
28impl std::fmt::Display for StoreError {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 StoreError::DimensionMismatch { expected, got } => {
32 write!(f, "dimension mismatch: expected {expected}, got {got}")
33 }
34 StoreError::LabelNotFound(label) => {
35 write!(f, "label not found: {label}")
36 }
37 StoreError::EmptyStore => write!(f, "store is empty"),
38 }
39 }
40}
41
42impl std::error::Error for StoreError {}
43
44#[derive(Debug, Clone)]
50pub struct EmbeddingEntry {
51 pub id: usize,
53 pub label: String,
55 pub vector: Vec<f64>,
57 pub metadata: HashMap<String, String>,
59}
60
61pub struct EmbeddingStore {
67 entries: Vec<EmbeddingEntry>,
68 label_index: HashMap<String, usize>, dim: Option<usize>,
70}
71
72impl Default for EmbeddingStore {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl EmbeddingStore {
79 pub fn new() -> Self {
81 Self {
82 entries: Vec::new(),
83 label_index: HashMap::new(),
84 dim: None,
85 }
86 }
87
88 pub fn insert(
95 &mut self,
96 label: impl Into<String>,
97 vector: Vec<f64>,
98 ) -> Result<usize, StoreError> {
99 self.insert_with_meta(label, vector, HashMap::new())
100 }
101
102 pub fn insert_with_meta(
104 &mut self,
105 label: impl Into<String>,
106 vector: Vec<f64>,
107 meta: HashMap<String, String>,
108 ) -> Result<usize, StoreError> {
109 let label = label.into();
110
111 match self.dim {
113 Some(d) if d != vector.len() => {
114 return Err(StoreError::DimensionMismatch {
115 expected: d,
116 got: vector.len(),
117 });
118 }
119 None => {
120 self.dim = Some(vector.len());
121 }
122 Some(_) => {}
123 }
124
125 let id = self.entries.len();
126
127 if let Some(&idx) = self.label_index.get(&label) {
129 self.entries[idx].vector = vector;
131 self.entries[idx].metadata = meta;
132 return Ok(self.entries[idx].id);
133 }
134
135 self.label_index.insert(label.clone(), id);
136 self.entries.push(EmbeddingEntry {
137 id,
138 label,
139 vector,
140 metadata: meta,
141 });
142 Ok(id)
143 }
144
145 pub fn get_by_label(&self, label: &str) -> Option<&EmbeddingEntry> {
147 let idx = self.label_index.get(label)?;
148 self.entries.get(*idx)
149 }
150
151 pub fn get_by_id(&self, id: usize) -> Option<&EmbeddingEntry> {
153 self.entries.iter().find(|e| e.id == id)
154 }
155
156 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163 if norm_a == 0.0 || norm_b == 0.0 {
164 return 0.0;
165 }
166 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
167 }
168
169 pub fn nearest(
175 &self,
176 query: &[f64],
177 k: usize,
178 ) -> Result<Vec<(&EmbeddingEntry, f64)>, StoreError> {
179 if self.entries.is_empty() {
180 return Err(StoreError::EmptyStore);
181 }
182 if let Some(d) = self.dim {
183 if query.len() != d {
184 return Err(StoreError::DimensionMismatch {
185 expected: d,
186 got: query.len(),
187 });
188 }
189 }
190
191 let mut scored: Vec<(&EmbeddingEntry, f64)> = self
192 .entries
193 .iter()
194 .map(|e| (e, Self::cosine_similarity(query, &e.vector)))
195 .collect();
196
197 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198 scored.truncate(k);
199 Ok(scored)
200 }
201
202 pub fn len(&self) -> usize {
204 self.entries.len()
205 }
206
207 pub fn is_empty(&self) -> bool {
209 self.entries.is_empty()
210 }
211
212 pub fn dim(&self) -> Option<usize> {
215 self.dim
216 }
217
218 pub fn labels(&self) -> Vec<&str> {
220 self.entries.iter().map(|e| e.label.as_str()).collect()
221 }
222
223 pub fn remove(&mut self, label: &str) -> bool {
229 if let Some(idx) = self.label_index.remove(label) {
230 self.entries.remove(idx);
231 self.label_index.clear();
233 for (i, entry) in self.entries.iter().enumerate() {
234 self.label_index.insert(entry.label.clone(), i);
235 }
236 if self.entries.is_empty() {
238 self.dim = None;
239 }
240 true
241 } else {
242 false
243 }
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn v2(x: f64, y: f64) -> Vec<f64> {
256 vec![x, y]
257 }
258
259 fn v3(x: f64, y: f64, z: f64) -> Vec<f64> {
260 vec![x, y, z]
261 }
262
263 #[test]
266 fn test_new_empty() {
267 let store = EmbeddingStore::new();
268 assert!(store.is_empty());
269 assert_eq!(store.len(), 0);
270 assert!(store.dim().is_none());
271 }
272
273 #[test]
274 fn test_insert_first_sets_dim() {
275 let mut store = EmbeddingStore::new();
276 store
277 .insert("a", v3(1.0, 0.0, 0.0))
278 .expect("should succeed");
279 assert_eq!(store.dim(), Some(3));
280 }
281
282 #[test]
283 fn test_insert_returns_id() {
284 let mut store = EmbeddingStore::new();
285 let id = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
286 assert_eq!(id, 0);
287 let id2 = store.insert("b", v2(0.0, 1.0)).expect("should succeed");
288 assert_eq!(id2, 1);
289 }
290
291 #[test]
292 fn test_insert_increments_len() {
293 let mut store = EmbeddingStore::new();
294 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
295 assert_eq!(store.len(), 1);
296 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
297 assert_eq!(store.len(), 2);
298 }
299
300 #[test]
301 fn test_insert_dim_mismatch_error() {
302 let mut store = EmbeddingStore::new();
303 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
304 let result = store.insert("b", v3(0.0, 1.0, 0.0));
305 assert!(matches!(
306 result,
307 Err(StoreError::DimensionMismatch {
308 expected: 2,
309 got: 3
310 })
311 ));
312 }
313
314 #[test]
315 fn test_insert_update_existing_label() {
316 let mut store = EmbeddingStore::new();
317 let id1 = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
318 let id2 = store.insert("a", v2(0.5, 0.5)).expect("should succeed");
319 assert_eq!(id1, id2);
321 assert_eq!(store.len(), 1);
322 let e = store.get_by_label("a").expect("exists");
323 assert!((e.vector[0] - 0.5).abs() < 1e-9);
324 }
325
326 #[test]
329 fn test_insert_with_meta_stores_metadata() {
330 let mut store = EmbeddingStore::new();
331 let mut meta = HashMap::new();
332 meta.insert("lang".to_string(), "en".to_string());
333 store
334 .insert_with_meta("doc1", v2(1.0, 0.0), meta)
335 .expect("should succeed");
336 let e = store.get_by_label("doc1").expect("exists");
337 assert_eq!(e.metadata["lang"], "en");
338 }
339
340 #[test]
343 fn test_get_by_label_existing() {
344 let mut store = EmbeddingStore::new();
345 store.insert("hello", v2(1.0, 0.0)).expect("should succeed");
346 assert!(store.get_by_label("hello").is_some());
347 }
348
349 #[test]
350 fn test_get_by_label_missing() {
351 let store = EmbeddingStore::new();
352 assert!(store.get_by_label("missing").is_none());
353 }
354
355 #[test]
356 fn test_get_by_label_returns_correct_vector() {
357 let mut store = EmbeddingStore::new();
358 store.insert("x", v2(3.0, 4.0)).expect("should succeed");
359 let e = store.get_by_label("x").expect("exists");
360 assert!((e.vector[0] - 3.0).abs() < 1e-9);
361 assert!((e.vector[1] - 4.0).abs() < 1e-9);
362 }
363
364 #[test]
367 fn test_get_by_id_existing() {
368 let mut store = EmbeddingStore::new();
369 let id = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
370 assert!(store.get_by_id(id).is_some());
371 }
372
373 #[test]
374 fn test_get_by_id_missing() {
375 let store = EmbeddingStore::new();
376 assert!(store.get_by_id(999).is_none());
377 }
378
379 #[test]
380 fn test_get_by_id_matches_label() {
381 let mut store = EmbeddingStore::new();
382 let id = store.insert("mykey", v2(1.0, 2.0)).expect("should succeed");
383 let e = store.get_by_id(id).expect("exists");
384 assert_eq!(e.label, "mykey");
385 }
386
387 #[test]
390 fn test_cosine_identical_vectors() {
391 let a = v3(1.0, 2.0, 3.0);
392 let sim = EmbeddingStore::cosine_similarity(&a, &a);
393 assert!((sim - 1.0).abs() < 1e-9);
394 }
395
396 #[test]
397 fn test_cosine_orthogonal_vectors() {
398 let a = v2(1.0, 0.0);
399 let b = v2(0.0, 1.0);
400 let sim = EmbeddingStore::cosine_similarity(&a, &b);
401 assert!(sim.abs() < 1e-9);
402 }
403
404 #[test]
405 fn test_cosine_opposite_vectors() {
406 let a = v2(1.0, 0.0);
407 let b = v2(-1.0, 0.0);
408 let sim = EmbeddingStore::cosine_similarity(&a, &b);
409 assert!((sim - (-1.0)).abs() < 1e-9);
410 }
411
412 #[test]
413 fn test_cosine_zero_vector_returns_zero() {
414 let a = v2(0.0, 0.0);
415 let b = v2(1.0, 0.0);
416 let sim = EmbeddingStore::cosine_similarity(&a, &b);
417 assert_eq!(sim, 0.0);
418 }
419
420 #[test]
421 fn test_cosine_symmetry() {
422 let a = v3(1.0, 2.0, 3.0);
423 let b = v3(4.0, 5.0, 6.0);
424 let sim_ab = EmbeddingStore::cosine_similarity(&a, &b);
425 let sim_ba = EmbeddingStore::cosine_similarity(&b, &a);
426 assert!((sim_ab - sim_ba).abs() < 1e-9);
427 }
428
429 #[test]
432 fn test_nearest_empty_store_error() {
433 let store = EmbeddingStore::new();
434 assert!(matches!(
435 store.nearest(&[1.0, 0.0], 3),
436 Err(StoreError::EmptyStore)
437 ));
438 }
439
440 #[test]
441 fn test_nearest_dim_mismatch_error() {
442 let mut store = EmbeddingStore::new();
443 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
444 assert!(matches!(
445 store.nearest(&[1.0, 0.0, 0.0], 3),
446 Err(StoreError::DimensionMismatch { .. })
447 ));
448 }
449
450 #[test]
451 fn test_nearest_returns_k_results() {
452 let mut store = EmbeddingStore::new();
453 for i in 0..5 {
454 store
455 .insert(format!("e{i}"), vec![i as f64, 0.0])
456 .expect("should succeed");
457 }
458 let results = store.nearest(&[1.0, 0.0], 3).expect("should succeed");
459 assert_eq!(results.len(), 3);
460 }
461
462 #[test]
463 fn test_nearest_sorted_descending() {
464 let mut store = EmbeddingStore::new();
465 store.insert("up", v2(0.0, 1.0)).expect("should succeed");
466 store.insert("right", v2(1.0, 0.0)).expect("should succeed");
467 store.insert("diag", v2(1.0, 1.0)).expect("should succeed");
468 let query = v2(1.0, 0.0);
469 let results = store.nearest(&query, 3).expect("should succeed");
470 let sims: Vec<f64> = results.iter().map(|(_, s)| *s).collect();
471 for pair in sims.windows(2) {
472 assert!(pair[0] >= pair[1]);
473 }
474 }
475
476 #[test]
477 fn test_nearest_top1_is_most_similar() {
478 let mut store = EmbeddingStore::new();
479 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
480 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
481 store.insert("c", v2(-1.0, 0.0)).expect("should succeed");
482 let results = store.nearest(&[1.0, 0.0], 1).expect("should succeed");
483 assert_eq!(results[0].0.label, "a");
484 }
485
486 #[test]
489 fn test_labels_empty() {
490 let store = EmbeddingStore::new();
491 assert!(store.labels().is_empty());
492 }
493
494 #[test]
495 fn test_labels_returns_all() {
496 let mut store = EmbeddingStore::new();
497 store.insert("alpha", v2(1.0, 0.0)).expect("should succeed");
498 store.insert("beta", v2(0.0, 1.0)).expect("should succeed");
499 let labels = store.labels();
500 assert_eq!(labels.len(), 2);
501 assert!(labels.contains(&"alpha"));
502 assert!(labels.contains(&"beta"));
503 }
504
505 #[test]
508 fn test_remove_existing_returns_true() {
509 let mut store = EmbeddingStore::new();
510 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
511 assert!(store.remove("a"));
512 assert!(store.is_empty());
513 }
514
515 #[test]
516 fn test_remove_missing_returns_false() {
517 let mut store = EmbeddingStore::new();
518 assert!(!store.remove("ghost"));
519 }
520
521 #[test]
522 fn test_remove_decrements_len() {
523 let mut store = EmbeddingStore::new();
524 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
525 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
526 store.remove("a");
527 assert_eq!(store.len(), 1);
528 }
529
530 #[test]
531 fn test_remove_remaining_entry_still_accessible() {
532 let mut store = EmbeddingStore::new();
533 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
534 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
535 store.remove("a");
536 assert!(store.get_by_label("b").is_some());
537 }
538
539 #[test]
540 fn test_remove_all_resets_dim() {
541 let mut store = EmbeddingStore::new();
542 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
543 store.remove("a");
544 assert!(store.dim().is_none());
545 }
546
547 #[test]
548 fn test_remove_allows_reinsertion_with_different_dim() {
549 let mut store = EmbeddingStore::new();
550 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
551 store.remove("a");
552 store
554 .insert("a", v3(1.0, 0.0, 0.0))
555 .expect("should succeed");
556 assert_eq!(store.dim(), Some(3));
557 }
558
559 #[test]
562 fn test_default_same_as_new() {
563 let store = EmbeddingStore::default();
564 assert!(store.is_empty());
565 }
566
567 #[test]
570 fn test_error_display_dimension_mismatch() {
571 let e = StoreError::DimensionMismatch {
572 expected: 3,
573 got: 2,
574 };
575 assert!(!e.to_string().is_empty());
576 }
577
578 #[test]
579 fn test_error_display_label_not_found() {
580 let e = StoreError::LabelNotFound("ghost".to_string());
581 assert!(e.to_string().contains("ghost"));
582 }
583
584 #[test]
585 fn test_error_display_empty_store() {
586 let e = StoreError::EmptyStore;
587 assert!(!e.to_string().is_empty());
588 }
589
590 #[test]
593 fn test_nearest_k_larger_than_store() {
594 let mut store = EmbeddingStore::new();
595 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
596 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
597 let results = store.nearest(&[1.0, 1.0], 10).expect("should succeed");
598 assert_eq!(results.len(), 2);
600 }
601
602 #[test]
603 fn test_id_is_stable_for_inserted_entry() {
604 let mut store = EmbeddingStore::new();
605 let id = store.insert("vec", v2(1.0, 1.0)).expect("should succeed");
606 let e = store.get_by_label("vec").expect("exists");
607 assert_eq!(e.id, id);
608 }
609
610 #[test]
611 fn test_entry_label_matches() {
612 let mut store = EmbeddingStore::new();
613 store
614 .insert("myLabel", v2(0.5, 0.5))
615 .expect("should succeed");
616 let e = store.get_by_label("myLabel").expect("exists");
617 assert_eq!(e.label, "myLabel");
618 }
619
620 #[test]
623 fn test_insert_empty_vector_sets_dim_zero() {
624 let mut store = EmbeddingStore::new();
625 store.insert("empty", vec![]).expect("should succeed");
626 assert_eq!(store.dim(), Some(0));
627 }
628
629 #[test]
630 fn test_cosine_unit_vectors() {
631 let a = vec![1.0_f64 / 2.0_f64.sqrt(), 1.0_f64 / 2.0_f64.sqrt()];
633 let b = vec![1.0, 0.0];
634 let sim = EmbeddingStore::cosine_similarity(&a, &b);
635 assert!((sim - (1.0_f64 / 2.0_f64.sqrt())).abs() < 1e-9);
636 }
637
638 #[test]
639 fn test_nearest_returns_fewer_when_store_smaller_than_k() {
640 let mut store = EmbeddingStore::new();
641 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
642 let results = store.nearest(&[1.0, 0.0], 100).expect("should succeed");
643 assert_eq!(results.len(), 1);
644 }
645
646 #[test]
647 fn test_remove_all_entries_allows_new_dim() {
648 let mut store = EmbeddingStore::new();
649 store.insert("a", v2(1.0, 0.0)).expect("should succeed");
650 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
651 store.remove("a");
652 store.remove("b");
653 assert_eq!(store.dim(), None);
654 store
656 .insert("c", v3(1.0, 0.0, 0.0))
657 .expect("should succeed");
658 assert_eq!(store.dim(), Some(3));
659 }
660
661 #[test]
662 fn test_get_by_id_after_remove_middle() {
663 let mut store = EmbeddingStore::new();
664 let id_a = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
665 store.insert("b", v2(0.0, 1.0)).expect("should succeed");
666 let id_c = store.insert("c", v2(0.5, 0.5)).expect("should succeed");
667 store.remove("b");
668 assert!(store.get_by_id(id_a).is_some());
670 assert!(store.get_by_id(id_c).is_some());
671 }
672
673 #[test]
674 fn test_insert_with_meta_empty_meta() {
675 let mut store = EmbeddingStore::new();
676 store
677 .insert_with_meta("doc", v2(1.0, 0.0), HashMap::new())
678 .expect("should succeed");
679 let e = store.get_by_label("doc").expect("exists");
680 assert!(e.metadata.is_empty());
681 }
682
683 #[test]
684 fn test_nearest_similarity_range() {
685 let mut store = EmbeddingStore::new();
686 store
687 .insert("a", v3(1.0, 0.0, 0.0))
688 .expect("should succeed");
689 store
690 .insert("b", v3(0.0, 1.0, 0.0))
691 .expect("should succeed");
692 store
693 .insert("c", v3(0.0, 0.0, 1.0))
694 .expect("should succeed");
695 let results = store.nearest(&[1.0, 0.0, 0.0], 3).expect("should succeed");
696 for (_, sim) in &results {
697 assert!(*sim >= -1.0 && *sim <= 1.0);
698 }
699 }
700}