use std::{future::Future, sync::Arc, time::Duration};
use crate::{
cache_redis::{
RedisCacheResult, RedisCommandOutcome, classify_redis_error, redis_breaker_rejection,
redis_cluster_hash_tag, run_redis_cache_command,
},
lock::{LockError, LockGuard, LockResult},
resil::BreakerCallError,
};
use super::{RedisDistributedLock, connection::RedisLockBackend};
pub(crate) async fn health_check(lock: &RedisDistributedLock) -> LockResult<()> {
let pong: String = match &lock.backend {
RedisLockBackend::Single(client) => {
let mut connection = lock.single_connection(client).await?;
lock.run_command("LOCK_PING", async {
redis::cmd("PING").query_async(&mut connection).await
})
.await?
}
RedisLockBackend::Cluster(client) => {
let mut connection = lock.cluster_connection(client).await?;
lock.run_command("LOCK_PING", async {
redis::cmd("PING").query_async(&mut connection).await
})
.await?
}
};
if pong == "PONG" {
Ok(())
} else {
Err(LockError::Backend(format!(
"unexpected PING response: {pong}"
)))
}
}
pub(crate) async fn acquire(
lock: &RedisDistributedLock,
rendered: String,
token: Arc<str>,
ttl: Duration,
) -> LockResult<Option<LockGuard>> {
match &lock.backend {
RedisLockBackend::Single(client) => {
let mut connection = lock.single_connection(client).await?;
lock.acquire_on_connection(&mut connection, rendered, token, ttl)
.await
}
RedisLockBackend::Cluster(client) => {
let mut connection = lock.cluster_connection(client).await?;
lock.acquire_on_connection(&mut connection, rendered, token, ttl)
.await
}
}
}
pub(crate) async fn release(lock: &RedisDistributedLock, guard: &LockGuard) -> LockResult<()> {
match &lock.backend {
RedisLockBackend::Single(client) => {
let mut connection = lock.single_connection(client).await?;
lock.release_on_connection(&mut connection, guard).await
}
RedisLockBackend::Cluster(client) => {
let mut connection = lock.cluster_connection(client).await?;
lock.release_on_connection(&mut connection, guard).await
}
}
}
impl RedisDistributedLock {
pub(crate) async fn acquire_on_connection<C>(
&self,
connection: &mut C,
rendered: String,
token: Arc<str>,
ttl: Duration,
) -> LockResult<Option<LockGuard>>
where
C: redis::aio::ConnectionLike + Send,
{
let ttl_millis = ttl.as_millis().max(1) as u64;
let result: Option<String> = self
.run_command("LOCK_ACQUIRE", async {
redis::cmd("SET")
.arg(&rendered)
.arg(token.as_ref())
.arg("NX")
.arg("PX")
.arg(ttl_millis)
.query_async(connection)
.await
})
.await?;
match result.as_deref() {
Some("OK") => Ok(Some(LockGuard::new(rendered, token, ttl))),
None => Ok(None),
Some(value) => Err(LockError::Backend(format!(
"unexpected SET NX response: {value}"
))),
}
}
pub(crate) async fn release_on_connection<C>(
&self,
connection: &mut C,
guard: &LockGuard,
) -> LockResult<()>
where
C: redis::aio::ConnectionLike + Send,
{
let keys = [guard.key().to_string()];
let args = [guard.token().to_string()];
let result = run_redis_cache_command(
&self.recorder,
"LOCK_RELEASE",
self.config.command_timeout,
self.protect_cache_result("LOCK_RELEASE", async {
self.invoke_release_script(connection, &keys, &args).await
}),
)
.await?;
if result == 1 {
Ok(())
} else {
Err(LockError::OwnerMismatch)
}
}
pub(crate) async fn run_command<T, F>(
&self,
operation: &'static str,
future: F,
) -> LockResult<T>
where
F: Future<Output = redis::RedisResult<T>>,
{
run_redis_cache_command(
&self.recorder,
operation,
self.config.command_timeout,
self.protect_redis_result(operation, future),
)
.await
.map_err(Into::into)
}
async fn protect_redis_result<T, F>(
&self,
operation: &'static str,
future: F,
) -> RedisCacheResult<T>
where
F: Future<Output = redis::RedisResult<T>>,
{
self.protect_cache_result(operation, async {
future.await.map_err(crate::cache_redis::redis_error)
})
.await
}
pub(crate) async fn protect_cache_result<T, F>(
&self,
operation: &'static str,
future: F,
) -> RedisCacheResult<T>
where
F: Future<Output = RedisCacheResult<T>>,
{
let Some(breaker) = &self.breaker else {
return future.await;
};
match breaker
.do_with_acceptable(
|| future,
crate::cache_redis::store_breaker::is_breaker_acceptable_error,
)
.await
{
Ok(value) => {
self.recorder.record_breaker(RedisCommandOutcome::Success);
Ok(value)
}
Err(BreakerCallError::Rejected(error)) => {
self.recorder
.record_breaker(RedisCommandOutcome::BreakerRejected);
Err(redis_breaker_rejection(operation, error))
}
Err(BreakerCallError::Inner(error)) => {
self.recorder.record_breaker(classify_redis_error(&error));
Err(error)
}
}
}
pub(crate) fn validate_acquire(&self, key: &str, ttl: Duration) -> LockResult<()> {
if key.trim().is_empty() {
return Err(LockError::InvalidConfig("lock key is required".to_string()));
}
if ttl.is_zero() {
return Err(LockError::InvalidConfig(
"lock ttl must be greater than zero".to_string(),
));
}
let rendered = self.render_key(key);
if self.config.redis.cluster.enabled
&& self.config.require_cluster_hash_tag
&& redis_cluster_hash_tag(&rendered) == rendered
{
return Err(LockError::InvalidConfig(
"redis cluster lock key must include a non-empty hash tag, for example lock:{order:123}"
.to_string(),
));
}
Ok(())
}
}