systemprompt-mcp 0.1.9

Core MCP (Model Context Protocol) functionality for systemprompt.io OS
Documentation
use std::fmt;
use std::sync::Arc;

use futures::Stream;
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use rmcp::transport::common::server_side_http::ServerSseMessage;
use rmcp::transport::streamable_http_server::session::local::{
    LocalSessionManager, LocalSessionManagerError,
};
use rmcp::transport::streamable_http_server::session::{SessionId, SessionManager};
use systemprompt_database::DbPool;
use tokio::sync::RwLock;

use crate::repository::McpSessionRepository;

#[derive(Debug)]
pub enum DatabaseSessionManagerError {
    Local(LocalSessionManagerError),
    Database(anyhow::Error),
    SessionNotFound(String),
    SessionExpired(String),
    SessionNeedsReconnect(String),
}

impl fmt::Display for DatabaseSessionManagerError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Local(e) => write!(f, "Local session error: {e}"),
            Self::Database(e) => write!(f, "Database error: {e}"),
            Self::SessionNotFound(id) => write!(f, "Session not found: {id}"),
            Self::SessionExpired(id) => write!(f, "Session expired: {id}"),
            Self::SessionNeedsReconnect(id) => write!(f, "Session needs reconnect: {id}"),
        }
    }
}

impl std::error::Error for DatabaseSessionManagerError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            Self::Local(e) => Some(e),
            Self::Database(e) => Some(e.as_ref()),
            _ => None,
        }
    }
}

impl From<LocalSessionManagerError> for DatabaseSessionManagerError {
    fn from(e: LocalSessionManagerError) -> Self {
        Self::Local(e)
    }
}

#[derive(Debug)]
pub struct DatabaseSessionManager {
    local_manager: LocalSessionManager,
    repository: Arc<RwLock<Option<McpSessionRepository>>>,
}

impl DatabaseSessionManager {
    pub fn new(db_pool: &DbPool) -> Self {
        let repository = McpSessionRepository::new(db_pool).ok();
        Self {
            local_manager: LocalSessionManager::default(),
            repository: Arc::new(RwLock::new(repository)),
        }
    }

    async fn persist_create(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref() {
            if let Err(e) = repo.create(session_id.as_ref(), None, None).await {
                tracing::warn!(
                    session_id = %session_id,
                    error = %e,
                    "Failed to persist session creation to database"
                );
            }
        }
    }

    async fn persist_close(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref() {
            if let Err(e) = repo.close(session_id.as_ref()).await {
                tracing::warn!(
                    session_id = %session_id,
                    error = %e,
                    "Failed to persist session close to database"
                );
            }
        }
    }

    async fn update_activity(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref() {
            if let Err(e) = repo.update_activity(session_id.as_ref()).await {
                tracing::debug!(
                    session_id = %session_id,
                    error = %e,
                    "Failed to update session activity"
                );
            }
        }
    }

    async fn check_db_session(&self, session_id: &SessionId) -> Option<bool> {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref() {
            match repo.find_active(session_id.as_ref()).await {
                Ok(Some(_)) => Some(true),
                Ok(None) => Some(false),
                Err(e) => {
                    tracing::warn!(
                        session_id = %session_id,
                        error = %e,
                        "Failed to check session in database"
                    );
                    None
                },
            }
        } else {
            None
        }
    }
}

impl SessionManager for DatabaseSessionManager {
    type Error = DatabaseSessionManagerError;
    type Transport = <LocalSessionManager as SessionManager>::Transport;

    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
        let (id, transport) = self.local_manager.create_session().await?;
        self.persist_create(&id).await;
        Ok((id, transport))
    }

    async fn initialize_session(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<ServerJsonRpcMessage, Self::Error> {
        let result = self.local_manager.initialize_session(id, message).await?;
        self.update_activity(id).await;
        Ok(result)
    }

    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
        if self.local_manager.has_session(id).await? {
            return Ok(true);
        }
        Ok(self.check_db_session(id).await.unwrap_or(false))
    }

    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
        let _ = self.local_manager.close_session(id).await;
        self.persist_close(id).await;
        Ok(())
    }

    async fn create_stream(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        let stream = self.local_manager.create_stream(id, message).await?;
        self.update_activity(id).await;
        Ok(stream)
    }

    async fn accept_message(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<(), Self::Error> {
        self.local_manager.accept_message(id, message).await?;
        self.update_activity(id).await;
        Ok(())
    }

    async fn create_standalone_stream(
        &self,
        id: &SessionId,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        let stream = self.local_manager.create_standalone_stream(id).await?;
        self.update_activity(id).await;
        Ok(stream)
    }

    async fn resume(
        &self,
        id: &SessionId,
        last_event_id: String,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        if self.local_manager.has_session(id).await.unwrap_or(false) {
            match self.local_manager.resume(id, last_event_id).await {
                Ok(stream) => return Ok(stream),
                Err(e) => {
                    tracing::info!(
                        session_id = %id,
                        error = %e,
                        "Session channel closed - cleaning up stale session"
                    );
                    let _ = self.local_manager.close_session(id).await;
                    self.persist_close(id).await;
                    return Err(DatabaseSessionManagerError::SessionNeedsReconnect(
                        id.to_string(),
                    ));
                },
            }
        }

        match self.check_db_session(id).await {
            Some(true) => {
                tracing::info!(
                    session_id = %id,
                    "Session exists in database but not in memory - client needs to reconnect"
                );
                Err(DatabaseSessionManagerError::SessionNeedsReconnect(
                    id.to_string(),
                ))
            },
            Some(false) => {
                tracing::debug!(session_id = %id, "Session not found in database");
                Err(DatabaseSessionManagerError::SessionNotFound(id.to_string()))
            },
            None => self
                .local_manager
                .resume(id, last_event_id)
                .await
                .map_err(Into::into),
        }
    }
}