use std::collections::HashMap;
use std::sync::Mutex;
use crate::AttestationVerdict;
#[derive(Default)]
pub struct AttestationCache {
inner: Mutex<HashMap<String, Entry>>,
}
struct Entry {
verdict: AttestationVerdict,
expires_at_unix: u64,
}
impl AttestationCache {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, model: &str, now_unix: u64) -> Option<AttestationVerdict> {
let map = self.inner.lock().ok()?;
let entry = map.get(model)?;
(now_unix < entry.expires_at_unix).then(|| entry.verdict.clone())
}
pub fn put(&self, verdict: AttestationVerdict, ttl_seconds: u64, now_unix: u64) {
if let Ok(mut map) = self.inner.lock() {
map.insert(
verdict.model.clone(),
Entry {
expires_at_unix: now_unix.saturating_add(ttl_seconds),
verdict,
},
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn verdict(model: &str) -> AttestationVerdict {
AttestationVerdict::unverified(model, format!("test-nonce-{model}"), 0)
}
#[test]
fn returns_a_cached_verdict_within_ttl_and_drops_it_after() {
let cache = AttestationCache::new();
cache.put(verdict("m"), 600, 1_000);
assert!(cache.get("m", 1_000).is_some());
assert!(
cache.get("m", 1_599).is_some(),
"still fresh just before expiry"
);
assert!(cache.get("m", 1_600).is_none(), "expired at exactly ttl");
assert!(cache.get("m", 2_000).is_none());
}
#[test]
fn miss_for_an_unknown_model() {
let cache = AttestationCache::new();
assert!(cache.get("never-cached", 0).is_none());
}
}