1use crate::channel::{ChannelInfo, ChannelType, PresenceMember};
4use crate::config::BroadcastConfig;
5use crate::message::{BroadcastMessage, ServerMessage};
6use crate::Error;
7use dashmap::DashMap;
8use serde::Serialize;
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tracing::{debug, info, warn};
12
13pub struct Client {
15 pub socket_id: String,
17 pub sender: mpsc::Sender<ServerMessage>,
19 pub channels: Vec<String>,
21}
22
23struct BroadcasterInner {
25 clients: DashMap<String, Client>,
27 channels: DashMap<String, ChannelInfo>,
29 authorizer: Option<Arc<dyn ChannelAuthorizer>>,
31 config: BroadcastConfig,
33}
34
35#[derive(Clone)]
37pub struct Broadcaster {
38 inner: Arc<BroadcasterInner>,
39}
40
41impl Broadcaster {
42 pub fn new() -> Self {
44 Self::with_config(BroadcastConfig::default())
45 }
46
47 pub fn with_config(config: BroadcastConfig) -> Self {
49 Self {
50 inner: Arc::new(BroadcasterInner {
51 clients: DashMap::new(),
52 channels: DashMap::new(),
53 authorizer: None,
54 config,
55 }),
56 }
57 }
58
59 pub fn with_authorizer<A: ChannelAuthorizer + 'static>(self, authorizer: A) -> Self {
61 Self {
62 inner: Arc::new(BroadcasterInner {
63 clients: DashMap::new(),
64 channels: DashMap::new(),
65 authorizer: Some(Arc::new(authorizer)),
66 config: self.inner.config.clone(),
67 }),
68 }
69 }
70
71 pub fn config(&self) -> &BroadcastConfig {
73 &self.inner.config
74 }
75
76 pub fn add_client(&self, socket_id: String, sender: mpsc::Sender<ServerMessage>) {
78 info!(socket_id = %socket_id, "Client connected");
79 self.inner.clients.insert(
80 socket_id.clone(),
81 Client {
82 socket_id,
83 sender,
84 channels: Vec::new(),
85 },
86 );
87 }
88
89 pub fn remove_client(&self, socket_id: &str) {
91 if let Some((_, client)) = self.inner.clients.remove(socket_id) {
92 info!(socket_id = %socket_id, "Client disconnected");
93
94 for channel_name in &client.channels {
96 self.unsubscribe_internal(socket_id, channel_name);
97 }
98 }
99 }
100
101 pub async fn subscribe(
103 &self,
104 socket_id: &str,
105 channel_name: &str,
106 auth: Option<&str>,
107 member_info: Option<PresenceMember>,
108 ) -> Result<(), Error> {
109 let channel_type = ChannelType::from_name(channel_name);
110 let config = &self.inner.config;
111
112 if config.max_channels > 0
114 && !self.inner.channels.contains_key(channel_name)
115 && self.inner.channels.len() >= config.max_channels
116 {
117 warn!(channel = %channel_name, max = config.max_channels, "Max channels limit reached");
118 return Err(Error::ChannelFull);
119 }
120
121 if channel_type.requires_auth() {
123 if let Some(authorizer) = &self.inner.authorizer {
124 let auth_data = AuthData {
125 socket_id: socket_id.to_string(),
126 channel: channel_name.to_string(),
127 auth_token: auth.map(|s| s.to_string()),
128 };
129 if !authorizer.authorize(&auth_data).await {
130 warn!(socket_id = %socket_id, channel = %channel_name, "Authorization failed");
131 return Err(Error::unauthorized("Channel authorization failed"));
132 }
133 } else if auth.is_none() {
134 return Err(Error::unauthorized("Authorization required"));
135 }
136 }
137
138 let mut channel = self
140 .inner
141 .channels
142 .entry(channel_name.to_string())
143 .or_insert_with(|| ChannelInfo::new(channel_name));
144
145 if config.max_subscribers_per_channel > 0
147 && channel.subscriber_count() >= config.max_subscribers_per_channel
148 {
149 warn!(
150 channel = %channel_name,
151 max = config.max_subscribers_per_channel,
152 "Max subscribers per channel limit reached"
153 );
154 return Err(Error::ChannelFull);
155 }
156
157 channel.add_subscriber(socket_id.to_string());
159
160 if channel_type == ChannelType::Presence {
162 if let Some(member) = member_info {
163 channel.add_member(member.clone());
164
165 let msg = ServerMessage::MemberAdded {
167 channel: channel_name.to_string(),
168 user_id: member.user_id.clone(),
169 user_info: member.user_info.clone(),
170 };
171 drop(channel); self.send_to_channel_except(channel_name, socket_id, &msg)
173 .await;
174 }
175 } else {
176 drop(channel);
177 }
178
179 if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
181 if !client.channels.contains(&channel_name.to_string()) {
182 client.channels.push(channel_name.to_string());
183 }
184 }
185
186 debug!(socket_id = %socket_id, channel = %channel_name, "Subscribed to channel");
187 Ok(())
188 }
189
190 pub async fn unsubscribe(&self, socket_id: &str, channel_name: &str) {
192 self.unsubscribe_internal(socket_id, channel_name);
193 }
194
195 fn unsubscribe_internal(&self, socket_id: &str, channel_name: &str) {
196 if let Some(mut channel) = self.inner.channels.get_mut(channel_name) {
198 channel.remove_subscriber(socket_id);
199
200 if channel.channel_type == ChannelType::Presence {
202 if let Some(member) = channel.remove_member(socket_id) {
203 let msg = ServerMessage::MemberRemoved {
204 channel: channel_name.to_string(),
205 user_id: member.user_id,
206 };
207 let channel_name = channel_name.to_string();
209 let broadcaster = self.clone();
210 tokio::spawn(async move {
211 broadcaster.send_to_channel(&channel_name, &msg).await;
212 });
213 }
214 }
215
216 if channel.is_empty() {
218 drop(channel);
219 self.inner.channels.remove(channel_name);
220 }
221 }
222
223 if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
225 client.channels.retain(|c| c != channel_name);
226 }
227
228 debug!(socket_id = %socket_id, channel = %channel_name, "Unsubscribed from channel");
229 }
230
231 pub async fn broadcast<T: Serialize>(
233 &self,
234 channel: &str,
235 event: &str,
236 data: T,
237 ) -> Result<(), Error> {
238 let msg = BroadcastMessage::new(channel, event, data);
239 let server_msg = ServerMessage::Event(msg);
240 self.send_to_channel(channel, &server_msg).await;
241 Ok(())
242 }
243
244 pub async fn broadcast_except<T: Serialize>(
246 &self,
247 channel: &str,
248 event: &str,
249 data: T,
250 except_socket_id: &str,
251 ) -> Result<(), Error> {
252 let msg = BroadcastMessage::new(channel, event, data);
253 let server_msg = ServerMessage::Event(msg);
254 self.send_to_channel_except(channel, except_socket_id, &server_msg)
255 .await;
256 Ok(())
257 }
258
259 async fn send_to_channel(&self, channel_name: &str, msg: &ServerMessage) {
261 if let Some(channel) = self.inner.channels.get(channel_name) {
262 for socket_id in channel.subscribers.iter() {
263 self.send_to_client(socket_id, msg.clone()).await;
264 }
265 }
266 }
267
268 async fn send_to_channel_except(
270 &self,
271 channel_name: &str,
272 except_socket_id: &str,
273 msg: &ServerMessage,
274 ) {
275 if let Some(channel) = self.inner.channels.get(channel_name) {
276 for socket_id in channel.subscribers.iter() {
277 if socket_id.as_str() != except_socket_id {
278 self.send_to_client(socket_id, msg.clone()).await;
279 }
280 }
281 }
282 }
283
284 async fn send_to_client(&self, socket_id: &str, msg: ServerMessage) {
286 if let Some(client) = self.inner.clients.get(socket_id) {
287 if let Err(e) = client.sender.send(msg).await {
288 warn!(socket_id = %socket_id, error = %e, "Failed to send message to client");
289 }
290 }
291 }
292
293 pub fn get_channel(&self, name: &str) -> Option<ChannelInfo> {
295 self.inner.channels.get(name).map(|c| c.clone())
296 }
297
298 pub fn client_count(&self) -> usize {
300 self.inner.clients.len()
301 }
302
303 pub fn channel_count(&self) -> usize {
305 self.inner.channels.len()
306 }
307}
308
309impl Default for Broadcaster {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct AuthData {
318 pub socket_id: String,
320 pub channel: String,
322 pub auth_token: Option<String>,
324}
325
326#[async_trait::async_trait]
328pub trait ChannelAuthorizer: Send + Sync {
329 async fn authorize(&self, data: &AuthData) -> bool;
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[tokio::test]
338 async fn test_broadcaster_basic() {
339 let broadcaster = Broadcaster::new();
340 let (tx, _rx) = mpsc::channel(32);
341
342 broadcaster.add_client("socket_1".into(), tx);
343 assert_eq!(broadcaster.client_count(), 1);
344
345 broadcaster.remove_client("socket_1");
346 assert_eq!(broadcaster.client_count(), 0);
347 }
348
349 #[tokio::test]
350 async fn test_subscribe_public_channel() {
351 let broadcaster = Broadcaster::new();
352 let (tx, _rx) = mpsc::channel(32);
353
354 broadcaster.add_client("socket_1".into(), tx);
355 broadcaster
356 .subscribe("socket_1", "orders", None, None)
357 .await
358 .unwrap();
359
360 assert_eq!(broadcaster.channel_count(), 1);
361 let channel = broadcaster.get_channel("orders").unwrap();
362 assert_eq!(channel.subscriber_count(), 1);
363 }
364
365 #[tokio::test]
366 async fn test_subscribe_private_requires_auth() {
367 let broadcaster = Broadcaster::new();
368 let (tx, _rx) = mpsc::channel(32);
369
370 broadcaster.add_client("socket_1".into(), tx);
371 let result = broadcaster
372 .subscribe("socket_1", "private-orders.1", None, None)
373 .await;
374
375 assert!(result.is_err());
376 }
377}