Skip to main content

tower_mcp/
jsonrpc.rs

1//! JSON-RPC 2.0 service layer
2//!
3//! Provides a Tower [`Layer`] and [`Service`] for JSON-RPC framing of MCP requests.
4//!
5//! - [`JsonRpcLayer`] - Tower layer for [`ServiceBuilder`](tower::ServiceBuilder) composition
6//! - [`JsonRpcService`] - Tower service wrapping an MCP router
7//!
8//! The service handles:
9//! - Single request processing
10//! - Batch request processing (concurrent execution)
11//! - JSON-RPC version validation
12//! - Error conversion to JSON-RPC error responses
13
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use tower::Layer;
19use tower_service::Service;
20
21use crate::error::{Error, JsonRpcError, Result};
22use crate::protocol::{
23    JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
24};
25use crate::router::{Extensions, RouterRequest, RouterResponse};
26
27/// Tower layer that adds JSON-RPC 2.0 framing to an MCP service.
28///
29/// This is the standard way to compose `JsonRpcService` with other tower
30/// middleware via [`ServiceBuilder`](tower::ServiceBuilder).
31///
32/// # Example
33///
34/// ```rust
35/// use tower::ServiceBuilder;
36/// use tower_mcp::{McpRouter, JsonRpcLayer, JsonRpcService};
37///
38/// let router = McpRouter::new().server_info("my-server", "1.0.0");
39///
40/// // Compose with ServiceBuilder
41/// let service = ServiceBuilder::new()
42///     .layer(JsonRpcLayer::new())
43///     .service(router);
44/// ```
45#[derive(Debug, Clone, Copy, Default)]
46pub struct JsonRpcLayer {
47    _priv: (),
48}
49
50impl JsonRpcLayer {
51    /// Create a new `JsonRpcLayer`.
52    pub fn new() -> Self {
53        Self { _priv: () }
54    }
55}
56
57impl<S> Layer<S> for JsonRpcLayer {
58    type Service = JsonRpcService<S>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        JsonRpcService::new(inner)
62    }
63}
64
65/// Service that handles JSON-RPC framing.
66///
67/// Wraps an MCP service and handles JSON-RPC request/response conversion.
68/// Supports both single requests and batch requests.
69///
70/// Can be created directly via [`JsonRpcService::new`] or through the
71/// [`JsonRpcLayer`] for [`ServiceBuilder`](tower::ServiceBuilder) composition.
72///
73/// # Example
74///
75/// ```rust
76/// use tower_mcp::{McpRouter, JsonRpcService};
77///
78/// let router = McpRouter::new().server_info("my-server", "1.0.0");
79/// let service = JsonRpcService::new(router);
80/// ```
81pub struct JsonRpcService<S> {
82    inner: S,
83    extensions: Extensions,
84}
85
86impl<S> JsonRpcService<S> {
87    /// Create a new JSON-RPC service wrapping the given inner service
88    pub fn new(inner: S) -> Self {
89        Self {
90            inner,
91            extensions: Extensions::new(),
92        }
93    }
94
95    /// Set extensions to inject into every `RouterRequest` created by this service.
96    ///
97    /// This is used by transports to bridge data (e.g., `TokenClaims`) from the
98    /// HTTP/WebSocket layer into the MCP request pipeline.
99    pub fn with_extensions(mut self, ext: Extensions) -> Self {
100        self.extensions = ext;
101        self
102    }
103
104    /// Process a single JSON-RPC request
105    pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
106    where
107        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
108            + Clone
109            + Send
110            + 'static,
111        S::Future: Send,
112    {
113        process_single_request(self.inner.clone(), req, self.extensions.clone()).await
114    }
115
116    /// Process a batch of JSON-RPC requests concurrently
117    pub async fn call_batch(
118        &mut self,
119        requests: Vec<JsonRpcRequest>,
120    ) -> Result<Vec<JsonRpcResponse>>
121    where
122        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
123            + Clone
124            + Send
125            + 'static,
126        S::Future: Send,
127    {
128        if requests.is_empty() {
129            return Err(Error::JsonRpc(JsonRpcError::invalid_request(
130                "Empty batch request",
131            )));
132        }
133
134        // Process all requests concurrently
135        let futures: Vec<_> = requests
136            .into_iter()
137            .map(|req| {
138                let inner = self.inner.clone();
139                let extensions = self.extensions.clone();
140                let req_id = req.id.clone();
141                async move {
142                    match process_single_request(inner, req, extensions).await {
143                        Ok(resp) => resp,
144                        Err(e) => {
145                            // Convert errors to error responses instead of dropping
146                            JsonRpcResponse::error(
147                                Some(req_id),
148                                JsonRpcError::internal_error(e.to_string()),
149                            )
150                        }
151                    }
152                }
153            })
154            .collect();
155
156        let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
157
158        // Results will never be empty since we converted all errors to responses
159        Ok(results)
160    }
161
162    /// Process a JSON-RPC message (single or batch)
163    pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
164    where
165        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
166            + Clone
167            + Send
168            + 'static,
169        S::Future: Send,
170    {
171        match msg {
172            JsonRpcMessage::Single(req) => {
173                let response = self.call_single(req).await?;
174                Ok(JsonRpcResponseMessage::Single(response))
175            }
176            JsonRpcMessage::Batch(requests) => {
177                let responses = self.call_batch(requests).await?;
178                Ok(JsonRpcResponseMessage::Batch(responses))
179            }
180        }
181    }
182}
183
184impl<S> Clone for JsonRpcService<S>
185where
186    S: Clone,
187{
188    fn clone(&self) -> Self {
189        Self {
190            inner: self.inner.clone(),
191            extensions: self.extensions.clone(),
192        }
193    }
194}
195
196impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
197where
198    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
199        + Clone
200        + Send
201        + 'static,
202    S::Future: Send,
203{
204    type Response = JsonRpcResponse;
205    type Error = Error;
206    type Future =
207        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
208
209    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
210        self.inner.poll_ready(cx).map_err(|_| unreachable!())
211    }
212
213    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
214        let mut inner = self.inner.clone();
215        let extensions = self.extensions.clone();
216        Box::pin(async move {
217            // Parse the MCP request from JSON-RPC
218            let mcp_request = McpRequest::from_jsonrpc(&req)?;
219
220            // Create router request
221            let router_req = RouterRequest {
222                id: req.id,
223                inner: mcp_request,
224                extensions,
225            };
226
227            // Call the inner service
228            let response = inner.call(router_req).await.unwrap(); // Infallible
229
230            // Convert to JSON-RPC response
231            Ok(response.into_jsonrpc())
232        })
233    }
234}
235
236/// Service implementation for JSON-RPC batch requests
237impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
238where
239    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
240        + Clone
241        + Send
242        + 'static,
243    S::Future: Send,
244{
245    type Response = JsonRpcResponseMessage;
246    type Error = Error;
247    type Future =
248        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
249
250    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
251        self.inner.poll_ready(cx).map_err(|_| unreachable!())
252    }
253
254    fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
255        let inner = self.inner.clone();
256        let extensions = self.extensions.clone();
257        Box::pin(async move {
258            match msg {
259                JsonRpcMessage::Single(req) => {
260                    let response = process_single_request(inner, req, extensions).await?;
261                    Ok(JsonRpcResponseMessage::Single(response))
262                }
263                JsonRpcMessage::Batch(requests) => {
264                    if requests.is_empty() {
265                        // Empty batch is an invalid request per JSON-RPC spec
266                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
267                            None,
268                            JsonRpcError::invalid_request("Empty batch request"),
269                        )));
270                    }
271
272                    // Process all requests concurrently
273                    let futures: Vec<_> = requests
274                        .into_iter()
275                        .map(|req| {
276                            let inner = inner.clone();
277                            let extensions = extensions.clone();
278                            let req_id = req.id.clone();
279                            async move {
280                                match process_single_request(inner, req, extensions).await {
281                                    Ok(resp) => resp,
282                                    Err(e) => {
283                                        // Convert errors to error responses instead of dropping
284                                        JsonRpcResponse::error(
285                                            Some(req_id),
286                                            JsonRpcError::internal_error(e.to_string()),
287                                        )
288                                    }
289                                }
290                            }
291                        })
292                        .collect();
293
294                    let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
295
296                    // Empty results only possible if input was empty (already handled above)
297                    if results.is_empty() {
298                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
299                            None,
300                            JsonRpcError::internal_error("All batch requests failed"),
301                        )));
302                    }
303
304                    Ok(JsonRpcResponseMessage::Batch(results))
305                }
306            }
307        })
308    }
309}
310
311/// Helper function to process a single JSON-RPC request
312async fn process_single_request<S>(
313    mut inner: S,
314    req: JsonRpcRequest,
315    extensions: Extensions,
316) -> std::result::Result<JsonRpcResponse, Error>
317where
318    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
319        + Send
320        + 'static,
321    S::Future: Send,
322{
323    // Validate JSON-RPC version
324    if let Err(e) = req.validate() {
325        return Ok(JsonRpcResponse::error(Some(req.id), e));
326    }
327
328    // Parse the MCP request from JSON-RPC
329    let mcp_request = match McpRequest::from_jsonrpc(&req) {
330        Ok(r) => r,
331        Err(e) => {
332            return Ok(JsonRpcResponse::error(
333                Some(req.id),
334                JsonRpcError::invalid_params(e.to_string()),
335            ));
336        }
337    };
338
339    // Create router request
340    let router_req = RouterRequest {
341        id: req.id,
342        inner: mcp_request,
343        extensions,
344    };
345
346    // Call the inner service
347    let response = inner.call(router_req).await.unwrap(); // Infallible
348
349    // Convert to JSON-RPC response
350    Ok(response.into_jsonrpc())
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::McpRouter;
357    use crate::tool::ToolBuilder;
358    use schemars::JsonSchema;
359    use serde::Deserialize;
360
361    #[derive(Debug, Deserialize, JsonSchema)]
362    struct AddInput {
363        a: i32,
364        b: i32,
365    }
366
367    fn create_test_router() -> McpRouter {
368        let add_tool = ToolBuilder::new("add")
369            .description("Add two numbers")
370            .handler(|input: AddInput| async move {
371                Ok(crate::CallToolResult::text(format!(
372                    "{}",
373                    input.a + input.b
374                )))
375            })
376            .build();
377
378        McpRouter::new()
379            .server_info("test-server", "1.0.0")
380            .tool(add_tool)
381    }
382
383    #[tokio::test]
384    async fn test_jsonrpc_service() {
385        let router = create_test_router();
386        let mut service = JsonRpcService::new(router.clone());
387
388        // Initialize first
389        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
390            "protocolVersion": "2025-11-25",
391            "capabilities": {},
392            "clientInfo": { "name": "test", "version": "1.0" }
393        }));
394        let resp = service.call_single(init_req).await.unwrap();
395        assert!(matches!(resp, JsonRpcResponse::Result(_)));
396
397        // Mark as initialized
398        router.handle_notification(crate::protocol::McpNotification::Initialized);
399
400        // Now list tools
401        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
402        let resp = service.call_single(req).await.unwrap();
403
404        match resp {
405            JsonRpcResponse::Result(r) => {
406                let tools = r.result.get("tools").unwrap().as_array().unwrap();
407                assert_eq!(tools.len(), 1);
408            }
409            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
410        }
411    }
412
413    #[tokio::test]
414    async fn test_batch_request() {
415        let router = create_test_router();
416        let mut service = JsonRpcService::new(router.clone());
417
418        // Initialize first
419        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
420            "protocolVersion": "2025-11-25",
421            "capabilities": {},
422            "clientInfo": { "name": "test", "version": "1.0" }
423        }));
424        service.call_single(init_req).await.unwrap();
425        router.handle_notification(crate::protocol::McpNotification::Initialized);
426
427        // Batch request
428        let requests = vec![
429            JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
430            JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
431                "name": "add",
432                "arguments": { "a": 1, "b": 2 }
433            })),
434        ];
435
436        let responses = service.call_batch(requests).await.unwrap();
437        assert_eq!(responses.len(), 2);
438    }
439
440    #[tokio::test]
441    async fn test_empty_batch_error() {
442        let router = create_test_router();
443        let mut service = JsonRpcService::new(router);
444
445        let result = service.call_batch(vec![]).await;
446        assert!(result.is_err());
447    }
448
449    #[tokio::test]
450    async fn test_jsonrpc_layer() {
451        use tower::ServiceBuilder;
452
453        let router = create_test_router();
454        let router_clone = router.clone();
455
456        // Build service using the layer via ServiceBuilder
457        let mut service = ServiceBuilder::new()
458            .layer(JsonRpcLayer::new())
459            .service(router);
460
461        // Initialize
462        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
463            "protocolVersion": "2025-11-25",
464            "capabilities": {},
465            "clientInfo": { "name": "test", "version": "1.0" }
466        }));
467        let resp = Service::<JsonRpcRequest>::call(&mut service, init_req)
468            .await
469            .unwrap();
470        assert!(matches!(resp, JsonRpcResponse::Result(_)));
471
472        router_clone.handle_notification(crate::protocol::McpNotification::Initialized);
473
474        // List tools through the layer-composed service
475        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
476        let resp = Service::<JsonRpcRequest>::call(&mut service, req)
477            .await
478            .unwrap();
479
480        match resp {
481            JsonRpcResponse::Result(r) => {
482                let tools = r.result.get("tools").unwrap().as_array().unwrap();
483                assert_eq!(tools.len(), 1);
484            }
485            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
486        }
487    }
488
489    #[test]
490    fn test_jsonrpc_layer_default() {
491        // JsonRpcLayer implements Default
492        let _layer = JsonRpcLayer::default();
493    }
494
495    #[test]
496    fn test_jsonrpc_layer_clone() {
497        // JsonRpcLayer implements Clone and Copy
498        let layer = JsonRpcLayer::new();
499        let _cloned = layer;
500        let _copied = layer;
501    }
502}