use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
use std::fmt;
use std::sync::Arc;
use tokio::sync::mpsc as tokio_mpsc;
pub fn channel<T>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio_mpsc::channel(capacity);
let metrics = Arc::new(ChannelMetricsTracker::new());
let name = Arc::new(name.into());
let capacity = capacity;
(
Sender {
inner: tx,
metrics: metrics.clone(),
name: name.clone(),
capacity,
},
Receiver {
inner: rx,
metrics,
name,
capacity,
},
)
}
pub fn unbounded_channel<T>(name: impl Into<String>) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let (tx, rx) = tokio_mpsc::unbounded_channel();
let metrics = Arc::new(ChannelMetricsTracker::new());
let name = Arc::new(name.into());
(
UnboundedSender {
inner: tx,
metrics: metrics.clone(),
name: name.clone(),
},
UnboundedReceiver {
inner: rx,
metrics,
name,
},
)
}
pub struct Sender<T> {
inner: tokio_mpsc::Sender<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
capacity: usize,
}
impl<T> Sender<T> {
pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
let timer = WaitTimer::start();
match self.inner.send(value).await {
Ok(()) => {
let wait_time = timer.elapsed_if_waited();
self.metrics.record_send(wait_time);
Ok(())
}
Err(tokio_mpsc::error::SendError(value)) => {
self.metrics.mark_closed();
Err(SendError(value))
}
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
match self.inner.try_send(value) {
Ok(()) => {
self.metrics.record_send(None);
Ok(())
}
Err(tokio_mpsc::error::TrySendError::Full(value)) => Err(TrySendError::Full(value)),
Err(tokio_mpsc::error::TrySendError::Closed(value)) => {
self.metrics.mark_closed();
Err(TrySendError::Closed(value))
}
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[must_use]
pub fn max_capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
let buffered = (self.capacity - self.inner.capacity()) as u64;
self.metrics.get_metrics(buffered)
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
metrics: self.metrics.clone(),
name: self.name.clone(),
capacity: self.capacity,
}
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sender")
.field("name", &self.name)
.field("capacity", &self.capacity)
.finish()
}
}
pub struct Receiver<T> {
inner: tokio_mpsc::Receiver<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
capacity: usize,
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
let timer = WaitTimer::start();
if let Some(value) = self.inner.recv().await {
let wait_time = timer.elapsed_if_waited();
self.metrics.record_recv(wait_time);
Some(value)
} else {
self.metrics.mark_closed();
None
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.inner.try_recv() {
Ok(value) => {
self.metrics.record_recv(None);
Ok(value)
}
Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
self.metrics.mark_closed();
Err(TryRecvError::Disconnected)
}
}
}
pub fn close(&mut self) {
self.inner.close();
self.metrics.mark_closed();
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
let received = self
.metrics
.received
.load(std::sync::atomic::Ordering::Relaxed);
let buffered = sent.saturating_sub(received);
self.metrics.get_metrics(buffered)
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver")
.field("name", &self.name)
.field("capacity", &self.capacity)
.finish()
}
}
pub struct UnboundedSender<T> {
inner: tokio_mpsc::UnboundedSender<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
}
impl<T> UnboundedSender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
match self.inner.send(value) {
Ok(()) => {
self.metrics.record_send(None);
Ok(())
}
Err(tokio_mpsc::error::SendError(value)) => {
self.metrics.mark_closed();
Err(SendError(value))
}
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
let received = self
.metrics
.received
.load(std::sync::atomic::Ordering::Relaxed);
let buffered = sent.saturating_sub(received);
self.metrics.get_metrics(buffered)
}
}
impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
metrics: self.metrics.clone(),
name: self.name.clone(),
}
}
}
impl<T> fmt::Debug for UnboundedSender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UnboundedSender")
.field("name", &self.name)
.finish()
}
}
pub struct UnboundedReceiver<T> {
inner: tokio_mpsc::UnboundedReceiver<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
}
impl<T> UnboundedReceiver<T> {
pub async fn recv(&mut self) -> Option<T> {
let timer = WaitTimer::start();
if let Some(value) = self.inner.recv().await {
let wait_time = timer.elapsed_if_waited();
self.metrics.record_recv(wait_time);
Some(value)
} else {
self.metrics.mark_closed();
None
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.inner.try_recv() {
Ok(value) => {
self.metrics.record_recv(None);
Ok(value)
}
Err(tokio_mpsc::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
Err(tokio_mpsc::error::TryRecvError::Disconnected) => {
self.metrics.mark_closed();
Err(TryRecvError::Disconnected)
}
}
}
pub fn close(&mut self) {
self.inner.close();
self.metrics.mark_closed();
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
let sent = self.metrics.sent.load(std::sync::atomic::Ordering::Relaxed);
let received = self
.metrics
.received
.load(std::sync::atomic::Ordering::Relaxed);
let buffered = sent.saturating_sub(received);
self.metrics.get_metrics(buffered)
}
}
impl<T> fmt::Debug for UnboundedReceiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UnboundedReceiver")
.field("name", &self.name)
.finish()
}
}
#[derive(Debug)]
pub struct SendError<T>(pub T);
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug)]
pub enum TrySendError<T> {
Full(T),
Closed(T),
}
impl<T> fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrySendError::Full(_) => write!(f, "channel full"),
TrySendError::Closed(_) => write!(f, "channel closed"),
}
}
}
impl<T: fmt::Debug> std::error::Error for TrySendError<T> {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Disconnected,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "channel empty"),
TryRecvError::Disconnected => write!(f, "channel disconnected"),
}
}
}
impl std::error::Error for TryRecvError {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bounded_channel() {
let (tx, mut rx) = channel::<i32>(10, "test");
tx.send(42).await.unwrap();
tx.send(43).await.unwrap();
assert_eq!(rx.recv().await, Some(42));
assert_eq!(rx.recv().await, Some(43));
let metrics = rx.metrics();
assert_eq!(metrics.sent, 2);
assert_eq!(metrics.received, 2);
}
#[tokio::test]
async fn test_unbounded_channel() {
let (tx, mut rx) = unbounded_channel::<String>("events");
tx.send("hello".into()).unwrap();
tx.send("world".into()).unwrap();
assert_eq!(rx.recv().await, Some("hello".into()));
assert_eq!(rx.recv().await, Some("world".into()));
let metrics = rx.metrics();
assert_eq!(metrics.sent, 2);
assert_eq!(metrics.received, 2);
}
#[tokio::test]
async fn test_channel_close() {
let (tx, mut rx) = channel::<i32>(10, "test");
tx.send(1).await.unwrap();
drop(tx);
assert_eq!(rx.recv().await, Some(1));
assert_eq!(rx.recv().await, None);
let metrics = rx.metrics();
assert!(metrics.closed);
}
#[tokio::test]
async fn test_try_send_recv() {
let (tx, mut rx) = channel::<i32>(2, "test");
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
assert!(matches!(tx.try_send(3), Err(TrySendError::Full(3))));
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
}
}