Skip to main content

forge_runtime/realtime/
message.rs

1use std::sync::atomic::{AtomicU32, Ordering};
2use std::time::Duration;
3
4use dashmap::DashMap;
5use serde::Serialize;
6use tokio::sync::mpsc;
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13    pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17    fn default() -> Self {
18        Self {
19            max_subscriptions_per_session: 50,
20        }
21    }
22}
23
24/// Job data sent to client (subset of internal JobRecord).
25#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27    pub job_id: String,
28    pub status: String,
29    #[serde(rename = "progress")]
30    pub progress_percent: Option<i32>,
31    #[serde(rename = "message")]
32    pub progress_message: Option<String>,
33    pub output: Option<serde_json::Value>,
34    pub error: Option<String>,
35}
36
37/// Workflow data sent to client.
38#[derive(Debug, Clone, Serialize)]
39pub struct WorkflowData {
40    pub workflow_id: String,
41    pub status: String,
42    #[serde(rename = "step")]
43    pub current_step: Option<String>,
44    pub steps: Vec<WorkflowStepData>,
45    pub output: Option<serde_json::Value>,
46    pub error: Option<String>,
47}
48
49/// Workflow step data sent to client.
50#[derive(Debug, Clone, Serialize)]
51pub struct WorkflowStepData {
52    pub name: String,
53    pub status: String,
54    pub error: Option<String>,
55}
56
57/// Message types for real-time communication.
58#[derive(Debug, Clone)]
59pub enum RealtimeMessage {
60    Subscribe {
61        id: String,
62        query: String,
63        args: serde_json::Value,
64    },
65    Unsubscribe {
66        subscription_id: SubscriptionId,
67    },
68    Ping,
69    Pong,
70    Data {
71        subscription_id: String,
72        data: serde_json::Value,
73    },
74    DeltaUpdate {
75        subscription_id: String,
76        delta: Delta<serde_json::Value>,
77    },
78    JobUpdate {
79        client_sub_id: String,
80        job: JobData,
81    },
82    WorkflowUpdate {
83        client_sub_id: String,
84        workflow: WorkflowData,
85    },
86    Error {
87        code: String,
88        message: String,
89    },
90    ErrorWithId {
91        id: String,
92        code: String,
93        message: String,
94    },
95    AuthSuccess,
96    AuthFailed {
97        reason: String,
98    },
99    /// Sent to slow clients before disconnecting them.
100    Lagging,
101}
102
103/// Per-session state with backpressure tracking.
104struct SessionEntry {
105    sender: mpsc::Sender<RealtimeMessage>,
106    subscriptions: Vec<SubscriptionId>,
107    connected_at: chrono::DateTime<chrono::Utc>,
108    last_active: chrono::DateTime<chrono::Utc>,
109    /// Consecutive failed try_send attempts. Resets on success.
110    consecutive_drops: AtomicU32,
111}
112
113/// Maximum consecutive drops before evicting a slow client.
114const MAX_CONSECUTIVE_DROPS: u32 = 10;
115
116pub struct SessionServer {
117    config: RealtimeConfig,
118    node_id: NodeId,
119    /// Active connections by session ID. DashMap for concurrent access.
120    connections: DashMap<SessionId, SessionEntry>,
121    /// Subscription to session mapping for fast reverse lookup.
122    subscription_sessions: DashMap<SubscriptionId, SessionId>,
123}
124
125impl SessionServer {
126    /// Create a new session server.
127    pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
128        Self {
129            config,
130            node_id,
131            connections: DashMap::new(),
132            subscription_sessions: DashMap::new(),
133        }
134    }
135
136    pub fn node_id(&self) -> NodeId {
137        self.node_id
138    }
139
140    pub fn config(&self) -> &RealtimeConfig {
141        &self.config
142    }
143
144    /// Register a new connection.
145    pub fn register_connection(
146        &self,
147        session_id: SessionId,
148        sender: mpsc::Sender<RealtimeMessage>,
149    ) {
150        let now = chrono::Utc::now();
151        let entry = SessionEntry {
152            sender,
153            subscriptions: Vec::new(),
154            connected_at: now,
155            last_active: now,
156            consecutive_drops: AtomicU32::new(0),
157        };
158        self.connections.insert(session_id, entry);
159    }
160
161    /// Remove a connection.
162    pub fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
163        if let Some((_, conn)) = self.connections.remove(&session_id) {
164            for sub_id in &conn.subscriptions {
165                self.subscription_sessions.remove(sub_id);
166            }
167            Some(conn.subscriptions)
168        } else {
169            None
170        }
171    }
172
173    /// Add a subscription to a connection.
174    pub fn add_subscription(
175        &self,
176        session_id: SessionId,
177        subscription_id: SubscriptionId,
178    ) -> forge_core::Result<()> {
179        let mut conn = self
180            .connections
181            .get_mut(&session_id)
182            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
183
184        if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
185            return Err(forge_core::ForgeError::Validation(format!(
186                "Maximum subscriptions per session ({}) exceeded",
187                self.config.max_subscriptions_per_session
188            )));
189        }
190
191        conn.subscriptions.push(subscription_id);
192        drop(conn);
193
194        self.subscription_sessions
195            .insert(subscription_id, session_id);
196
197        Ok(())
198    }
199
200    /// Remove a subscription from a connection.
201    pub fn remove_subscription(&self, subscription_id: SubscriptionId) {
202        if let Some((_, session_id)) = self.subscription_sessions.remove(&subscription_id)
203            && let Some(mut conn) = self.connections.get_mut(&session_id)
204        {
205            conn.subscriptions.retain(|id| *id != subscription_id);
206        }
207    }
208
209    /// Non-blocking send with backpressure. Returns false if client was evicted.
210    pub fn try_send_to_session(
211        &self,
212        session_id: SessionId,
213        message: RealtimeMessage,
214    ) -> Result<(), SendError> {
215        let conn = self
216            .connections
217            .get(&session_id)
218            .ok_or(SendError::SessionNotFound)?;
219
220        match conn.sender.try_send(message) {
221            Ok(()) => {
222                conn.consecutive_drops.store(0, Ordering::Relaxed);
223                Ok(())
224            }
225            Err(mpsc::error::TrySendError::Full(_)) => {
226                let drops = conn.consecutive_drops.fetch_add(1, Ordering::Relaxed);
227                if drops >= MAX_CONSECUTIVE_DROPS {
228                    // Try to send lagging notification before evicting
229                    let _ = conn.sender.try_send(RealtimeMessage::Lagging);
230                    drop(conn);
231                    self.evict_session(session_id);
232                    Err(SendError::Evicted)
233                } else {
234                    Err(SendError::Full)
235                }
236            }
237            Err(mpsc::error::TrySendError::Closed(_)) => {
238                drop(conn);
239                self.remove_connection(session_id);
240                Err(SendError::Closed)
241            }
242        }
243    }
244
245    /// Blocking send for initial data delivery where we need backpressure.
246    pub async fn send_to_session(
247        &self,
248        session_id: SessionId,
249        message: RealtimeMessage,
250    ) -> forge_core::Result<()> {
251        let sender = {
252            let conn = self.connections.get(&session_id).ok_or_else(|| {
253                forge_core::ForgeError::Validation("Session not found".to_string())
254            })?;
255            conn.sender.clone()
256        };
257
258        sender
259            .send(message)
260            .await
261            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
262    }
263
264    /// Send a delta to all sessions subscribed to a subscription.
265    pub async fn broadcast_delta(
266        &self,
267        subscription_id: SubscriptionId,
268        delta: Delta<serde_json::Value>,
269    ) -> forge_core::Result<()> {
270        let session_id = self.subscription_sessions.get(&subscription_id).map(|r| *r);
271
272        if let Some(session_id) = session_id {
273            let message = RealtimeMessage::DeltaUpdate {
274                subscription_id: subscription_id.to_string(),
275                delta,
276            };
277            self.send_to_session(session_id, message).await?;
278        }
279
280        Ok(())
281    }
282
283    /// Evict a slow session.
284    fn evict_session(&self, session_id: SessionId) {
285        tracing::warn!(?session_id, "Evicting slow client");
286        self.remove_connection(session_id);
287    }
288
289    /// Get connection count.
290    pub fn connection_count(&self) -> usize {
291        self.connections.len()
292    }
293
294    /// Get subscription count.
295    pub fn subscription_count(&self) -> usize {
296        self.subscription_sessions.len()
297    }
298
299    /// Get server statistics.
300    pub fn stats(&self) -> SessionStats {
301        let total_subscriptions: usize =
302            self.connections.iter().map(|c| c.subscriptions.len()).sum();
303
304        SessionStats {
305            connections: self.connections.len(),
306            subscriptions: total_subscriptions,
307            node_id: self.node_id,
308        }
309    }
310
311    /// Cleanup stale connections.
312    pub fn cleanup_stale(&self, max_idle: Duration) {
313        let cutoff = chrono::Utc::now()
314            - chrono::Duration::from_std(max_idle).unwrap_or(chrono::TimeDelta::MAX);
315
316        let stale: Vec<(SessionId, chrono::DateTime<chrono::Utc>)> = self
317            .connections
318            .iter()
319            .filter(|entry| entry.last_active < cutoff)
320            .map(|entry| (*entry.key(), entry.connected_at))
321            .collect();
322
323        if let Some((_, oldest_connected_at)) =
324            stale.iter().min_by_key(|(_, connected_at)| *connected_at)
325        {
326            tracing::debug!(
327                count = stale.len(),
328                oldest_connected_at = %oldest_connected_at,
329                "Cleaning up stale connections"
330            );
331        }
332
333        for (session_id, _) in stale {
334            self.remove_connection(session_id);
335        }
336    }
337}
338
339/// Error type for try_send operations.
340#[derive(Debug)]
341pub enum SendError {
342    SessionNotFound,
343    Full,
344    Closed,
345    Evicted,
346}
347
348/// Session server statistics.
349#[derive(Debug, Clone)]
350pub struct SessionStats {
351    pub connections: usize,
352    pub subscriptions: usize,
353    pub node_id: NodeId,
354}
355
356#[cfg(test)]
357#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_realtime_config_default() {
363        let config = RealtimeConfig::default();
364        assert_eq!(config.max_subscriptions_per_session, 50);
365    }
366
367    #[test]
368    fn test_session_server_creation() {
369        let node_id = NodeId::new();
370        let server = SessionServer::new(node_id, RealtimeConfig::default());
371
372        assert_eq!(server.node_id(), node_id);
373        assert_eq!(server.connection_count(), 0);
374        assert_eq!(server.subscription_count(), 0);
375    }
376
377    #[test]
378    fn test_session_connection() {
379        let node_id = NodeId::new();
380        let server = SessionServer::new(node_id, RealtimeConfig::default());
381        let session_id = SessionId::new();
382        let (tx, _rx) = mpsc::channel(100);
383
384        server.register_connection(session_id, tx);
385        assert_eq!(server.connection_count(), 1);
386
387        let removed = server.remove_connection(session_id);
388        assert!(removed.is_some());
389        assert_eq!(server.connection_count(), 0);
390    }
391
392    #[test]
393    fn test_session_subscription() {
394        let node_id = NodeId::new();
395        let server = SessionServer::new(node_id, RealtimeConfig::default());
396        let session_id = SessionId::new();
397        let subscription_id = SubscriptionId::new();
398        let (tx, _rx) = mpsc::channel(100);
399
400        server.register_connection(session_id, tx);
401        server
402            .add_subscription(session_id, subscription_id)
403            .unwrap();
404
405        assert_eq!(server.subscription_count(), 1);
406
407        server.remove_subscription(subscription_id);
408        assert_eq!(server.subscription_count(), 0);
409    }
410
411    #[test]
412    fn test_session_subscription_limit() {
413        let node_id = NodeId::new();
414        let config = RealtimeConfig {
415            max_subscriptions_per_session: 2,
416        };
417        let server = SessionServer::new(node_id, config);
418        let session_id = SessionId::new();
419        let (tx, _rx) = mpsc::channel(100);
420
421        server.register_connection(session_id, tx);
422
423        server
424            .add_subscription(session_id, SubscriptionId::new())
425            .unwrap();
426        server
427            .add_subscription(session_id, SubscriptionId::new())
428            .unwrap();
429
430        let result = server.add_subscription(session_id, SubscriptionId::new());
431        assert!(result.is_err());
432    }
433
434    #[test]
435    fn test_try_send_backpressure() {
436        let node_id = NodeId::new();
437        let server = SessionServer::new(node_id, RealtimeConfig::default());
438        let session_id = SessionId::new();
439        // Tiny buffer to trigger backpressure
440        let (tx, _rx) = mpsc::channel(1);
441
442        server.register_connection(session_id, tx);
443
444        // First send should succeed
445        let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
446        assert!(result.is_ok());
447
448        // Second send to full buffer should return Full
449        let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
450        assert!(matches!(result, Err(SendError::Full)));
451    }
452
453    #[test]
454    fn test_session_stats() {
455        let node_id = NodeId::new();
456        let server = SessionServer::new(node_id, RealtimeConfig::default());
457        let session_id = SessionId::new();
458        let (tx, _rx) = mpsc::channel(100);
459
460        server.register_connection(session_id, tx);
461        server
462            .add_subscription(session_id, SubscriptionId::new())
463            .unwrap();
464        server
465            .add_subscription(session_id, SubscriptionId::new())
466            .unwrap();
467
468        let stats = server.stats();
469        assert_eq!(stats.connections, 1);
470        assert_eq!(stats.subscriptions, 2);
471        assert_eq!(stats.node_id, node_id);
472    }
473}