use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::policy::Policy;
#[derive(Clone)]
struct CachedValue<T> {
value: T,
expires_at: Instant,
}
impl<T> CachedValue<T> {
fn new(value: T, ttl: Duration) -> Self {
Self {
value,
expires_at: Instant::now() + ttl,
}
}
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
}
pub struct Cache<T> {
ttl: Duration,
cached: Arc<RwLock<Option<CachedValue<T>>>>,
}
impl<T> Clone for Cache<T> {
fn clone(&self) -> Self {
Self {
ttl: self.ttl,
cached: Arc::clone(&self.cached),
}
}
}
impl<T> Cache<T>
where
T: Clone + Send + Sync,
{
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
cached: Arc::new(RwLock::new(None)),
}
}
pub async fn invalidate(&self) {
let mut cached = self.cached.write().await;
*cached = None;
}
pub async fn has_cached_value(&self) -> bool {
let cached = self.cached.read().await;
matches!(&*cached, Some(cv) if !cv.is_expired())
}
}
#[async_trait::async_trait]
impl<T, E> Policy<E> for Cache<T>
where
T: Clone + Send + Sync,
E: Send + Sync,
{
async fn execute<F, Fut, R>(&self, f: F) -> Result<R, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<R, E>> + Send,
R: Send,
{
f().await
}
}
pub struct TypedCache<T> {
ttl: Duration,
cached: Arc<RwLock<Option<CachedValue<T>>>>,
}
impl<T: Clone> Clone for TypedCache<T> {
fn clone(&self) -> Self {
Self {
ttl: self.ttl,
cached: Arc::clone(&self.cached),
}
}
}
impl<T> TypedCache<T>
where
T: Clone + Send + Sync,
{
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
cached: Arc::new(RwLock::new(None)),
}
}
pub async fn execute<F, Fut, E>(&self, f: F) -> Result<T, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<T, E>> + Send,
E: Send + Sync,
{
{
let cached = self.cached.read().await;
if let Some(cv) = &*cached {
if !cv.is_expired() {
return Ok(cv.value.clone());
}
}
}
let result = f().await?;
{
let mut cached = self.cached.write().await;
*cached = Some(CachedValue::new(result.clone(), self.ttl));
}
Ok(result)
}
pub async fn invalidate(&self) {
let mut cached = self.cached.write().await;
*cached = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_typed_cache_caches_result() {
let cache = TypedCache::<String>::new(Duration::from_secs(60));
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let result = cache
.execute(|| {
let count = Arc::clone(&cc);
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>("result".to_string())
}
})
.await;
assert_eq!(result.unwrap(), "result");
assert_eq!(call_count.load(Ordering::SeqCst), 1);
let cc = Arc::clone(&call_count);
let result = cache
.execute(|| {
let count = Arc::clone(&cc);
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>("new_result".to_string())
}
})
.await;
assert_eq!(result.unwrap(), "result"); assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_typed_cache_invalidate() {
let cache = TypedCache::<String>::new(Duration::from_secs(60));
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let _ = cache
.execute(|| {
let count = Arc::clone(&cc);
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>("first".to_string())
}
})
.await;
assert_eq!(call_count.load(Ordering::SeqCst), 1);
cache.invalidate().await;
let cc = Arc::clone(&call_count);
let result = cache
.execute(|| {
let count = Arc::clone(&cc);
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>("second".to_string())
}
})
.await;
assert_eq!(result.unwrap(), "second");
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
}