1use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12use keyhog_core::VerificationResult;
13use sha2::{Digest, Sha256};
14
15pub struct VerificationCache {
27 entries: RwLock<HashMap<CacheKey, CacheEntry>>,
28 inserts: AtomicUsize,
29 max_entries: usize,
30 ttl: Duration,
31}
32
33#[derive(Hash, Eq, PartialEq, Clone)]
34struct CacheKey {
35 credential_hash: [u8; VerificationCache::HASH_BYTES],
36 detector_id_hash: [u8; VerificationCache::HASH_BYTES],
37 detector_id: Arc<str>,
38}
39
40struct CacheEntry {
41 result: VerificationResult,
42 metadata: HashMap<String, String>,
43 expires_at: Instant,
44}
45
46impl VerificationCache {
47 const DEFAULT_TTL_SECS: u64 = 300;
48 const DEFAULT_MAX_ENTRIES: usize = 10_000;
49 const EVICTION_INTERVAL: usize = 64;
50 pub(crate) const HASH_BYTES: usize = 32;
51 const MAX_DETECTOR_ID_BYTES: usize = 128;
52 const MAX_METADATA_ENTRIES: usize = 16;
53 const MAX_METADATA_KEY_BYTES: usize = 64;
54 const MAX_METADATA_VALUE_BYTES: usize = 256;
55
56 pub fn new(ttl: Duration) -> Self {
68 Self::with_max_entries(ttl, Self::DEFAULT_MAX_ENTRIES)
69 }
70
71 pub fn with_max_entries(ttl: Duration, max_entries: usize) -> Self {
83 Self {
84 entries: RwLock::new(HashMap::new()),
85 inserts: AtomicUsize::new(0),
86 max_entries: max_entries.max(1),
87 ttl,
88 }
89 }
90
91 pub fn default_ttl() -> Self {
102 Self::new(Duration::from_secs(Self::DEFAULT_TTL_SECS))
103 }
104
105 pub fn get(
120 &self,
121 credential: &str,
122 detector_id: &str,
123 ) -> Option<(VerificationResult, HashMap<String, String>)> {
124 let key = cache_key(credential, detector_id);
125 let now = Instant::now();
126
127 let mut entries = self.entries.write().unwrap_or_else(|p| p.into_inner());
128 match entries.entry(key) {
129 std::collections::hash_map::Entry::Occupied(entry) => {
130 let (result, metadata, expires_at) = {
131 let entry = entry.get();
132 (
133 entry.result.clone(),
134 entry.metadata.clone(),
135 entry.expires_at,
136 )
137 };
138 if now >= expires_at {
139 entry.remove();
140 None
141 } else {
142 Some((result, metadata))
143 }
144 }
145 std::collections::hash_map::Entry::Vacant(_) => None,
146 }
147 }
148
149 pub fn put(
164 &self,
165 credential: &str,
166 detector_id: &str,
167 result: VerificationResult,
168 metadata: HashMap<String, String>,
169 ) {
170 let key = cache_key(credential, detector_id);
171
172 let insert_count = self.inserts.fetch_add(1, Ordering::Relaxed) + 1;
173 if insert_count.is_multiple_of(Self::EVICTION_INTERVAL) {
174 self.evict_expired();
178 }
179
180 self.entries
181 .write()
182 .unwrap_or_else(|p| p.into_inner())
183 .insert(
184 key,
185 CacheEntry {
186 result,
187 metadata: sanitize_metadata(metadata),
188 expires_at: Instant::now() + self.ttl,
189 },
190 );
191
192 if self.entries.read().unwrap_or_else(|p| p.into_inner()).len() > self.max_entries {
193 self.evict_one_oldest();
194 }
195 }
196
197 pub fn len(&self) -> usize {
209 self.entries.read().unwrap_or_else(|p| p.into_inner()).len()
210 }
211
212 pub fn is_empty(&self) -> bool {
224 self.entries
225 .read()
226 .unwrap_or_else(|p| p.into_inner())
227 .is_empty()
228 }
229
230 pub fn evict_expired(&self) {
243 let now = Instant::now();
244 self.entries
245 .write()
246 .unwrap_or_else(|p| p.into_inner())
247 .retain(|_, entry| now < entry.expires_at);
248 }
249
250 fn evict_one_oldest(&self) {
251 let mut entries = self.entries.write().unwrap_or_else(|p| p.into_inner());
252 let oldest_key = entries
253 .iter()
254 .min_by_key(|entry| entry.1.expires_at)
255 .map(|(key, _)| key.clone());
256
257 if let Some(key) = oldest_key {
258 entries.remove(&key);
259 }
260 }
261}
262
263fn hash_credential(credential: &str) -> [u8; VerificationCache::HASH_BYTES] {
264 Sha256::digest(credential.as_bytes()).into()
265}
266
267fn cache_key(credential: &str, detector_id: &str) -> CacheKey {
268 CacheKey {
269 credential_hash: hash_credential(credential),
270 detector_id_hash: hash_credential(detector_id),
271 detector_id: Arc::<str>::from(truncate_to_char_boundary(
272 detector_id,
273 VerificationCache::MAX_DETECTOR_ID_BYTES,
274 )),
275 }
276}
277
278fn sanitize_metadata(metadata: HashMap<String, String>) -> HashMap<String, String> {
279 metadata
280 .into_iter()
281 .take(VerificationCache::MAX_METADATA_ENTRIES)
282 .map(|(key, value)| {
283 (
284 truncate_to_char_boundary(&key, VerificationCache::MAX_METADATA_KEY_BYTES),
285 truncate_to_char_boundary(&value, VerificationCache::MAX_METADATA_VALUE_BYTES),
286 )
287 })
288 .collect()
289}
290
291fn truncate_to_char_boundary(value: &str, max_bytes: usize) -> String {
292 if value.len() <= max_bytes {
293 return value.to_string();
294 }
295
296 let mut end = max_bytes;
297 while end > 0 && !value.is_char_boundary(end) {
298 end -= 1;
299 }
300 value[..end].to_string()
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn cache_hit_and_miss() {
309 let cache = VerificationCache::new(Duration::from_secs(60));
310
311 assert!(cache.get("cred1", "detector1").is_none());
312
313 cache.put(
314 "cred1",
315 "detector1",
316 VerificationResult::Live,
317 HashMap::from([("user".into(), "alice".into())]),
318 );
319
320 let (result, metadata) = cache.get("cred1", "detector1").unwrap();
321 assert!(matches!(result, VerificationResult::Live));
322 assert_eq!(metadata["user"], "alice");
323 assert!(cache.get("cred1", "detector2").is_none());
324 }
325
326 #[test]
327 fn cache_ttl_expiry() {
328 let cache = VerificationCache::new(Duration::from_millis(1));
329 cache.put("cred", "det", VerificationResult::Dead, HashMap::new());
330 std::thread::sleep(Duration::from_millis(2));
331 assert!(cache.get("cred", "det").is_none());
332 }
333
334 #[test]
335 fn evict_expired() {
336 let cache = VerificationCache::new(Duration::from_millis(1));
337 cache.put("cred", "det", VerificationResult::Dead, HashMap::new());
338 std::thread::sleep(Duration::from_millis(2));
339 cache.evict_expired();
340 assert!(cache.is_empty());
341 }
342
343 #[test]
344 fn evicts_oldest_entry_when_cache_hits_capacity() {
345 let cache = VerificationCache::with_max_entries(Duration::from_secs(60), 2);
346 cache.put("cred1", "det", VerificationResult::Dead, HashMap::new());
347 std::thread::sleep(Duration::from_millis(1));
348 cache.put("cred2", "det", VerificationResult::Dead, HashMap::new());
349 std::thread::sleep(Duration::from_millis(1));
350 cache.put("cred3", "det", VerificationResult::Dead, HashMap::new());
351
352 assert!(cache.get("cred1", "det").is_none());
353 assert!(cache.get("cred2", "det").is_some());
354 assert!(cache.get("cred3", "det").is_some());
355 assert_eq!(cache.len(), 2);
356 }
357
358 #[test]
359 fn long_detector_ids_do_not_collide_after_truncation() {
360 let cache = VerificationCache::new(Duration::from_secs(60));
361 let shared_prefix = "x".repeat(VerificationCache::MAX_DETECTOR_ID_BYTES);
362 let detector_a = format!("{shared_prefix}alpha");
363 let detector_b = format!("{shared_prefix}beta");
364
365 cache.put(
366 "cred",
367 &detector_a,
368 VerificationResult::Live,
369 HashMap::from([("source".into(), "a".into())]),
370 );
371 cache.put(
372 "cred",
373 &detector_b,
374 VerificationResult::Dead,
375 HashMap::from([("source".into(), "b".into())]),
376 );
377
378 let (result_a, metadata_a) = cache.get("cred", &detector_a).unwrap();
379 let (result_b, metadata_b) = cache.get("cred", &detector_b).unwrap();
380 assert!(matches!(result_a, VerificationResult::Live));
381 assert!(matches!(result_b, VerificationResult::Dead));
382 assert_eq!(metadata_a["source"], "a");
383 assert_eq!(metadata_b["source"], "b");
384 }
385}