async_inspect/channel/
broadcast.rs

1//! Tracked broadcast channel
2//!
3//! A drop-in replacement for `tokio::sync::broadcast` that tracks message flow.
4
5use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
6use std::fmt;
7use std::sync::Arc;
8use tokio::sync::broadcast as tokio_broadcast;
9
10/// Create a tracked broadcast channel.
11///
12/// # Arguments
13///
14/// * `capacity` - Maximum number of messages the channel can hold
15/// * `name` - A descriptive name for debugging and metrics
16///
17/// # Example
18///
19/// ```rust,no_run
20/// use async_inspect::channel::broadcast;
21///
22/// #[tokio::main]
23/// async fn main() {
24///     let (tx, mut rx1) = broadcast::channel::<String>(16, "events");
25///     let mut rx2 = tx.subscribe();
26///
27///     tx.send("hello".into()).unwrap();
28///
29///     assert_eq!(rx1.recv().await.unwrap(), "hello");
30///     assert_eq!(rx2.recv().await.unwrap(), "hello");
31/// }
32/// ```
33pub fn channel<T: Clone>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
34    let (tx, rx) = tokio_broadcast::channel(capacity);
35    let metrics = Arc::new(ChannelMetricsTracker::new());
36    let name = Arc::new(name.into());
37
38    (
39        Sender {
40            inner: tx,
41            metrics: metrics.clone(),
42            name: name.clone(),
43            capacity,
44        },
45        Receiver {
46            inner: rx,
47            metrics,
48            name,
49        },
50    )
51}
52
53/// Tracked sender half of a broadcast channel.
54pub struct Sender<T> {
55    inner: tokio_broadcast::Sender<T>,
56    metrics: Arc<ChannelMetricsTracker>,
57    name: Arc<String>,
58    capacity: usize,
59}
60
61impl<T: Clone> Sender<T> {
62    /// Send a value to all receivers.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if there are no receivers.
67    pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
68        match self.inner.send(value) {
69            Ok(n) => {
70                self.metrics.record_send(None);
71                Ok(n)
72            }
73            Err(tokio_broadcast::error::SendError(value)) => {
74                self.metrics.mark_closed();
75                Err(SendError(value))
76            }
77        }
78    }
79
80    /// Create a new receiver subscribed to this sender.
81    #[must_use]
82    pub fn subscribe(&self) -> Receiver<T> {
83        Receiver {
84            inner: self.inner.subscribe(),
85            metrics: self.metrics.clone(),
86            name: self.name.clone(),
87        }
88    }
89
90    /// Get the number of active receivers.
91    #[must_use]
92    pub fn receiver_count(&self) -> usize {
93        self.inner.receiver_count()
94    }
95
96    /// Get the channel capacity.
97    #[must_use]
98    pub fn capacity(&self) -> usize {
99        self.capacity
100    }
101
102    /// Get the channel name.
103    #[must_use]
104    pub fn name(&self) -> &str {
105        &self.name
106    }
107
108    /// Get current metrics for this channel.
109    #[must_use]
110    pub fn metrics(&self) -> ChannelMetrics {
111        self.metrics.get_metrics(0)
112    }
113}
114
115impl<T> Clone for Sender<T> {
116    fn clone(&self) -> Self {
117        Self {
118            inner: self.inner.clone(),
119            metrics: self.metrics.clone(),
120            name: self.name.clone(),
121            capacity: self.capacity,
122        }
123    }
124}
125
126impl<T: Clone> fmt::Debug for Sender<T> {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        f.debug_struct("broadcast::Sender")
129            .field("name", &self.name)
130            .field("capacity", &self.capacity)
131            .field("receivers", &self.receiver_count())
132            .finish()
133    }
134}
135
136/// Tracked receiver half of a broadcast channel.
137pub struct Receiver<T> {
138    inner: tokio_broadcast::Receiver<T>,
139    metrics: Arc<ChannelMetricsTracker>,
140    name: Arc<String>,
141}
142
143impl<T: Clone> Receiver<T> {
144    /// Receive a value, waiting if necessary.
145    pub async fn recv(&mut self) -> Result<T, RecvError> {
146        let timer = WaitTimer::start();
147
148        match self.inner.recv().await {
149            Ok(value) => {
150                let wait_time = timer.elapsed_if_waited();
151                self.metrics.record_recv(wait_time);
152                Ok(value)
153            }
154            Err(tokio_broadcast::error::RecvError::Closed) => {
155                self.metrics.mark_closed();
156                Err(RecvError::Closed)
157            }
158            Err(tokio_broadcast::error::RecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
159        }
160    }
161
162    /// Try to receive without waiting.
163    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
164        match self.inner.try_recv() {
165            Ok(value) => {
166                self.metrics.record_recv(None);
167                Ok(value)
168            }
169            Err(tokio_broadcast::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
170            Err(tokio_broadcast::error::TryRecvError::Closed) => {
171                self.metrics.mark_closed();
172                Err(TryRecvError::Closed)
173            }
174            Err(tokio_broadcast::error::TryRecvError::Lagged(n)) => Err(TryRecvError::Lagged(n)),
175        }
176    }
177
178    /// Get the channel name.
179    #[must_use]
180    pub fn name(&self) -> &str {
181        &self.name
182    }
183
184    /// Get current metrics for this channel.
185    #[must_use]
186    pub fn metrics(&self) -> ChannelMetrics {
187        self.metrics.get_metrics(0)
188    }
189}
190
191impl<T> fmt::Debug for Receiver<T> {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        f.debug_struct("broadcast::Receiver")
194            .field("name", &self.name)
195            .finish()
196    }
197}
198
199/// Error returned when sending fails.
200#[derive(Debug)]
201pub struct SendError<T>(pub T);
202
203impl<T> fmt::Display for SendError<T> {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        write!(f, "channel closed (no receivers)")
206    }
207}
208
209impl<T: fmt::Debug> std::error::Error for SendError<T> {}
210
211/// Error returned when receiving fails.
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub enum RecvError {
214    /// The channel is closed.
215    Closed,
216    /// The receiver lagged too far behind (missed messages).
217    Lagged(u64),
218}
219
220impl fmt::Display for RecvError {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        match self {
223            RecvError::Closed => write!(f, "channel closed"),
224            RecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
225        }
226    }
227}
228
229impl std::error::Error for RecvError {}
230
231/// Error returned when `try_recv` fails.
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum TryRecvError {
234    /// The channel is empty.
235    Empty,
236    /// The channel is closed.
237    Closed,
238    /// The receiver lagged too far behind.
239    Lagged(u64),
240}
241
242impl fmt::Display for TryRecvError {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        match self {
245            TryRecvError::Empty => write!(f, "channel empty"),
246            TryRecvError::Closed => write!(f, "channel closed"),
247            TryRecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
248        }
249    }
250}
251
252impl std::error::Error for TryRecvError {}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[tokio::test]
259    async fn test_broadcast_basic() {
260        let (tx, mut rx1) = channel::<i32>(16, "test");
261        let mut rx2 = tx.subscribe();
262
263        tx.send(42).unwrap();
264
265        assert_eq!(rx1.recv().await.unwrap(), 42);
266        assert_eq!(rx2.recv().await.unwrap(), 42);
267
268        let metrics = tx.metrics();
269        assert_eq!(metrics.sent, 1);
270    }
271
272    #[tokio::test]
273    async fn test_broadcast_multiple_sends() {
274        let (tx, mut rx) = channel::<i32>(16, "test");
275
276        tx.send(1).unwrap();
277        tx.send(2).unwrap();
278        tx.send(3).unwrap();
279
280        assert_eq!(rx.recv().await.unwrap(), 1);
281        assert_eq!(rx.recv().await.unwrap(), 2);
282        assert_eq!(rx.recv().await.unwrap(), 3);
283
284        let metrics = rx.metrics();
285        assert_eq!(metrics.received, 3);
286    }
287
288    #[tokio::test]
289    async fn test_broadcast_receiver_count() {
290        let (tx, _rx1) = channel::<i32>(16, "test");
291        assert_eq!(tx.receiver_count(), 1);
292
293        let _rx2 = tx.subscribe();
294        assert_eq!(tx.receiver_count(), 2);
295
296        let _rx3 = tx.subscribe();
297        assert_eq!(tx.receiver_count(), 3);
298    }
299
300    #[tokio::test]
301    async fn test_broadcast_no_receivers() {
302        let (tx, rx) = channel::<i32>(16, "test");
303        drop(rx);
304
305        assert!(tx.send(42).is_err());
306    }
307}