Skip to main content

forge_runtime/realtime/
manager.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use forge_core::cluster::NodeId;
7use forge_core::realtime::{
8    Change, ReadSet, SessionId, SessionInfo, SessionStatus, SubscriptionId, SubscriptionInfo,
9};
10
11/// Session manager for tracking WebSocket connections.
12pub struct SessionManager {
13    sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
14    node_id: NodeId,
15}
16
17impl SessionManager {
18    /// Create a new session manager.
19    pub fn new(node_id: NodeId) -> Self {
20        Self {
21            sessions: Arc::new(RwLock::new(HashMap::new())),
22            node_id,
23        }
24    }
25
26    /// Create a new session.
27    pub async fn create_session(&self) -> SessionInfo {
28        let mut session = SessionInfo::new(self.node_id);
29        session.connect();
30
31        let mut sessions = self.sessions.write().await;
32        sessions.insert(session.id, session.clone());
33
34        session
35    }
36
37    /// Get a session by ID.
38    pub async fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39        let sessions = self.sessions.read().await;
40        sessions.get(&session_id).cloned()
41    }
42
43    /// Update a session.
44    pub async fn update_session(&self, session: SessionInfo) {
45        let mut sessions = self.sessions.write().await;
46        sessions.insert(session.id, session);
47    }
48
49    /// Remove a session.
50    pub async fn remove_session(&self, session_id: SessionId) {
51        let mut sessions = self.sessions.write().await;
52        sessions.remove(&session_id);
53    }
54
55    /// Mark a session as disconnected.
56    pub async fn disconnect_session(&self, session_id: SessionId) {
57        let mut sessions = self.sessions.write().await;
58        if let Some(session) = sessions.get_mut(&session_id) {
59            session.disconnect();
60        }
61    }
62
63    /// Get all connected sessions.
64    pub async fn get_connected_sessions(&self) -> Vec<SessionInfo> {
65        let sessions = self.sessions.read().await;
66        sessions
67            .values()
68            .filter(|s| s.is_connected())
69            .cloned()
70            .collect()
71    }
72
73    /// Count sessions by status.
74    pub async fn count_by_status(&self) -> SessionCounts {
75        let sessions = self.sessions.read().await;
76        let mut counts = SessionCounts::default();
77
78        for session in sessions.values() {
79            match session.status {
80                SessionStatus::Connecting => counts.connecting += 1,
81                SessionStatus::Connected => counts.connected += 1,
82                SessionStatus::Reconnecting => counts.reconnecting += 1,
83                SessionStatus::Disconnected => counts.disconnected += 1,
84            }
85            counts.total += 1;
86        }
87
88        counts
89    }
90
91    /// Clean up disconnected sessions older than the given duration.
92    pub async fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
93        let mut sessions = self.sessions.write().await;
94        let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_age).unwrap();
95
96        sessions.retain(|_, session| {
97            session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
98        });
99    }
100}
101
102/// Session count statistics.
103#[derive(Debug, Clone, Default)]
104pub struct SessionCounts {
105    /// Connecting sessions.
106    pub connecting: usize,
107    /// Connected sessions.
108    pub connected: usize,
109    /// Reconnecting sessions.
110    pub reconnecting: usize,
111    /// Disconnected sessions.
112    pub disconnected: usize,
113    /// Total sessions.
114    pub total: usize,
115}
116
117/// Subscription manager for tracking active subscriptions.
118pub struct SubscriptionManager {
119    /// Subscriptions by ID.
120    subscriptions: Arc<RwLock<HashMap<SubscriptionId, SubscriptionInfo>>>,
121    /// Subscriptions by session ID for fast lookup.
122    by_session: Arc<RwLock<HashMap<SessionId, Vec<SubscriptionId>>>>,
123    /// Subscriptions by query hash for deduplication.
124    by_query_hash: Arc<RwLock<HashMap<String, Vec<SubscriptionId>>>>,
125    /// Maximum subscriptions per session.
126    max_per_session: usize,
127}
128
129impl SubscriptionManager {
130    /// Create a new subscription manager.
131    pub fn new(max_per_session: usize) -> Self {
132        Self {
133            subscriptions: Arc::new(RwLock::new(HashMap::new())),
134            by_session: Arc::new(RwLock::new(HashMap::new())),
135            by_query_hash: Arc::new(RwLock::new(HashMap::new())),
136            max_per_session,
137        }
138    }
139
140    /// Create a new subscription.
141    pub async fn create_subscription(
142        &self,
143        session_id: SessionId,
144        query_name: impl Into<String>,
145        args: serde_json::Value,
146    ) -> forge_core::Result<SubscriptionInfo> {
147        // Check limit
148        let by_session = self.by_session.read().await;
149        if let Some(subs) = by_session.get(&session_id) {
150            if subs.len() >= self.max_per_session {
151                return Err(forge_core::ForgeError::Validation(format!(
152                    "Maximum subscriptions per session ({}) exceeded",
153                    self.max_per_session
154                )));
155            }
156        }
157        drop(by_session);
158
159        let subscription = SubscriptionInfo::new(session_id, query_name, args);
160
161        // Store subscription
162        let mut subscriptions = self.subscriptions.write().await;
163        subscriptions.insert(subscription.id, subscription.clone());
164
165        // Index by session
166        let mut by_session = self.by_session.write().await;
167        by_session
168            .entry(session_id)
169            .or_default()
170            .push(subscription.id);
171
172        // Index by query hash
173        let mut by_query_hash = self.by_query_hash.write().await;
174        by_query_hash
175            .entry(subscription.query_hash.clone())
176            .or_default()
177            .push(subscription.id);
178
179        Ok(subscription)
180    }
181
182    /// Get a subscription by ID.
183    pub async fn get_subscription(
184        &self,
185        subscription_id: SubscriptionId,
186    ) -> Option<SubscriptionInfo> {
187        let subscriptions = self.subscriptions.read().await;
188        subscriptions.get(&subscription_id).cloned()
189    }
190
191    /// Update a subscription after execution.
192    pub async fn update_subscription(
193        &self,
194        subscription_id: SubscriptionId,
195        read_set: ReadSet,
196        result_hash: String,
197    ) {
198        let mut subscriptions = self.subscriptions.write().await;
199        if let Some(sub) = subscriptions.get_mut(&subscription_id) {
200            sub.record_execution(read_set, result_hash);
201        }
202    }
203
204    /// Remove a subscription.
205    pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
206        let mut subscriptions = self.subscriptions.write().await;
207        if let Some(sub) = subscriptions.remove(&subscription_id) {
208            // Remove from session index
209            let mut by_session = self.by_session.write().await;
210            if let Some(subs) = by_session.get_mut(&sub.session_id) {
211                subs.retain(|id| *id != subscription_id);
212            }
213
214            // Remove from query hash index
215            let mut by_query_hash = self.by_query_hash.write().await;
216            if let Some(subs) = by_query_hash.get_mut(&sub.query_hash) {
217                subs.retain(|id| *id != subscription_id);
218            }
219        }
220    }
221
222    /// Remove all subscriptions for a session.
223    pub async fn remove_session_subscriptions(&self, session_id: SessionId) {
224        let subscription_ids: Vec<SubscriptionId> = {
225            let by_session = self.by_session.read().await;
226            by_session.get(&session_id).cloned().unwrap_or_default()
227        };
228
229        for sub_id in subscription_ids {
230            self.remove_subscription(sub_id).await;
231        }
232
233        // Clean up session entry
234        let mut by_session = self.by_session.write().await;
235        by_session.remove(&session_id);
236    }
237
238    /// Find subscriptions affected by a change.
239    pub async fn find_affected_subscriptions(&self, change: &Change) -> Vec<SubscriptionId> {
240        let subscriptions = self.subscriptions.read().await;
241        subscriptions
242            .iter()
243            .filter(|(_, sub)| sub.should_invalidate(change))
244            .map(|(id, _)| *id)
245            .collect()
246    }
247
248    /// Get subscriptions by query hash (for coalescing).
249    pub async fn get_by_query_hash(&self, query_hash: &str) -> Vec<SubscriptionInfo> {
250        let by_query_hash = self.by_query_hash.read().await;
251        let subscriptions = self.subscriptions.read().await;
252
253        by_query_hash
254            .get(query_hash)
255            .map(|ids| {
256                ids.iter()
257                    .filter_map(|id| subscriptions.get(id).cloned())
258                    .collect()
259            })
260            .unwrap_or_default()
261    }
262
263    /// Get subscription counts.
264    pub async fn counts(&self) -> SubscriptionCounts {
265        let subscriptions = self.subscriptions.read().await;
266        let by_session = self.by_session.read().await;
267
268        SubscriptionCounts {
269            total: subscriptions.len(),
270            unique_queries: self.by_query_hash.read().await.len(),
271            sessions: by_session.len(),
272            memory_bytes: subscriptions.values().map(|s| s.memory_bytes).sum(),
273        }
274    }
275}
276
277/// Subscription count statistics.
278#[derive(Debug, Clone, Default)]
279pub struct SubscriptionCounts {
280    /// Total subscriptions.
281    pub total: usize,
282    /// Number of unique queries (coalesced).
283    pub unique_queries: usize,
284    /// Number of sessions with subscriptions.
285    pub sessions: usize,
286    /// Total memory used by subscriptions.
287    pub memory_bytes: usize,
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[tokio::test]
295    async fn test_session_manager_create() {
296        let node_id = NodeId::new();
297        let manager = SessionManager::new(node_id);
298
299        let session = manager.create_session().await;
300        assert!(session.is_connected());
301
302        let retrieved = manager.get_session(session.id).await;
303        assert!(retrieved.is_some());
304    }
305
306    #[tokio::test]
307    async fn test_session_manager_disconnect() {
308        let node_id = NodeId::new();
309        let manager = SessionManager::new(node_id);
310
311        let session = manager.create_session().await;
312        manager.disconnect_session(session.id).await;
313
314        let retrieved = manager.get_session(session.id).await.unwrap();
315        assert!(!retrieved.is_connected());
316    }
317
318    #[tokio::test]
319    async fn test_subscription_manager_create() {
320        let manager = SubscriptionManager::new(50);
321        let session_id = SessionId::new();
322
323        let sub = manager
324            .create_subscription(session_id, "get_projects", serde_json::json!({}))
325            .await
326            .unwrap();
327
328        assert_eq!(sub.query_name, "get_projects");
329
330        let retrieved = manager.get_subscription(sub.id).await;
331        assert!(retrieved.is_some());
332    }
333
334    #[tokio::test]
335    async fn test_subscription_manager_limit() {
336        let manager = SubscriptionManager::new(2);
337        let session_id = SessionId::new();
338
339        // First two should succeed
340        manager
341            .create_subscription(session_id, "query1", serde_json::json!({}))
342            .await
343            .unwrap();
344        manager
345            .create_subscription(session_id, "query2", serde_json::json!({}))
346            .await
347            .unwrap();
348
349        // Third should fail
350        let result = manager
351            .create_subscription(session_id, "query3", serde_json::json!({}))
352            .await;
353        assert!(result.is_err());
354    }
355
356    #[tokio::test]
357    async fn test_subscription_manager_remove_session() {
358        let manager = SubscriptionManager::new(50);
359        let session_id = SessionId::new();
360
361        manager
362            .create_subscription(session_id, "query1", serde_json::json!({}))
363            .await
364            .unwrap();
365        manager
366            .create_subscription(session_id, "query2", serde_json::json!({}))
367            .await
368            .unwrap();
369
370        let counts = manager.counts().await;
371        assert_eq!(counts.total, 2);
372
373        manager.remove_session_subscriptions(session_id).await;
374
375        let counts = manager.counts().await;
376        assert_eq!(counts.total, 0);
377    }
378}