Skip to main content

forge_runtime/realtime/
manager.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex};
4
5use dashmap::DashMap;
6
7use forge_core::cluster::NodeId;
8use forge_core::function::AuthContext;
9use forge_core::realtime::{
10    AuthScope, Change, QueryGroup, QueryGroupId, ReadSet, SessionId, SessionInfo, SessionStatus,
11    Subscriber, SubscriberId, SubscriptionId,
12};
13
14/// Session manager for tracking WebSocket connections.
15pub struct SessionManager {
16    sessions: DashMap<SessionId, SessionInfo>,
17    node_id: NodeId,
18}
19
20impl SessionManager {
21    /// Create a new session manager.
22    pub fn new(node_id: NodeId) -> Self {
23        Self {
24            sessions: DashMap::new(),
25            node_id,
26        }
27    }
28
29    /// Create a new session.
30    pub fn create_session(&self) -> SessionInfo {
31        let mut session = SessionInfo::new(self.node_id);
32        session.connect();
33        self.sessions.insert(session.id, session.clone());
34        session
35    }
36
37    /// Get a session by ID.
38    pub fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39        self.sessions.get(&session_id).map(|r| r.clone())
40    }
41
42    /// Update a session.
43    pub fn update_session(&self, session: SessionInfo) {
44        self.sessions.insert(session.id, session);
45    }
46
47    /// Remove a session.
48    pub fn remove_session(&self, session_id: SessionId) {
49        self.sessions.remove(&session_id);
50    }
51
52    /// Mark a session as disconnected.
53    pub fn disconnect_session(&self, session_id: SessionId) {
54        if let Some(mut session) = self.sessions.get_mut(&session_id) {
55            session.disconnect();
56        }
57    }
58
59    /// Get all connected sessions.
60    pub fn get_connected_sessions(&self) -> Vec<SessionInfo> {
61        self.sessions
62            .iter()
63            .filter(|r| r.is_connected())
64            .map(|r| r.clone())
65            .collect()
66    }
67
68    /// Count sessions by status.
69    pub fn count_by_status(&self) -> SessionCounts {
70        let mut counts = SessionCounts::default();
71
72        for entry in self.sessions.iter() {
73            match entry.status {
74                SessionStatus::Connecting => counts.connecting += 1,
75                SessionStatus::Connected => counts.connected += 1,
76                SessionStatus::Reconnecting => counts.reconnecting += 1,
77                SessionStatus::Disconnected => counts.disconnected += 1,
78            }
79            counts.total += 1;
80        }
81
82        counts
83    }
84
85    /// Clean up disconnected sessions older than the given duration.
86    pub fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
87        let cutoff = chrono::Utc::now()
88            - chrono::Duration::from_std(max_age).unwrap_or(chrono::TimeDelta::MAX);
89
90        self.sessions.retain(|_, session| {
91            session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
92        });
93    }
94}
95
96/// Session count statistics.
97#[derive(Debug, Clone, Default)]
98pub struct SessionCounts {
99    pub connecting: usize,
100    pub connected: usize,
101    pub reconnecting: usize,
102    pub disconnected: usize,
103    pub total: usize,
104}
105
106/// Group-based subscription manager using sharded concurrent data structures.
107///
108/// Primary index: groups by QueryGroupId (DashMap, 64 shards).
109/// Secondary: lookup key -> QueryGroupId for dedup.
110/// Subscribers stored in a HashMap for O(1) insert/remove by key.
111/// Session -> subscribers mapping for cleanup.
112pub struct SubscriptionManager {
113    /// Query groups indexed by ID. Sharded for concurrent access.
114    groups: DashMap<QueryGroupId, QueryGroup>,
115    /// Lookup: hash(query_name+args+auth_scope) -> QueryGroupId for dedup.
116    group_lookup: DashMap<u64, QueryGroupId>,
117    /// Subscribers indexed by auto-incrementing key.
118    subscribers: Arc<Mutex<SubscriberStore>>,
119    /// Session -> subscriber IDs for fast cleanup on disconnect.
120    session_subscribers: DashMap<SessionId, Vec<SubscriberId>>,
121    /// Monotonic counter for group IDs.
122    next_group_id: AtomicU32,
123    /// Maximum subscriptions per session.
124    max_per_session: usize,
125}
126
127/// Simple indexed store replacing the `slab` crate.
128struct SubscriberStore {
129    entries: HashMap<usize, Subscriber>,
130    next_key: usize,
131}
132
133impl SubscriberStore {
134    fn new() -> Self {
135        Self {
136            entries: HashMap::new(),
137            next_key: 0,
138        }
139    }
140
141    fn insert(&mut self, value: Subscriber) -> usize {
142        let key = self.next_key;
143        self.next_key += 1;
144        self.entries.insert(key, value);
145        key
146    }
147
148    fn get(&self, key: usize) -> Option<&Subscriber> {
149        self.entries.get(&key)
150    }
151
152    fn remove(&mut self, key: usize) -> Subscriber {
153        self.entries
154            .remove(&key)
155            .expect("key not found in subscriber store")
156    }
157
158    fn contains(&self, key: usize) -> bool {
159        self.entries.contains_key(&key)
160    }
161
162    fn iter(&self) -> impl Iterator<Item = (usize, &Subscriber)> {
163        self.entries.iter().map(|(&k, v)| (k, v))
164    }
165}
166
167impl SubscriptionManager {
168    /// Create a new subscription manager.
169    pub fn new(max_per_session: usize) -> Self {
170        Self {
171            groups: DashMap::new(),
172            group_lookup: DashMap::new(),
173            subscribers: Arc::new(Mutex::new(SubscriberStore::new())),
174            session_subscribers: DashMap::new(),
175            next_group_id: AtomicU32::new(0),
176            max_per_session,
177        }
178    }
179
180    /// Subscribe to a query group. Returns the group ID and whether this is a new group.
181    /// If a group already exists for this query+args+auth_scope, the subscriber joins it.
182    #[allow(clippy::too_many_arguments)]
183    pub fn subscribe(
184        &self,
185        session_id: SessionId,
186        client_sub_id: String,
187        query_name: &str,
188        args: &serde_json::Value,
189        auth_context: &AuthContext,
190        table_deps: &'static [&'static str],
191        selected_cols: &'static [&'static str],
192    ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> {
193        // Check per-session limit
194        if let Some(subs) = self.session_subscribers.get(&session_id)
195            && subs.len() >= self.max_per_session
196        {
197            return Err(forge_core::ForgeError::Validation(format!(
198                "Maximum subscriptions per session ({}) exceeded",
199                self.max_per_session
200            )));
201        }
202
203        let auth_scope = AuthScope::from_auth(auth_context);
204        let lookup_key = QueryGroup::compute_lookup_key(query_name, args, &auth_scope);
205
206        // Atomic check-and-insert via DashMap entry API to avoid TOCTOU races
207        let mut is_new = false;
208        let group_id = *self.group_lookup.entry(lookup_key).or_insert_with(|| {
209            is_new = true;
210            let id = QueryGroupId(self.next_group_id.fetch_add(1, Ordering::Relaxed));
211            let group = QueryGroup {
212                id,
213                query_name: query_name.to_string(),
214                args: Arc::new(args.clone()),
215                auth_scope: auth_scope.clone(),
216                auth_context: auth_context.clone(),
217                table_deps,
218                selected_cols,
219                read_set: ReadSet::new(),
220                last_result_hash: None,
221                subscribers: Vec::new(),
222                created_at: chrono::Utc::now(),
223                execution_count: 0,
224            };
225            self.groups.insert(id, group);
226            id
227        });
228
229        // Create subscriber in the store
230        let subscription_id = SubscriptionId::new();
231        let subscriber_id = {
232            let mut store = self.subscribers.lock().expect("subscriber store poisoned");
233            let key = store.next_key;
234            let sid = SubscriberId(key as u32);
235            store.insert(Subscriber {
236                id: sid,
237                session_id,
238                client_sub_id,
239                group_id,
240                subscription_id,
241            });
242            sid
243        };
244
245        // Add subscriber to group
246        if let Some(mut group) = self.groups.get_mut(&group_id) {
247            group.subscribers.push(subscriber_id);
248        }
249
250        // Track session -> subscriber mapping
251        self.session_subscribers
252            .entry(session_id)
253            .or_default()
254            .push(subscriber_id);
255
256        Ok((group_id, subscription_id, is_new))
257    }
258
259    /// Remove a subscriber by its subscription ID.
260    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
261        let mut store = self.subscribers.lock().expect("subscriber store poisoned");
262
263        // Find the subscriber by subscription_id
264        let sub_key = store
265            .iter()
266            .find(|(_, s)| s.subscription_id == subscription_id)
267            .map(|(key, s)| (key, s.group_id, s.session_id));
268
269        if let Some((key, group_id, session_id)) = sub_key {
270            let subscriber_id = SubscriberId(key as u32);
271            store.remove(key);
272
273            // Remove from group
274            drop(store); // Release lock before accessing DashMap
275            if let Some(mut group) = self.groups.get_mut(&group_id) {
276                group.subscribers.retain(|s| *s != subscriber_id);
277
278                // If group is empty, remove it
279                if group.subscribers.is_empty() {
280                    let lookup_key = QueryGroup::compute_lookup_key(
281                        &group.query_name,
282                        &group.args,
283                        &group.auth_scope,
284                    );
285                    drop(group);
286                    self.groups.remove(&group_id);
287                    self.group_lookup.remove(&lookup_key);
288                }
289            }
290
291            // Remove from session mapping
292            if let Some(mut session_subs) = self.session_subscribers.get_mut(&session_id) {
293                session_subs.retain(|s| *s != subscriber_id);
294            }
295        }
296    }
297
298    /// Remove all subscriptions for a session.
299    pub fn remove_session_subscriptions(&self, session_id: SessionId) -> Vec<SubscriptionId> {
300        let subscriber_ids: Vec<SubscriberId> = self
301            .session_subscribers
302            .remove(&session_id)
303            .map(|(_, ids)| ids)
304            .unwrap_or_default();
305
306        let mut removed_sub_ids = Vec::new();
307        let mut store = self.subscribers.lock().expect("subscriber store poisoned");
308
309        for sid in subscriber_ids {
310            let key = sid.0 as usize;
311            if store.contains(key) {
312                let sub = store.remove(key);
313                removed_sub_ids.push(sub.subscription_id);
314
315                // Remove from group
316                if let Some(mut group) = self.groups.get_mut(&sub.group_id) {
317                    group.subscribers.retain(|s| *s != sid);
318
319                    if group.subscribers.is_empty() {
320                        let lookup_key = QueryGroup::compute_lookup_key(
321                            &group.query_name,
322                            &group.args,
323                            &group.auth_scope,
324                        );
325                        drop(group);
326                        self.groups.remove(&sub.group_id);
327                        self.group_lookup.remove(&lookup_key);
328                    }
329                }
330            }
331        }
332
333        removed_sub_ids
334    }
335
336    /// Find all groups affected by a change. Returns group IDs (not subscription IDs).
337    /// This is O(groups_for_table), not O(all_subscriptions).
338    pub fn find_affected_groups(&self, change: &Change) -> Vec<QueryGroupId> {
339        self.groups
340            .iter()
341            .filter(|entry| entry.should_invalidate(change))
342            .map(|entry| entry.id)
343            .collect()
344    }
345
346    /// Get a reference to a group by ID.
347    pub fn get_group(
348        &self,
349        group_id: QueryGroupId,
350    ) -> Option<dashmap::mapref::one::Ref<'_, QueryGroupId, QueryGroup>> {
351        self.groups.get(&group_id)
352    }
353
354    /// Get a mutable reference to a group by ID.
355    pub fn get_group_mut(
356        &self,
357        group_id: QueryGroupId,
358    ) -> Option<dashmap::mapref::one::RefMut<'_, QueryGroupId, QueryGroup>> {
359        self.groups.get_mut(&group_id)
360    }
361
362    /// Get all subscriber info for a group (for fan-out).
363    pub fn get_group_subscribers(&self, group_id: QueryGroupId) -> Vec<(SessionId, String)> {
364        let subscriber_ids: Vec<SubscriberId> = self
365            .groups
366            .get(&group_id)
367            .map(|g| g.subscribers.clone())
368            .unwrap_or_default();
369
370        let store = self.subscribers.lock().expect("subscriber store poisoned");
371        subscriber_ids
372            .iter()
373            .filter_map(|sid| {
374                store
375                    .get(sid.0 as usize)
376                    .map(|s| (s.session_id, s.client_sub_id.clone()))
377            })
378            .collect()
379    }
380
381    /// Update a group after re-execution.
382    pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) {
383        if let Some(mut group) = self.groups.get_mut(&group_id) {
384            group.record_execution(read_set, result_hash);
385        }
386    }
387
388    /// Get subscription counts.
389    pub fn counts(&self) -> SubscriptionCounts {
390        let total_subscribers: usize = self.groups.iter().map(|g| g.subscribers.len()).sum();
391        let groups_count = self.groups.len();
392        let sessions_count = self.session_subscribers.len();
393
394        // Estimate memory usage:
395        // - Each QueryGroup: ~256 bytes (name, args, auth, read_set, subscribers vec)
396        // - Each subscriber entry in the store: ~128 bytes
397        // - Each session mapping entry: ~64 bytes + subscriber ID vec
398        let estimated_bytes =
399            (groups_count * 256) + (total_subscribers * 128) + (sessions_count * 64);
400
401        SubscriptionCounts {
402            total: total_subscribers,
403            unique_queries: groups_count,
404            sessions: sessions_count,
405            memory_bytes: estimated_bytes,
406        }
407    }
408
409    /// Get group count.
410    pub fn group_count(&self) -> usize {
411        self.groups.len()
412    }
413}
414
415/// Subscription count statistics.
416#[derive(Debug, Clone, Default)]
417pub struct SubscriptionCounts {
418    pub total: usize,
419    pub unique_queries: usize,
420    pub sessions: usize,
421    pub memory_bytes: usize,
422}
423
424#[cfg(test)]
425#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
426mod tests {
427    use super::*;
428    use forge_core::function::AuthContext;
429
430    #[test]
431    fn test_session_manager_create() {
432        let node_id = NodeId::new();
433        let manager = SessionManager::new(node_id);
434
435        let session = manager.create_session();
436        assert!(session.is_connected());
437
438        let retrieved = manager.get_session(session.id);
439        assert!(retrieved.is_some());
440    }
441
442    #[test]
443    fn test_session_manager_disconnect() {
444        let node_id = NodeId::new();
445        let manager = SessionManager::new(node_id);
446
447        let session = manager.create_session();
448        manager.disconnect_session(session.id);
449
450        let retrieved = manager.get_session(session.id).unwrap();
451        assert!(!retrieved.is_connected());
452    }
453
454    #[test]
455    fn test_subscription_manager_create() {
456        let manager = SubscriptionManager::new(50);
457        let session_id = SessionId::new();
458        let auth = AuthContext::unauthenticated();
459
460        let (group_id, _sub_id, is_new) = manager
461            .subscribe(
462                session_id,
463                "sub-1".to_string(),
464                "get_projects",
465                &serde_json::json!({}),
466                &auth,
467                &[],
468                &[],
469            )
470            .unwrap();
471
472        assert!(is_new);
473        assert!(manager.get_group(group_id).is_some());
474    }
475
476    #[test]
477    fn test_subscription_manager_coalescing() {
478        let manager = SubscriptionManager::new(50);
479        let session1 = SessionId::new();
480        let session2 = SessionId::new();
481        let auth = AuthContext::unauthenticated();
482
483        let (g1, _, is_new1) = manager
484            .subscribe(
485                session1,
486                "s1".to_string(),
487                "get_projects",
488                &serde_json::json!({}),
489                &auth,
490                &[],
491                &[],
492            )
493            .unwrap();
494        let (g2, _, is_new2) = manager
495            .subscribe(
496                session2,
497                "s2".to_string(),
498                "get_projects",
499                &serde_json::json!({}),
500                &auth,
501                &[],
502                &[],
503            )
504            .unwrap();
505
506        assert!(is_new1);
507        assert!(!is_new2);
508        assert_eq!(g1, g2);
509
510        // Group should have 2 subscribers
511        let subs = manager.get_group_subscribers(g1);
512        assert_eq!(subs.len(), 2);
513    }
514
515    #[test]
516    fn test_subscription_manager_limit() {
517        let manager = SubscriptionManager::new(2);
518        let session_id = SessionId::new();
519        let auth = AuthContext::unauthenticated();
520
521        manager
522            .subscribe(
523                session_id,
524                "s1".to_string(),
525                "q1",
526                &serde_json::json!({}),
527                &auth,
528                &[],
529                &[],
530            )
531            .unwrap();
532        manager
533            .subscribe(
534                session_id,
535                "s2".to_string(),
536                "q2",
537                &serde_json::json!({}),
538                &auth,
539                &[],
540                &[],
541            )
542            .unwrap();
543
544        let result = manager.subscribe(
545            session_id,
546            "s3".to_string(),
547            "q3",
548            &serde_json::json!({}),
549            &auth,
550            &[],
551            &[],
552        );
553        assert!(result.is_err());
554    }
555
556    #[test]
557    fn test_subscription_manager_remove_session() {
558        let manager = SubscriptionManager::new(50);
559        let session_id = SessionId::new();
560        let auth = AuthContext::unauthenticated();
561
562        manager
563            .subscribe(
564                session_id,
565                "s1".to_string(),
566                "q1",
567                &serde_json::json!({}),
568                &auth,
569                &[],
570                &[],
571            )
572            .unwrap();
573        manager
574            .subscribe(
575                session_id,
576                "s2".to_string(),
577                "q2",
578                &serde_json::json!({}),
579                &auth,
580                &[],
581                &[],
582            )
583            .unwrap();
584
585        let counts = manager.counts();
586        assert_eq!(counts.total, 2);
587
588        manager.remove_session_subscriptions(session_id);
589
590        let counts = manager.counts();
591        assert_eq!(counts.total, 0);
592        assert_eq!(counts.unique_queries, 0);
593    }
594}