aether_core/testing/
fake_mcp.rs1use rmcp::{
2 ErrorData as McpError, Json, RoleServer, ServerHandler,
3 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
4 model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
5 service::DynService,
6 tool, tool_handler, tool_router,
7};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11use mcp_utils::client::{McpServer, McpTransport};
12
13pub fn fake_mcp(name: &str, server: FakeMcpServer) -> McpServer {
14 fake_mcp_with_proxy(name, server, false)
15}
16
17pub fn fake_mcp_with_proxy(name: &str, server: FakeMcpServer, proxy: bool) -> McpServer {
18 McpServer::new(name, McpTransport::InMemory { server: server.into_dyn() }, proxy)
19}
20
21#[derive(Clone)]
23pub struct FakeMcpServer {
24 tool_router: ToolRouter<Self>,
25}
26
27#[tool_handler(router = self.tool_router)]
28impl ServerHandler for FakeMcpServer {
29 fn get_info(&self) -> ServerInfo {
30 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
31 .with_server_info(
32 Implementation::new("fake-mcp-server", "0.1.0").with_description("A fake MCP server for testing"),
33 )
34 .with_instructions("A fake MCP server for testing")
35 }
36}
37
38#[derive(Serialize, Deserialize, JsonSchema)]
39pub struct AddNumbersRequest {
40 pub a: u32,
41 pub b: u32,
42}
43
44impl AddNumbersRequest {
45 pub fn new(a: u32, b: u32) -> Self {
46 Self { a, b }
47 }
48
49 pub fn json(&self) -> Result<String, serde_json::Error> {
50 serde_json::to_string(self)
51 }
52}
53
54#[derive(Serialize, Deserialize, JsonSchema)]
55pub struct AddNumbersResult {
56 pub sum: u32,
57}
58
59impl AddNumbersResult {
60 pub fn new(sum: u32) -> Self {
61 Self { sum }
62 }
63
64 pub fn json(&self) -> Result<String, serde_json::Error> {
65 serde_json::to_string(self)
66 }
67}
68
69#[derive(Serialize, Deserialize, JsonSchema)]
70pub struct DivideNumbersRequest {
71 pub a: i32,
72 pub b: i32,
73}
74
75impl DivideNumbersRequest {
76 pub fn new(a: i32, b: i32) -> Self {
77 Self { a, b }
78 }
79
80 pub fn json(&self) -> Result<String, serde_json::Error> {
81 serde_json::to_string(self)
82 }
83}
84
85#[derive(Serialize, Deserialize, JsonSchema)]
86pub struct DivideNumbersResult {
87 pub quotient: i32,
88}
89
90impl DivideNumbersResult {
91 pub fn new(quotient: i32) -> Self {
92 Self { quotient }
93 }
94
95 pub fn json(&self) -> Result<String, serde_json::Error> {
96 serde_json::to_string(self)
97 }
98}
99
100#[derive(Serialize, Deserialize, JsonSchema)]
101pub struct SlowToolRequest {
102 pub sleep_ms: u64,
103}
104
105impl SlowToolRequest {
106 pub fn new(sleep_ms: u64) -> Self {
107 Self { sleep_ms }
108 }
109
110 pub fn json(&self) -> Result<String, serde_json::Error> {
111 serde_json::to_string(self)
112 }
113}
114
115#[derive(Serialize, Deserialize, JsonSchema)]
116pub struct SlowToolResult {
117 pub message: String,
118}
119
120impl SlowToolResult {
121 pub fn new(message: String) -> Self {
122 Self { message }
123 }
124
125 pub fn json(&self) -> Result<String, serde_json::Error> {
126 serde_json::to_string(self)
127 }
128}
129
130impl Default for FakeMcpServer {
131 fn default() -> Self {
132 Self { tool_router: Self::tool_router() }
133 }
134}
135
136#[tool_router]
137impl FakeMcpServer {
138 pub fn new() -> Self {
139 Self::default()
140 }
141
142 pub fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
143 Box::new(self)
144 }
145
146 #[tool(description = "Adds two numbers together")]
147 pub async fn add_numbers(&self, request: Parameters<AddNumbersRequest>) -> Json<AddNumbersResult> {
148 let Parameters(AddNumbersRequest { a, b }) = request;
149 Json(AddNumbersResult { sum: a + b })
150 }
151
152 #[tool(description = "Divides two numbers")]
153 pub async fn divide_numbers(&self, request: Parameters<DivideNumbersRequest>) -> Result<CallToolResult, McpError> {
154 let Parameters(DivideNumbersRequest { a, b }) = request;
155
156 if b == 0 {
157 return Ok(CallToolResult::error(vec![Content::text("Division by zero")]));
158 }
159
160 let result = DivideNumbersResult { quotient: a / b };
161 let result_json = serde_json::to_string(&result).unwrap();
162
163 Ok(CallToolResult::success(vec![Content::text(result_json)]))
164 }
165
166 #[tool(description = "A tool that sleeps for a specified duration (for testing timeouts)")]
167 pub async fn slow_tool(&self, request: Parameters<SlowToolRequest>) -> Json<SlowToolResult> {
168 let Parameters(SlowToolRequest { sleep_ms }) = request;
169 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
170 Json(SlowToolResult { message: format!("Slept for {sleep_ms}ms") })
171 }
172}