#![cfg(not(feature = "local"))]
#![cfg(feature = "client")]
use rmcp::{
ClientHandler, ServerHandler, ServiceError, ServiceExt,
handler::server::router::tool::ToolRouter,
model::{CallToolRequestParams, ClientInfo, ErrorCode, JsonObject},
tool, tool_handler, tool_router,
};
#[derive(Debug, Clone)]
pub struct TaskSupportTestServer {
tool_router: ToolRouter<Self>,
}
impl Default for TaskSupportTestServer {
fn default() -> Self {
Self::new()
}
}
impl TaskSupportTestServer {
pub fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
}
#[tool_router]
impl TaskSupportTestServer {
#[tool(
description = "Tool that requires task-based invocation",
execution(task_support = "required")
)]
async fn required_task_tool(&self) -> String {
"required task executed".to_string()
}
#[tool(
description = "Tool that forbids task-based invocation",
execution(task_support = "forbidden")
)]
async fn forbidden_task_tool(&self) -> String {
"forbidden task executed".to_string()
}
#[tool(
description = "Tool that optionally supports task-based invocation",
execution(task_support = "optional")
)]
async fn optional_task_tool(&self) -> String {
"optional task executed".to_string()
}
}
#[tool_handler]
impl ServerHandler for TaskSupportTestServer {}
#[derive(Debug, Clone, Default)]
struct DummyClientHandler {}
impl ClientHandler for DummyClientHandler {
fn get_info(&self) -> ClientInfo {
ClientInfo::default()
}
}
fn make_task() -> JsonObject {
serde_json::Map::new()
}
#[tokio::test]
async fn test_required_task_tool_without_task_returns_method_not_found() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server = TaskSupportTestServer::new();
let server_handle = tokio::spawn(async move {
server.serve(server_transport).await?.waiting().await?;
anyhow::Ok(())
});
let client_handler = DummyClientHandler::default();
let client = client_handler.serve(client_transport).await?;
let result = client
.call_tool(CallToolRequestParams::new("required_task_tool"))
.await;
assert!(
result.is_err(),
"Expected error for required task tool without task"
);
let error = result.unwrap_err();
match error {
ServiceError::McpError(error_data) => {
assert_eq!(
error_data.code,
ErrorCode::METHOD_NOT_FOUND,
"Expected METHOD_NOT_FOUND error code (-32601)"
);
assert!(
error_data
.message
.contains("requires task-based invocation"),
"Error message should indicate task-based invocation is required, got: {}",
error_data.message
);
}
_ => panic!("Expected McpError variant, got: {:?}", error),
}
client.cancel().await?;
server_handle.await??;
Ok(())
}
#[tokio::test]
async fn test_forbidden_task_tool_with_task_returns_error() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server = TaskSupportTestServer::new();
let server_handle = tokio::spawn(async move {
server.serve(server_transport).await?.waiting().await?;
anyhow::Ok(())
});
let client_handler = DummyClientHandler::default();
let client = client_handler.serve(client_transport).await?;
let result = client
.call_tool(CallToolRequestParams::new("forbidden_task_tool").with_task(make_task()))
.await;
assert!(
result.is_err(),
"Expected error for forbidden task tool with task"
);
let error = result.unwrap_err();
match error {
ServiceError::McpError(error_data) => {
assert_eq!(
error_data.code,
ErrorCode::INVALID_PARAMS,
"Expected INVALID_PARAMS error code"
);
assert!(
error_data
.message
.contains("does not support task-based invocation"),
"Error message should indicate task-based invocation is not supported, got: {}",
error_data.message
);
}
_ => panic!("Expected McpError variant, got: {:?}", error),
}
client.cancel().await?;
server_handle.await??;
Ok(())
}
#[tokio::test]
async fn test_forbidden_task_tool_without_task_succeeds() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server = TaskSupportTestServer::new();
let server_handle = tokio::spawn(async move {
server.serve(server_transport).await?.waiting().await?;
anyhow::Ok(())
});
let client_handler = DummyClientHandler::default();
let client = client_handler.serve(client_transport).await?;
let result = client
.call_tool(CallToolRequestParams::new("forbidden_task_tool"))
.await;
assert!(
result.is_ok(),
"Forbidden task tool without task should succeed"
);
let result = result.unwrap();
let text = result
.content
.first()
.and_then(|c| c.raw.as_text())
.map(|t| t.text.as_str())
.unwrap_or("");
assert_eq!(text, "forbidden task executed");
client.cancel().await?;
server_handle.await??;
Ok(())
}
#[tokio::test]
async fn test_optional_task_tool_without_task_succeeds() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server = TaskSupportTestServer::new();
let server_handle = tokio::spawn(async move {
server.serve(server_transport).await?.waiting().await?;
anyhow::Ok(())
});
let client_handler = DummyClientHandler::default();
let client = client_handler.serve(client_transport).await?;
let result = client
.call_tool(CallToolRequestParams::new("optional_task_tool"))
.await;
assert!(
result.is_ok(),
"Optional task tool without task should succeed"
);
let result = result.unwrap();
let text = result
.content
.first()
.and_then(|c| c.raw.as_text())
.map(|t| t.text.as_str())
.unwrap_or("");
assert_eq!(text, "optional task executed");
client.cancel().await?;
server_handle.await??;
Ok(())
}