use std::time::Duration;
use async_trait::async_trait;
use redis::AsyncCommands;
use super::{Cache, CacheError};
pub struct RedisCache {
conn: redis::aio::ConnectionManager,
default_ttl: Option<Duration>,
}
impl RedisCache {
pub async fn new(url: &str) -> Result<Self, CacheError> {
Self::with_default_ttl(url, None).await
}
pub async fn with_default_ttl(
url: &str,
default_ttl: Option<Duration>,
) -> Result<Self, CacheError> {
let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?;
let conn = redis::aio::ConnectionManager::new(client)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
Ok(Self { conn, default_ttl })
}
fn effective_ttl(&self, ttl: Option<Duration>) -> Option<u64> {
ttl.or(self.default_ttl).map(|d| d.as_secs().max(1))
}
}
#[async_trait]
impl Cache for RedisCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let mut conn = self.conn.clone();
conn.get::<_, Option<String>>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
match self.effective_ttl(ttl) {
Some(secs) => conn
.set_ex::<_, _, ()>(key, value, secs)
.await
.map_err(|e| CacheError::Connection(e.to_string())),
None => conn
.set::<_, _, ()>(key, value)
.await
.map_err(|e| CacheError::Connection(e.to_string())),
}
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
conn.del::<_, ()>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let mut conn = self.conn.clone();
conn.exists::<_, bool>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn clear(&self) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("FLUSHDB")
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn incr(&self, key: &str, by: i64, ttl: Option<Duration>) -> Result<i64, CacheError> {
let mut conn = self.conn.clone();
let new: i64 = redis::cmd("INCRBY")
.arg(key)
.arg(by)
.query_async(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
if let Some(secs) = self.effective_ttl(ttl) {
let _: i64 = redis::cmd("EXPIRE")
.arg(key)
.arg(secs)
.arg("NX")
.query_async(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
}
Ok(new)
}
}