use core::{
fmt::Debug,
pin::Pin,
sync::atomic::AtomicBool,
task::{Context, Poll},
};
use futures::task::AtomicWaker;
use std::{
future::Future,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
#[derive(Debug, thiserror::Error)]
#[error("acknowledgement was cancelled")]
pub struct Canceled;
pub trait Acknowledgement: Clone + Send + Sync + Debug + 'static {
type Waiter: Future<Output = Result<(), Self::Error>> + Send + Sync + Unpin + 'static;
type Error: Debug + Send + Sync + 'static;
fn handle() -> (Self, Self::Waiter);
fn acknowledge(self);
}
pub struct Exact {
state: Arc<ExactState>,
acknowledged: bool,
}
impl Debug for Exact {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Exact")
.field("acknowledged", &self.acknowledged)
.finish()
}
}
impl Clone for Exact {
fn clone(&self) -> Self {
self.state.increment();
Self {
state: self.state.clone(),
acknowledged: false,
}
}
}
impl Drop for Exact {
fn drop(&mut self) {
if self.acknowledged {
return;
}
self.state.cancel();
self.acknowledged = true;
}
}
impl Acknowledgement for Exact {
type Error = Canceled;
type Waiter = ExactWaiter;
fn handle() -> (Self, Self::Waiter) {
let state = Arc::new(ExactState::new());
(
Self {
state: state.clone(),
acknowledged: false,
},
ExactWaiter { state },
)
}
fn acknowledge(mut self) {
self.state.acknowledge();
self.acknowledged = true;
}
}
pub struct ExactWaiter {
state: Arc<ExactState>,
}
impl Unpin for ExactWaiter {}
impl Future for ExactWaiter {
type Output = Result<(), Canceled>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.state.waker.register(cx.waker());
if self.state.canceled.load(Ordering::Acquire) {
return Poll::Ready(Err(Canceled));
}
if self.state.remaining.load(Ordering::Acquire) == 0 {
return Poll::Ready(Ok(()));
}
Poll::Pending
}
}
struct ExactState {
remaining: AtomicUsize,
canceled: AtomicBool,
waker: AtomicWaker,
}
impl ExactState {
const fn new() -> Self {
Self {
remaining: AtomicUsize::new(1),
canceled: AtomicBool::new(false),
waker: AtomicWaker::new(),
}
}
fn acknowledge(&self) {
if self.remaining.fetch_sub(1, Ordering::AcqRel) != 1 {
return;
}
self.waker.wake();
}
fn increment(&self) {
self.remaining.fetch_add(1, Ordering::AcqRel);
}
fn cancel(&self) {
self.canceled.store(true, Ordering::Release);
self.waker.wake();
}
}
#[cfg(test)]
mod tests {
use super::{Acknowledgement, Exact};
use futures::{future::FusedFuture, FutureExt};
use std::sync::atomic::Ordering;
#[test]
fn acknowledges_after_all_listeners() {
let (ack1, waiter) = Exact::handle();
let waiter = waiter.fuse();
let ack2 = ack1.clone();
ack1.acknowledge();
assert!(!waiter.is_terminated());
ack2.acknowledge();
assert!(waiter.now_or_never().unwrap().is_ok());
}
#[test]
fn cancels_on_drop() {
let (ack, waiter) = Exact::handle();
drop(ack);
assert!(waiter.now_or_never().unwrap().is_err());
}
#[test]
fn cancels_on_drop_before_acknowledgement() {
let (ack, waiter) = Exact::handle();
let ack2 = ack.clone();
drop(ack2);
ack.acknowledge();
assert!(waiter.now_or_never().unwrap().is_err());
}
#[test]
fn cancels_on_drop_after_acknowledgement() {
let (ack, waiter) = Exact::handle();
let ack2 = ack.clone();
ack.acknowledge();
drop(ack2);
assert!(waiter.now_or_never().unwrap().is_err());
}
#[test]
fn dropping_waiter_does_not_interfere_with_acknowledgement() {
let (ack, waiter) = Exact::handle();
let state = ack.state.clone();
drop(waiter);
let ack2 = ack.clone();
ack.acknowledge();
ack2.acknowledge();
assert_eq!(state.remaining.load(Ordering::Acquire), 0);
assert!(!state.canceled.load(Ordering::Acquire));
}
}