Skip to main content

forge_runtime/realtime/
manager.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex};
4
5use dashmap::DashMap;
6
7use forge_core::cluster::NodeId;
8use forge_core::function::AuthContext;
9use forge_core::realtime::{
10    AuthScope, Change, QueryGroup, QueryGroupId, ReadSet, SessionId, SessionInfo, SessionStatus,
11    Subscriber, SubscriberId, SubscriptionId,
12};
13
14/// Session manager for tracking WebSocket connections.
15pub struct SessionManager {
16    sessions: DashMap<SessionId, SessionInfo>,
17    node_id: NodeId,
18}
19
20impl SessionManager {
21    /// Create a new session manager.
22    pub fn new(node_id: NodeId) -> Self {
23        Self {
24            sessions: DashMap::new(),
25            node_id,
26        }
27    }
28
29    /// Create a new session.
30    pub fn create_session(&self) -> SessionInfo {
31        let mut session = SessionInfo::new(self.node_id);
32        session.connect();
33        self.sessions.insert(session.id, session.clone());
34        session
35    }
36
37    /// Get a session by ID.
38    pub fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39        self.sessions.get(&session_id).map(|r| r.clone())
40    }
41
42    /// Update a session.
43    pub fn update_session(&self, session: SessionInfo) {
44        self.sessions.insert(session.id, session);
45    }
46
47    /// Remove a session.
48    pub fn remove_session(&self, session_id: SessionId) {
49        self.sessions.remove(&session_id);
50    }
51
52    /// Mark a session as disconnected.
53    pub fn disconnect_session(&self, session_id: SessionId) {
54        if let Some(mut session) = self.sessions.get_mut(&session_id) {
55            session.disconnect();
56        }
57    }
58
59    /// Get all connected sessions.
60    pub fn get_connected_sessions(&self) -> Vec<SessionInfo> {
61        self.sessions
62            .iter()
63            .filter(|r| r.is_connected())
64            .map(|r| r.clone())
65            .collect()
66    }
67
68    /// Count sessions by status.
69    pub fn count_by_status(&self) -> SessionCounts {
70        let mut counts = SessionCounts::default();
71
72        for entry in self.sessions.iter() {
73            match entry.status {
74                SessionStatus::Connecting => counts.connecting += 1,
75                SessionStatus::Connected => counts.connected += 1,
76                SessionStatus::Reconnecting => counts.reconnecting += 1,
77                SessionStatus::Disconnected => counts.disconnected += 1,
78            }
79            counts.total += 1;
80        }
81
82        counts
83    }
84
85    /// Clean up disconnected sessions older than the given duration.
86    pub fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
87        let cutoff = chrono::Utc::now()
88            - chrono::Duration::from_std(max_age).unwrap_or(chrono::TimeDelta::MAX);
89
90        self.sessions.retain(|_, session| {
91            session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
92        });
93    }
94}
95
96/// Session count statistics.
97#[derive(Debug, Clone, Default)]
98pub struct SessionCounts {
99    pub connecting: usize,
100    pub connected: usize,
101    pub reconnecting: usize,
102    pub disconnected: usize,
103    pub total: usize,
104}
105
106/// Group-based subscription manager using sharded concurrent data structures.
107///
108/// Primary index: groups by QueryGroupId (DashMap, 64 shards).
109/// Secondary: lookup key -> QueryGroupId for dedup.
110/// Subscribers stored in a HashMap for O(1) insert/remove by key.
111/// Session -> subscribers mapping for cleanup.
112pub struct SubscriptionManager {
113    /// Query groups indexed by ID. Sharded for concurrent access.
114    groups: DashMap<QueryGroupId, QueryGroup>,
115    /// Lookup: hash(query_name+args+auth_scope) -> QueryGroupId for dedup.
116    group_lookup: DashMap<u64, QueryGroupId>,
117    /// Subscribers indexed by auto-incrementing key.
118    subscribers: Arc<Mutex<SubscriberStore>>,
119    /// Session -> subscriber IDs for fast cleanup on disconnect.
120    session_subscribers: DashMap<SessionId, Vec<SubscriberId>>,
121    /// Monotonic counter for group IDs.
122    next_group_id: AtomicU32,
123    /// Maximum subscriptions per session.
124    max_per_session: usize,
125}
126
127/// Simple indexed store replacing the `slab` crate.
128struct SubscriberStore {
129    entries: HashMap<usize, Subscriber>,
130    next_key: usize,
131}
132
133impl SubscriberStore {
134    fn new() -> Self {
135        Self {
136            entries: HashMap::new(),
137            next_key: 0,
138        }
139    }
140
141    fn insert(&mut self, value: Subscriber) -> usize {
142        let key = self.next_key;
143        self.next_key += 1;
144        self.entries.insert(key, value);
145        key
146    }
147
148    fn get(&self, key: usize) -> Option<&Subscriber> {
149        self.entries.get(&key)
150    }
151
152    fn remove(&mut self, key: usize) -> Subscriber {
153        self.entries
154            .remove(&key)
155            .expect("key not found in subscriber store")
156    }
157
158    fn contains(&self, key: usize) -> bool {
159        self.entries.contains_key(&key)
160    }
161
162    fn iter(&self) -> impl Iterator<Item = (usize, &Subscriber)> {
163        self.entries.iter().map(|(&k, v)| (k, v))
164    }
165}
166
167impl SubscriptionManager {
168    /// Create a new subscription manager.
169    pub fn new(max_per_session: usize) -> Self {
170        Self {
171            groups: DashMap::new(),
172            group_lookup: DashMap::new(),
173            subscribers: Arc::new(Mutex::new(SubscriberStore::new())),
174            session_subscribers: DashMap::new(),
175            next_group_id: AtomicU32::new(0),
176            max_per_session,
177        }
178    }
179
180    /// Subscribe to a query group. Returns the group ID and whether this is a new group.
181    /// If a group already exists for this query+args+auth_scope, the subscriber joins it.
182    #[allow(clippy::too_many_arguments)]
183    pub fn subscribe(
184        &self,
185        session_id: SessionId,
186        client_sub_id: String,
187        query_name: &str,
188        args: &serde_json::Value,
189        auth_context: &AuthContext,
190        table_deps: &'static [&'static str],
191        selected_cols: &'static [&'static str],
192    ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> {
193        // Check per-session limit
194        if let Some(subs) = self.session_subscribers.get(&session_id)
195            && subs.len() >= self.max_per_session
196        {
197            return Err(forge_core::ForgeError::Validation(format!(
198                "Maximum subscriptions per session ({}) exceeded",
199                self.max_per_session
200            )));
201        }
202
203        let auth_scope = AuthScope::from_auth(auth_context);
204        let lookup_key = QueryGroup::compute_lookup_key(query_name, args, &auth_scope);
205
206        // Atomic check-and-insert via DashMap entry API to avoid TOCTOU races
207        let mut is_new = false;
208        let group_id = *self.group_lookup.entry(lookup_key).or_insert_with(|| {
209            is_new = true;
210            let id = QueryGroupId(self.next_group_id.fetch_add(1, Ordering::Relaxed));
211            let group = QueryGroup {
212                id,
213                query_name: query_name.to_string(),
214                args: Arc::new(args.clone()),
215                auth_scope: auth_scope.clone(),
216                auth_context: auth_context.clone(),
217                table_deps,
218                selected_cols,
219                read_set: ReadSet::new(),
220                last_result_hash: None,
221                subscribers: Vec::new(),
222                created_at: chrono::Utc::now(),
223                execution_count: 0,
224            };
225            self.groups.insert(id, group);
226            id
227        });
228
229        // Create subscriber in the store
230        let subscription_id = SubscriptionId::new();
231        let subscriber_id = {
232            let mut store = self.subscribers.lock().expect("subscriber store poisoned");
233            let key = store.next_key;
234            let sid = SubscriberId(key as u32);
235            store.insert(Subscriber {
236                id: sid,
237                session_id,
238                client_sub_id,
239                group_id,
240                subscription_id,
241            });
242            sid
243        };
244
245        // Add subscriber to group
246        if let Some(mut group) = self.groups.get_mut(&group_id) {
247            group.subscribers.push(subscriber_id);
248        }
249
250        // Track session -> subscriber mapping
251        self.session_subscribers
252            .entry(session_id)
253            .or_default()
254            .push(subscriber_id);
255
256        Ok((group_id, subscription_id, is_new))
257    }
258
259    /// Remove a subscriber by its subscription ID.
260    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
261        let mut store = self.subscribers.lock().expect("subscriber store poisoned");
262
263        // Find the subscriber by subscription_id
264        let sub_key = store
265            .iter()
266            .find(|(_, s)| s.subscription_id == subscription_id)
267            .map(|(key, s)| (key, s.group_id, s.session_id));
268
269        if let Some((key, group_id, session_id)) = sub_key {
270            let subscriber_id = SubscriberId(key as u32);
271            store.remove(key);
272
273            // Remove from group
274            drop(store); // Release lock before accessing DashMap
275            if let Some(mut group) = self.groups.get_mut(&group_id) {
276                group.subscribers.retain(|s| *s != subscriber_id);
277
278                // If group is empty, remove it
279                if group.subscribers.is_empty() {
280                    let lookup_key = QueryGroup::compute_lookup_key(
281                        &group.query_name,
282                        &group.args,
283                        &group.auth_scope,
284                    );
285                    drop(group);
286                    self.groups.remove(&group_id);
287                    self.group_lookup.remove(&lookup_key);
288                }
289            }
290
291            // Remove from session mapping
292            if let Some(mut session_subs) = self.session_subscribers.get_mut(&session_id) {
293                session_subs.retain(|s| *s != subscriber_id);
294            }
295        }
296    }
297
298    /// Remove all subscriptions for a session.
299    pub fn remove_session_subscriptions(&self, session_id: SessionId) -> Vec<SubscriptionId> {
300        let subscriber_ids: Vec<SubscriberId> = self
301            .session_subscribers
302            .remove(&session_id)
303            .map(|(_, ids)| ids)
304            .unwrap_or_default();
305
306        let mut removed_sub_ids = Vec::new();
307        let mut store = self.subscribers.lock().expect("subscriber store poisoned");
308
309        for sid in subscriber_ids {
310            let key = sid.0 as usize;
311            if store.contains(key) {
312                let sub = store.remove(key);
313                removed_sub_ids.push(sub.subscription_id);
314
315                // Remove from group
316                if let Some(mut group) = self.groups.get_mut(&sub.group_id) {
317                    group.subscribers.retain(|s| *s != sid);
318
319                    if group.subscribers.is_empty() {
320                        let lookup_key = QueryGroup::compute_lookup_key(
321                            &group.query_name,
322                            &group.args,
323                            &group.auth_scope,
324                        );
325                        drop(group);
326                        self.groups.remove(&sub.group_id);
327                        self.group_lookup.remove(&lookup_key);
328                    }
329                }
330            }
331        }
332
333        removed_sub_ids
334    }
335
336    /// Find all groups affected by a change. Returns group IDs (not subscription IDs).
337    /// This is O(groups_for_table), not O(all_subscriptions).
338    pub fn find_affected_groups(&self, change: &Change) -> Vec<QueryGroupId> {
339        self.groups
340            .iter()
341            .filter(|entry| entry.should_invalidate(change))
342            .map(|entry| entry.id)
343            .collect()
344    }
345
346    /// Get a reference to a group by ID.
347    pub fn get_group(
348        &self,
349        group_id: QueryGroupId,
350    ) -> Option<dashmap::mapref::one::Ref<'_, QueryGroupId, QueryGroup>> {
351        self.groups.get(&group_id)
352    }
353
354    /// Get a mutable reference to a group by ID.
355    pub fn get_group_mut(
356        &self,
357        group_id: QueryGroupId,
358    ) -> Option<dashmap::mapref::one::RefMut<'_, QueryGroupId, QueryGroup>> {
359        self.groups.get_mut(&group_id)
360    }
361
362    /// Get all subscriber info for a group (for fan-out).
363    pub fn get_group_subscribers(&self, group_id: QueryGroupId) -> Vec<(SessionId, String)> {
364        let subscriber_ids: Vec<SubscriberId> = self
365            .groups
366            .get(&group_id)
367            .map(|g| g.subscribers.clone())
368            .unwrap_or_default();
369
370        let store = self.subscribers.lock().expect("subscriber store poisoned");
371        subscriber_ids
372            .iter()
373            .filter_map(|sid| {
374                store
375                    .get(sid.0 as usize)
376                    .map(|s| (s.session_id, s.client_sub_id.clone()))
377            })
378            .collect()
379    }
380
381    /// Update a group after re-execution.
382    pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) {
383        if let Some(mut group) = self.groups.get_mut(&group_id) {
384            group.record_execution(read_set, result_hash);
385        }
386    }
387
388    /// Get subscription counts.
389    pub fn counts(&self) -> SubscriptionCounts {
390        let total_subscribers: usize = self.groups.iter().map(|g| g.subscribers.len()).sum();
391
392        SubscriptionCounts {
393            total: total_subscribers,
394            unique_queries: self.groups.len(),
395            sessions: self.session_subscribers.len(),
396            memory_bytes: 0, // TODO: calculate if needed
397        }
398    }
399
400    /// Get group count.
401    pub fn group_count(&self) -> usize {
402        self.groups.len()
403    }
404}
405
406/// Subscription count statistics.
407#[derive(Debug, Clone, Default)]
408pub struct SubscriptionCounts {
409    pub total: usize,
410    pub unique_queries: usize,
411    pub sessions: usize,
412    pub memory_bytes: usize,
413}
414
415#[cfg(test)]
416#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
417mod tests {
418    use super::*;
419    use forge_core::function::AuthContext;
420
421    #[test]
422    fn test_session_manager_create() {
423        let node_id = NodeId::new();
424        let manager = SessionManager::new(node_id);
425
426        let session = manager.create_session();
427        assert!(session.is_connected());
428
429        let retrieved = manager.get_session(session.id);
430        assert!(retrieved.is_some());
431    }
432
433    #[test]
434    fn test_session_manager_disconnect() {
435        let node_id = NodeId::new();
436        let manager = SessionManager::new(node_id);
437
438        let session = manager.create_session();
439        manager.disconnect_session(session.id);
440
441        let retrieved = manager.get_session(session.id).unwrap();
442        assert!(!retrieved.is_connected());
443    }
444
445    #[test]
446    fn test_subscription_manager_create() {
447        let manager = SubscriptionManager::new(50);
448        let session_id = SessionId::new();
449        let auth = AuthContext::unauthenticated();
450
451        let (group_id, _sub_id, is_new) = manager
452            .subscribe(
453                session_id,
454                "sub-1".to_string(),
455                "get_projects",
456                &serde_json::json!({}),
457                &auth,
458                &[],
459                &[],
460            )
461            .unwrap();
462
463        assert!(is_new);
464        assert!(manager.get_group(group_id).is_some());
465    }
466
467    #[test]
468    fn test_subscription_manager_coalescing() {
469        let manager = SubscriptionManager::new(50);
470        let session1 = SessionId::new();
471        let session2 = SessionId::new();
472        let auth = AuthContext::unauthenticated();
473
474        let (g1, _, is_new1) = manager
475            .subscribe(
476                session1,
477                "s1".to_string(),
478                "get_projects",
479                &serde_json::json!({}),
480                &auth,
481                &[],
482                &[],
483            )
484            .unwrap();
485        let (g2, _, is_new2) = manager
486            .subscribe(
487                session2,
488                "s2".to_string(),
489                "get_projects",
490                &serde_json::json!({}),
491                &auth,
492                &[],
493                &[],
494            )
495            .unwrap();
496
497        assert!(is_new1);
498        assert!(!is_new2);
499        assert_eq!(g1, g2);
500
501        // Group should have 2 subscribers
502        let subs = manager.get_group_subscribers(g1);
503        assert_eq!(subs.len(), 2);
504    }
505
506    #[test]
507    fn test_subscription_manager_limit() {
508        let manager = SubscriptionManager::new(2);
509        let session_id = SessionId::new();
510        let auth = AuthContext::unauthenticated();
511
512        manager
513            .subscribe(
514                session_id,
515                "s1".to_string(),
516                "q1",
517                &serde_json::json!({}),
518                &auth,
519                &[],
520                &[],
521            )
522            .unwrap();
523        manager
524            .subscribe(
525                session_id,
526                "s2".to_string(),
527                "q2",
528                &serde_json::json!({}),
529                &auth,
530                &[],
531                &[],
532            )
533            .unwrap();
534
535        let result = manager.subscribe(
536            session_id,
537            "s3".to_string(),
538            "q3",
539            &serde_json::json!({}),
540            &auth,
541            &[],
542            &[],
543        );
544        assert!(result.is_err());
545    }
546
547    #[test]
548    fn test_subscription_manager_remove_session() {
549        let manager = SubscriptionManager::new(50);
550        let session_id = SessionId::new();
551        let auth = AuthContext::unauthenticated();
552
553        manager
554            .subscribe(
555                session_id,
556                "s1".to_string(),
557                "q1",
558                &serde_json::json!({}),
559                &auth,
560                &[],
561                &[],
562            )
563            .unwrap();
564        manager
565            .subscribe(
566                session_id,
567                "s2".to_string(),
568                "q2",
569                &serde_json::json!({}),
570                &auth,
571                &[],
572                &[],
573            )
574            .unwrap();
575
576        let counts = manager.counts();
577        assert_eq!(counts.total, 2);
578
579        manager.remove_session_subscriptions(session_id);
580
581        let counts = manager.counts();
582        assert_eq!(counts.total, 0);
583        assert_eq!(counts.unique_queries, 0);
584    }
585}