1pub use async_channel::{TryRecvError, TrySendError};
21
22use crate::metrics::{
23 DROPPED_LABEL, RECEIVED_LABEL, SENT_LABEL, UNBOUNDED_CHANNELS_COUNTER, UNBOUNDED_CHANNELS_SIZE,
24};
25use async_channel::{Receiver, Sender};
26use futures::{
27 stream::{FusedStream, Stream},
28 task::{Context, Poll},
29};
30use log::error;
31use pezsp_arithmetic::traits::SaturatedConversion;
32use std::{
33 backtrace::Backtrace,
34 pin::Pin,
35 sync::{
36 atomic::{AtomicBool, Ordering},
37 Arc,
38 },
39};
40
41#[derive(Debug)]
44pub struct TracingUnboundedSender<T> {
45 inner: Sender<T>,
46 name: &'static str,
47 queue_size_warning: usize,
48 warning_fired: Arc<AtomicBool>,
49 creation_backtrace: Arc<Backtrace>,
50}
51
52impl<T> Clone for TracingUnboundedSender<T> {
54 fn clone(&self) -> Self {
55 Self {
56 inner: self.inner.clone(),
57 name: self.name,
58 queue_size_warning: self.queue_size_warning,
59 warning_fired: self.warning_fired.clone(),
60 creation_backtrace: self.creation_backtrace.clone(),
61 }
62 }
63}
64
65#[derive(Debug)]
68pub struct TracingUnboundedReceiver<T> {
69 inner: Receiver<T>,
70 name: &'static str,
71}
72
73pub fn tracing_unbounded<T>(
77 name: &'static str,
78 queue_size_warning: usize,
79) -> (TracingUnboundedSender<T>, TracingUnboundedReceiver<T>) {
80 let (s, r) = async_channel::unbounded();
81 let sender = TracingUnboundedSender {
82 inner: s,
83 name,
84 queue_size_warning,
85 warning_fired: Arc::new(AtomicBool::new(false)),
86 creation_backtrace: Arc::new(Backtrace::force_capture()),
87 };
88 let receiver = TracingUnboundedReceiver { inner: r, name: name.into() };
89 (sender, receiver)
90}
91
92impl<T> TracingUnboundedSender<T> {
93 pub fn is_closed(&self) -> bool {
95 self.inner.is_closed()
96 }
97
98 pub fn close(&self) -> bool {
100 self.inner.close()
101 }
102
103 pub fn unbounded_send(&self, msg: T) -> Result<(), TrySendError<T>> {
105 self.inner.try_send(msg).inspect(|_| {
106 UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[self.name, SENT_LABEL]).inc();
107 UNBOUNDED_CHANNELS_SIZE
108 .with_label_values(&[self.name])
109 .set(self.inner.len().saturated_into());
110
111 if self.inner.len() >= self.queue_size_warning
112 && self
113 .warning_fired
114 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
115 .is_ok()
116 {
117 error!(
118 "The number of unprocessed messages in channel `{}` exceeded {}.\n\
119 The channel was created at:\n{}\n
120 Last message was sent from:\n{}",
121 self.name,
122 self.queue_size_warning,
123 self.creation_backtrace,
124 Backtrace::force_capture(),
125 );
126 }
127 })
128 }
129
130 pub fn len(&self) -> usize {
132 self.inner.len()
133 }
134}
135
136impl<T> TracingUnboundedReceiver<T> {
137 pub fn close(&mut self) -> bool {
139 self.inner.close()
140 }
141
142 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
145 self.inner.try_recv().inspect(|_| {
146 UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[self.name, RECEIVED_LABEL]).inc();
147 UNBOUNDED_CHANNELS_SIZE
148 .with_label_values(&[self.name])
149 .set(self.inner.len().saturated_into());
150 })
151 }
152
153 pub fn len(&self) -> usize {
155 self.inner.len()
156 }
157
158 pub fn name(&self) -> &'static str {
160 self.name
161 }
162}
163
164impl<T> Drop for TracingUnboundedReceiver<T> {
165 fn drop(&mut self) {
166 self.close();
168 let count = self.inner.len();
170 if count > 0 {
172 UNBOUNDED_CHANNELS_COUNTER
173 .with_label_values(&[self.name, DROPPED_LABEL])
174 .inc_by(count.saturated_into());
175 }
176 UNBOUNDED_CHANNELS_SIZE.with_label_values(&[self.name]).set(0);
178 while let Ok(_) = self.inner.try_recv() {}
182 }
183}
184
185impl<T> Unpin for TracingUnboundedReceiver<T> {}
186
187impl<T> Stream for TracingUnboundedReceiver<T> {
188 type Item = T;
189
190 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
191 let s = self.get_mut();
192 match Pin::new(&mut s.inner).poll_next(cx) {
193 Poll::Ready(msg) => {
194 if msg.is_some() {
195 UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[s.name, RECEIVED_LABEL]).inc();
196 UNBOUNDED_CHANNELS_SIZE
197 .with_label_values(&[s.name])
198 .set(s.inner.len().saturated_into());
199 }
200 Poll::Ready(msg)
201 },
202 Poll::Pending => Poll::Pending,
203 }
204 }
205}
206
207impl<T> FusedStream for TracingUnboundedReceiver<T> {
208 fn is_terminated(&self) -> bool {
209 self.inner.is_terminated()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::tracing_unbounded;
216 use async_channel::{self, RecvError, TryRecvError};
217
218 #[test]
219 fn test_tracing_unbounded_receiver_drop() {
220 let (tracing_unbounded_sender, tracing_unbounded_receiver) =
221 tracing_unbounded("test-receiver-drop", 10);
222 let (tx, rx) = async_channel::unbounded::<usize>();
223
224 tracing_unbounded_sender.unbounded_send(tx).unwrap();
225 drop(tracing_unbounded_receiver);
226
227 assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
228 assert_eq!(rx.recv_blocking(), Err(RecvError));
229 }
230}