pezsc_utils/
mpsc.rs

1// This file is part of Bizinikiwi.
2
3// Copyright (C) Parity Technologies (UK) Ltd. and Dijital Kurdistan Tech Institute
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Code to meter unbounded channels.
19
20pub use async_channel::{TryRecvError, TrySendError};
21
22use crate::metrics::{
23	DROPPED_LABEL, RECEIVED_LABEL, SENT_LABEL, UNBOUNDED_CHANNELS_COUNTER, UNBOUNDED_CHANNELS_SIZE,
24};
25use async_channel::{Receiver, Sender};
26use futures::{
27	stream::{FusedStream, Stream},
28	task::{Context, Poll},
29};
30use log::error;
31use pezsp_arithmetic::traits::SaturatedConversion;
32use std::{
33	backtrace::Backtrace,
34	pin::Pin,
35	sync::{
36		atomic::{AtomicBool, Ordering},
37		Arc,
38	},
39};
40
41/// Wrapper Type around [`async_channel::Sender`] that increases the global
42/// measure when a message is added.
43#[derive(Debug)]
44pub struct TracingUnboundedSender<T> {
45	inner: Sender<T>,
46	name: &'static str,
47	queue_size_warning: usize,
48	warning_fired: Arc<AtomicBool>,
49	creation_backtrace: Arc<Backtrace>,
50}
51
52// Strangely, deriving `Clone` requires that `T` is also `Clone`.
53impl<T> Clone for TracingUnboundedSender<T> {
54	fn clone(&self) -> Self {
55		Self {
56			inner: self.inner.clone(),
57			name: self.name,
58			queue_size_warning: self.queue_size_warning,
59			warning_fired: self.warning_fired.clone(),
60			creation_backtrace: self.creation_backtrace.clone(),
61		}
62	}
63}
64
65/// Wrapper Type around [`async_channel::Receiver`] that decreases the global
66/// measure when a message is polled.
67#[derive(Debug)]
68pub struct TracingUnboundedReceiver<T> {
69	inner: Receiver<T>,
70	name: &'static str,
71}
72
73/// Wrapper around [`async_channel::unbounded`] that tracks the in- and outflow via
74/// `UNBOUNDED_CHANNELS_COUNTER` and warns if the message queue grows
75/// above the warning threshold.
76pub fn tracing_unbounded<T>(
77	name: &'static str,
78	queue_size_warning: usize,
79) -> (TracingUnboundedSender<T>, TracingUnboundedReceiver<T>) {
80	let (s, r) = async_channel::unbounded();
81	let sender = TracingUnboundedSender {
82		inner: s,
83		name,
84		queue_size_warning,
85		warning_fired: Arc::new(AtomicBool::new(false)),
86		creation_backtrace: Arc::new(Backtrace::force_capture()),
87	};
88	let receiver = TracingUnboundedReceiver { inner: r, name: name.into() };
89	(sender, receiver)
90}
91
92impl<T> TracingUnboundedSender<T> {
93	/// Proxy function to [`async_channel::Sender`].
94	pub fn is_closed(&self) -> bool {
95		self.inner.is_closed()
96	}
97
98	/// Proxy function to [`async_channel::Sender`].
99	pub fn close(&self) -> bool {
100		self.inner.close()
101	}
102
103	/// Proxy function to `async_channel::Sender::try_send`.
104	pub fn unbounded_send(&self, msg: T) -> Result<(), TrySendError<T>> {
105		self.inner.try_send(msg).inspect(|_| {
106			UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[self.name, SENT_LABEL]).inc();
107			UNBOUNDED_CHANNELS_SIZE
108				.with_label_values(&[self.name])
109				.set(self.inner.len().saturated_into());
110
111			if self.inner.len() >= self.queue_size_warning
112				&& self
113					.warning_fired
114					.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
115					.is_ok()
116			{
117				error!(
118					"The number of unprocessed messages in channel `{}` exceeded {}.\n\
119					 The channel was created at:\n{}\n
120					 Last message was sent from:\n{}",
121					self.name,
122					self.queue_size_warning,
123					self.creation_backtrace,
124					Backtrace::force_capture(),
125				);
126			}
127		})
128	}
129
130	/// The number of elements in the channel (proxy function to [`async_channel::Sender`]).
131	pub fn len(&self) -> usize {
132		self.inner.len()
133	}
134}
135
136impl<T> TracingUnboundedReceiver<T> {
137	/// Proxy function to [`async_channel::Receiver`].
138	pub fn close(&mut self) -> bool {
139		self.inner.close()
140	}
141
142	/// Proxy function to [`async_channel::Receiver`]
143	/// that discounts the messages taken out.
144	pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
145		self.inner.try_recv().inspect(|_| {
146			UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[self.name, RECEIVED_LABEL]).inc();
147			UNBOUNDED_CHANNELS_SIZE
148				.with_label_values(&[self.name])
149				.set(self.inner.len().saturated_into());
150		})
151	}
152
153	/// The number of elements in the channel (proxy function to [`async_channel::Receiver`]).
154	pub fn len(&self) -> usize {
155		self.inner.len()
156	}
157
158	/// The name of this receiver
159	pub fn name(&self) -> &'static str {
160		self.name
161	}
162}
163
164impl<T> Drop for TracingUnboundedReceiver<T> {
165	fn drop(&mut self) {
166		// Close the channel to prevent any further messages to be sent into the channel
167		self.close();
168		// The number of messages about to be dropped
169		let count = self.inner.len();
170		// Discount the messages
171		if count > 0 {
172			UNBOUNDED_CHANNELS_COUNTER
173				.with_label_values(&[self.name, DROPPED_LABEL])
174				.inc_by(count.saturated_into());
175		}
176		// Reset the size metric to 0
177		UNBOUNDED_CHANNELS_SIZE.with_label_values(&[self.name]).set(0);
178		// Drain all the pending messages in the channel since they can never be accessed,
179		// this can be removed once https://github.com/smol-rs/async-channel/issues/23 is
180		// resolved
181		while let Ok(_) = self.inner.try_recv() {}
182	}
183}
184
185impl<T> Unpin for TracingUnboundedReceiver<T> {}
186
187impl<T> Stream for TracingUnboundedReceiver<T> {
188	type Item = T;
189
190	fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
191		let s = self.get_mut();
192		match Pin::new(&mut s.inner).poll_next(cx) {
193			Poll::Ready(msg) => {
194				if msg.is_some() {
195					UNBOUNDED_CHANNELS_COUNTER.with_label_values(&[s.name, RECEIVED_LABEL]).inc();
196					UNBOUNDED_CHANNELS_SIZE
197						.with_label_values(&[s.name])
198						.set(s.inner.len().saturated_into());
199				}
200				Poll::Ready(msg)
201			},
202			Poll::Pending => Poll::Pending,
203		}
204	}
205}
206
207impl<T> FusedStream for TracingUnboundedReceiver<T> {
208	fn is_terminated(&self) -> bool {
209		self.inner.is_terminated()
210	}
211}
212
213#[cfg(test)]
214mod tests {
215	use super::tracing_unbounded;
216	use async_channel::{self, RecvError, TryRecvError};
217
218	#[test]
219	fn test_tracing_unbounded_receiver_drop() {
220		let (tracing_unbounded_sender, tracing_unbounded_receiver) =
221			tracing_unbounded("test-receiver-drop", 10);
222		let (tx, rx) = async_channel::unbounded::<usize>();
223
224		tracing_unbounded_sender.unbounded_send(tx).unwrap();
225		drop(tracing_unbounded_receiver);
226
227		assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
228		assert_eq!(rx.recv_blocking(), Err(RecvError));
229	}
230}