1use std::sync::atomic::{AtomicU32, Ordering};
2use std::time::Duration;
3
4use dashmap::DashMap;
5use serde::Serialize;
6use tokio::sync::mpsc;
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13 pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17 fn default() -> Self {
18 Self {
19 max_subscriptions_per_session: 50,
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27 pub job_id: String,
28 pub status: String,
29 pub progress_percent: Option<i32>,
30 pub progress_message: Option<String>,
31 pub output: Option<serde_json::Value>,
32 pub error: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize)]
37pub struct WorkflowData {
38 pub workflow_id: String,
39 pub status: String,
40 pub current_step: Option<String>,
41 pub steps: Vec<WorkflowStepData>,
42 pub output: Option<serde_json::Value>,
43 pub error: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize)]
48pub struct WorkflowStepData {
49 pub name: String,
50 pub status: String,
51 pub error: Option<String>,
52}
53
54#[derive(Debug, Clone)]
56pub enum RealtimeMessage {
57 Subscribe {
58 id: String,
59 query: String,
60 args: serde_json::Value,
61 },
62 Unsubscribe {
63 subscription_id: SubscriptionId,
64 },
65 Ping,
66 Pong,
67 Data {
68 subscription_id: String,
69 data: serde_json::Value,
70 },
71 DeltaUpdate {
72 subscription_id: String,
73 delta: Delta<serde_json::Value>,
74 },
75 JobUpdate {
76 client_sub_id: String,
77 job: JobData,
78 },
79 WorkflowUpdate {
80 client_sub_id: String,
81 workflow: WorkflowData,
82 },
83 Error {
84 code: String,
85 message: String,
86 },
87 ErrorWithId {
88 id: String,
89 code: String,
90 message: String,
91 },
92 AuthSuccess,
93 AuthFailed {
94 reason: String,
95 },
96 Lagging,
98}
99
100struct SessionEntry {
102 sender: mpsc::Sender<RealtimeMessage>,
103 subscriptions: Vec<SubscriptionId>,
104 #[allow(dead_code)]
105 connected_at: chrono::DateTime<chrono::Utc>,
106 last_active: chrono::DateTime<chrono::Utc>,
107 consecutive_drops: AtomicU32,
109}
110
111const MAX_CONSECUTIVE_DROPS: u32 = 10;
113
114pub struct SessionServer {
115 config: RealtimeConfig,
116 node_id: NodeId,
117 connections: DashMap<SessionId, SessionEntry>,
119 subscription_sessions: DashMap<SubscriptionId, SessionId>,
121}
122
123impl SessionServer {
124 pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
126 Self {
127 config,
128 node_id,
129 connections: DashMap::new(),
130 subscription_sessions: DashMap::new(),
131 }
132 }
133
134 pub fn node_id(&self) -> NodeId {
135 self.node_id
136 }
137
138 pub fn config(&self) -> &RealtimeConfig {
139 &self.config
140 }
141
142 pub fn register_connection(
144 &self,
145 session_id: SessionId,
146 sender: mpsc::Sender<RealtimeMessage>,
147 ) {
148 let now = chrono::Utc::now();
149 let entry = SessionEntry {
150 sender,
151 subscriptions: Vec::new(),
152 connected_at: now,
153 last_active: now,
154 consecutive_drops: AtomicU32::new(0),
155 };
156 self.connections.insert(session_id, entry);
157 }
158
159 pub fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
161 if let Some((_, conn)) = self.connections.remove(&session_id) {
162 for sub_id in &conn.subscriptions {
163 self.subscription_sessions.remove(sub_id);
164 }
165 Some(conn.subscriptions)
166 } else {
167 None
168 }
169 }
170
171 pub fn add_subscription(
173 &self,
174 session_id: SessionId,
175 subscription_id: SubscriptionId,
176 ) -> forge_core::Result<()> {
177 let mut conn = self
178 .connections
179 .get_mut(&session_id)
180 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
181
182 if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
183 return Err(forge_core::ForgeError::Validation(format!(
184 "Maximum subscriptions per session ({}) exceeded",
185 self.config.max_subscriptions_per_session
186 )));
187 }
188
189 conn.subscriptions.push(subscription_id);
190 drop(conn);
191
192 self.subscription_sessions
193 .insert(subscription_id, session_id);
194
195 Ok(())
196 }
197
198 pub fn remove_subscription(&self, subscription_id: SubscriptionId) {
200 if let Some((_, session_id)) = self.subscription_sessions.remove(&subscription_id)
201 && let Some(mut conn) = self.connections.get_mut(&session_id)
202 {
203 conn.subscriptions.retain(|id| *id != subscription_id);
204 }
205 }
206
207 pub fn try_send_to_session(
209 &self,
210 session_id: SessionId,
211 message: RealtimeMessage,
212 ) -> Result<(), SendError> {
213 let conn = self
214 .connections
215 .get(&session_id)
216 .ok_or(SendError::SessionNotFound)?;
217
218 match conn.sender.try_send(message) {
219 Ok(()) => {
220 conn.consecutive_drops.store(0, Ordering::Relaxed);
221 Ok(())
222 }
223 Err(mpsc::error::TrySendError::Full(_)) => {
224 let drops = conn.consecutive_drops.fetch_add(1, Ordering::Relaxed);
225 if drops >= MAX_CONSECUTIVE_DROPS {
226 let _ = conn.sender.try_send(RealtimeMessage::Lagging);
228 drop(conn);
229 self.evict_session(session_id);
230 Err(SendError::Evicted)
231 } else {
232 Err(SendError::Full)
233 }
234 }
235 Err(mpsc::error::TrySendError::Closed(_)) => {
236 drop(conn);
237 self.remove_connection(session_id);
238 Err(SendError::Closed)
239 }
240 }
241 }
242
243 pub async fn send_to_session(
245 &self,
246 session_id: SessionId,
247 message: RealtimeMessage,
248 ) -> forge_core::Result<()> {
249 let sender = {
250 let conn = self.connections.get(&session_id).ok_or_else(|| {
251 forge_core::ForgeError::Validation("Session not found".to_string())
252 })?;
253 conn.sender.clone()
254 };
255
256 sender
257 .send(message)
258 .await
259 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
260 }
261
262 pub async fn broadcast_delta(
264 &self,
265 subscription_id: SubscriptionId,
266 delta: Delta<serde_json::Value>,
267 ) -> forge_core::Result<()> {
268 let session_id = self.subscription_sessions.get(&subscription_id).map(|r| *r);
269
270 if let Some(session_id) = session_id {
271 let message = RealtimeMessage::DeltaUpdate {
272 subscription_id: subscription_id.to_string(),
273 delta,
274 };
275 self.send_to_session(session_id, message).await?;
276 }
277
278 Ok(())
279 }
280
281 fn evict_session(&self, session_id: SessionId) {
283 tracing::warn!(?session_id, "Evicting slow client");
284 self.remove_connection(session_id);
285 }
286
287 pub fn connection_count(&self) -> usize {
289 self.connections.len()
290 }
291
292 pub fn subscription_count(&self) -> usize {
294 self.subscription_sessions.len()
295 }
296
297 pub fn stats(&self) -> SessionStats {
299 let total_subscriptions: usize =
300 self.connections.iter().map(|c| c.subscriptions.len()).sum();
301
302 SessionStats {
303 connections: self.connections.len(),
304 subscriptions: total_subscriptions,
305 node_id: self.node_id,
306 }
307 }
308
309 pub fn cleanup_stale(&self, max_idle: Duration) {
311 let cutoff = chrono::Utc::now()
312 - chrono::Duration::from_std(max_idle).unwrap_or(chrono::TimeDelta::MAX);
313
314 let stale: Vec<SessionId> = self
315 .connections
316 .iter()
317 .filter(|entry| entry.last_active < cutoff)
318 .map(|entry| *entry.key())
319 .collect();
320
321 for session_id in stale {
322 self.remove_connection(session_id);
323 }
324 }
325}
326
327#[derive(Debug)]
329pub enum SendError {
330 SessionNotFound,
331 Full,
332 Closed,
333 Evicted,
334}
335
336#[derive(Debug, Clone)]
338pub struct SessionStats {
339 pub connections: usize,
340 pub subscriptions: usize,
341 pub node_id: NodeId,
342}
343
344#[cfg(test)]
345#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_realtime_config_default() {
351 let config = RealtimeConfig::default();
352 assert_eq!(config.max_subscriptions_per_session, 50);
353 }
354
355 #[test]
356 fn test_session_server_creation() {
357 let node_id = NodeId::new();
358 let server = SessionServer::new(node_id, RealtimeConfig::default());
359
360 assert_eq!(server.node_id(), node_id);
361 assert_eq!(server.connection_count(), 0);
362 assert_eq!(server.subscription_count(), 0);
363 }
364
365 #[test]
366 fn test_session_connection() {
367 let node_id = NodeId::new();
368 let server = SessionServer::new(node_id, RealtimeConfig::default());
369 let session_id = SessionId::new();
370 let (tx, _rx) = mpsc::channel(100);
371
372 server.register_connection(session_id, tx);
373 assert_eq!(server.connection_count(), 1);
374
375 let removed = server.remove_connection(session_id);
376 assert!(removed.is_some());
377 assert_eq!(server.connection_count(), 0);
378 }
379
380 #[test]
381 fn test_session_subscription() {
382 let node_id = NodeId::new();
383 let server = SessionServer::new(node_id, RealtimeConfig::default());
384 let session_id = SessionId::new();
385 let subscription_id = SubscriptionId::new();
386 let (tx, _rx) = mpsc::channel(100);
387
388 server.register_connection(session_id, tx);
389 server
390 .add_subscription(session_id, subscription_id)
391 .unwrap();
392
393 assert_eq!(server.subscription_count(), 1);
394
395 server.remove_subscription(subscription_id);
396 assert_eq!(server.subscription_count(), 0);
397 }
398
399 #[test]
400 fn test_session_subscription_limit() {
401 let node_id = NodeId::new();
402 let config = RealtimeConfig {
403 max_subscriptions_per_session: 2,
404 };
405 let server = SessionServer::new(node_id, config);
406 let session_id = SessionId::new();
407 let (tx, _rx) = mpsc::channel(100);
408
409 server.register_connection(session_id, tx);
410
411 server
412 .add_subscription(session_id, SubscriptionId::new())
413 .unwrap();
414 server
415 .add_subscription(session_id, SubscriptionId::new())
416 .unwrap();
417
418 let result = server.add_subscription(session_id, SubscriptionId::new());
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn test_try_send_backpressure() {
424 let node_id = NodeId::new();
425 let server = SessionServer::new(node_id, RealtimeConfig::default());
426 let session_id = SessionId::new();
427 let (tx, _rx) = mpsc::channel(1);
429
430 server.register_connection(session_id, tx);
431
432 let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
434 assert!(result.is_ok());
435
436 let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
438 assert!(matches!(result, Err(SendError::Full)));
439 }
440
441 #[test]
442 fn test_session_stats() {
443 let node_id = NodeId::new();
444 let server = SessionServer::new(node_id, RealtimeConfig::default());
445 let session_id = SessionId::new();
446 let (tx, _rx) = mpsc::channel(100);
447
448 server.register_connection(session_id, tx);
449 server
450 .add_subscription(session_id, SubscriptionId::new())
451 .unwrap();
452 server
453 .add_subscription(session_id, SubscriptionId::new())
454 .unwrap();
455
456 let stats = server.stats();
457 assert_eq!(stats.connections, 1);
458 assert_eq!(stats.subscriptions, 2);
459 assert_eq!(stats.node_id, node_id);
460 }
461}