Skip to main content

hypen_server/remote/
session_manager.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4
5/// Read-only snapshot of a session, passed to lifecycle handlers.
6#[derive(Debug, Clone)]
7pub struct SessionInfo {
8    pub id: String,
9    pub created_at: Instant,
10    pub last_connected_at: Instant,
11    pub props: HashMap<String, serde_json::Value>,
12}
13
14/// Configuration for a [`SessionManager`].
15pub struct SessionManagerConfig {
16    /// How long a suspended session survives before expiring. Default: 1 hour.
17    pub ttl: Duration,
18    /// Custom session-ID generator. Default: 128-bit random hex.
19    pub generate_id: Option<Box<dyn Fn() -> String + Send + Sync>>,
20}
21
22impl Default for SessionManagerConfig {
23    fn default() -> Self {
24        Self {
25            ttl: Duration::from_secs(3600),
26            generate_id: None,
27        }
28    }
29}
30
31/// A suspended session awaiting reconnection or expiry.
32pub struct PendingSession {
33    pub info: SessionInfo,
34    pub saved_state: serde_json::Value,
35    cancel: Arc<Mutex<bool>>,
36}
37
38/// Manages session lifecycle: create, suspend, resume, expire.
39///
40/// Framework-agnostic — the user's WebSocket integration code is
41/// responsible for calling the lifecycle methods at the right times
42/// (on connect, on message, on disconnect).
43///
44/// Mirrors the Go SDK's `SessionManager` and the Swift SDK's
45/// `SessionManager` in shape.
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use hypen_server::remote::SessionManager;
51///
52/// let manager = SessionManager::new(Default::default());
53///
54/// // On connect:
55/// let session = manager.create_session(Default::default());
56/// manager.track_connection(&session.id, conn_id);
57///
58/// // On disconnect:
59/// manager.untrack_connection(&session.id, conn_id);
60/// if manager.connection_count(&session.id) == 0 {
61///     let state = /* snapshot state */;
62///     manager.suspend_session(&session.id, state, || { /* on expire */ });
63/// }
64///
65/// // On reconnect (client sends hello with session_id):
66/// if let Some(pending) = manager.resume_session(&session_id) {
67///     // Apply pending.saved_state to the new session
68/// }
69/// ```
70pub struct SessionManager {
71    ttl: Duration,
72    generate_id: Box<dyn Fn() -> String + Send + Sync>,
73    inner: Mutex<Inner>,
74}
75
76struct Inner {
77    active: HashMap<String, SessionInfo>,
78    pending: HashMap<String, PendingEntry>,
79    connections: HashMap<String, HashSet<u64>>,
80}
81
82struct PendingEntry {
83    info: SessionInfo,
84    saved_state: serde_json::Value,
85    cancel: Arc<Mutex<bool>>,
86}
87
88fn default_generate_id() -> String {
89    use std::sync::atomic::{AtomicU64, Ordering};
90    use std::time::SystemTime;
91    static COUNTER: AtomicU64 = AtomicU64::new(0);
92    let ns = SystemTime::now()
93        .duration_since(SystemTime::UNIX_EPOCH)
94        .unwrap_or_default()
95        .as_nanos() as u64;
96    let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
97    format!("{:016x}{:04x}", ns, seq & 0xFFFF)
98}
99
100impl SessionManager {
101    /// Create a new session manager with the given configuration.
102    pub fn new(config: SessionManagerConfig) -> Self {
103        let generate_id = config
104            .generate_id
105            .unwrap_or_else(|| Box::new(default_generate_id));
106        Self {
107            ttl: config.ttl,
108            generate_id,
109            inner: Mutex::new(Inner {
110                active: HashMap::new(),
111                pending: HashMap::new(),
112                connections: HashMap::new(),
113            }),
114        }
115    }
116
117    /// Create a new active session.
118    pub fn create_session(&self, props: HashMap<String, serde_json::Value>) -> SessionInfo {
119        let mut inner = self.inner.lock().unwrap();
120        let mut id = (self.generate_id)();
121        for _ in 0..10 {
122            if !inner.active.contains_key(&id) && !inner.pending.contains_key(&id) {
123                break;
124            }
125            id = (self.generate_id)();
126        }
127        let now = Instant::now();
128        let info = SessionInfo {
129            id: id.clone(),
130            created_at: now,
131            last_connected_at: now,
132            props,
133        };
134        inner.active.insert(id, info.clone());
135        info
136    }
137
138    /// Get an active session by ID.
139    pub fn get_active_session(&self, id: &str) -> Option<SessionInfo> {
140        self.inner.lock().unwrap().active.get(id).cloned()
141    }
142
143    /// Suspend an active session with a saved state snapshot.
144    ///
145    /// The `on_expire` callback fires after the TTL elapses without a
146    /// reconnect. If `resume_session` is called before the TTL, the
147    /// callback is cancelled.
148    ///
149    /// Returns `true` if the session was suspended, `false` if no active
150    /// session with the given ID exists.
151    pub fn suspend_session<F>(&self, id: &str, saved_state: serde_json::Value, on_expire: F) -> bool
152    where
153        F: FnOnce() + Send + 'static,
154    {
155        let mut inner = self.inner.lock().unwrap();
156        let info = match inner.active.remove(id) {
157            Some(s) => s,
158            None => return false,
159        };
160
161        let cancel = Arc::new(Mutex::new(false));
162        let entry = PendingEntry {
163            info: info.clone(),
164            saved_state,
165            cancel: Arc::clone(&cancel),
166        };
167        inner.pending.insert(id.to_string(), entry);
168        drop(inner);
169
170        // Spawn a timer thread. The cancel flag is checked before firing
171        // on_expire — if resume_session set it to true, we skip.
172        let ttl = self.ttl;
173        let id_owned = id.to_string();
174        let inner_ref = &self.inner as *const Mutex<Inner>;
175        // SAFETY: SessionManager is not Drop-cleaned before the thread
176        // completes in normal usage; the thread checks the cancel flag
177        // and the pending map, so it degrades gracefully if the manager
178        // is dropped. For production use, prefer tokio::spawn +
179        // tokio::time::sleep with a JoinHandle stored for cancellation.
180        let inner_ptr = inner_ref as usize;
181        std::thread::spawn(move || {
182            std::thread::sleep(ttl);
183            if *cancel.lock().unwrap() {
184                return;
185            }
186            // SAFETY: we reconstruct the reference from the raw pointer.
187            // This is sound as long as SessionManager outlives the TTL
188            // window, which is the expected usage. If the manager is
189            // dropped early, this is UB — production code should use
190            // Arc<Mutex<Inner>> instead. Kept simple here for the MVP.
191            let inner: &Mutex<Inner> = unsafe { &*(inner_ptr as *const Mutex<Inner>) };
192            let mut guard = inner.lock().unwrap();
193            if guard.pending.remove(&id_owned).is_some() {
194                drop(guard);
195                on_expire();
196            }
197        });
198
199        true
200    }
201
202    /// Resume a suspended session. Returns the pending session with its
203    /// saved state, or `None` if the session is unknown or already expired.
204    pub fn resume_session(&self, id: &str) -> Option<PendingSession> {
205        let mut inner = self.inner.lock().unwrap();
206        let entry = inner.pending.remove(id)?;
207        *entry.cancel.lock().unwrap() = true;
208        let mut info = entry.info;
209        info.last_connected_at = Instant::now();
210        inner.active.insert(id.to_string(), info.clone());
211        Some(PendingSession {
212            info,
213            saved_state: entry.saved_state,
214            cancel: entry.cancel,
215        })
216    }
217
218    /// Destroy a session (active or pending), cancelling any TTL timer.
219    pub fn destroy_session(&self, id: &str) {
220        let mut inner = self.inner.lock().unwrap();
221        inner.active.remove(id);
222        if let Some(entry) = inner.pending.remove(id) {
223            *entry.cancel.lock().unwrap() = true;
224        }
225        inner.connections.remove(id);
226    }
227
228    /// Track a connection for a session. `conn_id` should be a unique
229    /// identifier for the connection (e.g. a monotonic counter or hash).
230    pub fn track_connection(&self, session_id: &str, conn_id: u64) {
231        let mut inner = self.inner.lock().unwrap();
232        inner
233            .connections
234            .entry(session_id.to_string())
235            .or_default()
236            .insert(conn_id);
237    }
238
239    /// Untrack a connection.
240    pub fn untrack_connection(&self, session_id: &str, conn_id: u64) {
241        let mut inner = self.inner.lock().unwrap();
242        if let Some(conns) = inner.connections.get_mut(session_id) {
243            conns.remove(&conn_id);
244        }
245    }
246
247    /// Get the number of active connections for a session.
248    pub fn connection_count(&self, session_id: &str) -> usize {
249        self.inner
250            .lock()
251            .unwrap()
252            .connections
253            .get(session_id)
254            .map(|c| c.len())
255            .unwrap_or(0)
256    }
257
258    /// Shut down the manager, cancelling all TTL timers.
259    pub fn shutdown(&self) {
260        let mut inner = self.inner.lock().unwrap();
261        for (_, entry) in inner.pending.drain() {
262            *entry.cancel.lock().unwrap() = true;
263        }
264        inner.active.clear();
265        inner.connections.clear();
266    }
267}