commonware_utils/
acknowledgement.rs

1//! Utilities for providing acknowledgement.
2
3use core::{
4    fmt::Debug,
5    pin::Pin,
6    sync::atomic::AtomicBool,
7    task::{Context, Poll},
8};
9use futures::task::AtomicWaker;
10use std::{
11    future::Future,
12    sync::{
13        atomic::{AtomicUsize, Ordering},
14        Arc,
15    },
16};
17
18/// Acknowledgement cancellation error.
19#[derive(Debug, thiserror::Error)]
20#[error("acknowledgement was cancelled")]
21pub struct Canceled;
22
23/// A mechanism for acknowledging the completion of a task.
24pub trait Acknowledgement: Clone + Send + Sync + Debug + 'static {
25    /// Future resolved once the acknowledgement is handled.
26    type Waiter: Future<Output = Result<(), Self::Error>> + Send + Sync + Unpin + 'static;
27
28    /// Error produced if the acknowledgement is not handled.
29    type Error: Debug + Send + Sync + 'static;
30
31    /// Create a new acknowledgement handle paired with the waiter.
32    fn handle() -> (Self, Self::Waiter);
33
34    /// Fulfill the acknowledgement.
35    fn acknowledge(self);
36}
37
38/// [Acknowledgement] that returns after all instances are acknowledged.
39///
40/// If any acknowledgement is not handled, the acknowledgement will be cancelled.
41pub struct Exact {
42    state: Arc<ExactState>,
43    acknowledged: bool,
44}
45
46impl Debug for Exact {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("Exact")
49            .field("acknowledged", &self.acknowledged)
50            .finish()
51    }
52}
53
54impl Clone for Exact {
55    fn clone(&self) -> Self {
56        // Because acknowledge consumes self, we know that there is no way for there
57        // to remain 0 references before the last acknowledgement has been cloned (i.e.
58        // the acknowledgement won't resolve while we are still creating new clones).
59        self.state.increment();
60
61        // Create a new acknowledgement with acknowledged set to false (the acknowledgement
62        // we are cloning from will also be false because it hasn't been consumed but we do it
63        // manually to be explicit).
64        Self {
65            state: self.state.clone(),
66            acknowledged: false,
67        }
68    }
69}
70
71impl Drop for Exact {
72    fn drop(&mut self) {
73        if self.acknowledged {
74            return;
75        }
76
77        // If not yet acknowledged, cancel the acknowledgement.
78        self.state.cancel();
79        self.acknowledged = true;
80    }
81}
82
83impl Acknowledgement for Exact {
84    type Error = Canceled;
85    type Waiter = ExactWaiter;
86
87    fn handle() -> (Self, Self::Waiter) {
88        // When created, ExactState has a remaining count of 1 already.
89        let state = Arc::new(ExactState::new());
90        (
91            Self {
92                state: state.clone(),
93                acknowledged: false,
94            },
95            ExactWaiter { state },
96        )
97    }
98
99    fn acknowledge(mut self) {
100        self.state.acknowledge();
101        self.acknowledged = true;
102    }
103}
104
105/// Future that waits for an [Exact] acknowledgement to complete or be canceled.
106pub struct ExactWaiter {
107    state: Arc<ExactState>,
108}
109
110impl Unpin for ExactWaiter {}
111
112impl Future for ExactWaiter {
113    type Output = Result<(), Canceled>;
114
115    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116        self.state.waker.register(cx.waker());
117
118        if self.state.canceled.load(Ordering::Acquire) {
119            return Poll::Ready(Err(Canceled));
120        }
121
122        if self.state.remaining.load(Ordering::Acquire) == 0 {
123            return Poll::Ready(Ok(()));
124        }
125
126        Poll::Pending
127    }
128}
129
130/// State for the [Exact] acknowledgement.
131struct ExactState {
132    remaining: AtomicUsize,
133    canceled: AtomicBool,
134    waker: AtomicWaker,
135}
136
137impl ExactState {
138    /// Create a new acknowledgement state with a remaining count of 1.
139    const fn new() -> Self {
140        Self {
141            remaining: AtomicUsize::new(1),
142            canceled: AtomicBool::new(false),
143            waker: AtomicWaker::new(),
144        }
145    }
146
147    /// Acknowledge the completion of a task.
148    fn acknowledge(&self) {
149        // Decrement the remaining count and check if it was the last acknowledgement.
150        if self.remaining.fetch_sub(1, Ordering::AcqRel) != 1 {
151            return;
152        }
153
154        // On last acknowledgement, wake the waiter.
155        self.waker.wake();
156    }
157
158    /// Increment the remaining count.
159    fn increment(&self) {
160        self.remaining.fetch_add(1, Ordering::AcqRel);
161    }
162
163    /// Cancel the acknowledgement.
164    fn cancel(&self) {
165        self.canceled.store(true, Ordering::Release);
166        self.waker.wake();
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::{Acknowledgement, Exact};
173    use futures::{future::FusedFuture, FutureExt};
174    use std::sync::atomic::Ordering;
175
176    #[test]
177    fn acknowledges_after_all_listeners() {
178        let (ack1, waiter) = Exact::handle();
179        let waiter = waiter.fuse();
180        let ack2 = ack1.clone();
181        ack1.acknowledge();
182        assert!(!waiter.is_terminated());
183        ack2.acknowledge();
184        assert!(waiter.now_or_never().unwrap().is_ok());
185    }
186
187    #[test]
188    fn cancels_on_drop() {
189        let (ack, waiter) = Exact::handle();
190        drop(ack);
191        assert!(waiter.now_or_never().unwrap().is_err());
192    }
193
194    #[test]
195    fn cancels_on_drop_before_acknowledgement() {
196        let (ack, waiter) = Exact::handle();
197        let ack2 = ack.clone();
198        drop(ack2);
199        ack.acknowledge();
200        assert!(waiter.now_or_never().unwrap().is_err());
201    }
202
203    #[test]
204    fn cancels_on_drop_after_acknowledgement() {
205        let (ack, waiter) = Exact::handle();
206        let ack2 = ack.clone();
207        ack.acknowledge();
208        drop(ack2);
209        assert!(waiter.now_or_never().unwrap().is_err());
210    }
211
212    #[test]
213    fn dropping_waiter_does_not_interfere_with_acknowledgement() {
214        let (ack, waiter) = Exact::handle();
215        let state = ack.state.clone();
216        drop(waiter);
217
218        let ack2 = ack.clone();
219        ack.acknowledge();
220        ack2.acknowledge();
221
222        assert_eq!(state.remaining.load(Ordering::Acquire), 0);
223        assert!(!state.canceled.load(Ordering::Acquire));
224    }
225}