use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug)]
pub enum CacheError {
Driver(String),
Serialization(String),
}
impl std::fmt::Display for CacheError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheError::Driver(msg) => write!(f, "Cache driver error: {}", msg),
CacheError::Serialization(msg) => write!(f, "Cache serialization error: {}", msg),
}
}
}
impl std::error::Error for CacheError {}
#[async_trait]
pub trait CacheDriver: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError>;
async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError>;
async fn forget(&self, key: &str) -> Result<(), CacheError>;
async fn flush(&self) -> Result<(), CacheError>;
async fn has(&self, key: &str) -> Result<bool, CacheError>;
}
struct CacheEntry {
value: String,
expires_at: Option<Instant>,
}
pub struct MemoryDriver {
store: DashMap<String, CacheEntry>,
}
impl MemoryDriver {
pub fn new() -> Self {
Self {
store: DashMap::new(),
}
}
}
impl Default for MemoryDriver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CacheDriver for MemoryDriver {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
if let Some(entry) = self.store.get(key) {
if let Some(expires_at) = entry.expires_at {
if Instant::now() > expires_at {
drop(entry);
self.store.remove(key);
return Ok(None);
}
}
Ok(Some(entry.value.clone()))
} else {
Ok(None)
}
}
async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError> {
let expires_at = ttl_secs.map(|secs| Instant::now() + std::time::Duration::from_secs(secs));
self.store.insert(
key.to_string(),
CacheEntry {
value: value.to_string(),
expires_at,
},
);
Ok(())
}
async fn forget(&self, key: &str) -> Result<(), CacheError> {
self.store.remove(key);
Ok(())
}
async fn flush(&self) -> Result<(), CacheError> {
self.store.clear();
Ok(())
}
async fn has(&self, key: &str) -> Result<bool, CacheError> {
Ok(self.get(key).await?.is_some())
}
}
#[cfg(feature = "cache-redis")]
pub mod redis_driver {
use super::*;
pub struct RedisDriver {
client: redis::Client,
prefix: String,
}
impl RedisDriver {
pub fn new(redis_url: &str) -> Result<Self, CacheError> {
let client = redis::Client::open(redis_url)
.map_err(|e| CacheError::Driver(format!("Failed to connect to Redis: {}", e)))?;
Ok(Self {
client,
prefix: "rullst:cache:".to_string(),
})
}
fn prefixed_key(&self, key: &str) -> String {
format!("{}{}", self.prefix, key)
}
}
#[async_trait]
impl CacheDriver for RedisDriver {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
let result: Option<String> = redis::cmd("GET")
.arg(self.prefixed_key(key))
.query_async(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis GET failed: {}", e)))?;
Ok(result)
}
async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
let pk = self.prefixed_key(key);
if let Some(ttl) = ttl_secs {
redis::cmd("SETEX")
.arg(&pk)
.arg(ttl as i64)
.arg(value)
.query_async::<()>(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis SETEX failed: {}", e)))?;
} else {
redis::cmd("SET")
.arg(&pk)
.arg(value)
.query_async::<()>(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis SET failed: {}", e)))?;
}
Ok(())
}
async fn forget(&self, key: &str) -> Result<(), CacheError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
redis::cmd("DEL")
.arg(self.prefixed_key(key))
.query_async::<i64>(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis DEL failed: {}", e)))?;
Ok(())
}
async fn flush(&self) -> Result<(), CacheError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
let pattern = format!("{}*", self.prefix);
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis KEYS failed: {}", e)))?;
if !keys.is_empty() {
for key in &keys {
redis::cmd("DEL")
.arg(key)
.query_async::<i64>(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis DEL failed: {}", e)))?;
}
}
Ok(())
}
async fn has(&self, key: &str) -> Result<bool, CacheError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| CacheError::Driver(format!("Redis connection failed: {}", e)))?;
let exists: bool = redis::cmd("EXISTS")
.arg(self.prefixed_key(key))
.query_async(&mut con)
.await
.map_err(|e| CacheError::Driver(format!("Redis EXISTS failed: {}", e)))?;
Ok(exists)
}
}
}
pub struct Cache {
driver: Arc<Box<dyn CacheDriver>>,
}
impl Cache {
pub fn memory() -> Self {
Self {
driver: Arc::new(Box::new(MemoryDriver::new())),
}
}
#[cfg(feature = "cache-redis")]
pub fn redis(redis_url: &str) -> Result<Self, CacheError> {
let driver = redis_driver::RedisDriver::new(redis_url)?;
Ok(Self {
driver: Arc::new(Box::new(driver)),
})
}
pub fn custom(driver: Box<dyn CacheDriver>) -> Self {
Self {
driver: Arc::new(driver),
}
}
pub async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
self.driver.get(key).await
}
pub async fn put(&self, key: &str, value: &str, ttl_secs: Option<u64>) -> Result<(), CacheError> {
self.driver.put(key, value, ttl_secs).await
}
pub async fn forget(&self, key: &str) -> Result<(), CacheError> {
self.driver.forget(key).await
}
pub async fn flush(&self) -> Result<(), CacheError> {
self.driver.flush().await
}
pub async fn has(&self, key: &str) -> Result<bool, CacheError> {
self.driver.has(key).await
}
pub async fn remember<F, Fut>(&self, key: &str, ttl_secs: u64, f: F) -> Result<String, CacheError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<String, CacheError>>,
{
if let Some(cached) = self.get(key).await? {
return Ok(cached);
}
let value = f().await?;
self.put(key, &value, Some(ttl_secs)).await?;
Ok(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_cache_put_get() {
let cache = Cache::memory();
cache.put("key1", "value1", None).await.unwrap();
let result = cache.get("key1").await.unwrap();
assert_eq!(result, Some("value1".to_string()));
}
#[tokio::test]
async fn test_memory_cache_miss() {
let cache = Cache::memory();
let result = cache.get("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_memory_cache_forget() {
let cache = Cache::memory();
cache.put("key1", "value1", None).await.unwrap();
cache.forget("key1").await.unwrap();
let result = cache.get("key1").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_memory_cache_flush() {
let cache = Cache::memory();
cache.put("a", "1", None).await.unwrap();
cache.put("b", "2", None).await.unwrap();
cache.flush().await.unwrap();
assert!(cache.get("a").await.unwrap().is_none());
assert!(cache.get("b").await.unwrap().is_none());
}
#[tokio::test]
async fn test_memory_cache_has() {
let cache = Cache::memory();
assert!(!cache.has("key1").await.unwrap());
cache.put("key1", "value1", None).await.unwrap();
assert!(cache.has("key1").await.unwrap());
}
#[tokio::test]
async fn test_memory_cache_remember_miss() {
let cache = Cache::memory();
let value = cache
.remember("computed", 60, || async { Ok("hello".to_string()) })
.await
.unwrap();
assert_eq!(value, "hello");
let cached = cache.get("computed").await.unwrap();
assert_eq!(cached, Some("hello".to_string()));
}
#[tokio::test]
async fn test_memory_cache_remember_hit() {
let cache = Cache::memory();
cache.put("existing", "already_cached", Some(300)).await.unwrap();
let value = cache
.remember("existing", 60, || async {
panic!("This closure should NOT be called on cache hit");
})
.await
.unwrap();
assert_eq!(value, "already_cached");
}
#[tokio::test]
async fn test_memory_cache_overwrite() {
let cache = Cache::memory();
cache.put("key", "v1", None).await.unwrap();
cache.put("key", "v2", None).await.unwrap();
assert_eq!(cache.get("key").await.unwrap(), Some("v2".to_string()));
}
}