use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use devboy_core::Result;
use secrecy::SecretString;
use crate::CredentialStore;
struct CachedEntry {
value: SecretString,
expires_at: Instant,
}
impl CachedEntry {
fn new(value: SecretString, ttl: Duration) -> Self {
Self {
value,
expires_at: Instant::now() + ttl,
}
}
fn is_fresh(&self) -> bool {
Instant::now() < self.expires_at
}
}
pub struct CachedStore<S: CredentialStore> {
inner: S,
ttl: Duration,
entries: RwLock<HashMap<String, CachedEntry>>,
}
impl<S: CredentialStore> CachedStore<S> {
pub fn new(inner: S, ttl: Duration) -> Self {
Self {
inner,
ttl,
entries: RwLock::new(HashMap::new()),
}
}
pub fn invalidate_all(&self) {
if let Ok(mut entries) = self.entries.write() {
entries.clear();
}
}
pub fn invalidate(&self, key: &str) {
if let Ok(mut entries) = self.entries.write() {
entries.remove(key);
}
}
fn caching_disabled(&self) -> bool {
self.ttl.is_zero()
}
fn lookup_fresh(&self, key: &str) -> Option<SecretString> {
let entries = self.entries.read().ok()?;
let entry = entries.get(key)?;
if entry.is_fresh() {
Some(entry.value.clone())
} else {
None
}
}
fn insert(&self, key: &str, value: &SecretString) {
let Ok(mut entries) = self.entries.write() else {
return;
};
entries.insert(key.to_string(), CachedEntry::new(value.clone(), self.ttl));
}
fn purge_expired_locked(&self) {
let Ok(mut entries) = self.entries.write() else {
return;
};
let now = Instant::now();
entries.retain(|_, e| e.expires_at > now);
}
}
impl<S: CredentialStore> std::fmt::Debug for CachedStore<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let size = self.entries.read().map(|e| e.len()).unwrap_or(0);
f.debug_struct("CachedStore")
.field("ttl_secs", &self.ttl.as_secs())
.field("cached_entries", &size)
.field("values", &"<redacted>")
.finish()
}
}
impl<S: CredentialStore> CredentialStore for CachedStore<S> {
fn store(&self, key: &str, value: &SecretString) -> Result<()> {
let res = self.inner.store(key, value);
self.invalidate(key);
res
}
fn get(&self, key: &str) -> Result<Option<SecretString>> {
if self.caching_disabled() {
return self.inner.get(key);
}
if let Some(v) = self.lookup_fresh(key) {
return Ok(Some(v));
}
self.purge_expired_locked();
match self.inner.get(key)? {
Some(value) => {
self.insert(key, &value);
Ok(Some(value))
}
None => Ok(None),
}
}
fn delete(&self, key: &str) -> Result<()> {
let res = self.inner.delete(key);
self.invalidate(key);
res
}
fn is_available(&self) -> bool {
self.inner.is_available()
}
fn is_writable(&self) -> bool {
self.inner.is_writable()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MemoryStore;
use secrecy::ExposeSecret;
use std::thread;
fn store_with_entry(k: &str, v: &str) -> MemoryStore {
MemoryStore::with_credentials([(k.to_string(), v.to_string())])
}
fn secret(s: &str) -> SecretString {
SecretString::from(s.to_string())
}
fn exposed(s: &Option<SecretString>) -> Option<&str> {
s.as_ref().map(|v| v.expose_secret())
}
#[test]
fn test_cache_hit_returns_value_without_hitting_inner() {
let cache = CachedStore::new(store_with_entry("a/b", "secret-A"), Duration::from_secs(60));
assert_eq!(exposed(&cache.get("a/b").unwrap()), Some("secret-A"));
let dbg = format!("{:?}", cache);
assert!(dbg.contains("cached_entries: 1"));
assert!(!dbg.contains("secret-A"));
}
#[test]
fn test_cache_respects_ttl_and_refetches() {
let cache = CachedStore::new(
store_with_entry("a/b", "secret-A"),
Duration::from_millis(50),
);
assert_eq!(exposed(&cache.get("a/b").unwrap()), Some("secret-A"));
thread::sleep(Duration::from_millis(80));
assert_eq!(exposed(&cache.get("a/b").unwrap()), Some("secret-A"));
}
#[test]
fn test_cache_zero_ttl_disables_caching() {
let cache = CachedStore::new(store_with_entry("a/b", "v"), Duration::ZERO);
assert_eq!(exposed(&cache.get("a/b").unwrap()), Some("v"));
let dbg = format!("{:?}", cache);
assert!(dbg.contains("cached_entries: 0"));
}
#[test]
fn test_store_invalidates_cache_entry() {
let inner = MemoryStore::new();
inner.store("k", &secret("v1")).unwrap();
let cache = CachedStore::new(inner, Duration::from_secs(60));
assert_eq!(exposed(&cache.get("k").unwrap()), Some("v1"));
cache.store("k", &secret("v2")).unwrap();
assert_eq!(exposed(&cache.get("k").unwrap()), Some("v2"));
}
#[test]
fn test_delete_invalidates_cache_entry() {
let inner = MemoryStore::new();
inner.store("k", &secret("v1")).unwrap();
let cache = CachedStore::new(inner, Duration::from_secs(60));
assert_eq!(exposed(&cache.get("k").unwrap()), Some("v1"));
cache.delete("k").unwrap();
assert!(cache.get("k").unwrap().is_none());
}
#[test]
fn test_missing_keys_not_cached() {
let inner = MemoryStore::new();
let cache = CachedStore::new(inner, Duration::from_secs(60));
assert!(cache.get("k").unwrap().is_none());
cache.store("k", &secret("later")).unwrap();
assert_eq!(exposed(&cache.get("k").unwrap()), Some("later"));
}
#[test]
fn test_invalidate_all_drops_every_entry() {
let inner = MemoryStore::with_credentials([
("a".to_string(), "1".to_string()),
("b".to_string(), "2".to_string()),
]);
let cache = CachedStore::new(inner, Duration::from_secs(60));
cache.get("a").unwrap();
cache.get("b").unwrap();
cache.invalidate_all();
let dbg = format!("{:?}", cache);
assert!(dbg.contains("cached_entries: 0"));
}
#[test]
fn test_writable_and_available_delegate_to_inner() {
let cache = CachedStore::new(MemoryStore::new(), Duration::from_secs(10));
assert!(cache.is_writable());
assert!(cache.is_available());
}
}