use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
#[cfg(feature = "cache-redis")]
pub mod redis_backend;
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("cache connection error: {0}")]
Connection(String),
#[error("cache serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait Cache: Send + Sync + 'static {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError>;
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError>;
async fn delete(&self, key: &str) -> Result<(), CacheError>;
async fn exists(&self, key: &str) -> Result<bool, CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn incr(&self, key: &str, by: i64, ttl: Option<Duration>) -> Result<i64, CacheError> {
let cur = self
.get(key)
.await?
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(0);
let new = cur.saturating_add(by);
self.set(key, &new.to_string(), ttl).await?;
Ok(new)
}
}
pub type BoxedCache = Arc<dyn Cache>;
#[cfg(feature = "config")]
#[must_use]
pub fn from_settings(s: &crate::config::CacheSettings) -> BoxedCache {
match s.backend.as_deref() {
Some("redis") => {
#[cfg(feature = "cache-redis")]
{
if s.redis_url.as_deref().is_some_and(|u| !u.is_empty()) {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" requires async construction; \
build `RedisCache::new(url).await?` and pass the Arc \
directly. Falling back to InMemoryCache."
);
} else {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but redis_url is unset; falling back to InMemoryCache",
);
}
}
#[cfg(not(feature = "cache-redis"))]
{
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but the `cache-redis` feature isn't compiled in; falling back to InMemoryCache",
);
}
Arc::new(InMemoryCache::new())
}
Some("null" | "none") => Arc::new(NullCache),
Some("file") => file_from_settings_or_warn(s),
Some("memory") | None => Arc::new(InMemoryCache::new()),
Some(other) => {
tracing::warn!(
target: "rustango::cache",
backend = %other,
"unknown cache.backend value; falling back to InMemoryCache",
);
Arc::new(InMemoryCache::new())
}
}
}
#[cfg(feature = "config")]
fn file_from_settings_or_warn(s: &crate::config::CacheSettings) -> BoxedCache {
match s.file_cache_dir.as_deref() {
Some(dir) => Arc::new(FileCache::new(dir)),
None => {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"file\" but [cache].file_cache_dir is unset; \
falling back to InMemoryCache.",
);
Arc::new(InMemoryCache::new())
}
}
}
pub async fn get_json<T: serde::de::DeserializeOwned>(
cache: &dyn Cache,
key: &str,
) -> Result<Option<T>, CacheError> {
let Some(s) = cache.get(key).await? else {
return Ok(None);
};
serde_json::from_str(&s)
.map(Some)
.map_err(|e| CacheError::Serialization(e.to_string()))
}
pub async fn set_json<T: serde::Serialize>(
cache: &dyn Cache,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> Result<(), CacheError> {
let s = serde_json::to_string(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
cache.set(key, &s, ttl).await
}
pub async fn get_or_set<T, F, Fut>(
cache: &dyn Cache,
key: &str,
factory: F,
ttl: Option<Duration>,
) -> Result<T, CacheError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = T> + Send,
{
if let Some(cached) = get_json::<T>(cache, key).await? {
return Ok(cached);
}
let value = factory().await;
set_json(cache, key, &value, ttl).await?;
Ok(value)
}
pub struct NullCache;
#[async_trait]
impl Cache for NullCache {
async fn get(&self, _key: &str) -> Result<Option<String>, CacheError> {
Ok(None)
}
async fn set(
&self,
_key: &str,
_value: &str,
_ttl: Option<Duration>,
) -> Result<(), CacheError> {
Ok(())
}
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
Ok(())
}
async fn exists(&self, _key: &str) -> Result<bool, CacheError> {
Ok(false)
}
async fn clear(&self) -> Result<(), CacheError> {
Ok(())
}
}
struct CacheEntry {
value: String,
expires_at: Option<Instant>,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.expires_at.map_or(false, |t| Instant::now() > t)
}
}
pub struct InMemoryCache {
inner: tokio::sync::RwLock<HashMap<String, CacheEntry>>,
default_ttl: Option<Duration>,
}
impl InMemoryCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: None,
}
}
#[must_use]
pub fn with_default_ttl(default_ttl: Duration) -> Self {
Self {
inner: tokio::sync::RwLock::new(HashMap::new()),
default_ttl: Some(default_ttl),
}
}
fn resolve_ttl(&self, ttl: Option<Duration>) -> Option<Instant> {
let effective = ttl.or(self.default_ttl)?;
Some(Instant::now() + effective)
}
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Cache for InMemoryCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).and_then(|e| {
if e.is_expired() {
None
} else {
Some(e.value.clone())
}
}))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let expires_at = self.resolve_ttl(ttl);
let mut map = self.inner.write().await;
map.insert(
key.to_owned(),
CacheEntry {
value: value.to_owned(),
expires_at,
},
);
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
self.inner.write().await.remove(key);
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let map = self.inner.read().await;
Ok(map.get(key).map_or(false, |e| !e.is_expired()))
}
async fn clear(&self) -> Result<(), CacheError> {
self.inner.write().await.clear();
Ok(())
}
}
pub struct FileCache {
dir: std::path::PathBuf,
}
impl FileCache {
#[must_use]
pub fn new(dir: impl Into<std::path::PathBuf>) -> Self {
Self { dir: dir.into() }
}
#[must_use]
pub fn dir(&self) -> &std::path::Path {
&self.dir
}
fn key_path(&self, key: &str) -> std::path::PathBuf {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(key.as_bytes());
let mut name = String::with_capacity(64 + 6);
for b in hash {
use std::fmt::Write as _;
let _ = write!(&mut name, "{b:02x}");
}
name.push_str(".cache");
self.dir.join(name)
}
fn now_unix_secs() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn encode(value: &str, ttl: Option<Duration>) -> Vec<u8> {
let expires_at = ttl
.map(|d| Self::now_unix_secs().saturating_add(d.as_secs() as i64))
.unwrap_or(0);
let mut out = Vec::with_capacity(8 + value.len());
out.extend_from_slice(&expires_at.to_be_bytes());
out.extend_from_slice(value.as_bytes());
out
}
fn decode(buf: &[u8]) -> Option<(String, bool /* expired */)> {
if buf.len() < 8 {
return None;
}
let mut ts = [0u8; 8];
ts.copy_from_slice(&buf[..8]);
let expires_at = i64::from_be_bytes(ts);
let value = std::str::from_utf8(&buf[8..]).ok()?.to_owned();
let expired = expires_at != 0 && Self::now_unix_secs() >= expires_at;
Some((value, expired))
}
}
#[async_trait]
impl Cache for FileCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let path = self.key_path(key);
let buf = match std::fs::read(&path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(CacheError::Connection(format!("read: {e}"))),
};
match Self::decode(&buf) {
Some((_, true)) => {
let _ = std::fs::remove_file(&path);
Ok(None)
}
Some((v, false)) => Ok(Some(v)),
None => {
let _ = std::fs::remove_file(&path);
Ok(None)
}
}
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
std::fs::create_dir_all(&self.dir)
.map_err(|e| CacheError::Connection(format!("create_dir_all: {e}")))?;
let path = self.key_path(key);
std::fs::write(&path, Self::encode(value, ttl))
.map_err(|e| CacheError::Connection(format!("write: {e}")))?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let path = self.key_path(key);
match std::fs::remove_file(&path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(CacheError::Connection(format!("remove_file: {e}"))),
}
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
Ok(self.get(key).await?.is_some())
}
async fn clear(&self) -> Result<(), CacheError> {
let entries = match std::fs::read_dir(&self.dir) {
Ok(e) => e,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => return Err(CacheError::Connection(format!("read_dir: {e}"))),
};
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("cache") {
let _ = std::fs::remove_file(&path);
}
}
Ok(())
}
}
#[cfg(all(test, feature = "config"))]
mod settings_tests {
use super::*;
#[tokio::test]
async fn unset_backend_returns_inmemory() {
let s = crate::config::CacheSettings::default();
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn memory_backend_works() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("memory".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn null_backend_drops_writes() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("null".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert!(cache.get("k").await.unwrap().is_none());
}
#[tokio::test]
async fn unknown_backend_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("typo".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn redis_without_url_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("redis".into());
let cache = from_settings(&s);
#[cfg(not(feature = "cache-redis"))]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[cfg(feature = "cache-redis")]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
}
}