use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
#[cfg(feature = "cache-redis")]
pub mod redis_backend;
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("cache connection error: {0}")]
Connection(String),
#[error("cache serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait Cache: Send + Sync + 'static {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError>;
async fn set(
&self,
key: &str,
value: &str,
ttl: Option<Duration>,
) -> Result<(), CacheError>;
async fn delete(&self, key: &str) -> Result<(), CacheError>;
async fn exists(&self, key: &str) -> Result<bool, CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn incr(
&self,
key: &str,
by: i64,
ttl: Option<Duration>,
) -> Result<i64, CacheError> {
let cur = self
.get(key)
.await?
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(0);
let new = cur.saturating_add(by);
self.set(key, &new.to_string(), ttl).await?;
Ok(new)
}
}
pub type BoxedCache = Arc<dyn Cache>;
pub async fn get_json<T: serde::de::DeserializeOwned>(
cache: &dyn Cache,
key: &str,
) -> Result<Option<T>, CacheError> {
let Some(s) = cache.get(key).await? else {
return Ok(None);
};
serde_json::from_str(&s)
.map(Some)
.map_err(|e| CacheError::Serialization(e.to_string()))
}
pub async fn set_json<T: serde::Serialize>(
cache: &dyn Cache,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> Result<(), CacheError> {
let s = serde_json::to_string(value)
.map_err(|e| CacheError::Serialization(e.to_string()))?;
cache.set(key, &s, ttl).await
}
pub async fn get_or_set<T, F, Fut>(
cache: &dyn Cache,
key: &str,
factory: F,
ttl: Option<Duration>,
) -> Result<T, CacheError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = T> + Send,
{
if let Some(cached) = get_json::<T>(cache, key).await? {
return Ok(cached);
}
let value = factory().await;
set_json(cache, key, &value, ttl).await?;
Ok(value)
}
pub struct NullCache;
#[async_trait]
impl Cache for NullCache {
async fn get(&self, _key: &str) -> Result<Option<String>, CacheError> {
Ok(None)
}
async fn set(&self, _key: &str, _value: &str, _ttl: Option<Duration>) -> Result<(), CacheError> {
Ok(())
}
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
Ok(())
}
async fn exists(&self, _key: &str) -> Result<bool, CacheError> {
Ok(false)
}
async fn clear(&self) -> Result<(), CacheError> {
Ok(())
}
}
struct CacheEntry {
value: String,
expires_at: Option<Instant>,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.expires_at.map_or(false, |t| Instant::now() > t)
}
}
pub struct InMemoryCache {
inner: tokio::sync::RwLock<HashMap<String, CacheEntry>>,
default_ttl: Option<Duration>,
}
impl InMemoryCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: None,
}
}
#[must_use]
pub fn with_default_ttl(default_ttl: Duration) -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: Some(default_ttl),
}
}
fn resolve_ttl(&self, ttl: Option<Duration>) -> Option<Instant> {
let effective = ttl.or(self.default_ttl)?;
Some(Instant::now() + effective)
}
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Cache for InMemoryCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).and_then(|e| {
if e.is_expired() { None } else { Some(e.value.clone()) }
}))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let expires_at = self.resolve_ttl(ttl);
let mut map = self.inner.write().await;
map.insert(key.to_owned(), CacheEntry { value: value.to_owned(), expires_at });
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
self.inner.write().await.remove(key);
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).map_or(false, |e| !e.is_expired()))
}
async fn clear(&self) -> Result<(), CacheError> {
self.inner.write().await.clear();
Ok(())
}
}