use std::sync::atomic::{AtomicBool, AtomicIsize, Ordering};
#[derive(Debug)]
pub struct Countdown {
finish: AtomicBool,
counter: AtomicIsize,
}
impl Countdown {
pub fn new(counter: usize) -> Self {
Self {
finish: AtomicBool::new(false),
counter: AtomicIsize::new(isize::try_from(counter).expect("`counter` must NOT exceed `isize::MAX`.")),
}
}
pub fn countdown(&self) -> bool {
if self.finish.load(Ordering::Relaxed) {
return true;
}
self.counter.fetch_sub(1, Ordering::Relaxed) <= 0
}
pub fn reset(&self, counter: usize) {
self.finish.store(false, Ordering::Relaxed);
self.counter.store(
isize::try_from(counter).expect("`counter` must NOT exceed `isize::MAX`."),
Ordering::Relaxed,
);
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures_util::future::join_all;
use super::*;
async fn case(counter: usize, concurrency: usize) {
let cd = Countdown::new(counter);
let res = join_all((0..concurrency).map(|_| async {
tokio::time::sleep(Duration::from_millis(10)).await;
cd.countdown()
}))
.await;
assert_eq!(counter, res.into_iter().filter(|b| !b).count());
}
#[tokio::test]
async fn test_countdown() {
for counter in [1, 4, 8, 16] {
for concurrency in [16, 32, 64, 128] {
case(counter, concurrency).await;
}
}
}
}