aether_core/testing/
fake_mcp.rs1use rmcp::{
2 ErrorData as McpError, Json, RoleServer, ServerHandler,
3 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
4 model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
5 service::DynService,
6 tool, tool_handler, tool_router,
7};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11use mcp_utils::client::ServerConfig;
12
13pub fn fake_mcp(name: &str, server: FakeMcpServer) -> ServerConfig {
14 ServerConfig::InMemory {
15 name: name.to_string(),
16 server: server.as_dyn(),
17 }
18}
19
20#[derive(Clone)]
22pub struct FakeMcpServer {
23 tool_router: ToolRouter<Self>,
24}
25
26#[tool_handler(router = self.tool_router)]
27impl ServerHandler for FakeMcpServer {
28 fn get_info(&self) -> ServerInfo {
29 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
30 .with_server_info(
31 Implementation::new("fake-mcp-server", "0.1.0")
32 .with_description("A fake MCP server for testing"),
33 )
34 .with_instructions("A fake MCP server for testing")
35 }
36}
37
38#[derive(Serialize, Deserialize, JsonSchema)]
39pub struct AddNumbersRequest {
40 pub a: u32,
41 pub b: u32,
42}
43
44impl AddNumbersRequest {
45 pub fn new(a: u32, b: u32) -> Self {
46 Self { a, b }
47 }
48
49 pub fn json(&self) -> Result<String, serde_json::Error> {
50 serde_json::to_string(self)
51 }
52}
53
54#[derive(Serialize, Deserialize, JsonSchema)]
55pub struct AddNumbersResult {
56 pub sum: u32,
57}
58
59impl AddNumbersResult {
60 pub fn new(sum: u32) -> Self {
61 Self { sum }
62 }
63
64 pub fn json(&self) -> Result<String, serde_json::Error> {
65 serde_json::to_string(self)
66 }
67}
68
69#[derive(Serialize, Deserialize, JsonSchema)]
70pub struct DivideNumbersRequest {
71 pub a: i32,
72 pub b: i32,
73}
74
75impl DivideNumbersRequest {
76 pub fn new(a: i32, b: i32) -> Self {
77 Self { a, b }
78 }
79
80 pub fn json(&self) -> Result<String, serde_json::Error> {
81 serde_json::to_string(self)
82 }
83}
84
85#[derive(Serialize, Deserialize, JsonSchema)]
86pub struct DivideNumbersResult {
87 pub quotient: i32,
88}
89
90impl DivideNumbersResult {
91 pub fn new(quotient: i32) -> Self {
92 Self { quotient }
93 }
94
95 pub fn json(&self) -> Result<String, serde_json::Error> {
96 serde_json::to_string(self)
97 }
98}
99
100#[derive(Serialize, Deserialize, JsonSchema)]
101pub struct SlowToolRequest {
102 pub sleep_ms: u64,
103}
104
105impl SlowToolRequest {
106 pub fn new(sleep_ms: u64) -> Self {
107 Self { sleep_ms }
108 }
109
110 pub fn json(&self) -> Result<String, serde_json::Error> {
111 serde_json::to_string(self)
112 }
113}
114
115#[derive(Serialize, Deserialize, JsonSchema)]
116pub struct SlowToolResult {
117 pub message: String,
118}
119
120impl SlowToolResult {
121 pub fn new(message: String) -> Self {
122 Self { message }
123 }
124
125 pub fn json(&self) -> Result<String, serde_json::Error> {
126 serde_json::to_string(self)
127 }
128}
129
130impl Default for FakeMcpServer {
131 fn default() -> Self {
132 Self {
133 tool_router: Self::tool_router(),
134 }
135 }
136}
137
138#[tool_router]
139impl FakeMcpServer {
140 pub fn new() -> Self {
141 Self::default()
142 }
143
144 pub fn as_dyn(self) -> Box<dyn DynService<RoleServer>> {
145 Box::new(self)
146 }
147
148 #[tool(description = "Adds two numbers together")]
149 pub async fn add_numbers(
150 &self,
151 request: Parameters<AddNumbersRequest>,
152 ) -> Json<AddNumbersResult> {
153 let Parameters(AddNumbersRequest { a, b }) = request;
154 Json(AddNumbersResult { sum: a + b })
155 }
156
157 #[tool(description = "Divides two numbers")]
158 pub async fn divide_numbers(
159 &self,
160 request: Parameters<DivideNumbersRequest>,
161 ) -> Result<CallToolResult, McpError> {
162 let Parameters(DivideNumbersRequest { a, b }) = request;
163
164 if b == 0 {
165 return Ok(CallToolResult::error(vec![Content::text(
166 "Division by zero",
167 )]));
168 }
169
170 let result = DivideNumbersResult { quotient: a / b };
171 let result_json = serde_json::to_string(&result).unwrap();
172
173 Ok(CallToolResult::success(vec![Content::text(result_json)]))
174 }
175
176 #[tool(description = "A tool that sleeps for a specified duration (for testing timeouts)")]
177 pub async fn slow_tool(&self, request: Parameters<SlowToolRequest>) -> Json<SlowToolResult> {
178 let Parameters(SlowToolRequest { sleep_ms }) = request;
179 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
180 Json(SlowToolResult {
181 message: format!("Slept for {sleep_ms}ms"),
182 })
183 }
184}