1use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use keyhog_core::VerificationResult;
14use sha2::{Digest, Sha256};
15
16pub struct VerificationCache {
28 entries: DashMap<CacheKey, CacheEntry>,
29 inserts: AtomicUsize,
30 max_entries: usize,
31 ttl: Duration,
32}
33
34#[derive(Hash, Eq, PartialEq, Clone)]
35struct CacheKey {
36 credential_hash: [u8; VerificationCache::HASH_BYTES],
37 detector_id_hash: [u8; VerificationCache::HASH_BYTES],
38 detector_id: Arc<str>,
39}
40
41struct CacheEntry {
42 result: VerificationResult,
43 metadata: HashMap<String, String>,
44 expires_at: Instant,
45}
46
47impl VerificationCache {
48 const DEFAULT_TTL_SECS: u64 = 300;
49 const DEFAULT_MAX_ENTRIES: usize = 10_000;
50 const EVICTION_INTERVAL: usize = 64;
51 pub(crate) const HASH_BYTES: usize = 32;
52 const MAX_DETECTOR_ID_BYTES: usize = 128;
53 const MAX_METADATA_ENTRIES: usize = 16;
54 const MAX_METADATA_KEY_BYTES: usize = 64;
55 const MAX_METADATA_VALUE_BYTES: usize = 256;
56
57 pub fn new(ttl: Duration) -> Self {
69 Self::with_max_entries(ttl, Self::DEFAULT_MAX_ENTRIES)
70 }
71
72 pub fn with_max_entries(ttl: Duration, max_entries: usize) -> Self {
84 Self {
85 entries: DashMap::new(),
86 inserts: AtomicUsize::new(0),
87 max_entries: max_entries.max(1),
88 ttl,
89 }
90 }
91
92 pub fn default_ttl() -> Self {
103 Self::new(Duration::from_secs(Self::DEFAULT_TTL_SECS))
104 }
105
106 pub fn get(
121 &self,
122 credential: &str,
123 detector_id: &str,
124 ) -> Option<(VerificationResult, HashMap<String, String>)> {
125 let key = cache_key(credential, detector_id);
126
127 let now = Instant::now();
128 match self.entries.entry(key) {
129 dashmap::mapref::entry::Entry::Occupied(entry) => {
130 let (result, metadata, expires_at) = {
131 let entry = entry.get();
132 (
133 entry.result.clone(),
134 entry.metadata.clone(),
135 entry.expires_at,
136 )
137 };
138 if now >= expires_at {
139 entry.remove();
140 None
141 } else {
142 Some((result, metadata))
143 }
144 }
145 dashmap::mapref::entry::Entry::Vacant(_) => None,
146 }
147 }
148
149 pub fn put(
164 &self,
165 credential: &str,
166 detector_id: &str,
167 result: VerificationResult,
168 metadata: HashMap<String, String>,
169 ) {
170 let key = cache_key(credential, detector_id);
171
172 let insert_count = self.inserts.fetch_add(1, Ordering::Relaxed) + 1;
173 if insert_count.is_multiple_of(Self::EVICTION_INTERVAL) {
174 self.evict_expired();
178 }
179
180 self.entries.insert(
181 key,
182 CacheEntry {
183 result,
184 metadata: sanitize_metadata(metadata),
185 expires_at: Instant::now() + self.ttl,
186 },
187 );
188
189 if self.entries.len() > self.max_entries {
190 self.evict_one_oldest();
191 }
192 }
193
194 pub fn len(&self) -> usize {
206 self.entries.len()
207 }
208
209 pub fn is_empty(&self) -> bool {
221 self.entries.is_empty()
222 }
223
224 pub fn evict_expired(&self) {
237 let now = Instant::now();
238 self.entries.retain(|_, entry| now < entry.expires_at);
241 }
242
243 fn evict_one_oldest(&self) {
244 let oldest_key = self
248 .entries
249 .iter()
250 .min_by_key(|entry| entry.expires_at)
251 .map(|entry| entry.key().clone());
252
253 if let Some(key) = oldest_key {
254 self.entries.remove(&key);
255 }
256 }
257}
258
259fn hash_credential(credential: &str) -> [u8; VerificationCache::HASH_BYTES] {
260 Sha256::digest(credential.as_bytes()).into()
261}
262
263fn cache_key(credential: &str, detector_id: &str) -> CacheKey {
264 CacheKey {
265 credential_hash: hash_credential(credential),
266 detector_id_hash: hash_credential(detector_id),
267 detector_id: Arc::<str>::from(truncate_to_char_boundary(
268 detector_id,
269 VerificationCache::MAX_DETECTOR_ID_BYTES,
270 )),
271 }
272}
273
274fn sanitize_metadata(metadata: HashMap<String, String>) -> HashMap<String, String> {
275 metadata
276 .into_iter()
277 .take(VerificationCache::MAX_METADATA_ENTRIES)
278 .map(|(key, value)| {
279 (
280 truncate_to_char_boundary(&key, VerificationCache::MAX_METADATA_KEY_BYTES),
281 truncate_to_char_boundary(&value, VerificationCache::MAX_METADATA_VALUE_BYTES),
282 )
283 })
284 .collect()
285}
286
287fn truncate_to_char_boundary(value: &str, max_bytes: usize) -> String {
288 if value.len() <= max_bytes {
289 return value.to_string();
290 }
291
292 let mut end = max_bytes;
293 while end > 0 && !value.is_char_boundary(end) {
294 end -= 1;
295 }
296 value[..end].to_string()
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn cache_hit_and_miss() {
305 let cache = VerificationCache::new(Duration::from_secs(60));
306
307 assert!(cache.get("cred1", "detector1").is_none());
308
309 cache.put(
310 "cred1",
311 "detector1",
312 VerificationResult::Live,
313 HashMap::from([("user".into(), "alice".into())]),
314 );
315
316 let (result, metadata) = cache.get("cred1", "detector1").unwrap();
317 assert!(matches!(result, VerificationResult::Live));
318 assert_eq!(metadata["user"], "alice");
319 assert!(cache.get("cred1", "detector2").is_none());
320 }
321
322 #[test]
323 fn cache_ttl_expiry() {
324 let cache = VerificationCache::new(Duration::from_millis(1));
325 cache.put("cred", "det", VerificationResult::Dead, HashMap::new());
326 std::thread::sleep(Duration::from_millis(2));
327 assert!(cache.get("cred", "det").is_none());
328 }
329
330 #[test]
331 fn evict_expired() {
332 let cache = VerificationCache::new(Duration::from_millis(1));
333 cache.put("cred", "det", VerificationResult::Dead, HashMap::new());
334 std::thread::sleep(Duration::from_millis(2));
335 cache.evict_expired();
336 assert!(cache.is_empty());
337 }
338
339 #[test]
340 fn evicts_oldest_entry_when_cache_hits_capacity() {
341 let cache = VerificationCache::with_max_entries(Duration::from_secs(60), 2);
342 cache.put("cred1", "det", VerificationResult::Dead, HashMap::new());
343 std::thread::sleep(Duration::from_millis(1));
344 cache.put("cred2", "det", VerificationResult::Dead, HashMap::new());
345 std::thread::sleep(Duration::from_millis(1));
346 cache.put("cred3", "det", VerificationResult::Dead, HashMap::new());
347
348 assert!(cache.get("cred1", "det").is_none());
349 assert!(cache.get("cred2", "det").is_some());
350 assert!(cache.get("cred3", "det").is_some());
351 assert_eq!(cache.len(), 2);
352 }
353
354 #[test]
355 fn long_detector_ids_do_not_collide_after_truncation() {
356 let cache = VerificationCache::new(Duration::from_secs(60));
357 let shared_prefix = "x".repeat(VerificationCache::MAX_DETECTOR_ID_BYTES);
358 let detector_a = format!("{shared_prefix}alpha");
359 let detector_b = format!("{shared_prefix}beta");
360
361 cache.put(
362 "cred",
363 &detector_a,
364 VerificationResult::Live,
365 HashMap::from([("source".into(), "a".into())]),
366 );
367 cache.put(
368 "cred",
369 &detector_b,
370 VerificationResult::Dead,
371 HashMap::from([("source".into(), "b".into())]),
372 );
373
374 let (result_a, metadata_a) = cache.get("cred", &detector_a).unwrap();
375 let (result_b, metadata_b) = cache.get("cred", &detector_b).unwrap();
376 assert!(matches!(result_a, VerificationResult::Live));
377 assert!(matches!(result_b, VerificationResult::Dead));
378 assert_eq!(metadata_a["source"], "a");
379 assert_eq!(metadata_b["source"], "b");
380 }
381}