Skip to main content

commonware_utils/channel/
tracked.rs

1//! A channel that tracks message delivery.
2//!
3//! This channel provides message delivery tracking. Each sent message includes
4//! a [Guard] that tracks when the message has been fully processed. When ALL
5//! references to the guard are dropped, the message is marked as delivered.
6//
7//! # Features
8//!
9//! - **Watermarks**: Get the highest sequence number where all messages up to it have been delivered
10//! - **Batches**: Assign batches to messages and track pending counts per batch
11//! - **Clonable Guards**: Guards can be cloned and shared; delivery happens when all clones are dropped
12//
13//! # Sequence Number Overflow
14//!
15//! Uses [u64] for sequence numbers. At 100 messages per nanosecond, overflow occurs after ~5.85 years.
16//! Systems requiring more message throughput should implement periodic resets or use external sequence management.
17//
18//! # Example
19//!
20//! ```
21//! use futures::executor::block_on;
22//! use commonware_utils::channel::tracked;
23//! block_on(async {
24//!     let (sender, mut receiver) = tracked::bounded::<String, u64>(10);
25//
26//!     // Send a message with batch ID
27//!     let sequence = sender.send(Some(1), "hello".into()).await.unwrap();
28//
29//!     // Check pending messages
30//!     assert_eq!(sender.pending(1), 1);
31//!     assert_eq!(sender.watermark(), 0);
32//
33//!     // Receive and process
34//!     let msg = receiver.recv().await.unwrap();
35//!     assert_eq!(msg.data, "hello");
36//
37//!     // Clone the guard - delivery won't happen until all clones are dropped
38//!     let guard_clone = msg.guard.clone();
39//!     drop(msg.guard);
40//!     assert_eq!(sender.watermark(), 0); // Still not delivered
41//
42//!     // Drop the last guard reference to mark as delivered
43//!     drop(guard_clone);
44//!     assert_eq!(sender.pending(1), 0);
45//!     assert_eq!(sender.watermark(), 1);
46//! });
47//! ```
48
49use super::mpsc::{
50    self,
51    error::{SendError, TryRecvError, TrySendError},
52};
53use crate::sync::Mutex;
54use futures::Stream;
55use std::{
56    collections::HashMap,
57    hash::Hash,
58    pin::Pin,
59    sync::Arc,
60    task::{Context, Poll},
61};
62
63/// A guard that tracks message delivery. When dropped, the message is marked as delivered.
64#[derive(Clone)]
65pub struct Guard<B: Eq + Hash + Clone> {
66    sequence: u64,
67    tracker: Arc<Mutex<State<B>>>,
68
69    batch: Option<B>,
70}
71
72impl<B: Eq + Hash + Clone> Drop for Guard<B> {
73    fn drop(&mut self) {
74        // Get the state
75        let mut state = self.tracker.lock();
76
77        // Mark the message as delivered
78        *state.pending.get_mut(&self.sequence).unwrap() = true;
79
80        // Update watermark if possible
81        let mut current_watermark = state.watermark;
82        while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
83            // If the next message is not delivered, we can stop
84            if !*delivered {
85                break;
86            }
87
88            // Remove the next message from the pending list
89            state.pending.remove(&(current_watermark + 1));
90            current_watermark += 1;
91            state.watermark = current_watermark;
92        }
93
94        // Update batch count (if necessary)
95        if let Some(batch) = &self.batch {
96            let count = state.batches.get_mut(batch).unwrap();
97            if *count > 1 {
98                *count -= 1;
99            } else {
100                state.batches.remove(batch);
101            }
102        }
103    }
104}
105
106/// A message containing data and a [Guard] that tracks delivery.
107pub struct Message<T, B: Eq + Hash + Clone> {
108    /// The data of the message.
109    pub data: T,
110    /// The [Guard] that tracks delivery.
111    ///
112    /// When no outstanding references to the guard exist, the message is considered delivered.
113    pub guard: Arc<Guard<B>>,
114}
115
116/// The state of the [Tracker].
117struct State<B> {
118    next: u64,
119    watermark: u64,
120    batches: HashMap<B, usize>,
121    pending: HashMap<u64, bool>,
122}
123
124impl<B> Default for State<B> {
125    fn default() -> Self {
126        Self {
127            next: 1,
128            watermark: 0,
129            batches: HashMap::new(),
130            pending: HashMap::new(),
131        }
132    }
133}
134
135/// Tracks delivery state across all messages.
136///
137/// Note on sequence overflow: Using u64 for sequence numbers provides ample headroom.
138/// At 100 messages per nanosecond, it would take ~5.85 years to overflow.
139/// For systems requiring longer uptime without restart, consider implementing
140/// sequence number wrapping with careful watermark handling.
141#[derive(Clone)]
142struct Tracker<B: Eq + Hash + Clone> {
143    state: Arc<Mutex<State<B>>>,
144}
145
146impl<B: Eq + Hash + Clone> Tracker<B> {
147    fn new() -> Self {
148        Self {
149            state: Arc::new(Mutex::new(State::default())),
150        }
151    }
152
153    fn guard(&self, batch: Option<B>) -> Guard<B> {
154        // Get state
155        let mut state = self.state.lock();
156
157        // Get the next sequence
158        let sequence = state.next;
159        state.next += 1;
160
161        // Track this sequence as not yet delivered
162        state.pending.insert(sequence, false);
163
164        // Update batch count if provided
165        if let Some(batch) = &batch {
166            *state.batches.entry(batch.clone()).or_insert(0) += 1;
167        }
168
169        Guard {
170            sequence,
171            tracker: self.state.clone(),
172
173            batch,
174        }
175    }
176}
177
178/// A sender that wraps `Sender` and tracks message delivery.
179#[derive(Clone)]
180pub struct Sender<T, B: Eq + Hash + Clone> {
181    inner: mpsc::Sender<Message<T, B>>,
182    tracker: Tracker<B>,
183}
184
185impl<T, B: Eq + Hash + Clone> Sender<T, B> {
186    /// Sends a message with an optional batch ID and returns a delivery guard.
187    pub async fn send(&self, batch: Option<B>, data: T) -> Result<u64, SendError<Message<T, B>>> {
188        // Create the guard
189        let guard = Arc::new(self.tracker.guard(batch));
190        let watermark = guard.sequence;
191
192        // Send the message
193        let msg = Message { data, guard };
194        self.inner.send(msg).await?;
195
196        Ok(watermark)
197    }
198
199    /// Tries to send a message without blocking.
200    pub fn try_send(&self, batch: Option<B>, data: T) -> Result<u64, TrySendError<Message<T, B>>> {
201        // Create the guard
202        let guard = Arc::new(self.tracker.guard(batch));
203        let watermark = guard.sequence;
204
205        // Send the message
206        let msg = Message { data, guard };
207        self.inner.try_send(msg)?;
208
209        Ok(watermark)
210    }
211
212    /// Returns the current delivery watermark (highest sequence number where all messages up to and including it have been delivered).
213    pub fn watermark(&self) -> u64 {
214        self.tracker.state.lock().watermark
215    }
216
217    /// Returns the number of pending messages for a specific batch.
218    pub fn pending(&self, batch: B) -> usize {
219        self.tracker
220            .state
221            .lock()
222            .batches
223            .get(&batch)
224            .copied()
225            .unwrap_or(0)
226    }
227}
228
229/// A receiver that wraps [mpsc::Receiver] and provides tracked messages.
230pub struct Receiver<T, B: Eq + Hash + Clone> {
231    inner: mpsc::Receiver<Message<T, B>>,
232}
233
234impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
235    /// Receives the next message.
236    pub async fn recv(&mut self) -> Option<Message<T, B>> {
237        self.inner.recv().await
238    }
239
240    /// Tries to receive a message without blocking.
241    pub fn try_recv(&mut self) -> Result<Message<T, B>, TryRecvError> {
242        self.inner.try_recv()
243    }
244}
245
246impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
247    type Item = Message<T, B>;
248
249    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250        self.inner.poll_recv(cx)
251    }
252}
253
254/// Create a new bounded channel with delivery tracking.
255pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
256    let (tx, rx) = mpsc::channel(buffer);
257    let sender = Sender {
258        inner: tx,
259        tracker: Tracker::new(),
260    };
261    let receiver = Receiver { inner: rx };
262    (sender, receiver)
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use futures::executor::block_on;
269
270    #[test]
271    fn test_basic() {
272        block_on(async move {
273            let (sender, mut receiver) = bounded::<i32, u64>(10);
274
275            // Send a message without batch ID
276            let watermark = sender.send(None, 42).await.unwrap();
277            assert_eq!(watermark, 1);
278            assert_eq!(sender.watermark(), 0);
279
280            // Receive the message but don't drop the guard yet
281            let msg = receiver.recv().await.unwrap();
282            assert_eq!(msg.data, 42);
283            assert_eq!(sender.watermark(), 0);
284
285            // Drop the guard to mark as delivered
286            drop(msg.guard);
287            assert_eq!(sender.watermark(), 1);
288        });
289    }
290
291    #[test]
292    fn test_batch_tracking() {
293        block_on(async move {
294            let (sender, mut receiver) = bounded::<String, u64>(10);
295
296            // Send messages with different batch IDs
297            let watermark1 = sender.send(Some(100), "msg1".into()).await.unwrap();
298            let watermark2 = sender.send(Some(100), "msg2".into()).await.unwrap();
299            let watermark3 = sender.send(Some(200), "msg3".into()).await.unwrap();
300
301            assert_eq!(watermark1, 1);
302            assert_eq!(watermark2, 2);
303            assert_eq!(watermark3, 3);
304            assert_eq!(sender.pending(100), 2);
305            assert_eq!(sender.pending(200), 1);
306            assert_eq!(sender.pending(300), 0);
307
308            // Receive and process first message
309            let msg1 = receiver.recv().await.unwrap();
310            assert_eq!(msg1.data, "msg1");
311            drop(msg1.guard);
312
313            assert_eq!(sender.pending(100), 1);
314            assert_eq!(sender.pending(200), 1);
315
316            // Receive and process remaining messages
317            let msg2 = receiver.recv().await.unwrap();
318            let msg3 = receiver.recv().await.unwrap();
319            drop(msg2.guard);
320            drop(msg3.guard);
321
322            assert_eq!(sender.pending(100), 0);
323            assert_eq!(sender.pending(200), 0);
324        });
325    }
326
327    #[test]
328    fn test_cloned_guards() {
329        block_on(async move {
330            let (sender, mut receiver) = bounded::<&str, u64>(10);
331
332            let watermark = sender.send(Some(1), "test").await.unwrap();
333            assert_eq!(watermark, 1);
334
335            // Receive the message immediately
336            let msg = receiver.recv().await.unwrap();
337            assert_eq!(msg.data, "test");
338
339            // The message guard and sender guard are the same
340            let msg_guard_clone1 = msg.guard.clone();
341            let msg_guard_clone2 = msg.guard.clone();
342
343            assert_eq!(sender.pending(1), 1);
344            assert_eq!(sender.watermark(), 0);
345
346            // Drop original and one clone
347            drop(msg.guard);
348            drop(msg_guard_clone1);
349            assert_eq!(sender.pending(1), 1);
350            assert_eq!(sender.watermark(), 0);
351
352            // Drop last clone
353            drop(msg_guard_clone2);
354            assert_eq!(sender.pending(1), 0);
355            assert_eq!(sender.watermark(), 1);
356        });
357    }
358
359    #[test]
360    fn test_try_send() {
361        block_on(async move {
362            let (sender, mut receiver) = bounded::<i32, u64>(2);
363
364            // Try send should work when buffer has space
365            let watermark1 = sender.try_send(Some(10), 1).unwrap();
366            let watermark2 = sender.try_send(Some(10), 2).unwrap();
367
368            assert_eq!(sender.pending(10), 2);
369            assert_eq!(watermark1, 1);
370            assert_eq!(watermark2, 2);
371
372            // Receive messages
373            let msg1 = receiver.recv().await.unwrap();
374            assert_eq!(msg1.data, 1);
375            drop(msg1.guard);
376
377            assert_eq!(sender.pending(10), 1);
378
379            let msg2 = receiver.recv().await.unwrap();
380            drop(msg2.guard);
381
382            assert_eq!(sender.pending(10), 0);
383        });
384    }
385
386    #[test]
387    fn test_channel_closure() {
388        block_on(async move {
389            let (sender, receiver) = bounded::<i32, u64>(10);
390
391            let _guard = sender.send(None, 1).await.unwrap();
392
393            // Drop receiver
394            drop(receiver);
395
396            // Next send should fail
397            assert!(sender.send(None, 2).await.is_err());
398        });
399    }
400}