use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{OnceCell, RwLock};
#[derive(Debug)]
pub(crate) struct ExpiringCache<T, E> {
buffer_time: Duration,
value: Arc<RwLock<OnceCell<(T, SystemTime)>>>,
_phantom: PhantomData<E>,
}
impl<T, E> Clone for ExpiringCache<T, E> {
fn clone(&self) -> Self {
Self {
buffer_time: self.buffer_time,
value: self.value.clone(),
_phantom: Default::default(),
}
}
}
impl<T, E> ExpiringCache<T, E>
where
T: Clone,
{
pub fn new(buffer_time: Duration) -> Self {
ExpiringCache {
buffer_time,
value: Arc::new(RwLock::new(OnceCell::new())),
_phantom: Default::default(),
}
}
#[cfg(test)]
async fn get(&self) -> Option<T>
where
T: Clone,
{
self.value
.read()
.await
.get()
.cloned()
.map(|(creds, _expiry)| creds)
}
pub async fn get_or_load<F, Fut>(&self, f: F) -> Result<T, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<(T, SystemTime), E>>,
{
let lock = self.value.read().await;
let future = lock.get_or_try_init(f);
future.await.map(|(value, _expiry)| value.clone())
}
pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<T> {
if let Some((value, expiry)) = self.value.read().await.get() {
if !expired(*expiry, self.buffer_time, now) {
return Some(value.clone());
}
}
let mut lock = self.value.write().await;
if let Some((_value, expiration)) = lock.get() {
if expired(*expiration, self.buffer_time, now) {
*lock = OnceCell::new();
}
}
None
}
}
fn expired(expiration: SystemTime, buffer_time: Duration, now: SystemTime) -> bool {
now >= (expiration - buffer_time)
}
#[cfg(test)]
mod tests {
use super::{expired, ExpiringCache};
use aws_types::credentials::CredentialsError;
use aws_types::Credentials;
use std::time::{Duration, SystemTime};
use tracing_test::traced_test;
fn credentials(expired_secs: u64) -> Result<(Credentials, SystemTime), CredentialsError> {
let expiry = epoch_secs(expired_secs);
let creds = Credentials::new("test", "test", None, Some(expiry), "test");
Ok((creds, expiry))
}
fn epoch_secs(secs: u64) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
}
#[test]
fn expired_check() {
let ts = epoch_secs(100);
assert!(expired(ts, Duration::from_secs(10), epoch_secs(1000)));
assert!(expired(ts, Duration::from_secs(10), epoch_secs(90)));
assert!(!expired(ts, Duration::from_secs(10), epoch_secs(10)));
}
#[traced_test]
#[tokio::test]
async fn cache_clears_if_expired_only() {
let cache = ExpiringCache::new(Duration::from_secs(10));
assert!(cache
.yield_or_clear_if_expired(epoch_secs(100))
.await
.is_none());
cache
.get_or_load(|| async { credentials(100) })
.await
.unwrap();
assert_eq!(Some(epoch_secs(100)), cache.get().await.unwrap().expiry());
assert_eq!(
Some(epoch_secs(100)),
cache
.yield_or_clear_if_expired(epoch_secs(10))
.await
.unwrap()
.expiry()
);
assert!(cache
.yield_or_clear_if_expired(epoch_secs(500))
.await
.is_none());
assert!(cache.get().await.is_none());
}
}