use std::time::Duration;
use async_trait::async_trait;
use redis::AsyncCommands;
use super::{Cache, CacheError};
pub struct RedisCache {
conn: redis::aio::ConnectionManager,
default_ttl: Option<Duration>,
}
impl RedisCache {
pub async fn new(url: &str) -> Result<Self, CacheError> {
Self::with_default_ttl(url, None).await
}
pub async fn with_default_ttl(url: &str, default_ttl: Option<Duration>) -> Result<Self, CacheError> {
let client = redis::Client::open(url)
.map_err(|e| CacheError::Connection(e.to_string()))?;
let conn = redis::aio::ConnectionManager::new(client)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
Ok(Self { conn, default_ttl })
}
fn effective_ttl(&self, ttl: Option<Duration>) -> Option<u64> {
ttl.or(self.default_ttl).map(|d| d.as_secs().max(1))
}
}
#[async_trait]
impl Cache for RedisCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let mut conn = self.conn.clone();
conn.get::<_, Option<String>>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
match self.effective_ttl(ttl) {
Some(secs) => {
conn.set_ex::<_, _, ()>(key, value, secs)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
None => {
conn.set::<_, _, ()>(key, value)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
}
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
conn.del::<_, ()>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
let mut conn = self.conn.clone();
conn.exists::<_, bool>(key)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn clear(&self) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("FLUSHDB")
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn incr(
&self,
key: &str,
by: i64,
ttl: Option<Duration>,
) -> Result<i64, CacheError> {
let mut conn = self.conn.clone();
let new: i64 = redis::cmd("INCRBY")
.arg(key)
.arg(by)
.query_async(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
if let Some(secs) = self.effective_ttl(ttl) {
let _: i64 = redis::cmd("EXPIRE")
.arg(key)
.arg(secs)
.arg("NX")
.query_async(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
}
Ok(new)
}
}