1use std::sync::Arc;
4use std::time::Duration;
5
6use distributed_lock_core::error::{LockError, LockResult};
7use distributed_lock_core::timeout::TimeoutValue;
8use distributed_lock_core::traits::{DistributedReaderWriterLock, LockHandle};
9use fred::prelude::*;
10use tokio::sync::watch;
11
12use crate::redlock::{
13 acquire::acquire_redlock, extend::extend_redlock, helper::RedLockHelper,
14 release::release_redlock, timeouts::RedLockTimeouts,
15};
16
17#[derive(Debug, Clone)]
19pub(crate) struct RedisReadLockState {
20 reader_key: String,
22 writer_key: String,
24 lock_id: String,
26 timeouts: RedLockTimeouts,
28}
29
30impl RedisReadLockState {
31 fn new(reader_key: String, writer_key: String, timeouts: RedLockTimeouts) -> Self {
32 Self {
33 reader_key,
34 writer_key,
35 lock_id: RedLockHelper::create_lock_id(),
36 timeouts,
37 }
38 }
39
40 async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
42 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
43
44 let writer_exists: bool = client.exists(&self.writer_key).await.map_err(|e| {
46 LockError::Backend(Box::new(std::io::Error::other(format!(
47 "Redis EXISTS failed: {}",
48 e
49 ))))
50 })?;
51
52 if writer_exists {
53 return Ok(false);
54 }
55
56 let _: u32 = client
58 .sadd(&self.reader_key, &self.lock_id)
59 .await
60 .map_err(|e| {
61 LockError::Backend(Box::new(std::io::Error::other(format!(
62 "Redis SADD failed: {}",
63 e
64 ))))
65 })?;
66
67 let current_ttl: i64 = client.pttl(&self.reader_key).await.map_err(|e| {
69 LockError::Backend(Box::new(std::io::Error::other(format!(
70 "Redis PTTL failed: {}",
71 e
72 ))))
73 })?;
74
75 if current_ttl < expiry_millis {
76 let _: bool = client
77 .pexpire(&self.reader_key, expiry_millis, None)
78 .await
79 .map_err(|e| {
80 LockError::Backend(Box::new(std::io::Error::other(format!(
81 "Redis PEXPIRE failed: {}",
82 e
83 ))))
84 })?;
85 }
86
87 Ok(true)
88 }
89
90 async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
92 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
93
94 let is_member: bool = client
96 .sismember(&self.reader_key, &self.lock_id)
97 .await
98 .map_err(|e| {
99 LockError::Backend(Box::new(std::io::Error::other(format!(
100 "Redis SISMEMBER failed: {}",
101 e
102 ))))
103 })?;
104
105 if !is_member {
106 return Ok(false);
107 }
108
109 let current_ttl: i64 = client.pttl(&self.reader_key).await.map_err(|e| {
111 LockError::Backend(Box::new(std::io::Error::other(format!(
112 "Redis PTTL failed: {}",
113 e
114 ))))
115 })?;
116
117 if current_ttl < expiry_millis {
118 let _: bool = client
119 .pexpire(&self.reader_key, expiry_millis, None)
120 .await
121 .map_err(|e| {
122 LockError::Backend(Box::new(std::io::Error::other(format!(
123 "Redis PEXPIRE failed: {}",
124 e
125 ))))
126 })?;
127 }
128
129 Ok(true)
130 }
131
132 async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
134 let _: u32 = client
136 .srem(&self.reader_key, &self.lock_id)
137 .await
138 .map_err(|e| {
139 LockError::Backend(Box::new(std::io::Error::other(format!(
140 "Redis SREM failed: {}",
141 e
142 ))))
143 })?;
144
145 Ok(())
146 }
147}
148
149#[derive(Debug, Clone)]
151pub(crate) struct RedisWriteLockState {
152 reader_key: String,
154 writer_key: String,
156 lock_id: String,
158 waiting_lock_id: String,
160 timeouts: RedLockTimeouts,
162}
163
164const WRITER_WAITING_SUFFIX: &str = "_WRITERWAITING";
165
166impl RedisWriteLockState {
167 fn new(reader_key: String, writer_key: String, timeouts: RedLockTimeouts) -> Self {
168 let lock_id = RedLockHelper::create_lock_id();
169 Self {
170 reader_key,
171 writer_key,
172 waiting_lock_id: format!("{}{}", lock_id, WRITER_WAITING_SUFFIX),
173 lock_id,
174 timeouts,
175 }
176 }
177
178 async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
180 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
181
182 let reader_count: u32 = client.scard(&self.reader_key).await.map_err(|e| {
184 LockError::Backend(Box::new(std::io::Error::other(format!(
185 "Redis SCARD failed: {}",
186 e
187 ))))
188 })?;
189
190 let writer_value: Option<String> = client.get(&self.writer_key).await.map_err(|e| {
192 LockError::Backend(Box::new(std::io::Error::other(format!(
193 "Redis GET failed: {}",
194 e
195 ))))
196 })?;
197
198 if reader_count == 0 {
200 match writer_value {
202 Some(value) => {
203 if value.ends_with(WRITER_WAITING_SUFFIX) {
204 let _: Option<String> = client
206 .set(
207 &self.writer_key,
208 &self.lock_id,
209 Some(Expiration::PX(expiry_millis)),
210 None, false,
212 )
213 .await
214 .map_err(|e| {
215 LockError::Backend(Box::new(std::io::Error::other(format!(
216 "Redis SET failed: {}",
217 e
218 ))))
219 })?;
220 Ok(true)
221 } else {
222 Ok(false)
224 }
225 }
226 None => {
227 let result: Option<String> = client
229 .set(
230 &self.writer_key,
231 &self.lock_id,
232 Some(Expiration::PX(expiry_millis)),
233 Some(SetOptions::NX),
234 false,
235 )
236 .await
237 .map_err(|e| {
238 LockError::Backend(Box::new(std::io::Error::other(format!(
239 "Redis SET NX failed: {}",
240 e
241 ))))
242 })?;
243 Ok(result.is_some())
244 }
245 }
246 } else {
247 match writer_value {
249 Some(value) => {
250 if value == self.waiting_lock_id {
251 let _: bool = client
253 .pexpire(&self.writer_key, expiry_millis, None)
254 .await
255 .map_err(|e| {
256 LockError::Backend(Box::new(std::io::Error::other(format!(
257 "Redis PEXPIRE failed: {}",
258 e
259 ))))
260 })?;
261 Ok(false)
262 } else {
263 Ok(false)
265 }
266 }
267 None => {
268 let _: Option<String> = client
270 .set(
271 &self.writer_key,
272 &self.waiting_lock_id,
273 Some(Expiration::PX(expiry_millis)),
274 Some(SetOptions::NX),
275 false,
276 )
277 .await
278 .map_err(|e| {
279 LockError::Backend(Box::new(std::io::Error::other(format!(
280 "Redis SET NX failed: {}",
281 e
282 ))))
283 })?;
284 Ok(false)
285 }
286 }
287 }
288 }
289
290 async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
292 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
293
294 let writer_value: Option<String> = client.get(&self.writer_key).await.map_err(|e| {
296 LockError::Backend(Box::new(std::io::Error::other(format!(
297 "Redis GET failed: {}",
298 e
299 ))))
300 })?;
301
302 match writer_value {
303 Some(value) if value == self.lock_id => {
304 let _: bool = client
306 .pexpire(&self.writer_key, expiry_millis, None)
307 .await
308 .map_err(|e| {
309 LockError::Backend(Box::new(std::io::Error::other(format!(
310 "Redis PEXPIRE failed: {}",
311 e
312 ))))
313 })?;
314 Ok(true)
315 }
316 _ => Ok(false), }
318 }
319
320 async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
322 let writer_value: Option<String> = client.get(&self.writer_key).await.map_err(|e| {
324 LockError::Backend(Box::new(std::io::Error::other(format!(
325 "Redis GET failed: {}",
326 e
327 ))))
328 })?;
329
330 match writer_value {
331 Some(value) if value == self.lock_id => {
332 let _: i64 = client.del(&self.writer_key).await.map_err(|e| {
334 LockError::Backend(Box::new(std::io::Error::other(format!(
335 "Redis DEL failed: {}",
336 e
337 ))))
338 })?;
339 }
340 _ => {
341 }
343 }
344
345 Ok(())
346 }
347}
348
349pub struct RedisDistributedReaderWriterLock {
353 reader_key: String,
355 writer_key: String,
357 clients: Vec<RedisClient>,
359 extension_cadence: Duration,
361 timeouts: RedLockTimeouts,
363}
364
365impl RedisDistributedReaderWriterLock {
366 pub(crate) fn new(
368 name: String,
369 clients: Vec<RedisClient>,
370 expiry: Duration,
371 min_validity: Duration,
372 extension_cadence: Duration,
373 ) -> Self {
374 let reader_key = format!("distributed-lock:{}:readers", name);
375 let writer_key = format!("distributed-lock:{}:writer", name);
376 let timeouts = RedLockTimeouts::new(expiry, min_validity);
377
378 Self {
379 reader_key,
380 writer_key,
381 clients,
382 extension_cadence,
383 timeouts,
384 }
385 }
386
387 pub fn name(&self) -> &str {
389 self.reader_key
391 .strip_prefix("distributed-lock:")
392 .and_then(|s| s.strip_suffix(":readers"))
393 .unwrap_or(&self.reader_key)
394 }
395}
396
397impl DistributedReaderWriterLock for RedisDistributedReaderWriterLock {
398 type ReadHandle = RedisReadLockHandle;
399 type WriteHandle = RedisWriteLockHandle;
400
401 fn name(&self) -> &str {
402 self.name()
403 }
404
405 async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
406 use tokio::sync::watch;
407
408 let (cancel_sender, cancel_receiver) = watch::channel(false);
410
411 if let Some(timeout_duration) = timeout {
413 let cancel_sender_clone = cancel_sender.clone();
414 tokio::spawn(async move {
415 tokio::time::sleep(timeout_duration).await;
416 let _ = cancel_sender_clone.send(true);
417 });
418 }
419
420 let state = RedisReadLockState::new(
421 self.reader_key.clone(),
422 self.writer_key.clone(),
423 self.timeouts.clone(),
424 );
425 let clients = self.clients.clone();
426 let timeouts = self.timeouts.clone();
427
428 let state_for_acquire = state.clone();
430
431 let acquire_result = acquire_redlock(
433 move |client| {
434 let state = state_for_acquire.clone();
435 let client = client.clone();
436 async move { state.try_acquire(&client).await }
437 },
438 &clients,
439 &timeouts,
440 &cancel_receiver,
441 )
442 .await?;
443
444 let acquire_result = match acquire_result {
445 Some(result) if result.is_successful(clients.len()) => result,
446 _ => {
447 return Err(LockError::Timeout(
448 timeout.unwrap_or(Duration::from_secs(0)),
449 ));
450 }
451 };
452
453 Ok(RedisReadLockHandle::new(
455 state,
456 acquire_result.acquire_results,
457 clients,
458 self.extension_cadence,
459 self.timeouts.expiry,
460 ))
461 }
462
463 async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
464 use tokio::sync::watch;
465
466 let (_cancel_sender, cancel_receiver) = watch::channel(false);
468
469 let state = RedisReadLockState::new(
470 self.reader_key.clone(),
471 self.writer_key.clone(),
472 self.timeouts.clone(),
473 );
474 let clients = self.clients.clone();
475 let timeouts = self.timeouts.clone();
476
477 let state_for_acquire = state.clone();
479
480 let acquire_result = acquire_redlock(
482 move |client| {
483 let state = state_for_acquire.clone();
484 let client = client.clone();
485 async move { state.try_acquire(&client).await }
486 },
487 &clients,
488 &timeouts,
489 &cancel_receiver,
490 )
491 .await?;
492
493 match acquire_result {
494 Some(result) if result.is_successful(clients.len()) => {
495 Ok(Some(RedisReadLockHandle::new(
496 state,
497 result.acquire_results,
498 clients,
499 self.extension_cadence,
500 self.timeouts.expiry,
501 )))
502 }
503 _ => Ok(None),
504 }
505 }
506
507 async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
508 use tokio::sync::watch;
509
510 let (cancel_sender, cancel_receiver) = watch::channel(false);
512
513 if let Some(timeout_duration) = timeout {
515 let cancel_sender_clone = cancel_sender.clone();
516 tokio::spawn(async move {
517 tokio::time::sleep(timeout_duration).await;
518 let _ = cancel_sender_clone.send(true);
519 });
520 }
521
522 let state = RedisWriteLockState::new(
523 self.reader_key.clone(),
524 self.writer_key.clone(),
525 self.timeouts.clone(),
526 );
527 let clients = self.clients.clone();
528 let timeouts = self.timeouts.clone();
529
530 let timeout_value = TimeoutValue::from(timeout);
532 let start = std::time::Instant::now();
533 let mut sleep_duration = Duration::from_millis(50);
534 const MAX_SLEEP: Duration = Duration::from_secs(1);
535
536 loop {
537 if !timeout_value.is_infinite()
539 && start.elapsed() >= timeout_value.as_duration().unwrap()
540 {
541 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
542 }
543
544 if cancel_receiver.has_changed().unwrap_or(false) && *cancel_receiver.borrow() {
546 return Err(LockError::Cancelled);
547 }
548
549 let state_for_acquire = state.clone();
551
552 let acquire_result = acquire_redlock(
554 move |client| {
555 let state = state_for_acquire.clone();
556 let client = client.clone();
557 async move { state.try_acquire(&client).await }
558 },
559 &clients,
560 &timeouts,
561 &cancel_receiver,
562 )
563 .await?;
564
565 match acquire_result {
566 Some(result) if result.is_successful(clients.len()) => {
567 return Ok(RedisWriteLockHandle::new(
569 state,
570 result.acquire_results,
571 clients,
572 self.extension_cadence,
573 self.timeouts.expiry,
574 ));
575 }
576 _ => {
577 tokio::time::sleep(sleep_duration).await;
579 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
580 }
581 }
582 }
583 }
584
585 async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
586 use tokio::sync::watch;
587
588 let (_cancel_sender, cancel_receiver) = watch::channel(false);
590
591 let state = RedisWriteLockState::new(
592 self.reader_key.clone(),
593 self.writer_key.clone(),
594 self.timeouts.clone(),
595 );
596 let clients = self.clients.clone();
597 let timeouts = self.timeouts.clone();
598
599 let state_for_acquire = state.clone();
601
602 let acquire_result = acquire_redlock(
604 move |client| {
605 let state = state_for_acquire.clone();
606 let client = client.clone();
607 async move { state.try_acquire(&client).await }
608 },
609 &clients,
610 &timeouts,
611 &cancel_receiver,
612 )
613 .await?;
614
615 match acquire_result {
616 Some(result) if result.is_successful(clients.len()) => {
617 Ok(Some(RedisWriteLockHandle::new(
618 state,
619 result.acquire_results,
620 clients,
621 self.extension_cadence,
622 self.timeouts.expiry,
623 )))
624 }
625 _ => Ok(None),
626 }
627 }
628}
629
630pub struct RedisReadLockHandle {
632 state: Arc<RedisReadLockState>,
634 acquire_results: Arc<Vec<bool>>,
636 clients: Arc<Vec<RedisClient>>,
638 #[allow(dead_code)]
640 extension_cadence: Duration,
641 #[allow(dead_code)]
643 expiry: Duration,
644 lost_receiver: watch::Receiver<bool>,
646 extension_task: tokio::task::JoinHandle<()>,
648}
649
650impl RedisReadLockHandle {
651 pub(crate) fn new(
652 state: RedisReadLockState,
653 acquire_results: Vec<bool>,
654 clients: Vec<RedisClient>,
655 extension_cadence: Duration,
656 expiry: Duration,
657 ) -> Self {
658 let state = Arc::new(state);
659 let acquire_results = Arc::new(acquire_results);
660 let clients = Arc::new(clients);
661 let (lost_sender, lost_receiver) = watch::channel(false);
662
663 let state_clone = state.clone();
665 let acquire_results_clone = acquire_results.clone();
666 let clients_clone = clients.clone();
667 let extension_cadence_clone = extension_cadence;
668
669 let extension_task = tokio::spawn(async move {
671 let mut interval = tokio::time::interval(extension_cadence_clone);
672 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
673
674 loop {
675 interval.tick().await;
676
677 if lost_sender.is_closed() {
679 break;
680 }
681
682 let (_cancel_sender, cancel_receiver) = watch::channel(false);
684
685 let state_for_extend = state_clone.clone();
687 match extend_redlock(
688 move |client| {
689 let state = state_for_extend.clone();
690 let client = client.clone();
691 async move { state.try_extend(&client).await }
692 },
693 &clients_clone,
694 &acquire_results_clone,
695 &state_clone.timeouts,
696 &cancel_receiver,
697 )
698 .await
699 {
700 Ok(Some(true)) => {
701 continue;
703 }
704 Ok(Some(false)) => {
705 let _ = lost_sender.send(true);
707 break;
708 }
709 Ok(None) => {
710 continue;
712 }
713 Err(_) => {
714 let _ = lost_sender.send(true);
716 break;
717 }
718 }
719 }
720 });
721
722 Self {
723 state,
724 acquire_results,
725 clients,
726 extension_cadence,
727 expiry,
728 lost_receiver,
729 extension_task,
730 }
731 }
732}
733
734impl LockHandle for RedisReadLockHandle {
735 fn lost_token(&self) -> &watch::Receiver<bool> {
736 &self.lost_receiver
737 }
738
739 async fn release(self) -> LockResult<()> {
740 self.extension_task.abort();
742 let state = self.state.clone();
746 let clients = self.clients.clone();
747 let acquire_results = self.acquire_results.clone();
748 release_redlock(
749 move |client| {
750 let state = state.clone();
751 let client = client.clone();
752 async move { state.try_release(&client).await }
753 },
754 &clients,
755 &acquire_results,
756 )
757 .await
758 }
759}
760
761impl Drop for RedisReadLockHandle {
762 fn drop(&mut self) {
763 self.extension_task.abort();
765 }
768}
769
770pub struct RedisWriteLockHandle {
772 state: Arc<RedisWriteLockState>,
774 acquire_results: Arc<Vec<bool>>,
776 clients: Arc<Vec<RedisClient>>,
778 #[allow(dead_code)]
780 extension_cadence: Duration,
781 #[allow(dead_code)]
783 expiry: Duration,
784 lost_receiver: watch::Receiver<bool>,
786 extension_task: tokio::task::JoinHandle<()>,
788}
789
790impl RedisWriteLockHandle {
791 pub(crate) fn new(
792 state: RedisWriteLockState,
793 acquire_results: Vec<bool>,
794 clients: Vec<RedisClient>,
795 extension_cadence: Duration,
796 expiry: Duration,
797 ) -> Self {
798 let state = Arc::new(state);
799 let acquire_results = Arc::new(acquire_results);
800 let clients = Arc::new(clients);
801 let (lost_sender, lost_receiver) = watch::channel(false);
802
803 let state_clone = state.clone();
805 let acquire_results_clone = acquire_results.clone();
806 let clients_clone = clients.clone();
807 let extension_cadence_clone = extension_cadence;
808
809 let extension_task = tokio::spawn(async move {
811 let mut interval = tokio::time::interval(extension_cadence_clone);
812 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
813
814 loop {
815 interval.tick().await;
816
817 if lost_sender.is_closed() {
819 break;
820 }
821
822 let (_cancel_sender, cancel_receiver) = watch::channel(false);
824
825 let state_for_extend = state_clone.clone();
827 match extend_redlock(
828 move |client| {
829 let state = state_for_extend.clone();
830 let client = client.clone();
831 async move { state.try_extend(&client).await }
832 },
833 &clients_clone,
834 &acquire_results_clone,
835 &state_clone.timeouts,
836 &cancel_receiver,
837 )
838 .await
839 {
840 Ok(Some(true)) => {
841 continue;
843 }
844 Ok(Some(false)) => {
845 let _ = lost_sender.send(true);
847 break;
848 }
849 Ok(None) => {
850 continue;
852 }
853 Err(_) => {
854 let _ = lost_sender.send(true);
856 break;
857 }
858 }
859 }
860 });
861
862 Self {
863 state,
864 acquire_results,
865 clients,
866 extension_cadence,
867 expiry,
868 lost_receiver,
869 extension_task,
870 }
871 }
872}
873
874impl LockHandle for RedisWriteLockHandle {
875 fn lost_token(&self) -> &watch::Receiver<bool> {
876 &self.lost_receiver
877 }
878
879 async fn release(self) -> LockResult<()> {
880 self.extension_task.abort();
882 let state = self.state.clone();
886 let clients = self.clients.clone();
887 let acquire_results = self.acquire_results.clone();
888 release_redlock(
889 move |client| {
890 let state = state.clone();
891 let client = client.clone();
892 async move { state.try_release(&client).await }
893 },
894 &clients,
895 &acquire_results,
896 )
897 .await
898 }
899}
900
901impl Drop for RedisWriteLockHandle {
902 fn drop(&mut self) {
903 self.extension_task.abort();
905 }
908}