Skip to main content

commonware_utils/channel/
tracked.rs

1//! A channel that tracks message delivery.
2//!
3//! This channel provides message delivery tracking. Each sent message includes
4//! a [Guard] that tracks when the message has been fully processed. When ALL
5//! references to the guard are dropped, the message is marked as delivered.
6//
7//! # Features
8//!
9//! - **Watermarks**: Get the highest sequence number where all messages up to it have been delivered
10//! - **Batches**: Assign batches to messages and track pending counts per batch
11//! - **Clonable Guards**: Guards can be cloned and shared; delivery happens when all clones are dropped
12//
13//! # Sequence Number Overflow
14//!
15//! Uses [u64] for sequence numbers. At 100 messages per nanosecond, overflow occurs after ~5.85 years.
16//! Systems requiring more message throughput should implement periodic resets or use external sequence management.
17//
18//! # Example
19//!
20//! ```
21//! use futures::executor::block_on;
22//! use commonware_utils::channel::tracked;
23//! block_on(async {
24//!     let (sender, mut receiver) = tracked::bounded::<String, u64>(10);
25//
26//!     // Send a message with batch ID
27//!     let sequence = sender.send(Some(1), "hello".to_string()).await.unwrap();
28//
29//!     // Check pending messages
30//!     assert_eq!(sender.pending(1), 1);
31//!     assert_eq!(sender.watermark(), 0);
32//
33//!     // Receive and process
34//!     let msg = receiver.recv().await.unwrap();
35//!     assert_eq!(msg.data, "hello");
36//
37//!     // Clone the guard - delivery won't happen until all clones are dropped
38//!     let guard_clone = msg.guard.clone();
39//!     drop(msg.guard);
40//!     assert_eq!(sender.watermark(), 0); // Still not delivered
41//
42//!     // Drop the last guard reference to mark as delivered
43//!     drop(guard_clone);
44//!     assert_eq!(sender.pending(1), 0);
45//!     assert_eq!(sender.watermark(), 1);
46//! });
47//! ```
48
49use super::mpsc::{
50    self,
51    error::{SendError, TryRecvError, TrySendError},
52};
53use futures::Stream;
54use std::{
55    collections::HashMap,
56    hash::Hash,
57    pin::Pin,
58    sync::{Arc, Mutex},
59    task::{Context, Poll},
60};
61
62/// A guard that tracks message delivery. When dropped, the message is marked as delivered.
63#[derive(Clone)]
64pub struct Guard<B: Eq + Hash + Clone> {
65    sequence: u64,
66    tracker: Arc<Mutex<State<B>>>,
67
68    batch: Option<B>,
69}
70
71impl<B: Eq + Hash + Clone> Drop for Guard<B> {
72    fn drop(&mut self) {
73        // Get the state
74        let mut state = self.tracker.lock().unwrap();
75
76        // Mark the message as delivered
77        *state.pending.get_mut(&self.sequence).unwrap() = true;
78
79        // Update watermark if possible
80        let mut current_watermark = state.watermark;
81        while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
82            // If the next message is not delivered, we can stop
83            if !*delivered {
84                break;
85            }
86
87            // Remove the next message from the pending list
88            state.pending.remove(&(current_watermark + 1));
89            current_watermark += 1;
90            state.watermark = current_watermark;
91        }
92
93        // Update batch count (if necessary)
94        if let Some(batch) = &self.batch {
95            let count = state.batches.get_mut(batch).unwrap();
96            if *count > 1 {
97                *count -= 1;
98            } else {
99                state.batches.remove(batch);
100            }
101        }
102    }
103}
104
105/// A message containing data and a [Guard] that tracks delivery.
106pub struct Message<T, B: Eq + Hash + Clone> {
107    /// The data of the message.
108    pub data: T,
109    /// The [Guard] that tracks delivery.
110    ///
111    /// When no outstanding references to the guard exist, the message is considered delivered.
112    pub guard: Arc<Guard<B>>,
113}
114
115/// The state of the [Tracker].
116struct State<B> {
117    next: u64,
118    watermark: u64,
119    batches: HashMap<B, usize>,
120    pending: HashMap<u64, bool>,
121}
122
123impl<B> Default for State<B> {
124    fn default() -> Self {
125        Self {
126            next: 1,
127            watermark: 0,
128            batches: HashMap::new(),
129            pending: HashMap::new(),
130        }
131    }
132}
133
134/// Tracks delivery state across all messages.
135///
136/// Note on sequence overflow: Using u64 for sequence numbers provides ample headroom.
137/// At 100 messages per nanosecond, it would take ~5.85 years to overflow.
138/// For systems requiring longer uptime without restart, consider implementing
139/// sequence number wrapping with careful watermark handling.
140#[derive(Clone)]
141struct Tracker<B: Eq + Hash + Clone> {
142    state: Arc<Mutex<State<B>>>,
143}
144
145impl<B: Eq + Hash + Clone> Tracker<B> {
146    fn new() -> Self {
147        Self {
148            state: Arc::new(Mutex::new(State::default())),
149        }
150    }
151
152    fn guard(&self, batch: Option<B>) -> Guard<B> {
153        // Get state
154        let mut state = self.state.lock().unwrap();
155
156        // Get the next sequence
157        let sequence = state.next;
158        state.next += 1;
159
160        // Track this sequence as not yet delivered
161        state.pending.insert(sequence, false);
162
163        // Update batch count if provided
164        if let Some(batch) = &batch {
165            *state.batches.entry(batch.clone()).or_insert(0) += 1;
166        }
167
168        Guard {
169            sequence,
170            tracker: self.state.clone(),
171
172            batch,
173        }
174    }
175}
176
177/// A sender that wraps `Sender` and tracks message delivery.
178#[derive(Clone)]
179pub struct Sender<T, B: Eq + Hash + Clone> {
180    inner: mpsc::Sender<Message<T, B>>,
181    tracker: Tracker<B>,
182}
183
184impl<T, B: Eq + Hash + Clone> Sender<T, B> {
185    /// Sends a message with an optional batch ID and returns a delivery guard.
186    pub async fn send(&self, batch: Option<B>, data: T) -> Result<u64, SendError<Message<T, B>>> {
187        // Create the guard
188        let guard = Arc::new(self.tracker.guard(batch));
189        let watermark = guard.sequence;
190
191        // Send the message
192        let msg = Message { data, guard };
193        self.inner.send(msg).await?;
194
195        Ok(watermark)
196    }
197
198    /// Tries to send a message without blocking.
199    pub fn try_send(&self, batch: Option<B>, data: T) -> Result<u64, TrySendError<Message<T, B>>> {
200        // Create the guard
201        let guard = Arc::new(self.tracker.guard(batch));
202        let watermark = guard.sequence;
203
204        // Send the message
205        let msg = Message { data, guard };
206        self.inner.try_send(msg)?;
207
208        Ok(watermark)
209    }
210
211    /// Returns the current delivery watermark (highest sequence number where all messages up to and including it have been delivered).
212    pub fn watermark(&self) -> u64 {
213        self.tracker.state.lock().unwrap().watermark
214    }
215
216    /// Returns the number of pending messages for a specific batch.
217    pub fn pending(&self, batch: B) -> usize {
218        self.tracker
219            .state
220            .lock()
221            .unwrap()
222            .batches
223            .get(&batch)
224            .copied()
225            .unwrap_or(0)
226    }
227}
228
229/// A receiver that wraps [mpsc::Receiver] and provides tracked messages.
230pub struct Receiver<T, B: Eq + Hash + Clone> {
231    inner: mpsc::Receiver<Message<T, B>>,
232}
233
234impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
235    /// Receives the next message.
236    pub async fn recv(&mut self) -> Option<Message<T, B>> {
237        self.inner.recv().await
238    }
239
240    /// Tries to receive a message without blocking.
241    pub fn try_recv(&mut self) -> Result<Message<T, B>, TryRecvError> {
242        self.inner.try_recv()
243    }
244}
245
246impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
247    type Item = Message<T, B>;
248
249    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250        self.inner.poll_recv(cx)
251    }
252}
253
254/// Create a new bounded channel with delivery tracking.
255pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
256    let (tx, rx) = mpsc::channel(buffer);
257    let sender = Sender {
258        inner: tx,
259        tracker: Tracker::new(),
260    };
261    let receiver = Receiver { inner: rx };
262    (sender, receiver)
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use futures::executor::block_on;
269
270    #[test]
271    fn test_basic() {
272        block_on(async move {
273            let (sender, mut receiver) = bounded::<i32, u64>(10);
274
275            // Send a message without batch ID
276            let watermark = sender.send(None, 42).await.unwrap();
277            assert_eq!(watermark, 1);
278            assert_eq!(sender.watermark(), 0);
279
280            // Receive the message but don't drop the guard yet
281            let msg = receiver.recv().await.unwrap();
282            assert_eq!(msg.data, 42);
283            assert_eq!(sender.watermark(), 0);
284
285            // Drop the guard to mark as delivered
286            drop(msg.guard);
287            assert_eq!(sender.watermark(), 1);
288        });
289    }
290
291    #[test]
292    fn test_batch_tracking() {
293        block_on(async move {
294            let (sender, mut receiver) = bounded::<String, u64>(10);
295
296            // Send messages with different batch IDs
297            let watermark1 = sender.send(Some(100), "msg1".to_string()).await.unwrap();
298            let watermark2 = sender.send(Some(100), "msg2".to_string()).await.unwrap();
299            let watermark3 = sender.send(Some(200), "msg3".to_string()).await.unwrap();
300
301            assert_eq!(watermark1, 1);
302            assert_eq!(watermark2, 2);
303            assert_eq!(watermark3, 3);
304            assert_eq!(sender.pending(100), 2);
305            assert_eq!(sender.pending(200), 1);
306            assert_eq!(sender.pending(300), 0);
307
308            // Receive and process first message
309            let msg1 = receiver.recv().await.unwrap();
310            assert_eq!(msg1.data, "msg1");
311            drop(msg1.guard);
312
313            assert_eq!(sender.pending(100), 1);
314            assert_eq!(sender.pending(200), 1);
315
316            // Receive and process remaining messages
317            let msg2 = receiver.recv().await.unwrap();
318            let msg3 = receiver.recv().await.unwrap();
319            drop(msg2.guard);
320            drop(msg3.guard);
321
322            assert_eq!(sender.pending(100), 0);
323            assert_eq!(sender.pending(200), 0);
324        });
325    }
326
327    #[test]
328    fn test_cloned_guards() {
329        block_on(async move {
330            let (sender, mut receiver) = bounded::<&str, u64>(10);
331
332            let watermark = sender.send(Some(1), "test").await.unwrap();
333            assert_eq!(watermark, 1);
334
335            // Receive the message immediately
336            let msg = receiver.recv().await.unwrap();
337            assert_eq!(msg.data, "test");
338
339            // The message guard and sender guard are the same
340            let msg_guard_clone1 = msg.guard.clone();
341            let msg_guard_clone2 = msg.guard.clone();
342
343            assert_eq!(sender.pending(1), 1);
344            assert_eq!(sender.watermark(), 0);
345
346            // Drop original and one clone
347            drop(msg.guard);
348            drop(msg_guard_clone1);
349            assert_eq!(sender.pending(1), 1);
350            assert_eq!(sender.watermark(), 0);
351
352            // Drop last clone
353            drop(msg_guard_clone2);
354            assert_eq!(sender.pending(1), 0);
355            assert_eq!(sender.watermark(), 1);
356        });
357    }
358
359    #[test]
360    fn test_try_send() {
361        block_on(async move {
362            let (sender, mut receiver) = bounded::<i32, u64>(2);
363
364            // Try send should work when buffer has space
365            let watermark1 = sender.try_send(Some(10), 1).unwrap();
366            let watermark2 = sender.try_send(Some(10), 2).unwrap();
367
368            assert_eq!(sender.pending(10), 2);
369            assert_eq!(watermark1, 1);
370            assert_eq!(watermark2, 2);
371
372            // Receive messages
373            let msg1 = receiver.recv().await.unwrap();
374            assert_eq!(msg1.data, 1);
375            drop(msg1.guard);
376
377            assert_eq!(sender.pending(10), 1);
378
379            let msg2 = receiver.recv().await.unwrap();
380            drop(msg2.guard);
381
382            assert_eq!(sender.pending(10), 0);
383        });
384    }
385
386    #[test]
387    fn test_channel_closure() {
388        block_on(async move {
389            let (sender, receiver) = bounded::<i32, u64>(10);
390
391            let _guard = sender.send(None, 1).await.unwrap();
392
393            // Drop receiver
394            drop(receiver);
395
396            // Next send should fail
397            assert!(sender.send(None, 2).await.is_err());
398        });
399    }
400}