rocketmq_rust/
count_down_latch.rs1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio::sync::Mutex;
21use tokio::sync::Notify;
22
23#[derive(Clone)]
26pub struct CountDownLatch {
27 count: Arc<Mutex<u32>>,
29 notify: Arc<Notify>,
31}
32
33impl CountDownLatch {
34 #[inline]
36 pub fn new(count: u32) -> Self {
37 CountDownLatch {
38 count: Arc::new(Mutex::new(count)),
39 notify: Arc::new(Notify::new()),
40 }
41 }
42
43 #[inline]
44 pub async fn count_down(&self) {
45 let mut count = self.count.lock().await;
46 *count -= 1;
47 if *count == 0 {
48 self.notify.notify_waiters();
49 }
50 }
51
52 #[inline]
53 pub async fn wait(&self) {
54 let count = self.count.lock().await;
55 if *count > 0 {
56 drop(count);
57 self.notify.notified().await;
58 }
59 }
60
61 #[inline]
62 pub async fn wait_timeout(&self, timeout: Duration) -> bool {
63 tokio::time::timeout(timeout, self.wait()).await.is_ok()
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[tokio::test]
72 async fn count_down_latch_initial_count() {
73 let latch = CountDownLatch::new(3);
74 let count = latch.count.lock().await;
75 assert_eq!(*count, 3);
76 }
77
78 #[tokio::test]
79 async fn wait_timeout_reaches_zero_before_timeout() {
80 let latch = CountDownLatch::new(1);
81 latch.count_down().await;
82 let result = latch.wait_timeout(Duration::from_secs(1)).await;
83 assert!(result);
84 }
85
86 #[tokio::test]
87 async fn wait_timeout_exceeds_timeout() {
88 let latch = CountDownLatch::new(1);
89 let result = latch.wait_timeout(Duration::from_millis(10)).await;
90 assert!(!result);
91 }
92
93 #[tokio::test]
94 async fn count_down_latch_count_down() {
95 let latch = CountDownLatch::new(3);
96 latch.clone().count_down().await;
97 let count = latch.count.lock().await;
98 assert_eq!(*count, 2);
99 }
100
101 #[tokio::test]
102 async fn count_down_latch_multiple_waiters() {
103 let latch = CountDownLatch::new(2);
104 let latch_clone1 = latch.clone();
105 let latch_clone2 = latch.clone();
106
107 let waiter1 = tokio::spawn(async move {
108 latch_clone1.wait().await;
109 });
110
111 let waiter2 = tokio::spawn(async move {
112 latch_clone2.wait().await;
113 });
114
115 latch.clone().count_down().await;
116 latch.clone().count_down().await;
117
118 waiter1.await.unwrap();
119 waiter2.await.unwrap();
120
121 let count = latch.count.lock().await;
122 assert_eq!(*count, 0);
123 }
124}