async_inspect/channel/
mpsc.rs

1//! Tracked MPSC (Multi-Producer Single-Consumer) channel
2//!
3//! A drop-in replacement for `tokio::sync::mpsc` that automatically tracks
4//! message flow and integrates with async-inspect's visualization.
5
6use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc as tokio_mpsc;
10
11/// Create a bounded tracked mpsc channel.
12///
13/// # Arguments
14///
15/// * `capacity` - Maximum number of messages the channel can hold
16/// * `name` - A descriptive name for debugging and metrics
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use async_inspect::channel::mpsc;
22///
23/// #[tokio::main]
24/// async fn main() {
25///     let (tx, mut rx) = mpsc::channel::<i32>(100, "my_channel");
26///
27///     tx.send(42).await.unwrap();
28///     let value = rx.recv().await.unwrap();
29///     assert_eq!(value, 42);
30/// }
31/// ```
32pub fn channel<T>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
33    let (tx, rx) = tokio_mpsc::channel(capacity);
34    let metrics = Arc::new(ChannelMetricsTracker::new());
35    let name = Arc::new(name.into());
36    let capacity = capacity;
37
38    (
39        Sender {
40            inner: tx,
41            metrics: metrics.clone(),
42            name: name.clone(),
43            capacity,
44        },
45        Receiver {
46            inner: rx,
47            metrics,
48            name,
49            capacity,
50        },
51    )
52}
53
54/// Create an unbounded tracked mpsc channel.
55///
56/// # Arguments
57///
58/// * `name` - A descriptive name for debugging and metrics
59///
60/// # Example
61///
62/// ```rust,no_run
63/// use async_inspect::channel::mpsc;
64///
65/// #[tokio::main]
66/// async fn main() {
67///     let (tx, mut rx) = mpsc::unbounded_channel::<String>("events");
68///
69///     tx.send("event1".into()).unwrap();
70///     let event = rx.recv().await.unwrap();
71/// }
72/// ```
73pub fn unbounded_channel<T>(name: impl Into<String>) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
74    let (tx, rx) = tokio_mpsc::unbounded_channel();
75    let metrics = Arc::new(ChannelMetricsTracker::new());
76    let name = Arc::new(name.into());
77
78    (
79        UnboundedSender {
80            inner: tx,
81            metrics: metrics.clone(),
82            name: name.clone(),
83        },
84        UnboundedReceiver {
85            inner: rx,
86            metrics,
87            name,
88        },
89    )
90}
91
92/// Tracked bounded sender half of an mpsc channel.
93pub struct Sender<T> {
94    inner: tokio_mpsc::Sender<T>,
95    metrics: Arc<ChannelMetricsTracker>,
96    name: Arc<String>,
97    capacity: usize,
98}
99
100impl<T> Sender<T> {
101    /// Send a value, waiting if the channel is full.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the receiver has been dropped.
106    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
107        let timer = WaitTimer::start();
108
109        match self.inner.send(value).await {
110            Ok(()) => {
111                let wait_time = timer.elapsed_if_waited();
112                self.metrics.record_send(wait_time);
113                Ok(())
114            }
115            Err(tokio_mpsc::error::SendError(value)) => {
116                self.metrics.mark_closed();
117                Err(SendError(value))
118            }
119        }
120    }
121
122    /// Try to send a value without waiting.
123    pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
124        match self.inner.try_send(value) {
125            Ok(()) => {
126                self.metrics.record_send(None);
127                Ok(())
128            }
129            Err(tokio_mpsc::error::TrySendError::Full(value)) => Err(TrySendError::Full(value)),
130            Err(tokio_mpsc::error::TrySendError::Closed(value)) => {
131                self.metrics.mark_closed();
132                Err(TrySendError::Closed(value))
133            }
134        }
135    }
136
137    /// Check if the channel is closed.
138    #[must_use]
139    pub fn is_closed(&self) -> bool {
140        self.inner.is_closed()
141    }
142
143    /// Get the channel capacity.
144    #[must_use]
145    pub fn capacity(&self) -> usize {
146        self.inner.capacity()
147    }
148
149    /// Get the maximum capacity.
150    #[must_use]
151    pub fn max_capacity(&self) -> usize {
152        self.capacity
153    }
154
155    /// Get the channel name.
156    #[must_use]
157    pub fn name(&self) -> &str {
158        &self.name
159    }
160
161    /// Get current metrics for this channel.
162    #[must_use]
163    pub fn metrics(&self) -> ChannelMetrics {
164        let buffered = (self.capacity - self.inner.capacity()) as u64;
165        self.metrics.get_metrics(buffered)
166    }
167}
168
169impl<T> Clone for Sender<T> {
170    fn clone(&self) -> Self {
171        Self {
172            inner: self.inner.clone(),
173            metrics: self.metrics.clone(),
174            name: self.name.clone(),
175            capacity: self.capacity,
176        }
177    }
178}
179
180impl<T> fmt::Debug for Sender<T> {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        f.debug_struct("Sender")
183            .field("name", &self.name)
184            .field("capacity", &self.capacity)
185            .finish()
186    }
187}
188
189/// Tracked bounded receiver half of an mpsc channel.
190pub struct Receiver<T> {
191    inner: tokio_mpsc::Receiver<T>,
192    metrics: Arc<ChannelMetricsTracker>,
193    name: Arc<String>,
194    capacity: usize,
195}
196
197impl<T> Receiver<T> {
198    /// Receive a value, waiting if the channel is empty.
199    ///
200    /// Returns `None` if the channel is closed and empty.
201    pub async fn recv(&mut self) -> Option<T> {
202        let timer = WaitTimer::start();
203
204        if let Some(value) = self.inner.recv().await {
205            let wait_time = timer.elapsed_if_waited();
206            self.metrics.record_recv(wait_time);
207            Some(value)
208        } else {
209            self.metrics.mark_closed();
210            None
211        }
212    }
213
214    /// Try to receive a value without waiting.
215    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
216        match self.inner.try_recv() {
217            Ok(value) => {
218                self.metrics.record_recv(None);
219                Ok(value)
220            }
221            Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
222            Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
223                self.metrics.mark_closed();
224                Err(TryRecvError::Disconnected)
225            }
226        }
227    }
228
229    /// Close the receiver, preventing any new messages.
230    pub fn close(&mut self) {
231        self.inner.close();
232        self.metrics.mark_closed();
233    }
234
235    /// Get the channel name.
236    #[must_use]
237    pub fn name(&self) -> &str {
238        &self.name
239    }
240
241    /// Get current metrics for this channel.
242    #[must_use]
243    pub fn metrics(&self) -> ChannelMetrics {
244        // Approximate buffered count
245        let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
246        let received = self
247            .metrics
248            .received
249            .load(std::sync::atomic::Ordering::Relaxed);
250        let buffered = sent.saturating_sub(received);
251        self.metrics.get_metrics(buffered)
252    }
253}
254
255impl<T> fmt::Debug for Receiver<T> {
256    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257        f.debug_struct("Receiver")
258            .field("name", &self.name)
259            .field("capacity", &self.capacity)
260            .finish()
261    }
262}
263
264/// Tracked unbounded sender half of an mpsc channel.
265pub struct UnboundedSender<T> {
266    inner: tokio_mpsc::UnboundedSender<T>,
267    metrics: Arc<ChannelMetricsTracker>,
268    name: Arc<String>,
269}
270
271impl<T> UnboundedSender<T> {
272    /// Send a value (never blocks for unbounded channels).
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if the receiver has been dropped.
277    pub fn send(&self, value: T) -> Result<(), SendError<T>> {
278        match self.inner.send(value) {
279            Ok(()) => {
280                self.metrics.record_send(None);
281                Ok(())
282            }
283            Err(tokio_mpsc::error::SendError(value)) => {
284                self.metrics.mark_closed();
285                Err(SendError(value))
286            }
287        }
288    }
289
290    /// Check if the channel is closed.
291    #[must_use]
292    pub fn is_closed(&self) -> bool {
293        self.inner.is_closed()
294    }
295
296    /// Get the channel name.
297    #[must_use]
298    pub fn name(&self) -> &str {
299        &self.name
300    }
301
302    /// Get current metrics for this channel.
303    #[must_use]
304    pub fn metrics(&self) -> ChannelMetrics {
305        let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
306        let received = self
307            .metrics
308            .received
309            .load(std::sync::atomic::Ordering::Relaxed);
310        let buffered = sent.saturating_sub(received);
311        self.metrics.get_metrics(buffered)
312    }
313}
314
315impl<T> Clone for UnboundedSender<T> {
316    fn clone(&self) -> Self {
317        Self {
318            inner: self.inner.clone(),
319            metrics: self.metrics.clone(),
320            name: self.name.clone(),
321        }
322    }
323}
324
325impl<T> fmt::Debug for UnboundedSender<T> {
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        f.debug_struct("UnboundedSender")
328            .field("name", &self.name)
329            .finish()
330    }
331}
332
333/// Tracked unbounded receiver half of an mpsc channel.
334pub struct UnboundedReceiver<T> {
335    inner: tokio_mpsc::UnboundedReceiver<T>,
336    metrics: Arc<ChannelMetricsTracker>,
337    name: Arc<String>,
338}
339
340impl<T> UnboundedReceiver<T> {
341    /// Receive a value, waiting if the channel is empty.
342    pub async fn recv(&mut self) -> Option<T> {
343        let timer = WaitTimer::start();
344
345        if let Some(value) = self.inner.recv().await {
346            let wait_time = timer.elapsed_if_waited();
347            self.metrics.record_recv(wait_time);
348            Some(value)
349        } else {
350            self.metrics.mark_closed();
351            None
352        }
353    }
354
355    /// Try to receive a value without waiting.
356    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
357        match self.inner.try_recv() {
358            Ok(value) => {
359                self.metrics.record_recv(None);
360                Ok(value)
361            }
362            Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
363            Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
364                self.metrics.mark_closed();
365                Err(TryRecvError::Disconnected)
366            }
367        }
368    }
369
370    /// Close the receiver.
371    pub fn close(&mut self) {
372        self.inner.close();
373        self.metrics.mark_closed();
374    }
375
376    /// Get the channel name.
377    #[must_use]
378    pub fn name(&self) -> &str {
379        &self.name
380    }
381
382    /// Get current metrics for this channel.
383    #[must_use]
384    pub fn metrics(&self) -> ChannelMetrics {
385        let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
386        let received = self
387            .metrics
388            .received
389            .load(std::sync::atomic::Ordering::Relaxed);
390        let buffered = sent.saturating_sub(received);
391        self.metrics.get_metrics(buffered)
392    }
393}
394
395impl<T> fmt::Debug for UnboundedReceiver<T> {
396    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397        f.debug_struct("UnboundedReceiver")
398            .field("name", &self.name)
399            .finish()
400    }
401}
402
403/// Error returned when sending fails because the receiver was dropped.
404#[derive(Debug)]
405pub struct SendError<T>(pub T);
406
407impl<T> fmt::Display for SendError<T> {
408    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409        write!(f, "channel closed")
410    }
411}
412
413impl<T: fmt::Debug> std::error::Error for SendError<T> {}
414
415/// Error returned when `try_send` fails.
416#[derive(Debug)]
417pub enum TrySendError<T> {
418    /// Channel is full.
419    Full(T),
420    /// Channel is closed.
421    Closed(T),
422}
423
424impl<T> fmt::Display for TrySendError<T> {
425    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        match self {
427            TrySendError::Full(_) => write!(f, "channel full"),
428            TrySendError::Closed(_) => write!(f, "channel closed"),
429        }
430    }
431}
432
433impl<T: fmt::Debug> std::error::Error for TrySendError<T> {}
434
435/// Error returned when `try_recv` fails.
436#[derive(Debug, Clone, Copy, PartialEq, Eq)]
437pub enum TryRecvError {
438    /// Channel is empty.
439    Empty,
440    /// Channel is disconnected.
441    Disconnected,
442}
443
444impl fmt::Display for TryRecvError {
445    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446        match self {
447            TryRecvError::Empty => write!(f, "channel empty"),
448            TryRecvError::Disconnected => write!(f, "channel disconnected"),
449        }
450    }
451}
452
453impl std::error::Error for TryRecvError {}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[tokio::test]
460    async fn test_bounded_channel() {
461        let (tx, mut rx) = channel::<i32>(10, "test");
462
463        tx.send(42).await.unwrap();
464        tx.send(43).await.unwrap();
465
466        assert_eq!(rx.recv().await, Some(42));
467        assert_eq!(rx.recv().await, Some(43));
468
469        let metrics = rx.metrics();
470        assert_eq!(metrics.sent, 2);
471        assert_eq!(metrics.received, 2);
472    }
473
474    #[tokio::test]
475    async fn test_unbounded_channel() {
476        let (tx, mut rx) = unbounded_channel::<String>("events");
477
478        tx.send("hello".into()).unwrap();
479        tx.send("world".into()).unwrap();
480
481        assert_eq!(rx.recv().await, Some("hello".into()));
482        assert_eq!(rx.recv().await, Some("world".into()));
483
484        let metrics = rx.metrics();
485        assert_eq!(metrics.sent, 2);
486        assert_eq!(metrics.received, 2);
487    }
488
489    #[tokio::test]
490    async fn test_channel_close() {
491        let (tx, mut rx) = channel::<i32>(10, "test");
492
493        tx.send(1).await.unwrap();
494        drop(tx);
495
496        assert_eq!(rx.recv().await, Some(1));
497        assert_eq!(rx.recv().await, None);
498
499        let metrics = rx.metrics();
500        assert!(metrics.closed);
501    }
502
503    #[tokio::test]
504    async fn test_try_send_recv() {
505        let (tx, mut rx) = channel::<i32>(2, "test");
506
507        tx.try_send(1).unwrap();
508        tx.try_send(2).unwrap();
509
510        // Channel full
511        assert!(matches!(tx.try_send(3), Err(TrySendError::Full(3))));
512
513        assert_eq!(rx.try_recv().unwrap(), 1);
514        assert_eq!(rx.try_recv().unwrap(), 2);
515        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
516    }
517}