contextvm_sdk/transport/server/
session_store.rs1use std::num::NonZeroUsize;
11use std::sync::Arc;
12
13use lru::LruCache;
14use tokio::sync::RwLock;
15
16use crate::core::types::ClientSession;
17use crate::transport::server::ServerEventRouteStore;
18
19const LOG_TARGET: &str = "contextvm_sdk::transport::server::session_store";
20
21pub const DEFAULT_MAX_SESSIONS: usize = 1000;
26
27pub type EvictionCallback = Arc<dyn Fn(String) + Send + Sync>;
30
31#[derive(Clone)]
35pub struct SessionStore {
36 sessions: Arc<RwLock<LruCache<String, ClientSession>>>,
37 on_evicted: Option<EvictionCallback>,
38}
39
40impl Default for SessionStore {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SessionStore {
47 pub fn new() -> Self {
49 Self::with_capacity(DEFAULT_MAX_SESSIONS)
50 }
51
52 pub fn with_capacity(max_sessions: usize) -> Self {
54 Self {
55 sessions: Arc::new(RwLock::new(LruCache::new(
56 NonZeroUsize::new(max_sessions).unwrap_or(NonZeroUsize::new(1).unwrap()),
57 ))),
58 on_evicted: None,
59 }
60 }
61
62 pub fn set_eviction_callback(&mut self, cb: EvictionCallback) {
64 self.on_evicted = Some(cb);
65 }
66
67 pub fn eviction_callback(&self) -> Option<EvictionCallback> {
69 self.on_evicted.clone()
70 }
71
72 pub async fn get_or_create_session(
78 &self,
79 client_pubkey: &str,
80 is_encrypted: bool,
81 event_routes: &ServerEventRouteStore,
82 ) -> bool {
83 let on_evicted = self.on_evicted.clone();
84 let mut sessions = self.sessions.write().await;
85 if let Some(session) = sessions.get_mut(client_pubkey) {
86 session.is_encrypted = is_encrypted;
87 false
88 } else {
89 let new_session = ClientSession::new(is_encrypted);
90 let evicted = sessions.push(client_pubkey.to_string(), new_session);
91 Self::handle_eviction(
92 client_pubkey,
93 evicted,
94 &mut sessions,
95 on_evicted.as_ref(),
96 event_routes,
97 )
98 .await;
99 true
100 }
101 }
102
103 pub async fn get_session(&self, client_pubkey: &str) -> Option<SessionSnapshot> {
106 let sessions = self.sessions.read().await;
107 sessions.peek(client_pubkey).map(|s| SessionSnapshot {
108 is_initialized: s.is_initialized,
109 is_encrypted: s.is_encrypted,
110 has_sent_common_tags: s.has_sent_common_tags,
111 supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap,
112 })
113 }
114
115 pub async fn mark_initialized(&self, client_pubkey: &str) -> bool {
117 let mut sessions = self.sessions.write().await;
118 if let Some(session) = sessions.get_mut(client_pubkey) {
119 session.is_initialized = true;
120 true
121 } else {
122 false
123 }
124 }
125
126 pub async fn mark_common_tags_sent(&self, client_pubkey: &str) -> bool {
128 let mut sessions = self.sessions.write().await;
129 if let Some(session) = sessions.get_mut(client_pubkey) {
130 session.has_sent_common_tags = true;
131 true
132 } else {
133 false
134 }
135 }
136
137 pub async fn remove_session(&self, client_pubkey: &str) -> bool {
139 self.sessions.write().await.pop(client_pubkey).is_some()
140 }
141
142 pub async fn clear(&self) {
144 self.sessions.write().await.clear();
145 }
146
147 pub async fn session_count(&self) -> usize {
149 self.sessions.read().await.len()
150 }
151
152 pub async fn get_all_sessions(&self) -> Vec<(String, SessionSnapshot)> {
154 let sessions = self.sessions.read().await;
155 sessions
156 .iter()
157 .map(|(k, s)| {
158 (
159 k.clone(),
160 SessionSnapshot {
161 is_initialized: s.is_initialized,
162 is_encrypted: s.is_encrypted,
163 has_sent_common_tags: s.has_sent_common_tags,
164 supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap,
165 },
166 )
167 })
168 .collect()
169 }
170
171 pub(crate) async fn write(
173 &self,
174 ) -> tokio::sync::RwLockWriteGuard<'_, LruCache<String, ClientSession>> {
175 self.sessions.write().await
176 }
177
178 pub(crate) async fn read(
180 &self,
181 ) -> tokio::sync::RwLockReadGuard<'_, LruCache<String, ClientSession>> {
182 self.sessions.read().await
183 }
184
185 pub(crate) async fn handle_eviction(
192 inserted_key: &str,
193 evicted: Option<(String, ClientSession)>,
194 sessions: &mut LruCache<String, ClientSession>,
195 on_evicted: Option<&EvictionCallback>,
196 event_routes: &ServerEventRouteStore,
197 ) {
198 if let Some((evicted_key, evicted_session)) = evicted {
199 if evicted_key != inserted_key {
202 if event_routes
203 .has_active_routes_for_client(&evicted_key)
204 .await
205 {
206 tracing::warn!(
207 target: LOG_TARGET,
208 client_pubkey = %evicted_key,
209 "LRU eviction of session with active routes; recreating with clean state"
210 );
211 let _ = sessions.push(
215 evicted_key.clone(),
216 ClientSession::new(evicted_session.is_encrypted),
217 );
218 } else if let Some(cb) = on_evicted {
219 cb(evicted_key);
220 }
221 }
222 }
223 }
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct SessionSnapshot {
230 pub is_initialized: bool,
231 pub is_encrypted: bool,
232 pub has_sent_common_tags: bool,
233 pub supports_ephemeral_gift_wrap: bool,
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use serde_json::json;
240
241 fn routes() -> ServerEventRouteStore {
242 ServerEventRouteStore::new()
243 }
244
245 #[tokio::test]
246 async fn create_and_retrieve_session() {
247 let store = SessionStore::new();
248 let r = routes();
249
250 let created = store.get_or_create_session("client-1", true, &r).await;
251 assert!(created);
252
253 let snap = store.get_session("client-1").await.unwrap();
254 assert!(snap.is_encrypted);
255 assert!(!snap.is_initialized);
256 }
257
258 #[tokio::test]
259 async fn get_or_create_returns_existing() {
260 let store = SessionStore::new();
261 let r = routes();
262
263 let created = store.get_or_create_session("client-1", false, &r).await;
264 assert!(created);
265
266 let created2 = store.get_or_create_session("client-1", true, &r).await;
267 assert!(!created2);
268
269 let snap = store.get_session("client-1").await.unwrap();
270 assert!(snap.is_encrypted);
271 }
272
273 #[tokio::test]
274 async fn mark_initialized() {
275 let store = SessionStore::new();
276 let r = routes();
277 store.get_or_create_session("client-1", false, &r).await;
278
279 assert!(store.mark_initialized("client-1").await);
280 let snap = store.get_session("client-1").await.unwrap();
281 assert!(snap.is_initialized);
282 }
283
284 #[tokio::test]
285 async fn mark_initialized_unknown_returns_false() {
286 let store = SessionStore::new();
287 assert!(!store.mark_initialized("unknown").await);
288 }
289
290 #[tokio::test]
291 async fn remove_session() {
292 let store = SessionStore::new();
293 let r = routes();
294 store.get_or_create_session("client-1", false, &r).await;
295 assert!(store.remove_session("client-1").await);
296 assert!(store.get_session("client-1").await.is_none());
297 }
298
299 #[tokio::test]
300 async fn remove_unknown_returns_false() {
301 let store = SessionStore::new();
302 assert!(!store.remove_session("unknown").await);
303 }
304
305 #[tokio::test]
306 async fn clear_all_sessions() {
307 let store = SessionStore::new();
308 let r = routes();
309 store.get_or_create_session("client-1", false, &r).await;
310 store.get_or_create_session("client-2", true, &r).await;
311
312 store.clear().await;
313
314 assert_eq!(store.session_count().await, 0);
315 assert!(store.get_session("client-1").await.is_none());
316 assert!(store.get_session("client-2").await.is_none());
317 }
318
319 #[tokio::test]
320 async fn get_all_sessions() {
321 let store = SessionStore::new();
322 let r = routes();
323 store.get_or_create_session("client-1", false, &r).await;
324 store.get_or_create_session("client-2", true, &r).await;
325
326 let all = store.get_all_sessions().await;
327 assert_eq!(all.len(), 2);
328
329 let keys: Vec<&str> = all.iter().map(|(k, _)| k.as_str()).collect();
330 assert!(keys.contains(&"client-1"));
331 assert!(keys.contains(&"client-2"));
332 }
333
334 #[tokio::test]
337 async fn new_session_capability_fields_default_false() {
338 let store = SessionStore::new();
339 let r = routes();
340 store.get_or_create_session("client-1", false, &r).await;
341
342 let sessions = store.read().await;
343 let session = sessions.peek("client-1").unwrap();
344 assert!(!session.has_sent_common_tags);
345 assert!(!session.supports_encryption);
346 assert!(!session.supports_ephemeral_encryption);
347 assert!(!session.supports_oversized_transfer);
348 }
349
350 #[tokio::test]
351 async fn has_sent_common_tags_flag() {
352 let store = SessionStore::new();
353 let r = routes();
354 store.get_or_create_session("client-1", false, &r).await;
355
356 let mut sessions = store.write().await;
357 let session = sessions.get_mut("client-1").unwrap();
358 assert!(!session.has_sent_common_tags);
359 session.has_sent_common_tags = true;
360 assert!(session.has_sent_common_tags);
361 }
362
363 #[tokio::test]
364 async fn capability_or_assign_persists() {
365 let store = SessionStore::new();
366 let r = routes();
367 store.get_or_create_session("client-1", false, &r).await;
368
369 {
370 let mut sessions = store.write().await;
371 let session = sessions.get_mut("client-1").unwrap();
372 session.supports_encryption |= true;
373 session.supports_ephemeral_encryption |= false;
374 }
375
376 {
377 let mut sessions = store.write().await;
378 let session = sessions.get_mut("client-1").unwrap();
379 session.supports_encryption |= false;
380 session.supports_ephemeral_encryption |= true;
381 }
382
383 let sessions = store.read().await;
384 let session = sessions.peek("client-1").unwrap();
385 assert!(session.supports_encryption, "OR-assign must not downgrade");
386 assert!(session.supports_ephemeral_encryption);
387 assert!(!session.supports_oversized_transfer);
388 }
389
390 #[tokio::test]
391 async fn capability_fields_independent_per_client() {
392 let store = SessionStore::new();
393 let r = routes();
394 store.get_or_create_session("client-a", false, &r).await;
395 store.get_or_create_session("client-b", false, &r).await;
396
397 {
398 let mut sessions = store.write().await;
399 let sa = sessions.get_mut("client-a").unwrap();
400 sa.supports_encryption = true;
401 sa.has_sent_common_tags = true;
402 }
403
404 let sessions = store.read().await;
405 let sa = sessions.peek("client-a").unwrap();
406 let sb = sessions.peek("client-b").unwrap();
407 assert!(sa.supports_encryption);
408 assert!(sa.has_sent_common_tags);
409 assert!(!sb.supports_encryption);
410 assert!(!sb.has_sent_common_tags);
411 }
412
413 #[tokio::test]
414 async fn get_or_create_preserves_capability_fields() {
415 let store = SessionStore::new();
416 let r = routes();
417 store.get_or_create_session("client-1", false, &r).await;
418
419 {
420 let mut sessions = store.write().await;
421 let session = sessions.get_mut("client-1").unwrap();
422 session.supports_encryption = true;
423 session.has_sent_common_tags = true;
424 }
425
426 let created = store.get_or_create_session("client-1", true, &r).await;
427 assert!(!created);
428
429 let sessions = store.read().await;
430 let session = sessions.peek("client-1").unwrap();
431 assert!(session.supports_encryption);
432 assert!(session.has_sent_common_tags);
433 }
434
435 #[tokio::test]
436 async fn clear_resets_capability_fields() {
437 let store = SessionStore::new();
438 let r = routes();
439 store.get_or_create_session("client-1", false, &r).await;
440 {
441 let mut sessions = store.write().await;
442 let s = sessions.get_mut("client-1").unwrap();
443 s.supports_encryption = true;
444 }
445
446 store.clear().await;
447 store.get_or_create_session("client-1", false, &r).await;
448
449 let sessions = store.read().await;
450 let session = sessions.peek("client-1").unwrap();
451 assert!(!session.supports_encryption);
452 assert!(!session.has_sent_common_tags);
453 }
454
455 #[tokio::test]
458 async fn lru_eviction_drops_oldest_session() {
459 let store = SessionStore::with_capacity(3);
460 let r = routes();
461 store.get_or_create_session("a", false, &r).await;
462 store.get_or_create_session("b", false, &r).await;
463 store.get_or_create_session("c", false, &r).await;
464
465 store.get_or_create_session("d", false, &r).await;
466
467 assert!(
468 store.get_session("a").await.is_none(),
469 "a should be evicted"
470 );
471 assert!(store.get_session("b").await.is_some());
472 assert!(store.get_session("c").await.is_some());
473 assert!(store.get_session("d").await.is_some());
474 assert_eq!(store.session_count().await, 3);
475 }
476
477 #[tokio::test]
478 async fn eviction_callback_fires_on_lru_eviction() {
479 let evicted = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
480 let evicted_clone = evicted.clone();
481 let r = routes();
482
483 let mut store = SessionStore::with_capacity(2);
484 store.set_eviction_callback(Arc::new(move |pubkey| {
485 evicted_clone.lock().unwrap().push(pubkey);
486 }));
487
488 store.get_or_create_session("a", false, &r).await;
489 store.get_or_create_session("b", false, &r).await;
490 store.get_or_create_session("c", false, &r).await;
491
492 let evicted = evicted.lock().unwrap();
493 assert_eq!(evicted.len(), 1);
494 assert_eq!(evicted[0], "a");
495 }
496
497 #[tokio::test]
498 async fn eviction_safety_recreates_session_with_active_routes() {
499 let store = SessionStore::with_capacity(2);
500 let r = routes();
501 store.get_or_create_session("a", true, &r).await;
502 store.get_or_create_session("b", false, &r).await;
503
504 r.register("evt1".into(), "a".into(), json!(1), None).await;
506
507 store.get_or_create_session("c", false, &r).await;
510
511 let snap = store.get_session("a").await;
512 assert!(
513 snap.is_some(),
514 "session with active routes must survive eviction"
515 );
516 assert!(
518 store.get_session("b").await.is_none(),
519 "b should be evicted"
520 );
521 }
522
523 #[tokio::test]
524 async fn with_capacity_sets_limit() {
525 let store = SessionStore::with_capacity(5);
526 let r = routes();
527 for i in 0..10 {
528 store
529 .get_or_create_session(&format!("client-{i}"), false, &r)
530 .await;
531 }
532 assert_eq!(store.session_count().await, 5);
533 }
534}