Skip to main content

forge_runtime/realtime/
manager.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex};
4
5use dashmap::DashMap;
6
7use forge_core::cluster::NodeId;
8use forge_core::function::AuthContext;
9use forge_core::realtime::{
10    AuthScope, Change, QueryGroup, QueryGroupId, ReadSet, SessionId, SessionInfo, SessionStatus,
11    Subscriber, SubscriberId, SubscriptionId,
12};
13
14/// Session manager for tracking WebSocket connections.
15pub struct SessionManager {
16    sessions: DashMap<SessionId, SessionInfo>,
17    node_id: NodeId,
18}
19
20impl SessionManager {
21    /// Create a new session manager.
22    pub fn new(node_id: NodeId) -> Self {
23        Self {
24            sessions: DashMap::new(),
25            node_id,
26        }
27    }
28
29    /// Create a new session.
30    pub fn create_session(&self) -> SessionInfo {
31        let mut session = SessionInfo::new(self.node_id);
32        session.connect();
33        self.sessions.insert(session.id, session.clone());
34        session
35    }
36
37    /// Get a session by ID.
38    pub fn get_session(&self, session_id: SessionId) -> Option<SessionInfo> {
39        self.sessions.get(&session_id).map(|r| r.clone())
40    }
41
42    /// Update a session.
43    pub fn update_session(&self, session: SessionInfo) {
44        self.sessions.insert(session.id, session);
45    }
46
47    /// Remove a session.
48    pub fn remove_session(&self, session_id: SessionId) {
49        self.sessions.remove(&session_id);
50    }
51
52    /// Mark a session as disconnected.
53    pub fn disconnect_session(&self, session_id: SessionId) {
54        if let Some(mut session) = self.sessions.get_mut(&session_id) {
55            session.disconnect();
56        }
57    }
58
59    /// Get all connected sessions.
60    pub fn get_connected_sessions(&self) -> Vec<SessionInfo> {
61        self.sessions
62            .iter()
63            .filter(|r| r.is_connected())
64            .map(|r| r.clone())
65            .collect()
66    }
67
68    /// Count sessions by status.
69    pub fn count_by_status(&self) -> SessionCounts {
70        let mut counts = SessionCounts::default();
71
72        for entry in self.sessions.iter() {
73            match entry.status {
74                SessionStatus::Connecting => counts.connecting += 1,
75                SessionStatus::Connected => counts.connected += 1,
76                SessionStatus::Reconnecting => counts.reconnecting += 1,
77                SessionStatus::Disconnected => counts.disconnected += 1,
78            }
79            counts.total += 1;
80        }
81
82        counts
83    }
84
85    /// Clean up disconnected sessions older than the given duration.
86    pub fn cleanup_old_sessions(&self, max_age: std::time::Duration) {
87        let cutoff = chrono::Utc::now()
88            - chrono::Duration::from_std(max_age).unwrap_or(chrono::TimeDelta::MAX);
89
90        self.sessions.retain(|_, session| {
91            session.status != SessionStatus::Disconnected || session.last_active_at > cutoff
92        });
93    }
94}
95
96/// Session count statistics.
97#[derive(Debug, Clone, Default)]
98pub struct SessionCounts {
99    pub connecting: usize,
100    pub connected: usize,
101    pub reconnecting: usize,
102    pub disconnected: usize,
103    pub total: usize,
104}
105
106/// Group-based subscription manager using sharded concurrent data structures.
107///
108/// Primary index: groups by QueryGroupId (DashMap, 64 shards).
109/// Secondary: lookup key -> QueryGroupId for dedup.
110/// Subscribers stored in a HashMap for O(1) insert/remove by key.
111/// Session -> subscribers mapping for cleanup.
112pub struct SubscriptionManager {
113    /// Query groups indexed by ID. Sharded for concurrent access.
114    groups: DashMap<QueryGroupId, QueryGroup>,
115    /// Lookup: hash(query_name+args+auth_scope) -> QueryGroupId for dedup.
116    group_lookup: DashMap<u64, QueryGroupId>,
117    /// Subscribers indexed by auto-incrementing key.
118    subscribers: Arc<Mutex<SubscriberStore>>,
119    /// Session -> subscriber IDs for fast cleanup on disconnect.
120    session_subscribers: DashMap<SessionId, Vec<SubscriberId>>,
121    /// Monotonic counter for group IDs.
122    next_group_id: AtomicU32,
123    /// Maximum subscriptions per session.
124    max_per_session: usize,
125}
126
127/// Simple indexed store replacing the `slab` crate.
128struct SubscriberStore {
129    entries: HashMap<usize, Subscriber>,
130    next_key: usize,
131}
132
133impl SubscriberStore {
134    fn new() -> Self {
135        Self {
136            entries: HashMap::new(),
137            next_key: 0,
138        }
139    }
140
141    fn insert(&mut self, value: Subscriber) -> usize {
142        let key = self.next_key;
143        self.next_key += 1;
144        self.entries.insert(key, value);
145        key
146    }
147
148    fn get(&self, key: usize) -> Option<&Subscriber> {
149        self.entries.get(&key)
150    }
151
152    fn remove(&mut self, key: usize) -> Option<Subscriber> {
153        self.entries.remove(&key)
154    }
155
156    fn iter(&self) -> impl Iterator<Item = (usize, &Subscriber)> {
157        self.entries.iter().map(|(&k, v)| (k, v))
158    }
159}
160
161impl SubscriptionManager {
162    /// Create a new subscription manager.
163    pub fn new(max_per_session: usize) -> Self {
164        Self {
165            groups: DashMap::new(),
166            group_lookup: DashMap::new(),
167            subscribers: Arc::new(Mutex::new(SubscriberStore::new())),
168            session_subscribers: DashMap::new(),
169            next_group_id: AtomicU32::new(0),
170            max_per_session,
171        }
172    }
173
174    /// Subscribe to a query group. Returns the group ID and whether this is a new group.
175    /// If a group already exists for this query+args+auth_scope, the subscriber joins it.
176    #[allow(clippy::too_many_arguments)]
177    pub fn subscribe(
178        &self,
179        session_id: SessionId,
180        client_sub_id: String,
181        query_name: &str,
182        args: &serde_json::Value,
183        auth_context: &AuthContext,
184        table_deps: &'static [&'static str],
185        selected_cols: &'static [&'static str],
186    ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> {
187        // Check per-session limit
188        if let Some(subs) = self.session_subscribers.get(&session_id)
189            && subs.len() >= self.max_per_session
190        {
191            return Err(forge_core::ForgeError::Validation(format!(
192                "Maximum subscriptions per session ({}) exceeded",
193                self.max_per_session
194            )));
195        }
196
197        let auth_scope = AuthScope::from_auth(auth_context);
198        let lookup_key = QueryGroup::compute_lookup_key(query_name, args, &auth_scope);
199
200        // Atomic check-and-insert via DashMap entry API to avoid TOCTOU races
201        let mut is_new = false;
202        let group_id = *self.group_lookup.entry(lookup_key).or_insert_with(|| {
203            is_new = true;
204            let id = QueryGroupId(self.next_group_id.fetch_add(1, Ordering::Relaxed));
205            let group = QueryGroup {
206                id,
207                query_name: query_name.to_string(),
208                args: Arc::new(args.clone()),
209                auth_scope: auth_scope.clone(),
210                auth_context: auth_context.clone(),
211                table_deps,
212                selected_cols,
213                read_set: ReadSet::new(),
214                last_result_hash: None,
215                subscribers: Vec::new(),
216                created_at: chrono::Utc::now(),
217                execution_count: 0,
218            };
219            self.groups.insert(id, group);
220            id
221        });
222
223        // Create subscriber in the store
224        let subscription_id = SubscriptionId::new();
225        let subscriber_id = {
226            let mut store = self.subscribers.lock().unwrap_or_else(|e| {
227                tracing::error!("Subscriber store lock was poisoned, recovering");
228                e.into_inner()
229            });
230            let key = store.next_key;
231            let sid = SubscriberId(key as u32);
232            store.insert(Subscriber {
233                id: sid,
234                session_id,
235                client_sub_id,
236                group_id,
237                subscription_id,
238            });
239            sid
240        };
241
242        // Add subscriber to group
243        if let Some(mut group) = self.groups.get_mut(&group_id) {
244            group.subscribers.push(subscriber_id);
245        }
246
247        // Track session -> subscriber mapping
248        self.session_subscribers
249            .entry(session_id)
250            .or_default()
251            .push(subscriber_id);
252
253        Ok((group_id, subscription_id, is_new))
254    }
255
256    /// Remove a subscriber by its subscription ID.
257    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
258        let mut store = self.subscribers.lock().unwrap_or_else(|e| {
259            tracing::error!("Subscriber store lock was poisoned, recovering");
260            e.into_inner()
261        });
262
263        // Find the subscriber by subscription_id
264        let sub_key = store
265            .iter()
266            .find(|(_, s)| s.subscription_id == subscription_id)
267            .map(|(key, s)| (key, s.group_id, s.session_id));
268
269        if let Some((key, group_id, session_id)) = sub_key {
270            let subscriber_id = SubscriberId(key as u32);
271            store.remove(key);
272
273            // Remove from group
274            drop(store); // Release lock before accessing DashMap
275            if let Some(mut group) = self.groups.get_mut(&group_id) {
276                group.subscribers.retain(|s| *s != subscriber_id);
277
278                // If group is empty, remove it
279                if group.subscribers.is_empty() {
280                    let lookup_key = QueryGroup::compute_lookup_key(
281                        &group.query_name,
282                        &group.args,
283                        &group.auth_scope,
284                    );
285                    drop(group);
286                    self.groups.remove(&group_id);
287                    self.group_lookup.remove(&lookup_key);
288                }
289            }
290
291            // Remove from session mapping
292            if let Some(mut session_subs) = self.session_subscribers.get_mut(&session_id) {
293                session_subs.retain(|s| *s != subscriber_id);
294            }
295        }
296    }
297
298    /// Remove all subscriptions for a session.
299    pub fn remove_session_subscriptions(&self, session_id: SessionId) -> Vec<SubscriptionId> {
300        let subscriber_ids: Vec<SubscriberId> = self
301            .session_subscribers
302            .remove(&session_id)
303            .map(|(_, ids)| ids)
304            .unwrap_or_default();
305
306        let mut removed_sub_ids = Vec::new();
307        let mut store = self.subscribers.lock().unwrap_or_else(|e| {
308            tracing::error!("Subscriber store lock was poisoned, recovering");
309            e.into_inner()
310        });
311
312        for sid in subscriber_ids {
313            let key = sid.0 as usize;
314            if let Some(sub) = store.remove(key) {
315                removed_sub_ids.push(sub.subscription_id);
316
317                // Remove from group
318                if let Some(mut group) = self.groups.get_mut(&sub.group_id) {
319                    group.subscribers.retain(|s| *s != sid);
320
321                    if group.subscribers.is_empty() {
322                        let lookup_key = QueryGroup::compute_lookup_key(
323                            &group.query_name,
324                            &group.args,
325                            &group.auth_scope,
326                        );
327                        drop(group);
328                        self.groups.remove(&sub.group_id);
329                        self.group_lookup.remove(&lookup_key);
330                    }
331                }
332            }
333        }
334
335        removed_sub_ids
336    }
337
338    /// Find all groups affected by a change. Returns group IDs (not subscription IDs).
339    /// This is O(groups_for_table), not O(all_subscriptions).
340    pub fn find_affected_groups(&self, change: &Change) -> Vec<QueryGroupId> {
341        self.groups
342            .iter()
343            .filter(|entry| entry.should_invalidate(change))
344            .map(|entry| entry.id)
345            .collect()
346    }
347
348    /// Get a reference to a group by ID.
349    pub fn get_group(
350        &self,
351        group_id: QueryGroupId,
352    ) -> Option<dashmap::mapref::one::Ref<'_, QueryGroupId, QueryGroup>> {
353        self.groups.get(&group_id)
354    }
355
356    /// Get a mutable reference to a group by ID.
357    pub fn get_group_mut(
358        &self,
359        group_id: QueryGroupId,
360    ) -> Option<dashmap::mapref::one::RefMut<'_, QueryGroupId, QueryGroup>> {
361        self.groups.get_mut(&group_id)
362    }
363
364    /// Get all subscriber info for a group (for fan-out).
365    pub fn get_group_subscribers(&self, group_id: QueryGroupId) -> Vec<(SessionId, String)> {
366        let subscriber_ids: Vec<SubscriberId> = self
367            .groups
368            .get(&group_id)
369            .map(|g| g.subscribers.clone())
370            .unwrap_or_default();
371
372        let store = self.subscribers.lock().unwrap_or_else(|e| {
373            tracing::error!("Subscriber store lock was poisoned, recovering");
374            e.into_inner()
375        });
376        subscriber_ids
377            .iter()
378            .filter_map(|sid| {
379                store
380                    .get(sid.0 as usize)
381                    .map(|s| (s.session_id, s.client_sub_id.clone()))
382            })
383            .collect()
384    }
385
386    /// Update a group after re-execution.
387    pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) {
388        if let Some(mut group) = self.groups.get_mut(&group_id) {
389            group.record_execution(read_set, result_hash);
390        }
391    }
392
393    /// Get subscription counts.
394    pub fn counts(&self) -> SubscriptionCounts {
395        let total_subscribers: usize = self.groups.iter().map(|g| g.subscribers.len()).sum();
396        let groups_count = self.groups.len();
397        let sessions_count = self.session_subscribers.len();
398
399        // Estimate memory usage:
400        // - Each QueryGroup: ~256 bytes (name, args, auth, read_set, subscribers vec)
401        // - Each subscriber entry in the store: ~128 bytes
402        // - Each session mapping entry: ~64 bytes + subscriber ID vec
403        let estimated_bytes =
404            (groups_count * 256) + (total_subscribers * 128) + (sessions_count * 64);
405
406        SubscriptionCounts {
407            total: total_subscribers,
408            unique_queries: groups_count,
409            sessions: sessions_count,
410            memory_bytes: estimated_bytes,
411        }
412    }
413
414    /// Get group count.
415    pub fn group_count(&self) -> usize {
416        self.groups.len()
417    }
418}
419
420/// Subscription count statistics.
421#[derive(Debug, Clone, Default)]
422pub struct SubscriptionCounts {
423    pub total: usize,
424    pub unique_queries: usize,
425    pub sessions: usize,
426    pub memory_bytes: usize,
427}
428
429#[cfg(test)]
430#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
431mod tests {
432    use super::*;
433    use forge_core::function::AuthContext;
434
435    #[test]
436    fn test_session_manager_create() {
437        let node_id = NodeId::new();
438        let manager = SessionManager::new(node_id);
439
440        let session = manager.create_session();
441        assert!(session.is_connected());
442
443        let retrieved = manager.get_session(session.id);
444        assert!(retrieved.is_some());
445    }
446
447    #[test]
448    fn test_session_manager_disconnect() {
449        let node_id = NodeId::new();
450        let manager = SessionManager::new(node_id);
451
452        let session = manager.create_session();
453        manager.disconnect_session(session.id);
454
455        let retrieved = manager.get_session(session.id).unwrap();
456        assert!(!retrieved.is_connected());
457    }
458
459    #[test]
460    fn test_subscription_manager_create() {
461        let manager = SubscriptionManager::new(50);
462        let session_id = SessionId::new();
463        let auth = AuthContext::unauthenticated();
464
465        let (group_id, _sub_id, is_new) = manager
466            .subscribe(
467                session_id,
468                "sub-1".to_string(),
469                "get_projects",
470                &serde_json::json!({}),
471                &auth,
472                &[],
473                &[],
474            )
475            .unwrap();
476
477        assert!(is_new);
478        assert!(manager.get_group(group_id).is_some());
479    }
480
481    #[test]
482    fn test_subscription_manager_coalescing() {
483        let manager = SubscriptionManager::new(50);
484        let session1 = SessionId::new();
485        let session2 = SessionId::new();
486        let auth = AuthContext::unauthenticated();
487
488        let (g1, _, is_new1) = manager
489            .subscribe(
490                session1,
491                "s1".to_string(),
492                "get_projects",
493                &serde_json::json!({}),
494                &auth,
495                &[],
496                &[],
497            )
498            .unwrap();
499        let (g2, _, is_new2) = manager
500            .subscribe(
501                session2,
502                "s2".to_string(),
503                "get_projects",
504                &serde_json::json!({}),
505                &auth,
506                &[],
507                &[],
508            )
509            .unwrap();
510
511        assert!(is_new1);
512        assert!(!is_new2);
513        assert_eq!(g1, g2);
514
515        // Group should have 2 subscribers
516        let subs = manager.get_group_subscribers(g1);
517        assert_eq!(subs.len(), 2);
518    }
519
520    #[test]
521    fn test_subscription_manager_limit() {
522        let manager = SubscriptionManager::new(2);
523        let session_id = SessionId::new();
524        let auth = AuthContext::unauthenticated();
525
526        manager
527            .subscribe(
528                session_id,
529                "s1".to_string(),
530                "q1",
531                &serde_json::json!({}),
532                &auth,
533                &[],
534                &[],
535            )
536            .unwrap();
537        manager
538            .subscribe(
539                session_id,
540                "s2".to_string(),
541                "q2",
542                &serde_json::json!({}),
543                &auth,
544                &[],
545                &[],
546            )
547            .unwrap();
548
549        let result = manager.subscribe(
550            session_id,
551            "s3".to_string(),
552            "q3",
553            &serde_json::json!({}),
554            &auth,
555            &[],
556            &[],
557        );
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_subscription_manager_remove_session() {
563        let manager = SubscriptionManager::new(50);
564        let session_id = SessionId::new();
565        let auth = AuthContext::unauthenticated();
566
567        manager
568            .subscribe(
569                session_id,
570                "s1".to_string(),
571                "q1",
572                &serde_json::json!({}),
573                &auth,
574                &[],
575                &[],
576            )
577            .unwrap();
578        manager
579            .subscribe(
580                session_id,
581                "s2".to_string(),
582                "q2",
583                &serde_json::json!({}),
584                &auth,
585                &[],
586                &[],
587            )
588            .unwrap();
589
590        let counts = manager.counts();
591        assert_eq!(counts.total, 2);
592
593        manager.remove_session_subscriptions(session_id);
594
595        let counts = manager.counts();
596        assert_eq!(counts.total, 0);
597        assert_eq!(counts.unique_queries, 0);
598    }
599}