use std::{
sync::{Arc, Mutex},
time::Duration,
};
use async_trait::async_trait;
use redis::{aio::MultiplexedConnection, cluster_async::ClusterConnection};
use tokio::time::timeout;
use crate::{
cache::{CacheKey, CacheResult, CacheStore},
cache_redis::{
RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory,
RedisCommandRecorder, RedisDeleteClient, RedisDeleteRetryQueue, RedisDeleteRetryStats,
RedisOperation, backend::RedisCacheBackend, classify_redis_error, run_redis_cache_command,
},
resil::SharedCircuitBreaker,
};
#[derive(Clone)]
pub struct RedisCacheStore {
pub(super) config: RedisCacheConfig,
backend: RedisCacheBackend,
pub(super) recorder: RedisCommandRecorder,
pub(super) delete_retry: Option<RedisDeleteRetryQueue>,
pub(super) breaker: Option<SharedCircuitBreaker>,
#[allow(dead_code)]
cluster_slots: Arc<Mutex<crate::cache_redis::RedisClusterSlotMap>>,
}
impl std::fmt::Debug for RedisCacheStore {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("RedisCacheStore")
.field("config", &self.config)
.field("backend", &self.backend)
.field("delete_retry_enabled", &self.delete_retry.is_some())
.field("breaker_enabled", &self.breaker.is_some())
.finish_non_exhaustive()
}
}
impl RedisCacheStore {
pub fn new(config: RedisCacheConfig) -> RedisCacheResult<Self> {
let factory = RedisClientFactory::new(config.clone())?;
if config.cluster.enabled {
Self::from_cluster_client(config, factory.create_cluster_client()?)
} else {
Self::from_client(config, factory.create_client()?)
}
}
pub fn with_client(config: RedisCacheConfig, client: redis::Client) -> RedisCacheResult<Self> {
config.validate()?;
Self::from_client(config, client)
}
pub fn with_cluster_client(
config: RedisCacheConfig,
client: redis::cluster::ClusterClient,
) -> RedisCacheResult<Self> {
config.validate()?;
Self::from_cluster_client(config, client)
}
fn from_client(config: RedisCacheConfig, client: redis::Client) -> RedisCacheResult<Self> {
let delete_retry = RedisDeleteRetryQueue::spawn(
config.delete_retry,
RedisDeleteClient::new(
client.clone(),
config.connect_timeout,
config.command_timeout,
),
);
Ok(Self::from_backend(
config,
RedisCacheBackend::Single(client),
delete_retry,
))
}
fn from_cluster_client(
config: RedisCacheConfig,
client: redis::cluster::ClusterClient,
) -> RedisCacheResult<Self> {
Ok(Self::from_backend(
config,
RedisCacheBackend::Cluster(client),
None,
))
}
fn from_backend(
config: RedisCacheConfig,
backend: RedisCacheBackend,
delete_retry: Option<RedisDeleteRetryQueue>,
) -> Self {
let breaker = config.breaker.enabled.then(|| {
SharedCircuitBreaker::with_policy(
config.breaker.breaker.clone(),
config.breaker.policy.clone(),
)
});
Self {
config,
backend,
recorder: RedisCommandRecorder::new(),
delete_retry,
breaker,
cluster_slots: Arc::new(Mutex::new(crate::cache_redis::RedisClusterSlotMap::new())),
}
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
self.recorder = self.recorder.clone().with_metrics(metrics);
self
}
#[cfg(feature = "observability")]
pub fn with_shard_name(mut self, shard: impl Into<String>) -> Self {
self.recorder = self.recorder.clone().with_shard(shard);
self
}
pub fn namespace(&self) -> &str {
&self.config.namespace
}
pub fn delete_retry_stats(&self) -> Option<RedisDeleteRetryStats> {
self.delete_retry.as_ref().map(RedisDeleteRetryQueue::stats)
}
pub async fn health_check(&self) -> RedisCacheResult<()> {
let pong: String = match &self.backend {
RedisCacheBackend::Single(client) => {
let mut connection = self.single_connection(client).await?;
self.run_command("PING", async {
redis::cmd("PING").query_async(&mut connection).await
})
.await?
}
RedisCacheBackend::Cluster(client) => {
let mut connection = self.cluster_connection(client).await?;
self.run_command("PING", async {
redis::cmd("PING").query_async(&mut connection).await
})
.await?
}
};
if pong == "PONG" {
Ok(())
} else {
Err(RedisCacheError::Backend(format!(
"unexpected PING response: {pong}"
)))
}
}
pub(super) async fn single_connection(
&self,
client: &redis::Client,
) -> RedisCacheResult<MultiplexedConnection> {
let guard = self.allow_breaker("CONNECT").await?;
let result = timeout(
self.config.connect_timeout,
client.get_multiplexed_async_connection(),
)
.await
.map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
.map_err(|error| RedisCacheError::Connection(error.to_string()));
self.record_pool_outcome(&result);
self.record_breaker_outcome(guard, &result).await;
result
}
pub(super) async fn cluster_connection(
&self,
client: &redis::cluster::ClusterClient,
) -> RedisCacheResult<ClusterConnection> {
let guard = self.allow_breaker("CLUSTER CONNECT").await?;
let result = timeout(self.config.connect_timeout, client.get_async_connection())
.await
.map_err(|_| RedisCacheError::Timeout("cluster connect".to_string()))?
.map_err(|error| RedisCacheError::Connection(error.to_string()));
self.record_pool_outcome(&result);
self.record_breaker_outcome(guard, &result).await;
result
}
fn record_pool_outcome<T>(&self, result: &RedisCacheResult<T>) {
let outcome = match result {
Ok(_) => crate::cache_redis::RedisCommandOutcome::Success,
Err(error) => classify_redis_error(error),
};
self.recorder
.record_event(crate::cache_redis::RedisCommandEvent::new(
crate::cache_redis::RedisCommandEventKind::Pool,
outcome,
));
}
pub(super) async fn run_command<T, F>(
&self,
operation: &'static str,
future: F,
) -> RedisCacheResult<T>
where
F: std::future::Future<Output = redis::RedisResult<T>>,
{
run_redis_cache_command(
&self.recorder,
operation,
self.config.command_timeout,
self.protect_command(operation, future),
)
.await
}
pub async fn breaker_snapshot(&self) -> Option<crate::resil::BreakerSnapshot> {
match &self.breaker {
Some(breaker) => Some(breaker.snapshot().await),
None => None,
}
}
}
#[async_trait]
impl CacheStore for RedisCacheStore {
async fn get_raw(&self, key: &CacheKey) -> CacheResult<Option<Vec<u8>>> {
if let Err(error) = self.ensure_breaker_allows(RedisOperation::Get).await {
return match error {
Ok(()) => Ok(None),
Err(error) => Err(error),
};
}
match &self.backend {
RedisCacheBackend::Single(client) => self.get_raw_single(client, key).await,
RedisCacheBackend::Cluster(client) => self.get_raw_cluster(client, key).await,
}
}
async fn set_raw(
&self,
key: &CacheKey,
value: Vec<u8>,
ttl: Option<Duration>,
) -> CacheResult<()> {
if let Err(error) = self.ensure_breaker_allows(RedisOperation::Set).await {
return error;
}
match &self.backend {
RedisCacheBackend::Single(client) => self.set_raw_single(client, key, value, ttl).await,
RedisCacheBackend::Cluster(client) => {
self.set_raw_cluster(client, key, value, ttl).await
}
}
}
async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
if let Err(error) = self.ensure_breaker_allows(RedisOperation::Delete).await {
return error;
}
match &self.backend {
RedisCacheBackend::Single(client) => self.delete_single(client, key).await,
RedisCacheBackend::Cluster(client) => self.delete_cluster(client, key).await,
}
}
async fn delete_many(&self, keys: &[CacheKey]) -> CacheResult<()> {
if let Err(error) = self.ensure_breaker_allows(RedisOperation::Delete).await {
return error;
}
match &self.backend {
RedisCacheBackend::Single(client) => self.delete_many_single(client, keys).await,
RedisCacheBackend::Cluster(client) => self.delete_many_cluster(client, keys).await,
}
}
}
#[cfg(test)]
mod tests {
use crate::cache_redis::{RedisCacheConfig, RedisCacheError, RedisCacheStore};
#[test]
fn redis_store_rejects_invalid_url() {
let error = RedisCacheStore::new(RedisCacheConfig {
url: "http://127.0.0.1".to_string(),
..RedisCacheConfig::default()
})
.expect_err("invalid url");
assert!(matches!(error, RedisCacheError::InvalidUrl { .. }));
}
}