1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex};
4
5use dashmap::DashMap;
6
7use forge_core::cluster::NodeId;
8use forge_core::function::AuthContext;
9use forge_core::realtime::{
10 AuthScope, Change, QueryGroup, QueryGroupId, ReadSet, SessionId, SessionInfo, SessionStatus,
11 Subscriber, SubscriberId, SubscriptionId,
12};
13
14pub struct SessionManager {
16 sessions: DashMap<SessionId, SessionInfo>,
17 node_id: NodeId,
18}
19
20impl SessionManager {
21 pub fn new(node_id: NodeId) -> Self {
23 Self {
24 sessions: DashMap::new(),
25 node_id,
26 }
27 }
28
29 pub fn create_session(&self) -> SessionInfo {
31 let mut session = SessionInfo::new(self.node_id);
32 session.connect();
33 self.sessions.insert(session.id, session.clone());
34 session
35 }
36
37 pub fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39 self.sessions.get(&session_id).map(|r| r.clone())
40 }
41
42 pub fn update_session(&self, session: SessionInfo) {
44 self.sessions.insert(session.id, session);
45 }
46
47 pub fn remove_session(&self, session_id: SessionId) {
49 self.sessions.remove(&session_id);
50 }
51
52 pub fn disconnect_session(&self, session_id: SessionId) {
54 if let Some(mut session) = self.sessions.get_mut(&session_id) {
55 session.disconnect();
56 }
57 }
58
59 pub fn get_connected_sessions(&self) -> Vec<SessionInfo> {
61 self.sessions
62 .iter()
63 .filter(|r| r.is_connected())
64 .map(|r| r.clone())
65 .collect()
66 }
67
68 pub fn count_by_status(&self) -> SessionCounts {
70 let mut counts = SessionCounts::default();
71
72 for entry in self.sessions.iter() {
73 match entry.status {
74 SessionStatus::Connecting => counts.connecting += 1,
75 SessionStatus::Connected => counts.connected += 1,
76 SessionStatus::Reconnecting => counts.reconnecting += 1,
77 SessionStatus::Disconnected => counts.disconnected += 1,
78 }
79 counts.total += 1;
80 }
81
82 counts
83 }
84
85 pub fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
87 let cutoff = chrono::Utc::now()
88 - chrono::Duration::from_std(max_age).unwrap_or(chrono::TimeDelta::MAX);
89
90 self.sessions.retain(|_, session| {
91 session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
92 });
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct SessionCounts {
99 pub connecting: usize,
100 pub connected: usize,
101 pub reconnecting: usize,
102 pub disconnected: usize,
103 pub total: usize,
104}
105
106pub struct SubscriptionManager {
113 groups: DashMap<QueryGroupId, QueryGroup>,
115 group_lookup: DashMap<u64, QueryGroupId>,
117 subscribers: Arc<Mutex<SubscriberStore>>,
119 session_subscribers: DashMap<SessionId, Vec<SubscriberId>>,
121 next_group_id: AtomicU32,
123 max_per_session: usize,
125}
126
127struct SubscriberStore {
129 entries: HashMap<usize, Subscriber>,
130 next_key: usize,
131}
132
133impl SubscriberStore {
134 fn new() -> Self {
135 Self {
136 entries: HashMap::new(),
137 next_key: 0,
138 }
139 }
140
141 fn insert(&mut self, value: Subscriber) -> usize {
142 let key = self.next_key;
143 self.next_key += 1;
144 self.entries.insert(key, value);
145 key
146 }
147
148 fn get(&self, key: usize) -> Option<&Subscriber> {
149 self.entries.get(&key)
150 }
151
152 fn remove(&mut self, key: usize) -> Option<Subscriber> {
153 self.entries.remove(&key)
154 }
155
156 fn iter(&self) -> impl Iterator<Item = (usize, &Subscriber)> {
157 self.entries.iter().map(|(&k, v)| (k, v))
158 }
159}
160
161impl SubscriptionManager {
162 pub fn new(max_per_session: usize) -> Self {
164 Self {
165 groups: DashMap::new(),
166 group_lookup: DashMap::new(),
167 subscribers: Arc::new(Mutex::new(SubscriberStore::new())),
168 session_subscribers: DashMap::new(),
169 next_group_id: AtomicU32::new(0),
170 max_per_session,
171 }
172 }
173
174 #[allow(clippy::too_many_arguments)]
177 pub fn subscribe(
178 &self,
179 session_id: SessionId,
180 client_sub_id: String,
181 query_name: &str,
182 args: &serde_json::Value,
183 auth_context: &AuthContext,
184 table_deps: &'static [&'static str],
185 selected_cols: &'static [&'static str],
186 ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> {
187 if let Some(subs) = self.session_subscribers.get(&session_id)
189 && subs.len() >= self.max_per_session
190 {
191 return Err(forge_core::ForgeError::Validation(format!(
192 "Maximum subscriptions per session ({}) exceeded",
193 self.max_per_session
194 )));
195 }
196
197 let auth_scope = AuthScope::from_auth(auth_context);
198 let lookup_key = QueryGroup::compute_lookup_key(query_name, args, &auth_scope);
199
200 let mut is_new = false;
202 let group_id = *self.group_lookup.entry(lookup_key).or_insert_with(|| {
203 is_new = true;
204 let id = QueryGroupId(self.next_group_id.fetch_add(1, Ordering::Relaxed));
205 let group = QueryGroup {
206 id,
207 query_name: query_name.to_string(),
208 args: Arc::new(args.clone()),
209 auth_scope: auth_scope.clone(),
210 auth_context: auth_context.clone(),
211 table_deps,
212 selected_cols,
213 read_set: ReadSet::new(),
214 last_result_hash: None,
215 subscribers: Vec::new(),
216 created_at: chrono::Utc::now(),
217 execution_count: 0,
218 };
219 self.groups.insert(id, group);
220 id
221 });
222
223 let subscription_id = SubscriptionId::new();
225 let subscriber_id = {
226 let mut store = self.subscribers.lock().unwrap_or_else(|e| {
227 tracing::error!("Subscriber store lock was poisoned, recovering");
228 e.into_inner()
229 });
230 let key = store.next_key;
231 let sid = SubscriberId(key as u32);
232 store.insert(Subscriber {
233 id: sid,
234 session_id,
235 client_sub_id,
236 group_id,
237 subscription_id,
238 });
239 sid
240 };
241
242 if let Some(mut group) = self.groups.get_mut(&group_id) {
244 group.subscribers.push(subscriber_id);
245 }
246
247 self.session_subscribers
249 .entry(session_id)
250 .or_default()
251 .push(subscriber_id);
252
253 Ok((group_id, subscription_id, is_new))
254 }
255
256 pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
258 let mut store = self.subscribers.lock().unwrap_or_else(|e| {
259 tracing::error!("Subscriber store lock was poisoned, recovering");
260 e.into_inner()
261 });
262
263 let sub_key = store
265 .iter()
266 .find(|(_, s)| s.subscription_id == subscription_id)
267 .map(|(key, s)| (key, s.group_id, s.session_id));
268
269 if let Some((key, group_id, session_id)) = sub_key {
270 let subscriber_id = SubscriberId(key as u32);
271 store.remove(key);
272
273 drop(store); if let Some(mut group) = self.groups.get_mut(&group_id) {
276 group.subscribers.retain(|s| *s != subscriber_id);
277
278 if group.subscribers.is_empty() {
280 let lookup_key = QueryGroup::compute_lookup_key(
281 &group.query_name,
282 &group.args,
283 &group.auth_scope,
284 );
285 drop(group);
286 self.groups.remove(&group_id);
287 self.group_lookup.remove(&lookup_key);
288 }
289 }
290
291 if let Some(mut session_subs) = self.session_subscribers.get_mut(&session_id) {
293 session_subs.retain(|s| *s != subscriber_id);
294 }
295 }
296 }
297
298 pub fn remove_session_subscriptions(&self, session_id: SessionId) -> Vec<SubscriptionId> {
300 let subscriber_ids: Vec<SubscriberId> = self
301 .session_subscribers
302 .remove(&session_id)
303 .map(|(_, ids)| ids)
304 .unwrap_or_default();
305
306 let mut removed_sub_ids = Vec::new();
307 let mut store = self.subscribers.lock().unwrap_or_else(|e| {
308 tracing::error!("Subscriber store lock was poisoned, recovering");
309 e.into_inner()
310 });
311
312 for sid in subscriber_ids {
313 let key = sid.0 as usize;
314 if let Some(sub) = store.remove(key) {
315 removed_sub_ids.push(sub.subscription_id);
316
317 if let Some(mut group) = self.groups.get_mut(&sub.group_id) {
319 group.subscribers.retain(|s| *s != sid);
320
321 if group.subscribers.is_empty() {
322 let lookup_key = QueryGroup::compute_lookup_key(
323 &group.query_name,
324 &group.args,
325 &group.auth_scope,
326 );
327 drop(group);
328 self.groups.remove(&sub.group_id);
329 self.group_lookup.remove(&lookup_key);
330 }
331 }
332 }
333 }
334
335 removed_sub_ids
336 }
337
338 pub fn find_affected_groups(&self, change: &Change) -> Vec<QueryGroupId> {
341 self.groups
342 .iter()
343 .filter(|entry| entry.should_invalidate(change))
344 .map(|entry| entry.id)
345 .collect()
346 }
347
348 pub fn get_group(
350 &self,
351 group_id: QueryGroupId,
352 ) -> Option<dashmap::mapref::one::Ref<'_, QueryGroupId, QueryGroup>> {
353 self.groups.get(&group_id)
354 }
355
356 pub fn get_group_mut(
358 &self,
359 group_id: QueryGroupId,
360 ) -> Option<dashmap::mapref::one::RefMut<'_, QueryGroupId, QueryGroup>> {
361 self.groups.get_mut(&group_id)
362 }
363
364 pub fn get_group_subscribers(&self, group_id: QueryGroupId) -> Vec<(SessionId, String)> {
366 let subscriber_ids: Vec<SubscriberId> = self
367 .groups
368 .get(&group_id)
369 .map(|g| g.subscribers.clone())
370 .unwrap_or_default();
371
372 let store = self.subscribers.lock().unwrap_or_else(|e| {
373 tracing::error!("Subscriber store lock was poisoned, recovering");
374 e.into_inner()
375 });
376 subscriber_ids
377 .iter()
378 .filter_map(|sid| {
379 store
380 .get(sid.0 as usize)
381 .map(|s| (s.session_id, s.client_sub_id.clone()))
382 })
383 .collect()
384 }
385
386 pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) {
388 if let Some(mut group) = self.groups.get_mut(&group_id) {
389 group.record_execution(read_set, result_hash);
390 }
391 }
392
393 pub fn counts(&self) -> SubscriptionCounts {
395 let total_subscribers: usize = self.groups.iter().map(|g| g.subscribers.len()).sum();
396 let groups_count = self.groups.len();
397 let sessions_count = self.session_subscribers.len();
398
399 let estimated_bytes =
404 (groups_count * 256) + (total_subscribers * 128) + (sessions_count * 64);
405
406 SubscriptionCounts {
407 total: total_subscribers,
408 unique_queries: groups_count,
409 sessions: sessions_count,
410 memory_bytes: estimated_bytes,
411 }
412 }
413
414 pub fn group_count(&self) -> usize {
416 self.groups.len()
417 }
418}
419
420#[derive(Debug, Clone, Default)]
422pub struct SubscriptionCounts {
423 pub total: usize,
424 pub unique_queries: usize,
425 pub sessions: usize,
426 pub memory_bytes: usize,
427}
428
429#[cfg(test)]
430#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
431mod tests {
432 use super::*;
433 use forge_core::function::AuthContext;
434
435 #[test]
436 fn test_session_manager_create() {
437 let node_id = NodeId::new();
438 let manager = SessionManager::new(node_id);
439
440 let session = manager.create_session();
441 assert!(session.is_connected());
442
443 let retrieved = manager.get_session(session.id);
444 assert!(retrieved.is_some());
445 }
446
447 #[test]
448 fn test_session_manager_disconnect() {
449 let node_id = NodeId::new();
450 let manager = SessionManager::new(node_id);
451
452 let session = manager.create_session();
453 manager.disconnect_session(session.id);
454
455 let retrieved = manager.get_session(session.id).unwrap();
456 assert!(!retrieved.is_connected());
457 }
458
459 #[test]
460 fn test_subscription_manager_create() {
461 let manager = SubscriptionManager::new(50);
462 let session_id = SessionId::new();
463 let auth = AuthContext::unauthenticated();
464
465 let (group_id, _sub_id, is_new) = manager
466 .subscribe(
467 session_id,
468 "sub-1".to_string(),
469 "get_projects",
470 &serde_json::json!({}),
471 &auth,
472 &[],
473 &[],
474 )
475 .unwrap();
476
477 assert!(is_new);
478 assert!(manager.get_group(group_id).is_some());
479 }
480
481 #[test]
482 fn test_subscription_manager_coalescing() {
483 let manager = SubscriptionManager::new(50);
484 let session1 = SessionId::new();
485 let session2 = SessionId::new();
486 let auth = AuthContext::unauthenticated();
487
488 let (g1, _, is_new1) = manager
489 .subscribe(
490 session1,
491 "s1".to_string(),
492 "get_projects",
493 &serde_json::json!({}),
494 &auth,
495 &[],
496 &[],
497 )
498 .unwrap();
499 let (g2, _, is_new2) = manager
500 .subscribe(
501 session2,
502 "s2".to_string(),
503 "get_projects",
504 &serde_json::json!({}),
505 &auth,
506 &[],
507 &[],
508 )
509 .unwrap();
510
511 assert!(is_new1);
512 assert!(!is_new2);
513 assert_eq!(g1, g2);
514
515 let subs = manager.get_group_subscribers(g1);
517 assert_eq!(subs.len(), 2);
518 }
519
520 #[test]
521 fn test_subscription_manager_limit() {
522 let manager = SubscriptionManager::new(2);
523 let session_id = SessionId::new();
524 let auth = AuthContext::unauthenticated();
525
526 manager
527 .subscribe(
528 session_id,
529 "s1".to_string(),
530 "q1",
531 &serde_json::json!({}),
532 &auth,
533 &[],
534 &[],
535 )
536 .unwrap();
537 manager
538 .subscribe(
539 session_id,
540 "s2".to_string(),
541 "q2",
542 &serde_json::json!({}),
543 &auth,
544 &[],
545 &[],
546 )
547 .unwrap();
548
549 let result = manager.subscribe(
550 session_id,
551 "s3".to_string(),
552 "q3",
553 &serde_json::json!({}),
554 &auth,
555 &[],
556 &[],
557 );
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_subscription_manager_remove_session() {
563 let manager = SubscriptionManager::new(50);
564 let session_id = SessionId::new();
565 let auth = AuthContext::unauthenticated();
566
567 manager
568 .subscribe(
569 session_id,
570 "s1".to_string(),
571 "q1",
572 &serde_json::json!({}),
573 &auth,
574 &[],
575 &[],
576 )
577 .unwrap();
578 manager
579 .subscribe(
580 session_id,
581 "s2".to_string(),
582 "q2",
583 &serde_json::json!({}),
584 &auth,
585 &[],
586 &[],
587 )
588 .unwrap();
589
590 let counts = manager.counts();
591 assert_eq!(counts.total, 2);
592
593 manager.remove_session_subscriptions(session_id);
594
595 let counts = manager.counts();
596 assert_eq!(counts.total, 0);
597 assert_eq!(counts.unique_queries, 0);
598 }
599}