1use crate::Message;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7
8#[derive(Clone)]
31pub struct Broadcast {
32 sender: broadcast::Sender<Message>,
33 subscriber_count: Arc<AtomicUsize>,
34}
35
36impl Broadcast {
37 pub fn new() -> Self {
39 Self::with_capacity(100)
40 }
41
42 pub fn with_capacity(capacity: usize) -> Self {
44 let (sender, _) = broadcast::channel(capacity);
45 Self {
46 sender,
47 subscriber_count: Arc::new(AtomicUsize::new(0)),
48 }
49 }
50
51 pub fn subscribe(&self) -> BroadcastReceiver {
53 self.subscriber_count.fetch_add(1, Ordering::SeqCst);
54 BroadcastReceiver {
55 inner: self.sender.subscribe(),
56 subscriber_count: self.subscriber_count.clone(),
57 }
58 }
59
60 pub fn send(&self, msg: Message) -> usize {
65 self.sender.send(msg).unwrap_or(0)
66 }
67
68 pub fn send_text(&self, text: impl Into<String>) -> usize {
70 self.send(Message::text(text))
71 }
72
73 pub fn send_json<T: serde::Serialize>(
75 &self,
76 value: &T,
77 ) -> Result<usize, crate::WebSocketError> {
78 let msg = Message::json(value)?;
79 Ok(self.send(msg))
80 }
81
82 pub fn subscriber_count(&self) -> usize {
84 self.subscriber_count.load(Ordering::SeqCst)
85 }
86
87 pub fn has_subscribers(&self) -> bool {
89 self.subscriber_count() > 0
90 }
91}
92
93impl Default for Broadcast {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99pub struct BroadcastReceiver {
101 inner: broadcast::Receiver<Message>,
102 subscriber_count: Arc<AtomicUsize>,
103}
104
105impl BroadcastReceiver {
106 pub async fn recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
111 match self.inner.recv().await {
112 Ok(msg) => Some(Ok(msg)),
113 Err(broadcast::error::RecvError::Closed) => None,
114 Err(broadcast::error::RecvError::Lagged(count)) => {
115 Some(Err(BroadcastRecvError::Lagged(count)))
116 }
117 }
118 }
119
120 pub fn try_recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
122 match self.inner.try_recv() {
123 Ok(msg) => Some(Ok(msg)),
124 Err(broadcast::error::TryRecvError::Empty) => None,
125 Err(broadcast::error::TryRecvError::Closed) => None,
126 Err(broadcast::error::TryRecvError::Lagged(count)) => {
127 Some(Err(BroadcastRecvError::Lagged(count)))
128 }
129 }
130 }
131}
132
133impl Drop for BroadcastReceiver {
134 fn drop(&mut self) {
135 self.subscriber_count.fetch_sub(1, Ordering::SeqCst);
136 }
137}
138
139#[derive(Debug, Clone, Copy)]
141pub enum BroadcastRecvError {
142 Lagged(u64),
144}
145
146impl std::fmt::Display for BroadcastRecvError {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 match self {
149 Self::Lagged(count) => write!(f, "Lagged behind by {} messages", count),
150 }
151 }
152}
153
154impl std::error::Error for BroadcastRecvError {}