Skip to main content

mentedb_embedding/
cache.rs

1//! LRU cache for computed embeddings to avoid recomputation.
2
3use std::collections::{HashMap, VecDeque};
4
5/// Statistics about cache usage.
6#[derive(Debug, Clone, Default)]
7pub struct CacheStats {
8    pub hits: u64,
9    pub misses: u64,
10    pub size: usize,
11    pub max_size: usize,
12    pub evictions: u64,
13}
14
15/// A cached embedding entry.
16#[derive(Debug, Clone)]
17pub struct CachedEmbedding {
18    pub embedding: Vec<f32>,
19    pub created_at: u64,
20    pub hit_count: u32,
21}
22
23/// LRU cache for embedding vectors, keyed by hash of (model_name + text).
24pub struct EmbeddingCache {
25    max_size: usize,
26    cache: HashMap<u64, CachedEmbedding>,
27    order: VecDeque<u64>,
28    hits: u64,
29    misses: u64,
30    evictions: u64,
31}
32
33impl EmbeddingCache {
34    /// Create a new cache with the given maximum number of entries.
35    pub fn new(max_size: usize) -> Self {
36        Self {
37            max_size,
38            cache: HashMap::with_capacity(max_size.min(1024)),
39            order: VecDeque::with_capacity(max_size.min(1024)),
40            hits: 0,
41            misses: 0,
42            evictions: 0,
43        }
44    }
45
46    /// Create a cache with the default size of 10,000 entries.
47    pub fn default_size() -> Self {
48        Self::new(10_000)
49    }
50
51    /// Compute a cache key from model name and text.
52    fn cache_key(model: &str, text: &str) -> u64 {
53        // FNV-1a hash
54        let mut hash: u64 = 0xcbf29ce484222325;
55        let prime: u64 = 0x100000001b3;
56
57        for byte in model.as_bytes() {
58            hash ^= *byte as u64;
59            hash = hash.wrapping_mul(prime);
60        }
61        // Separator
62        hash ^= 0xff;
63        hash = hash.wrapping_mul(prime);
64
65        for byte in text.as_bytes() {
66            hash ^= *byte as u64;
67            hash = hash.wrapping_mul(prime);
68        }
69
70        hash
71    }
72
73    /// Look up a cached embedding. Returns `None` on cache miss.
74    pub fn get(&mut self, text: &str, model: &str) -> Option<&[f32]> {
75        let key = Self::cache_key(model, text);
76
77        if self.cache.contains_key(&key) {
78            self.hits += 1;
79
80            // Move to back (most recently used)
81            self.order.retain(|k| *k != key);
82            self.order.push_back(key);
83
84            let entry = self.cache.get_mut(&key).unwrap();
85            entry.hit_count += 1;
86            Some(&entry.embedding)
87        } else {
88            self.misses += 1;
89            None
90        }
91    }
92
93    /// Insert an embedding into the cache. Evicts the LRU entry if full.
94    pub fn put(&mut self, text: &str, model: &str, embedding: Vec<f32>) {
95        let key = Self::cache_key(model, text);
96
97        // If key already exists, update it
98        if self.cache.contains_key(&key) {
99            self.order.retain(|k| *k != key);
100            self.order.push_back(key);
101            self.cache.insert(
102                key,
103                CachedEmbedding {
104                    embedding,
105                    created_at: Self::now_micros(),
106                    hit_count: 0,
107                },
108            );
109            return;
110        }
111
112        // Evict LRU if at capacity
113        while self.cache.len() >= self.max_size {
114            if let Some(evict_key) = self.order.pop_front() {
115                self.cache.remove(&evict_key);
116                self.evictions += 1;
117            } else {
118                break;
119            }
120        }
121
122        self.cache.insert(
123            key,
124            CachedEmbedding {
125                embedding,
126                created_at: Self::now_micros(),
127                hit_count: 0,
128            },
129        );
130        self.order.push_back(key);
131    }
132
133    /// Get cache statistics.
134    pub fn stats(&self) -> CacheStats {
135        CacheStats {
136            hits: self.hits,
137            misses: self.misses,
138            size: self.cache.len(),
139            max_size: self.max_size,
140            evictions: self.evictions,
141        }
142    }
143
144    /// Clear all cached entries.
145    pub fn clear(&mut self) {
146        self.cache.clear();
147        self.order.clear();
148    }
149
150    fn now_micros() -> u64 {
151        std::time::SystemTime::now()
152            .duration_since(std::time::UNIX_EPOCH)
153            .unwrap_or_default()
154            .as_micros() as u64
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_cache_hit_miss() {
164        let mut cache = EmbeddingCache::new(10);
165        assert!(cache.get("hello", "model").is_none());
166        assert_eq!(cache.stats().misses, 1);
167
168        cache.put("hello", "model", vec![1.0, 2.0, 3.0]);
169        let result = cache.get("hello", "model");
170        assert!(result.is_some());
171        assert_eq!(result.unwrap(), &[1.0, 2.0, 3.0]);
172        assert_eq!(cache.stats().hits, 1);
173    }
174
175    #[test]
176    fn test_lru_eviction() {
177        let mut cache = EmbeddingCache::new(3);
178
179        cache.put("a", "m", vec![1.0]);
180        cache.put("b", "m", vec![2.0]);
181        cache.put("c", "m", vec![3.0]);
182
183        // Cache is full, inserting "d" should evict "a" (LRU)
184        cache.put("d", "m", vec![4.0]);
185
186        assert!(cache.get("a", "m").is_none());
187        assert!(cache.get("b", "m").is_some());
188        assert!(cache.get("c", "m").is_some());
189        assert!(cache.get("d", "m").is_some());
190        assert_eq!(cache.stats().evictions, 1);
191    }
192}