1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex};
4
5use dashmap::DashMap;
6
7use forge_core::cluster::NodeId;
8use forge_core::function::AuthContext;
9use forge_core::realtime::{
10 AuthScope, Change, QueryGroup, QueryGroupId, ReadSet, SessionId, SessionInfo, SessionStatus,
11 Subscriber, SubscriberId, SubscriptionId,
12};
13
14pub struct SessionManager {
16 sessions: DashMap<SessionId, SessionInfo>,
17 node_id: NodeId,
18}
19
20impl SessionManager {
21 pub fn new(node_id: NodeId) -> Self {
23 Self {
24 sessions: DashMap::new(),
25 node_id,
26 }
27 }
28
29 pub fn create_session(&self) -> SessionInfo {
31 let mut session = SessionInfo::new(self.node_id);
32 session.connect();
33 self.sessions.insert(session.id, session.clone());
34 session
35 }
36
37 pub fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39 self.sessions.get(&session_id).map(|r| r.clone())
40 }
41
42 pub fn update_session(&self, session: SessionInfo) {
44 self.sessions.insert(session.id, session);
45 }
46
47 pub fn remove_session(&self, session_id: SessionId) {
49 self.sessions.remove(&session_id);
50 }
51
52 pub fn disconnect_session(&self, session_id: SessionId) {
54 if let Some(mut session) = self.sessions.get_mut(&session_id) {
55 session.disconnect();
56 }
57 }
58
59 pub fn get_connected_sessions(&self) -> Vec<SessionInfo> {
61 self.sessions
62 .iter()
63 .filter(|r| r.is_connected())
64 .map(|r| r.clone())
65 .collect()
66 }
67
68 pub fn count_by_status(&self) -> SessionCounts {
70 let mut counts = SessionCounts::default();
71
72 for entry in self.sessions.iter() {
73 match entry.status {
74 SessionStatus::Connecting => counts.connecting += 1,
75 SessionStatus::Connected => counts.connected += 1,
76 SessionStatus::Reconnecting => counts.reconnecting += 1,
77 SessionStatus::Disconnected => counts.disconnected += 1,
78 }
79 counts.total += 1;
80 }
81
82 counts
83 }
84
85 pub fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
87 let cutoff = chrono::Utc::now()
88 - chrono::Duration::from_std(max_age).unwrap_or(chrono::TimeDelta::MAX);
89
90 self.sessions.retain(|_, session| {
91 session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
92 });
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct SessionCounts {
99 pub connecting: usize,
100 pub connected: usize,
101 pub reconnecting: usize,
102 pub disconnected: usize,
103 pub total: usize,
104}
105
106pub struct SubscriptionManager {
113 groups: DashMap<QueryGroupId, QueryGroup>,
115 group_lookup: DashMap<u64, QueryGroupId>,
117 subscribers: Arc<Mutex<SubscriberStore>>,
119 session_subscribers: DashMap<SessionId, Vec<SubscriberId>>,
121 next_group_id: AtomicU32,
123 max_per_session: usize,
125}
126
127struct SubscriberStore {
129 entries: HashMap<usize, Subscriber>,
130 next_key: usize,
131}
132
133impl SubscriberStore {
134 fn new() -> Self {
135 Self {
136 entries: HashMap::new(),
137 next_key: 0,
138 }
139 }
140
141 fn insert(&mut self, value: Subscriber) -> usize {
142 let key = self.next_key;
143 self.next_key += 1;
144 self.entries.insert(key, value);
145 key
146 }
147
148 fn get(&self, key: usize) -> Option<&Subscriber> {
149 self.entries.get(&key)
150 }
151
152 fn remove(&mut self, key: usize) -> Subscriber {
153 self.entries
154 .remove(&key)
155 .expect("key not found in subscriber store")
156 }
157
158 fn contains(&self, key: usize) -> bool {
159 self.entries.contains_key(&key)
160 }
161
162 fn iter(&self) -> impl Iterator<Item = (usize, &Subscriber)> {
163 self.entries.iter().map(|(&k, v)| (k, v))
164 }
165}
166
167impl SubscriptionManager {
168 pub fn new(max_per_session: usize) -> Self {
170 Self {
171 groups: DashMap::new(),
172 group_lookup: DashMap::new(),
173 subscribers: Arc::new(Mutex::new(SubscriberStore::new())),
174 session_subscribers: DashMap::new(),
175 next_group_id: AtomicU32::new(0),
176 max_per_session,
177 }
178 }
179
180 #[allow(clippy::too_many_arguments)]
183 pub fn subscribe(
184 &self,
185 session_id: SessionId,
186 client_sub_id: String,
187 query_name: &str,
188 args: &serde_json::Value,
189 auth_context: &AuthContext,
190 table_deps: &'static [&'static str],
191 selected_cols: &'static [&'static str],
192 ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> {
193 if let Some(subs) = self.session_subscribers.get(&session_id)
195 && subs.len() >= self.max_per_session
196 {
197 return Err(forge_core::ForgeError::Validation(format!(
198 "Maximum subscriptions per session ({}) exceeded",
199 self.max_per_session
200 )));
201 }
202
203 let auth_scope = AuthScope::from_auth(auth_context);
204 let lookup_key = QueryGroup::compute_lookup_key(query_name, args, &auth_scope);
205
206 let mut is_new = false;
208 let group_id = *self.group_lookup.entry(lookup_key).or_insert_with(|| {
209 is_new = true;
210 let id = QueryGroupId(self.next_group_id.fetch_add(1, Ordering::Relaxed));
211 let group = QueryGroup {
212 id,
213 query_name: query_name.to_string(),
214 args: Arc::new(args.clone()),
215 auth_scope: auth_scope.clone(),
216 auth_context: auth_context.clone(),
217 table_deps,
218 selected_cols,
219 read_set: ReadSet::new(),
220 last_result_hash: None,
221 subscribers: Vec::new(),
222 created_at: chrono::Utc::now(),
223 execution_count: 0,
224 };
225 self.groups.insert(id, group);
226 id
227 });
228
229 let subscription_id = SubscriptionId::new();
231 let subscriber_id = {
232 let mut store = self.subscribers.lock().expect("subscriber store poisoned");
233 let key = store.next_key;
234 let sid = SubscriberId(key as u32);
235 store.insert(Subscriber {
236 id: sid,
237 session_id,
238 client_sub_id,
239 group_id,
240 subscription_id,
241 });
242 sid
243 };
244
245 if let Some(mut group) = self.groups.get_mut(&group_id) {
247 group.subscribers.push(subscriber_id);
248 }
249
250 self.session_subscribers
252 .entry(session_id)
253 .or_default()
254 .push(subscriber_id);
255
256 Ok((group_id, subscription_id, is_new))
257 }
258
259 pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
261 let mut store = self.subscribers.lock().expect("subscriber store poisoned");
262
263 let sub_key = store
265 .iter()
266 .find(|(_, s)| s.subscription_id == subscription_id)
267 .map(|(key, s)| (key, s.group_id, s.session_id));
268
269 if let Some((key, group_id, session_id)) = sub_key {
270 let subscriber_id = SubscriberId(key as u32);
271 store.remove(key);
272
273 drop(store); if let Some(mut group) = self.groups.get_mut(&group_id) {
276 group.subscribers.retain(|s| *s != subscriber_id);
277
278 if group.subscribers.is_empty() {
280 let lookup_key = QueryGroup::compute_lookup_key(
281 &group.query_name,
282 &group.args,
283 &group.auth_scope,
284 );
285 drop(group);
286 self.groups.remove(&group_id);
287 self.group_lookup.remove(&lookup_key);
288 }
289 }
290
291 if let Some(mut session_subs) = self.session_subscribers.get_mut(&session_id) {
293 session_subs.retain(|s| *s != subscriber_id);
294 }
295 }
296 }
297
298 pub fn remove_session_subscriptions(&self, session_id: SessionId) -> Vec<SubscriptionId> {
300 let subscriber_ids: Vec<SubscriberId> = self
301 .session_subscribers
302 .remove(&session_id)
303 .map(|(_, ids)| ids)
304 .unwrap_or_default();
305
306 let mut removed_sub_ids = Vec::new();
307 let mut store = self.subscribers.lock().expect("subscriber store poisoned");
308
309 for sid in subscriber_ids {
310 let key = sid.0 as usize;
311 if store.contains(key) {
312 let sub = store.remove(key);
313 removed_sub_ids.push(sub.subscription_id);
314
315 if let Some(mut group) = self.groups.get_mut(&sub.group_id) {
317 group.subscribers.retain(|s| *s != sid);
318
319 if group.subscribers.is_empty() {
320 let lookup_key = QueryGroup::compute_lookup_key(
321 &group.query_name,
322 &group.args,
323 &group.auth_scope,
324 );
325 drop(group);
326 self.groups.remove(&sub.group_id);
327 self.group_lookup.remove(&lookup_key);
328 }
329 }
330 }
331 }
332
333 removed_sub_ids
334 }
335
336 pub fn find_affected_groups(&self, change: &Change) -> Vec<QueryGroupId> {
339 self.groups
340 .iter()
341 .filter(|entry| entry.should_invalidate(change))
342 .map(|entry| entry.id)
343 .collect()
344 }
345
346 pub fn get_group(
348 &self,
349 group_id: QueryGroupId,
350 ) -> Option<dashmap::mapref::one::Ref<'_, QueryGroupId, QueryGroup>> {
351 self.groups.get(&group_id)
352 }
353
354 pub fn get_group_mut(
356 &self,
357 group_id: QueryGroupId,
358 ) -> Option<dashmap::mapref::one::RefMut<'_, QueryGroupId, QueryGroup>> {
359 self.groups.get_mut(&group_id)
360 }
361
362 pub fn get_group_subscribers(&self, group_id: QueryGroupId) -> Vec<(SessionId, String)> {
364 let subscriber_ids: Vec<SubscriberId> = self
365 .groups
366 .get(&group_id)
367 .map(|g| g.subscribers.clone())
368 .unwrap_or_default();
369
370 let store = self.subscribers.lock().expect("subscriber store poisoned");
371 subscriber_ids
372 .iter()
373 .filter_map(|sid| {
374 store
375 .get(sid.0 as usize)
376 .map(|s| (s.session_id, s.client_sub_id.clone()))
377 })
378 .collect()
379 }
380
381 pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) {
383 if let Some(mut group) = self.groups.get_mut(&group_id) {
384 group.record_execution(read_set, result_hash);
385 }
386 }
387
388 pub fn counts(&self) -> SubscriptionCounts {
390 let total_subscribers: usize = self.groups.iter().map(|g| g.subscribers.len()).sum();
391
392 SubscriptionCounts {
393 total: total_subscribers,
394 unique_queries: self.groups.len(),
395 sessions: self.session_subscribers.len(),
396 memory_bytes: 0, }
398 }
399
400 pub fn group_count(&self) -> usize {
402 self.groups.len()
403 }
404}
405
406#[derive(Debug, Clone, Default)]
408pub struct SubscriptionCounts {
409 pub total: usize,
410 pub unique_queries: usize,
411 pub sessions: usize,
412 pub memory_bytes: usize,
413}
414
415#[cfg(test)]
416#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
417mod tests {
418 use super::*;
419 use forge_core::function::AuthContext;
420
421 #[test]
422 fn test_session_manager_create() {
423 let node_id = NodeId::new();
424 let manager = SessionManager::new(node_id);
425
426 let session = manager.create_session();
427 assert!(session.is_connected());
428
429 let retrieved = manager.get_session(session.id);
430 assert!(retrieved.is_some());
431 }
432
433 #[test]
434 fn test_session_manager_disconnect() {
435 let node_id = NodeId::new();
436 let manager = SessionManager::new(node_id);
437
438 let session = manager.create_session();
439 manager.disconnect_session(session.id);
440
441 let retrieved = manager.get_session(session.id).unwrap();
442 assert!(!retrieved.is_connected());
443 }
444
445 #[test]
446 fn test_subscription_manager_create() {
447 let manager = SubscriptionManager::new(50);
448 let session_id = SessionId::new();
449 let auth = AuthContext::unauthenticated();
450
451 let (group_id, _sub_id, is_new) = manager
452 .subscribe(
453 session_id,
454 "sub-1".to_string(),
455 "get_projects",
456 &serde_json::json!({}),
457 &auth,
458 &[],
459 &[],
460 )
461 .unwrap();
462
463 assert!(is_new);
464 assert!(manager.get_group(group_id).is_some());
465 }
466
467 #[test]
468 fn test_subscription_manager_coalescing() {
469 let manager = SubscriptionManager::new(50);
470 let session1 = SessionId::new();
471 let session2 = SessionId::new();
472 let auth = AuthContext::unauthenticated();
473
474 let (g1, _, is_new1) = manager
475 .subscribe(
476 session1,
477 "s1".to_string(),
478 "get_projects",
479 &serde_json::json!({}),
480 &auth,
481 &[],
482 &[],
483 )
484 .unwrap();
485 let (g2, _, is_new2) = manager
486 .subscribe(
487 session2,
488 "s2".to_string(),
489 "get_projects",
490 &serde_json::json!({}),
491 &auth,
492 &[],
493 &[],
494 )
495 .unwrap();
496
497 assert!(is_new1);
498 assert!(!is_new2);
499 assert_eq!(g1, g2);
500
501 let subs = manager.get_group_subscribers(g1);
503 assert_eq!(subs.len(), 2);
504 }
505
506 #[test]
507 fn test_subscription_manager_limit() {
508 let manager = SubscriptionManager::new(2);
509 let session_id = SessionId::new();
510 let auth = AuthContext::unauthenticated();
511
512 manager
513 .subscribe(
514 session_id,
515 "s1".to_string(),
516 "q1",
517 &serde_json::json!({}),
518 &auth,
519 &[],
520 &[],
521 )
522 .unwrap();
523 manager
524 .subscribe(
525 session_id,
526 "s2".to_string(),
527 "q2",
528 &serde_json::json!({}),
529 &auth,
530 &[],
531 &[],
532 )
533 .unwrap();
534
535 let result = manager.subscribe(
536 session_id,
537 "s3".to_string(),
538 "q3",
539 &serde_json::json!({}),
540 &auth,
541 &[],
542 &[],
543 );
544 assert!(result.is_err());
545 }
546
547 #[test]
548 fn test_subscription_manager_remove_session() {
549 let manager = SubscriptionManager::new(50);
550 let session_id = SessionId::new();
551 let auth = AuthContext::unauthenticated();
552
553 manager
554 .subscribe(
555 session_id,
556 "s1".to_string(),
557 "q1",
558 &serde_json::json!({}),
559 &auth,
560 &[],
561 &[],
562 )
563 .unwrap();
564 manager
565 .subscribe(
566 session_id,
567 "s2".to_string(),
568 "q2",
569 &serde_json::json!({}),
570 &auth,
571 &[],
572 &[],
573 )
574 .unwrap();
575
576 let counts = manager.counts();
577 assert_eq!(counts.total, 2);
578
579 manager.remove_session_subscriptions(session_id);
580
581 let counts = manager.counts();
582 assert_eq!(counts.total, 0);
583 assert_eq!(counts.unique_queries, 0);
584 }
585}