Skip to main content

oxirs_embed/
embedding_store.rs

1//! In-memory embedding store with label-based lookup and cosine similarity search.
2//!
3//! Provides `O(1)` label and id access and `O(n·d)` nearest-neighbour search
4//! over `n` embeddings of dimension `d`.
5
6use std::collections::HashMap;
7
8// ---------------------------------------------------------------------------
9// Error type
10// ---------------------------------------------------------------------------
11
12/// Errors that can be returned by [`EmbeddingStore`] operations.
13#[derive(Debug)]
14pub enum StoreError {
15    /// The supplied vector has the wrong number of dimensions.
16    DimensionMismatch {
17        /// The dimension expected by the store.
18        expected: usize,
19        /// The dimension of the supplied vector.
20        got: usize,
21    },
22    /// No entry exists with the given label.
23    LabelNotFound(String),
24    /// The store contains no entries.
25    EmptyStore,
26}
27
28impl std::fmt::Display for StoreError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            StoreError::DimensionMismatch { expected, got } => {
32                write!(f, "dimension mismatch: expected {expected}, got {got}")
33            }
34            StoreError::LabelNotFound(label) => {
35                write!(f, "label not found: {label}")
36            }
37            StoreError::EmptyStore => write!(f, "store is empty"),
38        }
39    }
40}
41
42impl std::error::Error for StoreError {}
43
44// ---------------------------------------------------------------------------
45// Entry
46// ---------------------------------------------------------------------------
47
48/// A stored embedding entry.
49#[derive(Debug, Clone)]
50pub struct EmbeddingEntry {
51    /// Sequential identifier assigned at insertion time.
52    pub id: usize,
53    /// Human-readable label used as the primary key.
54    pub label: String,
55    /// The embedding vector.
56    pub vector: Vec<f64>,
57    /// Optional key-value metadata attached to this entry.
58    pub metadata: HashMap<String, String>,
59}
60
61// ---------------------------------------------------------------------------
62// Store
63// ---------------------------------------------------------------------------
64
65/// In-memory store for labelled embedding vectors.
66pub struct EmbeddingStore {
67    entries: Vec<EmbeddingEntry>,
68    label_index: HashMap<String, usize>, // label → index into `entries`
69    dim: Option<usize>,
70}
71
72impl Default for EmbeddingStore {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl EmbeddingStore {
79    /// Create an empty `EmbeddingStore`.
80    pub fn new() -> Self {
81        Self {
82            entries: Vec::new(),
83            label_index: HashMap::new(),
84            dim: None,
85        }
86    }
87
88    /// Insert a new embedding.
89    ///
90    /// All vectors in the store must have the same dimension.  The first
91    /// insertion fixes the dimension for all subsequent insertions.
92    ///
93    /// Returns the assigned `id` on success.
94    pub fn insert(
95        &mut self,
96        label: impl Into<String>,
97        vector: Vec<f64>,
98    ) -> Result<usize, StoreError> {
99        self.insert_with_meta(label, vector, HashMap::new())
100    }
101
102    /// Insert a new embedding with accompanying metadata.
103    pub fn insert_with_meta(
104        &mut self,
105        label: impl Into<String>,
106        vector: Vec<f64>,
107        meta: HashMap<String, String>,
108    ) -> Result<usize, StoreError> {
109        let label = label.into();
110
111        // Check / set dimension
112        match self.dim {
113            Some(d) if d != vector.len() => {
114                return Err(StoreError::DimensionMismatch {
115                    expected: d,
116                    got: vector.len(),
117                });
118            }
119            None => {
120                self.dim = Some(vector.len());
121            }
122            Some(_) => {}
123        }
124
125        let id = self.entries.len();
126
127        // Update or insert
128        if let Some(&idx) = self.label_index.get(&label) {
129            // Update existing entry
130            self.entries[idx].vector = vector;
131            self.entries[idx].metadata = meta;
132            return Ok(self.entries[idx].id);
133        }
134
135        self.label_index.insert(label.clone(), id);
136        self.entries.push(EmbeddingEntry {
137            id,
138            label,
139            vector,
140            metadata: meta,
141        });
142        Ok(id)
143    }
144
145    /// Look up an entry by its label.
146    pub fn get_by_label(&self, label: &str) -> Option<&EmbeddingEntry> {
147        let idx = self.label_index.get(label)?;
148        self.entries.get(*idx)
149    }
150
151    /// Look up an entry by its sequential `id`.
152    pub fn get_by_id(&self, id: usize) -> Option<&EmbeddingEntry> {
153        self.entries.iter().find(|e| e.id == id)
154    }
155
156    /// Compute the cosine similarity between two slices of equal length.
157    ///
158    /// Returns `0.0` when either vector has zero norm.
159    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163        if norm_a == 0.0 || norm_b == 0.0 {
164            return 0.0;
165        }
166        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
167    }
168
169    /// Return the `k` nearest entries to `query` by cosine similarity, in
170    /// descending order of similarity.
171    ///
172    /// Returns [`StoreError::EmptyStore`] when the store is empty and
173    /// [`StoreError::DimensionMismatch`] when `query` has the wrong length.
174    pub fn nearest(
175        &self,
176        query: &[f64],
177        k: usize,
178    ) -> Result<Vec<(&EmbeddingEntry, f64)>, StoreError> {
179        if self.entries.is_empty() {
180            return Err(StoreError::EmptyStore);
181        }
182        if let Some(d) = self.dim {
183            if query.len() != d {
184                return Err(StoreError::DimensionMismatch {
185                    expected: d,
186                    got: query.len(),
187                });
188            }
189        }
190
191        let mut scored: Vec<(&EmbeddingEntry, f64)> = self
192            .entries
193            .iter()
194            .map(|e| (e, Self::cosine_similarity(query, &e.vector)))
195            .collect();
196
197        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198        scored.truncate(k);
199        Ok(scored)
200    }
201
202    /// Number of entries currently in the store.
203    pub fn len(&self) -> usize {
204        self.entries.len()
205    }
206
207    /// `true` when the store is empty.
208    pub fn is_empty(&self) -> bool {
209        self.entries.is_empty()
210    }
211
212    /// The fixed dimension of all stored vectors, or `None` if no vectors
213    /// have been inserted yet.
214    pub fn dim(&self) -> Option<usize> {
215        self.dim
216    }
217
218    /// Return a list of all labels in insertion order.
219    pub fn labels(&self) -> Vec<&str> {
220        self.entries.iter().map(|e| e.label.as_str()).collect()
221    }
222
223    /// Remove the entry with the given label.
224    ///
225    /// Returns `true` if the entry was found and removed, `false` otherwise.
226    ///
227    /// Note: removing an entry does **not** reuse or reassign its `id`.
228    pub fn remove(&mut self, label: &str) -> bool {
229        if let Some(idx) = self.label_index.remove(label) {
230            self.entries.remove(idx);
231            // Rebuild the label → index mapping because indices have shifted
232            self.label_index.clear();
233            for (i, entry) in self.entries.iter().enumerate() {
234                self.label_index.insert(entry.label.clone(), i);
235            }
236            // Reset dim if store is now empty
237            if self.entries.is_empty() {
238                self.dim = None;
239            }
240            true
241        } else {
242            false
243        }
244    }
245}
246
247// ---------------------------------------------------------------------------
248// Tests
249// ---------------------------------------------------------------------------
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn v2(x: f64, y: f64) -> Vec<f64> {
256        vec![x, y]
257    }
258
259    fn v3(x: f64, y: f64, z: f64) -> Vec<f64> {
260        vec![x, y, z]
261    }
262
263    // --- insert / len / dim ---
264
265    #[test]
266    fn test_new_empty() {
267        let store = EmbeddingStore::new();
268        assert!(store.is_empty());
269        assert_eq!(store.len(), 0);
270        assert!(store.dim().is_none());
271    }
272
273    #[test]
274    fn test_insert_first_sets_dim() {
275        let mut store = EmbeddingStore::new();
276        store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
277        assert_eq!(store.dim(), Some(3));
278    }
279
280    #[test]
281    fn test_insert_returns_id() {
282        let mut store = EmbeddingStore::new();
283        let id = store.insert("a", v2(1.0, 0.0)).unwrap();
284        assert_eq!(id, 0);
285        let id2 = store.insert("b", v2(0.0, 1.0)).unwrap();
286        assert_eq!(id2, 1);
287    }
288
289    #[test]
290    fn test_insert_increments_len() {
291        let mut store = EmbeddingStore::new();
292        store.insert("a", v2(1.0, 0.0)).unwrap();
293        assert_eq!(store.len(), 1);
294        store.insert("b", v2(0.0, 1.0)).unwrap();
295        assert_eq!(store.len(), 2);
296    }
297
298    #[test]
299    fn test_insert_dim_mismatch_error() {
300        let mut store = EmbeddingStore::new();
301        store.insert("a", v2(1.0, 0.0)).unwrap();
302        let result = store.insert("b", v3(0.0, 1.0, 0.0));
303        assert!(matches!(
304            result,
305            Err(StoreError::DimensionMismatch {
306                expected: 2,
307                got: 3
308            })
309        ));
310    }
311
312    #[test]
313    fn test_insert_update_existing_label() {
314        let mut store = EmbeddingStore::new();
315        let id1 = store.insert("a", v2(1.0, 0.0)).unwrap();
316        let id2 = store.insert("a", v2(0.5, 0.5)).unwrap();
317        // Same id, same len
318        assert_eq!(id1, id2);
319        assert_eq!(store.len(), 1);
320        let e = store.get_by_label("a").expect("exists");
321        assert!((e.vector[0] - 0.5).abs() < 1e-9);
322    }
323
324    // --- insert_with_meta ---
325
326    #[test]
327    fn test_insert_with_meta_stores_metadata() {
328        let mut store = EmbeddingStore::new();
329        let mut meta = HashMap::new();
330        meta.insert("lang".to_string(), "en".to_string());
331        store.insert_with_meta("doc1", v2(1.0, 0.0), meta).unwrap();
332        let e = store.get_by_label("doc1").expect("exists");
333        assert_eq!(e.metadata["lang"], "en");
334    }
335
336    // --- get_by_label ---
337
338    #[test]
339    fn test_get_by_label_existing() {
340        let mut store = EmbeddingStore::new();
341        store.insert("hello", v2(1.0, 0.0)).unwrap();
342        assert!(store.get_by_label("hello").is_some());
343    }
344
345    #[test]
346    fn test_get_by_label_missing() {
347        let store = EmbeddingStore::new();
348        assert!(store.get_by_label("missing").is_none());
349    }
350
351    #[test]
352    fn test_get_by_label_returns_correct_vector() {
353        let mut store = EmbeddingStore::new();
354        store.insert("x", v2(3.0, 4.0)).unwrap();
355        let e = store.get_by_label("x").expect("exists");
356        assert!((e.vector[0] - 3.0).abs() < 1e-9);
357        assert!((e.vector[1] - 4.0).abs() < 1e-9);
358    }
359
360    // --- get_by_id ---
361
362    #[test]
363    fn test_get_by_id_existing() {
364        let mut store = EmbeddingStore::new();
365        let id = store.insert("a", v2(1.0, 0.0)).unwrap();
366        assert!(store.get_by_id(id).is_some());
367    }
368
369    #[test]
370    fn test_get_by_id_missing() {
371        let store = EmbeddingStore::new();
372        assert!(store.get_by_id(999).is_none());
373    }
374
375    #[test]
376    fn test_get_by_id_matches_label() {
377        let mut store = EmbeddingStore::new();
378        let id = store.insert("mykey", v2(1.0, 2.0)).unwrap();
379        let e = store.get_by_id(id).expect("exists");
380        assert_eq!(e.label, "mykey");
381    }
382
383    // --- cosine_similarity ---
384
385    #[test]
386    fn test_cosine_identical_vectors() {
387        let a = v3(1.0, 2.0, 3.0);
388        let sim = EmbeddingStore::cosine_similarity(&a, &a);
389        assert!((sim - 1.0).abs() < 1e-9);
390    }
391
392    #[test]
393    fn test_cosine_orthogonal_vectors() {
394        let a = v2(1.0, 0.0);
395        let b = v2(0.0, 1.0);
396        let sim = EmbeddingStore::cosine_similarity(&a, &b);
397        assert!(sim.abs() < 1e-9);
398    }
399
400    #[test]
401    fn test_cosine_opposite_vectors() {
402        let a = v2(1.0, 0.0);
403        let b = v2(-1.0, 0.0);
404        let sim = EmbeddingStore::cosine_similarity(&a, &b);
405        assert!((sim - (-1.0)).abs() < 1e-9);
406    }
407
408    #[test]
409    fn test_cosine_zero_vector_returns_zero() {
410        let a = v2(0.0, 0.0);
411        let b = v2(1.0, 0.0);
412        let sim = EmbeddingStore::cosine_similarity(&a, &b);
413        assert_eq!(sim, 0.0);
414    }
415
416    #[test]
417    fn test_cosine_symmetry() {
418        let a = v3(1.0, 2.0, 3.0);
419        let b = v3(4.0, 5.0, 6.0);
420        let sim_ab = EmbeddingStore::cosine_similarity(&a, &b);
421        let sim_ba = EmbeddingStore::cosine_similarity(&b, &a);
422        assert!((sim_ab - sim_ba).abs() < 1e-9);
423    }
424
425    // --- nearest ---
426
427    #[test]
428    fn test_nearest_empty_store_error() {
429        let store = EmbeddingStore::new();
430        assert!(matches!(
431            store.nearest(&[1.0, 0.0], 3),
432            Err(StoreError::EmptyStore)
433        ));
434    }
435
436    #[test]
437    fn test_nearest_dim_mismatch_error() {
438        let mut store = EmbeddingStore::new();
439        store.insert("a", v2(1.0, 0.0)).unwrap();
440        assert!(matches!(
441            store.nearest(&[1.0, 0.0, 0.0], 3),
442            Err(StoreError::DimensionMismatch { .. })
443        ));
444    }
445
446    #[test]
447    fn test_nearest_returns_k_results() {
448        let mut store = EmbeddingStore::new();
449        for i in 0..5 {
450            store.insert(format!("e{i}"), vec![i as f64, 0.0]).unwrap();
451        }
452        let results = store.nearest(&[1.0, 0.0], 3).unwrap();
453        assert_eq!(results.len(), 3);
454    }
455
456    #[test]
457    fn test_nearest_sorted_descending() {
458        let mut store = EmbeddingStore::new();
459        store.insert("up", v2(0.0, 1.0)).unwrap();
460        store.insert("right", v2(1.0, 0.0)).unwrap();
461        store.insert("diag", v2(1.0, 1.0)).unwrap();
462        let query = v2(1.0, 0.0);
463        let results = store.nearest(&query, 3).unwrap();
464        let sims: Vec<f64> = results.iter().map(|(_, s)| *s).collect();
465        for pair in sims.windows(2) {
466            assert!(pair[0] >= pair[1]);
467        }
468    }
469
470    #[test]
471    fn test_nearest_top1_is_most_similar() {
472        let mut store = EmbeddingStore::new();
473        store.insert("a", v2(1.0, 0.0)).unwrap();
474        store.insert("b", v2(0.0, 1.0)).unwrap();
475        store.insert("c", v2(-1.0, 0.0)).unwrap();
476        let results = store.nearest(&[1.0, 0.0], 1).unwrap();
477        assert_eq!(results[0].0.label, "a");
478    }
479
480    // --- labels ---
481
482    #[test]
483    fn test_labels_empty() {
484        let store = EmbeddingStore::new();
485        assert!(store.labels().is_empty());
486    }
487
488    #[test]
489    fn test_labels_returns_all() {
490        let mut store = EmbeddingStore::new();
491        store.insert("alpha", v2(1.0, 0.0)).unwrap();
492        store.insert("beta", v2(0.0, 1.0)).unwrap();
493        let labels = store.labels();
494        assert_eq!(labels.len(), 2);
495        assert!(labels.contains(&"alpha"));
496        assert!(labels.contains(&"beta"));
497    }
498
499    // --- remove ---
500
501    #[test]
502    fn test_remove_existing_returns_true() {
503        let mut store = EmbeddingStore::new();
504        store.insert("a", v2(1.0, 0.0)).unwrap();
505        assert!(store.remove("a"));
506        assert!(store.is_empty());
507    }
508
509    #[test]
510    fn test_remove_missing_returns_false() {
511        let mut store = EmbeddingStore::new();
512        assert!(!store.remove("ghost"));
513    }
514
515    #[test]
516    fn test_remove_decrements_len() {
517        let mut store = EmbeddingStore::new();
518        store.insert("a", v2(1.0, 0.0)).unwrap();
519        store.insert("b", v2(0.0, 1.0)).unwrap();
520        store.remove("a");
521        assert_eq!(store.len(), 1);
522    }
523
524    #[test]
525    fn test_remove_remaining_entry_still_accessible() {
526        let mut store = EmbeddingStore::new();
527        store.insert("a", v2(1.0, 0.0)).unwrap();
528        store.insert("b", v2(0.0, 1.0)).unwrap();
529        store.remove("a");
530        assert!(store.get_by_label("b").is_some());
531    }
532
533    #[test]
534    fn test_remove_all_resets_dim() {
535        let mut store = EmbeddingStore::new();
536        store.insert("a", v2(1.0, 0.0)).unwrap();
537        store.remove("a");
538        assert!(store.dim().is_none());
539    }
540
541    #[test]
542    fn test_remove_allows_reinsertion_with_different_dim() {
543        let mut store = EmbeddingStore::new();
544        store.insert("a", v2(1.0, 0.0)).unwrap();
545        store.remove("a");
546        // After removing the only entry, dim is reset, so new dimension is allowed
547        store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
548        assert_eq!(store.dim(), Some(3));
549    }
550
551    // --- default ---
552
553    #[test]
554    fn test_default_same_as_new() {
555        let store = EmbeddingStore::default();
556        assert!(store.is_empty());
557    }
558
559    // --- StoreError display ---
560
561    #[test]
562    fn test_error_display_dimension_mismatch() {
563        let e = StoreError::DimensionMismatch {
564            expected: 3,
565            got: 2,
566        };
567        assert!(!e.to_string().is_empty());
568    }
569
570    #[test]
571    fn test_error_display_label_not_found() {
572        let e = StoreError::LabelNotFound("ghost".to_string());
573        assert!(e.to_string().contains("ghost"));
574    }
575
576    #[test]
577    fn test_error_display_empty_store() {
578        let e = StoreError::EmptyStore;
579        assert!(!e.to_string().is_empty());
580    }
581
582    // --- additional scenarios ---
583
584    #[test]
585    fn test_nearest_k_larger_than_store() {
586        let mut store = EmbeddingStore::new();
587        store.insert("a", v2(1.0, 0.0)).unwrap();
588        store.insert("b", v2(0.0, 1.0)).unwrap();
589        let results = store.nearest(&[1.0, 1.0], 10).unwrap();
590        // Cannot return more than what's in the store
591        assert_eq!(results.len(), 2);
592    }
593
594    #[test]
595    fn test_id_is_stable_for_inserted_entry() {
596        let mut store = EmbeddingStore::new();
597        let id = store.insert("vec", v2(1.0, 1.0)).unwrap();
598        let e = store.get_by_label("vec").expect("exists");
599        assert_eq!(e.id, id);
600    }
601
602    #[test]
603    fn test_entry_label_matches() {
604        let mut store = EmbeddingStore::new();
605        store.insert("myLabel", v2(0.5, 0.5)).unwrap();
606        let e = store.get_by_label("myLabel").expect("exists");
607        assert_eq!(e.label, "myLabel");
608    }
609
610    // --- additional coverage ---
611
612    #[test]
613    fn test_insert_empty_vector_sets_dim_zero() {
614        let mut store = EmbeddingStore::new();
615        store.insert("empty", vec![]).unwrap();
616        assert_eq!(store.dim(), Some(0));
617    }
618
619    #[test]
620    fn test_cosine_unit_vectors() {
621        // Two unit vectors at 45° apart
622        let a = vec![1.0_f64 / 2.0_f64.sqrt(), 1.0_f64 / 2.0_f64.sqrt()];
623        let b = vec![1.0, 0.0];
624        let sim = EmbeddingStore::cosine_similarity(&a, &b);
625        assert!((sim - (1.0_f64 / 2.0_f64.sqrt())).abs() < 1e-9);
626    }
627
628    #[test]
629    fn test_nearest_returns_fewer_when_store_smaller_than_k() {
630        let mut store = EmbeddingStore::new();
631        store.insert("a", v2(1.0, 0.0)).unwrap();
632        let results = store.nearest(&[1.0, 0.0], 100).unwrap();
633        assert_eq!(results.len(), 1);
634    }
635
636    #[test]
637    fn test_remove_all_entries_allows_new_dim() {
638        let mut store = EmbeddingStore::new();
639        store.insert("a", v2(1.0, 0.0)).unwrap();
640        store.insert("b", v2(0.0, 1.0)).unwrap();
641        store.remove("a");
642        store.remove("b");
643        assert_eq!(store.dim(), None);
644        // Should accept a 3-d vector now
645        store.insert("c", v3(1.0, 0.0, 0.0)).unwrap();
646        assert_eq!(store.dim(), Some(3));
647    }
648
649    #[test]
650    fn test_get_by_id_after_remove_middle() {
651        let mut store = EmbeddingStore::new();
652        let id_a = store.insert("a", v2(1.0, 0.0)).unwrap();
653        store.insert("b", v2(0.0, 1.0)).unwrap();
654        let id_c = store.insert("c", v2(0.5, 0.5)).unwrap();
655        store.remove("b");
656        // a and c should still be accessible by id
657        assert!(store.get_by_id(id_a).is_some());
658        assert!(store.get_by_id(id_c).is_some());
659    }
660
661    #[test]
662    fn test_insert_with_meta_empty_meta() {
663        let mut store = EmbeddingStore::new();
664        store
665            .insert_with_meta("doc", v2(1.0, 0.0), HashMap::new())
666            .unwrap();
667        let e = store.get_by_label("doc").expect("exists");
668        assert!(e.metadata.is_empty());
669    }
670
671    #[test]
672    fn test_nearest_similarity_range() {
673        let mut store = EmbeddingStore::new();
674        store.insert("a", v3(1.0, 0.0, 0.0)).unwrap();
675        store.insert("b", v3(0.0, 1.0, 0.0)).unwrap();
676        store.insert("c", v3(0.0, 0.0, 1.0)).unwrap();
677        let results = store.nearest(&[1.0, 0.0, 0.0], 3).unwrap();
678        for (_, sim) in &results {
679            assert!(*sim >= -1.0 && *sim <= 1.0);
680        }
681    }
682}