1use std::sync::Arc;
5
6use serde_json::Value;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::error::{INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, METHOD_NOT_FOUND};
11use crate::mcp::protocol::{
12 CallToolParams, InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse,
13 ServerCapabilities, ServerInfo, ToolsCapability, ToolsListResult, PROTOCOL_VERSION,
14};
15use crate::mcp::tool::ToolRegistry;
16
17pub struct McpServer<S: Send + Sync> {
23 name: String,
24 version: String,
25 state: Arc<RwLock<S>>,
26 tools: ToolRegistry<S>,
27}
28
29impl<S: Send + Sync + 'static> McpServer<S> {
30 pub fn new(
32 name: impl Into<String>,
33 version: impl Into<String>,
34 tools: ToolRegistry<S>,
35 state: Arc<RwLock<S>>,
36 ) -> Self {
37 Self {
38 name: name.into(),
39 version: version.into(),
40 state,
41 tools,
42 }
43 }
44
45 pub async fn handle_raw(&self, raw: &str) -> Option<JsonRpcResponse> {
50 let request: JsonRpcRequest = match serde_json::from_str(raw) {
51 Ok(req) => req,
52 Err(e) => {
53 return Some(JsonRpcResponse::error(
54 None,
55 crate::error::PARSE_ERROR,
56 format!("Parse error: {e}"),
57 ));
58 }
59 };
60 self.handle_request(request).await
61 }
62
63 pub async fn handle_request(&self, request: JsonRpcRequest) -> Option<JsonRpcResponse> {
67 if request.jsonrpc != "2.0" {
69 return Some(JsonRpcResponse::error(
70 request.id,
71 INVALID_REQUEST,
72 format!("Unsupported JSON-RPC version: {}", request.jsonrpc),
73 ));
74 }
75
76 if request.id.is_none() {
78 debug!(method = %request.method, "Received notification, no response");
79 return None;
80 }
81
82 let response = match request.method.as_str() {
83 "initialize" => self.handle_initialize(request.id, request.params),
84 "tools/list" => self.handle_tools_list(request.id),
85 "tools/call" => self.handle_tools_call(request.id, request.params).await,
86 "ping" => JsonRpcResponse::success(request.id, Value::Object(serde_json::Map::new())),
87 method => {
88 debug!(method, "Unknown MCP method");
89 JsonRpcResponse::error(
90 request.id,
91 METHOD_NOT_FOUND,
92 format!("Method not found: {method}"),
93 )
94 }
95 };
96
97 Some(response)
98 }
99
100 fn handle_initialize(&self, id: Option<Value>, params: Option<Value>) -> JsonRpcResponse {
102 if let Some(params) = params {
103 if let Ok(init) = serde_json::from_value::<InitializeParams>(params) {
104 debug!(
105 client = %init.client_info.name,
106 version = ?init.client_info.version,
107 protocol = %init.protocol_version,
108 "MCP client connected"
109 );
110 }
111 }
112
113 let result = InitializeResult {
114 protocol_version: PROTOCOL_VERSION.to_owned(),
115 capabilities: ServerCapabilities {
116 tools: Some(ToolsCapability {}),
117 },
118 server_info: ServerInfo {
119 name: self.name.clone(),
120 version: self.version.clone(),
121 },
122 };
123
124 match serde_json::to_value(result) {
125 Ok(val) => JsonRpcResponse::success(id, val),
126 Err(e) => {
127 JsonRpcResponse::error(id, INTERNAL_ERROR, format!("Serialization error: {e}"))
128 }
129 }
130 }
131
132 fn handle_tools_list(&self, id: Option<Value>) -> JsonRpcResponse {
134 let result = ToolsListResult {
135 tools: self.tools.list_definitions(),
136 };
137
138 match serde_json::to_value(result) {
139 Ok(val) => JsonRpcResponse::success(id, val),
140 Err(e) => {
141 JsonRpcResponse::error(id, INTERNAL_ERROR, format!("Serialization error: {e}"))
142 }
143 }
144 }
145
146 async fn handle_tools_call(&self, id: Option<Value>, params: Option<Value>) -> JsonRpcResponse {
148 let call_params: CallToolParams = match params {
149 Some(p) => match serde_json::from_value(p) {
150 Ok(cp) => cp,
151 Err(e) => {
152 return JsonRpcResponse::error(
153 id,
154 INVALID_PARAMS,
155 format!("Invalid params: {e}"),
156 );
157 }
158 },
159 None => {
160 return JsonRpcResponse::error(
161 id,
162 INVALID_PARAMS,
163 "Missing params for tools/call".to_owned(),
164 );
165 }
166 };
167
168 let arguments = call_params
169 .arguments
170 .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
171
172 let result = self
173 .tools
174 .execute(&call_params.name, &self.state, arguments)
175 .await;
176
177 match serde_json::to_value(result) {
178 Ok(val) => JsonRpcResponse::success(id, val),
179 Err(e) => JsonRpcResponse::error(
180 id,
181 INTERNAL_ERROR,
182 format!("Result serialization error: {e}"),
183 ),
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::mcp::protocol::CallToolResult;
192 use crate::mcp::tool::McpTool;
193 use serde_json::json;
194
195 struct TestState;
196
197 struct PingTool;
198
199 #[async_trait::async_trait]
200 impl McpTool<TestState> for PingTool {
201 fn definition(&self) -> crate::mcp::protocol::ToolDefinition {
202 crate::mcp::protocol::ToolDefinition {
203 name: "ping_tool".to_owned(),
204 description: "Returns pong".to_owned(),
205 input_schema: json!({"type": "object"}),
206 }
207 }
208
209 async fn execute(
210 &self,
211 _state: &Arc<RwLock<TestState>>,
212 _arguments: Value,
213 ) -> CallToolResult {
214 CallToolResult::text("pong".to_owned())
215 }
216 }
217
218 fn make_server() -> McpServer<TestState> {
219 let mut registry = ToolRegistry::new();
220 registry.register(Box::new(PingTool));
221 let state = Arc::new(RwLock::new(TestState));
222 McpServer::new("test-server", "0.1.0", registry, state)
223 }
224
225 #[tokio::test]
226 async fn handle_initialize() {
227 let server = make_server();
228 let raw = r#"{
229 "jsonrpc": "2.0",
230 "id": 1,
231 "method": "initialize",
232 "params": {
233 "protocolVersion": "2024-11-05",
234 "capabilities": {},
235 "clientInfo": { "name": "test-client" }
236 }
237 }"#;
238 let resp = server.handle_raw(raw).await.expect("response");
239 let result = resp.result.expect("result");
240 assert_eq!(result["protocolVersion"], "2024-11-05");
241 assert_eq!(result["serverInfo"]["name"], "test-server");
242 assert_eq!(result["serverInfo"]["version"], "0.1.0");
243 }
244
245 #[tokio::test]
246 async fn handle_initialize_without_params() {
247 let server = make_server();
248 let raw = r#"{"jsonrpc": "2.0", "id": 1, "method": "initialize"}"#;
249 let resp = server.handle_raw(raw).await.expect("response");
250 assert!(resp.result.is_some());
251 assert!(resp.error.is_none());
252 }
253
254 #[tokio::test]
255 async fn handle_tools_list() {
256 let server = make_server();
257 let raw = r#"{"jsonrpc": "2.0", "id": 2, "method": "tools/list"}"#;
258 let resp = server.handle_raw(raw).await.expect("response");
259 let result = resp.result.expect("result");
260 let tools = result["tools"].as_array().expect("tools array");
261 assert_eq!(tools.len(), 1);
262 assert_eq!(tools[0]["name"], "ping_tool");
263 }
264
265 #[tokio::test]
266 async fn handle_tools_call() {
267 let server = make_server();
268 let raw = r#"{
269 "jsonrpc": "2.0",
270 "id": 3,
271 "method": "tools/call",
272 "params": { "name": "ping_tool", "arguments": {} }
273 }"#;
274 let resp = server.handle_raw(raw).await.expect("response");
275 let result = resp.result.expect("result");
276 assert_eq!(result["content"][0]["text"], "pong");
277 }
278
279 #[tokio::test]
280 async fn handle_tools_call_unknown_tool() {
281 let server = make_server();
282 let raw = r#"{
283 "jsonrpc": "2.0",
284 "id": 4,
285 "method": "tools/call",
286 "params": { "name": "nonexistent" }
287 }"#;
288 let resp = server.handle_raw(raw).await.expect("response");
289 let result = resp.result.expect("result");
290 assert_eq!(result["isError"], true);
291 assert!(result["content"][0]["text"]
292 .as_str()
293 .expect("text")
294 .contains("Unknown tool"));
295 }
296
297 #[tokio::test]
298 async fn handle_tools_call_missing_params() {
299 let server = make_server();
300 let raw = r#"{"jsonrpc": "2.0", "id": 5, "method": "tools/call"}"#;
301 let resp = server.handle_raw(raw).await.expect("response");
302 let err = resp.error.expect("error");
303 assert_eq!(err.code, INVALID_PARAMS);
304 }
305
306 #[tokio::test]
307 async fn handle_ping() {
308 let server = make_server();
309 let raw = r#"{"jsonrpc": "2.0", "id": 6, "method": "ping"}"#;
310 let resp = server.handle_raw(raw).await.expect("response");
311 assert!(resp.result.is_some());
312 assert!(resp.error.is_none());
313 }
314
315 #[tokio::test]
316 async fn handle_unknown_method() {
317 let server = make_server();
318 let raw = r#"{"jsonrpc": "2.0", "id": 7, "method": "bogus/method"}"#;
319 let resp = server.handle_raw(raw).await.expect("response");
320 let err = resp.error.expect("error");
321 assert_eq!(err.code, METHOD_NOT_FOUND);
322 assert!(err.message.contains("bogus/method"));
323 }
324
325 #[tokio::test]
326 async fn handle_invalid_json() {
327 let server = make_server();
328 let resp = server
329 .handle_raw("not json at all")
330 .await
331 .expect("response");
332 let err = resp.error.expect("error");
333 assert_eq!(err.code, crate::error::PARSE_ERROR);
334 }
335
336 #[tokio::test]
337 async fn handle_wrong_jsonrpc_version() {
338 let server = make_server();
339 let raw = r#"{"jsonrpc": "1.0", "id": 8, "method": "ping"}"#;
340 let resp = server.handle_raw(raw).await.expect("response");
341 let err = resp.error.expect("error");
342 assert_eq!(err.code, INVALID_REQUEST);
343 }
344
345 #[tokio::test]
346 async fn notification_returns_none() {
347 let server = make_server();
348 let raw = r#"{"jsonrpc": "2.0", "method": "notifications/cancelled"}"#;
349 let resp = server.handle_raw(raw).await;
350 assert!(resp.is_none());
351 }
352
353 #[tokio::test]
354 async fn response_id_matches_request_id() {
355 let server = make_server();
356 let raw = r#"{"jsonrpc": "2.0", "id": 999, "method": "ping"}"#;
357 let resp = server.handle_raw(raw).await.expect("response");
358 assert_eq!(resp.id, Some(Value::from(999)));
359 }
360
361 #[tokio::test]
362 async fn tools_call_with_no_arguments_defaults_to_empty_object() {
363 let server = make_server();
364 let raw = r#"{
365 "jsonrpc": "2.0",
366 "id": 10,
367 "method": "tools/call",
368 "params": { "name": "ping_tool" }
369 }"#;
370 let resp = server.handle_raw(raw).await.expect("response");
371 let result = resp.result.expect("result");
372 assert_eq!(result["content"][0]["text"], "pong");
373 }
374
375 #[tokio::test]
376 async fn tools_call_with_invalid_params_structure() {
377 let server = make_server();
378 let raw = r#"{
379 "jsonrpc": "2.0",
380 "id": 11,
381 "method": "tools/call",
382 "params": "not an object"
383 }"#;
384 let resp = server.handle_raw(raw).await.expect("response");
385 let err = resp.error.expect("error");
386 assert_eq!(err.code, INVALID_PARAMS);
387 }
388}