use std::time::Duration;
use distributed_lock_core::error::{LockError, LockResult};
use distributed_lock_core::timeout::TimeoutValue;
use distributed_lock_core::traits::{DistributedSemaphore, LockHandle};
use fred::prelude::*;
use fred::types::CustomCommand;
use rand::Rng;
use tokio::sync::watch;
pub struct RedisDistributedSemaphore {
key: String,
name: String,
max_count: u32,
client: RedisClient,
expiry: Duration,
extension_cadence: Duration,
}
impl RedisDistributedSemaphore {
const ACQUIRE_SCRIPT: &'static str = r#"
redis.replicate_commands()
local nowResult = redis.call('time')
local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
redis.call('zremrangebyscore', KEYS[1], '-inf', nowMillis)
if redis.call('zcard', KEYS[1]) < tonumber(ARGV[1]) then
redis.call('zadd', KEYS[1], nowMillis + tonumber(ARGV[2]), ARGV[3])
-- Extend key TTL (set to 3x ticket expiry to be safe)
local keyTtl = redis.call('pttl', KEYS[1])
if keyTtl < tonumber(ARGV[4]) then
redis.call('pexpire', KEYS[1], ARGV[4])
end
return 1
end
return 0
"#;
pub(crate) fn new(
name: String,
max_count: u32,
client: RedisClient,
expiry: Duration,
extension_cadence: Duration,
) -> Self {
let key = format!("distributed-lock:semaphore:{}", name);
Self {
key,
name,
max_count,
client,
expiry,
extension_cadence,
}
}
fn generate_lock_id() -> String {
let mut rng = rand::thread_rng();
format!("{:016x}", rng.r#gen::<u64>())
}
async fn try_acquire_internal(&self) -> LockResult<Option<RedisSemaphoreHandle>> {
let lock_id = Self::generate_lock_id();
let expiry_millis = self.expiry.as_millis() as u64;
let set_expiry_millis = expiry_millis * 3;
let args: Vec<RedisValue> = vec![
Self::ACQUIRE_SCRIPT.into(),
1_i64.into(), self.key.clone().into(), (self.max_count as i64).into(), (expiry_millis as i64).into(), lock_id.clone().into(), (set_expiry_millis as i64).into(), ];
let cmd = CustomCommand::new_static("EVAL", None, false);
let result: i64 = self.client.custom(cmd, args).await.map_err(|e| {
LockError::Backend(Box::new(std::io::Error::other(format!(
"Redis custom EVAL (acquire semaphore) failed: {}",
e
))))
})?;
if result == 1 {
let (sender, receiver) = watch::channel(false);
Ok(Some(RedisSemaphoreHandle::new(
self.key.clone(),
lock_id,
self.client.clone(),
self.expiry,
self.extension_cadence,
sender,
receiver,
)))
} else {
Ok(None)
}
}
}
impl DistributedSemaphore for RedisDistributedSemaphore {
type Handle = RedisSemaphoreHandle;
fn name(&self) -> &str {
&self.name
}
fn max_count(&self) -> u32 {
self.max_count
}
async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
let timeout_value = TimeoutValue::from(timeout);
let start = std::time::Instant::now();
let mut sleep_duration = Duration::from_millis(10);
const MAX_SLEEP: Duration = Duration::from_millis(200);
loop {
match self.try_acquire_internal().await {
Ok(Some(handle)) => return Ok(handle),
Ok(None) => {
if !timeout_value.is_infinite()
&& start.elapsed() >= timeout_value.as_duration().unwrap()
{
return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
}
tokio::time::sleep(sleep_duration).await;
sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
}
Err(e) => return Err(e),
}
}
}
async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
self.try_acquire_internal().await
}
}
pub struct RedisSemaphoreHandle {
key: String,
lock_id: String,
client: RedisClient,
#[allow(dead_code)]
expiry: Duration,
#[allow(dead_code)]
extension_cadence: Duration,
lost_receiver: watch::Receiver<bool>,
_extension_task: tokio::task::JoinHandle<()>,
}
impl RedisSemaphoreHandle {
const EXTEND_SCRIPT: &'static str = r#"
redis.replicate_commands()
local nowResult = redis.call('time')
local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
local result = redis.call('zadd', KEYS[1], 'XX', 'CH', nowMillis + tonumber(ARGV[1]), ARGV[2])
-- Extend key TTL
local keyTtl = redis.call('pttl', KEYS[1])
if keyTtl < tonumber(ARGV[3]) then
redis.call('pexpire', KEYS[1], ARGV[3])
end
return result
"#;
const RELEASE_SCRIPT: &'static str = r#"
return redis.call('zrem', KEYS[1], ARGV[1])
"#;
pub(crate) fn new(
key: String,
lock_id: String,
client: RedisClient,
expiry: Duration,
extension_cadence: Duration,
lost_sender: watch::Sender<bool>,
lost_receiver: watch::Receiver<bool>,
) -> Self {
let extension_key = key.clone();
let extension_lock_id = lock_id.clone();
let extension_client = client.clone();
let extension_expiry = expiry;
let extension_lost_sender = lost_sender.clone();
let extension_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(extension_cadence);
let set_expiry_millis = extension_expiry.as_millis() * 3;
loop {
interval.tick().await;
if extension_lost_sender.is_closed() {
break;
}
let expiry_millis = extension_expiry.as_millis() as u64;
let args: Vec<RedisValue> = vec![
Self::EXTEND_SCRIPT.into(),
1_i64.into(), extension_key.clone().into(),
(expiry_millis as i64).into(),
extension_lock_id.clone().into(),
(set_expiry_millis as i64).into(),
];
let cmd = CustomCommand::new_static("EVAL", None, false);
let result_op: Result<i64, _> = extension_client.custom(cmd, args).await;
match result_op {
Ok(changed_count) => {
if changed_count == 0 {
let _ = extension_lost_sender.send(true);
break;
}
}
Err(_) => {
let _ = extension_lost_sender.send(true);
break;
}
}
}
});
Self {
key,
lock_id,
client,
expiry,
extension_cadence,
lost_receiver,
_extension_task: extension_task,
}
}
}
impl LockHandle for RedisSemaphoreHandle {
fn lost_token(&self) -> &watch::Receiver<bool> {
&self.lost_receiver
}
async fn release(self) -> LockResult<()> {
self._extension_task.abort();
let args: Vec<RedisValue> = vec![
Self::RELEASE_SCRIPT.into(),
1_i64.into(), self.key.clone().into(),
self.lock_id.clone().into(),
];
let cmd = CustomCommand::new_static("EVAL", None, false);
let _: i64 = self.client.custom(cmd, args).await.map_err(|e| {
LockError::Backend(Box::new(std::io::Error::other(format!(
"failed to release semaphore ticket: {}",
e
))))
})?;
Ok(())
}
}
impl Drop for RedisSemaphoreHandle {
fn drop(&mut self) {
self._extension_task.abort();
}
}