forge_runtime/realtime/
websocket.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Delta, SessionId, SubscriptionId};
9
10use crate::gateway::websocket::{JobData, WorkflowData};
11
12/// WebSocket server configuration.
13#[derive(Debug, Clone)]
14pub struct WebSocketConfig {
15    /// Maximum subscriptions per connection.
16    pub max_subscriptions_per_connection: usize,
17    /// Subscription timeout.
18    pub subscription_timeout: Duration,
19    /// Rate limit for subscription creation (per minute).
20    pub subscription_rate_limit: usize,
21    /// Heartbeat interval for keepalive.
22    pub heartbeat_interval: Duration,
23    /// Maximum message size in bytes.
24    pub max_message_size: usize,
25    /// Reconnect settings.
26    pub reconnect: ReconnectConfig,
27}
28
29impl Default for WebSocketConfig {
30    fn default() -> Self {
31        Self {
32            max_subscriptions_per_connection: 50,
33            subscription_timeout: Duration::from_secs(30),
34            subscription_rate_limit: 100,
35            heartbeat_interval: Duration::from_secs(30),
36            max_message_size: 1024 * 1024, // 1MB
37            reconnect: ReconnectConfig::default(),
38        }
39    }
40}
41
42/// Reconnection configuration.
43#[derive(Debug, Clone)]
44pub struct ReconnectConfig {
45    /// Whether reconnection is enabled.
46    pub enabled: bool,
47    /// Maximum reconnection attempts.
48    pub max_attempts: usize,
49    /// Initial delay between attempts.
50    pub delay: Duration,
51    /// Maximum delay between attempts.
52    pub max_delay: Duration,
53    /// Backoff strategy.
54    pub backoff: BackoffStrategy,
55}
56
57impl Default for ReconnectConfig {
58    fn default() -> Self {
59        Self {
60            enabled: true,
61            max_attempts: 10,
62            delay: Duration::from_secs(1),
63            max_delay: Duration::from_secs(30),
64            backoff: BackoffStrategy::Exponential,
65        }
66    }
67}
68
69/// Backoff strategy for reconnection.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BackoffStrategy {
72    /// Linear backoff.
73    Linear,
74    /// Exponential backoff.
75    Exponential,
76    /// Fixed delay.
77    Fixed,
78}
79
80/// Message types for WebSocket communication.
81#[derive(Debug, Clone)]
82pub enum WebSocketMessage {
83    /// Subscribe to a query.
84    Subscribe {
85        id: String,
86        query: String,
87        args: serde_json::Value,
88    },
89    /// Unsubscribe from a subscription.
90    Unsubscribe { subscription_id: SubscriptionId },
91    /// Ping for keepalive.
92    Ping,
93    /// Pong response.
94    Pong,
95    /// Initial data for subscription.
96    Data {
97        subscription_id: SubscriptionId,
98        data: serde_json::Value,
99    },
100    /// Delta update for subscription.
101    DeltaUpdate {
102        subscription_id: SubscriptionId,
103        delta: Delta<serde_json::Value>,
104    },
105    /// Job progress update.
106    JobUpdate { client_sub_id: String, job: JobData },
107    /// Workflow progress update.
108    WorkflowUpdate {
109        client_sub_id: String,
110        workflow: WorkflowData,
111    },
112    /// Error message.
113    Error { code: String, message: String },
114    /// Error message with subscription ID.
115    ErrorWithId {
116        id: String,
117        code: String,
118        message: String,
119    },
120    /// Authentication successful.
121    AuthSuccess,
122    /// Authentication failed.
123    AuthFailed { reason: String },
124}
125
126/// Represents a connected WebSocket client.
127#[derive(Debug)]
128pub struct WebSocketConnection {
129    /// Session ID for this connection.
130    #[allow(dead_code)]
131    pub session_id: SessionId,
132    /// Active subscriptions.
133    pub subscriptions: Vec<SubscriptionId>,
134    /// Sender for outgoing messages.
135    pub sender: mpsc::Sender<WebSocketMessage>,
136    /// When the connection was established.
137    #[allow(dead_code)]
138    pub connected_at: chrono::DateTime<chrono::Utc>,
139    /// Last activity time.
140    pub last_active: chrono::DateTime<chrono::Utc>,
141}
142
143impl WebSocketConnection {
144    /// Create a new connection.
145    pub fn new(session_id: SessionId, sender: mpsc::Sender<WebSocketMessage>) -> Self {
146        let now = chrono::Utc::now();
147        Self {
148            session_id,
149            subscriptions: Vec::new(),
150            sender,
151            connected_at: now,
152            last_active: now,
153        }
154    }
155
156    /// Add a subscription.
157    pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
158        self.subscriptions.push(subscription_id);
159        self.last_active = chrono::Utc::now();
160    }
161
162    /// Remove a subscription.
163    pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
164        self.subscriptions.retain(|id| *id != subscription_id);
165        self.last_active = chrono::Utc::now();
166    }
167
168    /// Send a message to the client.
169    pub async fn send(
170        &self,
171        message: WebSocketMessage,
172    ) -> Result<(), mpsc::error::SendError<WebSocketMessage>> {
173        self.sender.send(message).await
174    }
175}
176
177/// WebSocket server for managing real-time connections.
178pub struct WebSocketServer {
179    #[allow(dead_code)]
180    config: WebSocketConfig,
181    node_id: NodeId,
182    /// Active connections by session ID.
183    connections: Arc<RwLock<HashMap<SessionId, WebSocketConnection>>>,
184    /// Subscription to session mapping for fast lookup.
185    subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
186}
187
188impl WebSocketServer {
189    /// Create a new WebSocket server.
190    pub fn new(node_id: NodeId, config: WebSocketConfig) -> Self {
191        Self {
192            config,
193            node_id,
194            connections: Arc::new(RwLock::new(HashMap::new())),
195            subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
196        }
197    }
198
199    /// Get the node ID.
200    pub fn node_id(&self) -> NodeId {
201        self.node_id
202    }
203
204    /// Get the configuration.
205    pub fn config(&self) -> &WebSocketConfig {
206        &self.config
207    }
208
209    /// Register a new connection.
210    pub async fn register_connection(
211        &self,
212        session_id: SessionId,
213        sender: mpsc::Sender<WebSocketMessage>,
214    ) {
215        let connection = WebSocketConnection::new(session_id, sender);
216        let mut connections = self.connections.write().await;
217        connections.insert(session_id, connection);
218    }
219
220    /// Remove a connection.
221    pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
222        let mut connections = self.connections.write().await;
223        if let Some(conn) = connections.remove(&session_id) {
224            // Clean up subscription mappings
225            let mut sub_sessions = self.subscription_sessions.write().await;
226            for sub_id in &conn.subscriptions {
227                sub_sessions.remove(sub_id);
228            }
229            Some(conn.subscriptions)
230        } else {
231            None
232        }
233    }
234
235    /// Add a subscription to a connection.
236    pub async fn add_subscription(
237        &self,
238        session_id: SessionId,
239        subscription_id: SubscriptionId,
240    ) -> forge_core::Result<()> {
241        let mut connections = self.connections.write().await;
242        let conn = connections
243            .get_mut(&session_id)
244            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
245
246        // Check subscription limit
247        if conn.subscriptions.len() >= self.config.max_subscriptions_per_connection {
248            return Err(forge_core::ForgeError::Validation(format!(
249                "Maximum subscriptions per connection ({}) exceeded",
250                self.config.max_subscriptions_per_connection
251            )));
252        }
253
254        conn.add_subscription(subscription_id);
255
256        // Update subscription to session mapping
257        let mut sub_sessions = self.subscription_sessions.write().await;
258        sub_sessions.insert(subscription_id, session_id);
259
260        Ok(())
261    }
262
263    /// Remove a subscription from a connection.
264    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
265        let session_id = {
266            let mut sub_sessions = self.subscription_sessions.write().await;
267            sub_sessions.remove(&subscription_id)
268        };
269
270        if let Some(session_id) = session_id {
271            let mut connections = self.connections.write().await;
272            if let Some(conn) = connections.get_mut(&session_id) {
273                conn.remove_subscription(subscription_id);
274            }
275        }
276    }
277
278    /// Send a message to a specific session.
279    pub async fn send_to_session(
280        &self,
281        session_id: SessionId,
282        message: WebSocketMessage,
283    ) -> forge_core::Result<()> {
284        let connections = self.connections.read().await;
285        let conn = connections
286            .get(&session_id)
287            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
288
289        conn.send(message)
290            .await
291            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
292    }
293
294    /// Send a delta to all sessions subscribed to a subscription.
295    pub async fn broadcast_delta(
296        &self,
297        subscription_id: SubscriptionId,
298        delta: Delta<serde_json::Value>,
299    ) -> forge_core::Result<()> {
300        let session_id = {
301            let sub_sessions = self.subscription_sessions.read().await;
302            sub_sessions.get(&subscription_id).copied()
303        };
304
305        if let Some(session_id) = session_id {
306            let message = WebSocketMessage::DeltaUpdate {
307                subscription_id,
308                delta,
309            };
310            self.send_to_session(session_id, message).await?;
311        }
312
313        Ok(())
314    }
315
316    /// Get connection count.
317    pub async fn connection_count(&self) -> usize {
318        self.connections.read().await.len()
319    }
320
321    /// Get subscription count.
322    pub async fn subscription_count(&self) -> usize {
323        self.subscription_sessions.read().await.len()
324    }
325
326    /// Get server statistics.
327    pub async fn stats(&self) -> WebSocketStats {
328        let connections = self.connections.read().await;
329        let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
330
331        WebSocketStats {
332            connections: connections.len(),
333            subscriptions: total_subscriptions,
334            node_id: self.node_id,
335        }
336    }
337
338    /// Cleanup stale connections.
339    pub async fn cleanup_stale(&self, max_idle: Duration) {
340        let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
341        let mut connections = self.connections.write().await;
342        let mut sub_sessions = self.subscription_sessions.write().await;
343
344        connections.retain(|_, conn| {
345            if conn.last_active < cutoff {
346                // Clean up subscription mappings
347                for sub_id in &conn.subscriptions {
348                    sub_sessions.remove(sub_id);
349                }
350                false
351            } else {
352                true
353            }
354        });
355    }
356}
357
358/// WebSocket server statistics.
359#[derive(Debug, Clone)]
360pub struct WebSocketStats {
361    /// Number of active connections.
362    pub connections: usize,
363    /// Total subscriptions across all connections.
364    pub subscriptions: usize,
365    /// Node ID.
366    pub node_id: NodeId,
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_websocket_config_default() {
375        let config = WebSocketConfig::default();
376        assert_eq!(config.max_subscriptions_per_connection, 50);
377        assert_eq!(config.subscription_rate_limit, 100);
378        assert!(config.reconnect.enabled);
379    }
380
381    #[test]
382    fn test_reconnect_config_default() {
383        let config = ReconnectConfig::default();
384        assert!(config.enabled);
385        assert_eq!(config.max_attempts, 10);
386        assert_eq!(config.backoff, BackoffStrategy::Exponential);
387    }
388
389    #[tokio::test]
390    async fn test_websocket_server_creation() {
391        let node_id = NodeId::new();
392        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
393
394        assert_eq!(server.node_id(), node_id);
395        assert_eq!(server.connection_count().await, 0);
396        assert_eq!(server.subscription_count().await, 0);
397    }
398
399    #[tokio::test]
400    async fn test_websocket_connection() {
401        let node_id = NodeId::new();
402        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
403        let session_id = SessionId::new();
404        let (tx, _rx) = mpsc::channel(100);
405
406        server.register_connection(session_id, tx).await;
407        assert_eq!(server.connection_count().await, 1);
408
409        let removed = server.remove_connection(session_id).await;
410        assert!(removed.is_some());
411        assert_eq!(server.connection_count().await, 0);
412    }
413
414    #[tokio::test]
415    async fn test_websocket_subscription() {
416        let node_id = NodeId::new();
417        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
418        let session_id = SessionId::new();
419        let subscription_id = SubscriptionId::new();
420        let (tx, _rx) = mpsc::channel(100);
421
422        server.register_connection(session_id, tx).await;
423        server
424            .add_subscription(session_id, subscription_id)
425            .await
426            .unwrap();
427
428        assert_eq!(server.subscription_count().await, 1);
429
430        server.remove_subscription(subscription_id).await;
431        assert_eq!(server.subscription_count().await, 0);
432    }
433
434    #[tokio::test]
435    async fn test_websocket_subscription_limit() {
436        let node_id = NodeId::new();
437        let config = WebSocketConfig {
438            max_subscriptions_per_connection: 2,
439            ..Default::default()
440        };
441        let server = WebSocketServer::new(node_id, config);
442        let session_id = SessionId::new();
443        let (tx, _rx) = mpsc::channel(100);
444
445        server.register_connection(session_id, tx).await;
446
447        // First two should succeed
448        server
449            .add_subscription(session_id, SubscriptionId::new())
450            .await
451            .unwrap();
452        server
453            .add_subscription(session_id, SubscriptionId::new())
454            .await
455            .unwrap();
456
457        // Third should fail
458        let result = server
459            .add_subscription(session_id, SubscriptionId::new())
460            .await;
461        assert!(result.is_err());
462    }
463
464    #[tokio::test]
465    async fn test_websocket_stats() {
466        let node_id = NodeId::new();
467        let server = WebSocketServer::new(node_id, WebSocketConfig::default());
468        let session_id = SessionId::new();
469        let (tx, _rx) = mpsc::channel(100);
470
471        server.register_connection(session_id, tx).await;
472        server
473            .add_subscription(session_id, SubscriptionId::new())
474            .await
475            .unwrap();
476        server
477            .add_subscription(session_id, SubscriptionId::new())
478            .await
479            .unwrap();
480
481        let stats = server.stats().await;
482        assert_eq!(stats.connections, 1);
483        assert_eq!(stats.subscriptions, 2);
484        assert_eq!(stats.node_id, node_id);
485    }
486}