1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use forge_core::cluster::NodeId;
7use forge_core::realtime::{
8 Change, ReadSet, SessionId, SessionInfo, SessionStatus, SubscriptionId, SubscriptionInfo,
9};
10
11pub struct SessionManager {
13 sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
14 node_id: NodeId,
15}
16
17impl SessionManager {
18 pub fn new(node_id: NodeId) -> Self {
20 Self {
21 sessions: Arc::new(RwLock::new(HashMap::new())),
22 node_id,
23 }
24 }
25
26 pub async fn create_session(&self) -> SessionInfo {
28 let mut session = SessionInfo::new(self.node_id);
29 session.connect();
30
31 let mut sessions = self.sessions.write().await;
32 sessions.insert(session.id, session.clone());
33
34 session
35 }
36
37 pub async fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39 let sessions = self.sessions.read().await;
40 sessions.get(&session_id).cloned()
41 }
42
43 pub async fn update_session(&self, session: SessionInfo) {
45 let mut sessions = self.sessions.write().await;
46 sessions.insert(session.id, session);
47 }
48
49 pub async fn remove_session(&self, session_id: SessionId) {
51 let mut sessions = self.sessions.write().await;
52 sessions.remove(&session_id);
53 }
54
55 pub async fn disconnect_session(&self, session_id: SessionId) {
57 let mut sessions = self.sessions.write().await;
58 if let Some(session) = sessions.get_mut(&session_id) {
59 session.disconnect();
60 }
61 }
62
63 pub async fn get_connected_sessions(&self) -> Vec<SessionInfo> {
65 let sessions = self.sessions.read().await;
66 sessions
67 .values()
68 .filter(|s| s.is_connected())
69 .cloned()
70 .collect()
71 }
72
73 pub async fn count_by_status(&self) -> SessionCounts {
75 let sessions = self.sessions.read().await;
76 let mut counts = SessionCounts::default();
77
78 for session in sessions.values() {
79 match session.status {
80 SessionStatus::Connecting => counts.connecting += 1,
81 SessionStatus::Connected => counts.connected += 1,
82 SessionStatus::Reconnecting => counts.reconnecting += 1,
83 SessionStatus::Disconnected => counts.disconnected += 1,
84 }
85 counts.total += 1;
86 }
87
88 counts
89 }
90
91 pub async fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
93 let mut sessions = self.sessions.write().await;
94 let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_age).unwrap();
95
96 sessions.retain(|_, session| {
97 session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
98 });
99 }
100}
101
102#[derive(Debug, Clone, Default)]
104pub struct SessionCounts {
105 pub connecting: usize,
107 pub connected: usize,
109 pub reconnecting: usize,
111 pub disconnected: usize,
113 pub total: usize,
115}
116
117pub struct SubscriptionManager {
119 subscriptions: Arc<RwLock<HashMap<SubscriptionId, SubscriptionInfo>>>,
121 by_session: Arc<RwLock<HashMap<SessionId, Vec<SubscriptionId>>>>,
123 by_query_hash: Arc<RwLock<HashMap<String, Vec<SubscriptionId>>>>,
125 max_per_session: usize,
127}
128
129impl SubscriptionManager {
130 pub fn new(max_per_session: usize) -> Self {
132 Self {
133 subscriptions: Arc::new(RwLock::new(HashMap::new())),
134 by_session: Arc::new(RwLock::new(HashMap::new())),
135 by_query_hash: Arc::new(RwLock::new(HashMap::new())),
136 max_per_session,
137 }
138 }
139
140 pub async fn create_subscription(
142 &self,
143 session_id: SessionId,
144 query_name: impl Into<String>,
145 args: serde_json::Value,
146 ) -> forge_core::Result<SubscriptionInfo> {
147 let by_session = self.by_session.read().await;
149 if let Some(subs) = by_session.get(&session_id) {
150 if subs.len() >= self.max_per_session {
151 return Err(forge_core::ForgeError::Validation(format!(
152 "Maximum subscriptions per session ({}) exceeded",
153 self.max_per_session
154 )));
155 }
156 }
157 drop(by_session);
158
159 let subscription = SubscriptionInfo::new(session_id, query_name, args);
160
161 let mut subscriptions = self.subscriptions.write().await;
163 subscriptions.insert(subscription.id, subscription.clone());
164
165 let mut by_session = self.by_session.write().await;
167 by_session
168 .entry(session_id)
169 .or_default()
170 .push(subscription.id);
171
172 let mut by_query_hash = self.by_query_hash.write().await;
174 by_query_hash
175 .entry(subscription.query_hash.clone())
176 .or_default()
177 .push(subscription.id);
178
179 Ok(subscription)
180 }
181
182 pub async fn get_subscription(
184 &self,
185 subscription_id: SubscriptionId,
186 ) -> Option<SubscriptionInfo> {
187 let subscriptions = self.subscriptions.read().await;
188 subscriptions.get(&subscription_id).cloned()
189 }
190
191 pub async fn update_subscription(
193 &self,
194 subscription_id: SubscriptionId,
195 read_set: ReadSet,
196 result_hash: String,
197 ) {
198 let mut subscriptions = self.subscriptions.write().await;
199 if let Some(sub) = subscriptions.get_mut(&subscription_id) {
200 sub.record_execution(read_set, result_hash);
201 }
202 }
203
204 pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
206 let mut subscriptions = self.subscriptions.write().await;
207 if let Some(sub) = subscriptions.remove(&subscription_id) {
208 let mut by_session = self.by_session.write().await;
210 if let Some(subs) = by_session.get_mut(&sub.session_id) {
211 subs.retain(|id| *id != subscription_id);
212 }
213
214 let mut by_query_hash = self.by_query_hash.write().await;
216 if let Some(subs) = by_query_hash.get_mut(&sub.query_hash) {
217 subs.retain(|id| *id != subscription_id);
218 }
219 }
220 }
221
222 pub async fn remove_session_subscriptions(&self, session_id: SessionId) {
224 let subscription_ids: Vec<SubscriptionId> = {
225 let by_session = self.by_session.read().await;
226 by_session.get(&session_id).cloned().unwrap_or_default()
227 };
228
229 for sub_id in subscription_ids {
230 self.remove_subscription(sub_id).await;
231 }
232
233 let mut by_session = self.by_session.write().await;
235 by_session.remove(&session_id);
236 }
237
238 pub async fn find_affected_subscriptions(&self, change: &Change) -> Vec<SubscriptionId> {
240 let subscriptions = self.subscriptions.read().await;
241 subscriptions
242 .iter()
243 .filter(|(_, sub)| sub.should_invalidate(change))
244 .map(|(id, _)| *id)
245 .collect()
246 }
247
248 pub async fn get_by_query_hash(&self, query_hash: &str) -> Vec<SubscriptionInfo> {
250 let by_query_hash = self.by_query_hash.read().await;
251 let subscriptions = self.subscriptions.read().await;
252
253 by_query_hash
254 .get(query_hash)
255 .map(|ids| {
256 ids.iter()
257 .filter_map(|id| subscriptions.get(id).cloned())
258 .collect()
259 })
260 .unwrap_or_default()
261 }
262
263 pub async fn counts(&self) -> SubscriptionCounts {
265 let subscriptions = self.subscriptions.read().await;
266 let by_session = self.by_session.read().await;
267
268 SubscriptionCounts {
269 total: subscriptions.len(),
270 unique_queries: self.by_query_hash.read().await.len(),
271 sessions: by_session.len(),
272 memory_bytes: subscriptions.values().map(|s| s.memory_bytes).sum(),
273 }
274 }
275}
276
277#[derive(Debug, Clone, Default)]
279pub struct SubscriptionCounts {
280 pub total: usize,
282 pub unique_queries: usize,
284 pub sessions: usize,
286 pub memory_bytes: usize,
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[tokio::test]
295 async fn test_session_manager_create() {
296 let node_id = NodeId::new();
297 let manager = SessionManager::new(node_id);
298
299 let session = manager.create_session().await;
300 assert!(session.is_connected());
301
302 let retrieved = manager.get_session(session.id).await;
303 assert!(retrieved.is_some());
304 }
305
306 #[tokio::test]
307 async fn test_session_manager_disconnect() {
308 let node_id = NodeId::new();
309 let manager = SessionManager::new(node_id);
310
311 let session = manager.create_session().await;
312 manager.disconnect_session(session.id).await;
313
314 let retrieved = manager.get_session(session.id).await.unwrap();
315 assert!(!retrieved.is_connected());
316 }
317
318 #[tokio::test]
319 async fn test_subscription_manager_create() {
320 let manager = SubscriptionManager::new(50);
321 let session_id = SessionId::new();
322
323 let sub = manager
324 .create_subscription(session_id, "get_projects", serde_json::json!({}))
325 .await
326 .unwrap();
327
328 assert_eq!(sub.query_name, "get_projects");
329
330 let retrieved = manager.get_subscription(sub.id).await;
331 assert!(retrieved.is_some());
332 }
333
334 #[tokio::test]
335 async fn test_subscription_manager_limit() {
336 let manager = SubscriptionManager::new(2);
337 let session_id = SessionId::new();
338
339 manager
341 .create_subscription(session_id, "query1", serde_json::json!({}))
342 .await
343 .unwrap();
344 manager
345 .create_subscription(session_id, "query2", serde_json::json!({}))
346 .await
347 .unwrap();
348
349 let result = manager
351 .create_subscription(session_id, "query3", serde_json::json!({}))
352 .await;
353 assert!(result.is_err());
354 }
355
356 #[tokio::test]
357 async fn test_subscription_manager_remove_session() {
358 let manager = SubscriptionManager::new(50);
359 let session_id = SessionId::new();
360
361 manager
362 .create_subscription(session_id, "query1", serde_json::json!({}))
363 .await
364 .unwrap();
365 manager
366 .create_subscription(session_id, "query2", serde_json::json!({}))
367 .await
368 .unwrap();
369
370 let counts = manager.counts().await;
371 assert_eq!(counts.total, 2);
372
373 manager.remove_session_subscriptions(session_id).await;
374
375 let counts = manager.counts().await;
376 assert_eq!(counts.total, 0);
377 }
378}