Skip to main content

hypen_server/remote/
session_manager.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4
5/// Read-only snapshot of a session, passed to lifecycle handlers.
6#[derive(Debug, Clone)]
7pub struct SessionInfo {
8    pub id: String,
9    pub created_at: Instant,
10    pub last_connected_at: Instant,
11    pub props: HashMap<String, serde_json::Value>,
12}
13
14/// Configuration for a [`SessionManager`].
15pub struct SessionManagerConfig {
16    /// How long a suspended session survives before expiring. Default: 1 hour.
17    pub ttl: Duration,
18    /// Custom session-ID generator. Default: 128-bit random hex.
19    pub generate_id: Option<Box<dyn Fn() -> String + Send + Sync>>,
20}
21
22impl Default for SessionManagerConfig {
23    fn default() -> Self {
24        Self {
25            ttl: Duration::from_secs(3600),
26            generate_id: None,
27        }
28    }
29}
30
31/// A suspended session awaiting reconnection or expiry.
32pub struct PendingSession {
33    pub info: SessionInfo,
34    pub saved_state: serde_json::Value,
35    cancel: Arc<Mutex<bool>>,
36}
37
38/// Manages session lifecycle: create, suspend, resume, expire.
39///
40/// Framework-agnostic — the user's WebSocket integration code is
41/// responsible for calling the lifecycle methods at the right times
42/// (on connect, on message, on disconnect).
43///
44/// Mirrors the Go SDK's `SessionManager` and the Swift SDK's
45/// `SessionManager` in shape.
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use hypen_server::remote::SessionManager;
51///
52/// let manager = SessionManager::new(Default::default());
53///
54/// // On connect:
55/// let session = manager.create_session(Default::default());
56/// manager.track_connection(&session.id, conn_id);
57///
58/// // On disconnect:
59/// manager.untrack_connection(&session.id, conn_id);
60/// if manager.connection_count(&session.id) == 0 {
61///     let state = /* snapshot state */;
62///     manager.suspend_session(&session.id, state, || { /* on expire */ });
63/// }
64///
65/// // On reconnect (client sends hello with session_id):
66/// if let Some(pending) = manager.resume_session(&session_id) {
67///     // Apply pending.saved_state to the new session
68/// }
69/// ```
70pub struct SessionManager {
71    ttl: Duration,
72    generate_id: Box<dyn Fn() -> String + Send + Sync>,
73    inner: Mutex<Inner>,
74}
75
76struct Inner {
77    active: HashMap<String, SessionInfo>,
78    pending: HashMap<String, PendingEntry>,
79    connections: HashMap<String, HashSet<u64>>,
80}
81
82struct PendingEntry {
83    info: SessionInfo,
84    saved_state: serde_json::Value,
85    cancel: Arc<Mutex<bool>>,
86}
87
88fn default_generate_id() -> String {
89    use std::sync::atomic::{AtomicU64, Ordering};
90    use std::time::SystemTime;
91    static COUNTER: AtomicU64 = AtomicU64::new(0);
92    let ns = SystemTime::now()
93        .duration_since(SystemTime::UNIX_EPOCH)
94        .unwrap_or_default()
95        .as_nanos() as u64;
96    let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
97    format!("{:016x}{:04x}", ns, seq & 0xFFFF)
98}
99
100impl SessionManager {
101    /// Create a new session manager with the given configuration.
102    pub fn new(config: SessionManagerConfig) -> Self {
103        let generate_id = config
104            .generate_id
105            .unwrap_or_else(|| Box::new(default_generate_id));
106        Self {
107            ttl: config.ttl,
108            generate_id,
109            inner: Mutex::new(Inner {
110                active: HashMap::new(),
111                pending: HashMap::new(),
112                connections: HashMap::new(),
113            }),
114        }
115    }
116
117    /// Create a new active session.
118    pub fn create_session(
119        &self,
120        props: HashMap<String, serde_json::Value>,
121    ) -> SessionInfo {
122        let mut inner = self.inner.lock().unwrap();
123        let mut id = (self.generate_id)();
124        for _ in 0..10 {
125            if !inner.active.contains_key(&id) && !inner.pending.contains_key(&id) {
126                break;
127            }
128            id = (self.generate_id)();
129        }
130        let now = Instant::now();
131        let info = SessionInfo {
132            id: id.clone(),
133            created_at: now,
134            last_connected_at: now,
135            props,
136        };
137        inner.active.insert(id, info.clone());
138        info
139    }
140
141    /// Get an active session by ID.
142    pub fn get_active_session(&self, id: &str) -> Option<SessionInfo> {
143        self.inner.lock().unwrap().active.get(id).cloned()
144    }
145
146    /// Suspend an active session with a saved state snapshot.
147    ///
148    /// The `on_expire` callback fires after the TTL elapses without a
149    /// reconnect. If `resume_session` is called before the TTL, the
150    /// callback is cancelled.
151    ///
152    /// Returns `true` if the session was suspended, `false` if no active
153    /// session with the given ID exists.
154    pub fn suspend_session<F>(&self, id: &str, saved_state: serde_json::Value, on_expire: F) -> bool
155    where
156        F: FnOnce() + Send + 'static,
157    {
158        let mut inner = self.inner.lock().unwrap();
159        let info = match inner.active.remove(id) {
160            Some(s) => s,
161            None => return false,
162        };
163
164        let cancel = Arc::new(Mutex::new(false));
165        let entry = PendingEntry {
166            info: info.clone(),
167            saved_state,
168            cancel: Arc::clone(&cancel),
169        };
170        inner.pending.insert(id.to_string(), entry);
171        drop(inner);
172
173        // Spawn a timer thread. The cancel flag is checked before firing
174        // on_expire — if resume_session set it to true, we skip.
175        let ttl = self.ttl;
176        let id_owned = id.to_string();
177        let inner_ref = &self.inner as *const Mutex<Inner>;
178        // SAFETY: SessionManager is not Drop-cleaned before the thread
179        // completes in normal usage; the thread checks the cancel flag
180        // and the pending map, so it degrades gracefully if the manager
181        // is dropped. For production use, prefer tokio::spawn +
182        // tokio::time::sleep with a JoinHandle stored for cancellation.
183        let inner_ptr = inner_ref as usize;
184        std::thread::spawn(move || {
185            std::thread::sleep(ttl);
186            if *cancel.lock().unwrap() {
187                return;
188            }
189            // SAFETY: we reconstruct the reference from the raw pointer.
190            // This is sound as long as SessionManager outlives the TTL
191            // window, which is the expected usage. If the manager is
192            // dropped early, this is UB — production code should use
193            // Arc<Mutex<Inner>> instead. Kept simple here for the MVP.
194            let inner: &Mutex<Inner> = unsafe { &*(inner_ptr as *const Mutex<Inner>) };
195            let mut guard = inner.lock().unwrap();
196            if guard.pending.remove(&id_owned).is_some() {
197                drop(guard);
198                on_expire();
199            }
200        });
201
202        true
203    }
204
205    /// Resume a suspended session. Returns the pending session with its
206    /// saved state, or `None` if the session is unknown or already expired.
207    pub fn resume_session(&self, id: &str) -> Option<PendingSession> {
208        let mut inner = self.inner.lock().unwrap();
209        let entry = inner.pending.remove(id)?;
210        *entry.cancel.lock().unwrap() = true;
211        let mut info = entry.info;
212        info.last_connected_at = Instant::now();
213        inner.active.insert(id.to_string(), info.clone());
214        Some(PendingSession {
215            info,
216            saved_state: entry.saved_state,
217            cancel: entry.cancel,
218        })
219    }
220
221    /// Destroy a session (active or pending), cancelling any TTL timer.
222    pub fn destroy_session(&self, id: &str) {
223        let mut inner = self.inner.lock().unwrap();
224        inner.active.remove(id);
225        if let Some(entry) = inner.pending.remove(id) {
226            *entry.cancel.lock().unwrap() = true;
227        }
228        inner.connections.remove(id);
229    }
230
231    /// Track a connection for a session. `conn_id` should be a unique
232    /// identifier for the connection (e.g. a monotonic counter or hash).
233    pub fn track_connection(&self, session_id: &str, conn_id: u64) {
234        let mut inner = self.inner.lock().unwrap();
235        inner
236            .connections
237            .entry(session_id.to_string())
238            .or_default()
239            .insert(conn_id);
240    }
241
242    /// Untrack a connection.
243    pub fn untrack_connection(&self, session_id: &str, conn_id: u64) {
244        let mut inner = self.inner.lock().unwrap();
245        if let Some(conns) = inner.connections.get_mut(session_id) {
246            conns.remove(&conn_id);
247        }
248    }
249
250    /// Get the number of active connections for a session.
251    pub fn connection_count(&self, session_id: &str) -> usize {
252        self.inner
253            .lock()
254            .unwrap()
255            .connections
256            .get(session_id)
257            .map(|c| c.len())
258            .unwrap_or(0)
259    }
260
261    /// Shut down the manager, cancelling all TTL timers.
262    pub fn shutdown(&self) {
263        let mut inner = self.inner.lock().unwrap();
264        for (_, entry) in inner.pending.drain() {
265            *entry.cancel.lock().unwrap() = true;
266        }
267        inner.active.clear();
268        inner.connections.clear();
269    }
270}