#![cfg(test)]
use std::sync::{Arc, Mutex};
use std::time::Duration;
use rmcp::handler::client::ClientHandler;
use rmcp::handler::server::ServerHandler;
use rmcp::model::{
CreateMessageRequestParams, CreateMessageResult, ErrorData as McpError,
Role as RmcpRole, SamplingMessage, SamplingMessageContent,
};
use rmcp::service::{
Peer, RequestContext, RoleClient, RoleServer, serve_directly,
};
use tokio::io::duplex;
struct TestServerHandler;
impl ServerHandler for TestServerHandler {}
struct TestClientHandler {
delay_for_marker: Arc<Mutex<std::collections::HashMap<String, Duration>>>,
}
impl TestClientHandler {
fn new() -> Self {
Self {
delay_for_marker: Arc::new(Mutex::new(
std::collections::HashMap::new(),
)),
}
}
fn with_delay(self, marker: impl Into<String>, delay: Duration) -> Self {
self.delay_for_marker
.lock()
.expect("delay map mutex")
.insert(marker.into(), delay);
self
}
fn extract_user_text(params: &CreateMessageRequestParams) -> String {
let mut out = String::new();
for msg in ¶ms.messages {
if msg.role != RmcpRole::User {
continue;
}
for content in msg.content.iter() {
if let SamplingMessageContent::Text(text) = content {
if !out.is_empty() {
out.push('\n');
}
out.push_str(&text.text);
}
}
}
out
}
}
impl ClientHandler for TestClientHandler {
async fn create_message(
&self,
params: CreateMessageRequestParams,
_context: RequestContext<RoleClient>,
) -> std::result::Result<CreateMessageResult, McpError> {
let prompt = Self::extract_user_text(¶ms);
let delay = self
.delay_for_marker
.lock()
.expect("delay map mutex")
.get(&prompt)
.copied();
if let Some(d) = delay {
tokio::time::sleep(d).await;
}
let echoed = format!("ECHO[{prompt}]");
Ok(CreateMessageResult::new(
SamplingMessage::assistant_text(echoed),
"test-client".to_string(),
))
}
}
async fn make_peer(
handler: TestClientHandler,
) -> (
Peer<RoleServer>,
rmcp::service::RunningService<RoleServer, TestServerHandler>,
rmcp::service::RunningService<RoleClient, TestClientHandler>,
) {
let (server_io, client_io) = duplex(64 * 1024);
let server_service =
serve_directly(TestServerHandler, server_io, None);
let client_service = serve_directly(handler, client_io, None);
let peer = server_service.peer().clone();
(peer, server_service, client_service)
}
fn extract_assistant_text(result: &CreateMessageResult) -> String {
let mut out = String::new();
for content in result.message.content.iter() {
if let SamplingMessageContent::Text(text) = content {
if !out.is_empty() {
out.push('\n');
}
out.push_str(&text.text);
}
}
out
}
fn marker_request(marker: &str) -> CreateMessageRequestParams {
CreateMessageRequestParams::new(
vec![SamplingMessage::user_text(marker)],
32,
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn peer_create_message_serial_calls_get_correct_responses() {
let (peer, _server_handle, _client_handle) =
make_peer(TestClientHandler::new()).await;
for i in 0..5 {
let marker = format!("CALL_{i}");
let result = peer
.create_message(marker_request(&marker))
.await
.expect("create_message should succeed");
let echoed = extract_assistant_text(&result);
assert_eq!(
echoed,
format!("ECHO[{marker}]"),
"serial call {i} should receive its OWN response, not some other call's"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn peer_create_message_handles_concurrent_calls_without_cross_wiring()
{
let (peer, _server_handle, _client_handle) =
make_peer(TestClientHandler::new()).await;
let peer = Arc::new(peer);
const N: usize = 16;
let mut tasks = Vec::with_capacity(N);
for i in 0..N {
let peer = Arc::clone(&peer);
tasks.push(tokio::spawn(async move {
let marker = format!("CALL_{i}");
let result = peer
.create_message(marker_request(&marker))
.await
.expect("create_message should succeed");
(i, extract_assistant_text(&result))
}));
}
let mut got = Vec::with_capacity(N);
for task in tasks {
got.push(task.await.expect("join should succeed"));
}
for (call_id, echoed) in &got {
let expected = format!("ECHO[CALL_{call_id}]");
assert_eq!(
echoed, &expected,
"call {call_id} received {echoed:?}, expected {expected:?}. \
rmcp cross-wired the response ids?"
);
}
assert_eq!(got.len(), N);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn peer_create_message_handles_slow_call_among_fast_ones() {
let slow_marker = "CALL_SLOW";
let handler = TestClientHandler::new()
.with_delay(slow_marker, Duration::from_millis(150));
let (peer, _server_handle, _client_handle) = make_peer(handler).await;
let peer = Arc::new(peer);
let slow_task = {
let peer = Arc::clone(&peer);
tokio::spawn(async move {
peer.create_message(marker_request(slow_marker))
.await
.expect("slow create_message should succeed")
})
};
tokio::task::yield_now().await;
let mut fast_tasks = Vec::new();
for i in 0..4 {
let peer = Arc::clone(&peer);
fast_tasks.push(tokio::spawn(async move {
let marker = format!("CALL_FAST_{i}");
let result = peer
.create_message(marker_request(&marker))
.await
.expect("fast create_message should succeed");
(i, extract_assistant_text(&result))
}));
}
for task in fast_tasks {
let (i, echoed) = task.await.expect("join");
let expected = format!("ECHO[CALL_FAST_{i}]");
assert_eq!(
echoed, expected,
"fast call {i} got the wrong response while the slow call was \
still in flight — cross-wiring detected"
);
}
let slow_result = slow_task.await.expect("slow join");
let echoed = extract_assistant_text(&slow_result);
assert_eq!(
echoed,
format!("ECHO[{slow_marker}]"),
"slow call received a fast call's response — rmcp dispatch is not \
correlating response ids correctly"
);
}