commonware_utils/channel/
tracked.rs1use super::mpsc::{
50 self,
51 error::{SendError, TryRecvError, TrySendError},
52};
53use futures::Stream;
54use std::{
55 collections::HashMap,
56 hash::Hash,
57 pin::Pin,
58 sync::{Arc, Mutex},
59 task::{Context, Poll},
60};
61
62#[derive(Clone)]
64pub struct Guard<B: Eq + Hash + Clone> {
65 sequence: u64,
66 tracker: Arc<Mutex<State<B>>>,
67
68 batch: Option<B>,
69}
70
71impl<B: Eq + Hash + Clone> Drop for Guard<B> {
72 fn drop(&mut self) {
73 let mut state = self.tracker.lock().unwrap();
75
76 *state.pending.get_mut(&self.sequence).unwrap() = true;
78
79 let mut current_watermark = state.watermark;
81 while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
82 if !*delivered {
84 break;
85 }
86
87 state.pending.remove(&(current_watermark + 1));
89 current_watermark += 1;
90 state.watermark = current_watermark;
91 }
92
93 if let Some(batch) = &self.batch {
95 let count = state.batches.get_mut(batch).unwrap();
96 if *count > 1 {
97 *count -= 1;
98 } else {
99 state.batches.remove(batch);
100 }
101 }
102 }
103}
104
105pub struct Message<T, B: Eq + Hash + Clone> {
107 pub data: T,
109 pub guard: Arc<Guard<B>>,
113}
114
115struct State<B> {
117 next: u64,
118 watermark: u64,
119 batches: HashMap<B, usize>,
120 pending: HashMap<u64, bool>,
121}
122
123impl<B> Default for State<B> {
124 fn default() -> Self {
125 Self {
126 next: 1,
127 watermark: 0,
128 batches: HashMap::new(),
129 pending: HashMap::new(),
130 }
131 }
132}
133
134#[derive(Clone)]
141struct Tracker<B: Eq + Hash + Clone> {
142 state: Arc<Mutex<State<B>>>,
143}
144
145impl<B: Eq + Hash + Clone> Tracker<B> {
146 fn new() -> Self {
147 Self {
148 state: Arc::new(Mutex::new(State::default())),
149 }
150 }
151
152 fn guard(&self, batch: Option<B>) -> Guard<B> {
153 let mut state = self.state.lock().unwrap();
155
156 let sequence = state.next;
158 state.next += 1;
159
160 state.pending.insert(sequence, false);
162
163 if let Some(batch) = &batch {
165 *state.batches.entry(batch.clone()).or_insert(0) += 1;
166 }
167
168 Guard {
169 sequence,
170 tracker: self.state.clone(),
171
172 batch,
173 }
174 }
175}
176
177#[derive(Clone)]
179pub struct Sender<T, B: Eq + Hash + Clone> {
180 inner: mpsc::Sender<Message<T, B>>,
181 tracker: Tracker<B>,
182}
183
184impl<T, B: Eq + Hash + Clone> Sender<T, B> {
185 pub async fn send(&self, batch: Option<B>, data: T) -> Result<u64, SendError<Message<T, B>>> {
187 let guard = Arc::new(self.tracker.guard(batch));
189 let watermark = guard.sequence;
190
191 let msg = Message { data, guard };
193 self.inner.send(msg).await?;
194
195 Ok(watermark)
196 }
197
198 pub fn try_send(&self, batch: Option<B>, data: T) -> Result<u64, TrySendError<Message<T, B>>> {
200 let guard = Arc::new(self.tracker.guard(batch));
202 let watermark = guard.sequence;
203
204 let msg = Message { data, guard };
206 self.inner.try_send(msg)?;
207
208 Ok(watermark)
209 }
210
211 pub fn watermark(&self) -> u64 {
213 self.tracker.state.lock().unwrap().watermark
214 }
215
216 pub fn pending(&self, batch: B) -> usize {
218 self.tracker
219 .state
220 .lock()
221 .unwrap()
222 .batches
223 .get(&batch)
224 .copied()
225 .unwrap_or(0)
226 }
227}
228
229pub struct Receiver<T, B: Eq + Hash + Clone> {
231 inner: mpsc::Receiver<Message<T, B>>,
232}
233
234impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
235 pub async fn recv(&mut self) -> Option<Message<T, B>> {
237 self.inner.recv().await
238 }
239
240 pub fn try_recv(&mut self) -> Result<Message<T, B>, TryRecvError> {
242 self.inner.try_recv()
243 }
244}
245
246impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
247 type Item = Message<T, B>;
248
249 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250 self.inner.poll_recv(cx)
251 }
252}
253
254pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
256 let (tx, rx) = mpsc::channel(buffer);
257 let sender = Sender {
258 inner: tx,
259 tracker: Tracker::new(),
260 };
261 let receiver = Receiver { inner: rx };
262 (sender, receiver)
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use futures::executor::block_on;
269
270 #[test]
271 fn test_basic() {
272 block_on(async move {
273 let (sender, mut receiver) = bounded::<i32, u64>(10);
274
275 let watermark = sender.send(None, 42).await.unwrap();
277 assert_eq!(watermark, 1);
278 assert_eq!(sender.watermark(), 0);
279
280 let msg = receiver.recv().await.unwrap();
282 assert_eq!(msg.data, 42);
283 assert_eq!(sender.watermark(), 0);
284
285 drop(msg.guard);
287 assert_eq!(sender.watermark(), 1);
288 });
289 }
290
291 #[test]
292 fn test_batch_tracking() {
293 block_on(async move {
294 let (sender, mut receiver) = bounded::<String, u64>(10);
295
296 let watermark1 = sender.send(Some(100), "msg1".to_string()).await.unwrap();
298 let watermark2 = sender.send(Some(100), "msg2".to_string()).await.unwrap();
299 let watermark3 = sender.send(Some(200), "msg3".to_string()).await.unwrap();
300
301 assert_eq!(watermark1, 1);
302 assert_eq!(watermark2, 2);
303 assert_eq!(watermark3, 3);
304 assert_eq!(sender.pending(100), 2);
305 assert_eq!(sender.pending(200), 1);
306 assert_eq!(sender.pending(300), 0);
307
308 let msg1 = receiver.recv().await.unwrap();
310 assert_eq!(msg1.data, "msg1");
311 drop(msg1.guard);
312
313 assert_eq!(sender.pending(100), 1);
314 assert_eq!(sender.pending(200), 1);
315
316 let msg2 = receiver.recv().await.unwrap();
318 let msg3 = receiver.recv().await.unwrap();
319 drop(msg2.guard);
320 drop(msg3.guard);
321
322 assert_eq!(sender.pending(100), 0);
323 assert_eq!(sender.pending(200), 0);
324 });
325 }
326
327 #[test]
328 fn test_cloned_guards() {
329 block_on(async move {
330 let (sender, mut receiver) = bounded::<&str, u64>(10);
331
332 let watermark = sender.send(Some(1), "test").await.unwrap();
333 assert_eq!(watermark, 1);
334
335 let msg = receiver.recv().await.unwrap();
337 assert_eq!(msg.data, "test");
338
339 let msg_guard_clone1 = msg.guard.clone();
341 let msg_guard_clone2 = msg.guard.clone();
342
343 assert_eq!(sender.pending(1), 1);
344 assert_eq!(sender.watermark(), 0);
345
346 drop(msg.guard);
348 drop(msg_guard_clone1);
349 assert_eq!(sender.pending(1), 1);
350 assert_eq!(sender.watermark(), 0);
351
352 drop(msg_guard_clone2);
354 assert_eq!(sender.pending(1), 0);
355 assert_eq!(sender.watermark(), 1);
356 });
357 }
358
359 #[test]
360 fn test_try_send() {
361 block_on(async move {
362 let (sender, mut receiver) = bounded::<i32, u64>(2);
363
364 let watermark1 = sender.try_send(Some(10), 1).unwrap();
366 let watermark2 = sender.try_send(Some(10), 2).unwrap();
367
368 assert_eq!(sender.pending(10), 2);
369 assert_eq!(watermark1, 1);
370 assert_eq!(watermark2, 2);
371
372 let msg1 = receiver.recv().await.unwrap();
374 assert_eq!(msg1.data, 1);
375 drop(msg1.guard);
376
377 assert_eq!(sender.pending(10), 1);
378
379 let msg2 = receiver.recv().await.unwrap();
380 drop(msg2.guard);
381
382 assert_eq!(sender.pending(10), 0);
383 });
384 }
385
386 #[test]
387 fn test_channel_closure() {
388 block_on(async move {
389 let (sender, receiver) = bounded::<i32, u64>(10);
390
391 let _guard = sender.send(None, 1).await.unwrap();
392
393 drop(receiver);
395
396 assert!(sender.send(None, 2).await.is_err());
398 });
399 }
400}