1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::Serialize;
6use tokio::sync::{RwLock, mpsc};
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13 pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17 fn default() -> Self {
18 Self {
19 max_subscriptions_per_session: 50,
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27 pub job_id: String,
28 pub status: String,
29 pub progress_percent: Option<i32>,
30 pub progress_message: Option<String>,
31 pub output: Option<serde_json::Value>,
32 pub error: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38 pub workflow_id: String,
39 pub status: String,
40 pub current_step: Option<String>,
41 pub steps: Vec<WorkflowStepData>,
42 pub output: Option<serde_json::Value>,
43 pub error: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49 pub name: String,
50 pub status: String,
51 pub error: Option<String>,
52}
53
54#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57 Subscribe {
59 id: String,
60 query: String,
61 args: serde_json::Value,
62 },
63 Unsubscribe { subscription_id: SubscriptionId },
65 Ping,
67 Pong,
69 Data {
71 subscription_id: String,
72 data: serde_json::Value,
73 },
74 DeltaUpdate {
76 subscription_id: String,
77 delta: Delta<serde_json::Value>,
78 },
79 JobUpdate { client_sub_id: String, job: JobData },
81 WorkflowUpdate {
83 client_sub_id: String,
84 workflow: WorkflowData,
85 },
86 Error { code: String, message: String },
88 ErrorWithId {
90 id: String,
91 code: String,
92 message: String,
93 },
94 AuthSuccess,
96 AuthFailed { reason: String },
98}
99
100#[derive(Debug)]
101pub struct RealtimeSession {
102 #[allow(dead_code)]
103 pub session_id: SessionId,
104 pub subscriptions: Vec<SubscriptionId>,
105 pub sender: mpsc::Sender<RealtimeMessage>,
106 #[allow(dead_code)]
107 pub connected_at: chrono::DateTime<chrono::Utc>,
108 pub last_active: chrono::DateTime<chrono::Utc>,
109}
110
111impl RealtimeSession {
112 pub fn new(session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) -> Self {
114 let now = chrono::Utc::now();
115 Self {
116 session_id,
117 subscriptions: Vec::new(),
118 sender,
119 connected_at: now,
120 last_active: now,
121 }
122 }
123
124 pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
126 self.subscriptions.push(subscription_id);
127 self.last_active = chrono::Utc::now();
128 }
129
130 pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
132 self.subscriptions.retain(|id| *id != subscription_id);
133 self.last_active = chrono::Utc::now();
134 }
135
136 pub async fn send(
138 &self,
139 message: RealtimeMessage,
140 ) -> Result<(), mpsc::error::SendError<RealtimeMessage>> {
141 self.sender.send(message).await
142 }
143}
144
145pub struct SessionServer {
146 config: RealtimeConfig,
147 node_id: NodeId,
148 connections: Arc<RwLock<HashMap<SessionId, RealtimeSession>>>,
150 subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
152}
153
154impl SessionServer {
155 pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
157 Self {
158 config,
159 node_id,
160 connections: Arc::new(RwLock::new(HashMap::new())),
161 subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
162 }
163 }
164
165 pub fn node_id(&self) -> NodeId {
167 self.node_id
168 }
169
170 pub fn config(&self) -> &RealtimeConfig {
172 &self.config
173 }
174
175 pub async fn register_connection(
177 &self,
178 session_id: SessionId,
179 sender: mpsc::Sender<RealtimeMessage>,
180 ) {
181 let connection = RealtimeSession::new(session_id, sender);
182 let mut connections = self.connections.write().await;
183 connections.insert(session_id, connection);
184 }
185
186 pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
188 let mut connections = self.connections.write().await;
189 if let Some(conn) = connections.remove(&session_id) {
190 let mut sub_sessions = self.subscription_sessions.write().await;
191 for sub_id in &conn.subscriptions {
192 sub_sessions.remove(sub_id);
193 }
194 Some(conn.subscriptions)
195 } else {
196 None
197 }
198 }
199
200 pub async fn add_subscription(
202 &self,
203 session_id: SessionId,
204 subscription_id: SubscriptionId,
205 ) -> forge_core::Result<()> {
206 let mut connections = self.connections.write().await;
207 let conn = connections
208 .get_mut(&session_id)
209 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
210
211 if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
212 return Err(forge_core::ForgeError::Validation(format!(
213 "Maximum subscriptions per session ({}) exceeded",
214 self.config.max_subscriptions_per_session
215 )));
216 }
217
218 conn.add_subscription(subscription_id);
219
220 let mut sub_sessions = self.subscription_sessions.write().await;
221 sub_sessions.insert(subscription_id, session_id);
222
223 Ok(())
224 }
225
226 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
228 let session_id = {
229 let mut sub_sessions = self.subscription_sessions.write().await;
230 sub_sessions.remove(&subscription_id)
231 };
232
233 if let Some(session_id) = session_id {
234 let mut connections = self.connections.write().await;
235 if let Some(conn) = connections.get_mut(&session_id) {
236 conn.remove_subscription(subscription_id);
237 }
238 }
239 }
240
241 pub async fn send_to_session(
243 &self,
244 session_id: SessionId,
245 message: RealtimeMessage,
246 ) -> forge_core::Result<()> {
247 let connections = self.connections.read().await;
248 let conn = connections
249 .get(&session_id)
250 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
251
252 conn.send(message)
253 .await
254 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
255 }
256
257 pub async fn broadcast_delta(
259 &self,
260 subscription_id: SubscriptionId,
261 delta: Delta<serde_json::Value>,
262 ) -> forge_core::Result<()> {
263 let session_id = {
264 let sub_sessions = self.subscription_sessions.read().await;
265 sub_sessions.get(&subscription_id).copied()
266 };
267
268 if let Some(session_id) = session_id {
269 let message = RealtimeMessage::DeltaUpdate {
270 subscription_id: subscription_id.to_string(),
271 delta,
272 };
273 self.send_to_session(session_id, message).await?;
274 }
275
276 Ok(())
277 }
278
279 pub async fn connection_count(&self) -> usize {
281 self.connections.read().await.len()
282 }
283
284 pub async fn subscription_count(&self) -> usize {
286 self.subscription_sessions.read().await.len()
287 }
288
289 pub async fn stats(&self) -> SessionStats {
291 let connections = self.connections.read().await;
292 let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
293
294 SessionStats {
295 connections: connections.len(),
296 subscriptions: total_subscriptions,
297 node_id: self.node_id,
298 }
299 }
300
301 pub async fn cleanup_stale(&self, max_idle: Duration) {
303 let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
304 let mut connections = self.connections.write().await;
305 let mut sub_sessions = self.subscription_sessions.write().await;
306
307 connections.retain(|_, conn| {
308 if conn.last_active < cutoff {
309 for sub_id in &conn.subscriptions {
310 sub_sessions.remove(sub_id);
311 }
312 false
313 } else {
314 true
315 }
316 });
317 }
318}
319
320#[derive(Debug, Clone)]
322pub struct SessionStats {
323 pub connections: usize,
325 pub subscriptions: usize,
327 pub node_id: NodeId,
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_realtime_config_default() {
337 let config = RealtimeConfig::default();
338 assert_eq!(config.max_subscriptions_per_session, 50);
339 }
340
341 #[tokio::test]
342 async fn test_session_server_creation() {
343 let node_id = NodeId::new();
344 let server = SessionServer::new(node_id, RealtimeConfig::default());
345
346 assert_eq!(server.node_id(), node_id);
347 assert_eq!(server.connection_count().await, 0);
348 assert_eq!(server.subscription_count().await, 0);
349 }
350
351 #[tokio::test]
352 async fn test_session_connection() {
353 let node_id = NodeId::new();
354 let server = SessionServer::new(node_id, RealtimeConfig::default());
355 let session_id = SessionId::new();
356 let (tx, _rx) = mpsc::channel(100);
357
358 server.register_connection(session_id, tx).await;
359 assert_eq!(server.connection_count().await, 1);
360
361 let removed = server.remove_connection(session_id).await;
362 assert!(removed.is_some());
363 assert_eq!(server.connection_count().await, 0);
364 }
365
366 #[tokio::test]
367 async fn test_session_subscription() {
368 let node_id = NodeId::new();
369 let server = SessionServer::new(node_id, RealtimeConfig::default());
370 let session_id = SessionId::new();
371 let subscription_id = SubscriptionId::new();
372 let (tx, _rx) = mpsc::channel(100);
373
374 server.register_connection(session_id, tx).await;
375 server
376 .add_subscription(session_id, subscription_id)
377 .await
378 .unwrap();
379
380 assert_eq!(server.subscription_count().await, 1);
381
382 server.remove_subscription(subscription_id).await;
383 assert_eq!(server.subscription_count().await, 0);
384 }
385
386 #[tokio::test]
387 async fn test_session_subscription_limit() {
388 let node_id = NodeId::new();
389 let config = RealtimeConfig {
390 max_subscriptions_per_session: 2,
391 };
392 let server = SessionServer::new(node_id, config);
393 let session_id = SessionId::new();
394 let (tx, _rx) = mpsc::channel(100);
395
396 server.register_connection(session_id, tx).await;
397
398 server
399 .add_subscription(session_id, SubscriptionId::new())
400 .await
401 .unwrap();
402 server
403 .add_subscription(session_id, SubscriptionId::new())
404 .await
405 .unwrap();
406
407 let result = server
408 .add_subscription(session_id, SubscriptionId::new())
409 .await;
410 assert!(result.is_err());
411 }
412
413 #[tokio::test]
414 async fn test_session_stats() {
415 let node_id = NodeId::new();
416 let server = SessionServer::new(node_id, RealtimeConfig::default());
417 let session_id = SessionId::new();
418 let (tx, _rx) = mpsc::channel(100);
419
420 server.register_connection(session_id, tx).await;
421 server
422 .add_subscription(session_id, SubscriptionId::new())
423 .await
424 .unwrap();
425 server
426 .add_subscription(session_id, SubscriptionId::new())
427 .await
428 .unwrap();
429
430 let stats = server.stats().await;
431 assert_eq!(stats.connections, 1);
432 assert_eq!(stats.subscriptions, 2);
433 assert_eq!(stats.node_id, node_id);
434 }
435}