use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use crate::McpError;
use crate::types::{LogLevel, LoggingNotification, ProgressNotification, ServerNotification};
use super::request::RequestContext;
type SessionStateMap = dashmap::DashMap<String, Arc<RwLock<HashMap<String, Value>>>>;
static SESSION_STATE: std::sync::LazyLock<SessionStateMap> =
std::sync::LazyLock::new(SessionStateMap::new);
#[derive(Debug)]
pub struct SessionStateGuard {
session_id: String,
}
impl SessionStateGuard {
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
}
}
pub fn session_id(&self) -> &str {
&self.session_id
}
}
impl Drop for SessionStateGuard {
fn drop(&mut self) {
cleanup_session_state(&self.session_id);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StateError {
NoSessionId,
SerializationFailed(String),
DeserializationFailed(String),
}
impl std::fmt::Display for StateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoSessionId => write!(f, "no session ID set on context"),
Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
}
}
}
impl std::error::Error for StateError {}
pub trait RichContextExt {
fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
fn remove_state(&self, key: &str) -> bool;
fn clear_state(&self);
fn has_state(&self, key: &str) -> bool;
fn debug(
&self,
message: impl Into<String> + Send,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn info(
&self,
message: impl Into<String> + Send,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn warning(
&self,
message: impl Into<String> + Send,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn error(
&self,
message: impl Into<String> + Send,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn log(
&self,
level: LogLevel,
message: impl Into<String> + Send,
logger: Option<String>,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn report_progress(
&self,
current: u64,
total: u64,
message: Option<&str>,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
fn report_progress_with_token(
&self,
token: impl Into<String> + Send,
current: u64,
total: Option<u64>,
message: Option<&str>,
) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
}
impl RichContextExt for RequestContext {
fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.try_get_state(key).ok().flatten()
}
fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
let Some(state) = SESSION_STATE.get(session_id) else {
return Ok(None);
};
let state_read = state.read();
let Some(value) = state_read.get(key) else {
return Ok(None);
};
serde_json::from_value(value.clone())
.map(Some)
.map_err(|e| StateError::DeserializationFailed(e.to_string()))
}
fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
self.try_set_state(key, value).is_ok()
}
fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
let json_value = serde_json::to_value(value)
.map_err(|e| StateError::SerializationFailed(e.to_string()))?;
let state = SESSION_STATE
.entry(session_id.clone())
.or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
state.write().insert(key.to_string(), json_value);
Ok(())
}
fn remove_state(&self, key: &str) -> bool {
let Some(ref session_id) = self.session_id else {
return false;
};
if let Some(state) = SESSION_STATE.get(session_id) {
state.write().remove(key);
return true;
}
false
}
fn clear_state(&self) {
if let Some(ref session_id) = self.session_id
&& let Some(state) = SESSION_STATE.get(session_id)
{
state.write().clear();
}
}
fn has_state(&self, key: &str) -> bool {
if let Some(ref session_id) = self.session_id
&& let Some(state) = SESSION_STATE.get(session_id)
{
return state.read().contains_key(key);
}
false
}
async fn debug(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
self.log(LogLevel::Debug, message, None).await
}
async fn info(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
self.log(LogLevel::Info, message, None).await
}
async fn warning(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
self.log(LogLevel::Warning, message, None).await
}
async fn error(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
self.log(LogLevel::Error, message, None).await
}
async fn log(
&self,
level: LogLevel,
message: impl Into<String> + Send,
logger: Option<String>,
) -> Result<(), McpError> {
let Some(s2c) = self.server_to_client() else {
return Ok(());
};
let notification = ServerNotification::Message(LoggingNotification {
level,
data: serde_json::Value::String(message.into()),
logger,
});
s2c.send_notification(notification).await
}
async fn report_progress(
&self,
current: u64,
total: u64,
message: Option<&str>,
) -> Result<(), McpError> {
self.report_progress_with_token(&self.request_id, current, Some(total), message)
.await
}
async fn report_progress_with_token(
&self,
token: impl Into<String> + Send,
current: u64,
total: Option<u64>,
message: Option<&str>,
) -> Result<(), McpError> {
let Some(s2c) = self.server_to_client() else {
return Ok(());
};
let notification = ServerNotification::Progress(ProgressNotification {
progress_token: token.into(),
progress: current,
total,
message: message.map(ToString::to_string),
});
s2c.send_notification(notification).await
}
}
pub fn cleanup_session_state(session_id: &str) {
SESSION_STATE.remove(session_id);
}
pub fn active_sessions_count() -> usize {
SESSION_STATE.len()
}
#[cfg(test)]
pub fn clear_all_session_state() {
SESSION_STATE.clear();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_set_state() {
let ctx = RequestContext::new().with_session_id("test-session-1");
assert!(ctx.set_state("counter", &42i32));
assert!(ctx.set_state("name", &"Alice".to_string()));
assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
assert_eq!(ctx.get_state::<i32>("missing"), None);
assert!(ctx.has_state("counter"));
assert!(!ctx.has_state("missing"));
assert!(ctx.remove_state("counter"));
assert_eq!(ctx.get_state::<i32>("counter"), None);
assert!(!ctx.has_state("counter"));
ctx.clear_state();
assert_eq!(ctx.get_state::<String>("name"), None);
cleanup_session_state("test-session-1");
}
#[test]
fn test_state_without_session() {
let ctx = RequestContext::new();
assert!(!ctx.set_state("key", &"value"));
assert_eq!(ctx.get_state::<String>("key"), None);
assert!(!ctx.has_state("key"));
assert_eq!(
ctx.try_set_state("key", &"value"),
Err(StateError::NoSessionId)
);
assert_eq!(
ctx.try_get_state::<String>("key"),
Err(StateError::NoSessionId)
);
}
#[test]
fn test_state_isolation() {
let ctx1 = RequestContext::new().with_session_id("session-iso-1");
let ctx2 = RequestContext::new().with_session_id("session-iso-2");
ctx1.set_state("value", &1i32);
ctx2.set_state("value", &2i32);
assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
cleanup_session_state("session-iso-1");
cleanup_session_state("session-iso-2");
}
#[test]
fn test_complex_types() {
let ctx = RequestContext::new().with_session_id("complex-session-1");
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct MyData {
count: i32,
items: Vec<String>,
}
let data = MyData {
count: 3,
items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
};
ctx.set_state("data", &data);
let retrieved: Option<MyData> = ctx.get_state("data");
assert_eq!(retrieved, Some(data));
cleanup_session_state("complex-session-1");
}
#[test]
fn test_session_state_guard() {
let session_id = "guard-test-session";
{
let _guard = SessionStateGuard::new(session_id);
let ctx = RequestContext::new().with_session_id(session_id);
ctx.set_state("key", &"value");
assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
assert!(active_sessions_count() > 0);
}
let ctx = RequestContext::new().with_session_id(session_id);
assert_eq!(ctx.get_state::<String>("key"), None);
}
#[test]
fn test_try_get_state_errors() {
let ctx = RequestContext::new().with_session_id("error-test-session");
ctx.set_state("number", &42i32);
let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
cleanup_session_state("error-test-session");
}
#[test]
fn test_state_error_display() {
assert_eq!(
StateError::NoSessionId.to_string(),
"no session ID set on context"
);
assert!(
StateError::SerializationFailed("test".into())
.to_string()
.contains("serialization failed")
);
assert!(
StateError::DeserializationFailed("test".into())
.to_string()
.contains("deserialization failed")
);
}
#[tokio::test]
async fn test_logging_without_server_to_client() {
let ctx = RequestContext::new().with_session_id("logging-test");
assert!(ctx.debug("debug message").await.is_ok());
assert!(ctx.info("info message").await.is_ok());
assert!(ctx.warning("warning message").await.is_ok());
assert!(ctx.error("error message").await.is_ok());
assert!(ctx.log(LogLevel::Notice, "notice", None).await.is_ok());
}
#[tokio::test]
async fn test_progress_without_server_to_client() {
let ctx = RequestContext::new().with_session_id("progress-test");
assert!(ctx.report_progress(50, 100, Some("halfway")).await.is_ok());
assert!(ctx.report_progress(100, 100, None).await.is_ok());
assert!(
ctx.report_progress_with_token("custom-token", 25, Some(100), Some("processing"))
.await
.is_ok()
);
}
}