Skip to main content

aether_core/testing/
fake_mcp.rs

1use rmcp::{
2    ErrorData as McpError, Json, RoleServer, ServerHandler,
3    handler::server::{router::tool::ToolRouter, wrapper::Parameters},
4    model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
5    service::DynService,
6    tool, tool_handler, tool_router,
7};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11use mcp_utils::client::ServerConfig;
12
13pub fn fake_mcp(name: &str, server: FakeMcpServer) -> ServerConfig {
14    ServerConfig::InMemory { name: name.to_string(), server: server.into_dyn() }
15}
16
17/// A fake MCP server for testing
18#[derive(Clone)]
19pub struct FakeMcpServer {
20    tool_router: ToolRouter<Self>,
21}
22
23#[tool_handler(router = self.tool_router)]
24impl ServerHandler for FakeMcpServer {
25    fn get_info(&self) -> ServerInfo {
26        ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
27            .with_server_info(
28                Implementation::new("fake-mcp-server", "0.1.0").with_description("A fake MCP server for testing"),
29            )
30            .with_instructions("A fake MCP server for testing")
31    }
32}
33
34#[derive(Serialize, Deserialize, JsonSchema)]
35pub struct AddNumbersRequest {
36    pub a: u32,
37    pub b: u32,
38}
39
40impl AddNumbersRequest {
41    pub fn new(a: u32, b: u32) -> Self {
42        Self { a, b }
43    }
44
45    pub fn json(&self) -> Result<String, serde_json::Error> {
46        serde_json::to_string(self)
47    }
48}
49
50#[derive(Serialize, Deserialize, JsonSchema)]
51pub struct AddNumbersResult {
52    pub sum: u32,
53}
54
55impl AddNumbersResult {
56    pub fn new(sum: u32) -> Self {
57        Self { sum }
58    }
59
60    pub fn json(&self) -> Result<String, serde_json::Error> {
61        serde_json::to_string(self)
62    }
63}
64
65#[derive(Serialize, Deserialize, JsonSchema)]
66pub struct DivideNumbersRequest {
67    pub a: i32,
68    pub b: i32,
69}
70
71impl DivideNumbersRequest {
72    pub fn new(a: i32, b: i32) -> Self {
73        Self { a, b }
74    }
75
76    pub fn json(&self) -> Result<String, serde_json::Error> {
77        serde_json::to_string(self)
78    }
79}
80
81#[derive(Serialize, Deserialize, JsonSchema)]
82pub struct DivideNumbersResult {
83    pub quotient: i32,
84}
85
86impl DivideNumbersResult {
87    pub fn new(quotient: i32) -> Self {
88        Self { quotient }
89    }
90
91    pub fn json(&self) -> Result<String, serde_json::Error> {
92        serde_json::to_string(self)
93    }
94}
95
96#[derive(Serialize, Deserialize, JsonSchema)]
97pub struct SlowToolRequest {
98    pub sleep_ms: u64,
99}
100
101impl SlowToolRequest {
102    pub fn new(sleep_ms: u64) -> Self {
103        Self { sleep_ms }
104    }
105
106    pub fn json(&self) -> Result<String, serde_json::Error> {
107        serde_json::to_string(self)
108    }
109}
110
111#[derive(Serialize, Deserialize, JsonSchema)]
112pub struct SlowToolResult {
113    pub message: String,
114}
115
116impl SlowToolResult {
117    pub fn new(message: String) -> Self {
118        Self { message }
119    }
120
121    pub fn json(&self) -> Result<String, serde_json::Error> {
122        serde_json::to_string(self)
123    }
124}
125
126impl Default for FakeMcpServer {
127    fn default() -> Self {
128        Self { tool_router: Self::tool_router() }
129    }
130}
131
132#[tool_router]
133impl FakeMcpServer {
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    pub fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
139        Box::new(self)
140    }
141
142    #[tool(description = "Adds two numbers together")]
143    pub async fn add_numbers(&self, request: Parameters<AddNumbersRequest>) -> Json<AddNumbersResult> {
144        let Parameters(AddNumbersRequest { a, b }) = request;
145        Json(AddNumbersResult { sum: a + b })
146    }
147
148    #[tool(description = "Divides two numbers")]
149    pub async fn divide_numbers(&self, request: Parameters<DivideNumbersRequest>) -> Result<CallToolResult, McpError> {
150        let Parameters(DivideNumbersRequest { a, b }) = request;
151
152        if b == 0 {
153            return Ok(CallToolResult::error(vec![Content::text("Division by zero")]));
154        }
155
156        let result = DivideNumbersResult { quotient: a / b };
157        let result_json = serde_json::to_string(&result).unwrap();
158
159        Ok(CallToolResult::success(vec![Content::text(result_json)]))
160    }
161
162    #[tool(description = "A tool that sleeps for a specified duration (for testing timeouts)")]
163    pub async fn slow_tool(&self, request: Parameters<SlowToolRequest>) -> Json<SlowToolResult> {
164        let Parameters(SlowToolRequest { sleep_ms }) = request;
165        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
166        Json(SlowToolResult { message: format!("Slept for {sleep_ms}ms") })
167    }
168}