use std::{sync::Arc, time::Duration};
use uuid::Uuid;
use crate::{
cache_redis::{RedisClientFactory, RedisCommandRecorder},
lock::{DistributedLock, LockGuard, LockResult},
};
const RELEASE_SCRIPT: &str = r#"
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
"#;
mod config;
mod connection;
mod operations;
mod script;
pub use config::RedisLockConfig;
use connection::RedisLockBackend;
#[derive(Clone)]
pub struct RedisDistributedLock {
pub(crate) config: RedisLockConfig,
pub(crate) backend: RedisLockBackend,
pub(crate) release_script: crate::cache_redis::RedisLuaScript,
pub(crate) recorder: RedisCommandRecorder,
pub(crate) breaker: Option<crate::resil::SharedCircuitBreaker>,
}
impl std::fmt::Debug for RedisDistributedLock {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("RedisDistributedLock")
.field("config", &self.config)
.field("backend", &self.backend)
.field("breaker_enabled", &self.breaker.is_some())
.finish_non_exhaustive()
}
}
impl RedisDistributedLock {
pub fn new(config: RedisLockConfig) -> LockResult<Self> {
config.validate()?;
let factory = RedisClientFactory::new(config.redis.clone())?;
if config.redis.cluster.enabled {
Self::from_cluster_client(config, factory.create_cluster_client()?)
} else {
Self::from_client(config, factory.create_client()?)
}
}
pub fn with_client(config: RedisLockConfig, client: redis::Client) -> LockResult<Self> {
config.validate()?;
Self::from_client(config, client)
}
pub fn with_cluster_client(
config: RedisLockConfig,
client: redis::cluster::ClusterClient,
) -> LockResult<Self> {
config.validate()?;
Self::from_cluster_client(config, client)
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
self.recorder = self.recorder.clone().with_metrics(metrics);
self
}
#[cfg(feature = "observability")]
pub fn with_shard_name(mut self, shard: impl Into<String>) -> Self {
self.recorder = self.recorder.clone().with_shard(shard);
self
}
pub async fn health_check(&self) -> LockResult<()> {
operations::health_check(self).await
}
pub async fn breaker_snapshot(&self) -> Option<crate::resil::BreakerSnapshot> {
match &self.breaker {
Some(breaker) => Some(breaker.snapshot().await),
None => None,
}
}
fn from_client(config: RedisLockConfig, client: redis::Client) -> LockResult<Self> {
Ok(Self::from_backend(config, RedisLockBackend::Single(client)))
}
fn from_cluster_client(
config: RedisLockConfig,
client: redis::cluster::ClusterClient,
) -> LockResult<Self> {
Ok(Self::from_backend(
config,
RedisLockBackend::Cluster(client),
))
}
fn from_backend(config: RedisLockConfig, backend: RedisLockBackend) -> Self {
let breaker = config.breaker.enabled.then(|| {
crate::resil::SharedCircuitBreaker::with_policy(
config.breaker.breaker.clone(),
config.breaker.policy.clone(),
)
});
Self {
config,
backend,
release_script: crate::cache_redis::RedisLuaScript::new(RELEASE_SCRIPT),
recorder: RedisCommandRecorder::new(),
breaker,
}
}
pub(crate) fn render_key(&self, key: &str) -> String {
format!("{}:{key}", self.config.namespace)
}
}
#[async_trait::async_trait]
impl DistributedLock for RedisDistributedLock {
async fn acquire(&self, key: &str, ttl: Duration) -> LockResult<Option<LockGuard>> {
self.validate_acquire(key, ttl)?;
let rendered = self.render_key(key);
let token = Arc::<str>::from(Uuid::new_v4().to_string());
operations::acquire(self, rendered, token, ttl).await
}
async fn release(&self, guard: &LockGuard) -> LockResult<()> {
operations::release(self, guard).await
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::{DistributedLock, RedisDistributedLock, RedisLockConfig};
use crate::cache_redis::RedisCacheConfig;
#[test]
fn redis_lock_config_requires_namespace() {
let error = RedisLockConfig {
namespace: " ".to_string(),
..RedisLockConfig::default()
}
.validate()
.expect_err("invalid namespace");
assert!(error.to_string().contains("namespace"));
}
#[tokio::test]
async fn redis_cluster_lock_requires_hash_tag_by_default() {
let lock = RedisDistributedLock::new(RedisLockConfig {
redis: RedisCacheConfig {
url: "redis://127.0.0.1:6379".to_string(),
cluster: crate::cache_redis::RedisClusterConfig {
enabled: true,
..crate::cache_redis::RedisClusterConfig::default()
},
..RedisCacheConfig::default()
},
..RedisLockConfig::default()
})
.expect("lock backend");
let error = lock
.acquire("order:123", Duration::from_secs(1))
.await
.expect_err("missing hash tag");
assert!(error.to_string().contains("hash tag"));
}
}