rmcp 1.5.0

Rust SDK for Model Context Protocol
Documentation
#![cfg(not(feature = "local"))]
use std::sync::Arc;

use rmcp::{
    ClientHandler, ServerHandler, ServiceExt,
    model::{
        ClientRequest, ClientResult, CustomRequest, CustomResult, ServerRequest, ServerResult,
    },
};
use serde_json::json;
use tokio::sync::{Mutex, Notify};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

type CustomRequestPayload = (String, Option<serde_json::Value>);

struct CustomRequestServer {
    receive_signal: Arc<Notify>,
    payload: Arc<Mutex<Option<CustomRequestPayload>>>,
}

impl ServerHandler for CustomRequestServer {
    async fn on_custom_request(
        &self,
        request: CustomRequest,
        _context: rmcp::service::RequestContext<rmcp::RoleServer>,
    ) -> Result<CustomResult, rmcp::ErrorData> {
        let CustomRequest { method, params, .. } = request;
        *self.payload.lock().await = Some((method, params));
        self.receive_signal.notify_one();
        Ok(CustomResult::new(json!({ "status": "ok" })))
    }
}

#[tokio::test]
async fn test_custom_client_request_reaches_server() -> anyhow::Result<()> {
    let _ = tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "debug".to_string().into()),
        )
        .with(tracing_subscriber::fmt::layer())
        .try_init();

    let (server_transport, client_transport) = tokio::io::duplex(4096);
    let receive_signal = Arc::new(Notify::new());
    let payload = Arc::new(Mutex::new(None));

    {
        let receive_signal = receive_signal.clone();
        let payload = payload.clone();
        tokio::spawn(async move {
            let server = CustomRequestServer {
                receive_signal,
                payload,
            }
            .serve(server_transport)
            .await?;
            server.waiting().await?;
            anyhow::Ok(())
        });
    }

    let client = ().serve(client_transport).await?;

    let response = client
        .send_request(ClientRequest::CustomRequest(CustomRequest::new(
            "requests/custom-test",
            Some(json!({ "foo": "bar" })),
        )))
        .await?;

    tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?;

    let (method, params) = payload.lock().await.take().expect("payload set");
    assert_eq!("requests/custom-test", method);
    assert_eq!(Some(json!({ "foo": "bar" })), params);

    match response {
        ServerResult::CustomResult(result) => {
            assert_eq!(result.0, json!({ "status": "ok" }));
        }
        other => panic!("Expected custom result, got: {other:?}"),
    }

    client.cancel().await?;
    Ok(())
}

struct CustomRequestClient {
    receive_signal: Arc<Notify>,
    payload: Arc<Mutex<Option<CustomRequestPayload>>>,
}

impl ClientHandler for CustomRequestClient {
    async fn on_custom_request(
        &self,
        request: CustomRequest,
        _context: rmcp::service::RequestContext<rmcp::RoleClient>,
    ) -> Result<CustomResult, rmcp::ErrorData> {
        let CustomRequest { method, params, .. } = request;
        *self.payload.lock().await = Some((method, params));
        self.receive_signal.notify_one();
        Ok(CustomResult::new(json!({ "status": "ok" })))
    }
}

struct CustomRequestServerNotifier {
    receive_signal: Arc<Notify>,
    response: Arc<Mutex<Option<Result<serde_json::Value, String>>>>,
}

impl ServerHandler for CustomRequestServerNotifier {
    async fn on_initialized(&self, context: rmcp::service::NotificationContext<rmcp::RoleServer>) {
        let peer = context.peer.clone();
        let receive_signal = self.receive_signal.clone();
        let response = self.response.clone();
        tokio::spawn(async move {
            let result = peer
                .send_request(ServerRequest::CustomRequest(CustomRequest::new(
                    "requests/custom-server",
                    Some(json!({ "ping": "pong" })),
                )))
                .await;
            let payload = match result {
                Ok(ClientResult::CustomResult(result)) => Ok(result.0),
                Ok(other) => Err(format!("Unexpected response: {other:?}")),
                Err(err) => Err(format!("Failed to send request: {err:?}")),
            };
            *response.lock().await = Some(payload);
            receive_signal.notify_one();
        });
    }
}

#[tokio::test]
async fn test_custom_server_request_reaches_client() -> anyhow::Result<()> {
    let _ = tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "debug".to_string().into()),
        )
        .with(tracing_subscriber::fmt::layer())
        .try_init();

    let (server_transport, client_transport) = tokio::io::duplex(4096);
    let response_signal = Arc::new(Notify::new());
    let response = Arc::new(Mutex::new(None));
    tokio::spawn({
        let response_signal = response_signal.clone();
        let response = response.clone();
        async move {
            let server = CustomRequestServerNotifier {
                receive_signal: response_signal,
                response,
            }
            .serve(server_transport)
            .await?;
            server.waiting().await?;
            anyhow::Ok(())
        }
    });

    let receive_signal = Arc::new(Notify::new());
    let payload = Arc::new(Mutex::new(None));

    let client = CustomRequestClient {
        receive_signal: receive_signal.clone(),
        payload: payload.clone(),
    }
    .serve(client_transport)
    .await?;

    tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?;
    tokio::time::timeout(
        std::time::Duration::from_secs(5),
        response_signal.notified(),
    )
    .await?;

    let (method, params) = payload.lock().await.take().expect("payload set");
    assert_eq!("requests/custom-server", method);
    assert_eq!(Some(json!({ "ping": "pong" })), params);

    let response = response.lock().await.take().expect("response set");
    let response = response.expect("custom request response ok");
    assert_eq!(response, json!({ "status": "ok" }));

    client.cancel().await?;
    Ok(())
}