1use super::ids::IdAllocator;
5use super::tables::{ExportTable, ImportTable, Value};
6use super::variable_state::VariableStateManager;
7use base64::{engine::general_purpose, Engine as _};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionSnapshot {
17 pub session_id: String,
19 pub created_at: u64,
20 pub last_activity: u64,
21 pub version: u32,
22
23 pub next_positive_id: i64,
25 pub next_negative_id: i64,
26
27 pub imports: HashMap<i64, SerializableImportValue>,
29
30 pub exports: HashMap<i64, SerializableExportValue>,
32
33 pub variables: HashMap<String, Value>,
35
36 pub max_age_seconds: u64,
38 pub capabilities: Vec<String>, }
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub enum SerializableImportValue {
44 Value(Value),
45 StubReference(String), PromiseReference(String), }
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum SerializableExportValue {
52 Resolved(Value),
53 Rejected(Value),
54 StubReference(String), PromiseReference(String), }
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ResumeToken {
61 pub token_data: String,
63 pub session_id: String,
65 pub issued_at: u64,
67 pub expires_at: u64,
69}
70
71#[derive(Debug)]
73pub struct ResumeTokenManager {
74 secret_key: Vec<u8>,
76 default_ttl: u64,
78 max_session_age: u64,
80}
81
82impl ResumeTokenManager {
83 pub fn new(secret_key: Vec<u8>) -> Self {
85 Self {
86 secret_key,
87 default_ttl: 3600, max_session_age: 86400, }
90 }
91
92 pub fn with_settings(secret_key: Vec<u8>, default_ttl: u64, max_session_age: u64) -> Self {
94 Self {
95 secret_key,
96 default_ttl,
97 max_session_age,
98 }
99 }
100
101 pub fn generate_secret_key() -> Vec<u8> {
103 use rand::RngCore;
104 let mut key = vec![0u8; 32];
105 rand::rng().fill_bytes(&mut key);
106 key
107 }
108
109 pub async fn create_snapshot(
111 &self,
112 session_id: String,
113 _allocator: &Arc<IdAllocator>,
114 _imports: &Arc<ImportTable>,
115 _exports: &Arc<ExportTable>,
116 variables: Option<&VariableStateManager>,
117 ) -> Result<SessionSnapshot, ResumeTokenError> {
118 let now = SystemTime::now()
119 .duration_since(UNIX_EPOCH)
120 .expect("System time should be after UNIX epoch")
121 .as_secs();
122
123 let serializable_imports = HashMap::new();
125
126 tracing::info!(session_id = %session_id, "Creating session snapshot");
129
130 let variables_map = if let Some(var_mgr) = variables {
132 var_mgr.export_variables().await
133 } else {
134 HashMap::new()
135 };
136
137 let snapshot = SessionSnapshot {
138 session_id: session_id.clone(),
139 created_at: now,
140 last_activity: now,
141 version: 1, next_positive_id: 1,
145 next_negative_id: -1,
146
147 imports: serializable_imports,
148 exports: HashMap::new(), variables: variables_map,
151
152 max_age_seconds: self.max_session_age,
153 capabilities: Vec::new(), };
155
156 Ok(snapshot)
157 }
158
159 pub fn generate_token(
161 &self,
162 snapshot: SessionSnapshot,
163 ) -> Result<ResumeToken, ResumeTokenError> {
164 let now = SystemTime::now()
165 .duration_since(UNIX_EPOCH)
166 .expect("System time should be after UNIX epoch")
167 .as_secs();
168
169 let expires_at = now + self.default_ttl;
170
171 let snapshot_data = serde_json::to_vec(&snapshot)
173 .map_err(|e| ResumeTokenError::SerializationError(e.to_string()))?;
174
175 let signature = self.sign_data(&snapshot_data);
177 let token_payload = TokenPayload {
178 snapshot: snapshot_data,
179 issued_at: now,
180 expires_at,
181 signature,
182 };
183
184 let token_bytes = serde_json::to_vec(&token_payload)
185 .map_err(|e| ResumeTokenError::SerializationError(e.to_string()))?;
186
187 let token_data = general_purpose::STANDARD.encode(&token_bytes);
188
189 Ok(ResumeToken {
190 token_data,
191 session_id: snapshot.session_id,
192 issued_at: now,
193 expires_at,
194 })
195 }
196
197 pub fn parse_token(&self, token: &ResumeToken) -> Result<SessionSnapshot, ResumeTokenError> {
199 let now = SystemTime::now()
200 .duration_since(UNIX_EPOCH)
201 .expect("System time should be after UNIX epoch")
202 .as_secs();
203
204 if now > token.expires_at {
206 return Err(ResumeTokenError::TokenExpired);
207 }
208
209 let token_bytes = general_purpose::STANDARD
211 .decode(&token.token_data)
212 .map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
213
214 let token_payload: TokenPayload = serde_json::from_slice(&token_bytes)
215 .map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
216
217 let expected_signature = self.sign_data(&token_payload.snapshot);
219 if token_payload.signature != expected_signature {
220 return Err(ResumeTokenError::InvalidSignature);
221 }
222
223 let snapshot: SessionSnapshot = serde_json::from_slice(&token_payload.snapshot)
225 .map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
226
227 if snapshot.created_at + snapshot.max_age_seconds < now {
229 return Err(ResumeTokenError::SessionTooOld);
230 }
231
232 Ok(snapshot)
233 }
234
235 pub async fn restore_session(
237 &self,
238 snapshot: SessionSnapshot,
239 _allocator: &Arc<IdAllocator>,
240 _imports: &Arc<ImportTable>,
241 _exports: &Arc<ExportTable>,
242 variables: Option<&VariableStateManager>,
243 ) -> Result<(), ResumeTokenError> {
244 tracing::info!(
245 session_id = %snapshot.session_id,
246 imports_count = snapshot.imports.len(),
247 exports_count = snapshot.exports.len(),
248 variables_count = snapshot.variables.len(),
249 "Restoring session from snapshot"
250 );
251
252 if let Some(var_mgr) = variables {
254 var_mgr
255 .import_variables(snapshot.variables)
256 .await
257 .map_err(|e| ResumeTokenError::RestoreError(e.to_string()))?;
258 }
259
260 tracing::info!(session_id = %snapshot.session_id, "Session restoration completed");
267 Ok(())
268 }
269
270 fn sign_data(&self, data: &[u8]) -> String {
272 let mut hasher = Sha256::new();
273 hasher.update(&self.secret_key);
274 hasher.update(data);
275 general_purpose::STANDARD.encode(hasher.finalize())
276 }
277}
278
279#[derive(Debug, Serialize, Deserialize)]
281struct TokenPayload {
282 snapshot: Vec<u8>,
283 issued_at: u64,
284 expires_at: u64,
285 signature: String,
286}
287
288#[derive(Debug, thiserror::Error)]
290pub enum ResumeTokenError {
291 #[error("Serialization error: {0}")]
292 SerializationError(String),
293
294 #[error("Invalid token: {0}")]
295 InvalidToken(String),
296
297 #[error("Token has expired")]
298 TokenExpired,
299
300 #[error("Invalid token signature")]
301 InvalidSignature,
302
303 #[error("Session too old to resume")]
304 SessionTooOld,
305
306 #[error("Session restoration error: {0}")]
307 RestoreError(String),
308
309 #[error("Variable state error: {0}")]
310 VariableStateError(#[from] super::variable_state::VariableError),
311}
312
313#[derive(Debug)]
315pub struct PersistentSessionManager {
316 token_manager: ResumeTokenManager,
317 active_sessions: Arc<tokio::sync::RwLock<HashMap<String, SessionInfo>>>,
318}
319
320#[derive(Debug, Clone)]
321struct SessionInfo {
322 _session_id: String,
323 last_activity: u64,
324 _variable_manager: Option<Arc<VariableStateManager>>,
325}
326
327impl PersistentSessionManager {
328 pub fn new(token_manager: ResumeTokenManager) -> Self {
330 Self {
331 token_manager,
332 active_sessions: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
333 }
334 }
335
336 pub async fn snapshot_session(
338 &self,
339 session_id: &str,
340 _allocator: &Arc<IdAllocator>,
341 _imports: &Arc<ImportTable>,
342 _exports: &Arc<ExportTable>,
343 variables: Option<&VariableStateManager>,
344 ) -> Result<ResumeToken, ResumeTokenError> {
345 let snapshot = self
346 .token_manager
347 .create_snapshot(
348 session_id.to_string(),
349 _allocator,
350 _imports,
351 _exports,
352 variables,
353 )
354 .await?;
355
356 self.token_manager.generate_token(snapshot)
357 }
358
359 pub async fn restore_session(
361 &self,
362 token: &ResumeToken,
363 _allocator: &Arc<IdAllocator>,
364 _imports: &Arc<ImportTable>,
365 _exports: &Arc<ExportTable>,
366 variables: Option<&VariableStateManager>,
367 ) -> Result<String, ResumeTokenError> {
368 let snapshot = self.token_manager.parse_token(token)?;
369
370 self.token_manager
371 .restore_session(snapshot.clone(), _allocator, _imports, _exports, variables)
372 .await?;
373
374 let mut sessions = self.active_sessions.write().await;
376 sessions.insert(
377 snapshot.session_id.clone(),
378 SessionInfo {
379 _session_id: snapshot.session_id.clone(),
380 last_activity: SystemTime::now()
381 .duration_since(UNIX_EPOCH)
382 .expect("System time should be after UNIX epoch")
383 .as_secs(),
384 _variable_manager: None, },
386 );
387
388 Ok(snapshot.session_id)
389 }
390
391 pub async fn cleanup_expired_sessions(&self) -> usize {
393 let now = SystemTime::now()
394 .duration_since(UNIX_EPOCH)
395 .expect("System time should be after UNIX epoch")
396 .as_secs();
397
398 let mut sessions = self.active_sessions.write().await;
399 let initial_count = sessions.len();
400
401 sessions.retain(|_, info| {
402 now - info.last_activity < 3600 });
404
405 let cleaned_count = initial_count - sessions.len();
406 if cleaned_count > 0 {
407 tracing::info!(
408 cleaned_sessions = cleaned_count,
409 "Cleaned up expired sessions"
410 );
411 }
412
413 cleaned_count
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use serde_json::Number;
421
422 #[tokio::test]
423 async fn test_basic_resume_token_flow() {
424 let secret_key = ResumeTokenManager::generate_secret_key();
425 let manager = ResumeTokenManager::new(secret_key);
426
427 let mut variables = HashMap::new();
429 variables.insert("test_var".to_string(), Value::Number(Number::from(42)));
430
431 let snapshot = SessionSnapshot {
432 session_id: "test-session".to_string(),
433 created_at: SystemTime::now()
434 .duration_since(UNIX_EPOCH)
435 .unwrap()
436 .as_secs(),
437 last_activity: SystemTime::now()
438 .duration_since(UNIX_EPOCH)
439 .unwrap()
440 .as_secs(),
441 version: 1,
442 next_positive_id: 5,
443 next_negative_id: -3,
444 imports: HashMap::new(),
445 exports: HashMap::new(),
446 variables,
447 max_age_seconds: 3600,
448 capabilities: vec!["calculator".to_string()],
449 };
450
451 let token = manager.generate_token(snapshot.clone()).unwrap();
453 assert_eq!(token.session_id, "test-session");
454
455 let restored_snapshot = manager.parse_token(&token).unwrap();
457 assert_eq!(restored_snapshot.session_id, snapshot.session_id);
458 assert_eq!(restored_snapshot.variables.len(), 1);
459
460 if let Some(Value::Number(n)) = restored_snapshot.variables.get("test_var") {
461 assert_eq!(n.as_i64(), Some(42));
462 } else {
463 panic!("Expected test_var to be number 42");
464 }
465 }
466
467 #[tokio::test]
468 async fn test_token_expiration() {
469 let secret_key = ResumeTokenManager::generate_secret_key();
470 let manager = ResumeTokenManager::with_settings(secret_key, 0, 3600); let snapshot = SessionSnapshot {
473 session_id: "test-session".to_string(),
474 created_at: SystemTime::now()
475 .duration_since(UNIX_EPOCH)
476 .unwrap()
477 .as_secs(),
478 last_activity: SystemTime::now()
479 .duration_since(UNIX_EPOCH)
480 .unwrap()
481 .as_secs(),
482 version: 1,
483 next_positive_id: 1,
484 next_negative_id: -1,
485 imports: HashMap::new(),
486 exports: HashMap::new(),
487 variables: HashMap::new(),
488 max_age_seconds: 3600,
489 capabilities: Vec::new(),
490 };
491
492 let token = manager.generate_token(snapshot).unwrap();
493
494 tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
496
497 let result = manager.parse_token(&token);
498 assert!(matches!(result, Err(ResumeTokenError::TokenExpired)));
499 }
500
501 #[tokio::test]
502 async fn test_invalid_signature() {
503 let secret_key1 = ResumeTokenManager::generate_secret_key();
504 let secret_key2 = ResumeTokenManager::generate_secret_key();
505
506 let manager1 = ResumeTokenManager::new(secret_key1);
507 let manager2 = ResumeTokenManager::new(secret_key2);
508
509 let snapshot = SessionSnapshot {
510 session_id: "test-session".to_string(),
511 created_at: SystemTime::now()
512 .duration_since(UNIX_EPOCH)
513 .unwrap()
514 .as_secs(),
515 last_activity: SystemTime::now()
516 .duration_since(UNIX_EPOCH)
517 .unwrap()
518 .as_secs(),
519 version: 1,
520 next_positive_id: 1,
521 next_negative_id: -1,
522 imports: HashMap::new(),
523 exports: HashMap::new(),
524 variables: HashMap::new(),
525 max_age_seconds: 3600,
526 capabilities: Vec::new(),
527 };
528
529 let token = manager1.generate_token(snapshot).unwrap();
531
532 let result = manager2.parse_token(&token);
534 assert!(matches!(result, Err(ResumeTokenError::InvalidSignature)));
535 }
536
537 #[tokio::test]
538 async fn test_persistent_session_manager() {
539 let secret_key = ResumeTokenManager::generate_secret_key();
540 let token_manager = ResumeTokenManager::new(secret_key);
541 let session_manager = PersistentSessionManager::new(token_manager);
542
543 let cleaned = session_manager.cleanup_expired_sessions().await;
545 assert_eq!(cleaned, 0); }
547}