use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use crate::internal::CountdownState;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct Latch {
state: CountdownState,
}
impl Latch {
pub fn new(count: u32) -> Self {
Self {
state: CountdownState::new(count),
}
}
pub fn count(&self) -> u32 {
self.state.state()
}
pub fn count_down(&self) {
if self.state.decrement(1) {
self.state.wake_all();
}
}
pub fn arrive(&self, n: u32) {
if n != 0 && self.state.decrement(n) {
self.state.wake_all();
}
}
pub fn try_wait(&self) -> Result<(), u32> {
self.state.spin_wait(0)
}
pub async fn wait(&self) {
let fut = LatchWait {
idx: None,
latch: self,
};
fut.await
}
pub async fn wait_owned(self: Arc<Self>) {
let fut = OwnedLatchWait {
idx: None,
latch: self,
};
fut.await
}
}
impl Latch {
fn intern_poll(&self, idx: &mut Option<usize>, cx: &mut Context<'_>) -> Poll<()> {
if self.state.spin_wait(16).is_err() {
self.state.register_waker(idx, cx);
if self.state.spin_wait(0).is_err() {
return Poll::Pending;
}
}
Poll::Ready(())
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct LatchWait<'a> {
idx: Option<usize>,
latch: &'a Latch,
}
impl fmt::Debug for LatchWait<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LatchWait").finish_non_exhaustive()
}
}
impl Future for LatchWait<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { idx, latch } = self.get_mut();
latch.intern_poll(idx, cx)
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct OwnedLatchWait {
idx: Option<usize>,
latch: Arc<Latch>,
}
impl fmt::Debug for OwnedLatchWait {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OwnedLatchWait").finish_non_exhaustive()
}
}
impl Future for OwnedLatchWait {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { idx, latch } = self.get_mut();
latch.intern_poll(idx, cx)
}
}