Skip to main content

oxirs_embed/
embedding_store.rs

1//! In-memory embedding store with label-based lookup and cosine similarity search.
2//!
3//! Provides `O(1)` label and id access and `O(n·d)` nearest-neighbour search
4//! over `n` embeddings of dimension `d`.
5
6use std::collections::HashMap;
7
8// ---------------------------------------------------------------------------
9// Error type
10// ---------------------------------------------------------------------------
11
12/// Errors that can be returned by [`EmbeddingStore`] operations.
13#[derive(Debug)]
14pub enum StoreError {
15    /// The supplied vector has the wrong number of dimensions.
16    DimensionMismatch {
17        /// The dimension expected by the store.
18        expected: usize,
19        /// The dimension of the supplied vector.
20        got: usize,
21    },
22    /// No entry exists with the given label.
23    LabelNotFound(String),
24    /// The store contains no entries.
25    EmptyStore,
26}
27
28impl std::fmt::Display for StoreError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            StoreError::DimensionMismatch { expected, got } => {
32                write!(f, "dimension mismatch: expected {expected}, got {got}")
33            }
34            StoreError::LabelNotFound(label) => {
35                write!(f, "label not found: {label}")
36            }
37            StoreError::EmptyStore => write!(f, "store is empty"),
38        }
39    }
40}
41
42impl std::error::Error for StoreError {}
43
44// ---------------------------------------------------------------------------
45// Entry
46// ---------------------------------------------------------------------------
47
48/// A stored embedding entry.
49#[derive(Debug, Clone)]
50pub struct EmbeddingEntry {
51    /// Sequential identifier assigned at insertion time.
52    pub id: usize,
53    /// Human-readable label used as the primary key.
54    pub label: String,
55    /// The embedding vector.
56    pub vector: Vec<f64>,
57    /// Optional key-value metadata attached to this entry.
58    pub metadata: HashMap<String, String>,
59}
60
61// ---------------------------------------------------------------------------
62// Store
63// ---------------------------------------------------------------------------
64
65/// In-memory store for labelled embedding vectors.
66pub struct EmbeddingStore {
67    entries: Vec<EmbeddingEntry>,
68    label_index: HashMap<String, usize>, // label → index into `entries`
69    dim: Option<usize>,
70}
71
72impl Default for EmbeddingStore {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl EmbeddingStore {
79    /// Create an empty `EmbeddingStore`.
80    pub fn new() -> Self {
81        Self {
82            entries: Vec::new(),
83            label_index: HashMap::new(),
84            dim: None,
85        }
86    }
87
88    /// Insert a new embedding.
89    ///
90    /// All vectors in the store must have the same dimension.  The first
91    /// insertion fixes the dimension for all subsequent insertions.
92    ///
93    /// Returns the assigned `id` on success.
94    pub fn insert(
95        &mut self,
96        label: impl Into<String>,
97        vector: Vec<f64>,
98    ) -> Result<usize, StoreError> {
99        self.insert_with_meta(label, vector, HashMap::new())
100    }
101
102    /// Insert a new embedding with accompanying metadata.
103    pub fn insert_with_meta(
104        &mut self,
105        label: impl Into<String>,
106        vector: Vec<f64>,
107        meta: HashMap<String, String>,
108    ) -> Result<usize, StoreError> {
109        let label = label.into();
110
111        // Check / set dimension
112        match self.dim {
113            Some(d) if d != vector.len() => {
114                return Err(StoreError::DimensionMismatch {
115                    expected: d,
116                    got: vector.len(),
117                });
118            }
119            None => {
120                self.dim = Some(vector.len());
121            }
122            Some(_) => {}
123        }
124
125        let id = self.entries.len();
126
127        // Update or insert
128        if let Some(&idx) = self.label_index.get(&label) {
129            // Update existing entry
130            self.entries[idx].vector = vector;
131            self.entries[idx].metadata = meta;
132            return Ok(self.entries[idx].id);
133        }
134
135        self.label_index.insert(label.clone(), id);
136        self.entries.push(EmbeddingEntry {
137            id,
138            label,
139            vector,
140            metadata: meta,
141        });
142        Ok(id)
143    }
144
145    /// Look up an entry by its label.
146    pub fn get_by_label(&self, label: &str) -> Option<&EmbeddingEntry> {
147        let idx = self.label_index.get(label)?;
148        self.entries.get(*idx)
149    }
150
151    /// Look up an entry by its sequential `id`.
152    pub fn get_by_id(&self, id: usize) -> Option<&EmbeddingEntry> {
153        self.entries.iter().find(|e| e.id == id)
154    }
155
156    /// Compute the cosine similarity between two slices of equal length.
157    ///
158    /// Returns `0.0` when either vector has zero norm.
159    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163        if norm_a == 0.0 || norm_b == 0.0 {
164            return 0.0;
165        }
166        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
167    }
168
169    /// Return the `k` nearest entries to `query` by cosine similarity, in
170    /// descending order of similarity.
171    ///
172    /// Returns [`StoreError::EmptyStore`] when the store is empty and
173    /// [`StoreError::DimensionMismatch`] when `query` has the wrong length.
174    pub fn nearest(
175        &self,
176        query: &[f64],
177        k: usize,
178    ) -> Result<Vec<(&EmbeddingEntry, f64)>, StoreError> {
179        if self.entries.is_empty() {
180            return Err(StoreError::EmptyStore);
181        }
182        if let Some(d) = self.dim {
183            if query.len() != d {
184                return Err(StoreError::DimensionMismatch {
185                    expected: d,
186                    got: query.len(),
187                });
188            }
189        }
190
191        let mut scored: Vec<(&EmbeddingEntry, f64)> = self
192            .entries
193            .iter()
194            .map(|e| (e, Self::cosine_similarity(query, &e.vector)))
195            .collect();
196
197        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198        scored.truncate(k);
199        Ok(scored)
200    }
201
202    /// Number of entries currently in the store.
203    pub fn len(&self) -> usize {
204        self.entries.len()
205    }
206
207    /// `true` when the store is empty.
208    pub fn is_empty(&self) -> bool {
209        self.entries.is_empty()
210    }
211
212    /// The fixed dimension of all stored vectors, or `None` if no vectors
213    /// have been inserted yet.
214    pub fn dim(&self) -> Option<usize> {
215        self.dim
216    }
217
218    /// Return a list of all labels in insertion order.
219    pub fn labels(&self) -> Vec<&str> {
220        self.entries.iter().map(|e| e.label.as_str()).collect()
221    }
222
223    /// Remove the entry with the given label.
224    ///
225    /// Returns `true` if the entry was found and removed, `false` otherwise.
226    ///
227    /// Note: removing an entry does **not** reuse or reassign its `id`.
228    pub fn remove(&mut self, label: &str) -> bool {
229        if let Some(idx) = self.label_index.remove(label) {
230            self.entries.remove(idx);
231            // Rebuild the label → index mapping because indices have shifted
232            self.label_index.clear();
233            for (i, entry) in self.entries.iter().enumerate() {
234                self.label_index.insert(entry.label.clone(), i);
235            }
236            // Reset dim if store is now empty
237            if self.entries.is_empty() {
238                self.dim = None;
239            }
240            true
241        } else {
242            false
243        }
244    }
245}
246
247// ---------------------------------------------------------------------------
248// Tests
249// ---------------------------------------------------------------------------
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn v2(x: f64, y: f64) -> Vec<f64> {
256        vec![x, y]
257    }
258
259    fn v3(x: f64, y: f64, z: f64) -> Vec<f64> {
260        vec![x, y, z]
261    }
262
263    // --- insert / len / dim ---
264
265    #[test]
266    fn test_new_empty() {
267        let store = EmbeddingStore::new();
268        assert!(store.is_empty());
269        assert_eq!(store.len(), 0);
270        assert!(store.dim().is_none());
271    }
272
273    #[test]
274    fn test_insert_first_sets_dim() {
275        let mut store = EmbeddingStore::new();
276        store
277            .insert("a", v3(1.0, 0.0, 0.0))
278            .expect("should succeed");
279        assert_eq!(store.dim(), Some(3));
280    }
281
282    #[test]
283    fn test_insert_returns_id() {
284        let mut store = EmbeddingStore::new();
285        let id = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
286        assert_eq!(id, 0);
287        let id2 = store.insert("b", v2(0.0, 1.0)).expect("should succeed");
288        assert_eq!(id2, 1);
289    }
290
291    #[test]
292    fn test_insert_increments_len() {
293        let mut store = EmbeddingStore::new();
294        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
295        assert_eq!(store.len(), 1);
296        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
297        assert_eq!(store.len(), 2);
298    }
299
300    #[test]
301    fn test_insert_dim_mismatch_error() {
302        let mut store = EmbeddingStore::new();
303        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
304        let result = store.insert("b", v3(0.0, 1.0, 0.0));
305        assert!(matches!(
306            result,
307            Err(StoreError::DimensionMismatch {
308                expected: 2,
309                got: 3
310            })
311        ));
312    }
313
314    #[test]
315    fn test_insert_update_existing_label() {
316        let mut store = EmbeddingStore::new();
317        let id1 = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
318        let id2 = store.insert("a", v2(0.5, 0.5)).expect("should succeed");
319        // Same id, same len
320        assert_eq!(id1, id2);
321        assert_eq!(store.len(), 1);
322        let e = store.get_by_label("a").expect("exists");
323        assert!((e.vector[0] - 0.5).abs() < 1e-9);
324    }
325
326    // --- insert_with_meta ---
327
328    #[test]
329    fn test_insert_with_meta_stores_metadata() {
330        let mut store = EmbeddingStore::new();
331        let mut meta = HashMap::new();
332        meta.insert("lang".to_string(), "en".to_string());
333        store
334            .insert_with_meta("doc1", v2(1.0, 0.0), meta)
335            .expect("should succeed");
336        let e = store.get_by_label("doc1").expect("exists");
337        assert_eq!(e.metadata["lang"], "en");
338    }
339
340    // --- get_by_label ---
341
342    #[test]
343    fn test_get_by_label_existing() {
344        let mut store = EmbeddingStore::new();
345        store.insert("hello", v2(1.0, 0.0)).expect("should succeed");
346        assert!(store.get_by_label("hello").is_some());
347    }
348
349    #[test]
350    fn test_get_by_label_missing() {
351        let store = EmbeddingStore::new();
352        assert!(store.get_by_label("missing").is_none());
353    }
354
355    #[test]
356    fn test_get_by_label_returns_correct_vector() {
357        let mut store = EmbeddingStore::new();
358        store.insert("x", v2(3.0, 4.0)).expect("should succeed");
359        let e = store.get_by_label("x").expect("exists");
360        assert!((e.vector[0] - 3.0).abs() < 1e-9);
361        assert!((e.vector[1] - 4.0).abs() < 1e-9);
362    }
363
364    // --- get_by_id ---
365
366    #[test]
367    fn test_get_by_id_existing() {
368        let mut store = EmbeddingStore::new();
369        let id = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
370        assert!(store.get_by_id(id).is_some());
371    }
372
373    #[test]
374    fn test_get_by_id_missing() {
375        let store = EmbeddingStore::new();
376        assert!(store.get_by_id(999).is_none());
377    }
378
379    #[test]
380    fn test_get_by_id_matches_label() {
381        let mut store = EmbeddingStore::new();
382        let id = store.insert("mykey", v2(1.0, 2.0)).expect("should succeed");
383        let e = store.get_by_id(id).expect("exists");
384        assert_eq!(e.label, "mykey");
385    }
386
387    // --- cosine_similarity ---
388
389    #[test]
390    fn test_cosine_identical_vectors() {
391        let a = v3(1.0, 2.0, 3.0);
392        let sim = EmbeddingStore::cosine_similarity(&a, &a);
393        assert!((sim - 1.0).abs() < 1e-9);
394    }
395
396    #[test]
397    fn test_cosine_orthogonal_vectors() {
398        let a = v2(1.0, 0.0);
399        let b = v2(0.0, 1.0);
400        let sim = EmbeddingStore::cosine_similarity(&a, &b);
401        assert!(sim.abs() < 1e-9);
402    }
403
404    #[test]
405    fn test_cosine_opposite_vectors() {
406        let a = v2(1.0, 0.0);
407        let b = v2(-1.0, 0.0);
408        let sim = EmbeddingStore::cosine_similarity(&a, &b);
409        assert!((sim - (-1.0)).abs() < 1e-9);
410    }
411
412    #[test]
413    fn test_cosine_zero_vector_returns_zero() {
414        let a = v2(0.0, 0.0);
415        let b = v2(1.0, 0.0);
416        let sim = EmbeddingStore::cosine_similarity(&a, &b);
417        assert_eq!(sim, 0.0);
418    }
419
420    #[test]
421    fn test_cosine_symmetry() {
422        let a = v3(1.0, 2.0, 3.0);
423        let b = v3(4.0, 5.0, 6.0);
424        let sim_ab = EmbeddingStore::cosine_similarity(&a, &b);
425        let sim_ba = EmbeddingStore::cosine_similarity(&b, &a);
426        assert!((sim_ab - sim_ba).abs() < 1e-9);
427    }
428
429    // --- nearest ---
430
431    #[test]
432    fn test_nearest_empty_store_error() {
433        let store = EmbeddingStore::new();
434        assert!(matches!(
435            store.nearest(&[1.0, 0.0], 3),
436            Err(StoreError::EmptyStore)
437        ));
438    }
439
440    #[test]
441    fn test_nearest_dim_mismatch_error() {
442        let mut store = EmbeddingStore::new();
443        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
444        assert!(matches!(
445            store.nearest(&[1.0, 0.0, 0.0], 3),
446            Err(StoreError::DimensionMismatch { .. })
447        ));
448    }
449
450    #[test]
451    fn test_nearest_returns_k_results() {
452        let mut store = EmbeddingStore::new();
453        for i in 0..5 {
454            store
455                .insert(format!("e{i}"), vec![i as f64, 0.0])
456                .expect("should succeed");
457        }
458        let results = store.nearest(&[1.0, 0.0], 3).expect("should succeed");
459        assert_eq!(results.len(), 3);
460    }
461
462    #[test]
463    fn test_nearest_sorted_descending() {
464        let mut store = EmbeddingStore::new();
465        store.insert("up", v2(0.0, 1.0)).expect("should succeed");
466        store.insert("right", v2(1.0, 0.0)).expect("should succeed");
467        store.insert("diag", v2(1.0, 1.0)).expect("should succeed");
468        let query = v2(1.0, 0.0);
469        let results = store.nearest(&query, 3).expect("should succeed");
470        let sims: Vec<f64> = results.iter().map(|(_, s)| *s).collect();
471        for pair in sims.windows(2) {
472            assert!(pair[0] >= pair[1]);
473        }
474    }
475
476    #[test]
477    fn test_nearest_top1_is_most_similar() {
478        let mut store = EmbeddingStore::new();
479        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
480        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
481        store.insert("c", v2(-1.0, 0.0)).expect("should succeed");
482        let results = store.nearest(&[1.0, 0.0], 1).expect("should succeed");
483        assert_eq!(results[0].0.label, "a");
484    }
485
486    // --- labels ---
487
488    #[test]
489    fn test_labels_empty() {
490        let store = EmbeddingStore::new();
491        assert!(store.labels().is_empty());
492    }
493
494    #[test]
495    fn test_labels_returns_all() {
496        let mut store = EmbeddingStore::new();
497        store.insert("alpha", v2(1.0, 0.0)).expect("should succeed");
498        store.insert("beta", v2(0.0, 1.0)).expect("should succeed");
499        let labels = store.labels();
500        assert_eq!(labels.len(), 2);
501        assert!(labels.contains(&"alpha"));
502        assert!(labels.contains(&"beta"));
503    }
504
505    // --- remove ---
506
507    #[test]
508    fn test_remove_existing_returns_true() {
509        let mut store = EmbeddingStore::new();
510        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
511        assert!(store.remove("a"));
512        assert!(store.is_empty());
513    }
514
515    #[test]
516    fn test_remove_missing_returns_false() {
517        let mut store = EmbeddingStore::new();
518        assert!(!store.remove("ghost"));
519    }
520
521    #[test]
522    fn test_remove_decrements_len() {
523        let mut store = EmbeddingStore::new();
524        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
525        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
526        store.remove("a");
527        assert_eq!(store.len(), 1);
528    }
529
530    #[test]
531    fn test_remove_remaining_entry_still_accessible() {
532        let mut store = EmbeddingStore::new();
533        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
534        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
535        store.remove("a");
536        assert!(store.get_by_label("b").is_some());
537    }
538
539    #[test]
540    fn test_remove_all_resets_dim() {
541        let mut store = EmbeddingStore::new();
542        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
543        store.remove("a");
544        assert!(store.dim().is_none());
545    }
546
547    #[test]
548    fn test_remove_allows_reinsertion_with_different_dim() {
549        let mut store = EmbeddingStore::new();
550        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
551        store.remove("a");
552        // After removing the only entry, dim is reset, so new dimension is allowed
553        store
554            .insert("a", v3(1.0, 0.0, 0.0))
555            .expect("should succeed");
556        assert_eq!(store.dim(), Some(3));
557    }
558
559    // --- default ---
560
561    #[test]
562    fn test_default_same_as_new() {
563        let store = EmbeddingStore::default();
564        assert!(store.is_empty());
565    }
566
567    // --- StoreError display ---
568
569    #[test]
570    fn test_error_display_dimension_mismatch() {
571        let e = StoreError::DimensionMismatch {
572            expected: 3,
573            got: 2,
574        };
575        assert!(!e.to_string().is_empty());
576    }
577
578    #[test]
579    fn test_error_display_label_not_found() {
580        let e = StoreError::LabelNotFound("ghost".to_string());
581        assert!(e.to_string().contains("ghost"));
582    }
583
584    #[test]
585    fn test_error_display_empty_store() {
586        let e = StoreError::EmptyStore;
587        assert!(!e.to_string().is_empty());
588    }
589
590    // --- additional scenarios ---
591
592    #[test]
593    fn test_nearest_k_larger_than_store() {
594        let mut store = EmbeddingStore::new();
595        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
596        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
597        let results = store.nearest(&[1.0, 1.0], 10).expect("should succeed");
598        // Cannot return more than what's in the store
599        assert_eq!(results.len(), 2);
600    }
601
602    #[test]
603    fn test_id_is_stable_for_inserted_entry() {
604        let mut store = EmbeddingStore::new();
605        let id = store.insert("vec", v2(1.0, 1.0)).expect("should succeed");
606        let e = store.get_by_label("vec").expect("exists");
607        assert_eq!(e.id, id);
608    }
609
610    #[test]
611    fn test_entry_label_matches() {
612        let mut store = EmbeddingStore::new();
613        store
614            .insert("myLabel", v2(0.5, 0.5))
615            .expect("should succeed");
616        let e = store.get_by_label("myLabel").expect("exists");
617        assert_eq!(e.label, "myLabel");
618    }
619
620    // --- additional coverage ---
621
622    #[test]
623    fn test_insert_empty_vector_sets_dim_zero() {
624        let mut store = EmbeddingStore::new();
625        store.insert("empty", vec![]).expect("should succeed");
626        assert_eq!(store.dim(), Some(0));
627    }
628
629    #[test]
630    fn test_cosine_unit_vectors() {
631        // Two unit vectors at 45° apart
632        let a = vec![1.0_f64 / 2.0_f64.sqrt(), 1.0_f64 / 2.0_f64.sqrt()];
633        let b = vec![1.0, 0.0];
634        let sim = EmbeddingStore::cosine_similarity(&a, &b);
635        assert!((sim - (1.0_f64 / 2.0_f64.sqrt())).abs() < 1e-9);
636    }
637
638    #[test]
639    fn test_nearest_returns_fewer_when_store_smaller_than_k() {
640        let mut store = EmbeddingStore::new();
641        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
642        let results = store.nearest(&[1.0, 0.0], 100).expect("should succeed");
643        assert_eq!(results.len(), 1);
644    }
645
646    #[test]
647    fn test_remove_all_entries_allows_new_dim() {
648        let mut store = EmbeddingStore::new();
649        store.insert("a", v2(1.0, 0.0)).expect("should succeed");
650        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
651        store.remove("a");
652        store.remove("b");
653        assert_eq!(store.dim(), None);
654        // Should accept a 3-d vector now
655        store
656            .insert("c", v3(1.0, 0.0, 0.0))
657            .expect("should succeed");
658        assert_eq!(store.dim(), Some(3));
659    }
660
661    #[test]
662    fn test_get_by_id_after_remove_middle() {
663        let mut store = EmbeddingStore::new();
664        let id_a = store.insert("a", v2(1.0, 0.0)).expect("should succeed");
665        store.insert("b", v2(0.0, 1.0)).expect("should succeed");
666        let id_c = store.insert("c", v2(0.5, 0.5)).expect("should succeed");
667        store.remove("b");
668        // a and c should still be accessible by id
669        assert!(store.get_by_id(id_a).is_some());
670        assert!(store.get_by_id(id_c).is_some());
671    }
672
673    #[test]
674    fn test_insert_with_meta_empty_meta() {
675        let mut store = EmbeddingStore::new();
676        store
677            .insert_with_meta("doc", v2(1.0, 0.0), HashMap::new())
678            .expect("should succeed");
679        let e = store.get_by_label("doc").expect("exists");
680        assert!(e.metadata.is_empty());
681    }
682
683    #[test]
684    fn test_nearest_similarity_range() {
685        let mut store = EmbeddingStore::new();
686        store
687            .insert("a", v3(1.0, 0.0, 0.0))
688            .expect("should succeed");
689        store
690            .insert("b", v3(0.0, 1.0, 0.0))
691            .expect("should succeed");
692        store
693            .insert("c", v3(0.0, 0.0, 1.0))
694            .expect("should succeed");
695        let results = store.nearest(&[1.0, 0.0, 0.0], 3).expect("should succeed");
696        for (_, sim) in &results {
697            assert!(*sim >= -1.0 && *sim <= 1.0);
698        }
699    }
700}