#[cfg(test)]
mod tests {
use std::{collections::HashMap, result::Result as StdResult, sync::Arc};
use async_trait::async_trait;
use rmcp::{ServiceExt, model as rmcp_model};
use rmcp_model::{CallToolRequestParams, InitializeRequestParams, PaginatedRequestParams};
use serde_json::json;
use tmcp::{
Arguments, Client, Error, Result, Server, ServerCtx, ServerHandler, ToolError, schema::*,
testutils::make_duplex_pair,
};
use tokio::{
io::duplex,
time::{Duration, sleep, timeout},
};
use tracing_subscriber::fmt;
struct EchoConnection;
#[async_trait]
impl ServerHandler for EchoConnection {
async fn initialize(
&self,
_context: &ServerCtx,
_protocol_version: String,
_capabilities: ClientCapabilities,
_client_info: Implementation,
) -> Result<InitializeResult> {
Ok(InitializeResult::new("test-server")
.with_version("0.1.0")
.with_tools(true))
}
async fn list_tools(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListToolsResult> {
tracing::info!("EchoConnection.tools_list called");
let schema = ToolSchema::default()
.with_property(
"message",
json!({
"type": "string",
"description": "The message to echo"
}),
)
.with_required("message");
Ok(ListToolsResult::new()
.with_tool(Tool::new("echo", schema).with_description("Echoes the input message")))
}
async fn call_tool(
&self,
_context: &ServerCtx,
name: String,
arguments: Option<Arguments>,
_task: Option<TaskMetadata>,
) -> Result<CallToolResult> {
if name != "echo" {
return Err(Error::ToolNotFound(name));
}
let Some(args) = arguments else {
return Ok(ToolError::invalid_input("echo: Missing arguments").into());
};
let Some(message) = args.get::<String>("message") else {
return Ok(ToolError::invalid_input("echo: Missing message parameter").into());
};
Ok(CallToolResult {
content: vec![ContentBlock::Text(TextContent {
text: message,
annotations: None,
_meta: None,
})],
is_error: Some(false),
structured_content: None,
_meta: None,
})
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_tmcp_server_with_rmcp_client() {
fmt::try_init().ok();
let (server_reader, server_writer, client_reader, client_writer) = make_duplex_pair();
let server = Server::new(|| EchoConnection);
let server_handle = tmcp::ServerHandle::from_stream(server, server_reader, server_writer)
.await
.expect("Failed to start server");
sleep(Duration::from_millis(100)).await;
let client_transport = (client_reader, client_writer);
let client = timeout(Duration::from_secs(5), ().serve(client_transport))
.await
.expect("Client connection timed out")
.expect("Failed to connect client");
let tools = client.list_tools(None).await.unwrap();
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "echo");
let mut args = serde_json::Map::new();
args.insert("message".to_string(), json!("Hello from rmcp!"));
let result = client
.call_tool(rmcp_model::CallToolRequestParams {
meta: None,
name: "echo".into(),
arguments: Some(args),
task: None,
})
.await
.unwrap();
assert_eq!(result.content.len(), 1);
match &result.content[0].raw {
rmcp_model::RawContent::Text(text_content) => {
assert_eq!(&text_content.text, "Hello from rmcp!");
}
_ => panic!("Expected text content"),
}
drop(client);
timeout(Duration::from_millis(100), server_handle.stop())
.await
.ok();
}
#[tokio::test]
async fn test_rmcp_server_with_tmcp_client() {
use rmcp::{
handler::server::ServerHandler,
service::{RequestContext, RoleServer},
};
#[derive(Debug, Clone)]
struct TestRmcpServer;
impl ServerHandler for TestRmcpServer {
async fn initialize(
&self,
_request: InitializeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> StdResult<rmcp_model::InitializeResult, rmcp::ErrorData> {
Ok(rmcp_model::InitializeResult {
protocol_version: rmcp_model::ProtocolVersion::default(),
capabilities: rmcp_model::ServerCapabilities::default(),
server_info: rmcp_model::Implementation {
name: "test-rmcp-server".to_string(),
title: None,
version: "0.1.0".to_string(),
icons: None,
website_url: None,
},
instructions: None,
})
}
async fn list_tools(
&self,
_params: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> StdResult<rmcp_model::ListToolsResult, rmcp::ErrorData> {
Ok(rmcp_model::ListToolsResult {
meta: None,
next_cursor: None,
tools: vec![rmcp_model::Tool {
name: "reverse".into(),
title: None,
description: Some("Reverses a string".into()),
input_schema: {
let mut schema = serde_json::Map::new();
schema.insert("type".to_string(), json!("object"));
let mut properties = serde_json::Map::new();
properties.insert(
"text".to_string(),
json!({
"type": "string",
"description": "Text to reverse"
}),
);
schema.insert("properties".to_string(), json!(properties));
schema.insert("required".to_string(), json!(["text"]));
Arc::new(schema)
},
output_schema: None,
annotations: None,
icons: None,
meta: None,
}],
})
}
async fn call_tool(
&self,
params: CallToolRequestParams,
_ctx: RequestContext<RoleServer>,
) -> StdResult<rmcp_model::CallToolResult, rmcp::ErrorData> {
if params.name == "reverse" {
let text = params
.arguments
.as_ref()
.and_then(|args| args.get("text"))
.and_then(|v| v.as_str())
.ok_or_else(|| {
rmcp::ErrorData::invalid_params("reverse: Missing text parameter", None)
})?;
let reversed = text.chars().rev().collect::<String>();
Ok(rmcp_model::CallToolResult {
content: vec![rmcp_model::Content::text(reversed)],
structured_content: None,
is_error: None,
meta: None,
})
} else {
Err(rmcp::ErrorData::invalid_request("Unknown tool", None))
}
}
}
let (client_reader, server_writer) = duplex(8192);
let (server_reader, client_writer) = duplex(8192);
let server = TestRmcpServer;
let server_handle = tokio::spawn(async move {
let transport = (server_reader, server_writer);
let _service = server.serve(transport).await.unwrap();
sleep(Duration::from_secs(10)).await;
});
sleep(Duration::from_millis(100)).await;
let mut client = Client::new("test-client", "0.1.0");
let init_result = client
.connect_stream(client_reader, client_writer)
.await
.unwrap();
assert!(!init_result.server_info.name.is_empty());
let tools = client.list_tools(None).await.unwrap();
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "reverse");
let mut args = HashMap::new();
args.insert("text".to_string(), json!("hello"));
let result = client.call_tool("reverse", args).await.unwrap();
assert_eq!(result.content.len(), 1);
match &result.content[0] {
ContentBlock::Text(text) => {
assert_eq!(text.text, "olleh");
}
_ => panic!("Expected text content"),
}
server_handle.abort();
}
}