use super::{Driver, Error as CacheError};
use async_trait::async_trait;
use redis::AsyncCommands;
use std::{fmt::Debug, time::Duration};
#[derive(Clone)]
pub struct Config {
pub prefix: String,
pub redis_url: String,
}
impl Default for Config {
fn default() -> Self {
Self {
prefix: String::new(),
redis_url: "redis://localhost".to_string(),
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct RedisDriver {
prefix: String,
client: redis::Client,
}
impl RedisDriver {
pub fn new(config: Config) -> Result<Self, redis::RedisError> {
Ok(Self {
prefix: config.prefix,
client: redis::Client::open(config.redis_url)?,
})
}
fn prefixed_key(&self, key: &str) -> String {
format!("{}{key}", self.prefix)
}
}
fn map_redis_err(e: redis::RedisError) -> CacheError {
CacheError::Other(Box::new(e))
}
#[async_trait]
impl Driver for RedisDriver {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, CacheError> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(map_redis_err)?;
conn.get(self.prefixed_key(key))
.await
.map_err(map_redis_err)
}
async fn has(&self, key: &str) -> Result<bool, CacheError> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(map_redis_err)?;
conn.exists(self.prefixed_key(key))
.await
.map_err(map_redis_err)
}
async fn put(
&self,
key: &str,
value: Vec<u8>,
expiry: Option<Duration>,
) -> Result<(), CacheError> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(map_redis_err)?;
if let Some(expiry) = expiry {
conn.set_ex::<_, _, ()>(self.prefixed_key(key), value, expiry.as_secs())
.await
.map_err(map_redis_err)?;
} else {
conn.set::<_, _, ()>(self.prefixed_key(key), value)
.await
.map_err(map_redis_err)?;
}
Ok(())
}
async fn forget(&self, key: &str) -> Result<(), CacheError> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(map_redis_err)?;
conn.del::<_, ()>(self.prefixed_key(key))
.await
.map_err(map_redis_err)
}
async fn flush(&self) -> Result<(), CacheError> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(map_redis_err)?;
redis::cmd("FLUSHDB")
.query_async::<()>(&mut conn)
.await
.map_err(map_redis_err)
}
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use crate::Cache;
#[tokio::test]
async fn test_redis_driver() {
let cache = Cache::new(
RedisDriver::new(Config {
redis_url: env::var("REDIS_URL").expect("REDIS_URL not set"),
..Default::default()
})
.unwrap(),
);
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
cache
.put("foo", &"bar", Duration::from_secs(1))
.await
.unwrap();
assert_eq!(
cache.get::<String>("foo").await.unwrap(),
Some("bar".to_string())
);
assert!(cache.has("foo").await.unwrap());
cache.forget("foo").await.unwrap();
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
}
}