1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedSemaphore, LockHandle};
8use fred::prelude::*;
9use fred::types::CustomCommand;
10use rand::Rng;
11use tokio::sync::watch;
12
13pub struct RedisDistributedSemaphore {
18 key: String,
20 name: String,
22 max_count: u32,
24 client: RedisClient,
26 expiry: Duration,
28 extension_cadence: Duration,
30}
31
32impl RedisDistributedSemaphore {
33 const ACQUIRE_SCRIPT: &'static str = r#"
40 redis.replicate_commands()
41 local nowResult = redis.call('time')
42 local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
43
44 redis.call('zremrangebyscore', KEYS[1], '-inf', nowMillis)
45
46 if redis.call('zcard', KEYS[1]) < tonumber(ARGV[1]) then
47 redis.call('zadd', KEYS[1], nowMillis + tonumber(ARGV[2]), ARGV[3])
48
49 -- Extend key TTL (set to 3x ticket expiry to be safe)
50 local keyTtl = redis.call('pttl', KEYS[1])
51 if keyTtl < tonumber(ARGV[4]) then
52 redis.call('pexpire', KEYS[1], ARGV[4])
53 end
54 return 1
55 end
56 return 0
57 "#;
58
59 pub(crate) fn new(
60 name: String,
61 max_count: u32,
62 client: RedisClient,
63 expiry: Duration,
64 extension_cadence: Duration,
65 ) -> Self {
66 let key = format!("distributed-lock:semaphore:{}", name);
68 Self {
69 key,
70 name,
71 max_count,
72 client,
73 expiry,
74 extension_cadence,
75 }
76 }
77
78 fn generate_lock_id() -> String {
80 let mut rng = rand::thread_rng();
81 format!("{:016x}", rng.r#gen::<u64>())
82 }
83
84 async fn try_acquire_internal(&self) -> LockResult<Option<RedisSemaphoreHandle>> {
86 let lock_id = Self::generate_lock_id();
87 let expiry_millis = self.expiry.as_millis() as u64;
88
89 let set_expiry_millis = expiry_millis * 3;
91
92 let args: Vec<RedisValue> = vec![
93 Self::ACQUIRE_SCRIPT.into(),
94 1_i64.into(), self.key.clone().into(), (self.max_count as i64).into(), (expiry_millis as i64).into(), lock_id.clone().into(), (set_expiry_millis as i64).into(), ];
101
102 let cmd = CustomCommand::new_static("EVAL", None, false);
103 let result: i64 = self.client.custom(cmd, args).await.map_err(|e| {
104 LockError::Backend(Box::new(std::io::Error::other(format!(
105 "Redis custom EVAL (acquire semaphore) failed: {}",
106 e
107 ))))
108 })?;
109
110 if result == 1 {
111 let (sender, receiver) = watch::channel(false);
113 Ok(Some(RedisSemaphoreHandle::new(
114 self.key.clone(),
115 lock_id,
116 self.client.clone(),
117 self.expiry,
118 self.extension_cadence,
119 sender,
120 receiver,
121 )))
122 } else {
123 Ok(None)
124 }
125 }
126}
127
128impl DistributedSemaphore for RedisDistributedSemaphore {
129 type Handle = RedisSemaphoreHandle;
130
131 fn name(&self) -> &str {
132 &self.name
133 }
134
135 fn max_count(&self) -> u32 {
136 self.max_count
137 }
138
139 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
140 let timeout_value = TimeoutValue::from(timeout);
141 let start = std::time::Instant::now();
142
143 let mut sleep_duration = Duration::from_millis(10);
146 const MAX_SLEEP: Duration = Duration::from_millis(200);
147
148 loop {
149 match self.try_acquire_internal().await {
150 Ok(Some(handle)) => return Ok(handle),
151 Ok(None) => {
152 if !timeout_value.is_infinite()
154 && start.elapsed() >= timeout_value.as_duration().unwrap()
155 {
156 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
157 }
158
159 tokio::time::sleep(sleep_duration).await;
161 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
162 }
163 Err(e) => return Err(e),
164 }
165 }
166 }
167
168 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
169 self.try_acquire_internal().await
170 }
171}
172
173pub struct RedisSemaphoreHandle {
175 key: String,
177 lock_id: String,
179 client: RedisClient,
181 #[allow(dead_code)]
183 expiry: Duration,
184 #[allow(dead_code)]
186 extension_cadence: Duration,
187 lost_receiver: watch::Receiver<bool>,
189 _extension_task: tokio::task::JoinHandle<()>,
191}
192
193impl RedisSemaphoreHandle {
194 const EXTEND_SCRIPT: &'static str = r#"
199 redis.replicate_commands()
200 local nowResult = redis.call('time')
201 local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
202
203 local result = redis.call('zadd', KEYS[1], 'XX', 'CH', nowMillis + tonumber(ARGV[1]), ARGV[2])
204
205 -- Extend key TTL
206 local keyTtl = redis.call('pttl', KEYS[1])
207 if keyTtl < tonumber(ARGV[3]) then
208 redis.call('pexpire', KEYS[1], ARGV[3])
209 end
210 return result
211 "#;
212
213 const RELEASE_SCRIPT: &'static str = r#"
215 return redis.call('zrem', KEYS[1], ARGV[1])
216 "#;
217
218 pub(crate) fn new(
219 key: String,
220 lock_id: String,
221 client: RedisClient,
222 expiry: Duration,
223 extension_cadence: Duration,
224 lost_sender: watch::Sender<bool>,
225 lost_receiver: watch::Receiver<bool>,
226 ) -> Self {
227 let extension_key = key.clone();
228 let extension_lock_id = lock_id.clone();
229 let extension_client = client.clone();
230 let extension_expiry = expiry;
231 let extension_lost_sender = lost_sender.clone();
232
233 let extension_task = tokio::spawn(async move {
235 let mut interval = tokio::time::interval(extension_cadence);
236 let set_expiry_millis = extension_expiry.as_millis() * 3;
238
239 loop {
240 interval.tick().await;
241
242 if extension_lost_sender.is_closed() {
244 break;
245 }
246
247 let expiry_millis = extension_expiry.as_millis() as u64;
248
249 let args: Vec<RedisValue> = vec![
250 Self::EXTEND_SCRIPT.into(),
251 1_i64.into(), extension_key.clone().into(),
253 (expiry_millis as i64).into(),
254 extension_lock_id.clone().into(),
255 (set_expiry_millis as i64).into(),
256 ];
257
258 let cmd = CustomCommand::new_static("EVAL", None, false);
259 let result_op: Result<i64, _> = extension_client.custom(cmd, args).await;
260
261 match result_op {
262 Ok(changed_count) => {
263 if changed_count == 0 {
266 let _ = extension_lost_sender.send(true);
267 break;
268 }
269 }
270 Err(_) => {
271 let _ = extension_lost_sender.send(true);
273 break;
274 }
275 }
276 }
277 });
278
279 Self {
280 key,
281 lock_id,
282 client,
283 expiry,
284 extension_cadence,
285 lost_receiver,
286 _extension_task: extension_task,
287 }
288 }
289}
290
291impl LockHandle for RedisSemaphoreHandle {
292 fn lost_token(&self) -> &watch::Receiver<bool> {
293 &self.lost_receiver
294 }
295
296 async fn release(self) -> LockResult<()> {
297 self._extension_task.abort();
299
300 let args: Vec<RedisValue> = vec![
301 Self::RELEASE_SCRIPT.into(),
302 1_i64.into(), self.key.clone().into(),
304 self.lock_id.clone().into(),
305 ];
306
307 let cmd = CustomCommand::new_static("EVAL", None, false);
308
309 let _: i64 = self.client.custom(cmd, args).await.map_err(|e| {
311 LockError::Backend(Box::new(std::io::Error::other(format!(
312 "failed to release semaphore ticket: {}",
313 e
314 ))))
315 })?;
316
317 Ok(())
318 }
319}
320
321impl Drop for RedisSemaphoreHandle {
322 fn drop(&mut self) {
323 self._extension_task.abort();
325 }
328}