Skip to main content

tower_mcp/
jsonrpc.rs

1//! JSON-RPC 2.0 service layer
2//!
3//! Provides a Tower service that handles JSON-RPC framing for MCP requests.
4//! This layer wraps an MCP router and handles:
5//! - Single request processing
6//! - Batch request processing (concurrent execution)
7//! - JSON-RPC version validation
8//! - Error conversion to JSON-RPC error responses
9
10use std::future::Future;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tower_service::Service;
15
16use crate::error::{Error, JsonRpcError, Result};
17use crate::protocol::{
18    JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
19};
20use crate::router::{RouterRequest, RouterResponse};
21
22/// Service that handles JSON-RPC framing
23///
24/// Wraps an MCP service and handles JSON-RPC request/response conversion.
25/// Supports both single requests and batch requests.
26///
27/// # Example
28///
29/// ```rust
30/// use tower_mcp::{McpRouter, JsonRpcService};
31///
32/// let router = McpRouter::new().server_info("my-server", "1.0.0");
33/// let service = JsonRpcService::new(router);
34/// ```
35pub struct JsonRpcService<S> {
36    inner: S,
37}
38
39impl<S> JsonRpcService<S> {
40    /// Create a new JSON-RPC service wrapping the given inner service
41    pub fn new(inner: S) -> Self {
42        Self { inner }
43    }
44
45    /// Process a single JSON-RPC request
46    pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
47    where
48        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
49            + Clone
50            + Send
51            + 'static,
52        S::Future: Send,
53    {
54        process_single_request(self.inner.clone(), req).await
55    }
56
57    /// Process a batch of JSON-RPC requests concurrently
58    pub async fn call_batch(
59        &mut self,
60        requests: Vec<JsonRpcRequest>,
61    ) -> Result<Vec<JsonRpcResponse>>
62    where
63        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
64            + Clone
65            + Send
66            + 'static,
67        S::Future: Send,
68    {
69        if requests.is_empty() {
70            return Err(Error::JsonRpc(JsonRpcError::invalid_request(
71                "Empty batch request",
72            )));
73        }
74
75        // Process all requests concurrently
76        let futures: Vec<_> = requests
77            .into_iter()
78            .map(|req| {
79                let inner = self.inner.clone();
80                let req_id = req.id.clone();
81                async move {
82                    match process_single_request(inner, req).await {
83                        Ok(resp) => resp,
84                        Err(e) => {
85                            // Convert errors to error responses instead of dropping
86                            JsonRpcResponse::error(
87                                Some(req_id),
88                                JsonRpcError::internal_error(e.to_string()),
89                            )
90                        }
91                    }
92                }
93            })
94            .collect();
95
96        let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
97
98        // Results will never be empty since we converted all errors to responses
99        Ok(results)
100    }
101
102    /// Process a JSON-RPC message (single or batch)
103    pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
104    where
105        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
106            + Clone
107            + Send
108            + 'static,
109        S::Future: Send,
110    {
111        match msg {
112            JsonRpcMessage::Single(req) => {
113                let response = self.call_single(req).await?;
114                Ok(JsonRpcResponseMessage::Single(response))
115            }
116            JsonRpcMessage::Batch(requests) => {
117                let responses = self.call_batch(requests).await?;
118                Ok(JsonRpcResponseMessage::Batch(responses))
119            }
120        }
121    }
122}
123
124impl<S> Clone for JsonRpcService<S>
125where
126    S: Clone,
127{
128    fn clone(&self) -> Self {
129        Self {
130            inner: self.inner.clone(),
131        }
132    }
133}
134
135impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
136where
137    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
138        + Clone
139        + Send
140        + 'static,
141    S::Future: Send,
142{
143    type Response = JsonRpcResponse;
144    type Error = Error;
145    type Future =
146        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
147
148    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
149        self.inner.poll_ready(cx).map_err(|_| unreachable!())
150    }
151
152    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
153        let mut inner = self.inner.clone();
154        Box::pin(async move {
155            // Parse the MCP request from JSON-RPC
156            let mcp_request = McpRequest::from_jsonrpc(&req)?;
157
158            // Create router request
159            let router_req = RouterRequest {
160                id: req.id,
161                inner: mcp_request,
162            };
163
164            // Call the inner service
165            let response = inner.call(router_req).await.unwrap(); // Infallible
166
167            // Convert to JSON-RPC response
168            Ok(response.into_jsonrpc())
169        })
170    }
171}
172
173/// Service implementation for JSON-RPC batch requests
174impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
175where
176    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
177        + Clone
178        + Send
179        + 'static,
180    S::Future: Send,
181{
182    type Response = JsonRpcResponseMessage;
183    type Error = Error;
184    type Future =
185        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
186
187    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
188        self.inner.poll_ready(cx).map_err(|_| unreachable!())
189    }
190
191    fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
192        let inner = self.inner.clone();
193        Box::pin(async move {
194            match msg {
195                JsonRpcMessage::Single(req) => {
196                    let response = process_single_request(inner, req).await?;
197                    Ok(JsonRpcResponseMessage::Single(response))
198                }
199                JsonRpcMessage::Batch(requests) => {
200                    if requests.is_empty() {
201                        // Empty batch is an invalid request per JSON-RPC spec
202                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
203                            None,
204                            JsonRpcError::invalid_request("Empty batch request"),
205                        )));
206                    }
207
208                    // Process all requests concurrently
209                    let futures: Vec<_> = requests
210                        .into_iter()
211                        .map(|req| {
212                            let inner = inner.clone();
213                            let req_id = req.id.clone();
214                            async move {
215                                match process_single_request(inner, req).await {
216                                    Ok(resp) => resp,
217                                    Err(e) => {
218                                        // Convert errors to error responses instead of dropping
219                                        JsonRpcResponse::error(
220                                            Some(req_id),
221                                            JsonRpcError::internal_error(e.to_string()),
222                                        )
223                                    }
224                                }
225                            }
226                        })
227                        .collect();
228
229                    let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
230
231                    // Empty results only possible if input was empty (already handled above)
232                    if results.is_empty() {
233                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
234                            None,
235                            JsonRpcError::internal_error("All batch requests failed"),
236                        )));
237                    }
238
239                    Ok(JsonRpcResponseMessage::Batch(results))
240                }
241            }
242        })
243    }
244}
245
246/// Helper function to process a single JSON-RPC request
247async fn process_single_request<S>(
248    mut inner: S,
249    req: JsonRpcRequest,
250) -> std::result::Result<JsonRpcResponse, Error>
251where
252    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
253        + Send
254        + 'static,
255    S::Future: Send,
256{
257    // Validate JSON-RPC version
258    if let Err(e) = req.validate() {
259        return Ok(JsonRpcResponse::error(Some(req.id), e));
260    }
261
262    // Parse the MCP request from JSON-RPC
263    let mcp_request = match McpRequest::from_jsonrpc(&req) {
264        Ok(r) => r,
265        Err(e) => {
266            return Ok(JsonRpcResponse::error(
267                Some(req.id),
268                JsonRpcError::invalid_params(e.to_string()),
269            ));
270        }
271    };
272
273    // Create router request
274    let router_req = RouterRequest {
275        id: req.id,
276        inner: mcp_request,
277    };
278
279    // Call the inner service
280    let response = inner.call(router_req).await.unwrap(); // Infallible
281
282    // Convert to JSON-RPC response
283    Ok(response.into_jsonrpc())
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::McpRouter;
290    use crate::tool::ToolBuilder;
291    use schemars::JsonSchema;
292    use serde::Deserialize;
293
294    #[derive(Debug, Deserialize, JsonSchema)]
295    struct AddInput {
296        a: i32,
297        b: i32,
298    }
299
300    fn create_test_router() -> McpRouter {
301        let add_tool = ToolBuilder::new("add")
302            .description("Add two numbers")
303            .handler(|input: AddInput| async move {
304                Ok(crate::CallToolResult::text(format!(
305                    "{}",
306                    input.a + input.b
307                )))
308            })
309            .build()
310            .unwrap();
311
312        McpRouter::new()
313            .server_info("test-server", "1.0.0")
314            .tool(add_tool)
315    }
316
317    #[tokio::test]
318    async fn test_jsonrpc_service() {
319        let router = create_test_router();
320        let mut service = JsonRpcService::new(router.clone());
321
322        // Initialize first
323        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
324            "protocolVersion": "2025-03-26",
325            "capabilities": {},
326            "clientInfo": { "name": "test", "version": "1.0" }
327        }));
328        let resp = service.call_single(init_req).await.unwrap();
329        assert!(matches!(resp, JsonRpcResponse::Result(_)));
330
331        // Mark as initialized
332        router.handle_notification(crate::protocol::McpNotification::Initialized);
333
334        // Now list tools
335        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
336        let resp = service.call_single(req).await.unwrap();
337
338        match resp {
339            JsonRpcResponse::Result(r) => {
340                let tools = r.result.get("tools").unwrap().as_array().unwrap();
341                assert_eq!(tools.len(), 1);
342            }
343            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
344        }
345    }
346
347    #[tokio::test]
348    async fn test_batch_request() {
349        let router = create_test_router();
350        let mut service = JsonRpcService::new(router.clone());
351
352        // Initialize first
353        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
354            "protocolVersion": "2025-03-26",
355            "capabilities": {},
356            "clientInfo": { "name": "test", "version": "1.0" }
357        }));
358        service.call_single(init_req).await.unwrap();
359        router.handle_notification(crate::protocol::McpNotification::Initialized);
360
361        // Batch request
362        let requests = vec![
363            JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
364            JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
365                "name": "add",
366                "arguments": { "a": 1, "b": 2 }
367            })),
368        ];
369
370        let responses = service.call_batch(requests).await.unwrap();
371        assert_eq!(responses.len(), 2);
372    }
373
374    #[tokio::test]
375    async fn test_empty_batch_error() {
376        let router = create_test_router();
377        let mut service = JsonRpcService::new(router);
378
379        let result = service.call_batch(vec![]).await;
380        assert!(result.is_err());
381    }
382}