1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::Serialize;
6use tokio::sync::{RwLock, mpsc};
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13 pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17 fn default() -> Self {
18 Self {
19 max_subscriptions_per_session: 50,
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27 pub job_id: String,
28 pub status: String,
29 pub progress_percent: Option<i32>,
30 pub progress_message: Option<String>,
31 pub output: Option<serde_json::Value>,
32 pub error: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38 pub workflow_id: String,
39 pub status: String,
40 pub current_step: Option<String>,
41 pub steps: Vec<WorkflowStepData>,
42 pub output: Option<serde_json::Value>,
43 pub error: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49 pub name: String,
50 pub status: String,
51 pub error: Option<String>,
52}
53
54#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57 Subscribe {
59 id: String,
60 query: String,
61 args: serde_json::Value,
62 },
63 Unsubscribe { subscription_id: SubscriptionId },
65 Ping,
67 Pong,
69 Data {
71 subscription_id: String,
72 data: serde_json::Value,
73 },
74 DeltaUpdate {
76 subscription_id: String,
77 delta: Delta<serde_json::Value>,
78 },
79 JobUpdate { client_sub_id: String, job: JobData },
81 WorkflowUpdate {
83 client_sub_id: String,
84 workflow: WorkflowData,
85 },
86 Error { code: String, message: String },
88 ErrorWithId {
90 id: String,
91 code: String,
92 message: String,
93 },
94 AuthSuccess,
96 AuthFailed { reason: String },
98}
99
100#[derive(Debug)]
101pub struct RealtimeSession {
102 #[allow(dead_code)]
103 pub session_id: SessionId,
104 pub subscriptions: Vec<SubscriptionId>,
105 pub sender: mpsc::Sender<RealtimeMessage>,
106 #[allow(dead_code)]
107 pub connected_at: chrono::DateTime<chrono::Utc>,
108 pub last_active: chrono::DateTime<chrono::Utc>,
109}
110
111impl RealtimeSession {
112 pub fn new(session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) -> Self {
114 let now = chrono::Utc::now();
115 Self {
116 session_id,
117 subscriptions: Vec::new(),
118 sender,
119 connected_at: now,
120 last_active: now,
121 }
122 }
123
124 pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
126 self.subscriptions.push(subscription_id);
127 self.last_active = chrono::Utc::now();
128 }
129
130 pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
132 self.subscriptions.retain(|id| *id != subscription_id);
133 self.last_active = chrono::Utc::now();
134 }
135
136 pub async fn send(
138 &self,
139 message: RealtimeMessage,
140 ) -> Result<(), mpsc::error::SendError<RealtimeMessage>> {
141 self.sender.send(message).await
142 }
143}
144
145pub struct SessionServer {
146 config: RealtimeConfig,
147 node_id: NodeId,
148 connections: Arc<RwLock<HashMap<SessionId, RealtimeSession>>>,
150 subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
152}
153
154impl SessionServer {
155 pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
157 Self {
158 config,
159 node_id,
160 connections: Arc::new(RwLock::new(HashMap::new())),
161 subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
162 }
163 }
164
165 pub fn node_id(&self) -> NodeId {
167 self.node_id
168 }
169
170 pub fn config(&self) -> &RealtimeConfig {
172 &self.config
173 }
174
175 pub async fn register_connection(
177 &self,
178 session_id: SessionId,
179 sender: mpsc::Sender<RealtimeMessage>,
180 ) {
181 let connection = RealtimeSession::new(session_id, sender);
182 let mut connections = self.connections.write().await;
183 connections.insert(session_id, connection);
184 }
185
186 pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
188 let mut connections = self.connections.write().await;
189 if let Some(conn) = connections.remove(&session_id) {
190 let mut sub_sessions = self.subscription_sessions.write().await;
191 for sub_id in &conn.subscriptions {
192 sub_sessions.remove(sub_id);
193 }
194 Some(conn.subscriptions)
195 } else {
196 None
197 }
198 }
199
200 pub async fn add_subscription(
202 &self,
203 session_id: SessionId,
204 subscription_id: SubscriptionId,
205 ) -> forge_core::Result<()> {
206 let mut connections = self.connections.write().await;
207 let conn = connections
208 .get_mut(&session_id)
209 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
210
211 if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
212 return Err(forge_core::ForgeError::Validation(format!(
213 "Maximum subscriptions per session ({}) exceeded",
214 self.config.max_subscriptions_per_session
215 )));
216 }
217
218 conn.add_subscription(subscription_id);
219
220 let mut sub_sessions = self.subscription_sessions.write().await;
221 sub_sessions.insert(subscription_id, session_id);
222
223 Ok(())
224 }
225
226 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
228 let session_id = {
229 let mut sub_sessions = self.subscription_sessions.write().await;
230 sub_sessions.remove(&subscription_id)
231 };
232
233 if let Some(session_id) = session_id {
234 let mut connections = self.connections.write().await;
235 if let Some(conn) = connections.get_mut(&session_id) {
236 conn.remove_subscription(subscription_id);
237 }
238 }
239 }
240
241 pub async fn send_to_session(
243 &self,
244 session_id: SessionId,
245 message: RealtimeMessage,
246 ) -> forge_core::Result<()> {
247 let connections = self.connections.read().await;
248 let conn = connections
249 .get(&session_id)
250 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
251
252 conn.send(message)
253 .await
254 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
255 }
256
257 pub async fn broadcast_delta(
259 &self,
260 subscription_id: SubscriptionId,
261 delta: Delta<serde_json::Value>,
262 ) -> forge_core::Result<()> {
263 let session_id = {
264 let sub_sessions = self.subscription_sessions.read().await;
265 sub_sessions.get(&subscription_id).copied()
266 };
267
268 if let Some(session_id) = session_id {
269 let message = RealtimeMessage::DeltaUpdate {
270 subscription_id: subscription_id.to_string(),
271 delta,
272 };
273 self.send_to_session(session_id, message).await?;
274 }
275
276 Ok(())
277 }
278
279 pub async fn connection_count(&self) -> usize {
281 self.connections.read().await.len()
282 }
283
284 pub async fn subscription_count(&self) -> usize {
286 self.subscription_sessions.read().await.len()
287 }
288
289 pub async fn stats(&self) -> SessionStats {
291 let connections = self.connections.read().await;
292 let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
293
294 SessionStats {
295 connections: connections.len(),
296 subscriptions: total_subscriptions,
297 node_id: self.node_id,
298 }
299 }
300
301 pub async fn cleanup_stale(&self, max_idle: Duration) {
303 let cutoff = chrono::Utc::now()
304 - chrono::Duration::from_std(max_idle).expect("duration within chrono range");
305 let mut connections = self.connections.write().await;
306 let mut sub_sessions = self.subscription_sessions.write().await;
307
308 connections.retain(|_, conn| {
309 if conn.last_active < cutoff {
310 for sub_id in &conn.subscriptions {
311 sub_sessions.remove(sub_id);
312 }
313 false
314 } else {
315 true
316 }
317 });
318 }
319}
320
321#[derive(Debug, Clone)]
323pub struct SessionStats {
324 pub connections: usize,
326 pub subscriptions: usize,
328 pub node_id: NodeId,
330}
331
332#[cfg(test)]
333#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_realtime_config_default() {
339 let config = RealtimeConfig::default();
340 assert_eq!(config.max_subscriptions_per_session, 50);
341 }
342
343 #[tokio::test]
344 async fn test_session_server_creation() {
345 let node_id = NodeId::new();
346 let server = SessionServer::new(node_id, RealtimeConfig::default());
347
348 assert_eq!(server.node_id(), node_id);
349 assert_eq!(server.connection_count().await, 0);
350 assert_eq!(server.subscription_count().await, 0);
351 }
352
353 #[tokio::test]
354 async fn test_session_connection() {
355 let node_id = NodeId::new();
356 let server = SessionServer::new(node_id, RealtimeConfig::default());
357 let session_id = SessionId::new();
358 let (tx, _rx) = mpsc::channel(100);
359
360 server.register_connection(session_id, tx).await;
361 assert_eq!(server.connection_count().await, 1);
362
363 let removed = server.remove_connection(session_id).await;
364 assert!(removed.is_some());
365 assert_eq!(server.connection_count().await, 0);
366 }
367
368 #[tokio::test]
369 async fn test_session_subscription() {
370 let node_id = NodeId::new();
371 let server = SessionServer::new(node_id, RealtimeConfig::default());
372 let session_id = SessionId::new();
373 let subscription_id = SubscriptionId::new();
374 let (tx, _rx) = mpsc::channel(100);
375
376 server.register_connection(session_id, tx).await;
377 server
378 .add_subscription(session_id, subscription_id)
379 .await
380 .unwrap();
381
382 assert_eq!(server.subscription_count().await, 1);
383
384 server.remove_subscription(subscription_id).await;
385 assert_eq!(server.subscription_count().await, 0);
386 }
387
388 #[tokio::test]
389 async fn test_session_subscription_limit() {
390 let node_id = NodeId::new();
391 let config = RealtimeConfig {
392 max_subscriptions_per_session: 2,
393 };
394 let server = SessionServer::new(node_id, config);
395 let session_id = SessionId::new();
396 let (tx, _rx) = mpsc::channel(100);
397
398 server.register_connection(session_id, tx).await;
399
400 server
401 .add_subscription(session_id, SubscriptionId::new())
402 .await
403 .unwrap();
404 server
405 .add_subscription(session_id, SubscriptionId::new())
406 .await
407 .unwrap();
408
409 let result = server
410 .add_subscription(session_id, SubscriptionId::new())
411 .await;
412 assert!(result.is_err());
413 }
414
415 #[tokio::test]
416 async fn test_session_stats() {
417 let node_id = NodeId::new();
418 let server = SessionServer::new(node_id, RealtimeConfig::default());
419 let session_id = SessionId::new();
420 let (tx, _rx) = mpsc::channel(100);
421
422 server.register_connection(session_id, tx).await;
423 server
424 .add_subscription(session_id, SubscriptionId::new())
425 .await
426 .unwrap();
427 server
428 .add_subscription(session_id, SubscriptionId::new())
429 .await
430 .unwrap();
431
432 let stats = server.stats().await;
433 assert_eq!(stats.connections, 1);
434 assert_eq!(stats.subscriptions, 2);
435 assert_eq!(stats.node_id, node_id);
436 }
437}