forge_runtime/realtime/
websocket.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Delta, SessionId, SubscriptionId};
9
10use crate::gateway::websocket::{JobData, WorkflowData};
11
12/// WebSocket server configuration.
13#[derive(Debug, Clone)]
14pub struct WebSocketConfig {
15    /// Maximum subscriptions per connection.
16    pub max_subscriptions_per_connection: usize,
17    /// Subscription timeout.
18    pub subscription_timeout: Duration,
19    /// Rate limit for subscription creation (per minute).
20    pub subscription_rate_limit: usize,
21    /// Heartbeat interval for keepalive.
22    pub heartbeat_interval: Duration,
23    /// Maximum message size in bytes.
24    pub max_message_size: usize,
25    /// Reconnect settings.
26    pub reconnect: ReconnectConfig,
27}
28
29impl Default for WebSocketConfig {
30    fn default() -> Self {
31        Self {
32            max_subscriptions_per_connection: 50,
33            subscription_timeout: Duration::from_secs(30),
34            subscription_rate_limit: 100,
35            heartbeat_interval: Duration::from_secs(30),
36            max_message_size: 1024 * 1024, // 1MB
37            reconnect: ReconnectConfig::default(),
38        }
39    }
40}
41
42/// Reconnection configuration.
43#[derive(Debug, Clone)]
44pub struct ReconnectConfig {
45    /// Whether reconnection is enabled.
46    pub enabled: bool,
47    /// Maximum reconnection attempts.
48    pub max_attempts: usize,
49    /// Initial delay between attempts.
50    pub delay: Duration,
51    /// Maximum delay between attempts.
52    pub max_delay: Duration,
53    /// Backoff strategy.
54    pub backoff: BackoffStrategy,
55}
56
57impl Default for ReconnectConfig {
58    fn default() -> Self {
59        Self {
60            enabled: true,
61            max_attempts: 10,
62            delay: Duration::from_secs(1),
63            max_delay: Duration::from_secs(30),
64            backoff: BackoffStrategy::Exponential,
65        }
66    }
67}
68
69/// Backoff strategy for reconnection.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BackoffStrategy {
72    /// Linear backoff.
73    Linear,
74    /// Exponential backoff.
75    Exponential,
76    /// Fixed delay.
77    Fixed,
78}
79
80/// Message types for WebSocket communication.
81#[derive(Debug, Clone)]
82pub enum WebSocketMessage {
83    /// Subscribe to a query.
84    Subscribe {
85        id: String,
86        query: String,
87        args: serde_json::Value,
88    },
89    /// Unsubscribe from a subscription.
90    Unsubscribe { subscription_id: SubscriptionId },
91    /// Ping for keepalive.
92    Ping,
93    /// Pong response.
94    Pong,
95    /// Initial data for subscription.
96    Data {
97        subscription_id: SubscriptionId,
98        data: serde_json::Value,
99    },
100    /// Delta update for subscription.
101    DeltaUpdate {
102        subscription_id: SubscriptionId,
103        delta: Delta<serde_json::Value>,
104    },
105    /// Job progress update.
106    JobUpdate { client_sub_id: String, job: JobData },
107    /// Workflow progress update.
108    WorkflowUpdate {
109        client_sub_id: String,
110        workflow: WorkflowData,
111    },
112    /// Error message.
113    Error { code: String, message: String },
114    /// Error message with subscription ID.
115    ErrorWithId {
116        id: String,
117        code: String,
118        message: String,
119    },
120}
121
122/// Represents a connected WebSocket client.
123#[derive(Debug)]
124pub struct WebSocketConnection {
125    /// Session ID for this connection.
126    #[allow(dead_code)]
127    pub session_id: SessionId,
128    /// Active subscriptions.
129    pub subscriptions: Vec<SubscriptionId>,
130    /// Sender for outgoing messages.
131    pub sender: mpsc::Sender<WebSocketMessage>,
132    /// When the connection was established.
133    #[allow(dead_code)]
134    pub connected_at: chrono::DateTime<chrono::Utc>,
135    /// Last activity time.
136    pub last_active: chrono::DateTime<chrono::Utc>,
137}
138
139impl WebSocketConnection {
140    /// Create a new connection.
141    pub fn new(session_id: SessionId, sender: mpsc::Sender<WebSocketMessage>) -> Self {
142        let now = chrono::Utc::now();
143        Self {
144            session_id,
145            subscriptions: Vec::new(),
146            sender,
147            connected_at: now,
148            last_active: now,
149        }
150    }
151
152    /// Add a subscription.
153    pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
154        self.subscriptions.push(subscription_id);
155        self.last_active = chrono::Utc::now();
156    }
157
158    /// Remove a subscription.
159    pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
160        self.subscriptions.retain(|id| *id != subscription_id);
161        self.last_active = chrono::Utc::now();
162    }
163
164    /// Send a message to the client.
165    pub async fn send(
166        &self,
167        message: WebSocketMessage,
168    ) -> Result<(), mpsc::error::SendError<WebSocketMessage>> {
169        self.sender.send(message).await
170    }
171}
172
173/// WebSocket server for managing real-time connections.
174pub struct WebSocketServer {
175    #[allow(dead_code)]
176    config: WebSocketConfig,
177    node_id: NodeId,
178    /// Active connections by session ID.
179    connections: Arc<RwLock<HashMap<SessionId, WebSocketConnection>>>,
180    /// Subscription to session mapping for fast lookup.
181    subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
182}
183
184impl WebSocketServer {
185    /// Create a new WebSocket server.
186    pub fn new(node_id: NodeId, config: WebSocketConfig) -> Self {
187        Self {
188            config,
189            node_id,
190            connections: Arc::new(RwLock::new(HashMap::new())),
191            subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
192        }
193    }
194
195    /// Get the node ID.
196    pub fn node_id(&self) -> NodeId {
197        self.node_id
198    }
199
200    /// Get the configuration.
201    pub fn config(&self) -> &WebSocketConfig {
202        &self.config
203    }
204
205    /// Register a new connection.
206    pub async fn register_connection(
207        &self,
208        session_id: SessionId,
209        sender: mpsc::Sender<WebSocketMessage>,
210    ) {
211        let connection = WebSocketConnection::new(session_id, sender);
212        let mut connections = self.connections.write().await;
213        connections.insert(session_id, connection);
214    }
215
216    /// Remove a connection.
217    pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
218        let mut connections = self.connections.write().await;
219        if let Some(conn) = connections.remove(&session_id) {
220            // Clean up subscription mappings
221            let mut sub_sessions = self.subscription_sessions.write().await;
222            for sub_id in &conn.subscriptions {
223                sub_sessions.remove(sub_id);
224            }
225            Some(conn.subscriptions)
226        } else {
227            None
228        }
229    }
230
231    /// Add a subscription to a connection.
232    pub async fn add_subscription(
233        &self,
234        session_id: SessionId,
235        subscription_id: SubscriptionId,
236    ) -> forge_core::Result<()> {
237        let mut connections = self.connections.write().await;
238        let conn = connections
239            .get_mut(&session_id)
240            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
241
242        // Check subscription limit
243        if conn.subscriptions.len() >= self.config.max_subscriptions_per_connection {
244            return Err(forge_core::ForgeError::Validation(format!(
245                "Maximum subscriptions per connection ({}) exceeded",
246                self.config.max_subscriptions_per_connection
247            )));
248        }
249
250        conn.add_subscription(subscription_id);
251
252        // Update subscription to session mapping
253        let mut sub_sessions = self.subscription_sessions.write().await;
254        sub_sessions.insert(subscription_id, session_id);
255
256        Ok(())
257    }
258
259    /// Remove a subscription from a connection.
260    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
261        let session_id = {
262            let mut sub_sessions = self.subscription_sessions.write().await;
263            sub_sessions.remove(&subscription_id)
264        };
265
266        if let Some(session_id) = session_id {
267            let mut connections = self.connections.write().await;
268            if let Some(conn) = connections.get_mut(&session_id) {
269                conn.remove_subscription(subscription_id);
270            }
271        }
272    }
273
274    /// Send a message to a specific session.
275    pub async fn send_to_session(
276        &self,
277        session_id: SessionId,
278        message: WebSocketMessage,
279    ) -> forge_core::Result<()> {
280        let connections = self.connections.read().await;
281        let conn = connections
282            .get(&session_id)
283            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
284
285        conn.send(message)
286            .await
287            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
288    }
289
290    /// Send a delta to all sessions subscribed to a subscription.
291    pub async fn broadcast_delta(
292        &self,
293        subscription_id: SubscriptionId,
294        delta: Delta<serde_json::Value>,
295    ) -> forge_core::Result<()> {
296        let session_id = {
297            let sub_sessions = self.subscription_sessions.read().await;
298            sub_sessions.get(&subscription_id).copied()
299        };
300
301        if let Some(session_id) = session_id {
302            let message = WebSocketMessage::DeltaUpdate {
303                subscription_id,
304                delta,
305            };
306            self.send_to_session(session_id, message).await?;
307        }
308
309        Ok(())
310    }
311
312    /// Get connection count.
313    pub async fn connection_count(&self) -> usize {
314        self.connections.read().await.len()
315    }
316
317    /// Get subscription count.
318    pub async fn subscription_count(&self) -> usize {
319        self.subscription_sessions.read().await.len()
320    }
321
322    /// Get server statistics.
323    pub async fn stats(&self) -> WebSocketStats {
324        let connections = self.connections.read().await;
325        let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
326
327        WebSocketStats {
328            connections: connections.len(),
329            subscriptions: total_subscriptions,
330            node_id: self.node_id,
331        }
332    }
333
334    /// Cleanup stale connections.
335    pub async fn cleanup_stale(&self, max_idle: Duration) {
336        let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
337        let mut connections = self.connections.write().await;
338        let mut sub_sessions = self.subscription_sessions.write().await;
339
340        connections.retain(|_, conn| {
341            if conn.last_active < cutoff {
342                // Clean up subscription mappings
343                for sub_id in &conn.subscriptions {
344                    sub_sessions.remove(sub_id);
345                }
346                false
347            } else {
348                true
349            }
350        });
351    }
352}
353
354/// WebSocket server statistics.
355#[derive(Debug, Clone)]
356pub struct WebSocketStats {
357    /// Number of active connections.
358    pub connections: usize,
359    /// Total subscriptions across all connections.
360    pub subscriptions: usize,
361    /// Node ID.
362    pub node_id: NodeId,
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_websocket_config_default() {
371        let config = WebSocketConfig::default();
372        assert_eq!(config.max_subscriptions_per_connection, 50);
373        assert_eq!(config.subscription_rate_limit, 100);
374        assert!(config.reconnect.enabled);
375    }
376
377    #[test]
378    fn test_reconnect_config_default() {
379        let config = ReconnectConfig::default();
380        assert!(config.enabled);
381        assert_eq!(config.max_attempts, 10);
382        assert_eq!(config.backoff, BackoffStrategy::Exponential);
383    }
384
385    #[tokio::test]
386    async fn test_websocket_server_creation() {
387        let node_id = NodeId::new();
388        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
389
390        assert_eq!(server.node_id(), node_id);
391        assert_eq!(server.connection_count().await, 0);
392        assert_eq!(server.subscription_count().await, 0);
393    }
394
395    #[tokio::test]
396    async fn test_websocket_connection() {
397        let node_id = NodeId::new();
398        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
399        let session_id = SessionId::new();
400        let (tx, _rx) = mpsc::channel(100);
401
402        server.register_connection(session_id, tx).await;
403        assert_eq!(server.connection_count().await, 1);
404
405        let removed = server.remove_connection(session_id).await;
406        assert!(removed.is_some());
407        assert_eq!(server.connection_count().await, 0);
408    }
409
410    #[tokio::test]
411    async fn test_websocket_subscription() {
412        let node_id = NodeId::new();
413        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
414        let session_id = SessionId::new();
415        let subscription_id = SubscriptionId::new();
416        let (tx, _rx) = mpsc::channel(100);
417
418        server.register_connection(session_id, tx).await;
419        server
420            .add_subscription(session_id, subscription_id)
421            .await
422            .unwrap();
423
424        assert_eq!(server.subscription_count().await, 1);
425
426        server.remove_subscription(subscription_id).await;
427        assert_eq!(server.subscription_count().await, 0);
428    }
429
430    #[tokio::test]
431    async fn test_websocket_subscription_limit() {
432        let node_id = NodeId::new();
433        let config = WebSocketConfig {
434            max_subscriptions_per_connection: 2,
435            ..Default::default()
436        };
437        let server = WebSocketServer::new(node_id, config);
438        let session_id = SessionId::new();
439        let (tx, _rx) = mpsc::channel(100);
440
441        server.register_connection(session_id, tx).await;
442
443        // First two should succeed
444        server
445            .add_subscription(session_id, SubscriptionId::new())
446            .await
447            .unwrap();
448        server
449            .add_subscription(session_id, SubscriptionId::new())
450            .await
451            .unwrap();
452
453        // Third should fail
454        let result = server
455            .add_subscription(session_id, SubscriptionId::new())
456            .await;
457        assert!(result.is_err());
458    }
459
460    #[tokio::test]
461    async fn test_websocket_stats() {
462        let node_id = NodeId::new();
463        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
464        let session_id = SessionId::new();
465        let (tx, _rx) = mpsc::channel(100);
466
467        server.register_connection(session_id, tx).await;
468        server
469            .add_subscription(session_id, SubscriptionId::new())
470            .await
471            .unwrap();
472        server
473            .add_subscription(session_id, SubscriptionId::new())
474            .await
475            .unwrap();
476
477        let stats = server.stats().await;
478        assert_eq!(stats.connections, 1);
479        assert_eq!(stats.subscriptions, 2);
480        assert_eq!(stats.node_id, node_id);
481    }
482}