use std::fmt;
use std::sync::Arc;
use futures::{Stream, StreamExt};
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use rmcp::transport::common::server_side_http::ServerSseMessage;
use rmcp::transport::streamable_http_server::session::local::{
LocalSessionManager, LocalSessionManagerError,
};
use rmcp::transport::streamable_http_server::session::{SessionId, SessionManager};
use systemprompt_database::DbPool;
use tokio::sync::RwLock;
use crate::repository::McpSessionRepository;
#[derive(Debug)]
pub enum DatabaseSessionManagerError {
Local(LocalSessionManagerError),
Database(anyhow::Error),
SessionNotFound(String),
SessionExpired(String),
SessionNeedsReconnect(String),
}
impl fmt::Display for DatabaseSessionManagerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Local(e) => write!(f, "Local session error: {e}"),
Self::Database(e) => write!(f, "Database error: {e}"),
Self::SessionNotFound(id) => write!(f, "Session not found: {id}"),
Self::SessionExpired(id) => write!(f, "Session expired: {id}"),
Self::SessionNeedsReconnect(id) => write!(f, "Session needs reconnect: {id}"),
}
}
}
impl std::error::Error for DatabaseSessionManagerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Local(e) => Some(e),
Self::Database(e) => Some(e.as_ref()),
_ => None,
}
}
}
impl From<LocalSessionManagerError> for DatabaseSessionManagerError {
fn from(e: LocalSessionManagerError) -> Self {
Self::Local(e)
}
}
pub struct DatabaseSessionManager {
local_manager: LocalSessionManager,
repository: Arc<RwLock<Option<McpSessionRepository>>>,
}
impl fmt::Debug for DatabaseSessionManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DatabaseSessionManager")
.field("local_manager", &self.local_manager)
.field("repository", &self.repository)
.finish()
}
}
impl DatabaseSessionManager {
pub fn new(db_pool: &DbPool) -> Self {
let repository = McpSessionRepository::new(db_pool).ok();
Self {
local_manager: LocalSessionManager::default(),
repository: Arc::new(RwLock::new(repository)),
}
}
async fn persist_create(&self, session_id: &SessionId) {
let repo_guard = self.repository.read().await;
if let Some(repo) = repo_guard.as_ref()
&& let Err(e) = repo
.create(
&systemprompt_identifiers::SessionId::new(session_id.as_ref()),
None,
None,
)
.await
{
tracing::warn!(
session_id = %session_id,
error = %e,
"Failed to persist session creation to database"
);
}
}
async fn persist_close(&self, session_id: &SessionId) {
let repo_guard = self.repository.read().await;
if let Some(repo) = repo_guard.as_ref()
&& let Err(e) = repo
.close(&systemprompt_identifiers::SessionId::new(
session_id.as_ref(),
))
.await
{
tracing::warn!(
session_id = %session_id,
error = %e,
"Failed to persist session close to database"
);
}
}
async fn update_activity(&self, session_id: &SessionId) {
let repo_guard = self.repository.read().await;
if let Some(repo) = repo_guard.as_ref()
&& let Err(e) = repo
.update_activity(&systemprompt_identifiers::SessionId::new(
session_id.as_ref(),
))
.await
{
tracing::debug!(
session_id = %session_id,
error = %e,
"Failed to update session activity"
);
}
}
async fn check_db_session(&self, session_id: &SessionId) -> Option<bool> {
let repo_guard = self.repository.read().await;
if let Some(repo) = repo_guard.as_ref() {
match repo
.find_active(&systemprompt_identifiers::SessionId::new(
session_id.as_ref(),
))
.await
{
Ok(Some(_)) => Some(true),
Ok(None) => Some(false),
Err(e) => {
tracing::warn!(
session_id = %session_id,
error = %e,
"Failed to check session in database"
);
None
},
}
} else {
None
}
}
}
impl SessionManager for DatabaseSessionManager {
type Error = DatabaseSessionManagerError;
type Transport = <LocalSessionManager as SessionManager>::Transport;
async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
let (id, transport) = self.local_manager.create_session().await?;
tracing::info!(session_id = %id, "MCP session created");
self.persist_create(&id).await;
Ok((id, transport))
}
async fn initialize_session(
&self,
id: &SessionId,
message: ClientJsonRpcMessage,
) -> Result<ServerJsonRpcMessage, Self::Error> {
let result = self.local_manager.initialize_session(id, message).await?;
self.update_activity(id).await;
Ok(result)
}
async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
if self.local_manager.has_session(id).await.unwrap_or(false) {
return Ok(true);
}
if self.check_db_session(id).await == Some(true) {
tracing::info!(
session_id = %id,
"Session in DB but not memory — session not available (client should re-initialize)"
);
}
Ok(false)
}
async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
tracing::info!(session_id = %id, "MCP session closing");
if let Err(e) = self.local_manager.close_session(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to close local session");
}
self.persist_close(id).await;
Ok(())
}
async fn create_stream(
&self,
id: &SessionId,
message: ClientJsonRpcMessage,
) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
let stream = self.local_manager.create_stream(id, message).await?;
self.update_activity(id).await;
Ok(stream)
}
async fn accept_message(
&self,
id: &SessionId,
message: ClientJsonRpcMessage,
) -> Result<(), Self::Error> {
self.local_manager.accept_message(id, message).await?;
self.update_activity(id).await;
Ok(())
}
async fn create_standalone_stream(
&self,
id: &SessionId,
) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
let stream = self.local_manager.create_standalone_stream(id).await?;
self.update_activity(id).await;
Ok(stream)
}
async fn resume(
&self,
id: &SessionId,
last_event_id: String,
) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
if !self.local_manager.has_session(id).await.unwrap_or(false) {
if self.check_db_session(id).await == Some(true) {
tracing::info!(
session_id = %id,
"Session in DB but not memory (server restart?) — signaling reconnect"
);
self.persist_close(id).await;
return Err(DatabaseSessionManagerError::SessionNeedsReconnect(
id.to_string(),
));
}
tracing::warn!(
session_id = %id,
"Resume called but session not found anywhere"
);
return Err(DatabaseSessionManagerError::SessionNotFound(id.to_string()));
}
match self.local_manager.resume(id, last_event_id).await {
Ok(stream) => {
tracing::info!(
session_id = %id,
"Session resumed successfully"
);
self.update_activity(id).await;
Ok(stream.left_stream())
},
Err(e) => {
tracing::info!(
session_id = %id,
error = %e,
"Resume failed, attempting recovery via new standalone stream"
);
match self.local_manager.create_standalone_stream(id).await {
Ok(stream) => {
tracing::info!(
session_id = %id,
"Session recovered with new standalone stream"
);
self.update_activity(id).await;
Ok(stream.right_stream())
},
Err(e2) => {
tracing::warn!(
session_id = %id,
error = %e2,
"Session worker is dead, cleaning up"
);
if let Err(e) = self.local_manager.close_session(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to close local session during recovery");
}
self.persist_close(id).await;
Err(DatabaseSessionManagerError::SessionNotFound(id.to_string()))
},
}
},
}
}
}