1use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedSemaphore, LockHandle};
8use fred::prelude::*;
9use rand::Rng;
10use tokio::sync::watch;
11
12pub struct RedisDistributedSemaphore {
17 key: String,
19 name: String,
21 max_count: u32,
23 client: RedisClient,
25 expiry: Duration,
27 extension_cadence: Duration,
29}
30
31impl RedisDistributedSemaphore {
32 pub(crate) fn new(
33 name: String,
34 max_count: u32,
35 client: RedisClient,
36 expiry: Duration,
37 extension_cadence: Duration,
38 ) -> Self {
39 let key = format!("distributed-lock:semaphore:{}", name);
41 Self {
42 key,
43 name,
44 max_count,
45 client,
46 expiry,
47 extension_cadence,
48 }
49 }
50
51 fn generate_lock_id() -> String {
53 let mut rng = rand::thread_rng();
54 format!("{:016x}", rng.gen::<u64>())
55 }
56
57 fn now_millis() -> u64 {
59 SystemTime::now()
60 .duration_since(UNIX_EPOCH)
61 .unwrap()
62 .as_millis() as u64
63 }
64
65 async fn try_acquire_internal(&self) -> LockResult<Option<RedisSemaphoreHandle>> {
67 let lock_id = Self::generate_lock_id();
68 let now_millis = Self::now_millis();
69 let expiry_millis = self.expiry.as_millis() as u64;
70 let expiry_time = now_millis + expiry_millis;
71
72 let _: u32 = self
77 .client
78 .zremrangebyscore(&self.key, 0.0, now_millis as f64)
79 .await
80 .map_err(|e| {
81 LockError::Backend(Box::new(std::io::Error::other(format!(
82 "Redis error: {}",
83 e
84 ))))
85 })?;
86
87 let count: u32 = self.client.zcard(&self.key).await.map_err(|e| {
89 LockError::Backend(Box::new(std::io::Error::other(format!(
90 "Redis error: {}",
91 e
92 ))))
93 })?;
94
95 if count >= self.max_count {
96 return Ok(None);
97 }
98
99 let _: () = self
101 .client
102 .zadd(
103 &self.key,
104 None,
105 None,
106 false,
107 false,
108 (expiry_time as f64, lock_id.clone()),
109 )
110 .await
111 .map_err(|e| {
112 LockError::Backend(Box::new(std::io::Error::other(format!(
113 "Redis error: {}",
114 e
115 ))))
116 })?;
117
118 let set_expiry = expiry_millis * 2;
120 let _: bool = self
121 .client
122 .pexpire(&self.key, set_expiry as i64, None)
123 .await
124 .map_err(|e| {
125 LockError::Backend(Box::new(std::io::Error::other(format!(
126 "Redis error: {}",
127 e
128 ))))
129 })?;
130
131 let (sender, receiver) = watch::channel(false);
133 Ok(Some(RedisSemaphoreHandle::new(
134 self.key.clone(),
135 lock_id,
136 self.client.clone(),
137 self.expiry,
138 self.extension_cadence,
139 sender,
140 receiver,
141 )))
142 }
143}
144
145impl DistributedSemaphore for RedisDistributedSemaphore {
146 type Handle = RedisSemaphoreHandle;
147
148 fn name(&self) -> &str {
149 &self.name
150 }
151
152 fn max_count(&self) -> u32 {
153 self.max_count
154 }
155
156 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
157 let timeout_value = TimeoutValue::from(timeout);
158 let start = std::time::Instant::now();
159
160 let mut sleep_duration = Duration::from_millis(10);
162 const MAX_SLEEP: Duration = Duration::from_millis(200);
163
164 loop {
165 match self.try_acquire_internal().await {
166 Ok(Some(handle)) => return Ok(handle),
167 Ok(None) => {
168 if !timeout_value.is_infinite()
170 && start.elapsed() >= timeout_value.as_duration().unwrap()
171 {
172 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
173 }
174
175 tokio::time::sleep(sleep_duration).await;
177 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
178 }
179 Err(e) => return Err(e),
180 }
181 }
182 }
183
184 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
185 self.try_acquire_internal().await
186 }
187}
188
189pub struct RedisSemaphoreHandle {
191 key: String,
193 lock_id: String,
195 client: RedisClient,
197 #[allow(dead_code)]
199 expiry: Duration,
200 #[allow(dead_code)]
202 extension_cadence: Duration,
203 lost_receiver: watch::Receiver<bool>,
205 _extension_task: tokio::task::JoinHandle<()>,
207}
208
209impl RedisSemaphoreHandle {
210 pub(crate) fn new(
211 key: String,
212 lock_id: String,
213 client: RedisClient,
214 expiry: Duration,
215 extension_cadence: Duration,
216 lost_sender: watch::Sender<bool>,
217 lost_receiver: watch::Receiver<bool>,
218 ) -> Self {
219 let extension_key = key.clone();
220 let extension_lock_id = lock_id.clone();
221 let extension_client = client.clone();
222 let extension_expiry = expiry;
223 let extension_lost_sender = lost_sender.clone();
224
225 let extension_task = tokio::spawn(async move {
227 let mut interval = tokio::time::interval(extension_cadence);
228 loop {
229 interval.tick().await;
230
231 if extension_lost_sender.is_closed() {
233 break;
234 }
235
236 let now_millis = RedisDistributedSemaphore::now_millis();
238 let expiry_millis = extension_expiry.as_millis() as u64;
239 let expiry_time = now_millis + expiry_millis;
240
241 let _: u32 = match extension_client
243 .zremrangebyscore(&extension_key, 0.0, now_millis as f64)
244 .await
245 {
246 Ok(count) => count,
247 Err(_) => {
248 let _ = extension_lost_sender.send(true);
250 break;
251 }
252 };
253
254 let result: u32 = match extension_client
257 .zadd(
258 &extension_key,
259 None,
260 None,
261 false,
262 false,
263 (expiry_time as f64, extension_lock_id.clone()),
264 )
265 .await
266 {
267 Ok(count) => count,
268 Err(_) => {
269 let _ = extension_lost_sender.send(true);
271 break;
272 }
273 };
274
275 let set_expiry = expiry_millis * 2;
277 let _: bool = match extension_client
278 .pexpire(&extension_key, set_expiry as i64, None)
279 .await
280 {
281 Ok(result) => result,
282 Err(_) => {
283 let _ = extension_lost_sender.send(true);
285 break;
286 }
287 };
288
289 if result == 0 {
292 let _ = extension_lost_sender.send(true);
293 break;
294 }
295 }
296 });
297
298 Self {
299 key,
300 lock_id,
301 client,
302 expiry,
303 extension_cadence,
304 lost_receiver,
305 _extension_task: extension_task,
306 }
307 }
308}
309
310impl LockHandle for RedisSemaphoreHandle {
311 fn lost_token(&self) -> &watch::Receiver<bool> {
312 &self.lost_receiver
313 }
314
315 async fn release(self) -> LockResult<()> {
316 self._extension_task.abort();
318
319 let _: () = self
321 .client
322 .zrem(&self.key, &self.lock_id)
323 .await
324 .map_err(|e| {
325 LockError::Backend(Box::new(std::io::Error::other(format!(
326 "failed to release semaphore ticket: {}",
327 e
328 ))))
329 })?;
330
331 Ok(())
332 }
333}
334
335impl Drop for RedisSemaphoreHandle {
336 fn drop(&mut self) {
337 self._extension_task.abort();
339 }
342}