Skip to main content

ferro_broadcast/
broadcaster.rs

1//! The main broadcaster for managing channels and sending messages.
2
3use crate::channel::{ChannelInfo, ChannelType, PresenceMember};
4use crate::config::BroadcastConfig;
5use crate::message::{BroadcastMessage, ServerMessage};
6use crate::Error;
7use dashmap::DashMap;
8use serde::Serialize;
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tracing::{debug, info, warn};
12
13/// A connected client.
14pub struct Client {
15    /// Unique socket ID.
16    pub socket_id: String,
17    /// Sender to push messages to this client.
18    pub sender: mpsc::Sender<ServerMessage>,
19    /// Channels this client is subscribed to.
20    pub channels: Vec<String>,
21}
22
23/// Shared state for the broadcaster.
24struct BroadcasterInner {
25    /// Connected clients by socket ID.
26    clients: DashMap<String, Client>,
27    /// Channels by name.
28    channels: DashMap<String, ChannelInfo>,
29    /// Optional authorization callback.
30    authorizer: Option<Arc<dyn ChannelAuthorizer>>,
31    /// Configuration.
32    config: BroadcastConfig,
33}
34
35/// The broadcaster manages channels and client connections.
36#[derive(Clone)]
37pub struct Broadcaster {
38    inner: Arc<BroadcasterInner>,
39}
40
41impl Broadcaster {
42    /// Create a new broadcaster with default configuration.
43    pub fn new() -> Self {
44        Self::with_config(BroadcastConfig::default())
45    }
46
47    /// Create a new broadcaster with the given configuration.
48    pub fn with_config(config: BroadcastConfig) -> Self {
49        Self {
50            inner: Arc::new(BroadcasterInner {
51                clients: DashMap::new(),
52                channels: DashMap::new(),
53                authorizer: None,
54                config,
55            }),
56        }
57    }
58
59    /// Set the channel authorizer.
60    pub fn with_authorizer<A: ChannelAuthorizer + 'static>(self, authorizer: A) -> Self {
61        Self {
62            inner: Arc::new(BroadcasterInner {
63                clients: DashMap::new(),
64                channels: DashMap::new(),
65                authorizer: Some(Arc::new(authorizer)),
66                config: self.inner.config.clone(),
67            }),
68        }
69    }
70
71    /// Get the configuration.
72    pub fn config(&self) -> &BroadcastConfig {
73        &self.inner.config
74    }
75
76    /// Register a new client connection.
77    pub fn add_client(&self, socket_id: String, sender: mpsc::Sender<ServerMessage>) {
78        info!(socket_id = %socket_id, "Client connected");
79        self.inner.clients.insert(
80            socket_id.clone(),
81            Client {
82                socket_id,
83                sender,
84                channels: Vec::new(),
85            },
86        );
87    }
88
89    /// Remove a client and clean up their subscriptions.
90    pub fn remove_client(&self, socket_id: &str) {
91        if let Some((_, client)) = self.inner.clients.remove(socket_id) {
92            info!(socket_id = %socket_id, "Client disconnected");
93
94            // Remove from all subscribed channels
95            for channel_name in &client.channels {
96                self.unsubscribe_internal(socket_id, channel_name);
97            }
98        }
99    }
100
101    /// Subscribe a client to a channel.
102    pub async fn subscribe(
103        &self,
104        socket_id: &str,
105        channel_name: &str,
106        auth: Option<&str>,
107        member_info: Option<PresenceMember>,
108    ) -> Result<(), Error> {
109        let channel_type = ChannelType::from_name(channel_name);
110        let config = &self.inner.config;
111
112        // Check max channels limit
113        if config.max_channels > 0
114            && !self.inner.channels.contains_key(channel_name)
115            && self.inner.channels.len() >= config.max_channels
116        {
117            warn!(channel = %channel_name, max = config.max_channels, "Max channels limit reached");
118            return Err(Error::ChannelFull);
119        }
120
121        // Check authorization for private/presence channels
122        if channel_type.requires_auth() {
123            if let Some(authorizer) = &self.inner.authorizer {
124                let auth_data = AuthData {
125                    socket_id: socket_id.to_string(),
126                    channel: channel_name.to_string(),
127                    auth_token: auth.map(|s| s.to_string()),
128                };
129                if !authorizer.authorize(&auth_data).await {
130                    warn!(socket_id = %socket_id, channel = %channel_name, "Authorization failed");
131                    return Err(Error::unauthorized("Channel authorization failed"));
132                }
133            } else if auth.is_none() {
134                return Err(Error::unauthorized("Authorization required"));
135            }
136        }
137
138        // Get or create channel
139        let mut channel = self
140            .inner
141            .channels
142            .entry(channel_name.to_string())
143            .or_insert_with(|| ChannelInfo::new(channel_name));
144
145        // Check max subscribers limit
146        if config.max_subscribers_per_channel > 0
147            && channel.subscriber_count() >= config.max_subscribers_per_channel
148        {
149            warn!(
150                channel = %channel_name,
151                max = config.max_subscribers_per_channel,
152                "Max subscribers per channel limit reached"
153            );
154            return Err(Error::ChannelFull);
155        }
156
157        // Add subscriber
158        channel.add_subscriber(socket_id.to_string());
159
160        // For presence channels, add member info
161        if channel_type == ChannelType::Presence {
162            if let Some(member) = member_info {
163                channel.add_member(member.clone());
164
165                // Notify other members
166                let msg = ServerMessage::MemberAdded {
167                    channel: channel_name.to_string(),
168                    user_id: member.user_id.clone(),
169                    user_info: member.user_info.clone(),
170                };
171                drop(channel); // Release lock before broadcasting
172                self.send_to_channel_except(channel_name, socket_id, &msg)
173                    .await;
174            }
175        } else {
176            drop(channel);
177        }
178
179        // Update client's channel list
180        if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
181            if !client.channels.contains(&channel_name.to_string()) {
182                client.channels.push(channel_name.to_string());
183            }
184        }
185
186        debug!(socket_id = %socket_id, channel = %channel_name, "Subscribed to channel");
187        Ok(())
188    }
189
190    /// Unsubscribe a client from a channel.
191    pub async fn unsubscribe(&self, socket_id: &str, channel_name: &str) {
192        self.unsubscribe_internal(socket_id, channel_name);
193    }
194
195    fn unsubscribe_internal(&self, socket_id: &str, channel_name: &str) {
196        // Remove from channel
197        if let Some(mut channel) = self.inner.channels.get_mut(channel_name) {
198            channel.remove_subscriber(socket_id);
199
200            // For presence channels, notify about member leaving
201            if channel.channel_type == ChannelType::Presence {
202                if let Some(member) = channel.remove_member(socket_id) {
203                    let msg = ServerMessage::MemberRemoved {
204                        channel: channel_name.to_string(),
205                        user_id: member.user_id,
206                    };
207                    // We can't await here, so we'll spawn a task
208                    let channel_name = channel_name.to_string();
209                    let broadcaster = self.clone();
210                    tokio::spawn(async move {
211                        broadcaster.send_to_channel(&channel_name, &msg).await;
212                    });
213                }
214            }
215
216            // Clean up empty channels
217            if channel.is_empty() {
218                drop(channel);
219                self.inner.channels.remove(channel_name);
220            }
221        }
222
223        // Update client's channel list
224        if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
225            client.channels.retain(|c| c != channel_name);
226        }
227
228        debug!(socket_id = %socket_id, channel = %channel_name, "Unsubscribed from channel");
229    }
230
231    /// Broadcast a message to a channel.
232    pub async fn broadcast<T: Serialize>(
233        &self,
234        channel: &str,
235        event: &str,
236        data: T,
237    ) -> Result<(), Error> {
238        let msg = BroadcastMessage::new(channel, event, data);
239        let server_msg = ServerMessage::Event(msg);
240        self.send_to_channel(channel, &server_msg).await;
241        Ok(())
242    }
243
244    /// Forward a client event (whisper) to other subscribers of the same channel.
245    ///
246    /// Used for client-to-client messaging such as typing indicators and cursor positions.
247    /// Only works when `allow_client_events` is enabled in configuration.
248    /// The sender is excluded from receiving the whispered message.
249    pub async fn whisper(
250        &self,
251        socket_id: &str,
252        channel_name: &str,
253        event: &str,
254        data: serde_json::Value,
255    ) -> Result<(), Error> {
256        if !self.inner.config.allow_client_events {
257            return Err(Error::Other("Client events are not allowed".into()));
258        }
259
260        // Verify client is subscribed to the channel
261        let channel = self
262            .inner
263            .channels
264            .get(channel_name)
265            .ok_or_else(|| Error::ChannelNotFound(channel_name.to_string()))?;
266        if !channel.subscribers.contains(socket_id) {
267            return Err(Error::ClientNotConnected(format!(
268                "Client {socket_id} is not subscribed to {channel_name}"
269            )));
270        }
271        drop(channel); // Release DashMap guard before await
272
273        let msg = BroadcastMessage::with_data(channel_name, event, data);
274        let server_msg = ServerMessage::Event(msg);
275        self.send_to_channel_except(channel_name, socket_id, &server_msg)
276            .await;
277
278        Ok(())
279    }
280
281    /// Broadcast to a channel, excluding a specific client.
282    pub async fn broadcast_except<T: Serialize>(
283        &self,
284        channel: &str,
285        event: &str,
286        data: T,
287        except_socket_id: &str,
288    ) -> Result<(), Error> {
289        let msg = BroadcastMessage::new(channel, event, data);
290        let server_msg = ServerMessage::Event(msg);
291        self.send_to_channel_except(channel, except_socket_id, &server_msg)
292            .await;
293        Ok(())
294    }
295
296    /// Send a message to all subscribers of a channel.
297    async fn send_to_channel(&self, channel_name: &str, msg: &ServerMessage) {
298        if let Some(channel) = self.inner.channels.get(channel_name) {
299            for socket_id in channel.subscribers.iter() {
300                self.send_to_client(socket_id, msg.clone()).await;
301            }
302        }
303    }
304
305    /// Send a message to all subscribers except one.
306    async fn send_to_channel_except(
307        &self,
308        channel_name: &str,
309        except_socket_id: &str,
310        msg: &ServerMessage,
311    ) {
312        if let Some(channel) = self.inner.channels.get(channel_name) {
313            for socket_id in channel.subscribers.iter() {
314                if socket_id.as_str() != except_socket_id {
315                    self.send_to_client(socket_id, msg.clone()).await;
316                }
317            }
318        }
319    }
320
321    /// Send a message to a specific client.
322    async fn send_to_client(&self, socket_id: &str, msg: ServerMessage) {
323        if let Some(client) = self.inner.clients.get(socket_id) {
324            if let Err(e) = client.sender.send(msg).await {
325                warn!(socket_id = %socket_id, error = %e, "Failed to send message to client");
326            }
327        }
328    }
329
330    /// Check if a client would be authorized for a channel.
331    ///
332    /// Returns true if:
333    /// - Channel is public (no auth needed)
334    /// - Channel is private/presence AND authorizer returns true
335    ///
336    /// Returns false if:
337    /// - Channel is private/presence AND no authorizer registered
338    /// - Channel is private/presence AND authorizer denies access
339    ///
340    /// Used by the broadcasting auth HTTP endpoint to validate authorization
341    /// without subscribing the client.
342    pub async fn check_auth(&self, auth_data: &AuthData) -> bool {
343        let channel_type = ChannelType::from_name(&auth_data.channel);
344        if !channel_type.requires_auth() {
345            return true;
346        }
347        if let Some(authorizer) = &self.inner.authorizer {
348            authorizer.authorize(auth_data).await
349        } else {
350            false
351        }
352    }
353
354    /// Get channel info.
355    pub fn get_channel(&self, name: &str) -> Option<ChannelInfo> {
356        self.inner.channels.get(name).map(|c| c.clone())
357    }
358
359    /// Get number of connected clients.
360    pub fn client_count(&self) -> usize {
361        self.inner.clients.len()
362    }
363
364    /// Get number of active channels.
365    pub fn channel_count(&self) -> usize {
366        self.inner.channels.len()
367    }
368}
369
370impl Default for Broadcaster {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376/// Authorization data for private/presence channels.
377#[derive(Debug, Clone)]
378pub struct AuthData {
379    /// The socket ID requesting access.
380    pub socket_id: String,
381    /// The channel name.
382    pub channel: String,
383    /// Optional auth token from the client.
384    pub auth_token: Option<String>,
385}
386
387/// Trait for authorizing channel access.
388#[async_trait::async_trait]
389pub trait ChannelAuthorizer: Send + Sync {
390    /// Check if access should be granted.
391    async fn authorize(&self, data: &AuthData) -> bool;
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[tokio::test]
399    async fn test_broadcaster_basic() {
400        let broadcaster = Broadcaster::new();
401        let (tx, _rx) = mpsc::channel(32);
402
403        broadcaster.add_client("socket_1".into(), tx);
404        assert_eq!(broadcaster.client_count(), 1);
405
406        broadcaster.remove_client("socket_1");
407        assert_eq!(broadcaster.client_count(), 0);
408    }
409
410    #[tokio::test]
411    async fn test_subscribe_public_channel() {
412        let broadcaster = Broadcaster::new();
413        let (tx, _rx) = mpsc::channel(32);
414
415        broadcaster.add_client("socket_1".into(), tx);
416        broadcaster
417            .subscribe("socket_1", "orders", None, None)
418            .await
419            .unwrap();
420
421        assert_eq!(broadcaster.channel_count(), 1);
422        let channel = broadcaster.get_channel("orders").unwrap();
423        assert_eq!(channel.subscriber_count(), 1);
424    }
425
426    #[tokio::test]
427    async fn test_subscribe_private_requires_auth() {
428        let broadcaster = Broadcaster::new();
429        let (tx, _rx) = mpsc::channel(32);
430
431        broadcaster.add_client("socket_1".into(), tx);
432        let result = broadcaster
433            .subscribe("socket_1", "private-orders.1", None, None)
434            .await;
435
436        assert!(result.is_err());
437    }
438
439    #[tokio::test]
440    async fn test_whisper_forwards_to_others() {
441        let broadcaster = Broadcaster::new();
442
443        let (tx1, mut rx1) = mpsc::channel(32);
444        let (tx2, mut rx2) = mpsc::channel(32);
445
446        broadcaster.add_client("socket_1".into(), tx1);
447        broadcaster.add_client("socket_2".into(), tx2);
448
449        // Subscribe both to public channel
450        broadcaster
451            .subscribe("socket_1", "chat", None, None)
452            .await
453            .unwrap();
454        broadcaster
455            .subscribe("socket_2", "chat", None, None)
456            .await
457            .unwrap();
458
459        // Client 1 whispers
460        broadcaster
461            .whisper(
462                "socket_1",
463                "chat",
464                "typing",
465                serde_json::json!({"user": "alice"}),
466            )
467            .await
468            .unwrap();
469
470        // Client 2 receives the whisper
471        let msg = rx2.try_recv().unwrap();
472        match msg {
473            ServerMessage::Event(broadcast_msg) => {
474                assert_eq!(broadcast_msg.event, "typing");
475                assert_eq!(broadcast_msg.channel, "chat");
476                assert_eq!(broadcast_msg.data, serde_json::json!({"user": "alice"}));
477            }
478            other => panic!("Expected Event, got {other:?}"),
479        }
480
481        // Client 1 does NOT receive it
482        assert!(rx1.try_recv().is_err());
483    }
484
485    #[tokio::test]
486    async fn test_whisper_rejected_when_disabled() {
487        let config = BroadcastConfig::new().allow_client_events(false);
488        let broadcaster = Broadcaster::with_config(config);
489
490        let (tx, _rx) = mpsc::channel(32);
491        broadcaster.add_client("socket_1".into(), tx);
492        broadcaster
493            .subscribe("socket_1", "chat", None, None)
494            .await
495            .unwrap();
496
497        let result = broadcaster
498            .whisper("socket_1", "chat", "typing", serde_json::json!({}))
499            .await;
500
501        assert!(result.is_err());
502    }
503
504    struct MockAuthorizer {
505        allowed_channels: Vec<String>,
506    }
507
508    #[async_trait::async_trait]
509    impl ChannelAuthorizer for MockAuthorizer {
510        async fn authorize(&self, data: &AuthData) -> bool {
511            self.allowed_channels.contains(&data.channel)
512        }
513    }
514
515    #[tokio::test]
516    async fn test_check_auth_public_channel_always_authorized() {
517        let broadcaster = Broadcaster::new();
518        let auth_data = AuthData {
519            socket_id: "socket_1".to_string(),
520            channel: "orders".to_string(),
521            auth_token: None,
522        };
523        assert!(broadcaster.check_auth(&auth_data).await);
524    }
525
526    #[tokio::test]
527    async fn test_check_auth_public_channel_authorized_without_authorizer() {
528        // Even with no authorizer, public channels pass
529        let broadcaster = Broadcaster::new();
530        let auth_data = AuthData {
531            socket_id: "socket_1".to_string(),
532            channel: "chat".to_string(),
533            auth_token: Some("user_42".to_string()),
534        };
535        assert!(broadcaster.check_auth(&auth_data).await);
536    }
537
538    #[tokio::test]
539    async fn test_check_auth_private_channel_denied_without_authorizer() {
540        let broadcaster = Broadcaster::new();
541        let auth_data = AuthData {
542            socket_id: "socket_1".to_string(),
543            channel: "private-orders".to_string(),
544            auth_token: Some("user_42".to_string()),
545        };
546        assert!(!broadcaster.check_auth(&auth_data).await);
547    }
548
549    #[tokio::test]
550    async fn test_check_auth_private_channel_allowed_by_authorizer() {
551        let authorizer = MockAuthorizer {
552            allowed_channels: vec!["private-orders".to_string()],
553        };
554        let broadcaster = Broadcaster::new().with_authorizer(authorizer);
555        let auth_data = AuthData {
556            socket_id: "socket_1".to_string(),
557            channel: "private-orders".to_string(),
558            auth_token: Some("user_42".to_string()),
559        };
560        assert!(broadcaster.check_auth(&auth_data).await);
561    }
562
563    #[tokio::test]
564    async fn test_check_auth_private_channel_denied_by_authorizer() {
565        let authorizer = MockAuthorizer {
566            allowed_channels: vec!["private-orders".to_string()],
567        };
568        let broadcaster = Broadcaster::new().with_authorizer(authorizer);
569        let auth_data = AuthData {
570            socket_id: "socket_1".to_string(),
571            channel: "private-admin".to_string(),
572            auth_token: Some("user_42".to_string()),
573        };
574        assert!(!broadcaster.check_auth(&auth_data).await);
575    }
576
577    #[tokio::test]
578    async fn test_check_auth_presence_channel_denied_without_authorizer() {
579        let broadcaster = Broadcaster::new();
580        let auth_data = AuthData {
581            socket_id: "socket_1".to_string(),
582            channel: "presence-chat".to_string(),
583            auth_token: Some("user_42".to_string()),
584        };
585        assert!(!broadcaster.check_auth(&auth_data).await);
586    }
587
588    #[tokio::test]
589    async fn test_check_auth_presence_channel_allowed_by_authorizer() {
590        let authorizer = MockAuthorizer {
591            allowed_channels: vec!["presence-chat".to_string()],
592        };
593        let broadcaster = Broadcaster::new().with_authorizer(authorizer);
594        let auth_data = AuthData {
595            socket_id: "socket_1".to_string(),
596            channel: "presence-chat".to_string(),
597            auth_token: Some("user_42".to_string()),
598        };
599        assert!(broadcaster.check_auth(&auth_data).await);
600    }
601
602    #[tokio::test]
603    async fn test_whisper_rejected_when_not_subscribed() {
604        let broadcaster = Broadcaster::new();
605
606        let (tx1, _rx1) = mpsc::channel(32);
607        let (tx2, _rx2) = mpsc::channel(32);
608
609        broadcaster.add_client("socket_1".into(), tx1);
610        broadcaster.add_client("socket_2".into(), tx2);
611
612        // Only socket_2 subscribes
613        broadcaster
614            .subscribe("socket_2", "chat", None, None)
615            .await
616            .unwrap();
617
618        // socket_1 tries to whisper without subscribing
619        let result = broadcaster
620            .whisper("socket_1", "chat", "typing", serde_json::json!({}))
621            .await;
622
623        assert!(result.is_err());
624    }
625}