commonware_utils/channel/
tracked.rs1use super::mpsc::{
50 self,
51 error::{SendError, TryRecvError, TrySendError},
52};
53use crate::sync::Mutex;
54use futures::Stream;
55use std::{
56 collections::HashMap,
57 hash::Hash,
58 pin::Pin,
59 sync::Arc,
60 task::{Context, Poll},
61};
62
63#[derive(Clone)]
65pub struct Guard<B: Eq + Hash + Clone> {
66 sequence: u64,
67 tracker: Arc<Mutex<State<B>>>,
68
69 batch: Option<B>,
70}
71
72impl<B: Eq + Hash + Clone> Drop for Guard<B> {
73 fn drop(&mut self) {
74 let mut state = self.tracker.lock();
76
77 *state.pending.get_mut(&self.sequence).unwrap() = true;
79
80 let mut current_watermark = state.watermark;
82 while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
83 if !*delivered {
85 break;
86 }
87
88 state.pending.remove(&(current_watermark + 1));
90 current_watermark += 1;
91 state.watermark = current_watermark;
92 }
93
94 if let Some(batch) = &self.batch {
96 let count = state.batches.get_mut(batch).unwrap();
97 if *count > 1 {
98 *count -= 1;
99 } else {
100 state.batches.remove(batch);
101 }
102 }
103 }
104}
105
106pub struct Message<T, B: Eq + Hash + Clone> {
108 pub data: T,
110 pub guard: Arc<Guard<B>>,
114}
115
116struct State<B> {
118 next: u64,
119 watermark: u64,
120 batches: HashMap<B, usize>,
121 pending: HashMap<u64, bool>,
122}
123
124impl<B> Default for State<B> {
125 fn default() -> Self {
126 Self {
127 next: 1,
128 watermark: 0,
129 batches: HashMap::new(),
130 pending: HashMap::new(),
131 }
132 }
133}
134
135#[derive(Clone)]
142struct Tracker<B: Eq + Hash + Clone> {
143 state: Arc<Mutex<State<B>>>,
144}
145
146impl<B: Eq + Hash + Clone> Tracker<B> {
147 fn new() -> Self {
148 Self {
149 state: Arc::new(Mutex::new(State::default())),
150 }
151 }
152
153 fn guard(&self, batch: Option<B>) -> Guard<B> {
154 let mut state = self.state.lock();
156
157 let sequence = state.next;
159 state.next += 1;
160
161 state.pending.insert(sequence, false);
163
164 if let Some(batch) = &batch {
166 *state.batches.entry(batch.clone()).or_insert(0) += 1;
167 }
168
169 Guard {
170 sequence,
171 tracker: self.state.clone(),
172
173 batch,
174 }
175 }
176}
177
178#[derive(Clone)]
180pub struct Sender<T, B: Eq + Hash + Clone> {
181 inner: mpsc::Sender<Message<T, B>>,
182 tracker: Tracker<B>,
183}
184
185impl<T, B: Eq + Hash + Clone> Sender<T, B> {
186 pub async fn send(&self, batch: Option<B>, data: T) -> Result<u64, SendError<Message<T, B>>> {
188 let guard = Arc::new(self.tracker.guard(batch));
190 let watermark = guard.sequence;
191
192 let msg = Message { data, guard };
194 self.inner.send(msg).await?;
195
196 Ok(watermark)
197 }
198
199 pub fn try_send(&self, batch: Option<B>, data: T) -> Result<u64, TrySendError<Message<T, B>>> {
201 let guard = Arc::new(self.tracker.guard(batch));
203 let watermark = guard.sequence;
204
205 let msg = Message { data, guard };
207 self.inner.try_send(msg)?;
208
209 Ok(watermark)
210 }
211
212 pub fn watermark(&self) -> u64 {
214 self.tracker.state.lock().watermark
215 }
216
217 pub fn pending(&self, batch: B) -> usize {
219 self.tracker
220 .state
221 .lock()
222 .batches
223 .get(&batch)
224 .copied()
225 .unwrap_or(0)
226 }
227}
228
229pub struct Receiver<T, B: Eq + Hash + Clone> {
231 inner: mpsc::Receiver<Message<T, B>>,
232}
233
234impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
235 pub async fn recv(&mut self) -> Option<Message<T, B>> {
237 self.inner.recv().await
238 }
239
240 pub fn try_recv(&mut self) -> Result<Message<T, B>, TryRecvError> {
242 self.inner.try_recv()
243 }
244}
245
246impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
247 type Item = Message<T, B>;
248
249 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250 self.inner.poll_recv(cx)
251 }
252}
253
254pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
256 let (tx, rx) = mpsc::channel(buffer);
257 let sender = Sender {
258 inner: tx,
259 tracker: Tracker::new(),
260 };
261 let receiver = Receiver { inner: rx };
262 (sender, receiver)
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use futures::executor::block_on;
269
270 #[test]
271 fn test_basic() {
272 block_on(async move {
273 let (sender, mut receiver) = bounded::<i32, u64>(10);
274
275 let watermark = sender.send(None, 42).await.unwrap();
277 assert_eq!(watermark, 1);
278 assert_eq!(sender.watermark(), 0);
279
280 let msg = receiver.recv().await.unwrap();
282 assert_eq!(msg.data, 42);
283 assert_eq!(sender.watermark(), 0);
284
285 drop(msg.guard);
287 assert_eq!(sender.watermark(), 1);
288 });
289 }
290
291 #[test]
292 fn test_batch_tracking() {
293 block_on(async move {
294 let (sender, mut receiver) = bounded::<String, u64>(10);
295
296 let watermark1 = sender.send(Some(100), "msg1".into()).await.unwrap();
298 let watermark2 = sender.send(Some(100), "msg2".into()).await.unwrap();
299 let watermark3 = sender.send(Some(200), "msg3".into()).await.unwrap();
300
301 assert_eq!(watermark1, 1);
302 assert_eq!(watermark2, 2);
303 assert_eq!(watermark3, 3);
304 assert_eq!(sender.pending(100), 2);
305 assert_eq!(sender.pending(200), 1);
306 assert_eq!(sender.pending(300), 0);
307
308 let msg1 = receiver.recv().await.unwrap();
310 assert_eq!(msg1.data, "msg1");
311 drop(msg1.guard);
312
313 assert_eq!(sender.pending(100), 1);
314 assert_eq!(sender.pending(200), 1);
315
316 let msg2 = receiver.recv().await.unwrap();
318 let msg3 = receiver.recv().await.unwrap();
319 drop(msg2.guard);
320 drop(msg3.guard);
321
322 assert_eq!(sender.pending(100), 0);
323 assert_eq!(sender.pending(200), 0);
324 });
325 }
326
327 #[test]
328 fn test_cloned_guards() {
329 block_on(async move {
330 let (sender, mut receiver) = bounded::<&str, u64>(10);
331
332 let watermark = sender.send(Some(1), "test").await.unwrap();
333 assert_eq!(watermark, 1);
334
335 let msg = receiver.recv().await.unwrap();
337 assert_eq!(msg.data, "test");
338
339 let msg_guard_clone1 = msg.guard.clone();
341 let msg_guard_clone2 = msg.guard.clone();
342
343 assert_eq!(sender.pending(1), 1);
344 assert_eq!(sender.watermark(), 0);
345
346 drop(msg.guard);
348 drop(msg_guard_clone1);
349 assert_eq!(sender.pending(1), 1);
350 assert_eq!(sender.watermark(), 0);
351
352 drop(msg_guard_clone2);
354 assert_eq!(sender.pending(1), 0);
355 assert_eq!(sender.watermark(), 1);
356 });
357 }
358
359 #[test]
360 fn test_try_send() {
361 block_on(async move {
362 let (sender, mut receiver) = bounded::<i32, u64>(2);
363
364 let watermark1 = sender.try_send(Some(10), 1).unwrap();
366 let watermark2 = sender.try_send(Some(10), 2).unwrap();
367
368 assert_eq!(sender.pending(10), 2);
369 assert_eq!(watermark1, 1);
370 assert_eq!(watermark2, 2);
371
372 let msg1 = receiver.recv().await.unwrap();
374 assert_eq!(msg1.data, 1);
375 drop(msg1.guard);
376
377 assert_eq!(sender.pending(10), 1);
378
379 let msg2 = receiver.recv().await.unwrap();
380 drop(msg2.guard);
381
382 assert_eq!(sender.pending(10), 0);
383 });
384 }
385
386 #[test]
387 fn test_channel_closure() {
388 block_on(async move {
389 let (sender, receiver) = bounded::<i32, u64>(10);
390
391 let _guard = sender.send(None, 1).await.unwrap();
392
393 drop(receiver);
395
396 assert!(sender.send(None, 2).await.is_err());
398 });
399 }
400}