Skip to main content

forge_runtime/realtime/
message.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::Serialize;
6use tokio::sync::{RwLock, mpsc};
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13    pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17    fn default() -> Self {
18        Self {
19            max_subscriptions_per_session: 50,
20        }
21    }
22}
23
24/// Job data sent to client (subset of internal JobRecord).
25#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27    pub job_id: String,
28    pub status: String,
29    pub progress_percent: Option<i32>,
30    pub progress_message: Option<String>,
31    pub output: Option<serde_json::Value>,
32    pub error: Option<String>,
33}
34
35/// Workflow data sent to client.
36#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38    pub workflow_id: String,
39    pub status: String,
40    pub current_step: Option<String>,
41    pub steps: Vec<WorkflowStepData>,
42    pub output: Option<serde_json::Value>,
43    pub error: Option<String>,
44}
45
46/// Workflow step data sent to client.
47#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49    pub name: String,
50    pub status: String,
51    pub error: Option<String>,
52}
53
54/// Message types for real-time communication.
55#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57    /// Subscribe to a query.
58    Subscribe {
59        id: String,
60        query: String,
61        args: serde_json::Value,
62    },
63    /// Unsubscribe from a subscription.
64    Unsubscribe { subscription_id: SubscriptionId },
65    /// Ping for keepalive.
66    Ping,
67    /// Pong response.
68    Pong,
69    /// Initial data for subscription.
70    Data {
71        subscription_id: String,
72        data: serde_json::Value,
73    },
74    /// Delta update for subscription.
75    DeltaUpdate {
76        subscription_id: String,
77        delta: Delta<serde_json::Value>,
78    },
79    /// Job progress update.
80    JobUpdate { client_sub_id: String, job: JobData },
81    /// Workflow progress update.
82    WorkflowUpdate {
83        client_sub_id: String,
84        workflow: WorkflowData,
85    },
86    /// Error message.
87    Error { code: String, message: String },
88    /// Error message with subscription ID.
89    ErrorWithId {
90        id: String,
91        code: String,
92        message: String,
93    },
94    /// Authentication successful.
95    AuthSuccess,
96    /// Authentication failed.
97    AuthFailed { reason: String },
98}
99
100#[derive(Debug)]
101pub struct RealtimeSession {
102    #[allow(dead_code)]
103    pub session_id: SessionId,
104    pub subscriptions: Vec<SubscriptionId>,
105    pub sender: mpsc::Sender<RealtimeMessage>,
106    #[allow(dead_code)]
107    pub connected_at: chrono::DateTime<chrono::Utc>,
108    pub last_active: chrono::DateTime<chrono::Utc>,
109}
110
111impl RealtimeSession {
112    /// Create a new session.
113    pub fn new(session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) -> Self {
114        let now = chrono::Utc::now();
115        Self {
116            session_id,
117            subscriptions: Vec::new(),
118            sender,
119            connected_at: now,
120            last_active: now,
121        }
122    }
123
124    /// Add a subscription.
125    pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
126        self.subscriptions.push(subscription_id);
127        self.last_active = chrono::Utc::now();
128    }
129
130    /// Remove a subscription.
131    pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
132        self.subscriptions.retain(|id| *id != subscription_id);
133        self.last_active = chrono::Utc::now();
134    }
135
136    /// Send a message to the client.
137    pub async fn send(
138        &self,
139        message: RealtimeMessage,
140    ) -> Result<(), mpsc::error::SendError<RealtimeMessage>> {
141        self.sender.send(message).await
142    }
143}
144
145pub struct SessionServer {
146    config: RealtimeConfig,
147    node_id: NodeId,
148    /// Active connections by session ID.
149    connections: Arc<RwLock<HashMap<SessionId, RealtimeSession>>>,
150    /// Subscription to session mapping for fast lookup.
151    subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
152}
153
154impl SessionServer {
155    /// Create a new session server.
156    pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
157        Self {
158            config,
159            node_id,
160            connections: Arc::new(RwLock::new(HashMap::new())),
161            subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
162        }
163    }
164
165    /// Get the node ID.
166    pub fn node_id(&self) -> NodeId {
167        self.node_id
168    }
169
170    /// Get the configuration.
171    pub fn config(&self) -> &RealtimeConfig {
172        &self.config
173    }
174
175    /// Register a new connection.
176    pub async fn register_connection(
177        &self,
178        session_id: SessionId,
179        sender: mpsc::Sender<RealtimeMessage>,
180    ) {
181        let connection = RealtimeSession::new(session_id, sender);
182        let mut connections = self.connections.write().await;
183        connections.insert(session_id, connection);
184    }
185
186    /// Remove a connection.
187    pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
188        let mut connections = self.connections.write().await;
189        if let Some(conn) = connections.remove(&session_id) {
190            let mut sub_sessions = self.subscription_sessions.write().await;
191            for sub_id in &conn.subscriptions {
192                sub_sessions.remove(sub_id);
193            }
194            Some(conn.subscriptions)
195        } else {
196            None
197        }
198    }
199
200    /// Add a subscription to a connection.
201    pub async fn add_subscription(
202        &self,
203        session_id: SessionId,
204        subscription_id: SubscriptionId,
205    ) -> forge_core::Result<()> {
206        let mut connections = self.connections.write().await;
207        let conn = connections
208            .get_mut(&session_id)
209            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
210
211        if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
212            return Err(forge_core::ForgeError::Validation(format!(
213                "Maximum subscriptions per session ({}) exceeded",
214                self.config.max_subscriptions_per_session
215            )));
216        }
217
218        conn.add_subscription(subscription_id);
219
220        let mut sub_sessions = self.subscription_sessions.write().await;
221        sub_sessions.insert(subscription_id, session_id);
222
223        Ok(())
224    }
225
226    /// Remove a subscription from a connection.
227    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
228        let session_id = {
229            let mut sub_sessions = self.subscription_sessions.write().await;
230            sub_sessions.remove(&subscription_id)
231        };
232
233        if let Some(session_id) = session_id {
234            let mut connections = self.connections.write().await;
235            if let Some(conn) = connections.get_mut(&session_id) {
236                conn.remove_subscription(subscription_id);
237            }
238        }
239    }
240
241    /// Send a message to a specific session.
242    pub async fn send_to_session(
243        &self,
244        session_id: SessionId,
245        message: RealtimeMessage,
246    ) -> forge_core::Result<()> {
247        let connections = self.connections.read().await;
248        let conn = connections
249            .get(&session_id)
250            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
251
252        conn.send(message)
253            .await
254            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
255    }
256
257    /// Send a delta to all sessions subscribed to a subscription.
258    pub async fn broadcast_delta(
259        &self,
260        subscription_id: SubscriptionId,
261        delta: Delta<serde_json::Value>,
262    ) -> forge_core::Result<()> {
263        let session_id = {
264            let sub_sessions = self.subscription_sessions.read().await;
265            sub_sessions.get(&subscription_id).copied()
266        };
267
268        if let Some(session_id) = session_id {
269            let message = RealtimeMessage::DeltaUpdate {
270                subscription_id: subscription_id.to_string(),
271                delta,
272            };
273            self.send_to_session(session_id, message).await?;
274        }
275
276        Ok(())
277    }
278
279    /// Get connection count.
280    pub async fn connection_count(&self) -> usize {
281        self.connections.read().await.len()
282    }
283
284    /// Get subscription count.
285    pub async fn subscription_count(&self) -> usize {
286        self.subscription_sessions.read().await.len()
287    }
288
289    /// Get server statistics.
290    pub async fn stats(&self) -> SessionStats {
291        let connections = self.connections.read().await;
292        let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
293
294        SessionStats {
295            connections: connections.len(),
296            subscriptions: total_subscriptions,
297            node_id: self.node_id,
298        }
299    }
300
301    /// Cleanup stale connections.
302    pub async fn cleanup_stale(&self, max_idle: Duration) {
303        let cutoff = chrono::Utc::now()
304            - chrono::Duration::from_std(max_idle).expect("duration within chrono range");
305        let mut connections = self.connections.write().await;
306        let mut sub_sessions = self.subscription_sessions.write().await;
307
308        connections.retain(|_, conn| {
309            if conn.last_active < cutoff {
310                for sub_id in &conn.subscriptions {
311                    sub_sessions.remove(sub_id);
312                }
313                false
314            } else {
315                true
316            }
317        });
318    }
319}
320
321/// Session server statistics.
322#[derive(Debug, Clone)]
323pub struct SessionStats {
324    /// Number of active connections.
325    pub connections: usize,
326    /// Total subscriptions across all connections.
327    pub subscriptions: usize,
328    /// Node ID.
329    pub node_id: NodeId,
330}
331
332#[cfg(test)]
333#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_realtime_config_default() {
339        let config = RealtimeConfig::default();
340        assert_eq!(config.max_subscriptions_per_session, 50);
341    }
342
343    #[tokio::test]
344    async fn test_session_server_creation() {
345        let node_id = NodeId::new();
346        let server = SessionServer::new(node_id, RealtimeConfig::default());
347
348        assert_eq!(server.node_id(), node_id);
349        assert_eq!(server.connection_count().await, 0);
350        assert_eq!(server.subscription_count().await, 0);
351    }
352
353    #[tokio::test]
354    async fn test_session_connection() {
355        let node_id = NodeId::new();
356        let server = SessionServer::new(node_id, RealtimeConfig::default());
357        let session_id = SessionId::new();
358        let (tx, _rx) = mpsc::channel(100);
359
360        server.register_connection(session_id, tx).await;
361        assert_eq!(server.connection_count().await, 1);
362
363        let removed = server.remove_connection(session_id).await;
364        assert!(removed.is_some());
365        assert_eq!(server.connection_count().await, 0);
366    }
367
368    #[tokio::test]
369    async fn test_session_subscription() {
370        let node_id = NodeId::new();
371        let server = SessionServer::new(node_id, RealtimeConfig::default());
372        let session_id = SessionId::new();
373        let subscription_id = SubscriptionId::new();
374        let (tx, _rx) = mpsc::channel(100);
375
376        server.register_connection(session_id, tx).await;
377        server
378            .add_subscription(session_id, subscription_id)
379            .await
380            .unwrap();
381
382        assert_eq!(server.subscription_count().await, 1);
383
384        server.remove_subscription(subscription_id).await;
385        assert_eq!(server.subscription_count().await, 0);
386    }
387
388    #[tokio::test]
389    async fn test_session_subscription_limit() {
390        let node_id = NodeId::new();
391        let config = RealtimeConfig {
392            max_subscriptions_per_session: 2,
393        };
394        let server = SessionServer::new(node_id, config);
395        let session_id = SessionId::new();
396        let (tx, _rx) = mpsc::channel(100);
397
398        server.register_connection(session_id, tx).await;
399
400        server
401            .add_subscription(session_id, SubscriptionId::new())
402            .await
403            .unwrap();
404        server
405            .add_subscription(session_id, SubscriptionId::new())
406            .await
407            .unwrap();
408
409        let result = server
410            .add_subscription(session_id, SubscriptionId::new())
411            .await;
412        assert!(result.is_err());
413    }
414
415    #[tokio::test]
416    async fn test_session_stats() {
417        let node_id = NodeId::new();
418        let server = SessionServer::new(node_id, RealtimeConfig::default());
419        let session_id = SessionId::new();
420        let (tx, _rx) = mpsc::channel(100);
421
422        server.register_connection(session_id, tx).await;
423        server
424            .add_subscription(session_id, SubscriptionId::new())
425            .await
426            .unwrap();
427        server
428            .add_subscription(session_id, SubscriptionId::new())
429            .await
430            .unwrap();
431
432        let stats = server.stats().await;
433        assert_eq!(stats.connections, 1);
434        assert_eq!(stats.subscriptions, 2);
435        assert_eq!(stats.node_id, node_id);
436    }
437}