use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
#[cfg(feature = "cache-redis")]
pub mod redis_backend;
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("cache connection error: {0}")]
Connection(String),
#[error("cache serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait Cache: Send + Sync + 'static {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError>;
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError>;
async fn delete(&self, key: &str) -> Result<(), CacheError>;
async fn exists(&self, key: &str) -> Result<bool, CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn incr(&self, key: &str, by: i64, ttl: Option<Duration>) -> Result<i64, CacheError> {
let cur = self
.get(key)
.await?
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(0);
let new = cur.saturating_add(by);
self.set(key, &new.to_string(), ttl).await?;
Ok(new)
}
}
pub type BoxedCache = Arc<dyn Cache>;
#[cfg(feature = "config")]
#[must_use]
pub fn from_settings(s: &crate::config::CacheSettings) -> BoxedCache {
match s.backend.as_deref() {
Some("redis") => {
#[cfg(feature = "cache-redis")]
{
if let Some(url) = s.redis_url.as_deref().filter(|u| !u.is_empty()) {
return Arc::new(redis_backend::RedisCache::new(url));
}
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but redis_url is unset; falling back to InMemoryCache",
);
}
#[cfg(not(feature = "cache-redis"))]
{
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but the `cache-redis` feature isn't compiled in; falling back to InMemoryCache",
);
}
Arc::new(InMemoryCache::new())
}
Some("null" | "none") => Arc::new(NullCache),
Some("memory") | None => Arc::new(InMemoryCache::new()),
Some(other) => {
tracing::warn!(
target: "rustango::cache",
backend = %other,
"unknown cache.backend value; falling back to InMemoryCache",
);
Arc::new(InMemoryCache::new())
}
}
}
pub async fn get_json<T: serde::de::DeserializeOwned>(
cache: &dyn Cache,
key: &str,
) -> Result<Option<T>, CacheError> {
let Some(s) = cache.get(key).await? else {
return Ok(None);
};
serde_json::from_str(&s)
.map(Some)
.map_err(|e| CacheError::Serialization(e.to_string()))
}
pub async fn set_json<T: serde::Serialize>(
cache: &dyn Cache,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> Result<(), CacheError> {
let s = serde_json::to_string(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
cache.set(key, &s, ttl).await
}
pub async fn get_or_set<T, F, Fut>(
cache: &dyn Cache,
key: &str,
factory: F,
ttl: Option<Duration>,
) -> Result<T, CacheError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = T> + Send,
{
if let Some(cached) = get_json::<T>(cache, key).await? {
return Ok(cached);
}
let value = factory().await;
set_json(cache, key, &value, ttl).await?;
Ok(value)
}
pub struct NullCache;
#[async_trait]
impl Cache for NullCache {
async fn get(&self, _key: &str) -> Result<Option<String>, CacheError> {
Ok(None)
}
async fn set(
&self,
_key: &str,
_value: &str,
_ttl: Option<Duration>,
) -> Result<(), CacheError> {
Ok(())
}
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
Ok(())
}
async fn exists(&self, _key: &str) -> Result<bool, CacheError> {
Ok(false)
}
async fn clear(&self) -> Result<(), CacheError> {
Ok(())
}
}
struct CacheEntry {
value: String,
expires_at: Option<Instant>,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.expires_at.map_or(false, |t| Instant::now() > t)
}
}
pub struct InMemoryCache {
inner: tokio::sync::RwLock<HashMap<String, CacheEntry>>,
default_ttl: Option<Duration>,
}
impl InMemoryCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: None,
}
}
#[must_use]
pub fn with_default_ttl(default_ttl: Duration) -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: Some(default_ttl),
}
}
fn resolve_ttl(&self, ttl: Option<Duration>) -> Option<Instant> {
let effective = ttl.or(self.default_ttl)?;
Some(Instant::now() + effective)
}
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Cache for InMemoryCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).and_then(|e| {
if e.is_expired() {
None
} else {
Some(e.value.clone())
}
}))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let expires_at = self.resolve_ttl(ttl);
let mut map = self.inner.write().await;
map.insert(
key.to_owned(),
CacheEntry {
value: value.to_owned(),
expires_at,
},
);
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
self.inner.write().await.remove(key);
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).map_or(false, |e| !e.is_expired()))
}
async fn clear(&self) -> Result<(), CacheError> {
self.inner.write().await.clear();
Ok(())
}
}
#[cfg(all(test, feature = "config"))]
mod settings_tests {
use super::*;
#[tokio::test]
async fn unset_backend_returns_inmemory() {
let s = crate::config::CacheSettings::default();
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn memory_backend_works() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("memory".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn null_backend_drops_writes() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("null".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert!(cache.get("k").await.unwrap().is_none());
}
#[tokio::test]
async fn unknown_backend_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("typo".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn redis_without_url_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("redis".into());
let cache = from_settings(&s);
#[cfg(not(feature = "cache-redis"))]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[cfg(feature = "cache-redis")]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
}
}