use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::{AtomicU8, Ordering};
use crate::router::Extensions;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum SessionPhase {
Uninitialized = 0,
Initializing = 1,
Initialized = 2,
}
impl From<u8> for SessionPhase {
fn from(value: u8) -> Self {
match value {
0 => SessionPhase::Uninitialized,
1 => SessionPhase::Initializing,
2 => SessionPhase::Initialized,
_ => SessionPhase::Uninitialized,
}
}
}
#[derive(Clone)]
pub struct SessionState {
phase: Arc<AtomicU8>,
extensions: Arc<RwLock<Extensions>>,
}
impl Default for SessionState {
fn default() -> Self {
Self::new()
}
}
impl SessionState {
pub fn new() -> Self {
Self {
phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
extensions: Arc::new(RwLock::new(Extensions::new())),
}
}
pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
if let Ok(mut ext) = self.extensions.write() {
ext.insert(val);
}
}
pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
self.extensions
.read()
.ok()
.and_then(|ext| ext.get::<T>().cloned())
}
pub fn phase(&self) -> SessionPhase {
SessionPhase::from(self.phase.load(Ordering::Acquire))
}
pub fn is_initialized(&self) -> bool {
self.phase() == SessionPhase::Initialized
}
pub fn mark_initializing(&self) -> bool {
self.phase
.compare_exchange(
SessionPhase::Uninitialized as u8,
SessionPhase::Initializing as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
pub fn mark_initialized(&self) -> bool {
if self
.phase
.compare_exchange(
SessionPhase::Initializing as u8,
SessionPhase::Initialized as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
return true;
}
self.phase
.compare_exchange(
SessionPhase::Uninitialized as u8,
SessionPhase::Initialized as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
pub fn is_request_allowed(&self, method: &str) -> bool {
match self.phase() {
SessionPhase::Uninitialized => {
matches!(method, "initialize" | "ping" | "server/discover")
}
SessionPhase::Initializing | SessionPhase::Initialized => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_lifecycle() {
let session = SessionState::new();
assert_eq!(session.phase(), SessionPhase::Uninitialized);
assert!(!session.is_initialized());
assert!(session.is_request_allowed("initialize"));
assert!(session.is_request_allowed("ping"));
assert!(!session.is_request_allowed("tools/list"));
assert!(session.mark_initializing());
assert_eq!(session.phase(), SessionPhase::Initializing);
assert!(!session.is_initialized());
assert!(!session.mark_initializing());
assert!(session.is_request_allowed("tools/list"));
assert!(session.mark_initialized());
assert_eq!(session.phase(), SessionPhase::Initialized);
assert!(session.is_initialized());
assert!(!session.mark_initialized());
}
#[test]
fn test_session_clone_shares_state() {
let session1 = SessionState::new();
let session2 = session1.clone();
session1.mark_initializing();
assert_eq!(session2.phase(), SessionPhase::Initializing);
session2.mark_initialized();
assert_eq!(session1.phase(), SessionPhase::Initialized);
}
#[test]
fn test_session_extensions_insert_and_get() {
let session = SessionState::new();
session.insert(42u32);
assert_eq!(session.get::<u32>(), Some(42));
assert_eq!(session.get::<String>(), None);
}
#[test]
fn test_session_extensions_overwrite() {
let session = SessionState::new();
session.insert(42u32);
assert_eq!(session.get::<u32>(), Some(42));
session.insert(100u32);
assert_eq!(session.get::<u32>(), Some(100));
}
#[test]
fn test_session_extensions_multiple_types() {
let session = SessionState::new();
session.insert(42u32);
session.insert("hello".to_string());
session.insert(true);
assert_eq!(session.get::<u32>(), Some(42));
assert_eq!(session.get::<String>(), Some("hello".to_string()));
assert_eq!(session.get::<bool>(), Some(true));
}
#[test]
fn test_session_extensions_shared_across_clones() {
let session1 = SessionState::new();
let session2 = session1.clone();
session1.insert(42u32);
assert_eq!(session2.get::<u32>(), Some(42));
session2.insert("world".to_string());
assert_eq!(session1.get::<String>(), Some("world".to_string()));
}
#[test]
fn test_mark_initialized_from_uninitialized() {
let session = SessionState::new();
assert_eq!(session.phase(), SessionPhase::Uninitialized);
assert!(session.mark_initialized());
assert_eq!(session.phase(), SessionPhase::Initialized);
assert!(session.is_initialized());
assert!(session.is_request_allowed("tools/list"));
assert!(session.is_request_allowed("ping"));
}
#[test]
fn test_mark_initialized_idempotent_when_already_initialized() {
let session = SessionState::new();
session.mark_initializing();
session.mark_initialized();
assert_eq!(session.phase(), SessionPhase::Initialized);
assert!(!session.mark_initialized());
assert_eq!(session.phase(), SessionPhase::Initialized);
}
#[test]
fn test_session_extensions_custom_type() {
#[derive(Debug, Clone, PartialEq)]
struct UserClaims {
user_id: String,
role: String,
}
let session = SessionState::new();
session.insert(UserClaims {
user_id: "user123".to_string(),
role: "admin".to_string(),
});
let claims = session.get::<UserClaims>();
assert!(claims.is_some());
let claims = claims.unwrap();
assert_eq!(claims.user_id, "user123");
assert_eq!(claims.role, "admin");
}
}