forge_runtime/realtime/
message.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::Serialize;
6use tokio::sync::{RwLock, mpsc};
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13    pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17    fn default() -> Self {
18        Self {
19            max_subscriptions_per_session: 50,
20        }
21    }
22}
23
24/// Job data sent to client (subset of internal JobRecord).
25#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27    pub job_id: String,
28    pub status: String,
29    pub progress_percent: Option<i32>,
30    pub progress_message: Option<String>,
31    pub output: Option<serde_json::Value>,
32    pub error: Option<String>,
33}
34
35/// Workflow data sent to client.
36#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38    pub workflow_id: String,
39    pub status: String,
40    pub current_step: Option<String>,
41    pub steps: Vec<WorkflowStepData>,
42    pub output: Option<serde_json::Value>,
43    pub error: Option<String>,
44}
45
46/// Workflow step data sent to client.
47#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49    pub name: String,
50    pub status: String,
51    pub error: Option<String>,
52}
53
54/// Message types for real-time communication.
55#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57    /// Subscribe to a query.
58    Subscribe {
59        id: String,
60        query: String,
61        args: serde_json::Value,
62    },
63    /// Unsubscribe from a subscription.
64    Unsubscribe { subscription_id: SubscriptionId },
65    /// Ping for keepalive.
66    Ping,
67    /// Pong response.
68    Pong,
69    /// Initial data for subscription.
70    Data {
71        subscription_id: String,
72        data: serde_json::Value,
73    },
74    /// Delta update for subscription.
75    DeltaUpdate {
76        subscription_id: String,
77        delta: Delta<serde_json::Value>,
78    },
79    /// Job progress update.
80    JobUpdate { client_sub_id: String, job: JobData },
81    /// Workflow progress update.
82    WorkflowUpdate {
83        client_sub_id: String,
84        workflow: WorkflowData,
85    },
86    /// Error message.
87    Error { code: String, message: String },
88    /// Error message with subscription ID.
89    ErrorWithId {
90        id: String,
91        code: String,
92        message: String,
93    },
94    /// Authentication successful.
95    AuthSuccess,
96    /// Authentication failed.
97    AuthFailed { reason: String },
98}
99
100#[derive(Debug)]
101pub struct RealtimeSession {
102    #[allow(dead_code)]
103    pub session_id: SessionId,
104    pub subscriptions: Vec<SubscriptionId>,
105    pub sender: mpsc::Sender<RealtimeMessage>,
106    #[allow(dead_code)]
107    pub connected_at: chrono::DateTime<chrono::Utc>,
108    pub last_active: chrono::DateTime<chrono::Utc>,
109}
110
111impl RealtimeSession {
112    /// Create a new session.
113    pub fn new(session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) -> Self {
114        let now = chrono::Utc::now();
115        Self {
116            session_id,
117            subscriptions: Vec::new(),
118            sender,
119            connected_at: now,
120            last_active: now,
121        }
122    }
123
124    /// Add a subscription.
125    pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
126        self.subscriptions.push(subscription_id);
127        self.last_active = chrono::Utc::now();
128    }
129
130    /// Remove a subscription.
131    pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
132        self.subscriptions.retain(|id| *id != subscription_id);
133        self.last_active = chrono::Utc::now();
134    }
135
136    /// Send a message to the client.
137    pub async fn send(
138        &self,
139        message: RealtimeMessage,
140    ) -> Result<(), mpsc::error::SendError<RealtimeMessage>> {
141        self.sender.send(message).await
142    }
143}
144
145pub struct SessionServer {
146    config: RealtimeConfig,
147    node_id: NodeId,
148    /// Active connections by session ID.
149    connections: Arc<RwLock<HashMap<SessionId, RealtimeSession>>>,
150    /// Subscription to session mapping for fast lookup.
151    subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
152}
153
154impl SessionServer {
155    /// Create a new session server.
156    pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
157        Self {
158            config,
159            node_id,
160            connections: Arc::new(RwLock::new(HashMap::new())),
161            subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
162        }
163    }
164
165    /// Get the node ID.
166    pub fn node_id(&self) -> NodeId {
167        self.node_id
168    }
169
170    /// Get the configuration.
171    pub fn config(&self) -> &RealtimeConfig {
172        &self.config
173    }
174
175    /// Register a new connection.
176    pub async fn register_connection(
177        &self,
178        session_id: SessionId,
179        sender: mpsc::Sender<RealtimeMessage>,
180    ) {
181        let connection = RealtimeSession::new(session_id, sender);
182        let mut connections = self.connections.write().await;
183        connections.insert(session_id, connection);
184    }
185
186    /// Remove a connection.
187    pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
188        let mut connections = self.connections.write().await;
189        if let Some(conn) = connections.remove(&session_id) {
190            let mut sub_sessions = self.subscription_sessions.write().await;
191            for sub_id in &conn.subscriptions {
192                sub_sessions.remove(sub_id);
193            }
194            Some(conn.subscriptions)
195        } else {
196            None
197        }
198    }
199
200    /// Add a subscription to a connection.
201    pub async fn add_subscription(
202        &self,
203        session_id: SessionId,
204        subscription_id: SubscriptionId,
205    ) -> forge_core::Result<()> {
206        let mut connections = self.connections.write().await;
207        let conn = connections
208            .get_mut(&session_id)
209            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
210
211        if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
212            return Err(forge_core::ForgeError::Validation(format!(
213                "Maximum subscriptions per session ({}) exceeded",
214                self.config.max_subscriptions_per_session
215            )));
216        }
217
218        conn.add_subscription(subscription_id);
219
220        let mut sub_sessions = self.subscription_sessions.write().await;
221        sub_sessions.insert(subscription_id, session_id);
222
223        Ok(())
224    }
225
226    /// Remove a subscription from a connection.
227    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
228        let session_id = {
229            let mut sub_sessions = self.subscription_sessions.write().await;
230            sub_sessions.remove(&subscription_id)
231        };
232
233        if let Some(session_id) = session_id {
234            let mut connections = self.connections.write().await;
235            if let Some(conn) = connections.get_mut(&session_id) {
236                conn.remove_subscription(subscription_id);
237            }
238        }
239    }
240
241    /// Send a message to a specific session.
242    pub async fn send_to_session(
243        &self,
244        session_id: SessionId,
245        message: RealtimeMessage,
246    ) -> forge_core::Result<()> {
247        let connections = self.connections.read().await;
248        let conn = connections
249            .get(&session_id)
250            .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
251
252        conn.send(message)
253            .await
254            .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
255    }
256
257    /// Send a delta to all sessions subscribed to a subscription.
258    pub async fn broadcast_delta(
259        &self,
260        subscription_id: SubscriptionId,
261        delta: Delta<serde_json::Value>,
262    ) -> forge_core::Result<()> {
263        let session_id = {
264            let sub_sessions = self.subscription_sessions.read().await;
265            sub_sessions.get(&subscription_id).copied()
266        };
267
268        if let Some(session_id) = session_id {
269            let message = RealtimeMessage::DeltaUpdate {
270                subscription_id: subscription_id.to_string(),
271                delta,
272            };
273            self.send_to_session(session_id, message).await?;
274        }
275
276        Ok(())
277    }
278
279    /// Get connection count.
280    pub async fn connection_count(&self) -> usize {
281        self.connections.read().await.len()
282    }
283
284    /// Get subscription count.
285    pub async fn subscription_count(&self) -> usize {
286        self.subscription_sessions.read().await.len()
287    }
288
289    /// Get server statistics.
290    pub async fn stats(&self) -> SessionStats {
291        let connections = self.connections.read().await;
292        let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
293
294        SessionStats {
295            connections: connections.len(),
296            subscriptions: total_subscriptions,
297            node_id: self.node_id,
298        }
299    }
300
301    /// Cleanup stale connections.
302    pub async fn cleanup_stale(&self, max_idle: Duration) {
303        let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
304        let mut connections = self.connections.write().await;
305        let mut sub_sessions = self.subscription_sessions.write().await;
306
307        connections.retain(|_, conn| {
308            if conn.last_active < cutoff {
309                for sub_id in &conn.subscriptions {
310                    sub_sessions.remove(sub_id);
311                }
312                false
313            } else {
314                true
315            }
316        });
317    }
318}
319
320/// Session server statistics.
321#[derive(Debug, Clone)]
322pub struct SessionStats {
323    /// Number of active connections.
324    pub connections: usize,
325    /// Total subscriptions across all connections.
326    pub subscriptions: usize,
327    /// Node ID.
328    pub node_id: NodeId,
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_realtime_config_default() {
337        let config = RealtimeConfig::default();
338        assert_eq!(config.max_subscriptions_per_session, 50);
339    }
340
341    #[tokio::test]
342    async fn test_session_server_creation() {
343        let node_id = NodeId::new();
344        let server = SessionServer::new(node_id, RealtimeConfig::default());
345
346        assert_eq!(server.node_id(), node_id);
347        assert_eq!(server.connection_count().await, 0);
348        assert_eq!(server.subscription_count().await, 0);
349    }
350
351    #[tokio::test]
352    async fn test_session_connection() {
353        let node_id = NodeId::new();
354        let server = SessionServer::new(node_id, RealtimeConfig::default());
355        let session_id = SessionId::new();
356        let (tx, _rx) = mpsc::channel(100);
357
358        server.register_connection(session_id, tx).await;
359        assert_eq!(server.connection_count().await, 1);
360
361        let removed = server.remove_connection(session_id).await;
362        assert!(removed.is_some());
363        assert_eq!(server.connection_count().await, 0);
364    }
365
366    #[tokio::test]
367    async fn test_session_subscription() {
368        let node_id = NodeId::new();
369        let server = SessionServer::new(node_id, RealtimeConfig::default());
370        let session_id = SessionId::new();
371        let subscription_id = SubscriptionId::new();
372        let (tx, _rx) = mpsc::channel(100);
373
374        server.register_connection(session_id, tx).await;
375        server
376            .add_subscription(session_id, subscription_id)
377            .await
378            .unwrap();
379
380        assert_eq!(server.subscription_count().await, 1);
381
382        server.remove_subscription(subscription_id).await;
383        assert_eq!(server.subscription_count().await, 0);
384    }
385
386    #[tokio::test]
387    async fn test_session_subscription_limit() {
388        let node_id = NodeId::new();
389        let config = RealtimeConfig {
390            max_subscriptions_per_session: 2,
391        };
392        let server = SessionServer::new(node_id, config);
393        let session_id = SessionId::new();
394        let (tx, _rx) = mpsc::channel(100);
395
396        server.register_connection(session_id, tx).await;
397
398        server
399            .add_subscription(session_id, SubscriptionId::new())
400            .await
401            .unwrap();
402        server
403            .add_subscription(session_id, SubscriptionId::new())
404            .await
405            .unwrap();
406
407        let result = server
408            .add_subscription(session_id, SubscriptionId::new())
409            .await;
410        assert!(result.is_err());
411    }
412
413    #[tokio::test]
414    async fn test_session_stats() {
415        let node_id = NodeId::new();
416        let server = SessionServer::new(node_id, RealtimeConfig::default());
417        let session_id = SessionId::new();
418        let (tx, _rx) = mpsc::channel(100);
419
420        server.register_connection(session_id, tx).await;
421        server
422            .add_subscription(session_id, SubscriptionId::new())
423            .await
424            .unwrap();
425        server
426            .add_subscription(session_id, SubscriptionId::new())
427            .await
428            .unwrap();
429
430        let stats = server.stats().await;
431        assert_eq!(stats.connections, 1);
432        assert_eq!(stats.subscriptions, 2);
433        assert_eq!(stats.node_id, node_id);
434    }
435}