1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{mpsc, RwLock};
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Delta, SessionId, SubscriptionId};
9
10use crate::gateway::websocket::{JobData, WorkflowData};
11
12#[derive(Debug, Clone)]
14pub struct WebSocketConfig {
15 pub max_subscriptions_per_connection: usize,
17 pub subscription_timeout: Duration,
19 pub subscription_rate_limit: usize,
21 pub heartbeat_interval: Duration,
23 pub max_message_size: usize,
25 pub reconnect: ReconnectConfig,
27}
28
29impl Default for WebSocketConfig {
30 fn default() -> Self {
31 Self {
32 max_subscriptions_per_connection: 50,
33 subscription_timeout: Duration::from_secs(30),
34 subscription_rate_limit: 100,
35 heartbeat_interval: Duration::from_secs(30),
36 max_message_size: 1024 * 1024, reconnect: ReconnectConfig::default(),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct ReconnectConfig {
45 pub enabled: bool,
47 pub max_attempts: usize,
49 pub delay: Duration,
51 pub max_delay: Duration,
53 pub backoff: BackoffStrategy,
55}
56
57impl Default for ReconnectConfig {
58 fn default() -> Self {
59 Self {
60 enabled: true,
61 max_attempts: 10,
62 delay: Duration::from_secs(1),
63 max_delay: Duration::from_secs(30),
64 backoff: BackoffStrategy::Exponential,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BackoffStrategy {
72 Linear,
74 Exponential,
76 Fixed,
78}
79
80#[derive(Debug, Clone)]
82pub enum WebSocketMessage {
83 Subscribe {
85 id: String,
86 query: String,
87 args: serde_json::Value,
88 },
89 Unsubscribe { subscription_id: SubscriptionId },
91 Ping,
93 Pong,
95 Data {
97 subscription_id: SubscriptionId,
98 data: serde_json::Value,
99 },
100 DeltaUpdate {
102 subscription_id: SubscriptionId,
103 delta: Delta<serde_json::Value>,
104 },
105 JobUpdate { client_sub_id: String, job: JobData },
107 WorkflowUpdate {
109 client_sub_id: String,
110 workflow: WorkflowData,
111 },
112 Error { code: String, message: String },
114 ErrorWithId {
116 id: String,
117 code: String,
118 message: String,
119 },
120}
121
122#[derive(Debug)]
124pub struct WebSocketConnection {
125 #[allow(dead_code)]
127 pub session_id: SessionId,
128 pub subscriptions: Vec<SubscriptionId>,
130 pub sender: mpsc::Sender<WebSocketMessage>,
132 #[allow(dead_code)]
134 pub connected_at: chrono::DateTime<chrono::Utc>,
135 pub last_active: chrono::DateTime<chrono::Utc>,
137}
138
139impl WebSocketConnection {
140 pub fn new(session_id: SessionId, sender: mpsc::Sender<WebSocketMessage>) -> Self {
142 let now = chrono::Utc::now();
143 Self {
144 session_id,
145 subscriptions: Vec::new(),
146 sender,
147 connected_at: now,
148 last_active: now,
149 }
150 }
151
152 pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
154 self.subscriptions.push(subscription_id);
155 self.last_active = chrono::Utc::now();
156 }
157
158 pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
160 self.subscriptions.retain(|id| *id != subscription_id);
161 self.last_active = chrono::Utc::now();
162 }
163
164 pub async fn send(
166 &self,
167 message: WebSocketMessage,
168 ) -> Result<(), mpsc::error::SendError<WebSocketMessage>> {
169 self.sender.send(message).await
170 }
171}
172
173pub struct WebSocketServer {
175 #[allow(dead_code)]
176 config: WebSocketConfig,
177 node_id: NodeId,
178 connections: Arc<RwLock<HashMap<SessionId, WebSocketConnection>>>,
180 subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
182}
183
184impl WebSocketServer {
185 pub fn new(node_id: NodeId, config: WebSocketConfig) -> Self {
187 Self {
188 config,
189 node_id,
190 connections: Arc::new(RwLock::new(HashMap::new())),
191 subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
192 }
193 }
194
195 pub fn node_id(&self) -> NodeId {
197 self.node_id
198 }
199
200 pub fn config(&self) -> &WebSocketConfig {
202 &self.config
203 }
204
205 pub async fn register_connection(
207 &self,
208 session_id: SessionId,
209 sender: mpsc::Sender<WebSocketMessage>,
210 ) {
211 let connection = WebSocketConnection::new(session_id, sender);
212 let mut connections = self.connections.write().await;
213 connections.insert(session_id, connection);
214 }
215
216 pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
218 let mut connections = self.connections.write().await;
219 if let Some(conn) = connections.remove(&session_id) {
220 let mut sub_sessions = self.subscription_sessions.write().await;
222 for sub_id in &conn.subscriptions {
223 sub_sessions.remove(sub_id);
224 }
225 Some(conn.subscriptions)
226 } else {
227 None
228 }
229 }
230
231 pub async fn add_subscription(
233 &self,
234 session_id: SessionId,
235 subscription_id: SubscriptionId,
236 ) -> forge_core::Result<()> {
237 let mut connections = self.connections.write().await;
238 let conn = connections
239 .get_mut(&session_id)
240 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
241
242 if conn.subscriptions.len() >= self.config.max_subscriptions_per_connection {
244 return Err(forge_core::ForgeError::Validation(format!(
245 "Maximum subscriptions per connection ({}) exceeded",
246 self.config.max_subscriptions_per_connection
247 )));
248 }
249
250 conn.add_subscription(subscription_id);
251
252 let mut sub_sessions = self.subscription_sessions.write().await;
254 sub_sessions.insert(subscription_id, session_id);
255
256 Ok(())
257 }
258
259 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
261 let session_id = {
262 let mut sub_sessions = self.subscription_sessions.write().await;
263 sub_sessions.remove(&subscription_id)
264 };
265
266 if let Some(session_id) = session_id {
267 let mut connections = self.connections.write().await;
268 if let Some(conn) = connections.get_mut(&session_id) {
269 conn.remove_subscription(subscription_id);
270 }
271 }
272 }
273
274 pub async fn send_to_session(
276 &self,
277 session_id: SessionId,
278 message: WebSocketMessage,
279 ) -> forge_core::Result<()> {
280 let connections = self.connections.read().await;
281 let conn = connections
282 .get(&session_id)
283 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
284
285 conn.send(message)
286 .await
287 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
288 }
289
290 pub async fn broadcast_delta(
292 &self,
293 subscription_id: SubscriptionId,
294 delta: Delta<serde_json::Value>,
295 ) -> forge_core::Result<()> {
296 let session_id = {
297 let sub_sessions = self.subscription_sessions.read().await;
298 sub_sessions.get(&subscription_id).copied()
299 };
300
301 if let Some(session_id) = session_id {
302 let message = WebSocketMessage::DeltaUpdate {
303 subscription_id,
304 delta,
305 };
306 self.send_to_session(session_id, message).await?;
307 }
308
309 Ok(())
310 }
311
312 pub async fn connection_count(&self) -> usize {
314 self.connections.read().await.len()
315 }
316
317 pub async fn subscription_count(&self) -> usize {
319 self.subscription_sessions.read().await.len()
320 }
321
322 pub async fn stats(&self) -> WebSocketStats {
324 let connections = self.connections.read().await;
325 let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
326
327 WebSocketStats {
328 connections: connections.len(),
329 subscriptions: total_subscriptions,
330 node_id: self.node_id,
331 }
332 }
333
334 pub async fn cleanup_stale(&self, max_idle: Duration) {
336 let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
337 let mut connections = self.connections.write().await;
338 let mut sub_sessions = self.subscription_sessions.write().await;
339
340 connections.retain(|_, conn| {
341 if conn.last_active < cutoff {
342 for sub_id in &conn.subscriptions {
344 sub_sessions.remove(sub_id);
345 }
346 false
347 } else {
348 true
349 }
350 });
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct WebSocketStats {
357 pub connections: usize,
359 pub subscriptions: usize,
361 pub node_id: NodeId,
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_websocket_config_default() {
371 let config = WebSocketConfig::default();
372 assert_eq!(config.max_subscriptions_per_connection, 50);
373 assert_eq!(config.subscription_rate_limit, 100);
374 assert!(config.reconnect.enabled);
375 }
376
377 #[test]
378 fn test_reconnect_config_default() {
379 let config = ReconnectConfig::default();
380 assert!(config.enabled);
381 assert_eq!(config.max_attempts, 10);
382 assert_eq!(config.backoff, BackoffStrategy::Exponential);
383 }
384
385 #[tokio::test]
386 async fn test_websocket_server_creation() {
387 let node_id = NodeId::new();
388 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
389
390 assert_eq!(server.node_id(), node_id);
391 assert_eq!(server.connection_count().await, 0);
392 assert_eq!(server.subscription_count().await, 0);
393 }
394
395 #[tokio::test]
396 async fn test_websocket_connection() {
397 let node_id = NodeId::new();
398 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
399 let session_id = SessionId::new();
400 let (tx, _rx) = mpsc::channel(100);
401
402 server.register_connection(session_id, tx).await;
403 assert_eq!(server.connection_count().await, 1);
404
405 let removed = server.remove_connection(session_id).await;
406 assert!(removed.is_some());
407 assert_eq!(server.connection_count().await, 0);
408 }
409
410 #[tokio::test]
411 async fn test_websocket_subscription() {
412 let node_id = NodeId::new();
413 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
414 let session_id = SessionId::new();
415 let subscription_id = SubscriptionId::new();
416 let (tx, _rx) = mpsc::channel(100);
417
418 server.register_connection(session_id, tx).await;
419 server
420 .add_subscription(session_id, subscription_id)
421 .await
422 .unwrap();
423
424 assert_eq!(server.subscription_count().await, 1);
425
426 server.remove_subscription(subscription_id).await;
427 assert_eq!(server.subscription_count().await, 0);
428 }
429
430 #[tokio::test]
431 async fn test_websocket_subscription_limit() {
432 let node_id = NodeId::new();
433 let config = WebSocketConfig {
434 max_subscriptions_per_connection: 2,
435 ..Default::default()
436 };
437 let server = WebSocketServer::new(node_id, config);
438 let session_id = SessionId::new();
439 let (tx, _rx) = mpsc::channel(100);
440
441 server.register_connection(session_id, tx).await;
442
443 server
445 .add_subscription(session_id, SubscriptionId::new())
446 .await
447 .unwrap();
448 server
449 .add_subscription(session_id, SubscriptionId::new())
450 .await
451 .unwrap();
452
453 let result = server
455 .add_subscription(session_id, SubscriptionId::new())
456 .await;
457 assert!(result.is_err());
458 }
459
460 #[tokio::test]
461 async fn test_websocket_stats() {
462 let node_id = NodeId::new();
463 let server = WebSocketServer::new(node_id, WebSocketConfig::default());
464 let session_id = SessionId::new();
465 let (tx, _rx) = mpsc::channel(100);
466
467 server.register_connection(session_id, tx).await;
468 server
469 .add_subscription(session_id, SubscriptionId::new())
470 .await
471 .unwrap();
472 server
473 .add_subscription(session_id, SubscriptionId::new())
474 .await
475 .unwrap();
476
477 let stats = server.stats().await;
478 assert_eq!(stats.connections, 1);
479 assert_eq!(stats.subscriptions, 2);
480 assert_eq!(stats.node_id, node_id);
481 }
482}