Skip to main content

dravr_tronc/mcp/
server.rs

1// ABOUTME: Generic MCP server that routes JSON-RPC requests to protocol handlers and tools
2// ABOUTME: Implements initialize, tools/list, tools/call, and ping — parameterized over state S
3
4use std::sync::Arc;
5
6use serde_json::Value;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::error::{INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, METHOD_NOT_FOUND};
11use crate::mcp::protocol::{
12    CallToolParams, InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse,
13    ServerCapabilities, ServerInfo, ToolsCapability, ToolsListResult, PROTOCOL_VERSION,
14};
15use crate::mcp::tool::ToolRegistry;
16
17/// MCP server that dispatches JSON-RPC requests to the appropriate handler
18///
19/// Generic over `S` — the project-specific server state type.
20/// Owns the shared state and tool registry. Transport layers feed parsed
21/// requests into `handle_request` and send the returned responses.
22pub struct McpServer<S: Send + Sync> {
23    name: String,
24    version: String,
25    state: Arc<RwLock<S>>,
26    tools: ToolRegistry<S>,
27}
28
29impl<S: Send + Sync + 'static> McpServer<S> {
30    /// Create a server with the given name, version, tool registry, and shared state
31    pub fn new(
32        name: impl Into<String>,
33        version: impl Into<String>,
34        tools: ToolRegistry<S>,
35        state: Arc<RwLock<S>>,
36    ) -> Self {
37        Self {
38            name: name.into(),
39            version: version.into(),
40            state,
41            tools,
42        }
43    }
44
45    /// Route a raw JSON string to the appropriate MCP handler
46    ///
47    /// Parses the string as a `JsonRpcRequest`, dispatches it, and returns
48    /// the serialized response. Returns `None` for notifications.
49    pub async fn handle_raw(&self, raw: &str) -> Option<JsonRpcResponse> {
50        let request: JsonRpcRequest = match serde_json::from_str(raw) {
51            Ok(req) => req,
52            Err(e) => {
53                return Some(JsonRpcResponse::error(
54                    None,
55                    crate::error::PARSE_ERROR,
56                    format!("Parse error: {e}"),
57                ));
58            }
59        };
60        self.handle_request(request).await
61    }
62
63    /// Route a parsed JSON-RPC request to the appropriate MCP handler
64    ///
65    /// Returns `None` for notifications (requests without an id).
66    pub async fn handle_request(&self, request: JsonRpcRequest) -> Option<JsonRpcResponse> {
67        // Validate JSON-RPC protocol version
68        if request.jsonrpc != "2.0" {
69            return Some(JsonRpcResponse::error(
70                request.id,
71                INVALID_REQUEST,
72                format!("Unsupported JSON-RPC version: {}", request.jsonrpc),
73            ));
74        }
75
76        // Notifications have no id and expect no response
77        if request.id.is_none() {
78            debug!(method = %request.method, "Received notification, no response");
79            return None;
80        }
81
82        let response = match request.method.as_str() {
83            "initialize" => self.handle_initialize(request.id, request.params),
84            "tools/list" => self.handle_tools_list(request.id),
85            "tools/call" => self.handle_tools_call(request.id, request.params).await,
86            "ping" => JsonRpcResponse::success(request.id, Value::Object(serde_json::Map::new())),
87            method => {
88                debug!(method, "Unknown MCP method");
89                JsonRpcResponse::error(
90                    request.id,
91                    METHOD_NOT_FOUND,
92                    format!("Method not found: {method}"),
93                )
94            }
95        };
96
97        Some(response)
98    }
99
100    /// Handle `initialize` — parse client info and return server capabilities
101    fn handle_initialize(&self, id: Option<Value>, params: Option<Value>) -> JsonRpcResponse {
102        if let Some(params) = params {
103            if let Ok(init) = serde_json::from_value::<InitializeParams>(params) {
104                debug!(
105                    client = %init.client_info.name,
106                    version = ?init.client_info.version,
107                    protocol = %init.protocol_version,
108                    "MCP client connected"
109                );
110            }
111        }
112
113        let result = InitializeResult {
114            protocol_version: PROTOCOL_VERSION.to_owned(),
115            capabilities: ServerCapabilities {
116                tools: Some(ToolsCapability {}),
117            },
118            server_info: ServerInfo {
119                name: self.name.clone(),
120                version: self.version.clone(),
121            },
122        };
123
124        match serde_json::to_value(result) {
125            Ok(val) => JsonRpcResponse::success(id, val),
126            Err(e) => {
127                JsonRpcResponse::error(id, INTERNAL_ERROR, format!("Serialization error: {e}"))
128            }
129        }
130    }
131
132    /// Handle `tools/list` — return all registered tool definitions
133    fn handle_tools_list(&self, id: Option<Value>) -> JsonRpcResponse {
134        let result = ToolsListResult {
135            tools: self.tools.list_definitions(),
136        };
137
138        match serde_json::to_value(result) {
139            Ok(val) => JsonRpcResponse::success(id, val),
140            Err(e) => {
141                JsonRpcResponse::error(id, INTERNAL_ERROR, format!("Serialization error: {e}"))
142            }
143        }
144    }
145
146    /// Handle `tools/call` — dispatch to the named tool handler
147    async fn handle_tools_call(&self, id: Option<Value>, params: Option<Value>) -> JsonRpcResponse {
148        let call_params: CallToolParams = match params {
149            Some(p) => match serde_json::from_value(p) {
150                Ok(cp) => cp,
151                Err(e) => {
152                    return JsonRpcResponse::error(
153                        id,
154                        INVALID_PARAMS,
155                        format!("Invalid params: {e}"),
156                    );
157                }
158            },
159            None => {
160                return JsonRpcResponse::error(
161                    id,
162                    INVALID_PARAMS,
163                    "Missing params for tools/call".to_owned(),
164                );
165            }
166        };
167
168        let arguments = call_params
169            .arguments
170            .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
171
172        let result = self
173            .tools
174            .execute(&call_params.name, &self.state, arguments)
175            .await;
176
177        match serde_json::to_value(result) {
178            Ok(val) => JsonRpcResponse::success(id, val),
179            Err(e) => JsonRpcResponse::error(
180                id,
181                INTERNAL_ERROR,
182                format!("Result serialization error: {e}"),
183            ),
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::mcp::protocol::CallToolResult;
192    use crate::mcp::tool::McpTool;
193    use serde_json::json;
194
195    struct TestState;
196
197    struct PingTool;
198
199    #[async_trait::async_trait]
200    impl McpTool<TestState> for PingTool {
201        fn definition(&self) -> crate::mcp::protocol::ToolDefinition {
202            crate::mcp::protocol::ToolDefinition {
203                name: "ping_tool".to_owned(),
204                description: "Returns pong".to_owned(),
205                input_schema: json!({"type": "object"}),
206            }
207        }
208
209        async fn execute(
210            &self,
211            _state: &Arc<RwLock<TestState>>,
212            _arguments: Value,
213        ) -> CallToolResult {
214            CallToolResult::text("pong".to_owned())
215        }
216    }
217
218    fn make_server() -> McpServer<TestState> {
219        let mut registry = ToolRegistry::new();
220        registry.register(Box::new(PingTool));
221        let state = Arc::new(RwLock::new(TestState));
222        McpServer::new("test-server", "0.1.0", registry, state)
223    }
224
225    #[tokio::test]
226    async fn handle_initialize() {
227        let server = make_server();
228        let raw = r#"{
229            "jsonrpc": "2.0",
230            "id": 1,
231            "method": "initialize",
232            "params": {
233                "protocolVersion": "2024-11-05",
234                "capabilities": {},
235                "clientInfo": { "name": "test-client" }
236            }
237        }"#;
238        let resp = server.handle_raw(raw).await.expect("response");
239        let result = resp.result.expect("result");
240        assert_eq!(result["protocolVersion"], "2024-11-05");
241        assert_eq!(result["serverInfo"]["name"], "test-server");
242        assert_eq!(result["serverInfo"]["version"], "0.1.0");
243    }
244
245    #[tokio::test]
246    async fn handle_initialize_without_params() {
247        let server = make_server();
248        let raw = r#"{"jsonrpc": "2.0", "id": 1, "method": "initialize"}"#;
249        let resp = server.handle_raw(raw).await.expect("response");
250        assert!(resp.result.is_some());
251        assert!(resp.error.is_none());
252    }
253
254    #[tokio::test]
255    async fn handle_tools_list() {
256        let server = make_server();
257        let raw = r#"{"jsonrpc": "2.0", "id": 2, "method": "tools/list"}"#;
258        let resp = server.handle_raw(raw).await.expect("response");
259        let result = resp.result.expect("result");
260        let tools = result["tools"].as_array().expect("tools array");
261        assert_eq!(tools.len(), 1);
262        assert_eq!(tools[0]["name"], "ping_tool");
263    }
264
265    #[tokio::test]
266    async fn handle_tools_call() {
267        let server = make_server();
268        let raw = r#"{
269            "jsonrpc": "2.0",
270            "id": 3,
271            "method": "tools/call",
272            "params": { "name": "ping_tool", "arguments": {} }
273        }"#;
274        let resp = server.handle_raw(raw).await.expect("response");
275        let result = resp.result.expect("result");
276        assert_eq!(result["content"][0]["text"], "pong");
277    }
278
279    #[tokio::test]
280    async fn handle_tools_call_unknown_tool() {
281        let server = make_server();
282        let raw = r#"{
283            "jsonrpc": "2.0",
284            "id": 4,
285            "method": "tools/call",
286            "params": { "name": "nonexistent" }
287        }"#;
288        let resp = server.handle_raw(raw).await.expect("response");
289        let result = resp.result.expect("result");
290        assert_eq!(result["isError"], true);
291        assert!(result["content"][0]["text"]
292            .as_str()
293            .expect("text")
294            .contains("Unknown tool"));
295    }
296
297    #[tokio::test]
298    async fn handle_tools_call_missing_params() {
299        let server = make_server();
300        let raw = r#"{"jsonrpc": "2.0", "id": 5, "method": "tools/call"}"#;
301        let resp = server.handle_raw(raw).await.expect("response");
302        let err = resp.error.expect("error");
303        assert_eq!(err.code, INVALID_PARAMS);
304    }
305
306    #[tokio::test]
307    async fn handle_ping() {
308        let server = make_server();
309        let raw = r#"{"jsonrpc": "2.0", "id": 6, "method": "ping"}"#;
310        let resp = server.handle_raw(raw).await.expect("response");
311        assert!(resp.result.is_some());
312        assert!(resp.error.is_none());
313    }
314
315    #[tokio::test]
316    async fn handle_unknown_method() {
317        let server = make_server();
318        let raw = r#"{"jsonrpc": "2.0", "id": 7, "method": "bogus/method"}"#;
319        let resp = server.handle_raw(raw).await.expect("response");
320        let err = resp.error.expect("error");
321        assert_eq!(err.code, METHOD_NOT_FOUND);
322        assert!(err.message.contains("bogus/method"));
323    }
324
325    #[tokio::test]
326    async fn handle_invalid_json() {
327        let server = make_server();
328        let resp = server
329            .handle_raw("not json at all")
330            .await
331            .expect("response");
332        let err = resp.error.expect("error");
333        assert_eq!(err.code, crate::error::PARSE_ERROR);
334    }
335
336    #[tokio::test]
337    async fn handle_wrong_jsonrpc_version() {
338        let server = make_server();
339        let raw = r#"{"jsonrpc": "1.0", "id": 8, "method": "ping"}"#;
340        let resp = server.handle_raw(raw).await.expect("response");
341        let err = resp.error.expect("error");
342        assert_eq!(err.code, INVALID_REQUEST);
343    }
344
345    #[tokio::test]
346    async fn notification_returns_none() {
347        let server = make_server();
348        let raw = r#"{"jsonrpc": "2.0", "method": "notifications/cancelled"}"#;
349        let resp = server.handle_raw(raw).await;
350        assert!(resp.is_none());
351    }
352
353    #[tokio::test]
354    async fn response_id_matches_request_id() {
355        let server = make_server();
356        let raw = r#"{"jsonrpc": "2.0", "id": 999, "method": "ping"}"#;
357        let resp = server.handle_raw(raw).await.expect("response");
358        assert_eq!(resp.id, Some(Value::from(999)));
359    }
360
361    #[tokio::test]
362    async fn tools_call_with_no_arguments_defaults_to_empty_object() {
363        let server = make_server();
364        let raw = r#"{
365            "jsonrpc": "2.0",
366            "id": 10,
367            "method": "tools/call",
368            "params": { "name": "ping_tool" }
369        }"#;
370        let resp = server.handle_raw(raw).await.expect("response");
371        let result = resp.result.expect("result");
372        assert_eq!(result["content"][0]["text"], "pong");
373    }
374
375    #[tokio::test]
376    async fn tools_call_with_invalid_params_structure() {
377        let server = make_server();
378        let raw = r#"{
379            "jsonrpc": "2.0",
380            "id": 11,
381            "method": "tools/call",
382            "params": "not an object"
383        }"#;
384        let resp = server.handle_raw(raw).await.expect("response");
385        let err = resp.error.expect("error");
386        assert_eq!(err.code, INVALID_PARAMS);
387    }
388}