use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::sync::Notify;
#[derive(Clone)]
pub struct CountDownLatch {
count: Arc<Mutex<u32>>,
notify: Arc<Notify>,
}
impl CountDownLatch {
#[inline]
pub fn new(count: u32) -> Self {
CountDownLatch {
count: Arc::new(Mutex::new(count)),
notify: Arc::new(Notify::new()),
}
}
#[inline]
pub async fn count_down(&self) {
let mut count = self.count.lock().await;
*count -= 1;
if *count == 0 {
self.notify.notify_waiters();
}
}
#[inline]
pub async fn wait(&self) {
let count = self.count.lock().await;
if *count > 0 {
drop(count);
self.notify.notified().await;
}
}
#[inline]
pub async fn wait_timeout(&self, timeout: Duration) -> bool {
tokio::time::timeout(timeout, self.wait()).await.is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn count_down_latch_initial_count() {
let latch = CountDownLatch::new(3);
let count = latch.count.lock().await;
assert_eq!(*count, 3);
}
#[tokio::test]
async fn wait_timeout_reaches_zero_before_timeout() {
let latch = CountDownLatch::new(1);
latch.count_down().await;
let result = latch.wait_timeout(Duration::from_secs(1)).await;
assert!(result);
}
#[tokio::test]
async fn wait_timeout_exceeds_timeout() {
let latch = CountDownLatch::new(1);
let result = latch.wait_timeout(Duration::from_millis(10)).await;
assert!(!result);
}
#[tokio::test]
async fn count_down_latch_count_down() {
let latch = CountDownLatch::new(3);
latch.clone().count_down().await;
let count = latch.count.lock().await;
assert_eq!(*count, 2);
}
#[tokio::test]
async fn count_down_latch_multiple_waiters() {
let latch = CountDownLatch::new(2);
let latch_clone1 = latch.clone();
let latch_clone2 = latch.clone();
let waiter1 = tokio::spawn(async move {
latch_clone1.wait().await;
});
let waiter2 = tokio::spawn(async move {
latch_clone2.wait().await;
});
latch.clone().count_down().await;
latch.clone().count_down().await;
waiter1.await.unwrap();
waiter2.await.unwrap();
let count = latch.count.lock().await;
assert_eq!(*count, 0);
}
}