1use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
6use std::fmt;
7use std::sync::Arc;
8use tokio::sync::broadcast as tokio_broadcast;
9
10pub fn channel<T: Clone>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
34 let (tx, rx) = tokio_broadcast::channel(capacity);
35 let metrics = Arc::new(ChannelMetricsTracker::new());
36 let name = Arc::new(name.into());
37
38 (
39 Sender {
40 inner: tx,
41 metrics: metrics.clone(),
42 name: name.clone(),
43 capacity,
44 },
45 Receiver {
46 inner: rx,
47 metrics,
48 name,
49 },
50 )
51}
52
53pub struct Sender<T> {
55 inner: tokio_broadcast::Sender<T>,
56 metrics: Arc<ChannelMetricsTracker>,
57 name: Arc<String>,
58 capacity: usize,
59}
60
61impl<T: Clone> Sender<T> {
62 pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
68 match self.inner.send(value) {
69 Ok(n) => {
70 self.metrics.record_send(None);
71 Ok(n)
72 }
73 Err(tokio_broadcast::error::SendError(value)) => {
74 self.metrics.mark_closed();
75 Err(SendError(value))
76 }
77 }
78 }
79
80 #[must_use]
82 pub fn subscribe(&self) -> Receiver<T> {
83 Receiver {
84 inner: self.inner.subscribe(),
85 metrics: self.metrics.clone(),
86 name: self.name.clone(),
87 }
88 }
89
90 #[must_use]
92 pub fn receiver_count(&self) -> usize {
93 self.inner.receiver_count()
94 }
95
96 #[must_use]
98 pub fn capacity(&self) -> usize {
99 self.capacity
100 }
101
102 #[must_use]
104 pub fn name(&self) -> &str {
105 &self.name
106 }
107
108 #[must_use]
110 pub fn metrics(&self) -> ChannelMetrics {
111 self.metrics.get_metrics(0)
112 }
113}
114
115impl<T> Clone for Sender<T> {
116 fn clone(&self) -> Self {
117 Self {
118 inner: self.inner.clone(),
119 metrics: self.metrics.clone(),
120 name: self.name.clone(),
121 capacity: self.capacity,
122 }
123 }
124}
125
126impl<T: Clone> fmt::Debug for Sender<T> {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_struct("broadcast::Sender")
129 .field("name", &self.name)
130 .field("capacity", &self.capacity)
131 .field("receivers", &self.receiver_count())
132 .finish()
133 }
134}
135
136pub struct Receiver<T> {
138 inner: tokio_broadcast::Receiver<T>,
139 metrics: Arc<ChannelMetricsTracker>,
140 name: Arc<String>,
141}
142
143impl<T: Clone> Receiver<T> {
144 pub async fn recv(&mut self) -> Result<T, RecvError> {
146 let timer = WaitTimer::start();
147
148 match self.inner.recv().await {
149 Ok(value) => {
150 let wait_time = timer.elapsed_if_waited();
151 self.metrics.record_recv(wait_time);
152 Ok(value)
153 }
154 Err(tokio_broadcast::error::RecvError::Closed) => {
155 self.metrics.mark_closed();
156 Err(RecvError::Closed)
157 }
158 Err(tokio_broadcast::error::RecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
159 }
160 }
161
162 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
164 match self.inner.try_recv() {
165 Ok(value) => {
166 self.metrics.record_recv(None);
167 Ok(value)
168 }
169 Err(tokio_broadcast::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
170 Err(tokio_broadcast::error::TryRecvError::Closed) => {
171 self.metrics.mark_closed();
172 Err(TryRecvError::Closed)
173 }
174 Err(tokio_broadcast::error::TryRecvError::Lagged(n)) => Err(TryRecvError::Lagged(n)),
175 }
176 }
177
178 #[must_use]
180 pub fn name(&self) -> &str {
181 &self.name
182 }
183
184 #[must_use]
186 pub fn metrics(&self) -> ChannelMetrics {
187 self.metrics.get_metrics(0)
188 }
189}
190
191impl<T> fmt::Debug for Receiver<T> {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 f.debug_struct("broadcast::Receiver")
194 .field("name", &self.name)
195 .finish()
196 }
197}
198
199#[derive(Debug)]
201pub struct SendError<T>(pub T);
202
203impl<T> fmt::Display for SendError<T> {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 write!(f, "channel closed (no receivers)")
206 }
207}
208
209impl<T: fmt::Debug> std::error::Error for SendError<T> {}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub enum RecvError {
214 Closed,
216 Lagged(u64),
218}
219
220impl fmt::Display for RecvError {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 match self {
223 RecvError::Closed => write!(f, "channel closed"),
224 RecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
225 }
226 }
227}
228
229impl std::error::Error for RecvError {}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum TryRecvError {
234 Empty,
236 Closed,
238 Lagged(u64),
240}
241
242impl fmt::Display for TryRecvError {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 match self {
245 TryRecvError::Empty => write!(f, "channel empty"),
246 TryRecvError::Closed => write!(f, "channel closed"),
247 TryRecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
248 }
249 }
250}
251
252impl std::error::Error for TryRecvError {}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[tokio::test]
259 async fn test_broadcast_basic() {
260 let (tx, mut rx1) = channel::<i32>(16, "test");
261 let mut rx2 = tx.subscribe();
262
263 tx.send(42).unwrap();
264
265 assert_eq!(rx1.recv().await.unwrap(), 42);
266 assert_eq!(rx2.recv().await.unwrap(), 42);
267
268 let metrics = tx.metrics();
269 assert_eq!(metrics.sent, 1);
270 }
271
272 #[tokio::test]
273 async fn test_broadcast_multiple_sends() {
274 let (tx, mut rx) = channel::<i32>(16, "test");
275
276 tx.send(1).unwrap();
277 tx.send(2).unwrap();
278 tx.send(3).unwrap();
279
280 assert_eq!(rx.recv().await.unwrap(), 1);
281 assert_eq!(rx.recv().await.unwrap(), 2);
282 assert_eq!(rx.recv().await.unwrap(), 3);
283
284 let metrics = rx.metrics();
285 assert_eq!(metrics.received, 3);
286 }
287
288 #[tokio::test]
289 async fn test_broadcast_receiver_count() {
290 let (tx, _rx1) = channel::<i32>(16, "test");
291 assert_eq!(tx.receiver_count(), 1);
292
293 let _rx2 = tx.subscribe();
294 assert_eq!(tx.receiver_count(), 2);
295
296 let _rx3 = tx.subscribe();
297 assert_eq!(tx.receiver_count(), 3);
298 }
299
300 #[tokio::test]
301 async fn test_broadcast_no_receivers() {
302 let (tx, rx) = channel::<i32>(16, "test");
303 drop(rx);
304
305 assert!(tx.send(42).is_err());
306 }
307}