rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{sync::Arc, time::Duration};

use uuid::Uuid;

use crate::{
    cache_redis::{RedisClientFactory, RedisCommandRecorder},
    lock::{DistributedLock, LockGuard, LockResult},
};

const RELEASE_SCRIPT: &str = r#"
if redis.call("GET", KEYS[1]) == ARGV[1] then
  return redis.call("DEL", KEYS[1])
end
return 0
"#;

mod config;
mod connection;
mod operations;
mod script;

pub use config::RedisLockConfig;

use connection::RedisLockBackend;

/// Redis `SET NX PX` distributed lock backend.
#[derive(Clone)]
pub struct RedisDistributedLock {
    pub(crate) config: RedisLockConfig,
    pub(crate) backend: RedisLockBackend,
    pub(crate) release_script: crate::cache_redis::RedisLuaScript,
    pub(crate) recorder: RedisCommandRecorder,
    pub(crate) breaker: Option<crate::resil::SharedCircuitBreaker>,
}

impl std::fmt::Debug for RedisDistributedLock {
    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        formatter
            .debug_struct("RedisDistributedLock")
            .field("config", &self.config)
            .field("backend", &self.backend)
            .field("breaker_enabled", &self.breaker.is_some())
            .finish_non_exhaustive()
    }
}

impl RedisDistributedLock {
    /// Creates a Redis lock backend from configuration.
    pub fn new(config: RedisLockConfig) -> LockResult<Self> {
        config.validate()?;
        let factory = RedisClientFactory::new(config.redis.clone())?;
        if config.redis.cluster.enabled {
            Self::from_cluster_client(config, factory.create_cluster_client()?)
        } else {
            Self::from_client(config, factory.create_client()?)
        }
    }

    /// Creates a Redis lock backend from an existing single-node client.
    pub fn with_client(config: RedisLockConfig, client: redis::Client) -> LockResult<Self> {
        config.validate()?;
        Self::from_client(config, client)
    }

    /// Creates a Redis lock backend from an existing Cluster client.
    pub fn with_cluster_client(
        config: RedisLockConfig,
        client: redis::cluster::ClusterClient,
    ) -> LockResult<Self> {
        config.validate()?;
        Self::from_cluster_client(config, client)
    }

    /// Attaches a metrics registry to this Redis lock backend.
    #[cfg(feature = "observability")]
    pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
        self.recorder = self.recorder.clone().with_metrics(metrics);
        self
    }

    /// Sets the low-cardinality shard name used in Redis metrics.
    #[cfg(feature = "observability")]
    pub fn with_shard_name(mut self, shard: impl Into<String>) -> Self {
        self.recorder = self.recorder.clone().with_shard(shard);
        self
    }

    /// Validates Redis reachability with PING.
    pub async fn health_check(&self) -> LockResult<()> {
        operations::health_check(self).await
    }

    /// Returns a breaker snapshot when breaker protection is enabled.
    pub async fn breaker_snapshot(&self) -> Option<crate::resil::BreakerSnapshot> {
        match &self.breaker {
            Some(breaker) => Some(breaker.snapshot().await),
            None => None,
        }
    }

    fn from_client(config: RedisLockConfig, client: redis::Client) -> LockResult<Self> {
        Ok(Self::from_backend(config, RedisLockBackend::Single(client)))
    }

    fn from_cluster_client(
        config: RedisLockConfig,
        client: redis::cluster::ClusterClient,
    ) -> LockResult<Self> {
        Ok(Self::from_backend(
            config,
            RedisLockBackend::Cluster(client),
        ))
    }

    fn from_backend(config: RedisLockConfig, backend: RedisLockBackend) -> Self {
        let breaker = config.breaker.enabled.then(|| {
            crate::resil::SharedCircuitBreaker::with_policy(
                config.breaker.breaker.clone(),
                config.breaker.policy.clone(),
            )
        });
        Self {
            config,
            backend,
            release_script: crate::cache_redis::RedisLuaScript::new(RELEASE_SCRIPT),
            recorder: RedisCommandRecorder::new(),
            breaker,
        }
    }

    pub(crate) fn render_key(&self, key: &str) -> String {
        format!("{}:{key}", self.config.namespace)
    }
}

#[async_trait::async_trait]
impl DistributedLock for RedisDistributedLock {
    async fn acquire(&self, key: &str, ttl: Duration) -> LockResult<Option<LockGuard>> {
        self.validate_acquire(key, ttl)?;
        let rendered = self.render_key(key);
        let token = Arc::<str>::from(Uuid::new_v4().to_string());
        operations::acquire(self, rendered, token, ttl).await
    }

    async fn release(&self, guard: &LockGuard) -> LockResult<()> {
        operations::release(self, guard).await
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::{DistributedLock, RedisDistributedLock, RedisLockConfig};
    use crate::cache_redis::RedisCacheConfig;

    #[test]
    fn redis_lock_config_requires_namespace() {
        let error = RedisLockConfig {
            namespace: " ".to_string(),
            ..RedisLockConfig::default()
        }
        .validate()
        .expect_err("invalid namespace");

        assert!(error.to_string().contains("namespace"));
    }

    #[tokio::test]
    async fn redis_cluster_lock_requires_hash_tag_by_default() {
        let lock = RedisDistributedLock::new(RedisLockConfig {
            redis: RedisCacheConfig {
                url: "redis://127.0.0.1:6379".to_string(),
                cluster: crate::cache_redis::RedisClusterConfig {
                    enabled: true,
                    ..crate::cache_redis::RedisClusterConfig::default()
                },
                ..RedisCacheConfig::default()
            },
            ..RedisLockConfig::default()
        })
        .expect("lock backend");

        let error = lock
            .acquire("order:123", Duration::from_secs(1))
            .await
            .expect_err("missing hash tag");

        assert!(error.to_string().contains("hash tag"));
    }
}