1use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use tower::{Layer, ServiceExt};
19use tower_service::Service;
20
21use crate::error::{Error, JsonRpcError, Result};
22use crate::protocol::{
23 JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
24};
25use crate::router::{Extensions, RouterRequest, RouterResponse};
26
27#[derive(Debug, Clone, Copy, Default)]
46pub struct JsonRpcLayer {
47 _priv: (),
48}
49
50impl JsonRpcLayer {
51 pub fn new() -> Self {
53 Self { _priv: () }
54 }
55}
56
57impl<S> Layer<S> for JsonRpcLayer {
58 type Service = JsonRpcService<S>;
59
60 fn layer(&self, inner: S) -> Self::Service {
61 JsonRpcService::new(inner)
62 }
63}
64
65pub struct JsonRpcService<S> {
82 inner: S,
83 extensions: Extensions,
84}
85
86impl<S> JsonRpcService<S> {
87 pub fn new(inner: S) -> Self {
89 Self {
90 inner,
91 extensions: Extensions::new(),
92 }
93 }
94
95 pub fn with_extensions(mut self, ext: Extensions) -> Self {
100 self.extensions = ext;
101 self
102 }
103
104 pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
106 where
107 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
108 + Clone
109 + Send
110 + 'static,
111 S::Future: Send,
112 {
113 process_single_request(self.inner.clone(), req, self.extensions.clone()).await
114 }
115
116 pub async fn call_batch(
118 &mut self,
119 requests: Vec<JsonRpcRequest>,
120 ) -> Result<Vec<JsonRpcResponse>>
121 where
122 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
123 + Clone
124 + Send
125 + 'static,
126 S::Future: Send,
127 {
128 if requests.is_empty() {
129 return Err(Error::JsonRpc(JsonRpcError::invalid_request(
130 "Empty batch request",
131 )));
132 }
133
134 let futures: Vec<_> = requests
136 .into_iter()
137 .map(|req| {
138 let inner = self.inner.clone();
139 let extensions = self.extensions.clone();
140 let req_id = req.id.clone();
141 async move {
142 match process_single_request(inner, req, extensions).await {
143 Ok(resp) => resp,
144 Err(e) => {
145 JsonRpcResponse::error(
147 Some(req_id),
148 JsonRpcError::internal_error(e.to_string()),
149 )
150 }
151 }
152 }
153 })
154 .collect();
155
156 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
157
158 Ok(results)
160 }
161
162 pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
164 where
165 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
166 + Clone
167 + Send
168 + 'static,
169 S::Future: Send,
170 {
171 match msg {
172 JsonRpcMessage::Single(req) => {
173 let response = self.call_single(req).await?;
174 Ok(JsonRpcResponseMessage::Single(response))
175 }
176 JsonRpcMessage::Batch(requests) => {
177 let responses = self.call_batch(requests).await?;
178 Ok(JsonRpcResponseMessage::Batch(responses))
179 }
180 _ => Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
181 None,
182 JsonRpcError::invalid_request("Unsupported message type"),
183 ))),
184 }
185 }
186}
187
188impl<S> Clone for JsonRpcService<S>
189where
190 S: Clone,
191{
192 fn clone(&self) -> Self {
193 Self {
194 inner: self.inner.clone(),
195 extensions: self.extensions.clone(),
196 }
197 }
198}
199
200impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
201where
202 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
203 + Clone
204 + Send
205 + 'static,
206 S::Future: Send,
207{
208 type Response = JsonRpcResponse;
209 type Error = Error;
210 type Future =
211 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
212
213 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
214 self.inner.poll_ready(cx).map_err(|_| unreachable!())
215 }
216
217 fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
218 let inner = self.inner.clone();
219 let extensions = self.extensions.clone();
220 Box::pin(async move {
221 let mcp_request = McpRequest::from_jsonrpc(&req)?;
223
224 let router_req = RouterRequest {
226 id: req.id,
227 inner: mcp_request,
228 extensions,
229 };
230
231 let response = inner.oneshot(router_req).await.unwrap(); Ok(response.into_jsonrpc())
236 })
237 }
238}
239
240impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
242where
243 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
244 + Clone
245 + Send
246 + 'static,
247 S::Future: Send,
248{
249 type Response = JsonRpcResponseMessage;
250 type Error = Error;
251 type Future =
252 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
253
254 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
255 self.inner.poll_ready(cx).map_err(|_| unreachable!())
256 }
257
258 fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
259 let inner = self.inner.clone();
260 let extensions = self.extensions.clone();
261 Box::pin(async move {
262 match msg {
263 JsonRpcMessage::Single(req) => {
264 let response = process_single_request(inner, req, extensions).await?;
265 Ok(JsonRpcResponseMessage::Single(response))
266 }
267 JsonRpcMessage::Batch(requests) => {
268 if requests.is_empty() {
269 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
271 None,
272 JsonRpcError::invalid_request("Empty batch request"),
273 )));
274 }
275
276 let futures: Vec<_> = requests
278 .into_iter()
279 .map(|req| {
280 let inner = inner.clone();
281 let extensions = extensions.clone();
282 let req_id = req.id.clone();
283 async move {
284 match process_single_request(inner, req, extensions).await {
285 Ok(resp) => resp,
286 Err(e) => {
287 JsonRpcResponse::error(
289 Some(req_id),
290 JsonRpcError::internal_error(e.to_string()),
291 )
292 }
293 }
294 }
295 })
296 .collect();
297
298 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
299
300 if results.is_empty() {
302 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
303 None,
304 JsonRpcError::internal_error("All batch requests failed"),
305 )));
306 }
307
308 Ok(JsonRpcResponseMessage::Batch(results))
309 }
310 _ => Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
311 None,
312 JsonRpcError::invalid_request("Unsupported message type"),
313 ))),
314 }
315 })
316 }
317}
318
319async fn process_single_request<S>(
321 inner: S,
322 req: JsonRpcRequest,
323 extensions: Extensions,
324) -> std::result::Result<JsonRpcResponse, Error>
325where
326 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
327 + Send
328 + 'static,
329 S::Future: Send,
330{
331 if let Err(e) = req.validate() {
333 return Ok(JsonRpcResponse::error(Some(req.id), e));
334 }
335
336 let mcp_request = match McpRequest::from_jsonrpc(&req) {
338 Ok(r) => r,
339 Err(e) => {
340 return Ok(JsonRpcResponse::error(
341 Some(req.id),
342 JsonRpcError::invalid_params(e.to_string()),
343 ));
344 }
345 };
346
347 let router_req = RouterRequest {
349 id: req.id,
350 inner: mcp_request,
351 extensions,
352 };
353
354 let response = inner.oneshot(router_req).await.unwrap(); Ok(response.into_jsonrpc())
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::McpRouter;
365 use crate::tool::ToolBuilder;
366 use schemars::JsonSchema;
367 use serde::Deserialize;
368
369 #[derive(Debug, Deserialize, JsonSchema)]
370 struct AddInput {
371 a: i32,
372 b: i32,
373 }
374
375 fn create_test_router() -> McpRouter {
376 let add_tool = ToolBuilder::new("add")
377 .description("Add two numbers")
378 .handler(|input: AddInput| async move {
379 Ok(crate::CallToolResult::text(format!(
380 "{}",
381 input.a + input.b
382 )))
383 })
384 .build();
385
386 McpRouter::new()
387 .server_info("test-server", "1.0.0")
388 .tool(add_tool)
389 }
390
391 #[tokio::test]
392 async fn test_jsonrpc_service() {
393 let router = create_test_router();
394 let mut service = JsonRpcService::new(router.clone());
395
396 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
398 "protocolVersion": "2025-11-25",
399 "capabilities": {},
400 "clientInfo": { "name": "test", "version": "1.0" }
401 }));
402 let resp = service.call_single(init_req).await.unwrap();
403 assert!(matches!(resp, JsonRpcResponse::Result(_)));
404
405 router.handle_notification(crate::protocol::McpNotification::Initialized);
407
408 let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
410 let resp = service.call_single(req).await.unwrap();
411
412 match resp {
413 JsonRpcResponse::Result(r) => {
414 let tools = r.result.get("tools").unwrap().as_array().unwrap();
415 assert_eq!(tools.len(), 1);
416 }
417 JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
418 _ => panic!("unexpected response variant"),
419 }
420 }
421
422 #[tokio::test]
423 async fn test_batch_request() {
424 let router = create_test_router();
425 let mut service = JsonRpcService::new(router.clone());
426
427 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
429 "protocolVersion": "2025-11-25",
430 "capabilities": {},
431 "clientInfo": { "name": "test", "version": "1.0" }
432 }));
433 service.call_single(init_req).await.unwrap();
434 router.handle_notification(crate::protocol::McpNotification::Initialized);
435
436 let requests = vec![
438 JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
439 JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
440 "name": "add",
441 "arguments": { "a": 1, "b": 2 }
442 })),
443 ];
444
445 let responses = service.call_batch(requests).await.unwrap();
446 assert_eq!(responses.len(), 2);
447 }
448
449 #[tokio::test]
450 async fn test_empty_batch_error() {
451 let router = create_test_router();
452 let mut service = JsonRpcService::new(router);
453
454 let result = service.call_batch(vec![]).await;
455 assert!(result.is_err());
456 }
457
458 #[tokio::test]
459 async fn test_jsonrpc_layer() {
460 use tower::ServiceBuilder;
461
462 let router = create_test_router();
463 let router_clone = router.clone();
464
465 let mut service = ServiceBuilder::new()
467 .layer(JsonRpcLayer::new())
468 .service(router);
469
470 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
472 "protocolVersion": "2025-11-25",
473 "capabilities": {},
474 "clientInfo": { "name": "test", "version": "1.0" }
475 }));
476 let resp = Service::<JsonRpcRequest>::call(&mut service, init_req)
477 .await
478 .unwrap();
479 assert!(matches!(resp, JsonRpcResponse::Result(_)));
480
481 router_clone.handle_notification(crate::protocol::McpNotification::Initialized);
482
483 let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
485 let resp = Service::<JsonRpcRequest>::call(&mut service, req)
486 .await
487 .unwrap();
488
489 match resp {
490 JsonRpcResponse::Result(r) => {
491 let tools = r.result.get("tools").unwrap().as_array().unwrap();
492 assert_eq!(tools.len(), 1);
493 }
494 JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
495 _ => panic!("unexpected response variant"),
496 }
497 }
498
499 #[test]
500 fn test_jsonrpc_layer_default() {
501 let _layer = JsonRpcLayer::default();
503 }
504
505 #[test]
506 fn test_jsonrpc_layer_clone() {
507 let layer = JsonRpcLayer::new();
509 let _cloned = layer;
510 let _copied = layer;
511 }
512}