rocketmq_rust/
count_down_latch.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17use std::sync::Arc;
18use std::time::Duration;
19
20use tokio::sync::Mutex;
21use tokio::sync::Notify;
22
23/// A synchronization aid that allows one or more tasks to wait until a set of operations being
24/// performed in other tasks completes.
25#[derive(Clone)]
26pub struct CountDownLatch {
27    /// The current count of the latch.
28    count: Arc<Mutex<u32>>,
29    /// A notification mechanism to wake up waiting tasks.
30    notify: Arc<Notify>,
31}
32
33impl CountDownLatch {
34    /// A new `CountDownLatch`.
35    #[inline]
36    pub fn new(count: u32) -> Self {
37        CountDownLatch {
38            count: Arc::new(Mutex::new(count)),
39            notify: Arc::new(Notify::new()),
40        }
41    }
42
43    #[inline]
44    pub async fn count_down(&self) {
45        let mut count = self.count.lock().await;
46        *count -= 1;
47        if *count == 0 {
48            self.notify.notify_waiters();
49        }
50    }
51
52    #[inline]
53    pub async fn wait(&self) {
54        let count = self.count.lock().await;
55        if *count > 0 {
56            drop(count);
57            self.notify.notified().await;
58        }
59    }
60
61    #[inline]
62    pub async fn wait_timeout(&self, timeout: Duration) -> bool {
63        tokio::time::timeout(timeout, self.wait()).await.is_ok()
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[tokio::test]
72    async fn count_down_latch_initial_count() {
73        let latch = CountDownLatch::new(3);
74        let count = latch.count.lock().await;
75        assert_eq!(*count, 3);
76    }
77
78    #[tokio::test]
79    async fn wait_timeout_reaches_zero_before_timeout() {
80        let latch = CountDownLatch::new(1);
81        latch.count_down().await;
82        let result = latch.wait_timeout(Duration::from_secs(1)).await;
83        assert!(result);
84    }
85
86    #[tokio::test]
87    async fn wait_timeout_exceeds_timeout() {
88        let latch = CountDownLatch::new(1);
89        let result = latch.wait_timeout(Duration::from_millis(10)).await;
90        assert!(!result);
91    }
92
93    #[tokio::test]
94    async fn count_down_latch_count_down() {
95        let latch = CountDownLatch::new(3);
96        latch.clone().count_down().await;
97        let count = latch.count.lock().await;
98        assert_eq!(*count, 2);
99    }
100
101    #[tokio::test]
102    async fn count_down_latch_multiple_waiters() {
103        let latch = CountDownLatch::new(2);
104        let latch_clone1 = latch.clone();
105        let latch_clone2 = latch.clone();
106
107        let waiter1 = tokio::spawn(async move {
108            latch_clone1.wait().await;
109        });
110
111        let waiter2 = tokio::spawn(async move {
112            latch_clone2.wait().await;
113        });
114
115        latch.clone().count_down().await;
116        latch.clone().count_down().await;
117
118        waiter1.await.unwrap();
119        waiter2.await.unwrap();
120
121        let count = latch.count.lock().await;
122        assert_eq!(*count, 0);
123    }
124}