Skip to main content

forge_runtime/realtime/
manager.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use forge_core::cluster::NodeId;
7use forge_core::realtime::{
8    Change, ReadSet, SessionId, SessionInfo, SessionStatus, SubscriptionId, SubscriptionInfo,
9};
10
11/// Session manager for tracking WebSocket connections.
12pub struct SessionManager {
13    sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
14    node_id: NodeId,
15}
16
17impl SessionManager {
18    /// Create a new session manager.
19    pub fn new(node_id: NodeId) -> Self {
20        Self {
21            sessions: Arc::new(RwLock::new(HashMap::new())),
22            node_id,
23        }
24    }
25
26    /// Create a new session.
27    pub async fn create_session(&self) -> SessionInfo {
28        let mut session = SessionInfo::new(self.node_id);
29        session.connect();
30
31        let mut sessions = self.sessions.write().await;
32        sessions.insert(session.id, session.clone());
33
34        session
35    }
36
37    /// Get a session by ID.
38    pub async fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39        let sessions = self.sessions.read().await;
40        sessions.get(&session_id).cloned()
41    }
42
43    /// Update a session.
44    pub async fn update_session(&self, session: SessionInfo) {
45        let mut sessions = self.sessions.write().await;
46        sessions.insert(session.id, session);
47    }
48
49    /// Remove a session.
50    pub async fn remove_session(&self, session_id: SessionId) {
51        let mut sessions = self.sessions.write().await;
52        sessions.remove(&session_id);
53    }
54
55    /// Mark a session as disconnected.
56    pub async fn disconnect_session(&self, session_id: SessionId) {
57        let mut sessions = self.sessions.write().await;
58        if let Some(session) = sessions.get_mut(&session_id) {
59            session.disconnect();
60        }
61    }
62
63    /// Get all connected sessions.
64    pub async fn get_connected_sessions(&self) -> Vec<SessionInfo> {
65        let sessions = self.sessions.read().await;
66        sessions
67            .values()
68            .filter(|s| s.is_connected())
69            .cloned()
70            .collect()
71    }
72
73    /// Count sessions by status.
74    pub async fn count_by_status(&self) -> SessionCounts {
75        let sessions = self.sessions.read().await;
76        let mut counts = SessionCounts::default();
77
78        for session in sessions.values() {
79            match session.status {
80                SessionStatus::Connecting => counts.connecting += 1,
81                SessionStatus::Connected => counts.connected += 1,
82                SessionStatus::Reconnecting => counts.reconnecting += 1,
83                SessionStatus::Disconnected => counts.disconnected += 1,
84            }
85            counts.total += 1;
86        }
87
88        counts
89    }
90
91    /// Clean up disconnected sessions older than the given duration.
92    pub async fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
93        let mut sessions = self.sessions.write().await;
94        let cutoff = chrono::Utc::now()
95            - chrono::Duration::from_std(max_age).expect("duration within chrono range");
96
97        sessions.retain(|_, session| {
98            session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
99        });
100    }
101}
102
103/// Session count statistics.
104#[derive(Debug, Clone, Default)]
105pub struct SessionCounts {
106    /// Connecting sessions.
107    pub connecting: usize,
108    /// Connected sessions.
109    pub connected: usize,
110    /// Reconnecting sessions.
111    pub reconnecting: usize,
112    /// Disconnected sessions.
113    pub disconnected: usize,
114    /// Total sessions.
115    pub total: usize,
116}
117
118/// Subscription manager for tracking active subscriptions.
119pub struct SubscriptionManager {
120    /// Subscriptions by ID.
121    subscriptions: Arc<RwLock<HashMap<SubscriptionId, SubscriptionInfo>>>,
122    /// Subscriptions by session ID for fast lookup.
123    by_session: Arc<RwLock<HashMap<SessionId, Vec<SubscriptionId>>>>,
124    /// Subscriptions by query hash for deduplication.
125    by_query_hash: Arc<RwLock<HashMap<String, Vec<SubscriptionId>>>>,
126    /// Maximum subscriptions per session.
127    max_per_session: usize,
128}
129
130impl SubscriptionManager {
131    /// Create a new subscription manager.
132    pub fn new(max_per_session: usize) -> Self {
133        Self {
134            subscriptions: Arc::new(RwLock::new(HashMap::new())),
135            by_session: Arc::new(RwLock::new(HashMap::new())),
136            by_query_hash: Arc::new(RwLock::new(HashMap::new())),
137            max_per_session,
138        }
139    }
140
141    /// Create a new subscription.
142    pub async fn create_subscription(
143        &self,
144        session_id: SessionId,
145        query_name: impl Into<String>,
146        args: serde_json::Value,
147    ) -> forge_core::Result<SubscriptionInfo> {
148        // Check limit
149        let by_session = self.by_session.read().await;
150        if let Some(subs) = by_session.get(&session_id)
151            && subs.len() >= self.max_per_session
152        {
153            return Err(forge_core::ForgeError::Validation(format!(
154                "Maximum subscriptions per session ({}) exceeded",
155                self.max_per_session
156            )));
157        }
158        drop(by_session);
159
160        let subscription = SubscriptionInfo::new(session_id, query_name, args);
161
162        // Store subscription
163        let mut subscriptions = self.subscriptions.write().await;
164        subscriptions.insert(subscription.id, subscription.clone());
165
166        // Index by session
167        let mut by_session = self.by_session.write().await;
168        by_session
169            .entry(session_id)
170            .or_default()
171            .push(subscription.id);
172
173        // Index by query hash
174        let mut by_query_hash = self.by_query_hash.write().await;
175        by_query_hash
176            .entry(subscription.query_hash.clone())
177            .or_default()
178            .push(subscription.id);
179
180        Ok(subscription)
181    }
182
183    /// Get a subscription by ID.
184    pub async fn get_subscription(
185        &self,
186        subscription_id: SubscriptionId,
187    ) -> Option<SubscriptionInfo> {
188        let subscriptions = self.subscriptions.read().await;
189        subscriptions.get(&subscription_id).cloned()
190    }
191
192    /// Update a subscription after execution.
193    pub async fn update_subscription(
194        &self,
195        subscription_id: SubscriptionId,
196        read_set: ReadSet,
197        result_hash: String,
198    ) {
199        let mut subscriptions = self.subscriptions.write().await;
200        if let Some(sub) = subscriptions.get_mut(&subscription_id) {
201            sub.record_execution(read_set, result_hash);
202        }
203    }
204
205    /// Remove a subscription.
206    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
207        let mut subscriptions = self.subscriptions.write().await;
208        if let Some(sub) = subscriptions.remove(&subscription_id) {
209            // Remove from session index
210            let mut by_session = self.by_session.write().await;
211            if let Some(subs) = by_session.get_mut(&sub.session_id) {
212                subs.retain(|id| *id != subscription_id);
213            }
214
215            // Remove from query hash index
216            let mut by_query_hash = self.by_query_hash.write().await;
217            if let Some(subs) = by_query_hash.get_mut(&sub.query_hash) {
218                subs.retain(|id| *id != subscription_id);
219            }
220        }
221    }
222
223    /// Remove all subscriptions for a session.
224    pub async fn remove_session_subscriptions(&self, session_id: SessionId) {
225        let subscription_ids: Vec<SubscriptionId> = {
226            let by_session = self.by_session.read().await;
227            by_session.get(&session_id).cloned().unwrap_or_default()
228        };
229
230        for sub_id in subscription_ids {
231            self.remove_subscription(sub_id).await;
232        }
233
234        // Clean up session entry
235        let mut by_session = self.by_session.write().await;
236        by_session.remove(&session_id);
237    }
238
239    /// Find subscriptions affected by a change.
240    pub async fn find_affected_subscriptions(&self, change: &Change) -> Vec<SubscriptionId> {
241        let subscriptions = self.subscriptions.read().await;
242        subscriptions
243            .iter()
244            .filter(|(_, sub)| sub.should_invalidate(change))
245            .map(|(id, _)| *id)
246            .collect()
247    }
248
249    /// Get subscriptions by query hash (for coalescing).
250    pub async fn get_by_query_hash(&self, query_hash: &str) -> Vec<SubscriptionInfo> {
251        let by_query_hash = self.by_query_hash.read().await;
252        let subscriptions = self.subscriptions.read().await;
253
254        by_query_hash
255            .get(query_hash)
256            .map(|ids| {
257                ids.iter()
258                    .filter_map(|id| subscriptions.get(id).cloned())
259                    .collect()
260            })
261            .unwrap_or_default()
262    }
263
264    /// Get subscription counts.
265    pub async fn counts(&self) -> SubscriptionCounts {
266        let subscriptions = self.subscriptions.read().await;
267        let by_session = self.by_session.read().await;
268
269        SubscriptionCounts {
270            total: subscriptions.len(),
271            unique_queries: self.by_query_hash.read().await.len(),
272            sessions: by_session.len(),
273            memory_bytes: subscriptions.values().map(|s| s.memory_bytes).sum(),
274        }
275    }
276}
277
278/// Subscription count statistics.
279#[derive(Debug, Clone, Default)]
280pub struct SubscriptionCounts {
281    /// Total subscriptions.
282    pub total: usize,
283    /// Number of unique queries (coalesced).
284    pub unique_queries: usize,
285    /// Number of sessions with subscriptions.
286    pub sessions: usize,
287    /// Total memory used by subscriptions.
288    pub memory_bytes: usize,
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
293mod tests {
294    use super::*;
295
296    #[tokio::test]
297    async fn test_session_manager_create() {
298        let node_id = NodeId::new();
299        let manager = SessionManager::new(node_id);
300
301        let session = manager.create_session().await;
302        assert!(session.is_connected());
303
304        let retrieved = manager.get_session(session.id).await;
305        assert!(retrieved.is_some());
306    }
307
308    #[tokio::test]
309    async fn test_session_manager_disconnect() {
310        let node_id = NodeId::new();
311        let manager = SessionManager::new(node_id);
312
313        let session = manager.create_session().await;
314        manager.disconnect_session(session.id).await;
315
316        let retrieved = manager.get_session(session.id).await.unwrap();
317        assert!(!retrieved.is_connected());
318    }
319
320    #[tokio::test]
321    async fn test_subscription_manager_create() {
322        let manager = SubscriptionManager::new(50);
323        let session_id = SessionId::new();
324
325        let sub = manager
326            .create_subscription(session_id, "get_projects", serde_json::json!({}))
327            .await
328            .unwrap();
329
330        assert_eq!(sub.query_name, "get_projects");
331
332        let retrieved = manager.get_subscription(sub.id).await;
333        assert!(retrieved.is_some());
334    }
335
336    #[tokio::test]
337    async fn test_subscription_manager_limit() {
338        let manager = SubscriptionManager::new(2);
339        let session_id = SessionId::new();
340
341        // First two should succeed
342        manager
343            .create_subscription(session_id, "query1", serde_json::json!({}))
344            .await
345            .unwrap();
346        manager
347            .create_subscription(session_id, "query2", serde_json::json!({}))
348            .await
349            .unwrap();
350
351        // Third should fail
352        let result = manager
353            .create_subscription(session_id, "query3", serde_json::json!({}))
354            .await;
355        assert!(result.is_err());
356    }
357
358    #[tokio::test]
359    async fn test_subscription_manager_remove_session() {
360        let manager = SubscriptionManager::new(50);
361        let session_id = SessionId::new();
362
363        manager
364            .create_subscription(session_id, "query1", serde_json::json!({}))
365            .await
366            .unwrap();
367        manager
368            .create_subscription(session_id, "query2", serde_json::json!({}))
369            .await
370            .unwrap();
371
372        let counts = manager.counts().await;
373        assert_eq!(counts.total, 2);
374
375        manager.remove_session_subscriptions(session_id).await;
376
377        let counts = manager.counts().await;
378        assert_eq!(counts.total, 0);
379    }
380}