Skip to main content

claw_vector/embeddings/
cache.rs

1// embeddings/cache.rs — LRU cache for embedding vectors keyed by text.
2use std::num::NonZeroUsize;
3
4use lru::LruCache;
5use sha2::{Digest, Sha256};
6
7/// Snapshot of embedding cache statistics.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub struct EmbeddingCacheStats {
10    /// Number of cache hits.
11    pub hit_count: u64,
12    /// Number of cache misses.
13    pub miss_count: u64,
14    /// Number of cached entries.
15    pub len: usize,
16}
17
18/// LRU cache for embedding vectors keyed by a SHA-256 hash of the input text.
19pub struct EmbeddingCache {
20    /// Inner LRU cache mapping text digests to embeddings.
21    pub inner: LruCache<String, Vec<f32>>,
22    /// Number of cache hits.
23    pub hit_count: u64,
24    /// Number of cache misses.
25    pub miss_count: u64,
26}
27
28impl EmbeddingCache {
29    /// Create a new cache with the given maximum capacity.
30    pub fn new(capacity: usize) -> Self {
31        let capacity = NonZeroUsize::new(capacity.max(1)).unwrap();
32        EmbeddingCache {
33            inner: LruCache::new(capacity),
34            hit_count: 0,
35            miss_count: 0,
36        }
37    }
38
39    /// Compute the cache key for an input text.
40    pub fn key(text: &str) -> String {
41        let digest = Sha256::digest(text.as_bytes());
42        hex::encode(digest)
43    }
44
45    /// Look up a cached vector for `text`, returning `None` on a cache miss.
46    pub fn get(&mut self, text: &str) -> Option<Vec<f32>> {
47        let key = Self::key(text);
48        if let Some(vector) = self.inner.get(&key) {
49            self.hit_count += 1;
50            Some(vector.clone())
51        } else {
52            self.miss_count += 1;
53            None
54        }
55    }
56
57    /// Insert or update a cached vector for `text`.
58    pub fn insert(&mut self, text: &str, vector: Vec<f32>) {
59        let key = Self::key(text);
60        self.inner.put(key, vector);
61    }
62
63    /// Return a snapshot of cache hit/miss counts and entry count.
64    pub fn stats(&self) -> EmbeddingCacheStats {
65        EmbeddingCacheStats {
66            hit_count: self.hit_count,
67            miss_count: self.miss_count,
68            len: self.inner.len(),
69        }
70    }
71}