commonware_utils/channels/
tracked.rs

1//! A channel that tracks message delivery.
2//!
3//! This channel provides message delivery tracking. Each sent message includes
4//! a [Guard] that tracks when the message has been fully processed. When ALL
5//! references to the guard are dropped, the message is marked as delivered.
6//
7//! # Features
8//!
9//! - **Watermarks**: Get the highest sequence number where all messages up to it have been delivered
10//! - **Batches**: Assign batches to messages and track pending counts per batch
11//! - **Clonable Guards**: Guards can be cloned and shared; delivery happens when all clones are dropped
12//
13//! # Sequence Number Overflow
14//!
15//! Uses [u64] for sequence numbers. At 100 messages per nanosecond, overflow occurs after ~5.85 years.
16//! Systems requiring more message throughput should implement periodic resets or use external sequence management.
17//
18//! # Example
19//!
20//! ```
21//! use futures::executor::block_on;
22//! use commonware_utils::channels::tracked;
23//! block_on(async {
24//!     let (mut sender, mut receiver) = tracked::bounded::<String, u64>(10);
25//
26//!     // Send a message with batch ID
27//!     let sequence = sender.send(Some(1), "hello".to_string()).await.unwrap();
28//
29//!     // Check pending messages
30//!     assert_eq!(sender.pending(1), 1);
31//!     assert_eq!(sender.watermark(), 0);
32//
33//!     // Receive and process
34//!     let msg = receiver.recv().await.unwrap();
35//!     assert_eq!(msg.data, "hello");
36//
37//!     // Clone the guard - delivery won't happen until all clones are dropped
38//!     let guard_clone = msg.guard.clone();
39//!     drop(msg.guard);
40//!     assert_eq!(sender.watermark(), 0); // Still not delivered
41//
42//!     // Drop the last guard reference to mark as delivered
43//!     drop(guard_clone);
44//!     assert_eq!(sender.pending(1), 0);
45//!     assert_eq!(sender.watermark(), 1);
46//! });
47//! ```
48
49use futures::{
50    channel::mpsc::{self, Receiver as FutReceiver, SendError, Sender as FutSender, TrySendError},
51    SinkExt, Stream, StreamExt,
52};
53use std::{
54    collections::HashMap,
55    hash::Hash,
56    pin::Pin,
57    sync::{Arc, Mutex},
58    task::{Context, Poll},
59};
60
61/// A guard that tracks message delivery. When dropped, the message is marked as delivered.
62#[derive(Clone)]
63pub struct Guard<B: Eq + Hash + Clone> {
64    sequence: u64,
65    tracker: Arc<Mutex<State<B>>>,
66
67    batch: Option<B>,
68}
69
70impl<B: Eq + Hash + Clone> Drop for Guard<B> {
71    fn drop(&mut self) {
72        // Get the state
73        let mut state = self.tracker.lock().unwrap();
74
75        // Mark the message as delivered
76        *state.pending.get_mut(&self.sequence).unwrap() = true;
77
78        // Update watermark if possible
79        let mut current_watermark = state.watermark;
80        while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
81            // If the next message is not delivered, we can stop
82            if !*delivered {
83                break;
84            }
85
86            // Remove the next message from the pending list
87            state.pending.remove(&(current_watermark + 1));
88            current_watermark += 1;
89            state.watermark = current_watermark;
90        }
91
92        // Update batch count (if necessary)
93        if let Some(batch) = &self.batch {
94            let count = state.batches.get_mut(batch).unwrap();
95            if *count > 1 {
96                *count -= 1;
97            } else {
98                state.batches.remove(batch);
99            }
100        }
101    }
102}
103
104/// A message containing data and a [Guard] that tracks delivery.
105pub struct Message<T, B: Eq + Hash + Clone> {
106    /// The data of the message.
107    pub data: T,
108    /// The [Guard] that tracks delivery.
109    ///
110    /// When no outstanding references to the guard exist, the message is considered delivered.
111    pub guard: Arc<Guard<B>>,
112}
113
114/// The state of the [Tracker].
115struct State<B> {
116    next: u64,
117    watermark: u64,
118    batches: HashMap<B, usize>,
119    pending: HashMap<u64, bool>,
120}
121
122impl<B> Default for State<B> {
123    fn default() -> Self {
124        Self {
125            next: 1,
126            watermark: 0,
127            batches: HashMap::new(),
128            pending: HashMap::new(),
129        }
130    }
131}
132
133/// Tracks delivery state across all messages.
134///
135/// Note on sequence overflow: Using u64 for sequence numbers provides ample headroom.
136/// At 100 messages per nanosecond, it would take ~5.85 years to overflow.
137/// For systems requiring longer uptime without restart, consider implementing
138/// sequence number wrapping with careful watermark handling.
139#[derive(Clone)]
140struct Tracker<B: Eq + Hash + Clone> {
141    state: Arc<Mutex<State<B>>>,
142}
143
144impl<B: Eq + Hash + Clone> Tracker<B> {
145    fn new() -> Self {
146        Self {
147            state: Arc::new(Mutex::new(State::default())),
148        }
149    }
150
151    fn guard(&self, batch: Option<B>) -> Guard<B> {
152        // Get state
153        let mut state = self.state.lock().unwrap();
154
155        // Get the next sequence
156        let sequence = state.next;
157        state.next += 1;
158
159        // Track this sequence as not yet delivered
160        state.pending.insert(sequence, false);
161
162        // Update batch count if provided
163        if let Some(batch) = &batch {
164            *state.batches.entry(batch.clone()).or_insert(0) += 1;
165        }
166
167        Guard {
168            sequence,
169            tracker: self.state.clone(),
170
171            batch,
172        }
173    }
174}
175
176/// A sender that wraps `Sender` and tracks message delivery.
177#[derive(Clone)]
178pub struct Sender<T, B: Eq + Hash + Clone> {
179    inner: FutSender<Message<T, B>>,
180    tracker: Tracker<B>,
181}
182
183impl<T, B: Eq + Hash + Clone> Sender<T, B> {
184    /// Sends a message with an optional batch ID and returns a delivery guard.
185    pub async fn send(&mut self, batch: Option<B>, data: T) -> Result<u64, SendError> {
186        // Create the guard
187        let guard = Arc::new(self.tracker.guard(batch));
188        let watermark = guard.sequence;
189
190        // Send the message
191        let msg = Message { data, guard };
192        self.inner.send(msg).await?;
193
194        Ok(watermark)
195    }
196
197    /// Tries to send a message without blocking.
198    pub fn try_send(
199        &mut self,
200        batch: Option<B>,
201        data: T,
202    ) -> Result<u64, TrySendError<Message<T, B>>> {
203        // Create the guard
204        let guard = Arc::new(self.tracker.guard(batch));
205        let watermark = guard.sequence;
206
207        // Send the message
208        let msg = Message { data, guard };
209        self.inner.try_send(msg)?;
210
211        Ok(watermark)
212    }
213
214    /// Returns the current delivery watermark (highest sequence number where all messages up to and including it have been delivered).
215    pub fn watermark(&self) -> u64 {
216        self.tracker.state.lock().unwrap().watermark
217    }
218
219    /// Returns the number of pending messages for a specific batch.
220    pub fn pending(&self, batch: B) -> usize {
221        self.tracker
222            .state
223            .lock()
224            .unwrap()
225            .batches
226            .get(&batch)
227            .copied()
228            .unwrap_or(0)
229    }
230}
231
232/// A receiver that wraps [FutReceiver] and provides tracked messages.
233pub struct Receiver<T, B: Eq + Hash + Clone> {
234    inner: FutReceiver<Message<T, B>>,
235}
236
237impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
238    /// Receives the next message.
239    pub async fn recv(&mut self) -> Option<Message<T, B>> {
240        self.inner.next().await
241    }
242
243    /// Tries to receive a message without blocking.
244    pub fn try_recv(&mut self) -> Option<Message<T, B>> {
245        self.inner.try_next().ok().flatten()
246    }
247}
248
249impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
250    type Item = Message<T, B>;
251
252    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
253        Pin::new(&mut self.inner).poll_next(cx)
254    }
255}
256
257/// Create a new bounded channel with delivery tracking.
258pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
259    let (tx, rx) = mpsc::channel(buffer);
260    let sender = Sender {
261        inner: tx,
262        tracker: Tracker::new(),
263    };
264    let receiver = Receiver { inner: rx };
265    (sender, receiver)
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use futures::executor::block_on;
272
273    #[test]
274    fn test_basic() {
275        block_on(async move {
276            let (mut sender, mut receiver) = bounded::<i32, u64>(10);
277
278            // Send a message without batch ID
279            let watermark = sender.send(None, 42).await.unwrap();
280            assert_eq!(watermark, 1);
281            assert_eq!(sender.watermark(), 0);
282
283            // Receive the message but don't drop the guard yet
284            let msg = receiver.recv().await.unwrap();
285            assert_eq!(msg.data, 42);
286            assert_eq!(sender.watermark(), 0);
287
288            // Drop the guard to mark as delivered
289            drop(msg.guard);
290            assert_eq!(sender.watermark(), 1);
291        });
292    }
293
294    #[test]
295    fn test_batch_tracking() {
296        block_on(async move {
297            let (mut sender, mut receiver) = bounded::<String, u64>(10);
298
299            // Send messages with different batch IDs
300            let watermark1 = sender.send(Some(100), "msg1".to_string()).await.unwrap();
301            let watermark2 = sender.send(Some(100), "msg2".to_string()).await.unwrap();
302            let watermark3 = sender.send(Some(200), "msg3".to_string()).await.unwrap();
303
304            assert_eq!(watermark1, 1);
305            assert_eq!(watermark2, 2);
306            assert_eq!(watermark3, 3);
307            assert_eq!(sender.pending(100), 2);
308            assert_eq!(sender.pending(200), 1);
309            assert_eq!(sender.pending(300), 0);
310
311            // Receive and process first message
312            let msg1 = receiver.recv().await.unwrap();
313            assert_eq!(msg1.data, "msg1");
314            drop(msg1.guard);
315
316            assert_eq!(sender.pending(100), 1);
317            assert_eq!(sender.pending(200), 1);
318
319            // Receive and process remaining messages
320            let msg2 = receiver.recv().await.unwrap();
321            let msg3 = receiver.recv().await.unwrap();
322            drop(msg2.guard);
323            drop(msg3.guard);
324
325            assert_eq!(sender.pending(100), 0);
326            assert_eq!(sender.pending(200), 0);
327        });
328    }
329
330    #[test]
331    fn test_cloned_guards() {
332        block_on(async move {
333            let (mut sender, mut receiver) = bounded::<&str, u64>(10);
334
335            let watermark = sender.send(Some(1), "test").await.unwrap();
336            assert_eq!(watermark, 1);
337
338            // Receive the message immediately
339            let msg = receiver.recv().await.unwrap();
340            assert_eq!(msg.data, "test");
341
342            // The message guard and sender guard are the same
343            let msg_guard_clone1 = msg.guard.clone();
344            let msg_guard_clone2 = msg.guard.clone();
345
346            assert_eq!(sender.pending(1), 1);
347            assert_eq!(sender.watermark(), 0);
348
349            // Drop original and one clone
350            drop(msg.guard);
351            drop(msg_guard_clone1);
352            assert_eq!(sender.pending(1), 1);
353            assert_eq!(sender.watermark(), 0);
354
355            // Drop last clone
356            drop(msg_guard_clone2);
357            assert_eq!(sender.pending(1), 0);
358            assert_eq!(sender.watermark(), 1);
359        });
360    }
361
362    #[test]
363    fn test_try_send() {
364        block_on(async move {
365            let (mut sender, mut receiver) = bounded::<i32, u64>(2);
366
367            // Try send should work when buffer has space
368            let watermark1 = sender.try_send(Some(10), 1).unwrap();
369            let watermark2 = sender.try_send(Some(10), 2).unwrap();
370
371            assert_eq!(sender.pending(10), 2);
372            assert_eq!(watermark1, 1);
373            assert_eq!(watermark2, 2);
374
375            // Receive messages
376            let msg1 = receiver.recv().await.unwrap();
377            assert_eq!(msg1.data, 1);
378            drop(msg1.guard);
379
380            assert_eq!(sender.pending(10), 1);
381
382            let msg2 = receiver.recv().await.unwrap();
383            drop(msg2.guard);
384
385            assert_eq!(sender.pending(10), 0);
386        });
387    }
388
389    #[test]
390    fn test_channel_closure() {
391        block_on(async move {
392            let (mut sender, receiver) = bounded::<i32, u64>(10);
393
394            let _guard = sender.send(None, 1).await.unwrap();
395
396            // Drop receiver
397            drop(receiver);
398
399            // Next send should fail
400            assert!(sender.send(None, 2).await.is_err());
401        });
402    }
403}