rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{
    sync::{Arc, Mutex},
    time::Duration,
};

use async_trait::async_trait;
use redis::{aio::MultiplexedConnection, cluster_async::ClusterConnection};
use tokio::time::timeout;

use crate::{
    cache::{CacheKey, CacheResult, CacheStore},
    cache_redis::{
        RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory,
        RedisCommandRecorder, RedisDeleteClient, RedisDeleteRetryQueue, RedisDeleteRetryStats,
        RedisOperation, backend::RedisCacheBackend, classify_redis_error, run_redis_cache_command,
    },
    resil::SharedCircuitBreaker,
};

/// Redis-backed cache store.
#[derive(Clone)]
pub struct RedisCacheStore {
    pub(super) config: RedisCacheConfig,
    backend: RedisCacheBackend,
    pub(super) recorder: RedisCommandRecorder,
    pub(super) delete_retry: Option<RedisDeleteRetryQueue>,
    pub(super) breaker: Option<SharedCircuitBreaker>,
    #[allow(dead_code)]
    cluster_slots: Arc<Mutex<crate::cache_redis::RedisClusterSlotMap>>,
}

impl std::fmt::Debug for RedisCacheStore {
    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        formatter
            .debug_struct("RedisCacheStore")
            .field("config", &self.config)
            .field("backend", &self.backend)
            .field("delete_retry_enabled", &self.delete_retry.is_some())
            .field("breaker_enabled", &self.breaker.is_some())
            .finish_non_exhaustive()
    }
}

impl RedisCacheStore {
    /// Creates a Redis-backed store and validates local configuration.
    pub fn new(config: RedisCacheConfig) -> RedisCacheResult<Self> {
        let factory = RedisClientFactory::new(config.clone())?;
        if config.cluster.enabled {
            Self::from_cluster_client(config, factory.create_cluster_client()?)
        } else {
            Self::from_client(config, factory.create_client()?)
        }
    }

    /// Creates a store from an existing single-node Redis client.
    pub fn with_client(config: RedisCacheConfig, client: redis::Client) -> RedisCacheResult<Self> {
        config.validate()?;
        Self::from_client(config, client)
    }

    /// Creates a store from an existing Redis Cluster client.
    pub fn with_cluster_client(
        config: RedisCacheConfig,
        client: redis::cluster::ClusterClient,
    ) -> RedisCacheResult<Self> {
        config.validate()?;
        Self::from_cluster_client(config, client)
    }

    fn from_client(config: RedisCacheConfig, client: redis::Client) -> RedisCacheResult<Self> {
        let delete_retry = RedisDeleteRetryQueue::spawn(
            config.delete_retry,
            RedisDeleteClient::new(
                client.clone(),
                config.connect_timeout,
                config.command_timeout,
            ),
        );
        Ok(Self::from_backend(
            config,
            RedisCacheBackend::Single(client),
            delete_retry,
        ))
    }

    fn from_cluster_client(
        config: RedisCacheConfig,
        client: redis::cluster::ClusterClient,
    ) -> RedisCacheResult<Self> {
        Ok(Self::from_backend(
            config,
            RedisCacheBackend::Cluster(client),
            None,
        ))
    }

    fn from_backend(
        config: RedisCacheConfig,
        backend: RedisCacheBackend,
        delete_retry: Option<RedisDeleteRetryQueue>,
    ) -> Self {
        let breaker = config.breaker.enabled.then(|| {
            SharedCircuitBreaker::with_policy(
                config.breaker.breaker.clone(),
                config.breaker.policy.clone(),
            )
        });
        Self {
            config,
            backend,
            recorder: RedisCommandRecorder::new(),
            delete_retry,
            breaker,
            cluster_slots: Arc::new(Mutex::new(crate::cache_redis::RedisClusterSlotMap::new())),
        }
    }

    /// Attaches a metrics registry to this Redis store.
    #[cfg(feature = "observability")]
    pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
        self.recorder = self.recorder.clone().with_metrics(metrics);
        self
    }

    /// Sets the low-cardinality shard name used in Redis metrics.
    #[cfg(feature = "observability")]
    pub fn with_shard_name(mut self, shard: impl Into<String>) -> Self {
        self.recorder = self.recorder.clone().with_shard(shard);
        self
    }

    /// Returns the configured namespace.
    pub fn namespace(&self) -> &str {
        &self.config.namespace
    }

    /// Returns delete retry counters when retry is enabled.
    pub fn delete_retry_stats(&self) -> Option<RedisDeleteRetryStats> {
        self.delete_retry.as_ref().map(RedisDeleteRetryQueue::stats)
    }

    /// Validates that Redis can be reached and accepts commands.
    pub async fn health_check(&self) -> RedisCacheResult<()> {
        let pong: String = match &self.backend {
            RedisCacheBackend::Single(client) => {
                let mut connection = self.single_connection(client).await?;
                self.run_command("PING", async {
                    redis::cmd("PING").query_async(&mut connection).await
                })
                .await?
            }
            RedisCacheBackend::Cluster(client) => {
                let mut connection = self.cluster_connection(client).await?;
                self.run_command("PING", async {
                    redis::cmd("PING").query_async(&mut connection).await
                })
                .await?
            }
        };
        if pong == "PONG" {
            Ok(())
        } else {
            Err(RedisCacheError::Backend(format!(
                "unexpected PING response: {pong}"
            )))
        }
    }

    pub(super) async fn single_connection(
        &self,
        client: &redis::Client,
    ) -> RedisCacheResult<MultiplexedConnection> {
        let guard = self.allow_breaker("CONNECT").await?;
        let result = timeout(
            self.config.connect_timeout,
            client.get_multiplexed_async_connection(),
        )
        .await
        .map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
        .map_err(|error| RedisCacheError::Connection(error.to_string()));

        self.record_pool_outcome(&result);
        self.record_breaker_outcome(guard, &result).await;
        result
    }

    pub(super) async fn cluster_connection(
        &self,
        client: &redis::cluster::ClusterClient,
    ) -> RedisCacheResult<ClusterConnection> {
        let guard = self.allow_breaker("CLUSTER CONNECT").await?;
        let result = timeout(self.config.connect_timeout, client.get_async_connection())
            .await
            .map_err(|_| RedisCacheError::Timeout("cluster connect".to_string()))?
            .map_err(|error| RedisCacheError::Connection(error.to_string()));

        self.record_pool_outcome(&result);
        self.record_breaker_outcome(guard, &result).await;
        result
    }

    fn record_pool_outcome<T>(&self, result: &RedisCacheResult<T>) {
        let outcome = match result {
            Ok(_) => crate::cache_redis::RedisCommandOutcome::Success,
            Err(error) => classify_redis_error(error),
        };
        self.recorder
            .record_event(crate::cache_redis::RedisCommandEvent::new(
                crate::cache_redis::RedisCommandEventKind::Pool,
                outcome,
            ));
    }

    pub(super) async fn run_command<T, F>(
        &self,
        operation: &'static str,
        future: F,
    ) -> RedisCacheResult<T>
    where
        F: std::future::Future<Output = redis::RedisResult<T>>,
    {
        run_redis_cache_command(
            &self.recorder,
            operation,
            self.config.command_timeout,
            self.protect_command(operation, future),
        )
        .await
    }

    /// Returns a Redis breaker snapshot when breaker protection is enabled.
    pub async fn breaker_snapshot(&self) -> Option<crate::resil::BreakerSnapshot> {
        match &self.breaker {
            Some(breaker) => Some(breaker.snapshot().await),
            None => None,
        }
    }
}

#[async_trait]
impl CacheStore for RedisCacheStore {
    async fn get_raw(&self, key: &CacheKey) -> CacheResult<Option<Vec<u8>>> {
        if let Err(error) = self.ensure_breaker_allows(RedisOperation::Get).await {
            return match error {
                Ok(()) => Ok(None),
                Err(error) => Err(error),
            };
        }
        match &self.backend {
            RedisCacheBackend::Single(client) => self.get_raw_single(client, key).await,
            RedisCacheBackend::Cluster(client) => self.get_raw_cluster(client, key).await,
        }
    }

    async fn set_raw(
        &self,
        key: &CacheKey,
        value: Vec<u8>,
        ttl: Option<Duration>,
    ) -> CacheResult<()> {
        if let Err(error) = self.ensure_breaker_allows(RedisOperation::Set).await {
            return error;
        }
        match &self.backend {
            RedisCacheBackend::Single(client) => self.set_raw_single(client, key, value, ttl).await,
            RedisCacheBackend::Cluster(client) => {
                self.set_raw_cluster(client, key, value, ttl).await
            }
        }
    }

    async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
        if let Err(error) = self.ensure_breaker_allows(RedisOperation::Delete).await {
            return error;
        }
        match &self.backend {
            RedisCacheBackend::Single(client) => self.delete_single(client, key).await,
            RedisCacheBackend::Cluster(client) => self.delete_cluster(client, key).await,
        }
    }

    async fn delete_many(&self, keys: &[CacheKey]) -> CacheResult<()> {
        if let Err(error) = self.ensure_breaker_allows(RedisOperation::Delete).await {
            return error;
        }
        match &self.backend {
            RedisCacheBackend::Single(client) => self.delete_many_single(client, keys).await,
            RedisCacheBackend::Cluster(client) => self.delete_many_cluster(client, keys).await,
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::cache_redis::{RedisCacheConfig, RedisCacheError, RedisCacheStore};

    #[test]
    fn redis_store_rejects_invalid_url() {
        let error = RedisCacheStore::new(RedisCacheConfig {
            url: "http://127.0.0.1".to_string(),
            ..RedisCacheConfig::default()
        })
        .expect_err("invalid url");
        assert!(matches!(error, RedisCacheError::InvalidUrl { .. }));
    }
}