Skip to main content

aether_core/testing/
fake_mcp.rs

1use rmcp::{
2    ErrorData as McpError, Json, RoleServer, ServerHandler,
3    handler::server::{router::tool::ToolRouter, wrapper::Parameters},
4    model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
5    service::DynService,
6    tool, tool_handler, tool_router,
7};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11use mcp_utils::client::ServerConfig;
12
13pub fn fake_mcp(name: &str, server: FakeMcpServer) -> ServerConfig {
14    ServerConfig::InMemory {
15        name: name.to_string(),
16        server: server.as_dyn(),
17    }
18}
19
20/// A fake MCP server for testing
21#[derive(Clone)]
22pub struct FakeMcpServer {
23    tool_router: ToolRouter<Self>,
24}
25
26#[tool_handler(router = self.tool_router)]
27impl ServerHandler for FakeMcpServer {
28    fn get_info(&self) -> ServerInfo {
29        ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
30            .with_server_info(
31                Implementation::new("fake-mcp-server", "0.1.0")
32                    .with_description("A fake MCP server for testing"),
33            )
34            .with_instructions("A fake MCP server for testing")
35    }
36}
37
38#[derive(Serialize, Deserialize, JsonSchema)]
39pub struct AddNumbersRequest {
40    pub a: u32,
41    pub b: u32,
42}
43
44impl AddNumbersRequest {
45    pub fn new(a: u32, b: u32) -> Self {
46        Self { a, b }
47    }
48
49    pub fn json(&self) -> Result<String, serde_json::Error> {
50        serde_json::to_string(self)
51    }
52}
53
54#[derive(Serialize, Deserialize, JsonSchema)]
55pub struct AddNumbersResult {
56    pub sum: u32,
57}
58
59impl AddNumbersResult {
60    pub fn new(sum: u32) -> Self {
61        Self { sum }
62    }
63
64    pub fn json(&self) -> Result<String, serde_json::Error> {
65        serde_json::to_string(self)
66    }
67}
68
69#[derive(Serialize, Deserialize, JsonSchema)]
70pub struct DivideNumbersRequest {
71    pub a: i32,
72    pub b: i32,
73}
74
75impl DivideNumbersRequest {
76    pub fn new(a: i32, b: i32) -> Self {
77        Self { a, b }
78    }
79
80    pub fn json(&self) -> Result<String, serde_json::Error> {
81        serde_json::to_string(self)
82    }
83}
84
85#[derive(Serialize, Deserialize, JsonSchema)]
86pub struct DivideNumbersResult {
87    pub quotient: i32,
88}
89
90impl DivideNumbersResult {
91    pub fn new(quotient: i32) -> Self {
92        Self { quotient }
93    }
94
95    pub fn json(&self) -> Result<String, serde_json::Error> {
96        serde_json::to_string(self)
97    }
98}
99
100#[derive(Serialize, Deserialize, JsonSchema)]
101pub struct SlowToolRequest {
102    pub sleep_ms: u64,
103}
104
105impl SlowToolRequest {
106    pub fn new(sleep_ms: u64) -> Self {
107        Self { sleep_ms }
108    }
109
110    pub fn json(&self) -> Result<String, serde_json::Error> {
111        serde_json::to_string(self)
112    }
113}
114
115#[derive(Serialize, Deserialize, JsonSchema)]
116pub struct SlowToolResult {
117    pub message: String,
118}
119
120impl SlowToolResult {
121    pub fn new(message: String) -> Self {
122        Self { message }
123    }
124
125    pub fn json(&self) -> Result<String, serde_json::Error> {
126        serde_json::to_string(self)
127    }
128}
129
130impl Default for FakeMcpServer {
131    fn default() -> Self {
132        Self {
133            tool_router: Self::tool_router(),
134        }
135    }
136}
137
138#[tool_router]
139impl FakeMcpServer {
140    pub fn new() -> Self {
141        Self::default()
142    }
143
144    pub fn as_dyn(self) -> Box<dyn DynService<RoleServer>> {
145        Box::new(self)
146    }
147
148    #[tool(description = "Adds two numbers together")]
149    pub async fn add_numbers(
150        &self,
151        request: Parameters<AddNumbersRequest>,
152    ) -> Json<AddNumbersResult> {
153        let Parameters(AddNumbersRequest { a, b }) = request;
154        Json(AddNumbersResult { sum: a + b })
155    }
156
157    #[tool(description = "Divides two numbers")]
158    pub async fn divide_numbers(
159        &self,
160        request: Parameters<DivideNumbersRequest>,
161    ) -> Result<CallToolResult, McpError> {
162        let Parameters(DivideNumbersRequest { a, b }) = request;
163
164        if b == 0 {
165            return Ok(CallToolResult::error(vec![Content::text(
166                "Division by zero",
167            )]));
168        }
169
170        let result = DivideNumbersResult { quotient: a / b };
171        let result_json = serde_json::to_string(&result).unwrap();
172
173        Ok(CallToolResult::success(vec![Content::text(result_json)]))
174    }
175
176    #[tool(description = "A tool that sleeps for a specified duration (for testing timeouts)")]
177    pub async fn slow_tool(&self, request: Parameters<SlowToolRequest>) -> Json<SlowToolResult> {
178        let Parameters(SlowToolRequest { sleep_ms }) = request;
179        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
180        Json(SlowToolResult {
181            message: format!("Slept for {sleep_ms}ms"),
182        })
183    }
184}