1use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc as tokio_mpsc;
10
11pub fn channel<T>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
33 let (tx, rx) = tokio_mpsc::channel(capacity);
34 let metrics = Arc::new(ChannelMetricsTracker::new());
35 let name = Arc::new(name.into());
36 let capacity = capacity;
37
38 (
39 Sender {
40 inner: tx,
41 metrics: metrics.clone(),
42 name: name.clone(),
43 capacity,
44 },
45 Receiver {
46 inner: rx,
47 metrics,
48 name,
49 capacity,
50 },
51 )
52}
53
54pub fn unbounded_channel<T>(name: impl Into<String>) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
74 let (tx, rx) = tokio_mpsc::unbounded_channel();
75 let metrics = Arc::new(ChannelMetricsTracker::new());
76 let name = Arc::new(name.into());
77
78 (
79 UnboundedSender {
80 inner: tx,
81 metrics: metrics.clone(),
82 name: name.clone(),
83 },
84 UnboundedReceiver {
85 inner: rx,
86 metrics,
87 name,
88 },
89 )
90}
91
92pub struct Sender<T> {
94 inner: tokio_mpsc::Sender<T>,
95 metrics: Arc<ChannelMetricsTracker>,
96 name: Arc<String>,
97 capacity: usize,
98}
99
100impl<T> Sender<T> {
101 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
107 let timer = WaitTimer::start();
108
109 match self.inner.send(value).await {
110 Ok(()) => {
111 let wait_time = timer.elapsed_if_waited();
112 self.metrics.record_send(wait_time);
113 Ok(())
114 }
115 Err(tokio_mpsc::error::SendError(value)) => {
116 self.metrics.mark_closed();
117 Err(SendError(value))
118 }
119 }
120 }
121
122 pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
124 match self.inner.try_send(value) {
125 Ok(()) => {
126 self.metrics.record_send(None);
127 Ok(())
128 }
129 Err(tokio_mpsc::error::TrySendError::Full(value)) => Err(TrySendError::Full(value)),
130 Err(tokio_mpsc::error::TrySendError::Closed(value)) => {
131 self.metrics.mark_closed();
132 Err(TrySendError::Closed(value))
133 }
134 }
135 }
136
137 #[must_use]
139 pub fn is_closed(&self) -> bool {
140 self.inner.is_closed()
141 }
142
143 #[must_use]
145 pub fn capacity(&self) -> usize {
146 self.inner.capacity()
147 }
148
149 #[must_use]
151 pub fn max_capacity(&self) -> usize {
152 self.capacity
153 }
154
155 #[must_use]
157 pub fn name(&self) -> &str {
158 &self.name
159 }
160
161 #[must_use]
163 pub fn metrics(&self) -> ChannelMetrics {
164 let buffered = (self.capacity - self.inner.capacity()) as u64;
165 self.metrics.get_metrics(buffered)
166 }
167}
168
169impl<T> Clone for Sender<T> {
170 fn clone(&self) -> Self {
171 Self {
172 inner: self.inner.clone(),
173 metrics: self.metrics.clone(),
174 name: self.name.clone(),
175 capacity: self.capacity,
176 }
177 }
178}
179
180impl<T> fmt::Debug for Sender<T> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 f.debug_struct("Sender")
183 .field("name", &self.name)
184 .field("capacity", &self.capacity)
185 .finish()
186 }
187}
188
189pub struct Receiver<T> {
191 inner: tokio_mpsc::Receiver<T>,
192 metrics: Arc<ChannelMetricsTracker>,
193 name: Arc<String>,
194 capacity: usize,
195}
196
197impl<T> Receiver<T> {
198 pub async fn recv(&mut self) -> Option<T> {
202 let timer = WaitTimer::start();
203
204 if let Some(value) = self.inner.recv().await {
205 let wait_time = timer.elapsed_if_waited();
206 self.metrics.record_recv(wait_time);
207 Some(value)
208 } else {
209 self.metrics.mark_closed();
210 None
211 }
212 }
213
214 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
216 match self.inner.try_recv() {
217 Ok(value) => {
218 self.metrics.record_recv(None);
219 Ok(value)
220 }
221 Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
222 Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
223 self.metrics.mark_closed();
224 Err(TryRecvError::Disconnected)
225 }
226 }
227 }
228
229 pub fn close(&mut self) {
231 self.inner.close();
232 self.metrics.mark_closed();
233 }
234
235 #[must_use]
237 pub fn name(&self) -> &str {
238 &self.name
239 }
240
241 #[must_use]
243 pub fn metrics(&self) -> ChannelMetrics {
244 let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
246 let received = self
247 .metrics
248 .received
249 .load(std::sync::atomic::Ordering::Relaxed);
250 let buffered = sent.saturating_sub(received);
251 self.metrics.get_metrics(buffered)
252 }
253}
254
255impl<T> fmt::Debug for Receiver<T> {
256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 f.debug_struct("Receiver")
258 .field("name", &self.name)
259 .field("capacity", &self.capacity)
260 .finish()
261 }
262}
263
264pub struct UnboundedSender<T> {
266 inner: tokio_mpsc::UnboundedSender<T>,
267 metrics: Arc<ChannelMetricsTracker>,
268 name: Arc<String>,
269}
270
271impl<T> UnboundedSender<T> {
272 pub fn send(&self, value: T) -> Result<(), SendError<T>> {
278 match self.inner.send(value) {
279 Ok(()) => {
280 self.metrics.record_send(None);
281 Ok(())
282 }
283 Err(tokio_mpsc::error::SendError(value)) => {
284 self.metrics.mark_closed();
285 Err(SendError(value))
286 }
287 }
288 }
289
290 #[must_use]
292 pub fn is_closed(&self) -> bool {
293 self.inner.is_closed()
294 }
295
296 #[must_use]
298 pub fn name(&self) -> &str {
299 &self.name
300 }
301
302 #[must_use]
304 pub fn metrics(&self) -> ChannelMetrics {
305 let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
306 let received = self
307 .metrics
308 .received
309 .load(std::sync::atomic::Ordering::Relaxed);
310 let buffered = sent.saturating_sub(received);
311 self.metrics.get_metrics(buffered)
312 }
313}
314
315impl<T> Clone for UnboundedSender<T> {
316 fn clone(&self) -> Self {
317 Self {
318 inner: self.inner.clone(),
319 metrics: self.metrics.clone(),
320 name: self.name.clone(),
321 }
322 }
323}
324
325impl<T> fmt::Debug for UnboundedSender<T> {
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 f.debug_struct("UnboundedSender")
328 .field("name", &self.name)
329 .finish()
330 }
331}
332
333pub struct UnboundedReceiver<T> {
335 inner: tokio_mpsc::UnboundedReceiver<T>,
336 metrics: Arc<ChannelMetricsTracker>,
337 name: Arc<String>,
338}
339
340impl<T> UnboundedReceiver<T> {
341 pub async fn recv(&mut self) -> Option<T> {
343 let timer = WaitTimer::start();
344
345 if let Some(value) = self.inner.recv().await {
346 let wait_time = timer.elapsed_if_waited();
347 self.metrics.record_recv(wait_time);
348 Some(value)
349 } else {
350 self.metrics.mark_closed();
351 None
352 }
353 }
354
355 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
357 match self.inner.try_recv() {
358 Ok(value) => {
359 self.metrics.record_recv(None);
360 Ok(value)
361 }
362 Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
363 Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
364 self.metrics.mark_closed();
365 Err(TryRecvError::Disconnected)
366 }
367 }
368 }
369
370 pub fn close(&mut self) {
372 self.inner.close();
373 self.metrics.mark_closed();
374 }
375
376 #[must_use]
378 pub fn name(&self) -> &str {
379 &self.name
380 }
381
382 #[must_use]
384 pub fn metrics(&self) -> ChannelMetrics {
385 let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
386 let received = self
387 .metrics
388 .received
389 .load(std::sync::atomic::Ordering::Relaxed);
390 let buffered = sent.saturating_sub(received);
391 self.metrics.get_metrics(buffered)
392 }
393}
394
395impl<T> fmt::Debug for UnboundedReceiver<T> {
396 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397 f.debug_struct("UnboundedReceiver")
398 .field("name", &self.name)
399 .finish()
400 }
401}
402
403#[derive(Debug)]
405pub struct SendError<T>(pub T);
406
407impl<T> fmt::Display for SendError<T> {
408 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409 write!(f, "channel closed")
410 }
411}
412
413impl<T: fmt::Debug> std::error::Error for SendError<T> {}
414
415#[derive(Debug)]
417pub enum TrySendError<T> {
418 Full(T),
420 Closed(T),
422}
423
424impl<T> fmt::Display for TrySendError<T> {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 match self {
427 TrySendError::Full(_) => write!(f, "channel full"),
428 TrySendError::Closed(_) => write!(f, "channel closed"),
429 }
430 }
431}
432
433impl<T: fmt::Debug> std::error::Error for TrySendError<T> {}
434
435#[derive(Debug, Clone, Copy, PartialEq, Eq)]
437pub enum TryRecvError {
438 Empty,
440 Disconnected,
442}
443
444impl fmt::Display for TryRecvError {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 match self {
447 TryRecvError::Empty => write!(f, "channel empty"),
448 TryRecvError::Disconnected => write!(f, "channel disconnected"),
449 }
450 }
451}
452
453impl std::error::Error for TryRecvError {}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[tokio::test]
460 async fn test_bounded_channel() {
461 let (tx, mut rx) = channel::<i32>(10, "test");
462
463 tx.send(42).await.unwrap();
464 tx.send(43).await.unwrap();
465
466 assert_eq!(rx.recv().await, Some(42));
467 assert_eq!(rx.recv().await, Some(43));
468
469 let metrics = rx.metrics();
470 assert_eq!(metrics.sent, 2);
471 assert_eq!(metrics.received, 2);
472 }
473
474 #[tokio::test]
475 async fn test_unbounded_channel() {
476 let (tx, mut rx) = unbounded_channel::<String>("events");
477
478 tx.send("hello".into()).unwrap();
479 tx.send("world".into()).unwrap();
480
481 assert_eq!(rx.recv().await, Some("hello".into()));
482 assert_eq!(rx.recv().await, Some("world".into()));
483
484 let metrics = rx.metrics();
485 assert_eq!(metrics.sent, 2);
486 assert_eq!(metrics.received, 2);
487 }
488
489 #[tokio::test]
490 async fn test_channel_close() {
491 let (tx, mut rx) = channel::<i32>(10, "test");
492
493 tx.send(1).await.unwrap();
494 drop(tx);
495
496 assert_eq!(rx.recv().await, Some(1));
497 assert_eq!(rx.recv().await, None);
498
499 let metrics = rx.metrics();
500 assert!(metrics.closed);
501 }
502
503 #[tokio::test]
504 async fn test_try_send_recv() {
505 let (tx, mut rx) = channel::<i32>(2, "test");
506
507 tx.try_send(1).unwrap();
508 tx.try_send(2).unwrap();
509
510 assert!(matches!(tx.try_send(3), Err(TrySendError::Full(3))));
512
513 assert_eq!(rx.try_recv().unwrap(), 1);
514 assert_eq!(rx.try_recv().unwrap(), 2);
515 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
516 }
517}