use crate::channel::{ChannelMetrics, ChannelMetricsTracker, WaitTimer};
use std::fmt;
use std::sync::Arc;
use tokio::sync::broadcast as tokio_broadcast;
pub fn channel<T: Clone>(capacity: usize, name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio_broadcast::channel(capacity);
let metrics = Arc::new(ChannelMetricsTracker::new());
let name = Arc::new(name.into());
(
Sender {
inner: tx,
metrics: metrics.clone(),
name: name.clone(),
capacity,
},
Receiver {
inner: rx,
metrics,
name,
},
)
}
pub struct Sender<T> {
inner: tokio_broadcast::Sender<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
capacity: usize,
}
impl<T: Clone> Sender<T> {
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
match self.inner.send(value) {
Ok(n) => {
self.metrics.record_send(None);
Ok(n)
}
Err(tokio_broadcast::error::SendError(value)) => {
self.metrics.mark_closed();
Err(SendError(value))
}
}
}
#[must_use]
pub fn subscribe(&self) -> Receiver<T> {
Receiver {
inner: self.inner.subscribe(),
metrics: self.metrics.clone(),
name: self.name.clone(),
}
}
#[must_use]
pub fn receiver_count(&self) -> usize {
self.inner.receiver_count()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
self.metrics.get_metrics(0)
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
metrics: self.metrics.clone(),
name: self.name.clone(),
capacity: self.capacity,
}
}
}
impl<T: Clone> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("broadcast::Sender")
.field("name", &self.name)
.field("capacity", &self.capacity)
.field("receivers", &self.receiver_count())
.finish()
}
}
pub struct Receiver<T> {
inner: tokio_broadcast::Receiver<T>,
metrics: Arc<ChannelMetricsTracker>,
name: Arc<String>,
}
impl<T: Clone> Receiver<T> {
pub async fn recv(&mut self) -> Result<T, RecvError> {
let timer = WaitTimer::start();
match self.inner.recv().await {
Ok(value) => {
let wait_time = timer.elapsed_if_waited();
self.metrics.record_recv(wait_time);
Ok(value)
}
Err(tokio_broadcast::error::RecvError::Closed) => {
self.metrics.mark_closed();
Err(RecvError::Closed)
}
Err(tokio_broadcast::error::RecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.inner.try_recv() {
Ok(value) => {
self.metrics.record_recv(None);
Ok(value)
}
Err(tokio_broadcast::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
Err(tokio_broadcast::error::TryRecvError::Closed) => {
self.metrics.mark_closed();
Err(TryRecvError::Closed)
}
Err(tokio_broadcast::error::TryRecvError::Lagged(n)) => Err(TryRecvError::Lagged(n)),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn metrics(&self) -> ChannelMetrics {
self.metrics.get_metrics(0)
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("broadcast::Receiver")
.field("name", &self.name)
.finish()
}
}
#[derive(Debug)]
pub struct SendError<T>(pub T);
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "channel closed (no receivers)")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecvError {
Closed,
Lagged(u64),
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecvError::Closed => write!(f, "channel closed"),
RecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
}
}
}
impl std::error::Error for RecvError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Closed,
Lagged(u64),
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "channel empty"),
TryRecvError::Closed => write!(f, "channel closed"),
TryRecvError::Lagged(n) => write!(f, "receiver lagged, missed {n} messages"),
}
}
}
impl std::error::Error for TryRecvError {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_broadcast_basic() {
let (tx, mut rx1) = channel::<i32>(16, "test");
let mut rx2 = tx.subscribe();
tx.send(42).unwrap();
assert_eq!(rx1.recv().await.unwrap(), 42);
assert_eq!(rx2.recv().await.unwrap(), 42);
let metrics = tx.metrics();
assert_eq!(metrics.sent, 1);
}
#[tokio::test]
async fn test_broadcast_multiple_sends() {
let (tx, mut rx) = channel::<i32>(16, "test");
tx.send(1).unwrap();
tx.send(2).unwrap();
tx.send(3).unwrap();
assert_eq!(rx.recv().await.unwrap(), 1);
assert_eq!(rx.recv().await.unwrap(), 2);
assert_eq!(rx.recv().await.unwrap(), 3);
let metrics = rx.metrics();
assert_eq!(metrics.received, 3);
}
#[tokio::test]
async fn test_broadcast_receiver_count() {
let (tx, _rx1) = channel::<i32>(16, "test");
assert_eq!(tx.receiver_count(), 1);
let _rx2 = tx.subscribe();
assert_eq!(tx.receiver_count(), 2);
let _rx3 = tx.subscribe();
assert_eq!(tx.receiver_count(), 3);
}
#[tokio::test]
async fn test_broadcast_no_receivers() {
let (tx, rx) = channel::<i32>(16, "test");
drop(rx);
assert!(tx.send(42).is_err());
}
}