use super::ids::IdAllocator;
use super::tables::{ExportTable, ImportTable, Value};
use super::variable_state::VariableStateManager;
use base64::{engine::general_purpose, Engine as _};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSnapshot {
pub session_id: String,
pub created_at: u64,
pub last_activity: u64,
pub version: u32,
pub next_positive_id: i64,
pub next_negative_id: i64,
pub imports: HashMap<i64, SerializableImportValue>,
pub exports: HashMap<i64, SerializableExportValue>,
pub variables: HashMap<String, Value>,
pub max_age_seconds: u64,
pub capabilities: Vec<String>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SerializableImportValue {
Value(Value),
StubReference(String), PromiseReference(String), }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SerializableExportValue {
Resolved(Value),
Rejected(Value),
StubReference(String), PromiseReference(String), }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResumeToken {
pub token_data: String,
pub session_id: String,
pub issued_at: u64,
pub expires_at: u64,
}
#[derive(Debug)]
pub struct ResumeTokenManager {
secret_key: Vec<u8>,
default_ttl: u64,
max_session_age: u64,
}
impl ResumeTokenManager {
pub fn new(secret_key: Vec<u8>) -> Self {
Self {
secret_key,
default_ttl: 3600, max_session_age: 86400, }
}
pub fn with_settings(secret_key: Vec<u8>, default_ttl: u64, max_session_age: u64) -> Self {
Self {
secret_key,
default_ttl,
max_session_age,
}
}
pub fn generate_secret_key() -> Vec<u8> {
use rand::RngCore;
let mut key = vec![0u8; 32];
rand::rng().fill_bytes(&mut key);
key
}
pub async fn create_snapshot(
&self,
session_id: String,
_allocator: &Arc<IdAllocator>,
_imports: &Arc<ImportTable>,
_exports: &Arc<ExportTable>,
variables: Option<&VariableStateManager>,
) -> Result<SessionSnapshot, ResumeTokenError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time should be after UNIX epoch")
.as_secs();
let serializable_imports = HashMap::new();
tracing::info!(session_id = %session_id, "Creating session snapshot");
let variables_map = if let Some(var_mgr) = variables {
var_mgr.export_variables().await
} else {
HashMap::new()
};
let snapshot = SessionSnapshot {
session_id: session_id.clone(),
created_at: now,
last_activity: now,
version: 1,
next_positive_id: 1,
next_negative_id: -1,
imports: serializable_imports,
exports: HashMap::new(),
variables: variables_map,
max_age_seconds: self.max_session_age,
capabilities: Vec::new(), };
Ok(snapshot)
}
pub fn generate_token(
&self,
snapshot: SessionSnapshot,
) -> Result<ResumeToken, ResumeTokenError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time should be after UNIX epoch")
.as_secs();
let expires_at = now + self.default_ttl;
let snapshot_data = serde_json::to_vec(&snapshot)
.map_err(|e| ResumeTokenError::SerializationError(e.to_string()))?;
let signature = self.sign_data(&snapshot_data);
let token_payload = TokenPayload {
snapshot: snapshot_data,
issued_at: now,
expires_at,
signature,
};
let token_bytes = serde_json::to_vec(&token_payload)
.map_err(|e| ResumeTokenError::SerializationError(e.to_string()))?;
let token_data = general_purpose::STANDARD.encode(&token_bytes);
Ok(ResumeToken {
token_data,
session_id: snapshot.session_id,
issued_at: now,
expires_at,
})
}
pub fn parse_token(&self, token: &ResumeToken) -> Result<SessionSnapshot, ResumeTokenError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time should be after UNIX epoch")
.as_secs();
if now > token.expires_at {
return Err(ResumeTokenError::TokenExpired);
}
let token_bytes = general_purpose::STANDARD
.decode(&token.token_data)
.map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
let token_payload: TokenPayload = serde_json::from_slice(&token_bytes)
.map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
let expected_signature = self.sign_data(&token_payload.snapshot);
if token_payload.signature != expected_signature {
return Err(ResumeTokenError::InvalidSignature);
}
let snapshot: SessionSnapshot = serde_json::from_slice(&token_payload.snapshot)
.map_err(|e| ResumeTokenError::InvalidToken(e.to_string()))?;
if snapshot.created_at + snapshot.max_age_seconds < now {
return Err(ResumeTokenError::SessionTooOld);
}
Ok(snapshot)
}
pub async fn restore_session(
&self,
snapshot: SessionSnapshot,
_allocator: &Arc<IdAllocator>,
_imports: &Arc<ImportTable>,
_exports: &Arc<ExportTable>,
variables: Option<&VariableStateManager>,
) -> Result<(), ResumeTokenError> {
tracing::info!(
session_id = %snapshot.session_id,
imports_count = snapshot.imports.len(),
exports_count = snapshot.exports.len(),
variables_count = snapshot.variables.len(),
"Restoring session from snapshot"
);
if let Some(var_mgr) = variables {
var_mgr
.import_variables(snapshot.variables)
.await
.map_err(|e| ResumeTokenError::RestoreError(e.to_string()))?;
}
tracing::info!(session_id = %snapshot.session_id, "Session restoration completed");
Ok(())
}
fn sign_data(&self, data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(&self.secret_key);
hasher.update(data);
general_purpose::STANDARD.encode(hasher.finalize())
}
}
#[derive(Debug, Serialize, Deserialize)]
struct TokenPayload {
snapshot: Vec<u8>,
issued_at: u64,
expires_at: u64,
signature: String,
}
#[derive(Debug, thiserror::Error)]
pub enum ResumeTokenError {
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Invalid token: {0}")]
InvalidToken(String),
#[error("Token has expired")]
TokenExpired,
#[error("Invalid token signature")]
InvalidSignature,
#[error("Session too old to resume")]
SessionTooOld,
#[error("Session restoration error: {0}")]
RestoreError(String),
#[error("Variable state error: {0}")]
VariableStateError(#[from] super::variable_state::VariableError),
}
#[derive(Debug)]
pub struct PersistentSessionManager {
token_manager: ResumeTokenManager,
active_sessions: Arc<tokio::sync::RwLock<HashMap<String, SessionInfo>>>,
}
#[derive(Debug, Clone)]
struct SessionInfo {
_session_id: String,
last_activity: u64,
_variable_manager: Option<Arc<VariableStateManager>>,
}
impl PersistentSessionManager {
pub fn new(token_manager: ResumeTokenManager) -> Self {
Self {
token_manager,
active_sessions: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn snapshot_session(
&self,
session_id: &str,
_allocator: &Arc<IdAllocator>,
_imports: &Arc<ImportTable>,
_exports: &Arc<ExportTable>,
variables: Option<&VariableStateManager>,
) -> Result<ResumeToken, ResumeTokenError> {
let snapshot = self
.token_manager
.create_snapshot(
session_id.to_string(),
_allocator,
_imports,
_exports,
variables,
)
.await?;
self.token_manager.generate_token(snapshot)
}
pub async fn restore_session(
&self,
token: &ResumeToken,
_allocator: &Arc<IdAllocator>,
_imports: &Arc<ImportTable>,
_exports: &Arc<ExportTable>,
variables: Option<&VariableStateManager>,
) -> Result<String, ResumeTokenError> {
let snapshot = self.token_manager.parse_token(token)?;
self.token_manager
.restore_session(snapshot.clone(), _allocator, _imports, _exports, variables)
.await?;
let mut sessions = self.active_sessions.write().await;
sessions.insert(
snapshot.session_id.clone(),
SessionInfo {
_session_id: snapshot.session_id.clone(),
last_activity: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time should be after UNIX epoch")
.as_secs(),
_variable_manager: None, },
);
Ok(snapshot.session_id)
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time should be after UNIX epoch")
.as_secs();
let mut sessions = self.active_sessions.write().await;
let initial_count = sessions.len();
sessions.retain(|_, info| {
now - info.last_activity < 3600 });
let cleaned_count = initial_count - sessions.len();
if cleaned_count > 0 {
tracing::info!(
cleaned_sessions = cleaned_count,
"Cleaned up expired sessions"
);
}
cleaned_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Number;
#[tokio::test]
async fn test_basic_resume_token_flow() {
let secret_key = ResumeTokenManager::generate_secret_key();
let manager = ResumeTokenManager::new(secret_key);
let mut variables = HashMap::new();
variables.insert("test_var".to_string(), Value::Number(Number::from(42)));
let snapshot = SessionSnapshot {
session_id: "test-session".to_string(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
last_activity: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
version: 1,
next_positive_id: 5,
next_negative_id: -3,
imports: HashMap::new(),
exports: HashMap::new(),
variables,
max_age_seconds: 3600,
capabilities: vec!["calculator".to_string()],
};
let token = manager.generate_token(snapshot.clone()).unwrap();
assert_eq!(token.session_id, "test-session");
let restored_snapshot = manager.parse_token(&token).unwrap();
assert_eq!(restored_snapshot.session_id, snapshot.session_id);
assert_eq!(restored_snapshot.variables.len(), 1);
if let Some(Value::Number(n)) = restored_snapshot.variables.get("test_var") {
assert_eq!(n.as_i64(), Some(42));
} else {
panic!("Expected test_var to be number 42");
}
}
#[tokio::test]
async fn test_token_expiration() {
let secret_key = ResumeTokenManager::generate_secret_key();
let manager = ResumeTokenManager::with_settings(secret_key, 0, 3600);
let snapshot = SessionSnapshot {
session_id: "test-session".to_string(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
last_activity: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
version: 1,
next_positive_id: 1,
next_negative_id: -1,
imports: HashMap::new(),
exports: HashMap::new(),
variables: HashMap::new(),
max_age_seconds: 3600,
capabilities: Vec::new(),
};
let token = manager.generate_token(snapshot).unwrap();
tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
let result = manager.parse_token(&token);
assert!(matches!(result, Err(ResumeTokenError::TokenExpired)));
}
#[tokio::test]
async fn test_invalid_signature() {
let secret_key1 = ResumeTokenManager::generate_secret_key();
let secret_key2 = ResumeTokenManager::generate_secret_key();
let manager1 = ResumeTokenManager::new(secret_key1);
let manager2 = ResumeTokenManager::new(secret_key2);
let snapshot = SessionSnapshot {
session_id: "test-session".to_string(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
last_activity: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
version: 1,
next_positive_id: 1,
next_negative_id: -1,
imports: HashMap::new(),
exports: HashMap::new(),
variables: HashMap::new(),
max_age_seconds: 3600,
capabilities: Vec::new(),
};
let token = manager1.generate_token(snapshot).unwrap();
let result = manager2.parse_token(&token);
assert!(matches!(result, Err(ResumeTokenError::InvalidSignature)));
}
#[tokio::test]
async fn test_persistent_session_manager() {
let secret_key = ResumeTokenManager::generate_secret_key();
let token_manager = ResumeTokenManager::new(secret_key);
let session_manager = PersistentSessionManager::new(token_manager);
let cleaned = session_manager.cleanup_expired_sessions().await;
assert_eq!(cleaned, 0); }
}