commonware_utils/channels/
tracked.rs1use futures::{
50 channel::mpsc::{self, Receiver as FutReceiver, SendError, Sender as FutSender, TrySendError},
51 SinkExt, Stream, StreamExt,
52};
53use std::{
54 collections::HashMap,
55 hash::Hash,
56 pin::Pin,
57 sync::{Arc, Mutex},
58 task::{Context, Poll},
59};
60
61#[derive(Clone)]
63pub struct Guard<B: Eq + Hash + Clone> {
64 sequence: u64,
65 tracker: Arc<Mutex<State<B>>>,
66
67 batch: Option<B>,
68}
69
70impl<B: Eq + Hash + Clone> Drop for Guard<B> {
71 fn drop(&mut self) {
72 let mut state = self.tracker.lock().unwrap();
74
75 *state.pending.get_mut(&self.sequence).unwrap() = true;
77
78 let mut current_watermark = state.watermark;
80 while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
81 if !*delivered {
83 break;
84 }
85
86 state.pending.remove(&(current_watermark + 1));
88 current_watermark += 1;
89 state.watermark = current_watermark;
90 }
91
92 if let Some(batch) = &self.batch {
94 let count = state.batches.get_mut(batch).unwrap();
95 if *count > 1 {
96 *count -= 1;
97 } else {
98 state.batches.remove(batch);
99 }
100 }
101 }
102}
103
104pub struct Message<T, B: Eq + Hash + Clone> {
106 pub data: T,
108 pub guard: Arc<Guard<B>>,
112}
113
114struct State<B> {
116 next: u64,
117 watermark: u64,
118 batches: HashMap<B, usize>,
119 pending: HashMap<u64, bool>,
120}
121
122impl<B> Default for State<B> {
123 fn default() -> Self {
124 Self {
125 next: 1,
126 watermark: 0,
127 batches: HashMap::new(),
128 pending: HashMap::new(),
129 }
130 }
131}
132
133#[derive(Clone)]
140struct Tracker<B: Eq + Hash + Clone> {
141 state: Arc<Mutex<State<B>>>,
142}
143
144impl<B: Eq + Hash + Clone> Tracker<B> {
145 fn new() -> Self {
146 Self {
147 state: Arc::new(Mutex::new(State::default())),
148 }
149 }
150
151 fn guard(&self, batch: Option<B>) -> Guard<B> {
152 let mut state = self.state.lock().unwrap();
154
155 let sequence = state.next;
157 state.next += 1;
158
159 state.pending.insert(sequence, false);
161
162 if let Some(batch) = &batch {
164 *state.batches.entry(batch.clone()).or_insert(0) += 1;
165 }
166
167 Guard {
168 sequence,
169 tracker: self.state.clone(),
170
171 batch,
172 }
173 }
174}
175
176#[derive(Clone)]
178pub struct Sender<T, B: Eq + Hash + Clone> {
179 inner: FutSender<Message<T, B>>,
180 tracker: Tracker<B>,
181}
182
183impl<T, B: Eq + Hash + Clone> Sender<T, B> {
184 pub async fn send(&mut self, batch: Option<B>, data: T) -> Result<u64, SendError> {
186 let guard = Arc::new(self.tracker.guard(batch));
188 let watermark = guard.sequence;
189
190 let msg = Message { data, guard };
192 self.inner.send(msg).await?;
193
194 Ok(watermark)
195 }
196
197 pub fn try_send(
199 &mut self,
200 batch: Option<B>,
201 data: T,
202 ) -> Result<u64, TrySendError<Message<T, B>>> {
203 let guard = Arc::new(self.tracker.guard(batch));
205 let watermark = guard.sequence;
206
207 let msg = Message { data, guard };
209 self.inner.try_send(msg)?;
210
211 Ok(watermark)
212 }
213
214 pub fn watermark(&self) -> u64 {
216 self.tracker.state.lock().unwrap().watermark
217 }
218
219 pub fn pending(&self, batch: B) -> usize {
221 self.tracker
222 .state
223 .lock()
224 .unwrap()
225 .batches
226 .get(&batch)
227 .copied()
228 .unwrap_or(0)
229 }
230}
231
232pub struct Receiver<T, B: Eq + Hash + Clone> {
234 inner: FutReceiver<Message<T, B>>,
235}
236
237impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
238 pub async fn recv(&mut self) -> Option<Message<T, B>> {
240 self.inner.next().await
241 }
242
243 pub fn try_recv(&mut self) -> Option<Message<T, B>> {
245 self.inner.try_next().ok().flatten()
246 }
247}
248
249impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
250 type Item = Message<T, B>;
251
252 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
253 Pin::new(&mut self.inner).poll_next(cx)
254 }
255}
256
257pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
259 let (tx, rx) = mpsc::channel(buffer);
260 let sender = Sender {
261 inner: tx,
262 tracker: Tracker::new(),
263 };
264 let receiver = Receiver { inner: rx };
265 (sender, receiver)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use futures::executor::block_on;
272
273 #[test]
274 fn test_basic() {
275 block_on(async move {
276 let (mut sender, mut receiver) = bounded::<i32, u64>(10);
277
278 let watermark = sender.send(None, 42).await.unwrap();
280 assert_eq!(watermark, 1);
281 assert_eq!(sender.watermark(), 0);
282
283 let msg = receiver.recv().await.unwrap();
285 assert_eq!(msg.data, 42);
286 assert_eq!(sender.watermark(), 0);
287
288 drop(msg.guard);
290 assert_eq!(sender.watermark(), 1);
291 });
292 }
293
294 #[test]
295 fn test_batch_tracking() {
296 block_on(async move {
297 let (mut sender, mut receiver) = bounded::<String, u64>(10);
298
299 let watermark1 = sender.send(Some(100), "msg1".to_string()).await.unwrap();
301 let watermark2 = sender.send(Some(100), "msg2".to_string()).await.unwrap();
302 let watermark3 = sender.send(Some(200), "msg3".to_string()).await.unwrap();
303
304 assert_eq!(watermark1, 1);
305 assert_eq!(watermark2, 2);
306 assert_eq!(watermark3, 3);
307 assert_eq!(sender.pending(100), 2);
308 assert_eq!(sender.pending(200), 1);
309 assert_eq!(sender.pending(300), 0);
310
311 let msg1 = receiver.recv().await.unwrap();
313 assert_eq!(msg1.data, "msg1");
314 drop(msg1.guard);
315
316 assert_eq!(sender.pending(100), 1);
317 assert_eq!(sender.pending(200), 1);
318
319 let msg2 = receiver.recv().await.unwrap();
321 let msg3 = receiver.recv().await.unwrap();
322 drop(msg2.guard);
323 drop(msg3.guard);
324
325 assert_eq!(sender.pending(100), 0);
326 assert_eq!(sender.pending(200), 0);
327 });
328 }
329
330 #[test]
331 fn test_cloned_guards() {
332 block_on(async move {
333 let (mut sender, mut receiver) = bounded::<&str, u64>(10);
334
335 let watermark = sender.send(Some(1), "test").await.unwrap();
336 assert_eq!(watermark, 1);
337
338 let msg = receiver.recv().await.unwrap();
340 assert_eq!(msg.data, "test");
341
342 let msg_guard_clone1 = msg.guard.clone();
344 let msg_guard_clone2 = msg.guard.clone();
345
346 assert_eq!(sender.pending(1), 1);
347 assert_eq!(sender.watermark(), 0);
348
349 drop(msg.guard);
351 drop(msg_guard_clone1);
352 assert_eq!(sender.pending(1), 1);
353 assert_eq!(sender.watermark(), 0);
354
355 drop(msg_guard_clone2);
357 assert_eq!(sender.pending(1), 0);
358 assert_eq!(sender.watermark(), 1);
359 });
360 }
361
362 #[test]
363 fn test_try_send() {
364 block_on(async move {
365 let (mut sender, mut receiver) = bounded::<i32, u64>(2);
366
367 let watermark1 = sender.try_send(Some(10), 1).unwrap();
369 let watermark2 = sender.try_send(Some(10), 2).unwrap();
370
371 assert_eq!(sender.pending(10), 2);
372 assert_eq!(watermark1, 1);
373 assert_eq!(watermark2, 2);
374
375 let msg1 = receiver.recv().await.unwrap();
377 assert_eq!(msg1.data, 1);
378 drop(msg1.guard);
379
380 assert_eq!(sender.pending(10), 1);
381
382 let msg2 = receiver.recv().await.unwrap();
383 drop(msg2.guard);
384
385 assert_eq!(sender.pending(10), 0);
386 });
387 }
388
389 #[test]
390 fn test_channel_closure() {
391 block_on(async move {
392 let (mut sender, receiver) = bounded::<i32, u64>(10);
393
394 let _guard = sender.send(None, 1).await.unwrap();
395
396 drop(receiver);
398
399 assert!(sender.send(None, 2).await.is_err());
401 });
402 }
403}