use crate::IOCached;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fmt::Display;
use std::marker::PhantomData;
pub struct RedisCacheBuilder<K, V> {
seconds: u64,
refresh: bool,
namespace: String,
prefix: String,
connection_string: Option<String>,
pool_max_size: Option<u32>,
pool_min_idle: Option<u32>,
pool_max_lifetime: Option<std::time::Duration>,
pool_idle_timeout: Option<std::time::Duration>,
_phantom: PhantomData<(K, V)>,
}
const ENV_KEY: &str = "CACHED_REDIS_CONNECTION_STRING";
const DEFAULT_NAMESPACE: &str = "cached-redis-store:";
use thiserror::Error;
#[derive(Error, Debug)]
pub enum RedisCacheBuildError {
#[error("redis connection error")]
Connection(#[from] redis::RedisError),
#[error("redis pool error")]
Pool(#[from] r2d2::Error),
#[error("Connection string not specified or invalid in env var {env_key:?}: {error:?}")]
MissingConnectionString {
env_key: String,
error: std::env::VarError,
},
}
impl<K, V> RedisCacheBuilder<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
pub fn new<S: AsRef<str>>(prefix: S, seconds: u64) -> RedisCacheBuilder<K, V> {
Self {
seconds,
refresh: false,
namespace: DEFAULT_NAMESPACE.to_string(),
prefix: prefix.as_ref().to_string(),
connection_string: None,
pool_max_size: None,
pool_min_idle: None,
pool_max_lifetime: None,
pool_idle_timeout: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn set_lifespan(mut self, seconds: u64) -> Self {
self.seconds = seconds;
self
}
#[must_use]
pub fn set_refresh(mut self, refresh: bool) -> Self {
self.refresh = refresh;
self
}
#[must_use]
pub fn set_namespace<S: AsRef<str>>(mut self, namespace: S) -> Self {
self.namespace = namespace.as_ref().to_string();
self
}
#[must_use]
pub fn set_prefix<S: AsRef<str>>(mut self, prefix: S) -> Self {
self.prefix = prefix.as_ref().to_string();
self
}
#[must_use]
pub fn set_connection_string(mut self, cs: &str) -> Self {
self.connection_string = Some(cs.to_string());
self
}
#[must_use]
pub fn set_connection_pool_max_size(mut self, max_size: u32) -> Self {
self.pool_max_size = Some(max_size);
self
}
#[must_use]
pub fn set_connection_pool_min_idle(mut self, min_idle: u32) -> Self {
self.pool_min_idle = Some(min_idle);
self
}
#[must_use]
pub fn set_connection_pool_max_lifetime(mut self, max_lifetime: std::time::Duration) -> Self {
self.pool_max_lifetime = Some(max_lifetime);
self
}
#[must_use]
pub fn set_connection_pool_idle_timeout(mut self, idle_timeout: std::time::Duration) -> Self {
self.pool_idle_timeout = Some(idle_timeout);
self
}
pub fn connection_string(&self) -> Result<String, RedisCacheBuildError> {
match self.connection_string {
Some(ref s) => Ok(s.to_string()),
None => {
std::env::var(ENV_KEY).map_err(|e| RedisCacheBuildError::MissingConnectionString {
env_key: ENV_KEY.to_string(),
error: e,
})
}
}
}
fn create_pool(&self) -> Result<r2d2::Pool<redis::Client>, RedisCacheBuildError> {
let s = self.connection_string()?;
let client: redis::Client = redis::Client::open(s)?;
let pool_builder = r2d2::Pool::builder();
let pool_builder = if let Some(max_size) = self.pool_max_size {
pool_builder.max_size(max_size)
} else {
pool_builder
};
let pool_builder = if let Some(min_idle) = self.pool_min_idle {
pool_builder.min_idle(Some(min_idle))
} else {
pool_builder
};
let pool_builder = if let Some(max_lifetime) = self.pool_max_lifetime {
pool_builder.max_lifetime(Some(max_lifetime))
} else {
pool_builder
};
let pool_builder = if let Some(idle_timeout) = self.pool_idle_timeout {
pool_builder.idle_timeout(Some(idle_timeout))
} else {
pool_builder
};
let pool: r2d2::Pool<redis::Client> = pool_builder.build(client)?;
Ok(pool)
}
pub fn build(self) -> Result<RedisCache<K, V>, RedisCacheBuildError> {
Ok(RedisCache {
seconds: self.seconds,
refresh: self.refresh,
connection_string: self.connection_string()?,
pool: self.create_pool()?,
namespace: self.namespace,
prefix: self.prefix,
_phantom: PhantomData,
})
}
}
pub struct RedisCache<K, V> {
pub(super) seconds: u64,
pub(super) refresh: bool,
pub(super) namespace: String,
pub(super) prefix: String,
connection_string: String,
pool: r2d2::Pool<redis::Client>,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> RedisCache<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
#[allow(clippy::new_ret_no_self)]
pub fn new<S: AsRef<str>>(prefix: S, seconds: u64) -> RedisCacheBuilder<K, V> {
RedisCacheBuilder::new(prefix, seconds)
}
fn generate_key(&self, key: &K) -> String {
format!("{}{}{}", self.namespace, self.prefix, key)
}
#[must_use]
pub fn connection_string(&self) -> String {
self.connection_string.clone()
}
}
#[derive(Error, Debug)]
pub enum RedisCacheError {
#[error("redis error")]
RedisCacheError(#[from] redis::RedisError),
#[error("redis pool error")]
PoolError(#[from] r2d2::Error),
#[error("Error deserializing cached value {cached_value:?}: {error:?}")]
CacheDeserializationError {
cached_value: String,
error: serde_json::Error,
},
#[error("Error serializing cached value: {error:?}")]
CacheSerializationError { error: serde_json::Error },
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CachedRedisValue<V> {
pub(crate) value: V,
pub(crate) version: Option<u64>,
}
impl<V> CachedRedisValue<V> {
fn new(value: V) -> Self {
Self {
value,
version: Some(1),
}
}
}
impl<K, V> IOCached<K, V> for RedisCache<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
type Error = RedisCacheError;
fn cache_get(&self, key: &K) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(key);
pipe.get(key.clone());
if self.refresh {
pipe.expire(key, self.seconds as i64).ignore();
}
let res: (Option<String>,) = pipe.query(&mut *conn)?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
fn cache_set(&self, key: K, val: V) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(&key);
let val = CachedRedisValue::new(val);
pipe.get(key.clone());
pipe.set_ex::<String, String>(
key,
serde_json::to_string(&val)
.map_err(|e| RedisCacheError::CacheSerializationError { error: e })?,
self.seconds,
)
.ignore();
let res: (Option<String>,) = pipe.query(&mut *conn)?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
fn cache_remove(&self, key: &K) -> Result<Option<V>, RedisCacheError> {
let mut conn = self.pool.get()?;
let mut pipe = redis::pipe();
let key = self.generate_key(key);
pipe.get(key.clone());
pipe.del::<String>(key).ignore();
let res: (Option<String>,) = pipe.query(&mut *conn)?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
fn cache_lifespan(&self) -> Option<u64> {
Some(self.seconds)
}
fn cache_set_lifespan(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = seconds;
Some(old)
}
fn cache_set_refresh(&mut self, refresh: bool) -> bool {
let old = self.refresh;
self.refresh = refresh;
old
}
}
#[cfg(all(
feature = "async",
any(feature = "redis_async_std", feature = "redis_tokio")
))]
mod async_redis {
use super::{
CachedRedisValue, DeserializeOwned, Display, PhantomData, RedisCacheBuildError,
RedisCacheError, Serialize, DEFAULT_NAMESPACE, ENV_KEY,
};
use {crate::IOCachedAsync, async_trait::async_trait};
pub struct AsyncRedisCacheBuilder<K, V> {
seconds: u64,
refresh: bool,
namespace: String,
prefix: String,
connection_string: Option<String>,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> AsyncRedisCacheBuilder<K, V>
where
K: Display,
V: Serialize + DeserializeOwned,
{
pub fn new<S: AsRef<str>>(prefix: S, seconds: u64) -> AsyncRedisCacheBuilder<K, V> {
Self {
seconds,
refresh: false,
namespace: DEFAULT_NAMESPACE.to_string(),
prefix: prefix.as_ref().to_string(),
connection_string: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn set_lifespan(mut self, seconds: u64) -> Self {
self.seconds = seconds;
self
}
#[must_use]
pub fn set_refresh(mut self, refresh: bool) -> Self {
self.refresh = refresh;
self
}
#[must_use]
pub fn set_namespace<S: AsRef<str>>(mut self, namespace: S) -> Self {
self.namespace = namespace.as_ref().to_string();
self
}
#[must_use]
pub fn set_prefix<S: AsRef<str>>(mut self, prefix: S) -> Self {
self.prefix = prefix.as_ref().to_string();
self
}
#[must_use]
pub fn set_connection_string(mut self, cs: &str) -> Self {
self.connection_string = Some(cs.to_string());
self
}
pub fn connection_string(&self) -> Result<String, RedisCacheBuildError> {
match self.connection_string {
Some(ref s) => Ok(s.to_string()),
None => std::env::var(ENV_KEY).map_err(|e| {
RedisCacheBuildError::MissingConnectionString {
env_key: ENV_KEY.to_string(),
error: e,
}
}),
}
}
#[cfg(not(feature = "redis_connection_manager"))]
async fn create_multiplexed_connection(
&self,
) -> Result<redis::aio::MultiplexedConnection, RedisCacheBuildError> {
let s = self.connection_string()?;
let client = redis::Client::open(s)?;
let conn = client.get_multiplexed_async_connection().await?;
Ok(conn)
}
#[cfg(feature = "redis_connection_manager")]
async fn create_connection_manager(
&self,
) -> Result<redis::aio::ConnectionManager, RedisCacheBuildError> {
let s = self.connection_string()?;
let client = redis::Client::open(s)?;
let conn = redis::aio::ConnectionManager::new(client).await?;
Ok(conn)
}
pub async fn build(self) -> Result<AsyncRedisCache<K, V>, RedisCacheBuildError> {
Ok(AsyncRedisCache {
seconds: self.seconds,
refresh: self.refresh,
connection_string: self.connection_string()?,
#[cfg(not(feature = "redis_connection_manager"))]
connection: self.create_multiplexed_connection().await?,
#[cfg(feature = "redis_connection_manager")]
connection: self.create_connection_manager().await?,
namespace: self.namespace,
prefix: self.prefix,
_phantom: PhantomData,
})
}
}
pub struct AsyncRedisCache<K, V> {
pub(super) seconds: u64,
pub(super) refresh: bool,
pub(super) namespace: String,
pub(super) prefix: String,
connection_string: String,
#[cfg(not(feature = "redis_connection_manager"))]
connection: redis::aio::MultiplexedConnection,
#[cfg(feature = "redis_connection_manager")]
connection: redis::aio::ConnectionManager,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> AsyncRedisCache<K, V>
where
K: Display + Send + Sync,
V: Serialize + DeserializeOwned + Send + Sync,
{
#[allow(clippy::new_ret_no_self)]
pub fn new<S: AsRef<str>>(prefix: S, seconds: u64) -> AsyncRedisCacheBuilder<K, V> {
AsyncRedisCacheBuilder::new(prefix, seconds)
}
fn generate_key(&self, key: &K) -> String {
format!("{}{}{}", self.namespace, self.prefix, key)
}
#[must_use]
pub fn connection_string(&self) -> String {
self.connection_string.clone()
}
}
#[async_trait]
impl<K, V> IOCachedAsync<K, V> for AsyncRedisCache<K, V>
where
K: Display + Send + Sync,
V: Serialize + DeserializeOwned + Send + Sync,
{
type Error = RedisCacheError;
async fn cache_get(&self, key: &K) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(key);
pipe.get(key.clone());
if self.refresh {
pipe.expire(key, self.seconds as i64).ignore();
}
let res: (Option<String>,) = pipe.query_async(&mut conn).await?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
async fn cache_set(&self, key: K, val: V) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(&key);
let val = CachedRedisValue::new(val);
pipe.get(key.clone());
pipe.set_ex::<String, String>(
key,
serde_json::to_string(&val)
.map_err(|e| RedisCacheError::CacheSerializationError { error: e })?,
self.seconds,
)
.ignore();
let res: (Option<String>,) = pipe.query_async(&mut conn).await?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
async fn cache_remove(&self, key: &K) -> Result<Option<V>, Self::Error> {
let mut conn = self.connection.clone();
let mut pipe = redis::pipe();
let key = self.generate_key(key);
pipe.get(key.clone());
pipe.del::<String>(key).ignore();
let res: (Option<String>,) = pipe.query_async(&mut conn).await?;
match res.0 {
None => Ok(None),
Some(s) => {
let v: CachedRedisValue<V> = serde_json::from_str(&s).map_err(|e| {
RedisCacheError::CacheDeserializationError {
cached_value: s,
error: e,
}
})?;
Ok(Some(v.value))
}
}
}
fn cache_set_refresh(&mut self, refresh: bool) -> bool {
let old = self.refresh;
self.refresh = refresh;
old
}
fn cache_lifespan(&self) -> Option<u64> {
Some(self.seconds)
}
fn cache_set_lifespan(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = seconds;
Some(old)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
fn now_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
#[tokio::test]
async fn test_async_redis_cache() {
let mut c: AsyncRedisCache<u32, u32> =
AsyncRedisCache::new(format!("{}:async-redis-cache-test", now_millis()), 2)
.build()
.await
.unwrap();
assert!(c.cache_get(&1).await.unwrap().is_none());
assert!(c.cache_set(1, 100).await.unwrap().is_none());
assert!(c.cache_get(&1).await.unwrap().is_some());
sleep(Duration::new(2, 500_000));
assert!(c.cache_get(&1).await.unwrap().is_none());
let old = c.cache_set_lifespan(1).unwrap();
assert_eq!(2, old);
assert!(c.cache_set(1, 100).await.unwrap().is_none());
assert!(c.cache_get(&1).await.unwrap().is_some());
sleep(Duration::new(1, 600_000));
assert!(c.cache_get(&1).await.unwrap().is_none());
c.cache_set_lifespan(10).unwrap();
assert!(c.cache_set(1, 100).await.unwrap().is_none());
assert!(c.cache_set(2, 100).await.unwrap().is_none());
assert_eq!(c.cache_get(&1).await.unwrap().unwrap(), 100);
assert_eq!(c.cache_get(&1).await.unwrap().unwrap(), 100);
}
}
}
#[cfg(all(
feature = "async",
any(feature = "redis_async_std", feature = "redis_tokio")
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "async",
any(feature = "redis_async_std", feature = "redis_tokio")
)))
)]
pub use async_redis::{AsyncRedisCache, AsyncRedisCacheBuilder};
#[cfg(test)]
mod tests {
use std::thread::sleep;
use std::time::Duration;
use super::*;
fn now_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
#[test]
fn redis_cache() {
let mut c: RedisCache<u32, u32> =
RedisCache::new(format!("{}:redis-cache-test", now_millis()), 2)
.set_namespace("in-tests:")
.build()
.unwrap();
assert!(c.cache_get(&1).unwrap().is_none());
assert!(c.cache_set(1, 100).unwrap().is_none());
assert!(c.cache_get(&1).unwrap().is_some());
sleep(Duration::new(2, 500_000));
assert!(c.cache_get(&1).unwrap().is_none());
let old = c.cache_set_lifespan(1).unwrap();
assert_eq!(2, old);
assert!(c.cache_set(1, 100).unwrap().is_none());
assert!(c.cache_get(&1).unwrap().is_some());
sleep(Duration::new(1, 600_000));
assert!(c.cache_get(&1).unwrap().is_none());
c.cache_set_lifespan(10).unwrap();
assert!(c.cache_set(1, 100).unwrap().is_none());
assert!(c.cache_set(2, 100).unwrap().is_none());
assert_eq!(c.cache_get(&1).unwrap().unwrap(), 100);
assert_eq!(c.cache_get(&1).unwrap().unwrap(), 100);
}
#[test]
fn remove() {
let c: RedisCache<u32, u32> =
RedisCache::new(format!("{}:redis-cache-test-remove", now_millis()), 3600)
.build()
.unwrap();
assert!(c.cache_set(1, 100).unwrap().is_none());
assert!(c.cache_set(2, 200).unwrap().is_none());
assert!(c.cache_set(3, 300).unwrap().is_none());
assert_eq!(100, c.cache_remove(&1).unwrap().unwrap());
}
}