systemprompt-mcp 0.10.2

Native Model Context Protocol (MCP) implementation for systemprompt.io. Orchestration, per-server OAuth2, RBAC middleware, and tool-call governance — the core of the AI governance pipeline.
Documentation
use std::fmt;
use std::sync::Arc;

use futures::{Stream, StreamExt};
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use rmcp::transport::common::server_side_http::ServerSseMessage;
use rmcp::transport::streamable_http_server::session::local::{
    LocalSessionManager, LocalSessionManagerError,
};
use rmcp::transport::streamable_http_server::session::{SessionId, SessionManager};
use systemprompt_database::DbPool;
use tokio::sync::RwLock;

use crate::repository::McpSessionRepository;

#[derive(Debug)]
pub enum DatabaseSessionManagerError {
    Local(LocalSessionManagerError),
    Database(crate::error::McpDomainError),
    SessionNotFound(String),
    SessionExpired(String),
    SessionNeedsReconnect(String),
}

impl fmt::Display for DatabaseSessionManagerError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Local(e) => write!(f, "Local session error: {e}"),
            Self::Database(e) => write!(f, "Database error: {e}"),
            Self::SessionNotFound(id) => write!(f, "Session not found: {id}"),
            Self::SessionExpired(id) => write!(f, "Session expired: {id}"),
            Self::SessionNeedsReconnect(id) => write!(f, "Session needs reconnect: {id}"),
        }
    }
}

impl std::error::Error for DatabaseSessionManagerError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            Self::Local(e) => Some(e),
            Self::Database(e) => Some(e),
            _ => None,
        }
    }
}

impl From<LocalSessionManagerError> for DatabaseSessionManagerError {
    fn from(e: LocalSessionManagerError) -> Self {
        Self::Local(e)
    }
}

pub struct DatabaseSessionManager {
    local_manager: LocalSessionManager,
    repository: Arc<RwLock<Option<McpSessionRepository>>>,
}

impl fmt::Debug for DatabaseSessionManager {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("DatabaseSessionManager")
            .field("local_manager", &self.local_manager)
            .field("repository", &self.repository)
            .finish()
    }
}

impl DatabaseSessionManager {
    pub fn new(db_pool: &DbPool) -> Self {
        Self::with_timeouts(db_pool, crate::SessionTimeouts::default())
    }

    pub fn with_timeouts(db_pool: &DbPool, timeouts: crate::SessionTimeouts) -> Self {
        let mut local_manager = LocalSessionManager::default();
        let cfg = &mut local_manager.session_config;
        cfg.init_timeout = timeouts.init.or(cfg.init_timeout);
        cfg.keep_alive = timeouts.keep_alive.or(cfg.keep_alive);
        Self {
            local_manager,
            repository: Arc::new(RwLock::new(McpSessionRepository::new(db_pool).ok())),
        }
    }

    async fn persist_create(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref()
            && let Err(e) = repo
                .create(
                    &systemprompt_identifiers::SessionId::new(session_id.as_ref()),
                    None,
                    None,
                )
                .await
        {
            tracing::warn!(
                session_id = %session_id,
                error = %e,
                "Failed to persist session creation to database"
            );
        }
    }

    async fn persist_close(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref()
            && let Err(e) = repo
                .close(&systemprompt_identifiers::SessionId::new(
                    session_id.as_ref(),
                ))
                .await
        {
            tracing::warn!(
                session_id = %session_id,
                error = %e,
                "Failed to persist session close to database"
            );
        }
    }

    async fn update_activity(&self, session_id: &SessionId) {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref()
            && let Err(e) = repo
                .update_activity(&systemprompt_identifiers::SessionId::new(
                    session_id.as_ref(),
                ))
                .await
        {
            tracing::debug!(
                session_id = %session_id,
                error = %e,
                "Failed to update session activity"
            );
        }
    }

    async fn check_db_session(&self, session_id: &SessionId) -> Option<bool> {
        let repo_guard = self.repository.read().await;
        if let Some(repo) = repo_guard.as_ref() {
            match repo
                .find_active(&systemprompt_identifiers::SessionId::new(
                    session_id.as_ref(),
                ))
                .await
            {
                Ok(Some(_)) => Some(true),
                Ok(None) => Some(false),
                Err(e) => {
                    tracing::warn!(
                        session_id = %session_id,
                        error = %e,
                        "Failed to check session in database"
                    );
                    None
                },
            }
        } else {
            None
        }
    }
}

impl SessionManager for DatabaseSessionManager {
    type Error = DatabaseSessionManagerError;
    type Transport = <LocalSessionManager as SessionManager>::Transport;

    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
        let (id, transport) = self.local_manager.create_session().await?;
        tracing::info!(session_id = %id, "MCP session created");
        self.persist_create(&id).await;
        Ok((id, transport))
    }

    async fn initialize_session(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<ServerJsonRpcMessage, Self::Error> {
        let result = self.local_manager.initialize_session(id, message).await?;
        self.update_activity(id).await;
        Ok(result)
    }

    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
        if self.local_manager.has_session(id).await.unwrap_or(false) {
            return Ok(true);
        }
        if self.check_db_session(id).await == Some(true) {
            tracing::info!(
                session_id = %id,
                "Session in DB but not memory — session not available (client should re-initialize)"
            );
        }
        Ok(false)
    }

    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
        tracing::info!(session_id = %id, "MCP session closing");
        if let Err(e) = self.local_manager.close_session(id).await {
            tracing::warn!(session_id = %id, error = %e, "Failed to close local session");
        }
        self.persist_close(id).await;
        Ok(())
    }

    async fn create_stream(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        let stream = self.local_manager.create_stream(id, message).await?;
        self.update_activity(id).await;
        Ok(stream)
    }

    async fn accept_message(
        &self,
        id: &SessionId,
        message: ClientJsonRpcMessage,
    ) -> Result<(), Self::Error> {
        self.local_manager.accept_message(id, message).await?;
        self.update_activity(id).await;
        Ok(())
    }

    async fn create_standalone_stream(
        &self,
        id: &SessionId,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        let stream = self.local_manager.create_standalone_stream(id).await?;
        self.update_activity(id).await;
        Ok(stream)
    }

    async fn resume(
        &self,
        id: &SessionId,
        last_event_id: String,
    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
        if !self.local_manager.has_session(id).await.unwrap_or(false) {
            if self.check_db_session(id).await == Some(true) {
                tracing::info!(
                    session_id = %id,
                    "Session in DB but not memory (server restart?) — signaling reconnect"
                );
                self.persist_close(id).await;
                return Err(DatabaseSessionManagerError::SessionNeedsReconnect(
                    id.to_string(),
                ));
            }
            tracing::warn!(
                session_id = %id,
                "Resume called but session not found anywhere"
            );
            return Err(DatabaseSessionManagerError::SessionNotFound(id.to_string()));
        }

        match self.local_manager.resume(id, last_event_id).await {
            Ok(stream) => {
                tracing::info!(
                    session_id = %id,
                    "Session resumed successfully"
                );
                self.update_activity(id).await;
                Ok(stream.left_stream())
            },
            Err(e) => {
                tracing::info!(
                    session_id = %id,
                    error = %e,
                    "Resume failed, attempting recovery via new standalone stream"
                );
                match self.local_manager.create_standalone_stream(id).await {
                    Ok(stream) => {
                        tracing::info!(
                            session_id = %id,
                            "Session recovered with new standalone stream"
                        );
                        self.update_activity(id).await;
                        Ok(stream.right_stream())
                    },
                    Err(e2) => {
                        tracing::warn!(
                            session_id = %id,
                            error = %e2,
                            "Session worker is dead, cleaning up"
                        );
                        if let Err(e) = self.local_manager.close_session(id).await {
                            tracing::warn!(session_id = %id, error = %e, "Failed to close local session during recovery");
                        }
                        self.persist_close(id).await;
                        Err(DatabaseSessionManagerError::SessionNotFound(id.to_string()))
                    },
                }
            },
        }
    }
}