use crate::Message;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::broadcast;
#[derive(Clone)]
pub struct Broadcast {
sender: broadcast::Sender<Message>,
subscriber_count: Arc<AtomicUsize>,
}
impl Broadcast {
pub fn new() -> Self {
Self::with_capacity(100)
}
pub fn with_capacity(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self {
sender,
subscriber_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn subscribe(&self) -> BroadcastReceiver {
self.subscriber_count.fetch_add(1, Ordering::SeqCst);
BroadcastReceiver {
inner: self.sender.subscribe(),
subscriber_count: self.subscriber_count.clone(),
}
}
pub fn send(&self, msg: Message) -> usize {
self.sender.send(msg).unwrap_or(0)
}
pub fn send_text(&self, text: impl Into<String>) -> usize {
self.send(Message::text(text))
}
pub fn send_json<T: serde::Serialize>(
&self,
value: &T,
) -> Result<usize, crate::WebSocketError> {
let msg = Message::json(value)?;
Ok(self.send(msg))
}
pub fn subscriber_count(&self) -> usize {
self.subscriber_count.load(Ordering::SeqCst)
}
pub fn has_subscribers(&self) -> bool {
self.subscriber_count() > 0
}
}
impl Default for Broadcast {
fn default() -> Self {
Self::new()
}
}
pub struct BroadcastReceiver {
inner: broadcast::Receiver<Message>,
subscriber_count: Arc<AtomicUsize>,
}
impl BroadcastReceiver {
pub async fn recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
match self.inner.recv().await {
Ok(msg) => Some(Ok(msg)),
Err(broadcast::error::RecvError::Closed) => None,
Err(broadcast::error::RecvError::Lagged(count)) => {
Some(Err(BroadcastRecvError::Lagged(count)))
}
}
}
pub fn try_recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
match self.inner.try_recv() {
Ok(msg) => Some(Ok(msg)),
Err(broadcast::error::TryRecvError::Empty) => None,
Err(broadcast::error::TryRecvError::Closed) => None,
Err(broadcast::error::TryRecvError::Lagged(count)) => {
Some(Err(BroadcastRecvError::Lagged(count)))
}
}
}
}
impl Drop for BroadcastReceiver {
fn drop(&mut self) {
self.subscriber_count.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Debug, Clone, Copy)]
pub enum BroadcastRecvError {
Lagged(u64),
}
impl std::fmt::Display for BroadcastRecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Lagged(count) => write!(f, "Lagged behind by {} messages", count),
}
}
}
impl std::error::Error for BroadcastRecvError {}