1use std::sync::atomic::{AtomicU32, Ordering};
2use std::time::Duration;
3
4use dashmap::DashMap;
5use serde::Serialize;
6use tokio::sync::mpsc;
7
8use forge_core::cluster::NodeId;
9use forge_core::realtime::{Delta, SessionId, SubscriptionId};
10
11#[derive(Debug, Clone)]
12pub struct RealtimeConfig {
13 pub max_subscriptions_per_session: usize,
14}
15
16impl Default for RealtimeConfig {
17 fn default() -> Self {
18 Self {
19 max_subscriptions_per_session: 50,
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize)]
26pub struct JobData {
27 pub job_id: String,
28 pub status: String,
29 #[serde(rename = "progress")]
30 pub progress_percent: Option<i32>,
31 #[serde(rename = "message")]
32 pub progress_message: Option<String>,
33 pub output: Option<serde_json::Value>,
34 pub error: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize)]
39pub struct WorkflowData {
40 pub workflow_id: String,
41 pub status: String,
42 #[serde(rename = "step")]
43 pub current_step: Option<String>,
44 pub steps: Vec<WorkflowStepData>,
45 pub output: Option<serde_json::Value>,
46 pub error: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize)]
51pub struct WorkflowStepData {
52 pub name: String,
53 pub status: String,
54 pub error: Option<String>,
55}
56
57#[derive(Debug, Clone)]
59pub enum RealtimeMessage {
60 Subscribe {
61 id: String,
62 query: String,
63 args: serde_json::Value,
64 },
65 Unsubscribe {
66 subscription_id: SubscriptionId,
67 },
68 Ping,
69 Pong,
70 Data {
71 subscription_id: String,
72 data: serde_json::Value,
73 },
74 DeltaUpdate {
75 subscription_id: String,
76 delta: Delta<serde_json::Value>,
77 },
78 JobUpdate {
79 client_sub_id: String,
80 job: JobData,
81 },
82 WorkflowUpdate {
83 client_sub_id: String,
84 workflow: WorkflowData,
85 },
86 Error {
87 code: String,
88 message: String,
89 },
90 ErrorWithId {
91 id: String,
92 code: String,
93 message: String,
94 },
95 AuthSuccess,
96 AuthFailed {
97 reason: String,
98 },
99 Lagging,
101}
102
103struct SessionEntry {
105 sender: mpsc::Sender<RealtimeMessage>,
106 subscriptions: Vec<SubscriptionId>,
107 connected_at: chrono::DateTime<chrono::Utc>,
108 last_active: chrono::DateTime<chrono::Utc>,
109 consecutive_drops: AtomicU32,
111}
112
113const MAX_CONSECUTIVE_DROPS: u32 = 10;
115
116pub struct SessionServer {
117 config: RealtimeConfig,
118 node_id: NodeId,
119 connections: DashMap<SessionId, SessionEntry>,
121 subscription_sessions: DashMap<SubscriptionId, SessionId>,
123}
124
125impl SessionServer {
126 pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
128 Self {
129 config,
130 node_id,
131 connections: DashMap::new(),
132 subscription_sessions: DashMap::new(),
133 }
134 }
135
136 pub fn node_id(&self) -> NodeId {
137 self.node_id
138 }
139
140 pub fn config(&self) -> &RealtimeConfig {
141 &self.config
142 }
143
144 pub fn register_connection(
146 &self,
147 session_id: SessionId,
148 sender: mpsc::Sender<RealtimeMessage>,
149 ) {
150 let now = chrono::Utc::now();
151 let entry = SessionEntry {
152 sender,
153 subscriptions: Vec::new(),
154 connected_at: now,
155 last_active: now,
156 consecutive_drops: AtomicU32::new(0),
157 };
158 self.connections.insert(session_id, entry);
159 }
160
161 pub fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
163 if let Some((_, conn)) = self.connections.remove(&session_id) {
164 for sub_id in &conn.subscriptions {
165 self.subscription_sessions.remove(sub_id);
166 }
167 Some(conn.subscriptions)
168 } else {
169 None
170 }
171 }
172
173 pub fn add_subscription(
175 &self,
176 session_id: SessionId,
177 subscription_id: SubscriptionId,
178 ) -> forge_core::Result<()> {
179 let mut conn = self
180 .connections
181 .get_mut(&session_id)
182 .ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
183
184 if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
185 return Err(forge_core::ForgeError::Validation(format!(
186 "Maximum subscriptions per session ({}) exceeded",
187 self.config.max_subscriptions_per_session
188 )));
189 }
190
191 conn.subscriptions.push(subscription_id);
192 drop(conn);
193
194 self.subscription_sessions
195 .insert(subscription_id, session_id);
196
197 Ok(())
198 }
199
200 pub fn remove_subscription(&self, subscription_id: SubscriptionId) {
202 if let Some((_, session_id)) = self.subscription_sessions.remove(&subscription_id)
203 && let Some(mut conn) = self.connections.get_mut(&session_id)
204 {
205 conn.subscriptions.retain(|id| *id != subscription_id);
206 }
207 }
208
209 pub fn try_send_to_session(
211 &self,
212 session_id: SessionId,
213 message: RealtimeMessage,
214 ) -> Result<(), SendError> {
215 let conn = self
216 .connections
217 .get(&session_id)
218 .ok_or(SendError::SessionNotFound)?;
219
220 match conn.sender.try_send(message) {
221 Ok(()) => {
222 conn.consecutive_drops.store(0, Ordering::Relaxed);
223 Ok(())
224 }
225 Err(mpsc::error::TrySendError::Full(_)) => {
226 let drops = conn.consecutive_drops.fetch_add(1, Ordering::Relaxed);
227 if drops >= MAX_CONSECUTIVE_DROPS {
228 let _ = conn.sender.try_send(RealtimeMessage::Lagging);
230 drop(conn);
231 self.evict_session(session_id);
232 Err(SendError::Evicted)
233 } else {
234 Err(SendError::Full)
235 }
236 }
237 Err(mpsc::error::TrySendError::Closed(_)) => {
238 drop(conn);
239 self.remove_connection(session_id);
240 Err(SendError::Closed)
241 }
242 }
243 }
244
245 pub async fn send_to_session(
247 &self,
248 session_id: SessionId,
249 message: RealtimeMessage,
250 ) -> forge_core::Result<()> {
251 let sender = {
252 let conn = self.connections.get(&session_id).ok_or_else(|| {
253 forge_core::ForgeError::Validation("Session not found".to_string())
254 })?;
255 conn.sender.clone()
256 };
257
258 sender
259 .send(message)
260 .await
261 .map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
262 }
263
264 pub async fn broadcast_delta(
266 &self,
267 subscription_id: SubscriptionId,
268 delta: Delta<serde_json::Value>,
269 ) -> forge_core::Result<()> {
270 let session_id = self.subscription_sessions.get(&subscription_id).map(|r| *r);
271
272 if let Some(session_id) = session_id {
273 let message = RealtimeMessage::DeltaUpdate {
274 subscription_id: subscription_id.to_string(),
275 delta,
276 };
277 self.send_to_session(session_id, message).await?;
278 }
279
280 Ok(())
281 }
282
283 fn evict_session(&self, session_id: SessionId) {
285 tracing::warn!(?session_id, "Evicting slow client");
286 self.remove_connection(session_id);
287 }
288
289 pub fn connection_count(&self) -> usize {
291 self.connections.len()
292 }
293
294 pub fn subscription_count(&self) -> usize {
296 self.subscription_sessions.len()
297 }
298
299 pub fn stats(&self) -> SessionStats {
301 let total_subscriptions: usize =
302 self.connections.iter().map(|c| c.subscriptions.len()).sum();
303
304 SessionStats {
305 connections: self.connections.len(),
306 subscriptions: total_subscriptions,
307 node_id: self.node_id,
308 }
309 }
310
311 pub fn cleanup_stale(&self, max_idle: Duration) {
313 let cutoff = chrono::Utc::now()
314 - chrono::Duration::from_std(max_idle).unwrap_or(chrono::TimeDelta::MAX);
315
316 let stale: Vec<(SessionId, chrono::DateTime<chrono::Utc>)> = self
317 .connections
318 .iter()
319 .filter(|entry| entry.last_active < cutoff)
320 .map(|entry| (*entry.key(), entry.connected_at))
321 .collect();
322
323 if let Some((_, oldest_connected_at)) =
324 stale.iter().min_by_key(|(_, connected_at)| *connected_at)
325 {
326 tracing::debug!(
327 count = stale.len(),
328 oldest_connected_at = %oldest_connected_at,
329 "Cleaning up stale connections"
330 );
331 }
332
333 for (session_id, _) in stale {
334 self.remove_connection(session_id);
335 }
336 }
337}
338
339#[derive(Debug)]
341pub enum SendError {
342 SessionNotFound,
343 Full,
344 Closed,
345 Evicted,
346}
347
348#[derive(Debug, Clone)]
350pub struct SessionStats {
351 pub connections: usize,
352 pub subscriptions: usize,
353 pub node_id: NodeId,
354}
355
356#[cfg(test)]
357#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_realtime_config_default() {
363 let config = RealtimeConfig::default();
364 assert_eq!(config.max_subscriptions_per_session, 50);
365 }
366
367 #[test]
368 fn test_session_server_creation() {
369 let node_id = NodeId::new();
370 let server = SessionServer::new(node_id, RealtimeConfig::default());
371
372 assert_eq!(server.node_id(), node_id);
373 assert_eq!(server.connection_count(), 0);
374 assert_eq!(server.subscription_count(), 0);
375 }
376
377 #[test]
378 fn test_session_connection() {
379 let node_id = NodeId::new();
380 let server = SessionServer::new(node_id, RealtimeConfig::default());
381 let session_id = SessionId::new();
382 let (tx, _rx) = mpsc::channel(100);
383
384 server.register_connection(session_id, tx);
385 assert_eq!(server.connection_count(), 1);
386
387 let removed = server.remove_connection(session_id);
388 assert!(removed.is_some());
389 assert_eq!(server.connection_count(), 0);
390 }
391
392 #[test]
393 fn test_session_subscription() {
394 let node_id = NodeId::new();
395 let server = SessionServer::new(node_id, RealtimeConfig::default());
396 let session_id = SessionId::new();
397 let subscription_id = SubscriptionId::new();
398 let (tx, _rx) = mpsc::channel(100);
399
400 server.register_connection(session_id, tx);
401 server
402 .add_subscription(session_id, subscription_id)
403 .unwrap();
404
405 assert_eq!(server.subscription_count(), 1);
406
407 server.remove_subscription(subscription_id);
408 assert_eq!(server.subscription_count(), 0);
409 }
410
411 #[test]
412 fn test_session_subscription_limit() {
413 let node_id = NodeId::new();
414 let config = RealtimeConfig {
415 max_subscriptions_per_session: 2,
416 };
417 let server = SessionServer::new(node_id, config);
418 let session_id = SessionId::new();
419 let (tx, _rx) = mpsc::channel(100);
420
421 server.register_connection(session_id, tx);
422
423 server
424 .add_subscription(session_id, SubscriptionId::new())
425 .unwrap();
426 server
427 .add_subscription(session_id, SubscriptionId::new())
428 .unwrap();
429
430 let result = server.add_subscription(session_id, SubscriptionId::new());
431 assert!(result.is_err());
432 }
433
434 #[test]
435 fn test_try_send_backpressure() {
436 let node_id = NodeId::new();
437 let server = SessionServer::new(node_id, RealtimeConfig::default());
438 let session_id = SessionId::new();
439 let (tx, _rx) = mpsc::channel(1);
441
442 server.register_connection(session_id, tx);
443
444 let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
446 assert!(result.is_ok());
447
448 let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
450 assert!(matches!(result, Err(SendError::Full)));
451 }
452
453 #[test]
454 fn test_session_stats() {
455 let node_id = NodeId::new();
456 let server = SessionServer::new(node_id, RealtimeConfig::default());
457 let session_id = SessionId::new();
458 let (tx, _rx) = mpsc::channel(100);
459
460 server.register_connection(session_id, tx);
461 server
462 .add_subscription(session_id, SubscriptionId::new())
463 .unwrap();
464 server
465 .add_subscription(session_id, SubscriptionId::new())
466 .unwrap();
467
468 let stats = server.stats();
469 assert_eq!(stats.connections, 1);
470 assert_eq!(stats.subscriptions, 2);
471 assert_eq!(stats.node_id, node_id);
472 }
473}