1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use crate::types::OcrResult;
8
9#[derive(Debug, Clone, Hash, PartialEq, Eq)]
11pub struct CacheKey {
12 image_hash: u64,
14 provider: String,
16 output_format: String,
18 language: Option<String>,
20}
21
22impl CacheKey {
23 pub fn new(
25 image_data: &[u8],
26 provider: &str,
27 output_format: &str,
28 language: Option<&str>,
29 ) -> Self {
30 use std::hash::{Hash, Hasher};
31 let mut hasher = std::collections::hash_map::DefaultHasher::new();
32 image_data.hash(&mut hasher);
33 let image_hash = hasher.finish();
34
35 Self {
36 image_hash,
37 provider: provider.to_string(),
38 output_format: output_format.to_string(),
39 language: language.map(|s| s.to_string()),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46struct CachedEntry {
47 result: OcrResult,
48 created_at: Instant,
49 ttl: Duration,
50}
51
52impl CachedEntry {
53 fn is_expired(&self) -> bool {
54 self.created_at.elapsed() > self.ttl
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct VisionCache {
61 cache: Arc<Mutex<HashMap<CacheKey, CachedEntry>>>,
62 default_ttl: Duration,
63 max_size: usize,
64}
65
66impl Default for VisionCache {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl VisionCache {
73 pub fn new() -> Self {
76 Self {
77 cache: Arc::new(Mutex::new(HashMap::new())),
78 default_ttl: Duration::from_secs(3600),
79 max_size: 100,
80 }
81 }
82
83 pub fn with_ttl(ttl: Duration) -> Self {
85 Self {
86 cache: Arc::new(Mutex::new(HashMap::new())),
87 default_ttl: ttl,
88 max_size: 100,
89 }
90 }
91
92 pub fn with_max_size(max_size: usize) -> Self {
94 Self {
95 cache: Arc::new(Mutex::new(HashMap::new())),
96 default_ttl: Duration::from_secs(3600),
97 max_size,
98 }
99 }
100
101 pub fn get(&self, key: &CacheKey) -> Option<OcrResult> {
103 let mut cache = self.cache.lock().unwrap();
104
105 if let Some(entry) = cache.get(key) {
106 if entry.is_expired() {
107 cache.remove(key);
108 return None;
109 }
110 return Some(entry.result.clone());
111 }
112
113 None
114 }
115
116 pub fn set(&self, key: CacheKey, result: OcrResult) {
118 self.set_with_ttl(key, result, self.default_ttl);
119 }
120
121 pub fn set_with_ttl(&self, key: CacheKey, result: OcrResult, ttl: Duration) {
123 let mut cache = self.cache.lock().unwrap();
124
125 if cache.len() >= self.max_size {
127 self.evict_expired(&mut cache);
128 }
129
130 if cache.len() >= self.max_size {
132 if let Some(oldest_key) = cache
133 .iter()
134 .min_by_key(|(_, v)| v.created_at)
135 .map(|(k, _)| k.clone())
136 {
137 cache.remove(&oldest_key);
138 }
139 }
140
141 cache.insert(
142 key,
143 CachedEntry {
144 result,
145 created_at: Instant::now(),
146 ttl,
147 },
148 );
149 }
150
151 fn evict_expired(&self, cache: &mut HashMap<CacheKey, CachedEntry>) {
153 cache.retain(|_, entry| !entry.is_expired());
154 }
155
156 pub fn clear(&self) {
158 let mut cache = self.cache.lock().unwrap();
159 cache.clear();
160 }
161
162 pub fn len(&self) -> usize {
164 let cache = self.cache.lock().unwrap();
165 cache.len()
166 }
167
168 pub fn is_empty(&self) -> bool {
170 self.len() == 0
171 }
172
173 pub fn stats(&self) -> CacheStats {
175 let cache = self.cache.lock().unwrap();
176 let total = cache.len();
177 let expired = cache.values().filter(|e| e.is_expired()).count();
178
179 CacheStats {
180 total_entries: total,
181 expired_entries: expired,
182 active_entries: total - expired,
183 max_size: self.max_size,
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct CacheStats {
191 pub total_entries: usize,
192 pub expired_entries: usize,
193 pub active_entries: usize,
194 pub max_size: usize,
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_cache_basic() {
203 let cache = VisionCache::new();
204 let key = CacheKey::new(b"test image", "mock", "markdown", None);
205 let result = OcrResult::from_text("Hello");
206
207 cache.set(key.clone(), result.clone());
208
209 let cached = cache.get(&key).unwrap();
210 assert_eq!(cached.text, "Hello");
211 }
212
213 #[test]
214 fn test_cache_miss() {
215 let cache = VisionCache::new();
216 let key = CacheKey::new(b"test image", "mock", "markdown", None);
217
218 assert!(cache.get(&key).is_none());
219 }
220
221 #[test]
222 fn test_cache_expiration() {
223 let cache = VisionCache::with_ttl(Duration::from_millis(1));
224 let key = CacheKey::new(b"test image", "mock", "markdown", None);
225 let result = OcrResult::from_text("Hello");
226
227 cache.set(key.clone(), result);
228
229 std::thread::sleep(Duration::from_millis(10));
231
232 assert!(cache.get(&key).is_none());
233 }
234
235 #[test]
236 fn test_cache_max_size() {
237 let cache = VisionCache::with_max_size(2);
238
239 for i in 0..5 {
240 let key = CacheKey::new(format!("image{}", i).as_bytes(), "mock", "markdown", None);
241 cache.set(key, OcrResult::from_text(format!("Result {}", i)));
242 }
243
244 assert!(cache.len() <= 2);
246 }
247
248 #[test]
249 fn test_cache_clear() {
250 let cache = VisionCache::new();
251 let key = CacheKey::new(b"test image", "mock", "markdown", None);
252 cache.set(key, OcrResult::from_text("Hello"));
253
254 cache.clear();
255 assert!(cache.is_empty());
256 }
257}