1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use forge_core::cluster::NodeId;
7use forge_core::realtime::{
8 Change, ReadSet, SessionId, SessionInfo, SessionStatus, SubscriptionId, SubscriptionInfo,
9};
10
11pub struct SessionManager {
13 sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
14 node_id: NodeId,
15}
16
17impl SessionManager {
18 pub fn new(node_id: NodeId) -> Self {
20 Self {
21 sessions: Arc::new(RwLock::new(HashMap::new())),
22 node_id,
23 }
24 }
25
26 pub async fn create_session(&self) -> SessionInfo {
28 let mut session = SessionInfo::new(self.node_id);
29 session.connect();
30
31 let mut sessions = self.sessions.write().await;
32 sessions.insert(session.id, session.clone());
33
34 session
35 }
36
37 pub async fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39 let sessions = self.sessions.read().await;
40 sessions.get(&session_id).cloned()
41 }
42
43 pub async fn update_session(&self, session: SessionInfo) {
45 let mut sessions = self.sessions.write().await;
46 sessions.insert(session.id, session);
47 }
48
49 pub async fn remove_session(&self, session_id: SessionId) {
51 let mut sessions = self.sessions.write().await;
52 sessions.remove(&session_id);
53 }
54
55 pub async fn disconnect_session(&self, session_id: SessionId) {
57 let mut sessions = self.sessions.write().await;
58 if let Some(session) = sessions.get_mut(&session_id) {
59 session.disconnect();
60 }
61 }
62
63 pub async fn get_connected_sessions(&self) -> Vec<SessionInfo> {
65 let sessions = self.sessions.read().await;
66 sessions
67 .values()
68 .filter(|s| s.is_connected())
69 .cloned()
70 .collect()
71 }
72
73 pub async fn count_by_status(&self) -> SessionCounts {
75 let sessions = self.sessions.read().await;
76 let mut counts = SessionCounts::default();
77
78 for session in sessions.values() {
79 match session.status {
80 SessionStatus::Connecting => counts.connecting += 1,
81 SessionStatus::Connected => counts.connected += 1,
82 SessionStatus::Reconnecting => counts.reconnecting += 1,
83 SessionStatus::Disconnected => counts.disconnected += 1,
84 }
85 counts.total += 1;
86 }
87
88 counts
89 }
90
91 pub async fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
93 let mut sessions = self.sessions.write().await;
94 let cutoff = chrono::Utc::now()
95 - chrono::Duration::from_std(max_age).expect("duration within chrono range");
96
97 sessions.retain(|_, session| {
98 session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
99 });
100 }
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct SessionCounts {
106 pub connecting: usize,
108 pub connected: usize,
110 pub reconnecting: usize,
112 pub disconnected: usize,
114 pub total: usize,
116}
117
118pub struct SubscriptionManager {
120 subscriptions: Arc<RwLock<HashMap<SubscriptionId, SubscriptionInfo>>>,
122 by_session: Arc<RwLock<HashMap<SessionId, Vec<SubscriptionId>>>>,
124 by_query_hash: Arc<RwLock<HashMap<String, Vec<SubscriptionId>>>>,
126 max_per_session: usize,
128}
129
130impl SubscriptionManager {
131 pub fn new(max_per_session: usize) -> Self {
133 Self {
134 subscriptions: Arc::new(RwLock::new(HashMap::new())),
135 by_session: Arc::new(RwLock::new(HashMap::new())),
136 by_query_hash: Arc::new(RwLock::new(HashMap::new())),
137 max_per_session,
138 }
139 }
140
141 pub async fn create_subscription(
143 &self,
144 session_id: SessionId,
145 query_name: impl Into<String>,
146 args: serde_json::Value,
147 ) -> forge_core::Result<SubscriptionInfo> {
148 let by_session = self.by_session.read().await;
150 if let Some(subs) = by_session.get(&session_id)
151 && subs.len() >= self.max_per_session
152 {
153 return Err(forge_core::ForgeError::Validation(format!(
154 "Maximum subscriptions per session ({}) exceeded",
155 self.max_per_session
156 )));
157 }
158 drop(by_session);
159
160 let subscription = SubscriptionInfo::new(session_id, query_name, args);
161
162 let mut subscriptions = self.subscriptions.write().await;
164 subscriptions.insert(subscription.id, subscription.clone());
165
166 let mut by_session = self.by_session.write().await;
168 by_session
169 .entry(session_id)
170 .or_default()
171 .push(subscription.id);
172
173 let mut by_query_hash = self.by_query_hash.write().await;
175 by_query_hash
176 .entry(subscription.query_hash.clone())
177 .or_default()
178 .push(subscription.id);
179
180 Ok(subscription)
181 }
182
183 pub async fn get_subscription(
185 &self,
186 subscription_id: SubscriptionId,
187 ) -> Option<SubscriptionInfo> {
188 let subscriptions = self.subscriptions.read().await;
189 subscriptions.get(&subscription_id).cloned()
190 }
191
192 pub async fn update_subscription(
194 &self,
195 subscription_id: SubscriptionId,
196 read_set: ReadSet,
197 result_hash: String,
198 ) {
199 let mut subscriptions = self.subscriptions.write().await;
200 if let Some(sub) = subscriptions.get_mut(&subscription_id) {
201 sub.record_execution(read_set, result_hash);
202 }
203 }
204
205 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
207 let mut subscriptions = self.subscriptions.write().await;
208 if let Some(sub) = subscriptions.remove(&subscription_id) {
209 let mut by_session = self.by_session.write().await;
211 if let Some(subs) = by_session.get_mut(&sub.session_id) {
212 subs.retain(|id| *id != subscription_id);
213 }
214
215 let mut by_query_hash = self.by_query_hash.write().await;
217 if let Some(subs) = by_query_hash.get_mut(&sub.query_hash) {
218 subs.retain(|id| *id != subscription_id);
219 }
220 }
221 }
222
223 pub async fn remove_session_subscriptions(&self, session_id: SessionId) {
225 let subscription_ids: Vec<SubscriptionId> = {
226 let by_session = self.by_session.read().await;
227 by_session.get(&session_id).cloned().unwrap_or_default()
228 };
229
230 for sub_id in subscription_ids {
231 self.remove_subscription(sub_id).await;
232 }
233
234 let mut by_session = self.by_session.write().await;
236 by_session.remove(&session_id);
237 }
238
239 pub async fn find_affected_subscriptions(&self, change: &Change) -> Vec<SubscriptionId> {
241 let subscriptions = self.subscriptions.read().await;
242 subscriptions
243 .iter()
244 .filter(|(_, sub)| sub.should_invalidate(change))
245 .map(|(id, _)| *id)
246 .collect()
247 }
248
249 pub async fn get_by_query_hash(&self, query_hash: &str) -> Vec<SubscriptionInfo> {
251 let by_query_hash = self.by_query_hash.read().await;
252 let subscriptions = self.subscriptions.read().await;
253
254 by_query_hash
255 .get(query_hash)
256 .map(|ids| {
257 ids.iter()
258 .filter_map(|id| subscriptions.get(id).cloned())
259 .collect()
260 })
261 .unwrap_or_default()
262 }
263
264 pub async fn counts(&self) -> SubscriptionCounts {
266 let subscriptions = self.subscriptions.read().await;
267 let by_session = self.by_session.read().await;
268
269 SubscriptionCounts {
270 total: subscriptions.len(),
271 unique_queries: self.by_query_hash.read().await.len(),
272 sessions: by_session.len(),
273 memory_bytes: subscriptions.values().map(|s| s.memory_bytes).sum(),
274 }
275 }
276}
277
278#[derive(Debug, Clone, Default)]
280pub struct SubscriptionCounts {
281 pub total: usize,
283 pub unique_queries: usize,
285 pub sessions: usize,
287 pub memory_bytes: usize,
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
293mod tests {
294 use super::*;
295
296 #[tokio::test]
297 async fn test_session_manager_create() {
298 let node_id = NodeId::new();
299 let manager = SessionManager::new(node_id);
300
301 let session = manager.create_session().await;
302 assert!(session.is_connected());
303
304 let retrieved = manager.get_session(session.id).await;
305 assert!(retrieved.is_some());
306 }
307
308 #[tokio::test]
309 async fn test_session_manager_disconnect() {
310 let node_id = NodeId::new();
311 let manager = SessionManager::new(node_id);
312
313 let session = manager.create_session().await;
314 manager.disconnect_session(session.id).await;
315
316 let retrieved = manager.get_session(session.id).await.unwrap();
317 assert!(!retrieved.is_connected());
318 }
319
320 #[tokio::test]
321 async fn test_subscription_manager_create() {
322 let manager = SubscriptionManager::new(50);
323 let session_id = SessionId::new();
324
325 let sub = manager
326 .create_subscription(session_id, "get_projects", serde_json::json!({}))
327 .await
328 .unwrap();
329
330 assert_eq!(sub.query_name, "get_projects");
331
332 let retrieved = manager.get_subscription(sub.id).await;
333 assert!(retrieved.is_some());
334 }
335
336 #[tokio::test]
337 async fn test_subscription_manager_limit() {
338 let manager = SubscriptionManager::new(2);
339 let session_id = SessionId::new();
340
341 manager
343 .create_subscription(session_id, "query1", serde_json::json!({}))
344 .await
345 .unwrap();
346 manager
347 .create_subscription(session_id, "query2", serde_json::json!({}))
348 .await
349 .unwrap();
350
351 let result = manager
353 .create_subscription(session_id, "query3", serde_json::json!({}))
354 .await;
355 assert!(result.is_err());
356 }
357
358 #[tokio::test]
359 async fn test_subscription_manager_remove_session() {
360 let manager = SubscriptionManager::new(50);
361 let session_id = SessionId::new();
362
363 manager
364 .create_subscription(session_id, "query1", serde_json::json!({}))
365 .await
366 .unwrap();
367 manager
368 .create_subscription(session_id, "query2", serde_json::json!({}))
369 .await
370 .unwrap();
371
372 let counts = manager.counts().await;
373 assert_eq!(counts.total, 2);
374
375 manager.remove_session_subscriptions(session_id).await;
376
377 let counts = manager.counts().await;
378 assert_eq!(counts.total, 0);
379 }
380}