use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
pub mod db_backend;
#[cfg(feature = "cache-redis")]
pub mod redis_backend;
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
pub use db_backend::DatabaseCache;
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("cache connection error: {0}")]
Connection(String),
#[error("cache serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait Cache: Send + Sync + 'static {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError>;
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError>;
async fn delete(&self, key: &str) -> Result<(), CacheError>;
async fn exists(&self, key: &str) -> Result<bool, CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn incr(&self, key: &str, by: i64, ttl: Option<Duration>) -> Result<i64, CacheError> {
let cur = self
.get(key)
.await?
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(0);
let new = cur.saturating_add(by);
self.set(key, &new.to_string(), ttl).await?;
Ok(new)
}
async fn add(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<bool, CacheError> {
if self.exists(key).await? {
return Ok(false);
}
self.set(key, value, ttl).await?;
Ok(true)
}
async fn touch(&self, key: &str, ttl: Option<Duration>) -> Result<bool, CacheError> {
match self.get(key).await? {
Some(value) => {
self.set(key, &value, ttl).await?;
Ok(true)
}
None => Ok(false),
}
}
async fn get_many(&self, keys: &[&str]) -> Result<HashMap<String, String>, CacheError> {
let mut out = HashMap::with_capacity(keys.len());
for k in keys {
if let Some(v) = self.get(k).await? {
out.insert((*k).to_owned(), v);
}
}
Ok(out)
}
async fn set_many(
&self,
entries: &[(&str, &str)],
ttl: Option<Duration>,
) -> Result<(), CacheError> {
for (k, v) in entries {
self.set(k, v, ttl).await?;
}
Ok(())
}
async fn delete_many(&self, keys: &[&str]) -> Result<(), CacheError> {
for k in keys {
self.delete(k).await?;
}
Ok(())
}
async fn has_key(&self, key: &str) -> Result<bool, CacheError> {
self.exists(key).await
}
async fn decr(&self, key: &str, by: i64, ttl: Option<Duration>) -> Result<i64, CacheError> {
self.incr(key, by.saturating_neg(), ttl).await
}
async fn get_or(&self, key: &str, default: &str) -> Result<String, CacheError> {
Ok(self.get(key).await?.unwrap_or_else(|| default.to_owned()))
}
}
pub type BoxedCache = Arc<dyn Cache>;
#[cfg(feature = "config")]
#[must_use]
pub fn from_settings(s: &crate::config::CacheSettings) -> BoxedCache {
match s.backend.as_deref() {
Some("redis") => {
#[cfg(feature = "cache-redis")]
{
if s.redis_url.as_deref().is_some_and(|u| !u.is_empty()) {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" requires async construction; \
build `RedisCache::new(url).await?` and pass the Arc \
directly. Falling back to InMemoryCache."
);
} else {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but redis_url is unset; falling back to InMemoryCache",
);
}
}
#[cfg(not(feature = "cache-redis"))]
{
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"redis\" but the `cache-redis` feature isn't compiled in; falling back to InMemoryCache",
);
}
Arc::new(InMemoryCache::new())
}
Some("null" | "none") => Arc::new(NullCache),
Some("file") => file_from_settings_or_warn(s),
Some("db" | "database") => {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"db\" requires async construction with a `&Pool`; \
build `DatabaseCache::new(pool, table)` and call `ensure_table().await` \
then pass the Arc directly. Falling back to InMemoryCache."
);
Arc::new(InMemoryCache::new())
}
Some("memory") | None => Arc::new(InMemoryCache::new()),
Some(other) => {
tracing::warn!(
target: "rustango::cache",
backend = %other,
"unknown cache.backend value; falling back to InMemoryCache",
);
Arc::new(InMemoryCache::new())
}
}
}
#[cfg(feature = "config")]
fn file_from_settings_or_warn(s: &crate::config::CacheSettings) -> BoxedCache {
match s.file_cache_dir.as_deref() {
Some(dir) => Arc::new(FileCache::new(dir)),
None => {
tracing::warn!(
target: "rustango::cache",
"cache.backend = \"file\" but [cache].file_cache_dir is unset; \
falling back to InMemoryCache.",
);
Arc::new(InMemoryCache::new())
}
}
}
pub async fn get_json<T: serde::de::DeserializeOwned>(
cache: &dyn Cache,
key: &str,
) -> Result<Option<T>, CacheError> {
let Some(s) = cache.get(key).await? else {
return Ok(None);
};
serde_json::from_str(&s)
.map(Some)
.map_err(|e| CacheError::Serialization(e.to_string()))
}
pub async fn set_json<T: serde::Serialize>(
cache: &dyn Cache,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> Result<(), CacheError> {
let s = serde_json::to_string(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
cache.set(key, &s, ttl).await
}
pub async fn get_or_set<T, F, Fut>(
cache: &dyn Cache,
key: &str,
factory: F,
ttl: Option<Duration>,
) -> Result<T, CacheError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = T> + Send,
{
if let Some(cached) = get_json::<T>(cache, key).await? {
return Ok(cached);
}
let value = factory().await;
set_json(cache, key, &value, ttl).await?;
Ok(value)
}
pub struct NullCache;
#[async_trait]
impl Cache for NullCache {
async fn get(&self, _key: &str) -> Result<Option<String>, CacheError> {
Ok(None)
}
async fn set(
&self,
_key: &str,
_value: &str,
_ttl: Option<Duration>,
) -> Result<(), CacheError> {
Ok(())
}
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
Ok(())
}
async fn exists(&self, _key: &str) -> Result<bool, CacheError> {
Ok(false)
}
async fn clear(&self) -> Result<(), CacheError> {
Ok(())
}
}
struct CacheEntry {
value: String,
expires_at: Option<Instant>,
last_used: AtomicU64,
size: usize,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.expires_at.map_or(false, |t| Instant::now() > t)
}
}
struct Store {
map: HashMap<String, CacheEntry>,
used_bytes: usize,
}
pub const DEFAULT_MAX_BYTES: usize = 256 * 1024 * 1024;
pub const DEFAULT_MAX_ENTRIES: usize = 100_000;
pub struct InMemoryCache {
inner: tokio::sync::RwLock<Store>,
default_ttl: Option<Duration>,
max_bytes: usize,
max_entries: usize,
tick: AtomicU64,
}
impl InMemoryCache {
#[must_use]
pub fn new() -> Self {
Self::build(None, DEFAULT_MAX_BYTES, DEFAULT_MAX_ENTRIES)
}
#[must_use]
pub fn with_default_ttl(default_ttl: Duration) -> Self {
Self::build(Some(default_ttl), DEFAULT_MAX_BYTES, DEFAULT_MAX_ENTRIES)
}
#[must_use]
pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
self.max_bytes = max_bytes;
self
}
#[must_use]
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
fn build(default_ttl: Option<Duration>, max_bytes: usize, max_entries: usize) -> Self {
Self {
inner: tokio::sync::RwLock::new(Store {
map: HashMap::new(),
used_bytes: 0,
}),
default_ttl,
max_bytes,
max_entries,
tick: AtomicU64::new(0),
}
}
fn resolve_ttl(&self, ttl: Option<Duration>) -> Option<Instant> {
let effective = ttl.or(self.default_ttl)?;
Some(Instant::now() + effective)
}
fn next_tick(&self) -> u64 {
self.tick.fetch_add(1, Ordering::Relaxed)
}
fn over_budget(&self, s: &Store) -> bool {
(self.max_bytes > 0 && s.used_bytes > self.max_bytes)
|| (self.max_entries > 0 && s.map.len() > self.max_entries)
}
fn evict_locked(&self, store: &mut Store) {
if !self.over_budget(store) {
return;
}
let expired: Vec<String> = store
.map
.iter()
.filter(|(_, e)| e.is_expired())
.map(|(k, _)| k.clone())
.collect();
for k in expired {
if let Some(e) = store.map.remove(&k) {
store.used_bytes = store.used_bytes.saturating_sub(e.size);
}
}
while self.over_budget(store) && store.map.len() > 1 {
let victim = store
.map
.iter()
.min_by_key(|(_, e)| e.last_used.load(Ordering::Relaxed))
.map(|(k, _)| k.clone());
match victim {
Some(k) => {
if let Some(e) = store.map.remove(&k) {
store.used_bytes = store.used_bytes.saturating_sub(e.size);
}
}
None => break,
}
}
}
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Cache for InMemoryCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let store = self.inner.read().await;
Ok(store.map.get(key).and_then(|e| {
if e.is_expired() {
None
} else {
e.last_used.store(self.next_tick(), Ordering::Relaxed);
Some(e.value.clone())
}
}))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let expires_at = self.resolve_ttl(ttl);
let size = key.len() + value.len();
let tick = self.next_tick();
let mut store = self.inner.write().await;
if let Some(old) = store.map.remove(key) {
store.used_bytes = store.used_bytes.saturating_sub(old.size);
}
store.used_bytes += size;
store.map.insert(
key.to_owned(),
CacheEntry {
value: value.to_owned(),
expires_at,
last_used: AtomicU64::new(tick),
size,
},
);
self.evict_locked(&mut store);
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let mut store = self.inner.write().await;
if let Some(e) = store.map.remove(key) {
store.used_bytes = store.used_bytes.saturating_sub(e.size);
}
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let store = self.inner.read().await;
Ok(store.map.get(key).map_or(false, |e| !e.is_expired()))
}
async fn clear(&self) -> Result<(), CacheError> {
let mut store = self.inner.write().await;
store.map.clear();
store.used_bytes = 0;
Ok(())
}
}
pub struct FileCache {
dir: std::path::PathBuf,
}
impl FileCache {
#[must_use]
pub fn new(dir: impl Into<std::path::PathBuf>) -> Self {
Self { dir: dir.into() }
}
#[must_use]
pub fn dir(&self) -> &std::path::Path {
&self.dir
}
fn key_path(&self, key: &str) -> std::path::PathBuf {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(key.as_bytes());
let mut name = String::with_capacity(64 + 6);
for b in hash {
use std::fmt::Write as _;
let _ = write!(&mut name, "{b:02x}");
}
name.push_str(".cache");
self.dir.join(name)
}
fn now_unix_secs() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn encode(value: &str, ttl: Option<Duration>) -> Vec<u8> {
let expires_at = ttl
.map(|d| Self::now_unix_secs().saturating_add(d.as_secs() as i64))
.unwrap_or(0);
let mut out = Vec::with_capacity(8 + value.len());
out.extend_from_slice(&expires_at.to_be_bytes());
out.extend_from_slice(value.as_bytes());
out
}
fn decode(buf: &[u8]) -> Option<(String, bool /* expired */)> {
if buf.len() < 8 {
return None;
}
let mut ts = [0u8; 8];
ts.copy_from_slice(&buf[..8]);
let expires_at = i64::from_be_bytes(ts);
let value = std::str::from_utf8(&buf[8..]).ok()?.to_owned();
let expired = expires_at != 0 && Self::now_unix_secs() >= expires_at;
Some((value, expired))
}
}
#[async_trait]
impl Cache for FileCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let path = self.key_path(key);
let buf = match std::fs::read(&path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(CacheError::Connection(format!("read: {e}"))),
};
match Self::decode(&buf) {
Some((_, true)) => {
let _ = std::fs::remove_file(&path);
Ok(None)
}
Some((v, false)) => Ok(Some(v)),
None => {
let _ = std::fs::remove_file(&path);
Ok(None)
}
}
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
std::fs::create_dir_all(&self.dir)
.map_err(|e| CacheError::Connection(format!("create_dir_all: {e}")))?;
let path = self.key_path(key);
std::fs::write(&path, Self::encode(value, ttl))
.map_err(|e| CacheError::Connection(format!("write: {e}")))?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let path = self.key_path(key);
match std::fs::remove_file(&path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(CacheError::Connection(format!("remove_file: {e}"))),
}
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
Ok(self.get(key).await?.is_some())
}
async fn clear(&self) -> Result<(), CacheError> {
let entries = match std::fs::read_dir(&self.dir) {
Ok(e) => e,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => return Err(CacheError::Connection(format!("read_dir: {e}"))),
};
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("cache") {
let _ = std::fs::remove_file(&path);
}
}
Ok(())
}
}
#[cfg(all(test, feature = "config"))]
mod settings_tests {
use super::*;
#[tokio::test]
async fn unset_backend_returns_inmemory() {
let s = crate::config::CacheSettings::default();
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn memory_backend_works() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("memory".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn null_backend_drops_writes() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("null".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert!(cache.get("k").await.unwrap().is_none());
}
#[tokio::test]
async fn unknown_backend_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("typo".into());
let cache = from_settings(&s);
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn redis_without_url_falls_back_to_inmemory() {
let mut s = crate::config::CacheSettings::default();
s.backend = Some("redis".into());
let cache = from_settings(&s);
#[cfg(not(feature = "cache-redis"))]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
#[cfg(feature = "cache-redis")]
{
cache.set("k", "v", None).await.unwrap();
assert_eq!(cache.get("k").await.unwrap().as_deref(), Some("v"));
}
}
}
#[cfg(test)]
mod bound_tests {
use super::*;
fn val(n: usize) -> String {
"x".repeat(n)
}
#[tokio::test]
async fn byte_budget_caps_unique_key_flood() {
let cache = InMemoryCache::new()
.with_max_bytes(10 * 1024)
.with_max_entries(0);
for i in 0..1000 {
cache.set(&format!("k{i}"), &val(1024), None).await.unwrap();
}
let store = cache.inner.read().await;
assert!(
store.used_bytes <= 10 * 1024,
"used_bytes {} exceeded byte budget",
store.used_bytes
);
assert!(
store.map.len() <= 12,
"entry count {} too high",
store.map.len()
);
}
#[tokio::test]
async fn entry_budget_caps_count() {
let cache = InMemoryCache::new().with_max_bytes(0).with_max_entries(5);
for i in 0..50 {
cache.set(&format!("k{i}"), "v", None).await.unwrap();
}
assert!(cache.inner.read().await.map.len() <= 5);
}
#[tokio::test]
async fn lru_keeps_recently_used() {
let cache = InMemoryCache::new().with_max_bytes(0).with_max_entries(3);
cache.set("hot", "v", None).await.unwrap();
cache.set("a", "v", None).await.unwrap();
cache.set("b", "v", None).await.unwrap();
let _ = cache.get("hot").await.unwrap(); cache.set("c", "v", None).await.unwrap(); cache.set("d", "v", None).await.unwrap(); assert_eq!(cache.get("hot").await.unwrap().as_deref(), Some("v"));
}
#[tokio::test]
async fn zero_budget_is_unbounded() {
let cache = InMemoryCache::new().with_max_bytes(0).with_max_entries(0);
for i in 0..1000 {
cache.set(&format!("k{i}"), "v", None).await.unwrap();
}
assert_eq!(cache.inner.read().await.map.len(), 1000);
}
#[tokio::test]
async fn delete_frees_bytes() {
let cache = InMemoryCache::new();
cache.set("k", &val(4096), None).await.unwrap();
assert!(cache.inner.read().await.used_bytes >= 4096);
cache.delete("k").await.unwrap();
assert_eq!(cache.inner.read().await.used_bytes, 0);
}
}