use std::sync::Arc;
use std::time::Duration;
use moka::notification::RemovalCause;
use moka::sync::Cache;
use crate::audit::{AuditEvent, AuditSink, CacheEvent};
use crate::hardening::HardeningToken;
use crate::SecretString;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheKey {
pub scheme: &'static str,
pub identity: String,
}
impl CacheKey {
pub fn new(scheme: &'static str, identity: impl Into<String>) -> Self {
Self {
scheme,
identity: identity.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub enum CachePolicy {
#[default]
Disabled,
Process {
ttl: Duration,
capacity: u64,
},
#[cfg(feature = "cache-persistent")]
Persistent(PersistentPolicy),
}
#[cfg(feature = "cache-persistent")]
#[derive(Debug, Clone)]
pub struct PersistentPolicy {
pub ttl: Duration,
pub path: std::path::PathBuf,
pub keyring_service: String,
pub keyring_account: String,
pub capacity: u64,
}
#[cfg(feature = "cache-persistent")]
impl PersistentPolicy {
pub const MAX_TTL: Duration = Duration::from_secs(3600);
pub const DEFAULT_TTL: Duration = Duration::from_secs(300);
pub fn defaults() -> Option<Self> {
let dir = dirs::cache_dir()?.join("hasp");
Some(Self {
ttl: Self::DEFAULT_TTL,
path: dir.join("cache.bin"),
keyring_service: "hasp".into(),
keyring_account: format!("cache:{}", whoami_or_unknown()),
capacity: 1024,
})
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = if ttl > Self::MAX_TTL {
Self::MAX_TTL
} else {
ttl
};
self
}
}
#[cfg(feature = "cache-persistent")]
fn whoami_or_unknown() -> String {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".into())
}
impl CachePolicy {
pub fn process_default() -> Self {
Self::Process {
ttl: Duration::from_secs(300),
capacity: 1024,
}
}
}
#[derive(Clone)]
pub struct ProcessCache {
inner: Cache<CacheKey, Arc<SecretString>>,
}
impl ProcessCache {
pub fn new(
policy: &CachePolicy,
_token: HardeningToken,
audit_sink: Option<Arc<dyn AuditSink>>,
) -> Option<Self> {
match policy {
#[cfg(feature = "cache-persistent")]
CachePolicy::Persistent(p) => {
let process = CachePolicy::Process {
ttl: p.ttl,
capacity: p.capacity,
};
Self::new(&process, _token, audit_sink)
}
CachePolicy::Disabled => None,
CachePolicy::Process { ttl, capacity } => {
let sink = audit_sink.clone();
let inner = Cache::builder()
.max_capacity(*capacity)
.time_to_live(*ttl)
.eviction_listener(move |k: Arc<CacheKey>, v, cause| {
drop(v);
if matches!(cause, RemovalCause::Expired) {
if let Some(s) = &sink {
s.emit(&AuditEvent::cache(CacheEvent::Expire, k.scheme));
}
}
})
.build();
Some(Self { inner })
}
}
}
pub fn get(&self, key: &CacheKey) -> Option<Arc<SecretString>> {
self.inner.get(key)
}
pub fn insert(&self, key: CacheKey, value: Arc<SecretString>) {
self.inner.insert(key, value);
}
pub fn invalidate(&self, key: &CacheKey) {
self.inner.invalidate(key);
}
pub fn invalidate_all(&self) {
self.inner.invalidate_all();
}
pub fn run_pending_tasks(&self) {
self.inner.run_pending_tasks();
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
}
impl std::fmt::Debug for ProcessCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProcessCache")
.field("entry_count", &self.inner.entry_count())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hardening;
use secrecy::ExposeSecret;
fn token() -> HardeningToken {
hardening::install().expect("hardening install should succeed in tests")
}
#[test]
fn disabled_policy_returns_none() {
let policy = CachePolicy::Disabled;
assert!(ProcessCache::new(&policy, token(), None).is_none());
}
#[test]
fn process_policy_returns_some() {
let policy = CachePolicy::process_default();
assert!(ProcessCache::new(&policy, token(), None).is_some());
}
#[test]
fn insert_then_get_returns_same_secret_bytes() {
let cache = ProcessCache::new(&CachePolicy::process_default(), token(), None).unwrap();
let key = CacheKey::new("env", "USER");
let secret = Arc::new(SecretString::new("alice".to_string().into()));
cache.insert(key.clone(), secret.clone());
let got = cache.get(&key).expect("cache hit");
assert_eq!(got.expose_secret(), "alice");
}
#[test]
fn invalidate_removes_entry() {
let cache = ProcessCache::new(&CachePolicy::process_default(), token(), None).unwrap();
let key = CacheKey::new("env", "USER");
cache.insert(
key.clone(),
Arc::new(SecretString::new("v".to_string().into())),
);
cache.invalidate(&key);
cache.run_pending_tasks();
assert!(cache.get(&key).is_none());
}
#[test]
fn ttl_expiry_returns_none() {
let policy = CachePolicy::Process {
ttl: Duration::from_millis(50),
capacity: 16,
};
let cache = ProcessCache::new(&policy, token(), None).unwrap();
let key = CacheKey::new("env", "USER");
cache.insert(
key.clone(),
Arc::new(SecretString::new("v".to_string().into())),
);
std::thread::sleep(Duration::from_millis(120));
cache.run_pending_tasks();
assert!(cache.get(&key).is_none());
}
#[derive(Default)]
struct TestSink {
events: std::sync::Mutex<Vec<String>>,
}
impl AuditSink for TestSink {
fn emit(&self, event: &AuditEvent) {
if let Ok(mut v) = self.events.lock() {
v.push(event.event.to_string());
}
}
}
#[test]
fn ttl_expiry_emits_cache_expire_event() {
let sink: Arc<TestSink> = Arc::new(TestSink::default());
let policy = CachePolicy::Process {
ttl: Duration::from_millis(50),
capacity: 16,
};
let cache = ProcessCache::new(&policy, token(), Some(sink.clone())).unwrap();
let key = CacheKey::new("env", "EXPIRE_TEST");
cache.insert(
key.clone(),
Arc::new(SecretString::new("v".to_string().into())),
);
std::thread::sleep(Duration::from_millis(120));
let _ = cache.get(&key);
cache.run_pending_tasks();
let events = sink.events.lock().unwrap().clone();
assert!(
events.iter().any(|e| e == "cache.expire"),
"expected a cache.expire event, got {events:?}"
);
}
#[test]
fn capacity_eviction_drops_oldest() {
let policy = CachePolicy::Process {
ttl: Duration::from_secs(60),
capacity: 2,
};
let cache = ProcessCache::new(&policy, token(), None).unwrap();
cache.insert(
CacheKey::new("env", "A"),
Arc::new(SecretString::new("a".to_string().into())),
);
cache.insert(
CacheKey::new("env", "B"),
Arc::new(SecretString::new("b".to_string().into())),
);
cache.insert(
CacheKey::new("env", "C"),
Arc::new(SecretString::new("c".to_string().into())),
);
cache.run_pending_tasks();
assert!(cache.entry_count() <= 2);
}
#[test]
fn scheme_namespacing_prevents_cross_backend_alias() {
let cache = ProcessCache::new(&CachePolicy::process_default(), token(), None).unwrap();
let k1 = CacheKey::new("env", "DUP");
let k2 = CacheKey::new("file", "DUP");
cache.insert(
k1.clone(),
Arc::new(SecretString::new("env-value".to_string().into())),
);
cache.insert(
k2.clone(),
Arc::new(SecretString::new("file-value".to_string().into())),
);
assert_eq!(cache.get(&k1).unwrap().expose_secret(), "env-value");
assert_eq!(cache.get(&k2).unwrap().expose_secret(), "file-value");
}
}