ferro_broadcast/
broadcaster.rs

1//! The main broadcaster for managing channels and sending messages.
2
3use crate::channel::{ChannelInfo, ChannelType, PresenceMember};
4use crate::config::BroadcastConfig;
5use crate::message::{BroadcastMessage, ServerMessage};
6use crate::Error;
7use dashmap::DashMap;
8use serde::Serialize;
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tracing::{debug, info, warn};
12
13/// A connected client.
14pub struct Client {
15    /// Unique socket ID.
16    pub socket_id: String,
17    /// Sender to push messages to this client.
18    pub sender: mpsc::Sender<ServerMessage>,
19    /// Channels this client is subscribed to.
20    pub channels: Vec<String>,
21}
22
23/// Shared state for the broadcaster.
24struct BroadcasterInner {
25    /// Connected clients by socket ID.
26    clients: DashMap<String, Client>,
27    /// Channels by name.
28    channels: DashMap<String, ChannelInfo>,
29    /// Optional authorization callback.
30    authorizer: Option<Arc<dyn ChannelAuthorizer>>,
31    /// Configuration.
32    config: BroadcastConfig,
33}
34
35/// The broadcaster manages channels and client connections.
36#[derive(Clone)]
37pub struct Broadcaster {
38    inner: Arc<BroadcasterInner>,
39}
40
41impl Broadcaster {
42    /// Create a new broadcaster with default configuration.
43    pub fn new() -> Self {
44        Self::with_config(BroadcastConfig::default())
45    }
46
47    /// Create a new broadcaster with the given configuration.
48    pub fn with_config(config: BroadcastConfig) -> Self {
49        Self {
50            inner: Arc::new(BroadcasterInner {
51                clients: DashMap::new(),
52                channels: DashMap::new(),
53                authorizer: None,
54                config,
55            }),
56        }
57    }
58
59    /// Set the channel authorizer.
60    pub fn with_authorizer<A: ChannelAuthorizer + 'static>(self, authorizer: A) -> Self {
61        Self {
62            inner: Arc::new(BroadcasterInner {
63                clients: DashMap::new(),
64                channels: DashMap::new(),
65                authorizer: Some(Arc::new(authorizer)),
66                config: self.inner.config.clone(),
67            }),
68        }
69    }
70
71    /// Get the configuration.
72    pub fn config(&self) -> &BroadcastConfig {
73        &self.inner.config
74    }
75
76    /// Register a new client connection.
77    pub fn add_client(&self, socket_id: String, sender: mpsc::Sender<ServerMessage>) {
78        info!(socket_id = %socket_id, "Client connected");
79        self.inner.clients.insert(
80            socket_id.clone(),
81            Client {
82                socket_id,
83                sender,
84                channels: Vec::new(),
85            },
86        );
87    }
88
89    /// Remove a client and clean up their subscriptions.
90    pub fn remove_client(&self, socket_id: &str) {
91        if let Some((_, client)) = self.inner.clients.remove(socket_id) {
92            info!(socket_id = %socket_id, "Client disconnected");
93
94            // Remove from all subscribed channels
95            for channel_name in &client.channels {
96                self.unsubscribe_internal(socket_id, channel_name);
97            }
98        }
99    }
100
101    /// Subscribe a client to a channel.
102    pub async fn subscribe(
103        &self,
104        socket_id: &str,
105        channel_name: &str,
106        auth: Option<&str>,
107        member_info: Option<PresenceMember>,
108    ) -> Result<(), Error> {
109        let channel_type = ChannelType::from_name(channel_name);
110        let config = &self.inner.config;
111
112        // Check max channels limit
113        if config.max_channels > 0
114            && !self.inner.channels.contains_key(channel_name)
115            && self.inner.channels.len() >= config.max_channels
116        {
117            warn!(channel = %channel_name, max = config.max_channels, "Max channels limit reached");
118            return Err(Error::ChannelFull);
119        }
120
121        // Check authorization for private/presence channels
122        if channel_type.requires_auth() {
123            if let Some(authorizer) = &self.inner.authorizer {
124                let auth_data = AuthData {
125                    socket_id: socket_id.to_string(),
126                    channel: channel_name.to_string(),
127                    auth_token: auth.map(|s| s.to_string()),
128                };
129                if !authorizer.authorize(&auth_data).await {
130                    warn!(socket_id = %socket_id, channel = %channel_name, "Authorization failed");
131                    return Err(Error::unauthorized("Channel authorization failed"));
132                }
133            } else if auth.is_none() {
134                return Err(Error::unauthorized("Authorization required"));
135            }
136        }
137
138        // Get or create channel
139        let mut channel = self
140            .inner
141            .channels
142            .entry(channel_name.to_string())
143            .or_insert_with(|| ChannelInfo::new(channel_name));
144
145        // Check max subscribers limit
146        if config.max_subscribers_per_channel > 0
147            && channel.subscriber_count() >= config.max_subscribers_per_channel
148        {
149            warn!(
150                channel = %channel_name,
151                max = config.max_subscribers_per_channel,
152                "Max subscribers per channel limit reached"
153            );
154            return Err(Error::ChannelFull);
155        }
156
157        // Add subscriber
158        channel.add_subscriber(socket_id.to_string());
159
160        // For presence channels, add member info
161        if channel_type == ChannelType::Presence {
162            if let Some(member) = member_info {
163                channel.add_member(member.clone());
164
165                // Notify other members
166                let msg = ServerMessage::MemberAdded {
167                    channel: channel_name.to_string(),
168                    user_id: member.user_id.clone(),
169                    user_info: member.user_info.clone(),
170                };
171                drop(channel); // Release lock before broadcasting
172                self.send_to_channel_except(channel_name, socket_id, &msg)
173                    .await;
174            }
175        } else {
176            drop(channel);
177        }
178
179        // Update client's channel list
180        if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
181            if !client.channels.contains(&channel_name.to_string()) {
182                client.channels.push(channel_name.to_string());
183            }
184        }
185
186        debug!(socket_id = %socket_id, channel = %channel_name, "Subscribed to channel");
187        Ok(())
188    }
189
190    /// Unsubscribe a client from a channel.
191    pub async fn unsubscribe(&self, socket_id: &str, channel_name: &str) {
192        self.unsubscribe_internal(socket_id, channel_name);
193    }
194
195    fn unsubscribe_internal(&self, socket_id: &str, channel_name: &str) {
196        // Remove from channel
197        if let Some(mut channel) = self.inner.channels.get_mut(channel_name) {
198            channel.remove_subscriber(socket_id);
199
200            // For presence channels, notify about member leaving
201            if channel.channel_type == ChannelType::Presence {
202                if let Some(member) = channel.remove_member(socket_id) {
203                    let msg = ServerMessage::MemberRemoved {
204                        channel: channel_name.to_string(),
205                        user_id: member.user_id,
206                    };
207                    // We can't await here, so we'll spawn a task
208                    let channel_name = channel_name.to_string();
209                    let broadcaster = self.clone();
210                    tokio::spawn(async move {
211                        broadcaster.send_to_channel(&channel_name, &msg).await;
212                    });
213                }
214            }
215
216            // Clean up empty channels
217            if channel.is_empty() {
218                drop(channel);
219                self.inner.channels.remove(channel_name);
220            }
221        }
222
223        // Update client's channel list
224        if let Some(mut client) = self.inner.clients.get_mut(socket_id) {
225            client.channels.retain(|c| c != channel_name);
226        }
227
228        debug!(socket_id = %socket_id, channel = %channel_name, "Unsubscribed from channel");
229    }
230
231    /// Broadcast a message to a channel.
232    pub async fn broadcast<T: Serialize>(
233        &self,
234        channel: &str,
235        event: &str,
236        data: T,
237    ) -> Result<(), Error> {
238        let msg = BroadcastMessage::new(channel, event, data);
239        let server_msg = ServerMessage::Event(msg);
240        self.send_to_channel(channel, &server_msg).await;
241        Ok(())
242    }
243
244    /// Broadcast to a channel, excluding a specific client.
245    pub async fn broadcast_except<T: Serialize>(
246        &self,
247        channel: &str,
248        event: &str,
249        data: T,
250        except_socket_id: &str,
251    ) -> Result<(), Error> {
252        let msg = BroadcastMessage::new(channel, event, data);
253        let server_msg = ServerMessage::Event(msg);
254        self.send_to_channel_except(channel, except_socket_id, &server_msg)
255            .await;
256        Ok(())
257    }
258
259    /// Send a message to all subscribers of a channel.
260    async fn send_to_channel(&self, channel_name: &str, msg: &ServerMessage) {
261        if let Some(channel) = self.inner.channels.get(channel_name) {
262            for socket_id in channel.subscribers.iter() {
263                self.send_to_client(socket_id, msg.clone()).await;
264            }
265        }
266    }
267
268    /// Send a message to all subscribers except one.
269    async fn send_to_channel_except(
270        &self,
271        channel_name: &str,
272        except_socket_id: &str,
273        msg: &ServerMessage,
274    ) {
275        if let Some(channel) = self.inner.channels.get(channel_name) {
276            for socket_id in channel.subscribers.iter() {
277                if socket_id.as_str() != except_socket_id {
278                    self.send_to_client(socket_id, msg.clone()).await;
279                }
280            }
281        }
282    }
283
284    /// Send a message to a specific client.
285    async fn send_to_client(&self, socket_id: &str, msg: ServerMessage) {
286        if let Some(client) = self.inner.clients.get(socket_id) {
287            if let Err(e) = client.sender.send(msg).await {
288                warn!(socket_id = %socket_id, error = %e, "Failed to send message to client");
289            }
290        }
291    }
292
293    /// Get channel info.
294    pub fn get_channel(&self, name: &str) -> Option<ChannelInfo> {
295        self.inner.channels.get(name).map(|c| c.clone())
296    }
297
298    /// Get number of connected clients.
299    pub fn client_count(&self) -> usize {
300        self.inner.clients.len()
301    }
302
303    /// Get number of active channels.
304    pub fn channel_count(&self) -> usize {
305        self.inner.channels.len()
306    }
307}
308
309impl Default for Broadcaster {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315/// Authorization data for private/presence channels.
316#[derive(Debug, Clone)]
317pub struct AuthData {
318    /// The socket ID requesting access.
319    pub socket_id: String,
320    /// The channel name.
321    pub channel: String,
322    /// Optional auth token from the client.
323    pub auth_token: Option<String>,
324}
325
326/// Trait for authorizing channel access.
327#[async_trait::async_trait]
328pub trait ChannelAuthorizer: Send + Sync {
329    /// Check if access should be granted.
330    async fn authorize(&self, data: &AuthData) -> bool;
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[tokio::test]
338    async fn test_broadcaster_basic() {
339        let broadcaster = Broadcaster::new();
340        let (tx, _rx) = mpsc::channel(32);
341
342        broadcaster.add_client("socket_1".into(), tx);
343        assert_eq!(broadcaster.client_count(), 1);
344
345        broadcaster.remove_client("socket_1");
346        assert_eq!(broadcaster.client_count(), 0);
347    }
348
349    #[tokio::test]
350    async fn test_subscribe_public_channel() {
351        let broadcaster = Broadcaster::new();
352        let (tx, _rx) = mpsc::channel(32);
353
354        broadcaster.add_client("socket_1".into(), tx);
355        broadcaster
356            .subscribe("socket_1", "orders", None, None)
357            .await
358            .unwrap();
359
360        assert_eq!(broadcaster.channel_count(), 1);
361        let channel = broadcaster.get_channel("orders").unwrap();
362        assert_eq!(channel.subscriber_count(), 1);
363    }
364
365    #[tokio::test]
366    async fn test_subscribe_private_requires_auth() {
367        let broadcaster = Broadcaster::new();
368        let (tx, _rx) = mpsc::channel(32);
369
370        broadcaster.add_client("socket_1".into(), tx);
371        let result = broadcaster
372            .subscribe("socket_1", "private-orders.1", None, None)
373            .await;
374
375        assert!(result.is_err());
376    }
377}