aether_core/testing/
fake_mcp.rs1use rmcp::{
2 ErrorData as McpError, Json, RoleServer, ServerHandler,
3 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
4 model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
5 service::DynService,
6 tool, tool_handler, tool_router,
7};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11use mcp_utils::client::ServerConfig;
12
13pub fn fake_mcp(name: &str, server: FakeMcpServer) -> ServerConfig {
14 ServerConfig::InMemory { name: name.to_string(), server: server.into_dyn() }
15}
16
17#[derive(Clone)]
19pub struct FakeMcpServer {
20 tool_router: ToolRouter<Self>,
21}
22
23#[tool_handler(router = self.tool_router)]
24impl ServerHandler for FakeMcpServer {
25 fn get_info(&self) -> ServerInfo {
26 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
27 .with_server_info(
28 Implementation::new("fake-mcp-server", "0.1.0").with_description("A fake MCP server for testing"),
29 )
30 .with_instructions("A fake MCP server for testing")
31 }
32}
33
34#[derive(Serialize, Deserialize, JsonSchema)]
35pub struct AddNumbersRequest {
36 pub a: u32,
37 pub b: u32,
38}
39
40impl AddNumbersRequest {
41 pub fn new(a: u32, b: u32) -> Self {
42 Self { a, b }
43 }
44
45 pub fn json(&self) -> Result<String, serde_json::Error> {
46 serde_json::to_string(self)
47 }
48}
49
50#[derive(Serialize, Deserialize, JsonSchema)]
51pub struct AddNumbersResult {
52 pub sum: u32,
53}
54
55impl AddNumbersResult {
56 pub fn new(sum: u32) -> Self {
57 Self { sum }
58 }
59
60 pub fn json(&self) -> Result<String, serde_json::Error> {
61 serde_json::to_string(self)
62 }
63}
64
65#[derive(Serialize, Deserialize, JsonSchema)]
66pub struct DivideNumbersRequest {
67 pub a: i32,
68 pub b: i32,
69}
70
71impl DivideNumbersRequest {
72 pub fn new(a: i32, b: i32) -> Self {
73 Self { a, b }
74 }
75
76 pub fn json(&self) -> Result<String, serde_json::Error> {
77 serde_json::to_string(self)
78 }
79}
80
81#[derive(Serialize, Deserialize, JsonSchema)]
82pub struct DivideNumbersResult {
83 pub quotient: i32,
84}
85
86impl DivideNumbersResult {
87 pub fn new(quotient: i32) -> Self {
88 Self { quotient }
89 }
90
91 pub fn json(&self) -> Result<String, serde_json::Error> {
92 serde_json::to_string(self)
93 }
94}
95
96#[derive(Serialize, Deserialize, JsonSchema)]
97pub struct SlowToolRequest {
98 pub sleep_ms: u64,
99}
100
101impl SlowToolRequest {
102 pub fn new(sleep_ms: u64) -> Self {
103 Self { sleep_ms }
104 }
105
106 pub fn json(&self) -> Result<String, serde_json::Error> {
107 serde_json::to_string(self)
108 }
109}
110
111#[derive(Serialize, Deserialize, JsonSchema)]
112pub struct SlowToolResult {
113 pub message: String,
114}
115
116impl SlowToolResult {
117 pub fn new(message: String) -> Self {
118 Self { message }
119 }
120
121 pub fn json(&self) -> Result<String, serde_json::Error> {
122 serde_json::to_string(self)
123 }
124}
125
126impl Default for FakeMcpServer {
127 fn default() -> Self {
128 Self { tool_router: Self::tool_router() }
129 }
130}
131
132#[tool_router]
133impl FakeMcpServer {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn into_dyn(self) -> Box<dyn DynService<RoleServer>> {
139 Box::new(self)
140 }
141
142 #[tool(description = "Adds two numbers together")]
143 pub async fn add_numbers(&self, request: Parameters<AddNumbersRequest>) -> Json<AddNumbersResult> {
144 let Parameters(AddNumbersRequest { a, b }) = request;
145 Json(AddNumbersResult { sum: a + b })
146 }
147
148 #[tool(description = "Divides two numbers")]
149 pub async fn divide_numbers(&self, request: Parameters<DivideNumbersRequest>) -> Result<CallToolResult, McpError> {
150 let Parameters(DivideNumbersRequest { a, b }) = request;
151
152 if b == 0 {
153 return Ok(CallToolResult::error(vec![Content::text("Division by zero")]));
154 }
155
156 let result = DivideNumbersResult { quotient: a / b };
157 let result_json = serde_json::to_string(&result).unwrap();
158
159 Ok(CallToolResult::success(vec![Content::text(result_json)]))
160 }
161
162 #[tool(description = "A tool that sleeps for a specified duration (for testing timeouts)")]
163 pub async fn slow_tool(&self, request: Parameters<SlowToolRequest>) -> Json<SlowToolResult> {
164 let Parameters(SlowToolRequest { sleep_ms }) = request;
165 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
166 Json(SlowToolResult { message: format!("Slept for {sleep_ms}ms") })
167 }
168}