use crate::IOKash;
use redis::Pipeline;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::fmt::Display;
use std::marker::PhantomData;
use thiserror::Error;
pub struct RedisCacheBuilder<K, V> {
seconds: Option<u64>,
namespace: Option<String>,
prefix: String,
connection_string: Option<String>,
pool_max_size: Option<u32>,
pool_min_idle: Option<u32>,
pool_max_lifetime: Option<std::time::Duration>,
pool_idle_timeout: Option<std::time::Duration>,
_phantom: PhantomData<(K, V)>,
}
const ENV_KEY: &str = "KASH_REDIS_CONNECTION_STRING";
const DEFAULT_NAMESPACE: &str = "kash:";
#[derive(Error, Debug)]
pub enum RedisCacheBuildError {
#[error("redis connection error")]
Connection(#[from] redis::RedisError),
#[error("redis pool error")]
Pool(#[from] r2d2::Error),
#[error("Connection string not specified or invalid in env var {env_key:?}: {error:?}")]
MissingConnectionString {
env_key: String,
error: std::env::VarError,
},
}
impl<K, V> RedisCacheBuilder<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
#[must_use]
pub fn new(prefix: &str, seconds: Option<u64>) -> RedisCacheBuilder<K, V> {
Self {
seconds,
namespace: None,
prefix: prefix.to_string(),
connection_string: None,
pool_max_size: None,
pool_min_idle: None,
pool_max_lifetime: None,
pool_idle_timeout: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn set_ttl(mut self, seconds: u64) -> Self {
self.seconds = Some(seconds);
self
}
#[must_use]
pub fn set_namespace(mut self, namespace: &str) -> Self {
self.namespace = Some(namespace.to_string());
self
}
#[must_use]
pub fn set_prefix(mut self, prefix: &str) -> Self {
self.prefix = prefix.to_string();
self
}
#[must_use]
pub fn set_connection_string(mut self, cs: &str) -> Self {
self.connection_string = Some(cs.to_string());
self
}
#[must_use]
pub fn set_connection_pool_max_size(mut self, max_size: u32) -> Self {
self.pool_max_size = Some(max_size);
self
}
#[must_use]
pub fn set_connection_pool_min_idle(mut self, min_idle: u32) -> Self {
self.pool_min_idle = Some(min_idle);
self
}
#[must_use]
pub fn set_connection_pool_max_lifetime(mut self, max_lifetime: std::time::Duration) -> Self {
self.pool_max_lifetime = Some(max_lifetime);
self
}
#[must_use]
pub fn set_connection_pool_idle_timeout(mut self, idle_timeout: std::time::Duration) -> Self {
self.pool_idle_timeout = Some(idle_timeout);
self
}
pub fn connection_string(&self) -> Result<String, RedisCacheBuildError> {
match self.connection_string {
Some(ref s) => Ok(s.to_string()),
None => {
std::env::var(ENV_KEY).map_err(|e| RedisCacheBuildError::MissingConnectionString {
env_key: ENV_KEY.to_string(),
error: e,
})
}
}
}
fn create_pool(&self) -> Result<r2d2::Pool<redis::Client>, RedisCacheBuildError> {
let s = self.connection_string()?;
let client: redis::Client = redis::Client::open(s)?;
let pool_builder = r2d2::Pool::builder();
let pool_builder = if let Some(max_size) = self.pool_max_size {
pool_builder.max_size(max_size)
} else {
pool_builder
};
let pool_builder = if let Some(min_idle) = self.pool_min_idle {
pool_builder.min_idle(Some(min_idle))
} else {
pool_builder
};
let pool_builder = if let Some(max_lifetime) = self.pool_max_lifetime {
pool_builder.max_lifetime(Some(max_lifetime))
} else {
pool_builder
};
let pool_builder = if let Some(idle_timeout) = self.pool_idle_timeout {
pool_builder.idle_timeout(Some(idle_timeout))
} else {
pool_builder
};
let pool: r2d2::Pool<redis::Client> = pool_builder.build(client)?;
Ok(pool)
}
pub fn build(self) -> Result<RedisCache<K, V>, RedisCacheBuildError> {
let combined_prefix = format!(
"{}{}",
self.namespace.as_deref().unwrap_or(DEFAULT_NAMESPACE),
self.prefix
);
Ok(RedisCache {
seconds: self.seconds,
connection_string: self.connection_string()?,
pool: self.create_pool()?,
combined_prefix,
_phantom: PhantomData,
})
}
}
pub struct RedisCache<K, V> {
pub(super) seconds: Option<u64>,
combined_prefix: String,
connection_string: String,
pool: r2d2::Pool<redis::Client>,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> RedisCache<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new(prefix: &str, seconds: Option<u64>) -> RedisCacheBuilder<K, V> {
RedisCacheBuilder::new(prefix, seconds)
}
fn generate_key(&self, key: impl Display) -> String {
format!("{}{key}", self.combined_prefix)
}
#[must_use]
pub fn connection_string(&self) -> &str {
&self.connection_string
}
}
#[derive(Error, Debug)]
pub enum RedisCacheError {
#[error("redis error")]
RedisCacheError(#[from] redis::RedisError),
#[error("redis pool error")]
PoolError(#[from] r2d2::Error),
#[error("Error deserializing cached value")]
CacheDeserializationError(#[from] rmp_serde::decode::Error),
#[error("Error serializing cached value")]
CacheSerializationError(#[from] rmp_serde::encode::Error),
}
impl<K, V> IOKash<K, V> for RedisCache<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
type Error = RedisCacheError;
fn get(&self, k: &K) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(k);
pipe.get(&key);
let res: (Option<Vec<u8>>,) = pipe.query(&mut *conn)?;
check_and_get_result(res)
}
fn set(&self, k: K, v: V) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(&k);
pipe.get(&key);
let val = rmp_serde::to_vec(&v)?;
set_val(self.seconds, &mut pipe, key, &val);
let res: (Option<Vec<u8>>,) = pipe.query(&mut *conn)?;
check_and_get_result(res)
}
fn remove(&self, k: &K) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(k);
pipe.get(&key);
pipe.del(key).ignore();
let res: (Option<Vec<u8>>,) = pipe.query(&mut *conn)?;
check_and_get_result(res)
}
fn clear(&self) -> Result<(), RedisCacheError> {
use redis::Commands;
let mut conn = self.pool.get()?;
let keys = conn
.scan_match::<_, String>(self.generate_key("*"))?
.collect::<Vec<_>>();
conn.del::<_, usize>(keys)?;
Ok(())
}
fn ttl(&self) -> Option<u64> {
self.seconds
}
fn set_ttl(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = Some(seconds);
old
}
}
#[cfg(all(feature = "async", feature = "redis_tokio"))]
mod async_redis {
use super::{
DEFAULT_NAMESPACE, DeserializeOwned, Display, ENV_KEY, PhantomData, RedisCacheBuildError,
RedisCacheError, Serialize, check_and_get_result, set_val,
};
use crate::IOKashAsync;
pub struct AsyncRedisCacheBuilder<K, V> {
seconds: Option<u64>,
namespace: Option<String>,
prefix: String,
connection_string: Option<String>,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> AsyncRedisCacheBuilder<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
#[must_use]
pub fn new(prefix: &str, seconds: Option<u64>) -> AsyncRedisCacheBuilder<K, V> {
Self {
seconds,
namespace: None,
prefix: prefix.to_string(),
connection_string: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn set_ttl(mut self, seconds: Option<u64>) -> Self {
self.seconds = seconds;
self
}
#[must_use]
pub fn set_namespace(mut self, namespace: &str) -> Self {
self.namespace = Some(namespace.to_string());
self
}
#[must_use]
pub fn set_prefix(mut self, prefix: &str) -> Self {
self.prefix = prefix.to_string();
self
}
#[must_use]
pub fn set_connection_string(mut self, cs: &str) -> Self {
self.connection_string = Some(cs.to_string());
self
}
pub fn connection_string(&self) -> Result<String, RedisCacheBuildError> {
match self.connection_string {
Some(ref s) => Ok(s.to_string()),
None => std::env::var(ENV_KEY).map_err(|e| {
RedisCacheBuildError::MissingConnectionString {
env_key: ENV_KEY.to_string(),
error: e,
}
}),
}
}
#[cfg(not(feature = "redis_connection_manager"))]
async fn create_multiplexed_connection(
&self,
) -> Result<redis::aio::MultiplexedConnection, RedisCacheBuildError> {
let s = self.connection_string()?;
let client = redis::Client::open(s)?;
let conn = client.get_multiplexed_async_connection().await?;
Ok(conn)
}
#[cfg(feature = "redis_connection_manager")]
async fn create_connection_manager(
&self,
) -> Result<redis::aio::ConnectionManager, RedisCacheBuildError> {
let s = self.connection_string()?;
let client = redis::Client::open(s)?;
let conn = redis::aio::ConnectionManager::new(client).await?;
Ok(conn)
}
pub async fn build(self) -> Result<AsyncRedisCache<K, V>, RedisCacheBuildError> {
let combined_prefix = format!(
"{}{}",
self.namespace.as_deref().unwrap_or(DEFAULT_NAMESPACE),
self.prefix
);
Ok(AsyncRedisCache {
seconds: self.seconds,
connection_string: self.connection_string()?,
#[cfg(not(feature = "redis_connection_manager"))]
connection: self.create_multiplexed_connection().await?,
#[cfg(feature = "redis_connection_manager")]
connection: self.create_connection_manager().await?,
combined_prefix,
_phantom: PhantomData,
})
}
}
pub struct AsyncRedisCache<K, V> {
pub(super) seconds: Option<u64>,
combined_prefix: String,
connection_string: String,
#[cfg(not(feature = "redis_connection_manager"))]
connection: redis::aio::MultiplexedConnection,
#[cfg(feature = "redis_connection_manager")]
connection: redis::aio::ConnectionManager,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> AsyncRedisCache<K, V>
where
K: Display + Send + Sync,
V: Serialize + DeserializeOwned + Send + Sync,
{
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new(prefix: &str, seconds: Option<u64>) -> AsyncRedisCacheBuilder<K, V> {
AsyncRedisCacheBuilder::new(prefix, seconds)
}
fn generate_key(&self, key: impl Display) -> String {
format!("{}{key}", self.combined_prefix)
}
#[must_use]
pub fn connection_string(&self) -> &str {
&self.connection_string
}
}
#[async_trait::async_trait]
impl<K, V> IOKashAsync<K, V> for AsyncRedisCache<K, V>
where
K: Display + Send + Sync,
V: Serialize + DeserializeOwned + Send + Sync,
{
type Error = RedisCacheError;
async fn get(&self, k: &K) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(k);
pipe.get(&key);
let res: (Option<Vec<u8>>,) = pipe.query_async(&mut conn).await?;
check_and_get_result(res)
}
async fn set(&self, k: K, v: V) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(&k);
pipe.get(&key);
let val = rmp_serde::to_vec(&v)?;
set_val(self.seconds, &mut pipe, key, &val);
let res: (Option<Vec<u8>>,) = pipe.query_async(&mut conn).await?;
check_and_get_result(res)
}
async fn remove(&self, k: &K) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(k);
pipe.get(&key);
pipe.del(&key).ignore();
let res: (Option<Vec<u8>>,) = pipe.query_async(&mut conn).await?;
check_and_get_result(res)
}
async fn clear(&self) -> Result<(), RedisCacheError> {
use futures_util::StreamExt;
use redis::AsyncCommands;
let mut conn = self.connection.clone();
let keys = conn
.scan_match::<_, String>(self.generate_key("*"))
.await?
.collect::<Vec<_>>()
.await;
conn.del::<_, usize>(keys).await?;
Ok(())
}
fn ttl(&self) -> Option<u64> {
self.seconds
}
fn set_ttl(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = Some(seconds);
old
}
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
fn now_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
#[tokio::test]
async fn test_async_redis_cache() {
let mut c: AsyncRedisCache<u32, u32> =
AsyncRedisCache::new(&format!("{}:async-redis-cache-test", now_millis()), Some(2))
.build()
.await
.unwrap();
assert!(c.get(&1).await.unwrap().is_none());
assert!(c.set(1, 100).await.unwrap().is_none());
assert!(c.get(&1).await.unwrap().is_some());
sleep(Duration::new(2, 500_000));
assert!(c.get(&1).await.unwrap().is_none());
let old = c.set_ttl(1).unwrap();
assert_eq!(2, old);
assert!(c.set(1, 100).await.unwrap().is_none());
assert!(c.get(&1).await.unwrap().is_some());
sleep(Duration::new(1, 600_000));
assert!(c.get(&1).await.unwrap().is_none());
c.set_ttl(10).unwrap();
assert!(c.set(1, 100).await.unwrap().is_none());
assert!(c.set(2, 100).await.unwrap().is_none());
assert_eq!(c.get(&1).await.unwrap().unwrap(), 100);
assert_eq!(c.get(&1).await.unwrap().unwrap(), 100);
}
}
}
fn check_and_get_result<V>(res: (Option<Vec<u8>>,)) -> Result<Option<V>, RedisCacheError>
where
V: Serialize + DeserializeOwned,
{
match res.0 {
None => Ok(None),
Some(s) => {
let v = rmp_serde::from_slice(&s)?;
Ok(Some(v))
}
}
}
fn set_val(seconds: Option<u64>, pipe: &mut Pipeline, key: String, val: &[u8]) {
if let Some(seconds) = seconds {
pipe.set_ex(key, val, seconds).ignore();
} else {
pipe.set(key, val).ignore();
}
}
#[cfg(all(feature = "async", feature = "redis_tokio"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "async", feature = "redis_tokio"))))]
pub use async_redis::{AsyncRedisCache, AsyncRedisCacheBuilder};
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use std::thread::sleep;
use std::time::Duration;
use super::*;
fn now_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
#[test]
fn redis_cache() {
let mut c: RedisCache<u32, u32> =
RedisCache::new(&format!("{}:redis-cache-test", now_millis()), Some(2))
.set_namespace("in-tests:")
.build()
.unwrap();
assert!(c.get(&1).unwrap().is_none());
assert!(c.set(1, 100).unwrap().is_none());
assert!(c.get(&1).unwrap().is_some());
sleep(Duration::new(2, 500_000));
assert!(c.get(&1).unwrap().is_none());
let old = c.set_ttl(1).unwrap();
assert_eq!(2, old);
assert!(c.set(1, 100).unwrap().is_none());
assert!(c.get(&1).unwrap().is_some());
sleep(Duration::new(1, 600_000));
assert!(c.get(&1).unwrap().is_none());
c.set_ttl(10).unwrap();
assert!(c.set(1, 100).unwrap().is_none());
assert!(c.set(2, 100).unwrap().is_none());
assert_eq!(c.get(&1).unwrap().unwrap(), 100);
assert_eq!(c.get(&1).unwrap().unwrap(), 100);
}
#[test]
fn remove() {
let c: RedisCache<u32, u32> = RedisCache::new(
&format!("{}:redis-cache-test-remove", now_millis()),
Some(3600),
)
.build()
.unwrap();
assert!(c.set(1, 100).unwrap().is_none());
assert!(c.set(2, 200).unwrap().is_none());
assert!(c.set(3, 300).unwrap().is_none());
assert_eq!(100, c.remove(&1).unwrap().unwrap());
}
#[test]
fn clear() {
let c: RedisCache<u32, u32> = RedisCache::new(
&format!("{}:redis-cache-test-clear", now_millis()),
Some(3600),
)
.build()
.unwrap();
assert!(c.set(1, 100).unwrap().is_none());
assert!(c.set(2, 200).unwrap().is_none());
assert!(c.set(3, 300).unwrap().is_none());
c.clear().unwrap();
assert!(c.get(&1).unwrap().is_none());
assert!(c.get(&2).unwrap().is_none());
assert!(c.get(&3).unwrap().is_none());
}
}