commonware_utils/
acknowledgement.rs1use core::{
4 fmt::Debug,
5 pin::Pin,
6 sync::atomic::AtomicBool,
7 task::{Context, Poll},
8};
9use futures::task::AtomicWaker;
10use std::{
11 future::Future,
12 sync::{
13 atomic::{AtomicUsize, Ordering},
14 Arc,
15 },
16};
17
18#[derive(Debug, thiserror::Error)]
20#[error("acknowledgement was cancelled")]
21pub struct Canceled;
22
23pub trait Acknowledgement: Clone + Send + Sync + Debug + 'static {
25 type Waiter: Future<Output = Result<(), Self::Error>> + Send + Sync + Unpin + 'static;
27
28 type Error: Debug + Send + Sync + 'static;
30
31 fn handle() -> (Self, Self::Waiter);
33
34 fn acknowledge(self);
36}
37
38pub struct Exact {
42 state: Arc<ExactState>,
43 acknowledged: bool,
44}
45
46impl Debug for Exact {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("Exact")
49 .field("acknowledged", &self.acknowledged)
50 .finish()
51 }
52}
53
54impl Clone for Exact {
55 fn clone(&self) -> Self {
56 self.state.increment();
60
61 Self {
65 state: self.state.clone(),
66 acknowledged: false,
67 }
68 }
69}
70
71impl Drop for Exact {
72 fn drop(&mut self) {
73 if self.acknowledged {
74 return;
75 }
76
77 self.state.cancel();
79 self.acknowledged = true;
80 }
81}
82
83impl Acknowledgement for Exact {
84 type Error = Canceled;
85 type Waiter = ExactWaiter;
86
87 fn handle() -> (Self, Self::Waiter) {
88 let state = Arc::new(ExactState::new());
90 (
91 Self {
92 state: state.clone(),
93 acknowledged: false,
94 },
95 ExactWaiter { state },
96 )
97 }
98
99 fn acknowledge(mut self) {
100 self.state.acknowledge();
101 self.acknowledged = true;
102 }
103}
104
105pub struct ExactWaiter {
107 state: Arc<ExactState>,
108}
109
110impl Unpin for ExactWaiter {}
111
112impl Future for ExactWaiter {
113 type Output = Result<(), Canceled>;
114
115 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116 self.state.waker.register(cx.waker());
117
118 if self.state.canceled.load(Ordering::Acquire) {
119 return Poll::Ready(Err(Canceled));
120 }
121
122 if self.state.remaining.load(Ordering::Acquire) == 0 {
123 return Poll::Ready(Ok(()));
124 }
125
126 Poll::Pending
127 }
128}
129
130struct ExactState {
132 remaining: AtomicUsize,
133 canceled: AtomicBool,
134 waker: AtomicWaker,
135}
136
137impl ExactState {
138 const fn new() -> Self {
140 Self {
141 remaining: AtomicUsize::new(1),
142 canceled: AtomicBool::new(false),
143 waker: AtomicWaker::new(),
144 }
145 }
146
147 fn acknowledge(&self) {
149 if self.remaining.fetch_sub(1, Ordering::AcqRel) != 1 {
151 return;
152 }
153
154 self.waker.wake();
156 }
157
158 fn increment(&self) {
160 self.remaining.fetch_add(1, Ordering::AcqRel);
161 }
162
163 fn cancel(&self) {
165 self.canceled.store(true, Ordering::Release);
166 self.waker.wake();
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::{Acknowledgement, Exact};
173 use futures::{future::FusedFuture, FutureExt};
174 use std::sync::atomic::Ordering;
175
176 #[test]
177 fn acknowledges_after_all_listeners() {
178 let (ack1, waiter) = Exact::handle();
179 let waiter = waiter.fuse();
180 let ack2 = ack1.clone();
181 ack1.acknowledge();
182 assert!(!waiter.is_terminated());
183 ack2.acknowledge();
184 assert!(waiter.now_or_never().unwrap().is_ok());
185 }
186
187 #[test]
188 fn cancels_on_drop() {
189 let (ack, waiter) = Exact::handle();
190 drop(ack);
191 assert!(waiter.now_or_never().unwrap().is_err());
192 }
193
194 #[test]
195 fn cancels_on_drop_before_acknowledgement() {
196 let (ack, waiter) = Exact::handle();
197 let ack2 = ack.clone();
198 drop(ack2);
199 ack.acknowledge();
200 assert!(waiter.now_or_never().unwrap().is_err());
201 }
202
203 #[test]
204 fn cancels_on_drop_after_acknowledgement() {
205 let (ack, waiter) = Exact::handle();
206 let ack2 = ack.clone();
207 ack.acknowledge();
208 drop(ack2);
209 assert!(waiter.now_or_never().unwrap().is_err());
210 }
211
212 #[test]
213 fn dropping_waiter_does_not_interfere_with_acknowledgement() {
214 let (ack, waiter) = Exact::handle();
215 let state = ack.state.clone();
216 drop(waiter);
217
218 let ack2 = ack.clone();
219 ack.acknowledge();
220 ack2.acknowledge();
221
222 assert_eq!(state.remaining.load(Ordering::Acquire), 0);
223 assert!(!state.canceled.load(Ordering::Acquire));
224 }
225}