Skip to main content

contextvm_sdk/transport/server/
session_store.rs

1//! Server-side session store for managing client sessions.
2//!
3//! Uses an LRU cache bounded by `max_sessions` (default 1000, matching the TS SDK
4//! server session store).  When a new session would exceed capacity the
5//! least-recently-used session is evicted.  If the evicted session still has
6//! active routes in the correlation store it is recreated with clean state
7//! (eviction safety, matching TS SDK's `hasActiveRoutesForClient` check), and
8//! the optional eviction callback fires so external code can clean up resources.
9
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12
13use lru::LruCache;
14use tokio::sync::RwLock;
15
16use crate::core::types::ClientSession;
17use crate::transport::server::ServerEventRouteStore;
18
19const LOG_TARGET: &str = "contextvm_sdk::transport::server::session_store";
20
21/// Default maximum number of concurrent client sessions.
22///
23/// Matches the TS SDK's `SessionStore` default (`maxSessions ?? 1000`), not
24/// the broader `DEFAULT_LRU_SIZE` constant (5000) used elsewhere in the TS SDK.
25pub const DEFAULT_MAX_SESSIONS: usize = 1000;
26
27/// Callback invoked when a session is evicted from the LRU cache.
28/// Receives the evicted client's public key (hex).
29pub type EvictionCallback = Arc<dyn Fn(String) + Send + Sync>;
30
31/// Manages client sessions keyed by public key (hex).
32///
33/// Backed by an LRU cache so memory usage is bounded.
34#[derive(Clone)]
35pub struct SessionStore {
36    sessions: Arc<RwLock<LruCache<String, ClientSession>>>,
37    on_evicted: Option<EvictionCallback>,
38}
39
40impl Default for SessionStore {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SessionStore {
47    /// Create a store with the default capacity ([`DEFAULT_MAX_SESSIONS`]).
48    pub fn new() -> Self {
49        Self::with_capacity(DEFAULT_MAX_SESSIONS)
50    }
51
52    /// Create a store with a specific maximum number of sessions.
53    pub fn with_capacity(max_sessions: usize) -> Self {
54        Self {
55            sessions: Arc::new(RwLock::new(LruCache::new(
56                NonZeroUsize::new(max_sessions).unwrap_or(NonZeroUsize::new(1).unwrap()),
57            ))),
58            on_evicted: None,
59        }
60    }
61
62    /// Register a callback that fires when a session is evicted from the LRU.
63    pub fn set_eviction_callback(&mut self, cb: EvictionCallback) {
64        self.on_evicted = Some(cb);
65    }
66
67    /// Clone the eviction callback (cheap Arc clone) for use outside the lock.
68    pub fn eviction_callback(&self) -> Option<EvictionCallback> {
69        self.on_evicted.clone()
70    }
71
72    /// Get an existing session or create a new one. Returns `true` if a new session was created.
73    ///
74    /// `event_routes` is consulted during eviction safety: if the evicted client
75    /// still has active routes, the session is recreated with clean state
76    /// (matching TS SDK's `hasActiveRoutesForClient` check).
77    pub async fn get_or_create_session(
78        &self,
79        client_pubkey: &str,
80        is_encrypted: bool,
81        event_routes: &ServerEventRouteStore,
82    ) -> bool {
83        let on_evicted = self.on_evicted.clone();
84        let mut sessions = self.sessions.write().await;
85        if let Some(session) = sessions.get_mut(client_pubkey) {
86            session.is_encrypted = is_encrypted;
87            false
88        } else {
89            let new_session = ClientSession::new(is_encrypted);
90            let evicted = sessions.push(client_pubkey.to_string(), new_session);
91            Self::handle_eviction(
92                client_pubkey,
93                evicted,
94                &mut sessions,
95                on_evicted.as_ref(),
96                event_routes,
97            )
98            .await;
99            true
100        }
101    }
102
103    /// Get a read-only snapshot of session fields.
104    /// Returns `None` if the session does not exist.
105    pub async fn get_session(&self, client_pubkey: &str) -> Option<SessionSnapshot> {
106        let sessions = self.sessions.read().await;
107        sessions.peek(client_pubkey).map(|s| SessionSnapshot {
108            is_initialized: s.is_initialized,
109            is_encrypted: s.is_encrypted,
110            has_sent_common_tags: s.has_sent_common_tags,
111            supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap,
112        })
113    }
114
115    /// Mark a session as initialized. Returns `true` if the session existed.
116    pub async fn mark_initialized(&self, client_pubkey: &str) -> bool {
117        let mut sessions = self.sessions.write().await;
118        if let Some(session) = sessions.get_mut(client_pubkey) {
119            session.is_initialized = true;
120            true
121        } else {
122            false
123        }
124    }
125
126    /// Mark that common tags have been sent for this session.
127    pub async fn mark_common_tags_sent(&self, client_pubkey: &str) -> bool {
128        let mut sessions = self.sessions.write().await;
129        if let Some(session) = sessions.get_mut(client_pubkey) {
130            session.has_sent_common_tags = true;
131            true
132        } else {
133            false
134        }
135    }
136
137    /// Remove a session. Returns `true` if it existed.
138    pub async fn remove_session(&self, client_pubkey: &str) -> bool {
139        self.sessions.write().await.pop(client_pubkey).is_some()
140    }
141
142    /// Remove all sessions.
143    pub async fn clear(&self) {
144        self.sessions.write().await.clear();
145    }
146
147    /// Number of active sessions.
148    pub async fn session_count(&self) -> usize {
149        self.sessions.read().await.len()
150    }
151
152    /// Return a snapshot of all sessions as `(client_pubkey, snapshot)` pairs.
153    pub async fn get_all_sessions(&self) -> Vec<(String, SessionSnapshot)> {
154        let sessions = self.sessions.read().await;
155        sessions
156            .iter()
157            .map(|(k, s)| {
158                (
159                    k.clone(),
160                    SessionSnapshot {
161                        is_initialized: s.is_initialized,
162                        is_encrypted: s.is_encrypted,
163                        has_sent_common_tags: s.has_sent_common_tags,
164                        supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap,
165                    },
166                )
167            })
168            .collect()
169    }
170
171    /// Acquire write access to the underlying LRU cache (transport internals only).
172    pub(crate) async fn write(
173        &self,
174    ) -> tokio::sync::RwLockWriteGuard<'_, LruCache<String, ClientSession>> {
175        self.sessions.write().await
176    }
177
178    /// Acquire read access to the underlying LRU cache (transport internals only).
179    pub(crate) async fn read(
180        &self,
181    ) -> tokio::sync::RwLockReadGuard<'_, LruCache<String, ClientSession>> {
182        self.sessions.read().await
183    }
184
185    /// Handle a potential LRU eviction after inserting a session.
186    ///
187    /// If the evicted client still has active routes in the correlation store,
188    /// a clean session is re-inserted (eviction safety, matching TS SDK's
189    /// `hasActiveRoutesForClient` check).  The eviction callback fires only
190    /// for genuine, non-vetoed evictions.
191    pub(crate) async fn handle_eviction(
192        inserted_key: &str,
193        evicted: Option<(String, ClientSession)>,
194        sessions: &mut LruCache<String, ClientSession>,
195        on_evicted: Option<&EvictionCallback>,
196        event_routes: &ServerEventRouteStore,
197    ) {
198        if let Some((evicted_key, evicted_session)) = evicted {
199            // `push` also returns the old value when the *same* key is updated;
200            // only act when a *different* key was evicted due to capacity.
201            if evicted_key != inserted_key {
202                if event_routes
203                    .has_active_routes_for_client(&evicted_key)
204                    .await
205                {
206                    tracing::warn!(
207                        target: LOG_TARGET,
208                        client_pubkey = %evicted_key,
209                        "LRU eviction of session with active routes; recreating with clean state"
210                    );
211                    // Re-insert with clean state so the client isn't orphaned.
212                    // Skip the external callback — the session still exists
213                    // (matches TS SDK: vetoed evictions don't fire the callback).
214                    let _ = sessions.push(
215                        evicted_key.clone(),
216                        ClientSession::new(evicted_session.is_encrypted),
217                    );
218                } else if let Some(cb) = on_evicted {
219                    cb(evicted_key);
220                }
221            }
222        }
223    }
224}
225
226/// A lightweight snapshot of session state (avoids exposing the full `ClientSession`
227/// through the async API boundary).
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct SessionSnapshot {
230    pub is_initialized: bool,
231    pub is_encrypted: bool,
232    pub has_sent_common_tags: bool,
233    pub supports_ephemeral_gift_wrap: bool,
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use serde_json::json;
240
241    fn routes() -> ServerEventRouteStore {
242        ServerEventRouteStore::new()
243    }
244
245    #[tokio::test]
246    async fn create_and_retrieve_session() {
247        let store = SessionStore::new();
248        let r = routes();
249
250        let created = store.get_or_create_session("client-1", true, &r).await;
251        assert!(created);
252
253        let snap = store.get_session("client-1").await.unwrap();
254        assert!(snap.is_encrypted);
255        assert!(!snap.is_initialized);
256    }
257
258    #[tokio::test]
259    async fn get_or_create_returns_existing() {
260        let store = SessionStore::new();
261        let r = routes();
262
263        let created = store.get_or_create_session("client-1", false, &r).await;
264        assert!(created);
265
266        let created2 = store.get_or_create_session("client-1", true, &r).await;
267        assert!(!created2);
268
269        let snap = store.get_session("client-1").await.unwrap();
270        assert!(snap.is_encrypted);
271    }
272
273    #[tokio::test]
274    async fn mark_initialized() {
275        let store = SessionStore::new();
276        let r = routes();
277        store.get_or_create_session("client-1", false, &r).await;
278
279        assert!(store.mark_initialized("client-1").await);
280        let snap = store.get_session("client-1").await.unwrap();
281        assert!(snap.is_initialized);
282    }
283
284    #[tokio::test]
285    async fn mark_initialized_unknown_returns_false() {
286        let store = SessionStore::new();
287        assert!(!store.mark_initialized("unknown").await);
288    }
289
290    #[tokio::test]
291    async fn remove_session() {
292        let store = SessionStore::new();
293        let r = routes();
294        store.get_or_create_session("client-1", false, &r).await;
295        assert!(store.remove_session("client-1").await);
296        assert!(store.get_session("client-1").await.is_none());
297    }
298
299    #[tokio::test]
300    async fn remove_unknown_returns_false() {
301        let store = SessionStore::new();
302        assert!(!store.remove_session("unknown").await);
303    }
304
305    #[tokio::test]
306    async fn clear_all_sessions() {
307        let store = SessionStore::new();
308        let r = routes();
309        store.get_or_create_session("client-1", false, &r).await;
310        store.get_or_create_session("client-2", true, &r).await;
311
312        store.clear().await;
313
314        assert_eq!(store.session_count().await, 0);
315        assert!(store.get_session("client-1").await.is_none());
316        assert!(store.get_session("client-2").await.is_none());
317    }
318
319    #[tokio::test]
320    async fn get_all_sessions() {
321        let store = SessionStore::new();
322        let r = routes();
323        store.get_or_create_session("client-1", false, &r).await;
324        store.get_or_create_session("client-2", true, &r).await;
325
326        let all = store.get_all_sessions().await;
327        assert_eq!(all.len(), 2);
328
329        let keys: Vec<&str> = all.iter().map(|(k, _)| k.as_str()).collect();
330        assert!(keys.contains(&"client-1"));
331        assert!(keys.contains(&"client-2"));
332    }
333
334    // ── CEP-35 capability fields ────────────────────────────────
335
336    #[tokio::test]
337    async fn new_session_capability_fields_default_false() {
338        let store = SessionStore::new();
339        let r = routes();
340        store.get_or_create_session("client-1", false, &r).await;
341
342        let sessions = store.read().await;
343        let session = sessions.peek("client-1").unwrap();
344        assert!(!session.has_sent_common_tags);
345        assert!(!session.supports_encryption);
346        assert!(!session.supports_ephemeral_encryption);
347        assert!(!session.supports_oversized_transfer);
348    }
349
350    #[tokio::test]
351    async fn has_sent_common_tags_flag() {
352        let store = SessionStore::new();
353        let r = routes();
354        store.get_or_create_session("client-1", false, &r).await;
355
356        let mut sessions = store.write().await;
357        let session = sessions.get_mut("client-1").unwrap();
358        assert!(!session.has_sent_common_tags);
359        session.has_sent_common_tags = true;
360        assert!(session.has_sent_common_tags);
361    }
362
363    #[tokio::test]
364    async fn capability_or_assign_persists() {
365        let store = SessionStore::new();
366        let r = routes();
367        store.get_or_create_session("client-1", false, &r).await;
368
369        {
370            let mut sessions = store.write().await;
371            let session = sessions.get_mut("client-1").unwrap();
372            session.supports_encryption |= true;
373            session.supports_ephemeral_encryption |= false;
374        }
375
376        {
377            let mut sessions = store.write().await;
378            let session = sessions.get_mut("client-1").unwrap();
379            session.supports_encryption |= false;
380            session.supports_ephemeral_encryption |= true;
381        }
382
383        let sessions = store.read().await;
384        let session = sessions.peek("client-1").unwrap();
385        assert!(session.supports_encryption, "OR-assign must not downgrade");
386        assert!(session.supports_ephemeral_encryption);
387        assert!(!session.supports_oversized_transfer);
388    }
389
390    #[tokio::test]
391    async fn capability_fields_independent_per_client() {
392        let store = SessionStore::new();
393        let r = routes();
394        store.get_or_create_session("client-a", false, &r).await;
395        store.get_or_create_session("client-b", false, &r).await;
396
397        {
398            let mut sessions = store.write().await;
399            let sa = sessions.get_mut("client-a").unwrap();
400            sa.supports_encryption = true;
401            sa.has_sent_common_tags = true;
402        }
403
404        let sessions = store.read().await;
405        let sa = sessions.peek("client-a").unwrap();
406        let sb = sessions.peek("client-b").unwrap();
407        assert!(sa.supports_encryption);
408        assert!(sa.has_sent_common_tags);
409        assert!(!sb.supports_encryption);
410        assert!(!sb.has_sent_common_tags);
411    }
412
413    #[tokio::test]
414    async fn get_or_create_preserves_capability_fields() {
415        let store = SessionStore::new();
416        let r = routes();
417        store.get_or_create_session("client-1", false, &r).await;
418
419        {
420            let mut sessions = store.write().await;
421            let session = sessions.get_mut("client-1").unwrap();
422            session.supports_encryption = true;
423            session.has_sent_common_tags = true;
424        }
425
426        let created = store.get_or_create_session("client-1", true, &r).await;
427        assert!(!created);
428
429        let sessions = store.read().await;
430        let session = sessions.peek("client-1").unwrap();
431        assert!(session.supports_encryption);
432        assert!(session.has_sent_common_tags);
433    }
434
435    #[tokio::test]
436    async fn clear_resets_capability_fields() {
437        let store = SessionStore::new();
438        let r = routes();
439        store.get_or_create_session("client-1", false, &r).await;
440        {
441            let mut sessions = store.write().await;
442            let s = sessions.get_mut("client-1").unwrap();
443            s.supports_encryption = true;
444        }
445
446        store.clear().await;
447        store.get_or_create_session("client-1", false, &r).await;
448
449        let sessions = store.read().await;
450        let session = sessions.peek("client-1").unwrap();
451        assert!(!session.supports_encryption);
452        assert!(!session.has_sent_common_tags);
453    }
454
455    // ── LRU eviction ────────────────────────────────────────────
456
457    #[tokio::test]
458    async fn lru_eviction_drops_oldest_session() {
459        let store = SessionStore::with_capacity(3);
460        let r = routes();
461        store.get_or_create_session("a", false, &r).await;
462        store.get_or_create_session("b", false, &r).await;
463        store.get_or_create_session("c", false, &r).await;
464
465        store.get_or_create_session("d", false, &r).await;
466
467        assert!(
468            store.get_session("a").await.is_none(),
469            "a should be evicted"
470        );
471        assert!(store.get_session("b").await.is_some());
472        assert!(store.get_session("c").await.is_some());
473        assert!(store.get_session("d").await.is_some());
474        assert_eq!(store.session_count().await, 3);
475    }
476
477    #[tokio::test]
478    async fn eviction_callback_fires_on_lru_eviction() {
479        let evicted = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
480        let evicted_clone = evicted.clone();
481        let r = routes();
482
483        let mut store = SessionStore::with_capacity(2);
484        store.set_eviction_callback(Arc::new(move |pubkey| {
485            evicted_clone.lock().unwrap().push(pubkey);
486        }));
487
488        store.get_or_create_session("a", false, &r).await;
489        store.get_or_create_session("b", false, &r).await;
490        store.get_or_create_session("c", false, &r).await;
491
492        let evicted = evicted.lock().unwrap();
493        assert_eq!(evicted.len(), 1);
494        assert_eq!(evicted[0], "a");
495    }
496
497    #[tokio::test]
498    async fn eviction_safety_recreates_session_with_active_routes() {
499        let store = SessionStore::with_capacity(2);
500        let r = routes();
501        store.get_or_create_session("a", true, &r).await;
502        store.get_or_create_session("b", false, &r).await;
503
504        // Register an active route for client "a" in the correlation store
505        r.register("evt1".into(), "a".into(), json!(1), None).await;
506
507        // Adding "c" would normally evict "a", but eviction safety recreates it
508        // because "a" has active routes.
509        store.get_or_create_session("c", false, &r).await;
510
511        let snap = store.get_session("a").await;
512        assert!(
513            snap.is_some(),
514            "session with active routes must survive eviction"
515        );
516        // "b" was evicted instead (next LRU after "a" was re-inserted)
517        assert!(
518            store.get_session("b").await.is_none(),
519            "b should be evicted"
520        );
521    }
522
523    #[tokio::test]
524    async fn with_capacity_sets_limit() {
525        let store = SessionStore::with_capacity(5);
526        let r = routes();
527        for i in 0..10 {
528            store
529                .get_or_create_session(&format!("client-{i}"), false, &r)
530                .await;
531        }
532        assert_eq!(store.session_count().await, 5);
533    }
534}