1use std::collections::HashMap;
7
8#[derive(Debug)]
14pub enum StoreError {
15 DimensionMismatch {
17 expected: usize,
19 got: usize,
21 },
22 LabelNotFound(String),
24 EmptyStore,
26}
27
28impl std::fmt::Display for StoreError {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 StoreError::DimensionMismatch { expected, got } => {
32 write!(f, "dimension mismatch: expected {expected}, got {got}")
33 }
34 StoreError::LabelNotFound(label) => {
35 write!(f, "label not found: {label}")
36 }
37 StoreError::EmptyStore => write!(f, "store is empty"),
38 }
39 }
40}
41
42impl std::error::Error for StoreError {}
43
44#[derive(Debug, Clone)]
50pub struct EmbeddingEntry {
51 pub id: usize,
53 pub label: String,
55 pub vector: Vec<f64>,
57 pub metadata: HashMap<String, String>,
59}
60
61pub struct EmbeddingStore {
67 entries: Vec<EmbeddingEntry>,
68 label_index: HashMap<String, usize>, dim: Option<usize>,
70}
71
72impl Default for EmbeddingStore {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl EmbeddingStore {
79 pub fn new() -> Self {
81 Self {
82 entries: Vec::new(),
83 label_index: HashMap::new(),
84 dim: None,
85 }
86 }
87
88 pub fn insert(
95 &mut self,
96 label: impl Into<String>,
97 vector: Vec<f64>,
98 ) -> Result<usize, StoreError> {
99 self.insert_with_meta(label, vector, HashMap::new())
100 }
101
102 pub fn insert_with_meta(
104 &mut self,
105 label: impl Into<String>,
106 vector: Vec<f64>,
107 meta: HashMap<String, String>,
108 ) -> Result<usize, StoreError> {
109 let label = label.into();
110
111 match self.dim {
113 Some(d) if d != vector.len() => {
114 return Err(StoreError::DimensionMismatch {
115 expected: d,
116 got: vector.len(),
117 });
118 }
119 None => {
120 self.dim = Some(vector.len());
121 }
122 Some(_) => {}
123 }
124
125 let id = self.entries.len();
126
127 if let Some(&idx) = self.label_index.get(&label) {
129 self.entries[idx].vector = vector;
131 self.entries[idx].metadata = meta;
132 return Ok(self.entries[idx].id);
133 }
134
135 self.label_index.insert(label.clone(), id);
136 self.entries.push(EmbeddingEntry {
137 id,
138 label,
139 vector,
140 metadata: meta,
141 });
142 Ok(id)
143 }
144
145 pub fn get_by_label(&self, label: &str) -> Option<&EmbeddingEntry> {
147 let idx = self.label_index.get(label)?;
148 self.entries.get(*idx)
149 }
150
151 pub fn get_by_id(&self, id: usize) -> Option<&EmbeddingEntry> {
153 self.entries.iter().find(|e| e.id == id)
154 }
155
156 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163 if norm_a == 0.0 || norm_b == 0.0 {
164 return 0.0;
165 }
166 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
167 }
168
169 pub fn nearest(
175 &self,
176 query: &[f64],
177 k: usize,
178 ) -> Result<Vec<(&EmbeddingEntry, f64)>, StoreError> {
179 if self.entries.is_empty() {
180 return Err(StoreError::EmptyStore);
181 }
182 if let Some(d) = self.dim {
183 if query.len() != d {
184 return Err(StoreError::DimensionMismatch {
185 expected: d,
186 got: query.len(),
187 });
188 }
189 }
190
191 let mut scored: Vec<(&EmbeddingEntry, f64)> = self
192 .entries
193 .iter()
194 .map(|e| (e, Self::cosine_similarity(query, &e.vector)))
195 .collect();
196
197 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198 scored.truncate(k);
199 Ok(scored)
200 }
201
202 pub fn len(&self) -> usize {
204 self.entries.len()
205 }
206
207 pub fn is_empty(&self) -> bool {
209 self.entries.is_empty()
210 }
211
212 pub fn dim(&self) -> Option<usize> {
215 self.dim
216 }
217
218 pub fn labels(&self) -> Vec<&str> {
220 self.entries.iter().map(|e| e.label.as_str()).collect()
221 }
222
223 pub fn remove(&mut self, label: &str) -> bool {
229 if let Some(idx) = self.label_index.remove(label) {
230 self.entries.remove(idx);
231 self.label_index.clear();
233 for (i, entry) in self.entries.iter().enumerate() {
234 self.label_index.insert(entry.label.clone(), i);
235 }
236 if self.entries.is_empty() {
238 self.dim = None;
239 }
240 true
241 } else {
242 false
243 }
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn v2(x: f64, y: f64) -> Vec<f64> {
256 vec![x, y]
257 }
258
259 fn v3(x: f64, y: f64, z: f64) -> Vec<f64> {
260 vec![x, y, z]
261 }
262
263 #[test]
266 fn test_new_empty() {
267 let store = EmbeddingStore::new();
268 assert!(store.is_empty());
269 assert_eq!(store.len(), 0);
270 assert!(store.dim().is_none());
271 }
272
273 #[test]
274 fn test_insert_first_sets_dim() {
275 let mut store = EmbeddingStore::new();
276 store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
277 assert_eq!(store.dim(), Some(3));
278 }
279
280 #[test]
281 fn test_insert_returns_id() {
282 let mut store = EmbeddingStore::new();
283 let id = store.insert("a", v2(1.0, 0.0)).unwrap();
284 assert_eq!(id, 0);
285 let id2 = store.insert("b", v2(0.0, 1.0)).unwrap();
286 assert_eq!(id2, 1);
287 }
288
289 #[test]
290 fn test_insert_increments_len() {
291 let mut store = EmbeddingStore::new();
292 store.insert("a", v2(1.0, 0.0)).unwrap();
293 assert_eq!(store.len(), 1);
294 store.insert("b", v2(0.0, 1.0)).unwrap();
295 assert_eq!(store.len(), 2);
296 }
297
298 #[test]
299 fn test_insert_dim_mismatch_error() {
300 let mut store = EmbeddingStore::new();
301 store.insert("a", v2(1.0, 0.0)).unwrap();
302 let result = store.insert("b", v3(0.0, 1.0, 0.0));
303 assert!(matches!(
304 result,
305 Err(StoreError::DimensionMismatch {
306 expected: 2,
307 got: 3
308 })
309 ));
310 }
311
312 #[test]
313 fn test_insert_update_existing_label() {
314 let mut store = EmbeddingStore::new();
315 let id1 = store.insert("a", v2(1.0, 0.0)).unwrap();
316 let id2 = store.insert("a", v2(0.5, 0.5)).unwrap();
317 assert_eq!(id1, id2);
319 assert_eq!(store.len(), 1);
320 let e = store.get_by_label("a").expect("exists");
321 assert!((e.vector[0] - 0.5).abs() < 1e-9);
322 }
323
324 #[test]
327 fn test_insert_with_meta_stores_metadata() {
328 let mut store = EmbeddingStore::new();
329 let mut meta = HashMap::new();
330 meta.insert("lang".to_string(), "en".to_string());
331 store.insert_with_meta("doc1", v2(1.0, 0.0), meta).unwrap();
332 let e = store.get_by_label("doc1").expect("exists");
333 assert_eq!(e.metadata["lang"], "en");
334 }
335
336 #[test]
339 fn test_get_by_label_existing() {
340 let mut store = EmbeddingStore::new();
341 store.insert("hello", v2(1.0, 0.0)).unwrap();
342 assert!(store.get_by_label("hello").is_some());
343 }
344
345 #[test]
346 fn test_get_by_label_missing() {
347 let store = EmbeddingStore::new();
348 assert!(store.get_by_label("missing").is_none());
349 }
350
351 #[test]
352 fn test_get_by_label_returns_correct_vector() {
353 let mut store = EmbeddingStore::new();
354 store.insert("x", v2(3.0, 4.0)).unwrap();
355 let e = store.get_by_label("x").expect("exists");
356 assert!((e.vector[0] - 3.0).abs() < 1e-9);
357 assert!((e.vector[1] - 4.0).abs() < 1e-9);
358 }
359
360 #[test]
363 fn test_get_by_id_existing() {
364 let mut store = EmbeddingStore::new();
365 let id = store.insert("a", v2(1.0, 0.0)).unwrap();
366 assert!(store.get_by_id(id).is_some());
367 }
368
369 #[test]
370 fn test_get_by_id_missing() {
371 let store = EmbeddingStore::new();
372 assert!(store.get_by_id(999).is_none());
373 }
374
375 #[test]
376 fn test_get_by_id_matches_label() {
377 let mut store = EmbeddingStore::new();
378 let id = store.insert("mykey", v2(1.0, 2.0)).unwrap();
379 let e = store.get_by_id(id).expect("exists");
380 assert_eq!(e.label, "mykey");
381 }
382
383 #[test]
386 fn test_cosine_identical_vectors() {
387 let a = v3(1.0, 2.0, 3.0);
388 let sim = EmbeddingStore::cosine_similarity(&a, &a);
389 assert!((sim - 1.0).abs() < 1e-9);
390 }
391
392 #[test]
393 fn test_cosine_orthogonal_vectors() {
394 let a = v2(1.0, 0.0);
395 let b = v2(0.0, 1.0);
396 let sim = EmbeddingStore::cosine_similarity(&a, &b);
397 assert!(sim.abs() < 1e-9);
398 }
399
400 #[test]
401 fn test_cosine_opposite_vectors() {
402 let a = v2(1.0, 0.0);
403 let b = v2(-1.0, 0.0);
404 let sim = EmbeddingStore::cosine_similarity(&a, &b);
405 assert!((sim - (-1.0)).abs() < 1e-9);
406 }
407
408 #[test]
409 fn test_cosine_zero_vector_returns_zero() {
410 let a = v2(0.0, 0.0);
411 let b = v2(1.0, 0.0);
412 let sim = EmbeddingStore::cosine_similarity(&a, &b);
413 assert_eq!(sim, 0.0);
414 }
415
416 #[test]
417 fn test_cosine_symmetry() {
418 let a = v3(1.0, 2.0, 3.0);
419 let b = v3(4.0, 5.0, 6.0);
420 let sim_ab = EmbeddingStore::cosine_similarity(&a, &b);
421 let sim_ba = EmbeddingStore::cosine_similarity(&b, &a);
422 assert!((sim_ab - sim_ba).abs() < 1e-9);
423 }
424
425 #[test]
428 fn test_nearest_empty_store_error() {
429 let store = EmbeddingStore::new();
430 assert!(matches!(
431 store.nearest(&[1.0, 0.0], 3),
432 Err(StoreError::EmptyStore)
433 ));
434 }
435
436 #[test]
437 fn test_nearest_dim_mismatch_error() {
438 let mut store = EmbeddingStore::new();
439 store.insert("a", v2(1.0, 0.0)).unwrap();
440 assert!(matches!(
441 store.nearest(&[1.0, 0.0, 0.0], 3),
442 Err(StoreError::DimensionMismatch { .. })
443 ));
444 }
445
446 #[test]
447 fn test_nearest_returns_k_results() {
448 let mut store = EmbeddingStore::new();
449 for i in 0..5 {
450 store.insert(format!("e{i}"), vec![i as f64, 0.0]).unwrap();
451 }
452 let results = store.nearest(&[1.0, 0.0], 3).unwrap();
453 assert_eq!(results.len(), 3);
454 }
455
456 #[test]
457 fn test_nearest_sorted_descending() {
458 let mut store = EmbeddingStore::new();
459 store.insert("up", v2(0.0, 1.0)).unwrap();
460 store.insert("right", v2(1.0, 0.0)).unwrap();
461 store.insert("diag", v2(1.0, 1.0)).unwrap();
462 let query = v2(1.0, 0.0);
463 let results = store.nearest(&query, 3).unwrap();
464 let sims: Vec<f64> = results.iter().map(|(_, s)| *s).collect();
465 for pair in sims.windows(2) {
466 assert!(pair[0] >= pair[1]);
467 }
468 }
469
470 #[test]
471 fn test_nearest_top1_is_most_similar() {
472 let mut store = EmbeddingStore::new();
473 store.insert("a", v2(1.0, 0.0)).unwrap();
474 store.insert("b", v2(0.0, 1.0)).unwrap();
475 store.insert("c", v2(-1.0, 0.0)).unwrap();
476 let results = store.nearest(&[1.0, 0.0], 1).unwrap();
477 assert_eq!(results[0].0.label, "a");
478 }
479
480 #[test]
483 fn test_labels_empty() {
484 let store = EmbeddingStore::new();
485 assert!(store.labels().is_empty());
486 }
487
488 #[test]
489 fn test_labels_returns_all() {
490 let mut store = EmbeddingStore::new();
491 store.insert("alpha", v2(1.0, 0.0)).unwrap();
492 store.insert("beta", v2(0.0, 1.0)).unwrap();
493 let labels = store.labels();
494 assert_eq!(labels.len(), 2);
495 assert!(labels.contains(&"alpha"));
496 assert!(labels.contains(&"beta"));
497 }
498
499 #[test]
502 fn test_remove_existing_returns_true() {
503 let mut store = EmbeddingStore::new();
504 store.insert("a", v2(1.0, 0.0)).unwrap();
505 assert!(store.remove("a"));
506 assert!(store.is_empty());
507 }
508
509 #[test]
510 fn test_remove_missing_returns_false() {
511 let mut store = EmbeddingStore::new();
512 assert!(!store.remove("ghost"));
513 }
514
515 #[test]
516 fn test_remove_decrements_len() {
517 let mut store = EmbeddingStore::new();
518 store.insert("a", v2(1.0, 0.0)).unwrap();
519 store.insert("b", v2(0.0, 1.0)).unwrap();
520 store.remove("a");
521 assert_eq!(store.len(), 1);
522 }
523
524 #[test]
525 fn test_remove_remaining_entry_still_accessible() {
526 let mut store = EmbeddingStore::new();
527 store.insert("a", v2(1.0, 0.0)).unwrap();
528 store.insert("b", v2(0.0, 1.0)).unwrap();
529 store.remove("a");
530 assert!(store.get_by_label("b").is_some());
531 }
532
533 #[test]
534 fn test_remove_all_resets_dim() {
535 let mut store = EmbeddingStore::new();
536 store.insert("a", v2(1.0, 0.0)).unwrap();
537 store.remove("a");
538 assert!(store.dim().is_none());
539 }
540
541 #[test]
542 fn test_remove_allows_reinsertion_with_different_dim() {
543 let mut store = EmbeddingStore::new();
544 store.insert("a", v2(1.0, 0.0)).unwrap();
545 store.remove("a");
546 store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
548 assert_eq!(store.dim(), Some(3));
549 }
550
551 #[test]
554 fn test_default_same_as_new() {
555 let store = EmbeddingStore::default();
556 assert!(store.is_empty());
557 }
558
559 #[test]
562 fn test_error_display_dimension_mismatch() {
563 let e = StoreError::DimensionMismatch {
564 expected: 3,
565 got: 2,
566 };
567 assert!(!e.to_string().is_empty());
568 }
569
570 #[test]
571 fn test_error_display_label_not_found() {
572 let e = StoreError::LabelNotFound("ghost".to_string());
573 assert!(e.to_string().contains("ghost"));
574 }
575
576 #[test]
577 fn test_error_display_empty_store() {
578 let e = StoreError::EmptyStore;
579 assert!(!e.to_string().is_empty());
580 }
581
582 #[test]
585 fn test_nearest_k_larger_than_store() {
586 let mut store = EmbeddingStore::new();
587 store.insert("a", v2(1.0, 0.0)).unwrap();
588 store.insert("b", v2(0.0, 1.0)).unwrap();
589 let results = store.nearest(&[1.0, 1.0], 10).unwrap();
590 assert_eq!(results.len(), 2);
592 }
593
594 #[test]
595 fn test_id_is_stable_for_inserted_entry() {
596 let mut store = EmbeddingStore::new();
597 let id = store.insert("vec", v2(1.0, 1.0)).unwrap();
598 let e = store.get_by_label("vec").expect("exists");
599 assert_eq!(e.id, id);
600 }
601
602 #[test]
603 fn test_entry_label_matches() {
604 let mut store = EmbeddingStore::new();
605 store.insert("myLabel", v2(0.5, 0.5)).unwrap();
606 let e = store.get_by_label("myLabel").expect("exists");
607 assert_eq!(e.label, "myLabel");
608 }
609
610 #[test]
613 fn test_insert_empty_vector_sets_dim_zero() {
614 let mut store = EmbeddingStore::new();
615 store.insert("empty", vec![]).unwrap();
616 assert_eq!(store.dim(), Some(0));
617 }
618
619 #[test]
620 fn test_cosine_unit_vectors() {
621 let a = vec![1.0_f64 / 2.0_f64.sqrt(), 1.0_f64 / 2.0_f64.sqrt()];
623 let b = vec![1.0, 0.0];
624 let sim = EmbeddingStore::cosine_similarity(&a, &b);
625 assert!((sim - (1.0_f64 / 2.0_f64.sqrt())).abs() < 1e-9);
626 }
627
628 #[test]
629 fn test_nearest_returns_fewer_when_store_smaller_than_k() {
630 let mut store = EmbeddingStore::new();
631 store.insert("a", v2(1.0, 0.0)).unwrap();
632 let results = store.nearest(&[1.0, 0.0], 100).unwrap();
633 assert_eq!(results.len(), 1);
634 }
635
636 #[test]
637 fn test_remove_all_entries_allows_new_dim() {
638 let mut store = EmbeddingStore::new();
639 store.insert("a", v2(1.0, 0.0)).unwrap();
640 store.insert("b", v2(0.0, 1.0)).unwrap();
641 store.remove("a");
642 store.remove("b");
643 assert_eq!(store.dim(), None);
644 store.insert("c", v3(1.0, 0.0, 0.0)).unwrap();
646 assert_eq!(store.dim(), Some(3));
647 }
648
649 #[test]
650 fn test_get_by_id_after_remove_middle() {
651 let mut store = EmbeddingStore::new();
652 let id_a = store.insert("a", v2(1.0, 0.0)).unwrap();
653 store.insert("b", v2(0.0, 1.0)).unwrap();
654 let id_c = store.insert("c", v2(0.5, 0.5)).unwrap();
655 store.remove("b");
656 assert!(store.get_by_id(id_a).is_some());
658 assert!(store.get_by_id(id_c).is_some());
659 }
660
661 #[test]
662 fn test_insert_with_meta_empty_meta() {
663 let mut store = EmbeddingStore::new();
664 store
665 .insert_with_meta("doc", v2(1.0, 0.0), HashMap::new())
666 .unwrap();
667 let e = store.get_by_label("doc").expect("exists");
668 assert!(e.metadata.is_empty());
669 }
670
671 #[test]
672 fn test_nearest_similarity_range() {
673 let mut store = EmbeddingStore::new();
674 store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
675 store.insert("b", v3(0.0, 1.0, 0.0)).unwrap();
676 store.insert("c", v3(0.0, 0.0, 1.0)).unwrap();
677 let results = store.nearest(&[1.0, 0.0, 0.0], 3).unwrap();
678 for (_, sim) in &results {
679 assert!(*sim >= -1.0 && *sim <= 1.0);
680 }
681 }
682}