guts_realtime/
hub.rs

1//! Event hub for managing WebSocket connections and broadcasting.
2
3use crate::client::{create_client, Client, ClientId, ClientReceiver};
4use crate::error::RealtimeError;
5use crate::event::{EventKind, RealtimeEvent};
6use crate::subscription::Channel;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::broadcast;
12use tracing::{debug, info};
13
14/// Capacity of the broadcast channel.
15const BROADCAST_CAPACITY: usize = 1024;
16
17/// Maximum number of concurrent connections.
18const MAX_CONNECTIONS: usize = 10000;
19
20/// Event hub manages all WebSocket connections and event broadcasting.
21#[derive(Debug)]
22pub struct EventHub {
23    /// Connected clients indexed by ID.
24    clients: RwLock<HashMap<ClientId, Arc<Client>>>,
25    /// Broadcast channel for events.
26    event_tx: broadcast::Sender<RealtimeEvent>,
27    /// Statistics.
28    stats: RwLock<HubStats>,
29}
30
31impl EventHub {
32    /// Create a new event hub.
33    pub fn new() -> Self {
34        let (event_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
35        Self {
36            clients: RwLock::new(HashMap::new()),
37            event_tx,
38            stats: RwLock::new(HubStats::default()),
39        }
40    }
41
42    /// Connect a new client and return its message receiver.
43    pub fn connect(&self) -> Result<(Arc<Client>, ClientReceiver), RealtimeError> {
44        let clients = self.clients.read();
45        if clients.len() >= MAX_CONNECTIONS {
46            return Err(RealtimeError::SendFailed(
47                "maximum connections reached".to_string(),
48            ));
49        }
50        drop(clients);
51
52        let client_id = uuid::Uuid::new_v4().to_string();
53        let (client, receiver) = create_client(client_id.clone());
54
55        self.clients
56            .write()
57            .insert(client_id.clone(), client.clone());
58        self.stats.write().total_connections += 1;
59
60        info!(client_id = %client_id, "Client connected");
61
62        Ok((client, receiver))
63    }
64
65    /// Disconnect a client.
66    pub fn disconnect(&self, client_id: &str) {
67        if let Some(client) = self.clients.write().remove(client_id) {
68            client.clear_subscriptions();
69            info!(client_id = %client_id, "Client disconnected");
70        }
71    }
72
73    /// Get a client by ID.
74    pub fn get_client(&self, client_id: &str) -> Option<Arc<Client>> {
75        self.clients.read().get(client_id).cloned()
76    }
77
78    /// Handle a client command.
79    pub fn handle_command(
80        &self,
81        client: &Arc<Client>,
82        command: ClientCommand,
83    ) -> Result<ServerMessage, RealtimeError> {
84        match command {
85            ClientCommand::Subscribe { channel } => {
86                let parsed = Channel::parse(&channel)?;
87                let is_new = client.subscribe(parsed)?;
88
89                if is_new {
90                    debug!(client_id = %client.id, channel = %channel, "Client subscribed");
91                    self.stats.write().total_subscriptions += 1;
92                }
93
94                Ok(ServerMessage::Subscribed { channel })
95            }
96            ClientCommand::Unsubscribe { channel } => {
97                let parsed = Channel::parse(&channel)?;
98                let was_subscribed = client.unsubscribe(&parsed);
99
100                if was_subscribed {
101                    debug!(client_id = %client.id, channel = %channel, "Client unsubscribed");
102                }
103
104                Ok(ServerMessage::Unsubscribed { channel })
105            }
106            ClientCommand::Ping => Ok(ServerMessage::Pong),
107        }
108    }
109
110    /// Emit an event to all subscribed clients.
111    pub fn emit(&self, event: RealtimeEvent) {
112        let channel = event.channel.clone();
113        let event_kind = event.event;
114
115        // Count how many clients will receive this
116        let mut recipient_count = 0;
117        let clients = self.clients.read();
118
119        for client in clients.values() {
120            if client.matches_event(&channel) {
121                if let Ok(json) = serde_json::to_string(&event) {
122                    if client.send(json).is_ok() {
123                        recipient_count += 1;
124                    }
125                }
126            }
127        }
128
129        drop(clients);
130
131        // Also send to broadcast channel for any listeners
132        let _ = self.event_tx.send(event);
133
134        self.stats.write().total_events += 1;
135
136        debug!(
137            channel = %channel,
138            event = %event_kind,
139            recipients = recipient_count,
140            "Event broadcast"
141        );
142    }
143
144    /// Emit an event with the given parameters.
145    pub fn emit_event(&self, channel: String, event: EventKind, data: serde_json::Value) {
146        self.emit(RealtimeEvent::new(channel, event, data));
147    }
148
149    /// Subscribe to the broadcast channel for events.
150    pub fn subscribe_events(&self) -> broadcast::Receiver<RealtimeEvent> {
151        self.event_tx.subscribe()
152    }
153
154    /// Get current connection count.
155    pub fn connection_count(&self) -> usize {
156        self.clients.read().len()
157    }
158
159    /// Get hub statistics.
160    pub fn stats(&self) -> HubStats {
161        let mut stats = self.stats.read().clone();
162        stats.current_connections = self.connection_count();
163        stats
164    }
165
166    /// Broadcast a message to all clients (for system announcements).
167    pub fn broadcast_all(&self, message: &str) {
168        let clients = self.clients.read();
169        for client in clients.values() {
170            let _ = client.send(message.to_string());
171        }
172    }
173}
174
175impl Default for EventHub {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Commands that clients can send.
182#[derive(Debug, Clone, Serialize, Deserialize)]
183#[serde(tag = "type", rename_all = "snake_case")]
184pub enum ClientCommand {
185    /// Subscribe to a channel.
186    Subscribe { channel: String },
187    /// Unsubscribe from a channel.
188    Unsubscribe { channel: String },
189    /// Ping for keepalive.
190    Ping,
191}
192
193/// Messages sent from server to client.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195#[serde(tag = "type", rename_all = "snake_case")]
196pub enum ServerMessage {
197    /// Subscription confirmed.
198    Subscribed { channel: String },
199    /// Unsubscription confirmed.
200    Unsubscribed { channel: String },
201    /// Pong response to ping.
202    Pong,
203    /// Error message.
204    Error { message: String },
205}
206
207/// Hub statistics.
208#[derive(Debug, Clone, Default)]
209pub struct HubStats {
210    /// Current number of connections.
211    pub current_connections: usize,
212    /// Total connections since start.
213    pub total_connections: u64,
214    /// Total subscriptions since start.
215    pub total_subscriptions: u64,
216    /// Total events broadcast since start.
217    pub total_events: u64,
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[tokio::test]
225    async fn test_hub_connect() {
226        let hub = EventHub::new();
227        let (client, _rx) = hub.connect().unwrap();
228
229        assert!(!client.id.is_empty());
230        assert_eq!(hub.connection_count(), 1);
231    }
232
233    #[tokio::test]
234    async fn test_hub_disconnect() {
235        let hub = EventHub::new();
236        let (client, _rx) = hub.connect().unwrap();
237        let client_id = client.id.clone();
238
239        hub.disconnect(&client_id);
240        assert_eq!(hub.connection_count(), 0);
241    }
242
243    #[tokio::test]
244    async fn test_hub_subscribe_command() {
245        let hub = EventHub::new();
246        let (client, _rx) = hub.connect().unwrap();
247
248        let cmd = ClientCommand::Subscribe {
249            channel: "repo:alice/myrepo".to_string(),
250        };
251
252        let response = hub.handle_command(&client, cmd).unwrap();
253        assert!(matches!(response, ServerMessage::Subscribed { .. }));
254        assert_eq!(client.subscription_count(), 1);
255    }
256
257    #[tokio::test]
258    async fn test_hub_unsubscribe_command() {
259        let hub = EventHub::new();
260        let (client, _rx) = hub.connect().unwrap();
261
262        // Subscribe first
263        hub.handle_command(
264            &client,
265            ClientCommand::Subscribe {
266                channel: "repo:alice/myrepo".to_string(),
267            },
268        )
269        .unwrap();
270
271        // Then unsubscribe
272        let response = hub
273            .handle_command(
274                &client,
275                ClientCommand::Unsubscribe {
276                    channel: "repo:alice/myrepo".to_string(),
277                },
278            )
279            .unwrap();
280
281        assert!(matches!(response, ServerMessage::Unsubscribed { .. }));
282        assert_eq!(client.subscription_count(), 0);
283    }
284
285    #[tokio::test]
286    async fn test_hub_ping_pong() {
287        let hub = EventHub::new();
288        let (client, _rx) = hub.connect().unwrap();
289
290        let response = hub.handle_command(&client, ClientCommand::Ping).unwrap();
291        assert!(matches!(response, ServerMessage::Pong));
292    }
293
294    #[tokio::test]
295    async fn test_hub_emit_event() {
296        let hub = EventHub::new();
297        let (client, mut rx) = hub.connect().unwrap();
298
299        // Subscribe to repository
300        hub.handle_command(
301            &client,
302            ClientCommand::Subscribe {
303                channel: "repo:alice/myrepo".to_string(),
304            },
305        )
306        .unwrap();
307
308        // Emit an event
309        hub.emit_event(
310            "repo:alice/myrepo".to_string(),
311            EventKind::Push,
312            serde_json::json!({"ref": "refs/heads/main"}),
313        );
314
315        // Client should receive the event
316        let msg = rx.try_recv().unwrap();
317        assert!(msg.contains("push"));
318        assert!(msg.contains("repo:alice/myrepo"));
319    }
320
321    #[tokio::test]
322    async fn test_hub_emit_filtered() {
323        let hub = EventHub::new();
324        let (client1, mut rx1) = hub.connect().unwrap();
325        let (client2, mut rx2) = hub.connect().unwrap();
326
327        // Client 1 subscribes to alice/myrepo
328        hub.handle_command(
329            &client1,
330            ClientCommand::Subscribe {
331                channel: "repo:alice/myrepo".to_string(),
332            },
333        )
334        .unwrap();
335
336        // Client 2 subscribes to bob/otherrepo
337        hub.handle_command(
338            &client2,
339            ClientCommand::Subscribe {
340                channel: "repo:bob/otherrepo".to_string(),
341            },
342        )
343        .unwrap();
344
345        // Emit an event for alice/myrepo
346        hub.emit_event(
347            "repo:alice/myrepo".to_string(),
348            EventKind::Push,
349            serde_json::json!({}),
350        );
351
352        // Client 1 should receive it
353        assert!(rx1.try_recv().is_ok());
354
355        // Client 2 should not receive it
356        assert!(rx2.try_recv().is_err());
357    }
358
359    #[tokio::test]
360    async fn test_hub_stats() {
361        let hub = EventHub::new();
362
363        let (client, _rx) = hub.connect().unwrap();
364        hub.handle_command(
365            &client,
366            ClientCommand::Subscribe {
367                channel: "repo:alice/myrepo".to_string(),
368            },
369        )
370        .unwrap();
371        hub.emit_event(
372            "repo:alice/myrepo".to_string(),
373            EventKind::Push,
374            serde_json::json!({}),
375        );
376
377        let stats = hub.stats();
378        assert_eq!(stats.current_connections, 1);
379        assert_eq!(stats.total_connections, 1);
380        assert_eq!(stats.total_subscriptions, 1);
381        assert_eq!(stats.total_events, 1);
382    }
383
384    #[test]
385    fn test_client_command_serialization() {
386        let cmd = ClientCommand::Subscribe {
387            channel: "repo:alice/myrepo".to_string(),
388        };
389        let json = serde_json::to_string(&cmd).unwrap();
390        assert!(json.contains("subscribe"));
391        assert!(json.contains("repo:alice/myrepo"));
392
393        let parsed: ClientCommand = serde_json::from_str(&json).unwrap();
394        assert!(matches!(parsed, ClientCommand::Subscribe { .. }));
395    }
396
397    #[test]
398    fn test_server_message_serialization() {
399        let msg = ServerMessage::Subscribed {
400            channel: "repo:test/repo".to_string(),
401        };
402        let json = serde_json::to_string(&msg).unwrap();
403        assert!(json.contains("subscribed"));
404
405        let pong = ServerMessage::Pong;
406        let json = serde_json::to_string(&pong).unwrap();
407        assert!(json.contains("pong"));
408    }
409}