use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::{RwLock, broadcast};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use turul_mcp_protocol::{ClientCapabilities, Implementation, McpVersion, ServerCapabilities};
use turul_mcp_session_storage::{SessionStorage, SessionStorageError, SessionView};
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type GetStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
type SetStateFn = Arc<dyn Fn(&str, Value) -> BoxFuture<()> + Send + Sync>;
type RemoveStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
#[derive(Clone)]
pub struct SessionContext {
pub session_id: String,
pub get_state: GetStateFn,
pub set_state: SetStateFn,
pub remove_state: RemoveStateFn,
pub is_initialized: Arc<dyn Fn() -> BoxFuture<bool> + Send + Sync>,
pub send_notification: Arc<dyn Fn(SessionEvent) -> BoxFuture<()> + Send + Sync>,
pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
pub extensions: HashMap<String, Value>,
}
impl SessionContext {
pub(crate) fn from_json_rpc_with_broadcaster(
json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
) -> Self {
let session_id = json_rpc_ctx.session_id.clone();
let broadcaster = json_rpc_ctx.broadcaster.clone();
let get_state = {
let storage = storage.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let storage = storage.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move {
match storage.get_session_state(&session_id, &key).await {
Ok(Some(value)) => Some(value),
Ok(None) => None,
Err(e) => {
tracing::warn!("Failed to get session state for key '{}': {}", key, e);
None
}
}
})
})
};
let set_state = {
let storage = storage.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
let storage = storage.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move {
if let Err(e) = storage.set_session_state(&session_id, &key, value).await {
tracing::error!("Failed to set session state for key '{}': {}", key, e);
}
})
})
};
let remove_state = {
let storage = storage.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let storage = storage.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move {
match storage.remove_session_state(&session_id, &key).await {
Ok(value) => value,
Err(e) => {
tracing::warn!(
"Failed to remove session state for key '{}': {}",
key,
e
);
None
}
}
})
})
};
let is_initialized = {
let storage = storage.clone();
let session_id = session_id.clone();
Arc::new(move || -> BoxFuture<bool> {
let storage = storage.clone();
let session_id = session_id.clone();
Box::pin(async move {
match storage.get_session(&session_id).await {
Ok(Some(session_info)) => session_info.is_initialized,
Ok(None) => {
tracing::warn!("Session {} not found in storage", session_id);
false
}
Err(e) => {
tracing::error!("Failed to check session initialization: {}", e);
false
}
}
})
})
};
let send_notification = {
let session_id = session_id.clone();
let broadcaster = broadcaster.clone();
Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
let session_id = session_id.clone();
let broadcaster = broadcaster.clone();
Box::pin(async move {
debug!(
"📨 SessionContext.send_notification() called for session {}: {:?}",
session_id, event
);
if let Some(broadcaster_any) = &broadcaster {
debug!(
"✅ NotificationBroadcaster available for session: {}",
session_id
);
match event {
SessionEvent::Notification(json_value) => {
debug!(
"🔧 Attempting to send notification via StreamManagerNotificationBroadcaster"
);
debug!("📦 Notification JSON: {}", json_value);
match parse_and_send_notification_with_broadcaster(
&session_id,
&json_value,
broadcaster_any,
)
.await
{
Ok(_) => debug!(
"✅ Bridge working: Successfully processed notification for session {}",
session_id
),
Err(e) => error!(
"❌ Bridge error: Failed to process notification for session {}: {}",
session_id, e
),
}
}
_ => {
debug!("⚠️ Non-notification event, ignoring: {:?}", event);
}
}
} else {
debug!("⚠️ No broadcaster available for session {}", session_id);
}
})
})
};
SessionContext {
session_id,
get_state,
set_state,
remove_state,
is_initialized,
send_notification,
broadcaster,
extensions: json_rpc_ctx.extensions,
}
}
pub fn has_broadcaster(&self) -> bool {
self.broadcaster.is_some()
}
pub fn get_raw_broadcaster(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
self.broadcaster.clone()
}
pub fn get_extension(&self, key: &str) -> Option<&Value> {
self.extensions.get(key)
}
pub fn get_typed_extension<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.extensions
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
#[cfg(feature = "test-utils")]
pub fn from_json_rpc_with_broadcaster_for_tests(
json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
) -> Self {
Self::from_json_rpc_with_broadcaster(json_rpc_ctx, storage)
}
pub async fn get_typed_state<T>(&self, key: &str) -> Option<T>
where
T: serde::de::DeserializeOwned,
{
(self.get_state)(key)
.await
.and_then(|v| serde_json::from_value(v).ok())
}
pub async fn set_typed_state<T>(&self, key: &str, value: T) -> Result<(), String>
where
T: serde::Serialize,
{
match serde_json::to_value(value) {
Ok(json_value) => {
(self.set_state)(key, json_value).await;
Ok(())
}
Err(e) => Err(format!("Failed to serialize value: {}", e)),
}
}
#[cfg(test)]
pub fn new_test() -> Self {
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
let state = Arc::new(RwLock::new(HashMap::<String, Value>::new()));
let get_state = {
let state = state.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let state = state.clone();
let key = key.to_string();
Box::pin(async move { state.read().await.get(&key).cloned() })
})
};
let set_state = {
let state = state.clone();
Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
let state = state.clone();
let key = key.to_string();
Box::pin(async move {
state.write().await.insert(key, value);
})
})
};
let remove_state = {
let state = state.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let state = state.clone();
let key = key.to_string();
Box::pin(async move { state.write().await.remove(&key) })
})
};
let is_initialized = Arc::new(|| -> BoxFuture<bool> { Box::pin(async { true }) });
let send_notification =
Arc::new(|_event: SessionEvent| -> BoxFuture<()> { Box::pin(async {}) });
SessionContext {
session_id: Uuid::now_v7().as_simple().to_string(),
get_state,
set_state,
remove_state,
is_initialized,
send_notification,
broadcaster: None,
extensions: HashMap::new(),
}
}
pub async fn notify(&self, event: SessionEvent) {
debug!(
"📨 SessionContext.notify() called for session {}: {:?}",
self.session_id, event
);
(self.send_notification)(event).await;
debug!("🚀 SessionContext.notify() send_notification closure completed");
}
pub async fn notify_progress(&self, progress_token: impl Into<String>, progress: u64) {
if self.has_broadcaster() {
debug!(
"🔔 notify_progress using NotificationBroadcaster for session: {}",
self.session_id
);
} else {
debug!(
"🔔 notify_progress using OLD SessionManager for session: {}",
self.session_id
);
}
let mut other = std::collections::HashMap::new();
other.insert(
"progressToken".to_string(),
serde_json::json!(progress_token.into()),
);
other.insert("progress".to_string(), serde_json::json!(progress));
let params = turul_mcp_protocol::RequestParams { meta: None, other };
let notification =
turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
.with_params(params);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn notify_progress_with_total(
&self,
progress_token: impl Into<String>,
progress: u64,
total: u64,
) {
let mut other = std::collections::HashMap::new();
other.insert(
"progressToken".to_string(),
serde_json::json!(progress_token.into()),
);
other.insert("progress".to_string(), serde_json::json!(progress));
other.insert("total".to_string(), serde_json::json!(total));
let params = turul_mcp_protocol::RequestParams { meta: None, other };
let notification =
turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
.with_params(params);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn notify_log(
&self,
level: turul_mcp_protocol::logging::LoggingLevel,
data: serde_json::Value,
logger: Option<String>,
meta: Option<std::collections::HashMap<String, serde_json::Value>>,
) {
let message_level = level;
if !self.should_log(message_level).await {
let threshold = self.get_logging_level().await;
debug!(
"🔕 Filtering out {:?} level message for session {} (threshold: {:?})",
message_level, self.session_id, threshold
);
return;
}
let threshold = self.get_logging_level().await;
debug!(
"📢 Sending {:?} level message to session {} (threshold: {:?})",
message_level, self.session_id, threshold
);
use turul_mcp_protocol::notifications::LoggingMessageNotification;
let mut notification = LoggingMessageNotification::new(message_level, data);
if let Some(logger) = logger {
notification = notification.with_logger(logger);
}
if let Some(meta) = meta {
notification = notification.with_meta(meta);
}
if self.has_broadcaster() {
debug!(
"🔔 notify_log using NotificationBroadcaster for session: {}",
self.session_id
);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
return;
} else {
debug!(
"🔔 notify_log using OLD SessionManager for session: {}",
self.session_id
);
}
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn notify_resources_changed(&self) {
let notification = turul_mcp_protocol::JsonRpcNotification::new(
"notifications/resources/list_changed".to_string(),
);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn notify_resource_updated(&self, uri: impl Into<String>) {
let mut other = std::collections::HashMap::new();
other.insert("uri".to_string(), serde_json::json!(uri.into()));
let params = turul_mcp_protocol::RequestParams { meta: None, other };
let notification = turul_mcp_protocol::JsonRpcNotification::new(
"notifications/resources/updated".to_string(),
)
.with_params(params);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn notify_tools_changed(&self) {
let notification = turul_mcp_protocol::JsonRpcNotification::new(
"notifications/tools/list_changed".to_string(),
);
self.notify(SessionEvent::Notification(
serde_json::to_value(notification).unwrap(),
))
.await;
}
pub async fn get_logging_level(&self) -> turul_mcp_protocol::logging::LoggingLevel {
use turul_mcp_protocol::logging::LoggingLevel;
if let Some(level_value) = (self.get_state)("mcp:logging:level").await {
if let Some(level_str) = level_value.as_str() {
match level_str {
"debug" => LoggingLevel::Debug,
"info" => LoggingLevel::Info,
"notice" => LoggingLevel::Notice,
"warning" => LoggingLevel::Warning,
"error" => LoggingLevel::Error,
"critical" => LoggingLevel::Critical,
"alert" => LoggingLevel::Alert,
"emergency" => LoggingLevel::Emergency,
_ => LoggingLevel::Info, }
} else {
LoggingLevel::Info }
} else {
LoggingLevel::Info }
}
pub async fn set_logging_level(&self, level: turul_mcp_protocol::logging::LoggingLevel) {
use turul_mcp_protocol::logging::LoggingLevel;
let level_str = match level {
LoggingLevel::Debug => "debug",
LoggingLevel::Info => "info",
LoggingLevel::Notice => "notice",
LoggingLevel::Warning => "warning",
LoggingLevel::Error => "error",
LoggingLevel::Critical => "critical",
LoggingLevel::Alert => "alert",
LoggingLevel::Emergency => "emergency",
};
(self.set_state)("mcp:logging:level", serde_json::json!(level_str)).await;
debug!(
"🎯 Set logging level for session {}: {:?}",
self.session_id, level
);
}
pub async fn should_log(
&self,
message_level: turul_mcp_protocol::logging::LoggingLevel,
) -> bool {
let session_threshold = self.get_logging_level().await;
message_level.should_log(session_threshold)
}
pub fn should_log_sync(
&self,
message_level: turul_mcp_protocol::logging::LoggingLevel,
) -> bool {
let session_level = futures::executor::block_on(self.get_logging_level());
message_level.should_log(session_level)
}
}
#[async_trait]
impl SessionView for SessionContext {
fn session_id(&self) -> &str {
&self.session_id
}
async fn get_state(&self, key: &str) -> Result<Option<Value>, String> {
Ok((self.get_state)(key).await)
}
async fn set_state(&self, key: &str, value: Value) -> Result<(), String> {
(self.set_state)(key, value).await;
Ok(())
}
async fn get_metadata(&self, key: &str) -> Result<Option<Value>, String> {
let metadata_key = format!("__meta__:{}", key);
Ok((self.get_state)(&metadata_key).await)
}
async fn set_metadata(&self, key: &str, value: Value) -> Result<(), String> {
let metadata_key = format!("__meta__:{}", key);
(self.set_state)(&metadata_key, value).await;
Ok(())
}
}
impl turul_mcp_builders::logging::LoggingTarget for SessionContext {
fn should_log(&self, level: turul_mcp_protocol::logging::LoggingLevel) -> bool {
self.should_log_sync(level)
}
fn notify_log(
&self,
level: turul_mcp_protocol::logging::LoggingLevel,
data: serde_json::Value,
logger: Option<String>,
meta: Option<std::collections::HashMap<String, serde_json::Value>>,
) {
let session_ctx = self.clone();
tokio::spawn(async move {
session_ctx.notify_log(level, data, logger, meta).await;
});
}
}
async fn parse_and_send_notification_with_broadcaster(
session_id: &str,
json_value: &Value,
broadcaster_any: &Arc<dyn std::any::Any + Send + Sync>,
) -> Result<(), String> {
debug!(
"🔍 Parsing notification JSON for session {}: {:?}",
session_id, json_value
);
use turul_http_mcp_server::notification_bridge::SharedNotificationBroadcaster;
use turul_mcp_protocol::notifications::{LoggingMessageNotification, ProgressNotification};
debug!(
"🔍 Attempting downcast for session {}, broadcaster type: {:?}",
session_id,
std::any::type_name::<SharedNotificationBroadcaster>()
);
if let Some(broadcaster) = broadcaster_any.downcast_ref::<SharedNotificationBroadcaster>() {
debug!(
"✅ Successfully downcast broadcaster for session {}",
session_id
);
if let Some(method) = json_value.get("method").and_then(|v| v.as_str()) {
match method {
"notifications/message" => {
debug!(
"📝 Message notification detected, deserializing directly to LoggingMessageNotification"
);
match serde_json::from_value::<LoggingMessageNotification>(json_value.clone()) {
Ok(notification) => {
debug!(
"✅ Successfully deserialized LoggingMessageNotification: level={:?}, logger={:?}",
notification.params.level, notification.params.logger
);
debug!(
"🔧 About to call broadcaster.send_message_notification() for session {}",
session_id
);
match broadcaster
.send_message_notification(session_id, notification)
.await
{
Ok(()) => {
debug!(
"🎉 SUCCESS: LoggingMessageNotification sent to StreamManager for session {}",
session_id
);
debug!(
"🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
);
return Ok(());
}
Err(e) => {
error!(
"❌ Failed to send LoggingMessageNotification to StreamManager: {}",
e
);
return Err(format!(
"Failed to send LoggingMessageNotification: {}",
e
));
}
}
}
Err(e) => {
error!("❌ Failed to deserialize LoggingMessageNotification: {}", e);
return Err(format!(
"Failed to deserialize LoggingMessageNotification: {}",
e
));
}
}
}
"notifications/progress" => {
if let Some(params) = json_value.get("params")
&& let Some(token) = params.get("progressToken").and_then(|v| v.as_str())
{
debug!("📊 Progress notification detected: token={}", token);
let progress = params
.get("progress")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let notification = ProgressNotification {
method: "notifications/progress".to_string(),
params: turul_mcp_protocol::notifications::ProgressNotificationParams {
progress_token: token.to_string().into(),
progress,
total: params.get("total").and_then(|v| v.as_f64()),
message: params
.get("message")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
meta: None,
},
};
debug!(
"🔧 About to call broadcaster.send_progress_notification() for session {}",
session_id
);
match broadcaster
.send_progress_notification(session_id, notification)
.await
{
Ok(()) => {
debug!(
"🎉 SUCCESS: ProgressNotification sent to StreamManager for session {}",
session_id
);
debug!(
"🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
);
return Ok(());
}
Err(e) => {
error!(
"❌ Failed to send ProgressNotification to StreamManager: {}",
e
);
return Err(format!("Failed to send ProgressNotification: {}", e));
}
}
}
}
_ => {
debug!(
"🔧 Other notification method: {} - sending as generic JsonRpcNotification",
method
);
let params_map: std::collections::HashMap<String, serde_json::Value> =
json_value
.get("params")
.and_then(|p| p.as_object())
.unwrap_or(&serde_json::Map::new())
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let json_rpc_notification =
turul_mcp_json_rpc_server::JsonRpcNotification::new_with_object_params(
method.to_string(),
params_map,
);
match broadcaster
.send_notification(session_id, json_rpc_notification)
.await
{
Ok(()) => {
debug!(
"🎉 SUCCESS: Generic notification sent to StreamManager for session {}",
session_id
);
return Ok(());
}
Err(e) => {
error!(
"❌ Failed to send generic notification to StreamManager: {}",
e
);
return Err(format!("Failed to send generic notification: {}", e));
}
}
}
}
}
} else {
error!(
"❌ Failed to downcast broadcaster for session {}",
session_id
);
return Err("Failed to downcast broadcaster to SharedNotificationBroadcaster".to_string());
}
debug!(
"❓ Could not determine notification type for session {}",
session_id
);
Ok(())
}
#[derive(Debug, Clone)]
pub enum SessionEvent {
Notification(Value),
KeepAlive,
Disconnect,
Custom { event_type: String, data: Value },
}
#[derive(Debug)]
pub struct McpSession {
pub id: String,
pub created: Instant,
pub last_accessed: Instant,
pub mcp_version: McpVersion,
pub client_capabilities: Option<ClientCapabilities>,
pub server_capabilities: ServerCapabilities,
pub client_info: Option<Implementation>,
pub state: HashMap<String, Value>,
pub event_sender: broadcast::Sender<SessionEvent>,
pub initialized: bool,
}
impl McpSession {
pub fn new(server_capabilities: ServerCapabilities) -> Self {
let session_id = Uuid::now_v7().as_simple().to_string();
let (event_sender, _) = broadcast::channel(128);
Self {
id: session_id,
created: Instant::now(),
last_accessed: Instant::now(),
mcp_version: McpVersion::CURRENT,
client_capabilities: None,
server_capabilities,
client_info: None,
state: HashMap::new(),
event_sender,
initialized: false,
}
}
pub fn touch(&mut self) {
self.last_accessed = Instant::now();
}
pub fn is_expired(&self, timeout: Duration) -> bool {
self.last_accessed.elapsed() > timeout
}
pub fn initialize(
&mut self,
client_info: Implementation,
client_capabilities: ClientCapabilities,
) {
self.client_info = Some(client_info);
self.client_capabilities = Some(client_capabilities);
self.initialized = true;
self.touch();
}
pub fn initialize_with_version(
&mut self,
client_info: Implementation,
client_capabilities: ClientCapabilities,
mcp_version: McpVersion,
) {
self.client_info = Some(client_info);
self.client_capabilities = Some(client_capabilities);
self.mcp_version = mcp_version;
self.initialized = true;
self.touch();
}
pub fn get_state(&self, key: &str) -> Option<Value> {
self.state.get(key).cloned()
}
pub fn set_state(&mut self, key: &str, value: Value) {
self.state.insert(key.to_string(), value);
self.touch();
}
pub fn remove_state(&mut self, key: &str) -> Option<Value> {
let result = self.state.remove(key);
if result.is_some() {
self.touch();
}
result
}
pub fn send_event(&self, event: SessionEvent) -> Result<(), String> {
self.event_sender
.send(event)
.map_err(|e| format!("Failed to send event: {}", e))?;
Ok(())
}
pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
self.event_sender.subscribe()
}
}
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("Session not found: {0}")]
NotFound(String),
#[error("Session expired: {0}")]
Expired(String),
#[error("Session not initialized: {0}")]
NotInitialized(String),
#[error("Invalid session data: {0}")]
InvalidData(String),
#[error("Storage error: {0}")]
StorageError(String),
}
pub struct SessionManager {
storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
sessions: RwLock<HashMap<String, McpSession>>,
session_timeout: Duration,
cleanup_interval: Duration,
default_capabilities: ServerCapabilities,
global_event_sender: broadcast::Sender<(String, SessionEvent)>,
}
impl SessionManager {
pub fn new(default_capabilities: ServerCapabilities) -> Self {
let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
Self::with_storage_and_timeouts(
storage,
default_capabilities,
Duration::from_secs(30 * 60), Duration::from_secs(60), )
}
pub fn with_timeouts(
default_capabilities: ServerCapabilities,
session_timeout: Duration,
cleanup_interval: Duration,
) -> Self {
let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
Self::with_storage_and_timeouts(
storage,
default_capabilities,
session_timeout,
cleanup_interval,
)
}
pub fn with_storage(
storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
default_capabilities: ServerCapabilities,
) -> Self {
Self::with_storage_and_timeouts(
storage,
default_capabilities,
Duration::from_secs(30 * 60), Duration::from_secs(60), )
}
pub fn with_storage_and_timeouts(
storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
default_capabilities: ServerCapabilities,
session_timeout: Duration,
cleanup_interval: Duration,
) -> Self {
let (global_event_sender, _) = broadcast::channel(1000);
Self {
storage,
sessions: RwLock::new(HashMap::new()),
session_timeout,
cleanup_interval,
default_capabilities,
global_event_sender,
}
}
pub async fn create_session(&self) -> String {
let session = McpSession::new(self.default_capabilities.clone());
let session_id = session.id.clone();
debug!("Creating new session: {}", session_id);
match self
.storage
.create_session_with_id(session_id.clone(), self.default_capabilities.clone())
.await
{
Ok(_) => debug!("Session {} created in storage backend", session_id),
Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
}
self.sessions
.write()
.await
.insert(session_id.clone(), session);
session_id
}
pub async fn create_session_with_id(&self, session_id: String) -> String {
let mut session = McpSession::new(self.default_capabilities.clone());
session.id = session_id.clone();
debug!("Creating session with provided ID: {}", session_id);
match self
.storage
.create_session_with_id(session_id.clone(), self.default_capabilities.clone())
.await
{
Ok(_) => debug!("Session {} created in storage backend", session_id),
Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
}
self.sessions
.write()
.await
.insert(session_id.clone(), session);
session_id
}
pub async fn add_session_to_cache(
&self,
session_id: String,
server_capabilities: ServerCapabilities,
) {
let mut session = McpSession::new(server_capabilities);
session.id = session_id.clone();
debug!("Adding externally created session {} to cache", session_id);
self.sessions.write().await.insert(session_id, session);
}
pub async fn load_session_from_storage(&self, session_id: &str) -> Result<bool, SessionError> {
match self.storage.get_session(session_id).await {
Ok(Some(session_info)) => {
debug!("Loading session {} from storage", session_id);
let server_capabilities =
session_info.server_capabilities.clone().unwrap_or_else(|| {
warn!(
"Session {} in storage has no server capabilities, using defaults",
session_id
);
self.default_capabilities.clone()
});
let mut session = McpSession::new(server_capabilities);
session.id = session_id.to_string();
session.initialized = session_info.is_initialized;
session.client_capabilities = session_info.client_capabilities.clone();
session.state = session_info.state.clone();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let created_elapsed = if now > session_info.created_at {
Duration::from_millis(now - session_info.created_at)
} else {
Duration::from_secs(0)
};
let last_activity_elapsed = if now > session_info.last_activity {
Duration::from_millis(now - session_info.last_activity)
} else {
Duration::from_secs(0)
};
session.created = Instant::now() - created_elapsed;
session.last_accessed = Instant::now() - last_activity_elapsed;
self.sessions
.write()
.await
.insert(session_id.to_string(), session);
debug!(
"Session {} loaded from storage: initialized={}, has_capabilities={}",
session_id,
session_info.is_initialized,
session_info.server_capabilities.is_some()
);
Ok(true)
}
Ok(None) => {
debug!("Session {} not found in storage", session_id);
Ok(false)
}
Err(e) => {
error!("Failed to get session {} from storage: {}", session_id, e);
Err(SessionError::StorageError(e.to_string()))
}
}
}
pub async fn touch_session(&self, session_id: &str) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
if session.is_expired(self.session_timeout) {
sessions.remove(session_id);
return Err(SessionError::Expired(session_id.to_string()));
}
session.touch();
Ok(())
} else {
Err(SessionError::NotFound(session_id.to_string()))
}
}
pub async fn initialize_session(
&self,
session_id: &str,
client_info: Implementation,
client_capabilities: ClientCapabilities,
) -> Result<(), SessionError> {
if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
session_info.client_capabilities = Some(client_capabilities.clone());
session_info.is_initialized = true;
session_info.touch();
if let Err(e) = self.storage.update_session(session_info).await {
error!("Failed to update session in storage: {}", e);
}
}
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.initialize(client_info, client_capabilities);
debug!("Session {} initialized", session_id);
Ok(())
} else {
Err(SessionError::NotFound(session_id.to_string()))
}
}
pub async fn initialize_session_with_version(
&self,
session_id: &str,
client_info: Implementation,
client_capabilities: ClientCapabilities,
mcp_version: McpVersion,
) -> Result<(), SessionError> {
if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
session_info.client_capabilities = Some(client_capabilities.clone());
session_info.is_initialized = true;
session_info.touch();
if let Err(e) = self.storage.update_session(session_info).await {
error!(
"❌ CRITICAL: Failed to update session {} in storage: {}",
session_id, e
);
return Err(SessionError::StorageError(format!(
"Failed to persist session initialization: {}",
e
)));
}
debug!(
"✅ Session {} storage updated with is_initialized=true",
session_id
);
} else {
error!(
"❌ Session {} not found in storage during initialization",
session_id
);
return Err(SessionError::NotFound(session_id.to_string()));
}
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.initialize_with_version(client_info, client_capabilities, mcp_version);
debug!(
"✅ Session {} cache updated with protocol version {}",
session_id, mcp_version
);
Ok(())
} else {
warn!(
"⚠️ Session {} not found in cache but exists in storage - creating cache entry",
session_id
);
Ok(())
}
}
pub async fn session_exists(&self, session_id: &str) -> bool {
match self.storage.get_session(session_id).await {
Ok(Some(session_info)) => {
let timeout_minutes = self.session_timeout.as_secs() / 60;
!session_info.is_expired(timeout_minutes)
}
Ok(None) => false,
Err(e) => {
debug!("Storage backend error for session_exists: {}", e);
let sessions = self.sessions.read().await;
sessions
.get(session_id)
.map(|s| !s.is_expired(self.session_timeout))
.unwrap_or(false)
}
}
}
pub async fn get_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
match self.storage.get_session_state(session_id, key).await {
Ok(value) => value,
Err(e) => {
debug!("Storage backend error for get_session_state: {}", e);
let sessions = self.sessions.read().await;
sessions.get(session_id)?.get_state(key)
}
}
}
pub async fn set_session_state(&self, session_id: &str, key: &str, value: Value) {
if let Err(e) = self
.storage
.set_session_state(session_id, key, value.clone())
.await
{
error!("Failed to set session state in storage: {}", e);
}
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.set_state(key, value);
}
}
pub async fn remove_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
let storage_result = match self.storage.remove_session_state(session_id, key).await {
Ok(value) => value,
Err(e) => {
error!("Failed to remove session state from storage: {}", e);
None
}
};
let mut sessions = self.sessions.write().await;
let memory_result = sessions.get_mut(session_id)?.remove_state(key);
storage_result.or(memory_result)
}
pub async fn is_session_initialized(&self, session_id: &str) -> bool {
match self.storage.get_session(session_id).await {
Ok(Some(session_info)) => {
debug!(
"✅ Session {} initialization status from storage: {}",
session_id, session_info.is_initialized
);
session_info.is_initialized
}
Ok(None) => {
debug!("⚠️ Session {} not found in storage", session_id);
false
}
Err(e) => {
warn!(
"⚠️ Failed to check session {} in storage: {} - falling back to cache",
session_id, e
);
let sessions = self.sessions.read().await;
sessions
.get(session_id)
.map(|s| s.initialized)
.unwrap_or(false)
}
}
}
pub async fn remove_session(&self, session_id: &str) -> bool {
let storage_removed = match self.storage.delete_session(session_id).await {
Ok(removed) => {
if removed {
debug!("Session {} removed from storage backend", session_id);
}
removed
}
Err(e) => {
error!(
"Failed to remove session {} from storage: {}",
session_id, e
);
false
}
};
let mut sessions = self.sessions.write().await;
let memory_removed = if let Some(session) = sessions.remove(session_id) {
debug!("Session {} removed from memory cache", session_id);
let _ = session.send_event(SessionEvent::Disconnect);
true
} else {
false
};
storage_removed || memory_removed
}
pub async fn cleanup_expired(&self) -> usize {
let timeout_duration = self.session_timeout;
let cutoff = std::time::SystemTime::now() - timeout_duration;
let storage_removed = match self.storage.expire_sessions(cutoff).await {
Ok(expired_ids) => {
let count = expired_ids.len();
if count > 0 {
info!(
"Storage backend cleaned up {} expired sessions: {:?}",
count, expired_ids
);
}
count
}
Err(e) => {
error!("Failed to clean up expired sessions from storage: {}", e);
0
}
};
let cutoff_instant = Instant::now() - timeout_duration;
let mut sessions = self.sessions.write().await;
let initial_count = sessions.len();
sessions.retain(|id, session| {
let keep = session.last_accessed >= cutoff_instant;
if !keep {
info!("Session {} expired and removed from memory cache", id);
let _ = session.send_event(SessionEvent::Disconnect);
}
keep
});
let memory_removed = initial_count - sessions.len();
std::cmp::max(storage_removed, memory_removed)
}
pub async fn send_event_to_session(
&self,
session_id: &str,
event: SessionEvent,
) -> Result<(), SessionError> {
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(session_id) {
session
.send_event(event.clone())
.map_err(SessionError::InvalidData)?;
debug!(
"🌐 Forwarding event to global broadcaster: session={}, event={:?}",
session_id, event
);
if let Err(e) = self
.global_event_sender
.send((session_id.to_string(), event))
{
debug!("⚠️ Global event broadcast failed (no listeners): {}", e);
} else {
debug!("✅ Global event broadcast succeeded");
}
Ok(())
} else {
Err(SessionError::NotFound(session_id.to_string()))
}
}
pub async fn broadcast_event(&self, event: SessionEvent) {
let sessions = self.sessions.read().await;
for (session_id, session) in sessions.iter() {
if let Err(e) = session.send_event(event.clone()) {
warn!("Failed to send event to session {}: {}", session_id, e);
}
}
}
pub async fn session_count(&self) -> usize {
match self.storage.session_count().await {
Ok(count) => count,
Err(e) => {
debug!("Storage backend error for session_count: {}", e);
self.sessions.read().await.len()
}
}
}
pub fn create_session_context(self: &Arc<Self>, session_id: &str) -> Option<SessionContext> {
let session_id = session_id.to_string();
let session_manager = Arc::clone(self);
let get_state = {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move { session_manager.get_session_state(&session_id, &key).await })
})
};
let set_state = {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move {
let _ = session_manager
.set_session_state(&session_id, &key, value)
.await;
})
})
};
let remove_state = {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
let key = key.to_string();
Box::pin(async move {
session_manager
.remove_session_state(&session_id, &key)
.await
})
})
};
let is_initialized = {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Arc::new(move || -> BoxFuture<bool> {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Box::pin(async move { session_manager.is_session_initialized(&session_id).await })
})
};
let send_notification = {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
let session_manager = session_manager.clone();
let session_id = session_id.clone();
Box::pin(async move {
let _ = session_manager
.send_event_to_session(&session_id, event)
.await;
})
})
};
Some(SessionContext {
session_id,
get_state,
set_state,
remove_state,
is_initialized,
send_notification,
broadcaster: None, extensions: HashMap::new(),
})
}
pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let manager = Arc::clone(&self);
tokio::spawn(async move {
let mut interval = tokio::time::interval(manager.cleanup_interval);
loop {
interval.tick().await;
let cleaned = manager.cleanup_expired().await;
if cleaned > 0 {
debug!("Cleaned up {} expired sessions", cleaned);
}
}
})
}
pub async fn get_session_event_receiver(
&self,
session_id: &str,
) -> Option<broadcast::Receiver<SessionEvent>> {
let sessions = self.sessions.read().await;
Some(sessions.get(session_id)?.subscribe_events())
}
pub fn subscribe_all_session_events(&self) -> broadcast::Receiver<(String, SessionEvent)> {
self.global_event_sender.subscribe()
}
pub fn get_storage(&self) -> Arc<turul_mcp_session_storage::BoxedSessionStorage> {
Arc::clone(&self.storage)
}
pub fn get_default_capabilities(&self) -> ServerCapabilities {
self.default_capabilities.clone()
}
pub async fn session_exists_in_cache(&self, session_id: &str) -> bool {
self.sessions.read().await.contains_key(session_id)
}
}
#[async_trait]
pub trait SessionAware {
async fn handle_with_session(
&self,
params: Option<Value>,
session: Option<SessionContext>,
) -> Result<Value, String>;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_session_creation() {
let capabilities = ServerCapabilities::default();
let manager = SessionManager::new(capabilities);
let session_id = manager.create_session().await;
assert!(!session_id.is_empty());
assert!(manager.session_exists(&session_id).await);
}
#[tokio::test]
async fn test_session_state() {
let capabilities = ServerCapabilities::default();
let manager = SessionManager::new(capabilities);
let session_id = manager.create_session().await;
manager
.set_session_state(&session_id, "test_key", json!("test_value"))
.await;
let value = manager.get_session_state(&session_id, "test_key").await;
assert_eq!(value, Some(json!("test_value")));
let removed = manager.remove_session_state(&session_id, "test_key").await;
assert_eq!(removed, Some(json!("test_value")));
let value = manager.get_session_state(&session_id, "test_key").await;
assert_eq!(value, None);
}
#[tokio::test]
async fn test_session_context() {
let capabilities = ServerCapabilities::default();
let manager = Arc::new(SessionManager::new(capabilities));
let session_id = manager.create_session().await;
let ctx = manager.create_session_context(&session_id).unwrap();
(ctx.set_state)("test", json!("value")).await;
let value = (ctx.get_state)("test").await;
assert_eq!(value, Some(json!("value")));
let removed = (ctx.remove_state)("test").await;
assert_eq!(removed, Some(json!("value")));
ctx.notify_log(
turul_mcp_protocol::logging::LoggingLevel::Info,
serde_json::json!("Test notification"),
Some("test".to_string()),
None,
)
.await;
ctx.notify_progress("test-token", 50).await;
}
#[tokio::test]
async fn test_session_expiry() {
let capabilities = ServerCapabilities::default();
let mut manager = SessionManager::new(capabilities);
manager.session_timeout = Duration::from_millis(100);
let session_id = manager.create_session().await;
assert!(manager.session_exists_in_cache(&session_id).await);
tokio::time::sleep(Duration::from_millis(150)).await;
let result = manager.touch_session(&session_id).await;
assert!(matches!(result, Err(SessionError::Expired(_))));
}
#[tokio::test]
async fn test_extensions_not_persisted_to_session_storage() {
let ctx = SessionContext::new_test();
assert!(ctx.extensions.is_empty());
(ctx.set_state)("key", json!("value")).await;
let val = (ctx.get_state)("key").await;
assert_eq!(val, Some(json!("value")));
assert!(ctx.extensions.is_empty());
}
#[tokio::test]
async fn test_extensions_empty_when_no_middleware() {
let ctx = SessionContext::new_test();
assert!(ctx.extensions.is_empty());
assert!(ctx.get_extension("anything").is_none());
}
#[tokio::test]
async fn test_get_typed_extension_deserialization() {
#[derive(Debug, serde::Deserialize, PartialEq)]
struct TokenClaims {
sub: String,
iss: String,
}
let mut ctx = SessionContext::new_test();
ctx.extensions.insert(
"__turul_internal.auth_claims".to_string(),
json!({
"sub": "user-123",
"iss": "https://auth.example.com"
}),
);
let claims: Option<TokenClaims> = ctx.get_typed_extension("__turul_internal.auth_claims");
assert!(claims.is_some());
let claims = claims.unwrap();
assert_eq!(claims.sub, "user-123");
assert_eq!(claims.iss, "https://auth.example.com");
let missing: Option<TokenClaims> = ctx.get_typed_extension("nonexistent");
assert!(missing.is_none());
let wrong: Option<Vec<String>> = ctx.get_typed_extension("__turul_internal.auth_claims");
assert!(wrong.is_none());
}
#[tokio::test]
async fn test_extensions_thread_from_json_rpc_to_framework() {
use turul_mcp_session_storage::InMemorySessionStorage;
let storage = Arc::new(InMemorySessionStorage::new());
storage
.create_session_with_id("test-session".to_string(), Default::default())
.await
.unwrap();
let mut json_rpc_ctx = turul_mcp_json_rpc_server::SessionContext {
session_id: "test-session".to_string(),
metadata: HashMap::new(),
broadcaster: None,
timestamp: 0,
extensions: HashMap::new(),
};
json_rpc_ctx.extensions.insert(
"__turul_internal.auth_claims".to_string(),
json!({"sub": "user-456"}),
);
let framework_ctx = SessionContext::from_json_rpc_with_broadcaster(json_rpc_ctx, storage);
assert_eq!(
framework_ctx.get_extension("__turul_internal.auth_claims"),
Some(&json!({"sub": "user-456"}))
);
}
}