use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use keyhog_core::VerificationResult;
use sha2::{Digest, Sha256};
pub struct VerificationCache {
entries: DashMap<CacheKey, CacheEntry>,
inserts: AtomicUsize,
max_entries: usize,
ttl: Duration,
}
#[derive(Hash, Eq, PartialEq, Clone)]
struct CacheKey {
credential_hash: [u8; VerificationCache::HASH_BYTES],
detector_id_hash: [u8; VerificationCache::HASH_BYTES],
detector_id: Arc<str>,
}
struct CacheEntry {
result: VerificationResult,
metadata: HashMap<String, String>,
expires_at: Instant,
}
impl VerificationCache {
const DEFAULT_TTL_SECS: u64 = 300;
const DEFAULT_MAX_ENTRIES: usize = 10_000;
const EVICTION_INTERVAL: usize = 64;
pub(crate) const HASH_BYTES: usize = 32;
const MAX_DETECTOR_ID_BYTES: usize = 128;
const MAX_METADATA_ENTRIES: usize = 16;
const MAX_METADATA_KEY_BYTES: usize = 64;
const MAX_METADATA_VALUE_BYTES: usize = 256;
pub fn new(ttl: Duration) -> Self {
Self::with_max_entries(ttl, Self::DEFAULT_MAX_ENTRIES)
}
pub fn with_max_entries(ttl: Duration, max_entries: usize) -> Self {
Self {
entries: DashMap::new(),
inserts: AtomicUsize::new(0),
max_entries: max_entries.max(1),
ttl,
}
}
pub fn default_ttl() -> Self {
Self::new(Duration::from_secs(Self::DEFAULT_TTL_SECS))
}
pub fn get(
&self,
credential: &str,
detector_id: &str,
) -> Option<(VerificationResult, HashMap<String, String>)> {
let key = cache_key(credential, detector_id);
let now = Instant::now();
if let Some(entry) = self.entries.get(&key) {
if now < entry.expires_at {
return Some((entry.result.clone(), entry.metadata.clone()));
}
} else {
return None;
}
if let dashmap::mapref::entry::Entry::Occupied(entry) = self.entries.entry(key) {
if now >= entry.get().expires_at {
entry.remove();
} else {
let entry = entry.get();
return Some((entry.result.clone(), entry.metadata.clone()));
}
}
None
}
pub fn put(
&self,
credential: &str,
detector_id: &str,
result: VerificationResult,
metadata: HashMap<String, String>,
) {
let key = cache_key(credential, detector_id);
let insert_count = self.inserts.fetch_add(1, Ordering::Relaxed) + 1;
if insert_count.is_multiple_of(Self::EVICTION_INTERVAL) {
self.evict_expired();
}
self.entries.insert(
key,
CacheEntry {
result,
metadata: sanitize_metadata(metadata),
expires_at: Instant::now() + self.ttl,
},
);
if self.entries.len() > self.max_entries {
self.evict_one_oldest();
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn evict_expired(&self) {
let now = Instant::now();
self.entries.retain(|_, entry| now < entry.expires_at);
}
fn evict_one_oldest(&self) {
let oldest_key = self
.entries
.iter()
.min_by_key(|entry| entry.value().expires_at)
.map(|entry| entry.key().clone());
if let Some(key) = oldest_key {
self.entries.remove(&key);
}
}
}
fn hash_credential(credential: &str) -> [u8; VerificationCache::HASH_BYTES] {
Sha256::digest(credential.as_bytes()).into()
}
fn cache_key(credential: &str, detector_id: &str) -> CacheKey {
CacheKey {
credential_hash: hash_credential(credential),
detector_id_hash: hash_credential(detector_id),
detector_id: Arc::<str>::from(truncate_to_char_boundary(
detector_id,
VerificationCache::MAX_DETECTOR_ID_BYTES,
)),
}
}
fn sanitize_metadata(metadata: HashMap<String, String>) -> HashMap<String, String> {
metadata
.into_iter()
.take(VerificationCache::MAX_METADATA_ENTRIES)
.map(|(key, value)| {
(
truncate_to_char_boundary(&key, VerificationCache::MAX_METADATA_KEY_BYTES),
truncate_to_char_boundary(&value, VerificationCache::MAX_METADATA_VALUE_BYTES),
)
})
.collect()
}
fn truncate_to_char_boundary(value: &str, max_bytes: usize) -> String {
if value.len() <= max_bytes {
return value.to_string();
}
let mut end = max_bytes;
while end > 0 && !value.is_char_boundary(end) {
end -= 1;
}
value[..end].to_string()
}