Skip to main content

forge_runtime/realtime/
message.rs

1use std::sync::atomic::{AtomicU32, Ordering};
2use std::time::Duration;
3
4use dashmap::DashMap;
5use serde::Serialize;
6use tokio::sync::mpsc;
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13    pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17    fn default() -> Self {
18        Self {
19            max_subscriptions_per_session: 50,
20        }
21    }
22}
23
24/// Job data sent to client (subset of internal JobRecord).
25#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27    pub job_id: String,
28    pub status: String,
29    pub progress_percent: Option<i32>,
30    pub progress_message: Option<String>,
31    pub output: Option<serde_json::Value>,
32    pub error: Option<String>,
33}
34
35/// Workflow data sent to client.
36#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38    pub workflow_id: String,
39    pub status: String,
40    pub current_step: Option<String>,
41    pub steps: Vec<WorkflowStepData>,
42    pub output: Option<serde_json::Value>,
43    pub error: Option<String>,
44}
45
46/// Workflow step data sent to client.
47#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49    pub name: String,
50    pub status: String,
51    pub error: Option<String>,
52}
53
54/// Message types for real-time communication.
55#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57    Subscribe {
58        id: String,
59        query: String,
60        args: serde_json::Value,
61    },
62    Unsubscribe {
63        subscription_id: SubscriptionId,
64    },
65    Ping,
66    Pong,
67    Data {
68        subscription_id: String,
69        data: serde_json::Value,
70    },
71    DeltaUpdate {
72        subscription_id: String,
73        delta: Delta<serde_json::Value>,
74    },
75    JobUpdate {
76        client_sub_id: String,
77        job: JobData,
78    },
79    WorkflowUpdate {
80        client_sub_id: String,
81        workflow: WorkflowData,
82    },
83    Error {
84        code: String,
85        message: String,
86    },
87    ErrorWithId {
88        id: String,
89        code: String,
90        message: String,
91    },
92    AuthSuccess,
93    AuthFailed {
94        reason: String,
95    },
96    /// Sent to slow clients before disconnecting them.
97    Lagging,
98}
99
100/// Per-session state with backpressure tracking.
101struct SessionEntry {
102    sender: mpsc::Sender<RealtimeMessage>,
103    subscriptions: Vec<SubscriptionId>,
104    #[allow(dead_code)]
105    connected_at: chrono::DateTime<chrono::Utc>,
106    last_active: chrono::DateTime<chrono::Utc>,
107    /// Consecutive failed try_send attempts. Resets on success.
108    consecutive_drops: AtomicU32,
109}
110
111/// Maximum consecutive drops before evicting a slow client.
112const MAX_CONSECUTIVE_DROPS: u32 = 10;
113
114pub struct SessionServer {
115    config: RealtimeConfig,
116    node_id: NodeId,
117    /// Active connections by session ID. DashMap for concurrent access.
118    connections: DashMap<SessionId, SessionEntry>,
119    /// Subscription to session mapping for fast reverse lookup.
120    subscription_sessions: DashMap<SubscriptionId, SessionId>,
121}
122
123impl SessionServer {
124    /// Create a new session server.
125    pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
126        Self {
127            config,
128            node_id,
129            connections: DashMap::new(),
130            subscription_sessions: DashMap::new(),
131        }
132    }
133
134    pub fn node_id(&self) -> NodeId {
135        self.node_id
136    }
137
138    pub fn config(&self) -> &RealtimeConfig {
139        &self.config
140    }
141
142    /// Register a new connection.
143    pub fn register_connection(
144        &self,
145        session_id: SessionId,
146        sender: mpsc::Sender<RealtimeMessage>,
147    ) {
148        let now = chrono::Utc::now();
149        let entry = SessionEntry {
150            sender,
151            subscriptions: Vec::new(),
152            connected_at: now,
153            last_active: now,
154            consecutive_drops: AtomicU32::new(0),
155        };
156        self.connections.insert(session_id, entry);
157    }
158
159    /// Remove a connection.
160    pub fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
161        if let Some((_, conn)) = self.connections.remove(&session_id) {
162            for sub_id in &conn.subscriptions {
163                self.subscription_sessions.remove(sub_id);
164            }
165            Some(conn.subscriptions)
166        } else {
167            None
168        }
169    }
170
171    /// Add a subscription to a connection.
172    pub fn add_subscription(
173        &self,
174        session_id: SessionId,
175        subscription_id: SubscriptionId,
176    ) -> forge_core::Result<()> {
177        let mut conn = self
178            .connections
179            .get_mut(&session_id)
180            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
181
182        if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
183            return Err(forge_core::ForgeError::Validation(format!(
184                "Maximum subscriptions per session ({}) exceeded",
185                self.config.max_subscriptions_per_session
186            )));
187        }
188
189        conn.subscriptions.push(subscription_id);
190        drop(conn);
191
192        self.subscription_sessions
193            .insert(subscription_id, session_id);
194
195        Ok(())
196    }
197
198    /// Remove a subscription from a connection.
199    pub fn remove_subscription(&self, subscription_id: SubscriptionId) {
200        if let Some((_, session_id)) = self.subscription_sessions.remove(&subscription_id)
201            && let Some(mut conn) = self.connections.get_mut(&session_id)
202        {
203            conn.subscriptions.retain(|id| *id != subscription_id);
204        }
205    }
206
207    /// Non-blocking send with backpressure. Returns false if client was evicted.
208    pub fn try_send_to_session(
209        &self,
210        session_id: SessionId,
211        message: RealtimeMessage,
212    ) -> Result<(), SendError> {
213        let conn = self
214            .connections
215            .get(&session_id)
216            .ok_or(SendError::SessionNotFound)?;
217
218        match conn.sender.try_send(message) {
219            Ok(()) => {
220                conn.consecutive_drops.store(0, Ordering::Relaxed);
221                Ok(())
222            }
223            Err(mpsc::error::TrySendError::Full(_)) => {
224                let drops = conn.consecutive_drops.fetch_add(1, Ordering::Relaxed);
225                if drops >= MAX_CONSECUTIVE_DROPS {
226                    // Try to send lagging notification before evicting
227                    let _ = conn.sender.try_send(RealtimeMessage::Lagging);
228                    drop(conn);
229                    self.evict_session(session_id);
230                    Err(SendError::Evicted)
231                } else {
232                    Err(SendError::Full)
233                }
234            }
235            Err(mpsc::error::TrySendError::Closed(_)) => {
236                drop(conn);
237                self.remove_connection(session_id);
238                Err(SendError::Closed)
239            }
240        }
241    }
242
243    /// Blocking send for initial data delivery where we need backpressure.
244    pub async fn send_to_session(
245        &self,
246        session_id: SessionId,
247        message: RealtimeMessage,
248    ) -> forge_core::Result<()> {
249        let sender = {
250            let conn = self.connections.get(&session_id).ok_or_else(|| {
251                forge_core::ForgeError::Validation("Session not found".to_string())
252            })?;
253            conn.sender.clone()
254        };
255
256        sender
257            .send(message)
258            .await
259            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
260    }
261
262    /// Send a delta to all sessions subscribed to a subscription.
263    pub async fn broadcast_delta(
264        &self,
265        subscription_id: SubscriptionId,
266        delta: Delta<serde_json::Value>,
267    ) -> forge_core::Result<()> {
268        let session_id = self.subscription_sessions.get(&subscription_id).map(|r| *r);
269
270        if let Some(session_id) = session_id {
271            let message = RealtimeMessage::DeltaUpdate {
272                subscription_id: subscription_id.to_string(),
273                delta,
274            };
275            self.send_to_session(session_id, message).await?;
276        }
277
278        Ok(())
279    }
280
281    /// Evict a slow session.
282    fn evict_session(&self, session_id: SessionId) {
283        tracing::warn!(?session_id, "Evicting slow client");
284        self.remove_connection(session_id);
285    }
286
287    /// Get connection count.
288    pub fn connection_count(&self) -> usize {
289        self.connections.len()
290    }
291
292    /// Get subscription count.
293    pub fn subscription_count(&self) -> usize {
294        self.subscription_sessions.len()
295    }
296
297    /// Get server statistics.
298    pub fn stats(&self) -> SessionStats {
299        let total_subscriptions: usize =
300            self.connections.iter().map(|c| c.subscriptions.len()).sum();
301
302        SessionStats {
303            connections: self.connections.len(),
304            subscriptions: total_subscriptions,
305            node_id: self.node_id,
306        }
307    }
308
309    /// Cleanup stale connections.
310    pub fn cleanup_stale(&self, max_idle: Duration) {
311        let cutoff = chrono::Utc::now()
312            - chrono::Duration::from_std(max_idle).unwrap_or(chrono::TimeDelta::MAX);
313
314        let stale: Vec<SessionId> = self
315            .connections
316            .iter()
317            .filter(|entry| entry.last_active < cutoff)
318            .map(|entry| *entry.key())
319            .collect();
320
321        for session_id in stale {
322            self.remove_connection(session_id);
323        }
324    }
325}
326
327/// Error type for try_send operations.
328#[derive(Debug)]
329pub enum SendError {
330    SessionNotFound,
331    Full,
332    Closed,
333    Evicted,
334}
335
336/// Session server statistics.
337#[derive(Debug, Clone)]
338pub struct SessionStats {
339    pub connections: usize,
340    pub subscriptions: usize,
341    pub node_id: NodeId,
342}
343
344#[cfg(test)]
345#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_realtime_config_default() {
351        let config = RealtimeConfig::default();
352        assert_eq!(config.max_subscriptions_per_session, 50);
353    }
354
355    #[test]
356    fn test_session_server_creation() {
357        let node_id = NodeId::new();
358        let server = SessionServer::new(node_id, RealtimeConfig::default());
359
360        assert_eq!(server.node_id(), node_id);
361        assert_eq!(server.connection_count(), 0);
362        assert_eq!(server.subscription_count(), 0);
363    }
364
365    #[test]
366    fn test_session_connection() {
367        let node_id = NodeId::new();
368        let server = SessionServer::new(node_id, RealtimeConfig::default());
369        let session_id = SessionId::new();
370        let (tx, _rx) = mpsc::channel(100);
371
372        server.register_connection(session_id, tx);
373        assert_eq!(server.connection_count(), 1);
374
375        let removed = server.remove_connection(session_id);
376        assert!(removed.is_some());
377        assert_eq!(server.connection_count(), 0);
378    }
379
380    #[test]
381    fn test_session_subscription() {
382        let node_id = NodeId::new();
383        let server = SessionServer::new(node_id, RealtimeConfig::default());
384        let session_id = SessionId::new();
385        let subscription_id = SubscriptionId::new();
386        let (tx, _rx) = mpsc::channel(100);
387
388        server.register_connection(session_id, tx);
389        server
390            .add_subscription(session_id, subscription_id)
391            .unwrap();
392
393        assert_eq!(server.subscription_count(), 1);
394
395        server.remove_subscription(subscription_id);
396        assert_eq!(server.subscription_count(), 0);
397    }
398
399    #[test]
400    fn test_session_subscription_limit() {
401        let node_id = NodeId::new();
402        let config = RealtimeConfig {
403            max_subscriptions_per_session: 2,
404        };
405        let server = SessionServer::new(node_id, config);
406        let session_id = SessionId::new();
407        let (tx, _rx) = mpsc::channel(100);
408
409        server.register_connection(session_id, tx);
410
411        server
412            .add_subscription(session_id, SubscriptionId::new())
413            .unwrap();
414        server
415            .add_subscription(session_id, SubscriptionId::new())
416            .unwrap();
417
418        let result = server.add_subscription(session_id, SubscriptionId::new());
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn test_try_send_backpressure() {
424        let node_id = NodeId::new();
425        let server = SessionServer::new(node_id, RealtimeConfig::default());
426        let session_id = SessionId::new();
427        // Tiny buffer to trigger backpressure
428        let (tx, _rx) = mpsc::channel(1);
429
430        server.register_connection(session_id, tx);
431
432        // First send should succeed
433        let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
434        assert!(result.is_ok());
435
436        // Second send to full buffer should return Full
437        let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
438        assert!(matches!(result, Err(SendError::Full)));
439    }
440
441    #[test]
442    fn test_session_stats() {
443        let node_id = NodeId::new();
444        let server = SessionServer::new(node_id, RealtimeConfig::default());
445        let session_id = SessionId::new();
446        let (tx, _rx) = mpsc::channel(100);
447
448        server.register_connection(session_id, tx);
449        server
450            .add_subscription(session_id, SubscriptionId::new())
451            .unwrap();
452        server
453            .add_subscription(session_id, SubscriptionId::new())
454            .unwrap();
455
456        let stats = server.stats();
457        assert_eq!(stats.connections, 1);
458        assert_eq!(stats.subscriptions, 2);
459        assert_eq!(stats.node_id, node_id);
460    }
461}