use std::future::Future;
use std::ops::{Add, AddAssign, Sub, SubAssign};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
#[derive(Debug, Clone)]
pub struct Counter {
value: Arc<AtomicUsize>,
target: usize,
waker: Arc<Mutex<Option<Waker>>>,
}
impl Counter {
const MUST_LOCK: &'static str = "Counter inner mutex must lock";
pub fn new(from: usize, target: usize) -> Self {
Self {
value: Arc::new(AtomicUsize::new(from)),
target,
waker: Arc::new(Mutex::new(None)),
}
}
pub fn to(target: usize) -> Self {
Self::new(0, target)
}
pub fn value(&self) -> usize {
self.value.load(Ordering::Relaxed)
}
pub fn target(&self) -> usize {
self.target
}
fn inc(&self, rhs: usize) {
self.value.fetch_add(rhs, Ordering::Relaxed);
if let Some(waker) = self.waker.lock().expect(Self::MUST_LOCK).take() {
waker.wake()
}
}
fn dec(&self, rhs: usize) {
self.value.fetch_sub(rhs, Ordering::Relaxed);
if let Some(waker) = self.waker.lock().expect(Self::MUST_LOCK).take() {
waker.wake()
}
}
pub fn set(&self, rhs: usize) {
self.value.store(rhs, Ordering::Relaxed);
if let Some(waker) = self.waker.lock().expect(Self::MUST_LOCK).take() {
waker.wake()
}
}
}
impl Future for Counter {
type Output = usize;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let value = self.value.load(Ordering::Relaxed);
if value >= self.target {
Poll::Ready(value)
} else {
*self.waker.lock().expect(Self::MUST_LOCK) = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl AddAssign<usize> for Counter {
fn add_assign(&mut self, rhs: usize) {
self.inc(rhs);
}
}
impl SubAssign<usize> for Counter {
fn sub_assign(&mut self, rhs: usize) {
self.dec(rhs);
}
}
impl Add<usize> for Counter {
type Output = Self;
fn add(mut self, rhs: usize) -> Self::Output {
self += rhs;
self
}
}
impl Sub<usize> for Counter {
type Output = Self;
fn sub(mut self, rhs: usize) -> Self::Output {
self -= rhs;
self
}
}
#[cfg(test)]
mod tests {
use crate::Counter;
use log::debug;
use std::ops::Mul;
use std::time::Duration;
use tokio::time;
#[tokio::test]
async fn counter_counts_up() {
let _ = pretty_env_logger::try_init();
let counting_interval = Duration::from_millis(10);
let target = 10;
let counter = Counter::to(target);
let mut count = counter.clone();
tokio::spawn(async move {
for i in 0u8..20 {
time::sleep(counting_interval).await;
debug!("Tick {i}");
count = count + 5;
}
});
let r = time::timeout(counting_interval.mul(20), counter).await;
assert!(matches!(r, Ok(t) if t == target));
debug!("Counter target is reached!");
}
#[tokio::test]
async fn counter_counts_up_and_down() {
let _ = pretty_env_logger::try_init();
let counting_interval = Duration::from_millis(10);
let target = 10;
let counter = Counter::to(target);
let mut count = counter.clone();
tokio::spawn(async move {
for i in 0u8..3 {
time::sleep(counting_interval).await;
debug!("Tick {i}");
count += 5;
}
count -= 6;
count += 3;
});
let r = time::timeout(counting_interval.mul(20), counter.clone()).await;
assert!(matches!(r, Ok(t) if t == target));
time::sleep(counting_interval.mul(2)).await;
let r = time::timeout(counting_interval.mul(20), counter).await;
debug!("{r:?}");
assert!(matches!(r, Ok(t) if t == 12));
debug!("Counter target is reached!");
}
}