rs-zero 0.2.8

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{future::Future, sync::Arc, time::Duration};

use crate::{
    cache_redis::{
        RedisCacheResult, RedisCommandOutcome, classify_redis_error, redis_breaker_rejection,
        redis_cluster_hash_tag, run_redis_cache_command,
    },
    lock::{LockError, LockGuard, LockResult},
    resil::BreakerCallError,
};

use super::{RedisDistributedLock, connection::RedisLockBackend};

pub(crate) async fn health_check(lock: &RedisDistributedLock) -> LockResult<()> {
    let pong: String = match &lock.backend {
        RedisLockBackend::Single(client) => {
            let mut connection = lock.single_connection(client).await?;
            lock.run_command("LOCK_PING", async {
                redis::cmd("PING").query_async(&mut connection).await
            })
            .await?
        }
        RedisLockBackend::Cluster(client) => {
            let mut connection = lock.cluster_connection(client).await?;
            lock.run_command("LOCK_PING", async {
                redis::cmd("PING").query_async(&mut connection).await
            })
            .await?
        }
    };
    if pong == "PONG" {
        Ok(())
    } else {
        Err(LockError::Backend(format!(
            "unexpected PING response: {pong}"
        )))
    }
}

pub(crate) async fn acquire(
    lock: &RedisDistributedLock,
    rendered: String,
    token: Arc<str>,
    ttl: Duration,
) -> LockResult<Option<LockGuard>> {
    match &lock.backend {
        RedisLockBackend::Single(client) => {
            let mut connection = lock.single_connection(client).await?;
            lock.acquire_on_connection(&mut connection, rendered, token, ttl)
                .await
        }
        RedisLockBackend::Cluster(client) => {
            let mut connection = lock.cluster_connection(client).await?;
            lock.acquire_on_connection(&mut connection, rendered, token, ttl)
                .await
        }
    }
}

pub(crate) async fn release(lock: &RedisDistributedLock, guard: &LockGuard) -> LockResult<()> {
    match &lock.backend {
        RedisLockBackend::Single(client) => {
            let mut connection = lock.single_connection(client).await?;
            lock.release_on_connection(&mut connection, guard).await
        }
        RedisLockBackend::Cluster(client) => {
            let mut connection = lock.cluster_connection(client).await?;
            lock.release_on_connection(&mut connection, guard).await
        }
    }
}

impl RedisDistributedLock {
    pub(crate) async fn acquire_on_connection<C>(
        &self,
        connection: &mut C,
        rendered: String,
        token: Arc<str>,
        ttl: Duration,
    ) -> LockResult<Option<LockGuard>>
    where
        C: redis::aio::ConnectionLike + Send,
    {
        let ttl_millis = ttl.as_millis().max(1) as u64;
        let result: Option<String> = self
            .run_command("LOCK_ACQUIRE", async {
                redis::cmd("SET")
                    .arg(&rendered)
                    .arg(token.as_ref())
                    .arg("NX")
                    .arg("PX")
                    .arg(ttl_millis)
                    .query_async(connection)
                    .await
            })
            .await?;

        match result.as_deref() {
            Some("OK") => Ok(Some(LockGuard::new(rendered, token, ttl))),
            None => Ok(None),
            Some(value) => Err(LockError::Backend(format!(
                "unexpected SET NX response: {value}"
            ))),
        }
    }

    pub(crate) async fn release_on_connection<C>(
        &self,
        connection: &mut C,
        guard: &LockGuard,
    ) -> LockResult<()>
    where
        C: redis::aio::ConnectionLike + Send,
    {
        let keys = [guard.key().to_string()];
        let args = [guard.token().to_string()];
        let result = run_redis_cache_command(
            &self.recorder,
            "LOCK_RELEASE",
            self.config.command_timeout,
            self.protect_cache_result("LOCK_RELEASE", async {
                self.invoke_release_script(connection, &keys, &args).await
            }),
        )
        .await?;

        if result == 1 {
            Ok(())
        } else {
            Err(LockError::OwnerMismatch)
        }
    }

    pub(crate) async fn run_command<T, F>(
        &self,
        operation: &'static str,
        future: F,
    ) -> LockResult<T>
    where
        F: Future<Output = redis::RedisResult<T>>,
    {
        run_redis_cache_command(
            &self.recorder,
            operation,
            self.config.command_timeout,
            self.protect_redis_result(operation, future),
        )
        .await
        .map_err(Into::into)
    }

    async fn protect_redis_result<T, F>(
        &self,
        operation: &'static str,
        future: F,
    ) -> RedisCacheResult<T>
    where
        F: Future<Output = redis::RedisResult<T>>,
    {
        self.protect_cache_result(operation, async {
            future.await.map_err(crate::cache_redis::redis_error)
        })
        .await
    }

    pub(crate) async fn protect_cache_result<T, F>(
        &self,
        operation: &'static str,
        future: F,
    ) -> RedisCacheResult<T>
    where
        F: Future<Output = RedisCacheResult<T>>,
    {
        let Some(breaker) = &self.breaker else {
            return future.await;
        };

        match breaker
            .do_with_acceptable(
                || future,
                crate::cache_redis::store_breaker::is_breaker_acceptable_error,
            )
            .await
        {
            Ok(value) => {
                self.recorder.record_breaker(RedisCommandOutcome::Success);
                Ok(value)
            }
            Err(BreakerCallError::Rejected(error)) => {
                self.recorder
                    .record_breaker(RedisCommandOutcome::BreakerRejected);
                Err(redis_breaker_rejection(operation, error))
            }
            Err(BreakerCallError::Inner(error)) => {
                self.recorder.record_breaker(classify_redis_error(&error));
                Err(error)
            }
        }
    }

    pub(crate) fn validate_acquire(&self, key: &str, ttl: Duration) -> LockResult<()> {
        if key.trim().is_empty() {
            return Err(LockError::InvalidConfig("lock key is required".to_string()));
        }
        if ttl.is_zero() {
            return Err(LockError::InvalidConfig(
                "lock ttl must be greater than zero".to_string(),
            ));
        }
        let rendered = self.render_key(key);
        if self.config.redis.cluster.enabled
            && self.config.require_cluster_hash_tag
            && redis_cluster_hash_tag(&rendered) == rendered
        {
            return Err(LockError::InvalidConfig(
                "redis cluster lock key must include a non-empty hash tag, for example lock:{order:123}"
                    .to_string(),
            ));
        }
        Ok(())
    }
}