1use crate::channel::{ChannelInfo, ChannelType, PresenceMember};
4use crate::config::BroadcastConfig;
5use crate::message::{BroadcastMessage, ServerMessage};
6use crate::Error;
7use dashmap::DashMap;
8use serde::Serialize;
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tracing::{debug, info, warn};
12
13pub struct Client {
15 pub socket_id: String,
17 pub sender: mpsc::Sender<ServerMessage>,
19 pub channels: Vec<String>,
21}
22
23struct BroadcasterInner {
25 clients: DashMap<String, Client>,
27 channels: DashMap<String, ChannelInfo>,
29 authorizer: Option<Arc<dyn ChannelAuthorizer>>,
31 config: BroadcastConfig,
33}
34
35#[derive(Clone)]
37pub struct Broadcaster {
38 inner: Arc<BroadcasterInner>,
39}
40
41impl Broadcaster {
42 pub fn new() -> Self {
44 Self::with_config(BroadcastConfig::default())
45 }
46
47 pub fn with_config(config: BroadcastConfig) -> Self {
49 Self {
50 inner: Arc::new(BroadcasterInner {
51 clients: DashMap::new(),
52 channels: DashMap::new(),
53 authorizer: None,
54 config,
55 }),
56 }
57 }
58
59 pub fn with_authorizer<A: ChannelAuthorizer + 'static>(self, authorizer: A) -> Self {
61 Self {
62 inner: Arc::new(BroadcasterInner {
63 clients: DashMap::new(),
64 channels: DashMap::new(),
65 authorizer: Some(Arc::new(authorizer)),
66 config: self.inner.config.clone(),
67 }),
68 }
69 }
70
71 pub fn config(&self) -> &BroadcastConfig {
73 &self.inner.config
74 }
75
76 pub fn add_client(&self, socket_id: String, sender: mpsc::Sender<ServerMessage>) {
78 info!(socket_id = %socket_id, "Client connected");
79 self.inner.clients.insert(
80 socket_id.clone(),
81 Client {
82 socket_id,
83 sender,
84 channels: Vec::new(),
85 },
86 );
87 }
88
89 pub fn remove_client(&self, socket_id: &str) {
91 if let Some((_, client)) = self.inner.clients.remove(socket_id) {
92 info!(socket_id = %socket_id, "Client disconnected");
93
94 for channel_name in &client.channels {
96 self.unsubscribe_internal(socket_id, channel_name);
97 }
98 }
99 }
100
101 pub async fn subscribe(
103 &self,
104 socket_id: &str,
105 channel_name: &str,
106 auth: Option<&str>,
107 member_info: Option<PresenceMember>,
108 ) -> Result<(), Error> {
109 let channel_type = ChannelType::from_name(channel_name);
110 let config = &self.inner.config;
111
112 if config.max_channels > 0
114 && !self.inner.channels.contains_key(channel_name)
115 && self.inner.channels.len() >= config.max_channels
116 {
117 warn!(channel = %channel_name, max = config.max_channels, "Max channels limit reached");
118 return Err(Error::ChannelFull);
119 }
120
121 if channel_type.requires_auth() {
123 if let Some(authorizer) = &self.inner.authorizer {
124 let auth_data = AuthData {
125 socket_id: socket_id.to_string(),
126 channel: channel_name.to_string(),
127 auth_token: auth.map(|s| s.to_string()),
128 };
129 if !authorizer.authorize(&auth_data).await {
130 warn!(socket_id = %socket_id, channel = %channel_name, "Authorization failed");
131 return Err(Error::unauthorized("Channel authorization failed"));
132 }
133 } else if auth.is_none() {
134 return Err(Error::unauthorized("Authorization required"));
135 }
136 }
137
138 let mut channel = self
140 .inner
141 .channels
142 .entry(channel_name.to_string())
143 .or_insert_with(|| ChannelInfo::new(channel_name));
144
145 if config.max_subscribers_per_channel > 0
147 && channel.subscriber_count() >= config.max_subscribers_per_channel
148 {
149 warn!(
150 channel = %channel_name,
151 max = config.max_subscribers_per_channel,
152 "Max subscribers per channel limit reached"
153 );
154 return Err(Error::ChannelFull);
155 }
156
157 channel.add_subscriber(socket_id.to_string());
159
160 if channel_type == ChannelType::Presence {
162 if let Some(member) = member_info {
163 channel.add_member(member.clone());
164
165 let msg = ServerMessage::MemberAdded {
167 channel: channel_name.to_string(),
168 user_id: member.user_id.clone(),
169 user_info: member.user_info.clone(),
170 };
171 drop(channel); self.send_to_channel_except(channel_name, socket_id, &msg)
173 .await;
174 }
175 } else {
176 drop(channel);
177 }
178
179 if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
181 if !client.channels.contains(&channel_name.to_string()) {
182 client.channels.push(channel_name.to_string());
183 }
184 }
185
186 debug!(socket_id = %socket_id, channel = %channel_name, "Subscribed to channel");
187 Ok(())
188 }
189
190 pub async fn unsubscribe(&self, socket_id: &str, channel_name: &str) {
192 self.unsubscribe_internal(socket_id, channel_name);
193 }
194
195 fn unsubscribe_internal(&self, socket_id: &str, channel_name: &str) {
196 if let Some(mut channel) = self.inner.channels.get_mut(channel_name) {
198 channel.remove_subscriber(socket_id);
199
200 if channel.channel_type == ChannelType::Presence {
202 if let Some(member) = channel.remove_member(socket_id) {
203 let msg = ServerMessage::MemberRemoved {
204 channel: channel_name.to_string(),
205 user_id: member.user_id,
206 };
207 let channel_name = channel_name.to_string();
209 let broadcaster = self.clone();
210 tokio::spawn(async move {
211 broadcaster.send_to_channel(&channel_name, &msg).await;
212 });
213 }
214 }
215
216 if channel.is_empty() {
218 drop(channel);
219 self.inner.channels.remove(channel_name);
220 }
221 }
222
223 if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
225 client.channels.retain(|c| c != channel_name);
226 }
227
228 debug!(socket_id = %socket_id, channel = %channel_name, "Unsubscribed from channel");
229 }
230
231 pub async fn broadcast<T: Serialize>(
233 &self,
234 channel: &str,
235 event: &str,
236 data: T,
237 ) -> Result<(), Error> {
238 let msg = BroadcastMessage::new(channel, event, data);
239 let server_msg = ServerMessage::Event(msg);
240 self.send_to_channel(channel, &server_msg).await;
241 Ok(())
242 }
243
244 pub async fn whisper(
250 &self,
251 socket_id: &str,
252 channel_name: &str,
253 event: &str,
254 data: serde_json::Value,
255 ) -> Result<(), Error> {
256 if !self.inner.config.allow_client_events {
257 return Err(Error::Other("Client events are not allowed".into()));
258 }
259
260 let channel = self
262 .inner
263 .channels
264 .get(channel_name)
265 .ok_or_else(|| Error::ChannelNotFound(channel_name.to_string()))?;
266 if !channel.subscribers.contains(socket_id) {
267 return Err(Error::ClientNotConnected(format!(
268 "Client {socket_id} is not subscribed to {channel_name}"
269 )));
270 }
271 drop(channel); let msg = BroadcastMessage::with_data(channel_name, event, data);
274 let server_msg = ServerMessage::Event(msg);
275 self.send_to_channel_except(channel_name, socket_id, &server_msg)
276 .await;
277
278 Ok(())
279 }
280
281 pub async fn broadcast_except<T: Serialize>(
283 &self,
284 channel: &str,
285 event: &str,
286 data: T,
287 except_socket_id: &str,
288 ) -> Result<(), Error> {
289 let msg = BroadcastMessage::new(channel, event, data);
290 let server_msg = ServerMessage::Event(msg);
291 self.send_to_channel_except(channel, except_socket_id, &server_msg)
292 .await;
293 Ok(())
294 }
295
296 async fn send_to_channel(&self, channel_name: &str, msg: &ServerMessage) {
298 if let Some(channel) = self.inner.channels.get(channel_name) {
299 for socket_id in channel.subscribers.iter() {
300 self.send_to_client(socket_id, msg.clone()).await;
301 }
302 }
303 }
304
305 async fn send_to_channel_except(
307 &self,
308 channel_name: &str,
309 except_socket_id: &str,
310 msg: &ServerMessage,
311 ) {
312 if let Some(channel) = self.inner.channels.get(channel_name) {
313 for socket_id in channel.subscribers.iter() {
314 if socket_id.as_str() != except_socket_id {
315 self.send_to_client(socket_id, msg.clone()).await;
316 }
317 }
318 }
319 }
320
321 async fn send_to_client(&self, socket_id: &str, msg: ServerMessage) {
323 if let Some(client) = self.inner.clients.get(socket_id) {
324 if let Err(e) = client.sender.send(msg).await {
325 warn!(socket_id = %socket_id, error = %e, "Failed to send message to client");
326 }
327 }
328 }
329
330 pub async fn check_auth(&self, auth_data: &AuthData) -> bool {
343 let channel_type = ChannelType::from_name(&auth_data.channel);
344 if !channel_type.requires_auth() {
345 return true;
346 }
347 if let Some(authorizer) = &self.inner.authorizer {
348 authorizer.authorize(auth_data).await
349 } else {
350 false
351 }
352 }
353
354 pub fn get_channel(&self, name: &str) -> Option<ChannelInfo> {
356 self.inner.channels.get(name).map(|c| c.clone())
357 }
358
359 pub fn client_count(&self) -> usize {
361 self.inner.clients.len()
362 }
363
364 pub fn channel_count(&self) -> usize {
366 self.inner.channels.len()
367 }
368}
369
370impl Default for Broadcaster {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct AuthData {
379 pub socket_id: String,
381 pub channel: String,
383 pub auth_token: Option<String>,
385}
386
387#[async_trait::async_trait]
389pub trait ChannelAuthorizer: Send + Sync {
390 async fn authorize(&self, data: &AuthData) -> bool;
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[tokio::test]
399 async fn test_broadcaster_basic() {
400 let broadcaster = Broadcaster::new();
401 let (tx, _rx) = mpsc::channel(32);
402
403 broadcaster.add_client("socket_1".into(), tx);
404 assert_eq!(broadcaster.client_count(), 1);
405
406 broadcaster.remove_client("socket_1");
407 assert_eq!(broadcaster.client_count(), 0);
408 }
409
410 #[tokio::test]
411 async fn test_subscribe_public_channel() {
412 let broadcaster = Broadcaster::new();
413 let (tx, _rx) = mpsc::channel(32);
414
415 broadcaster.add_client("socket_1".into(), tx);
416 broadcaster
417 .subscribe("socket_1", "orders", None, None)
418 .await
419 .unwrap();
420
421 assert_eq!(broadcaster.channel_count(), 1);
422 let channel = broadcaster.get_channel("orders").unwrap();
423 assert_eq!(channel.subscriber_count(), 1);
424 }
425
426 #[tokio::test]
427 async fn test_subscribe_private_requires_auth() {
428 let broadcaster = Broadcaster::new();
429 let (tx, _rx) = mpsc::channel(32);
430
431 broadcaster.add_client("socket_1".into(), tx);
432 let result = broadcaster
433 .subscribe("socket_1", "private-orders.1", None, None)
434 .await;
435
436 assert!(result.is_err());
437 }
438
439 #[tokio::test]
440 async fn test_whisper_forwards_to_others() {
441 let broadcaster = Broadcaster::new();
442
443 let (tx1, mut rx1) = mpsc::channel(32);
444 let (tx2, mut rx2) = mpsc::channel(32);
445
446 broadcaster.add_client("socket_1".into(), tx1);
447 broadcaster.add_client("socket_2".into(), tx2);
448
449 broadcaster
451 .subscribe("socket_1", "chat", None, None)
452 .await
453 .unwrap();
454 broadcaster
455 .subscribe("socket_2", "chat", None, None)
456 .await
457 .unwrap();
458
459 broadcaster
461 .whisper(
462 "socket_1",
463 "chat",
464 "typing",
465 serde_json::json!({"user": "alice"}),
466 )
467 .await
468 .unwrap();
469
470 let msg = rx2.try_recv().unwrap();
472 match msg {
473 ServerMessage::Event(broadcast_msg) => {
474 assert_eq!(broadcast_msg.event, "typing");
475 assert_eq!(broadcast_msg.channel, "chat");
476 assert_eq!(broadcast_msg.data, serde_json::json!({"user": "alice"}));
477 }
478 other => panic!("Expected Event, got {other:?}"),
479 }
480
481 assert!(rx1.try_recv().is_err());
483 }
484
485 #[tokio::test]
486 async fn test_whisper_rejected_when_disabled() {
487 let config = BroadcastConfig::new().allow_client_events(false);
488 let broadcaster = Broadcaster::with_config(config);
489
490 let (tx, _rx) = mpsc::channel(32);
491 broadcaster.add_client("socket_1".into(), tx);
492 broadcaster
493 .subscribe("socket_1", "chat", None, None)
494 .await
495 .unwrap();
496
497 let result = broadcaster
498 .whisper("socket_1", "chat", "typing", serde_json::json!({}))
499 .await;
500
501 assert!(result.is_err());
502 }
503
504 struct MockAuthorizer {
505 allowed_channels: Vec<String>,
506 }
507
508 #[async_trait::async_trait]
509 impl ChannelAuthorizer for MockAuthorizer {
510 async fn authorize(&self, data: &AuthData) -> bool {
511 self.allowed_channels.contains(&data.channel)
512 }
513 }
514
515 #[tokio::test]
516 async fn test_check_auth_public_channel_always_authorized() {
517 let broadcaster = Broadcaster::new();
518 let auth_data = AuthData {
519 socket_id: "socket_1".to_string(),
520 channel: "orders".to_string(),
521 auth_token: None,
522 };
523 assert!(broadcaster.check_auth(&auth_data).await);
524 }
525
526 #[tokio::test]
527 async fn test_check_auth_public_channel_authorized_without_authorizer() {
528 let broadcaster = Broadcaster::new();
530 let auth_data = AuthData {
531 socket_id: "socket_1".to_string(),
532 channel: "chat".to_string(),
533 auth_token: Some("user_42".to_string()),
534 };
535 assert!(broadcaster.check_auth(&auth_data).await);
536 }
537
538 #[tokio::test]
539 async fn test_check_auth_private_channel_denied_without_authorizer() {
540 let broadcaster = Broadcaster::new();
541 let auth_data = AuthData {
542 socket_id: "socket_1".to_string(),
543 channel: "private-orders".to_string(),
544 auth_token: Some("user_42".to_string()),
545 };
546 assert!(!broadcaster.check_auth(&auth_data).await);
547 }
548
549 #[tokio::test]
550 async fn test_check_auth_private_channel_allowed_by_authorizer() {
551 let authorizer = MockAuthorizer {
552 allowed_channels: vec!["private-orders".to_string()],
553 };
554 let broadcaster = Broadcaster::new().with_authorizer(authorizer);
555 let auth_data = AuthData {
556 socket_id: "socket_1".to_string(),
557 channel: "private-orders".to_string(),
558 auth_token: Some("user_42".to_string()),
559 };
560 assert!(broadcaster.check_auth(&auth_data).await);
561 }
562
563 #[tokio::test]
564 async fn test_check_auth_private_channel_denied_by_authorizer() {
565 let authorizer = MockAuthorizer {
566 allowed_channels: vec!["private-orders".to_string()],
567 };
568 let broadcaster = Broadcaster::new().with_authorizer(authorizer);
569 let auth_data = AuthData {
570 socket_id: "socket_1".to_string(),
571 channel: "private-admin".to_string(),
572 auth_token: Some("user_42".to_string()),
573 };
574 assert!(!broadcaster.check_auth(&auth_data).await);
575 }
576
577 #[tokio::test]
578 async fn test_check_auth_presence_channel_denied_without_authorizer() {
579 let broadcaster = Broadcaster::new();
580 let auth_data = AuthData {
581 socket_id: "socket_1".to_string(),
582 channel: "presence-chat".to_string(),
583 auth_token: Some("user_42".to_string()),
584 };
585 assert!(!broadcaster.check_auth(&auth_data).await);
586 }
587
588 #[tokio::test]
589 async fn test_check_auth_presence_channel_allowed_by_authorizer() {
590 let authorizer = MockAuthorizer {
591 allowed_channels: vec!["presence-chat".to_string()],
592 };
593 let broadcaster = Broadcaster::new().with_authorizer(authorizer);
594 let auth_data = AuthData {
595 socket_id: "socket_1".to_string(),
596 channel: "presence-chat".to_string(),
597 auth_token: Some("user_42".to_string()),
598 };
599 assert!(broadcaster.check_auth(&auth_data).await);
600 }
601
602 #[tokio::test]
603 async fn test_whisper_rejected_when_not_subscribed() {
604 let broadcaster = Broadcaster::new();
605
606 let (tx1, _rx1) = mpsc::channel(32);
607 let (tx2, _rx2) = mpsc::channel(32);
608
609 broadcaster.add_client("socket_1".into(), tx1);
610 broadcaster.add_client("socket_2".into(), tx2);
611
612 broadcaster
614 .subscribe("socket_2", "chat", None, None)
615 .await
616 .unwrap();
617
618 let result = broadcaster
620 .whisper("socket_1", "chat", "typing", serde_json::json!({}))
621 .await;
622
623 assert!(result.is_err());
624 }
625}