use std::sync::Arc;
use arc_swap::ArcSwapOption;
use bon::bon;
use crate::{
platform::{Duration, SystemTime},
secrets::{Secret, SecretOutput},
};
struct CachedEntry<T: Clone> {
output: SecretOutput<T>,
cached_at: SystemTime,
}
struct CachedSecretInner<S: Secret> {
secret: S,
ttl: Option<Duration>,
cached: ArcSwapOption<CachedEntry<S::Output>>,
refresh_lock: tokio::sync::Mutex<()>,
}
#[derive(Clone)]
pub struct CachedSecret<S: Secret> {
inner: Arc<CachedSecretInner<S>>,
}
impl<S: Secret> std::fmt::Debug for CachedSecret<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedSecret").finish_non_exhaustive()
}
}
#[bon]
impl<S: Secret> CachedSecret<S> {
#[builder]
pub fn new(
secret: S,
ttl: Option<Duration>,
) -> CachedSecret<S> {
CachedSecret {
inner: Arc::new(CachedSecretInner {
secret,
ttl,
cached: ArcSwapOption::empty(),
refresh_lock: tokio::sync::Mutex::new(()),
}),
}
}
async fn reload(&self) -> Result<SecretOutput<S::Output>, S::Error> {
let output = self.inner.secret.get_secret_value().await?;
self.inner.cached.store(Some(Arc::new(CachedEntry {
output: output.clone(),
cached_at: SystemTime::now(),
})));
Ok(output)
}
pub fn invalidate(&self) {
self.inner.cached.store(None);
}
fn is_expired(entry: &CachedEntry<S::Output>, ttl: Duration) -> bool {
SystemTime::now() >= entry.cached_at + ttl
}
fn is_valid(entry: &CachedEntry<S::Output>, ttl: Option<Duration>) -> bool {
ttl.is_none_or(|ttl| !Self::is_expired(entry, ttl))
}
}
impl<S: Secret> Secret for CachedSecret<S> {
type Error = S::Error;
type Output = S::Output;
async fn get_secret_value(&self) -> Result<SecretOutput<Self::Output>, Self::Error> {
if let Some(entry) = self.inner.cached.load_full()
&& Self::is_valid(&entry, self.inner.ttl)
{
return Ok(entry.output.clone());
}
let _lock = self.inner.refresh_lock.lock().await;
if let Some(entry) = self.inner.cached.load_full()
&& Self::is_valid(&entry, self.inner.ttl)
{
return Ok(entry.output.clone());
}
self.reload().await
}
}