1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Delta, SessionId, SubscriptionId};
9
10use crate::gateway::websocket::{JobData, WorkflowData};
11
12#[derive(Debug, Clone)]
14pub struct WebSocketConfig {
15 pub max_subscriptions_per_connection: usize,
17 pub subscription_timeout: Duration,
19 pub subscription_rate_limit: usize,
21 pub heartbeat_interval: Duration,
23 pub max_message_size: usize,
25 pub reconnect: ReconnectConfig,
27}
28
29impl Default for WebSocketConfig {
30 fn default() -> Self {
31 Self {
32 max_subscriptions_per_connection: 50,
33 subscription_timeout: Duration::from_secs(30),
34 subscription_rate_limit: 100,
35 heartbeat_interval: Duration::from_secs(30),
36 max_message_size: 1024 * 1024, reconnect: ReconnectConfig::default(),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct ReconnectConfig {
45 pub enabled: bool,
47 pub max_attempts: usize,
49 pub delay: Duration,
51 pub max_delay: Duration,
53 pub backoff: BackoffStrategy,
55}
56
57impl Default for ReconnectConfig {
58 fn default() -> Self {
59 Self {
60 enabled: true,
61 max_attempts: 10,
62 delay: Duration::from_secs(1),
63 max_delay: Duration::from_secs(30),
64 backoff: BackoffStrategy::Exponential,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BackoffStrategy {
72 Linear,
74 Exponential,
76 Fixed,
78}
79
80#[derive(Debug, Clone)]
82pub enum WebSocketMessage {
83 Subscribe {
85 id: String,
86 query: String,
87 args: serde_json::Value,
88 },
89 Unsubscribe { subscription_id: SubscriptionId },
91 Ping,
93 Pong,
95 Data {
97 subscription_id: SubscriptionId,
98 data: serde_json::Value,
99 },
100 DeltaUpdate {
102 subscription_id: SubscriptionId,
103 delta: Delta<serde_json::Value>,
104 },
105 JobUpdate { client_sub_id: String, job: JobData },
107 WorkflowUpdate {
109 client_sub_id: String,
110 workflow: WorkflowData,
111 },
112 Error { code: String, message: String },
114 ErrorWithId {
116 id: String,
117 code: String,
118 message: String,
119 },
120 AuthSuccess,
122 AuthFailed { reason: String },
124}
125
126#[derive(Debug)]
128pub struct WebSocketConnection {
129 #[allow(dead_code)]
131 pub session_id: SessionId,
132 pub subscriptions: Vec<SubscriptionId>,
134 pub sender: mpsc::Sender<WebSocketMessage>,
136 #[allow(dead_code)]
138 pub connected_at: chrono::DateTime<chrono::Utc>,
139 pub last_active: chrono::DateTime<chrono::Utc>,
141}
142
143impl WebSocketConnection {
144 pub fn new(session_id: SessionId, sender: mpsc::Sender<WebSocketMessage>) -> Self {
146 let now = chrono::Utc::now();
147 Self {
148 session_id,
149 subscriptions: Vec::new(),
150 sender,
151 connected_at: now,
152 last_active: now,
153 }
154 }
155
156 pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
158 self.subscriptions.push(subscription_id);
159 self.last_active = chrono::Utc::now();
160 }
161
162 pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
164 self.subscriptions.retain(|id| *id != subscription_id);
165 self.last_active = chrono::Utc::now();
166 }
167
168 pub async fn send(
170 &self,
171 message: WebSocketMessage,
172 ) -> Result<(), mpsc::error::SendError<WebSocketMessage>> {
173 self.sender.send(message).await
174 }
175}
176
177pub struct WebSocketServer {
179 #[allow(dead_code)]
180 config: WebSocketConfig,
181 node_id: NodeId,
182 connections: Arc<RwLock<HashMap<SessionId, WebSocketConnection>>>,
184 subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
186}
187
188impl WebSocketServer {
189 pub fn new(node_id: NodeId, config: WebSocketConfig) -> Self {
191 Self {
192 config,
193 node_id,
194 connections: Arc::new(RwLock::new(HashMap::new())),
195 subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
196 }
197 }
198
199 pub fn node_id(&self) -> NodeId {
201 self.node_id
202 }
203
204 pub fn config(&self) -> &WebSocketConfig {
206 &self.config
207 }
208
209 pub async fn register_connection(
211 &self,
212 session_id: SessionId,
213 sender: mpsc::Sender<WebSocketMessage>,
214 ) {
215 let connection = WebSocketConnection::new(session_id, sender);
216 let mut connections = self.connections.write().await;
217 connections.insert(session_id, connection);
218 }
219
220 pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
222 let mut connections = self.connections.write().await;
223 if let Some(conn) = connections.remove(&session_id) {
224 let mut sub_sessions = self.subscription_sessions.write().await;
226 for sub_id in &conn.subscriptions {
227 sub_sessions.remove(sub_id);
228 }
229 Some(conn.subscriptions)
230 } else {
231 None
232 }
233 }
234
235 pub async fn add_subscription(
237 &self,
238 session_id: SessionId,
239 subscription_id: SubscriptionId,
240 ) -> forge_core::Result<()> {
241 let mut connections = self.connections.write().await;
242 let conn = connections
243 .get_mut(&session_id)
244 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
245
246 if conn.subscriptions.len() >= self.config.max_subscriptions_per_connection {
248 return Err(forge_core::ForgeError::Validation(format!(
249 "Maximum subscriptions per connection ({}) exceeded",
250 self.config.max_subscriptions_per_connection
251 )));
252 }
253
254 conn.add_subscription(subscription_id);
255
256 let mut sub_sessions = self.subscription_sessions.write().await;
258 sub_sessions.insert(subscription_id, session_id);
259
260 Ok(())
261 }
262
263 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
265 let session_id = {
266 let mut sub_sessions = self.subscription_sessions.write().await;
267 sub_sessions.remove(&subscription_id)
268 };
269
270 if let Some(session_id) = session_id {
271 let mut connections = self.connections.write().await;
272 if let Some(conn) = connections.get_mut(&session_id) {
273 conn.remove_subscription(subscription_id);
274 }
275 }
276 }
277
278 pub async fn send_to_session(
280 &self,
281 session_id: SessionId,
282 message: WebSocketMessage,
283 ) -> forge_core::Result<()> {
284 let connections = self.connections.read().await;
285 let conn = connections
286 .get(&session_id)
287 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
288
289 conn.send(message)
290 .await
291 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
292 }
293
294 pub async fn broadcast_delta(
296 &self,
297 subscription_id: SubscriptionId,
298 delta: Delta<serde_json::Value>,
299 ) -> forge_core::Result<()> {
300 let session_id = {
301 let sub_sessions = self.subscription_sessions.read().await;
302 sub_sessions.get(&subscription_id).copied()
303 };
304
305 if let Some(session_id) = session_id {
306 let message = WebSocketMessage::DeltaUpdate {
307 subscription_id,
308 delta,
309 };
310 self.send_to_session(session_id, message).await?;
311 }
312
313 Ok(())
314 }
315
316 pub async fn connection_count(&self) -> usize {
318 self.connections.read().await.len()
319 }
320
321 pub async fn subscription_count(&self) -> usize {
323 self.subscription_sessions.read().await.len()
324 }
325
326 pub async fn stats(&self) -> WebSocketStats {
328 let connections = self.connections.read().await;
329 let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
330
331 WebSocketStats {
332 connections: connections.len(),
333 subscriptions: total_subscriptions,
334 node_id: self.node_id,
335 }
336 }
337
338 pub async fn cleanup_stale(&self, max_idle: Duration) {
340 let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
341 let mut connections = self.connections.write().await;
342 let mut sub_sessions = self.subscription_sessions.write().await;
343
344 connections.retain(|_, conn| {
345 if conn.last_active < cutoff {
346 for sub_id in &conn.subscriptions {
348 sub_sessions.remove(sub_id);
349 }
350 false
351 } else {
352 true
353 }
354 });
355 }
356}
357
358#[derive(Debug, Clone)]
360pub struct WebSocketStats {
361 pub connections: usize,
363 pub subscriptions: usize,
365 pub node_id: NodeId,
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_websocket_config_default() {
375 let config = WebSocketConfig::default();
376 assert_eq!(config.max_subscriptions_per_connection, 50);
377 assert_eq!(config.subscription_rate_limit, 100);
378 assert!(config.reconnect.enabled);
379 }
380
381 #[test]
382 fn test_reconnect_config_default() {
383 let config = ReconnectConfig::default();
384 assert!(config.enabled);
385 assert_eq!(config.max_attempts, 10);
386 assert_eq!(config.backoff, BackoffStrategy::Exponential);
387 }
388
389 #[tokio::test]
390 async fn test_websocket_server_creation() {
391 let node_id = NodeId::new();
392 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
393
394 assert_eq!(server.node_id(), node_id);
395 assert_eq!(server.connection_count().await, 0);
396 assert_eq!(server.subscription_count().await, 0);
397 }
398
399 #[tokio::test]
400 async fn test_websocket_connection() {
401 let node_id = NodeId::new();
402 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
403 let session_id = SessionId::new();
404 let (tx, _rx) = mpsc::channel(100);
405
406 server.register_connection(session_id, tx).await;
407 assert_eq!(server.connection_count().await, 1);
408
409 let removed = server.remove_connection(session_id).await;
410 assert!(removed.is_some());
411 assert_eq!(server.connection_count().await, 0);
412 }
413
414 #[tokio::test]
415 async fn test_websocket_subscription() {
416 let node_id = NodeId::new();
417 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
418 let session_id = SessionId::new();
419 let subscription_id = SubscriptionId::new();
420 let (tx, _rx) = mpsc::channel(100);
421
422 server.register_connection(session_id, tx).await;
423 server
424 .add_subscription(session_id, subscription_id)
425 .await
426 .unwrap();
427
428 assert_eq!(server.subscription_count().await, 1);
429
430 server.remove_subscription(subscription_id).await;
431 assert_eq!(server.subscription_count().await, 0);
432 }
433
434 #[tokio::test]
435 async fn test_websocket_subscription_limit() {
436 let node_id = NodeId::new();
437 let config = WebSocketConfig {
438 max_subscriptions_per_connection: 2,
439 ..Default::default()
440 };
441 let server = WebSocketServer::new(node_id, config);
442 let session_id = SessionId::new();
443 let (tx, _rx) = mpsc::channel(100);
444
445 server.register_connection(session_id, tx).await;
446
447 server
449 .add_subscription(session_id, SubscriptionId::new())
450 .await
451 .unwrap();
452 server
453 .add_subscription(session_id, SubscriptionId::new())
454 .await
455 .unwrap();
456
457 let result = server
459 .add_subscription(session_id, SubscriptionId::new())
460 .await;
461 assert!(result.is_err());
462 }
463
464 #[tokio::test]
465 async fn test_websocket_stats() {
466 let node_id = NodeId::new();
467 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
468 let session_id = SessionId::new();
469 let (tx, _rx) = mpsc::channel(100);
470
471 server.register_connection(session_id, tx).await;
472 server
473 .add_subscription(session_id, SubscriptionId::new())
474 .await
475 .unwrap();
476 server
477 .add_subscription(session_id, SubscriptionId::new())
478 .await
479 .unwrap();
480
481 let stats = server.stats().await;
482 assert_eq!(stats.connections, 1);
483 assert_eq!(stats.subscriptions, 2);
484 assert_eq!(stats.node_id, node_id);
485 }
486}