1use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use tower::Layer;
19use tower_service::Service;
20
21use crate::error::{Error, JsonRpcError, Result};
22use crate::protocol::{
23 JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
24};
25use crate::router::{Extensions, RouterRequest, RouterResponse};
26
27#[derive(Debug, Clone, Copy, Default)]
46pub struct JsonRpcLayer {
47 _priv: (),
48}
49
50impl JsonRpcLayer {
51 pub fn new() -> Self {
53 Self { _priv: () }
54 }
55}
56
57impl<S> Layer<S> for JsonRpcLayer {
58 type Service = JsonRpcService<S>;
59
60 fn layer(&self, inner: S) -> Self::Service {
61 JsonRpcService::new(inner)
62 }
63}
64
65pub struct JsonRpcService<S> {
82 inner: S,
83 extensions: Extensions,
84}
85
86impl<S> JsonRpcService<S> {
87 pub fn new(inner: S) -> Self {
89 Self {
90 inner,
91 extensions: Extensions::new(),
92 }
93 }
94
95 pub fn with_extensions(mut self, ext: Extensions) -> Self {
100 self.extensions = ext;
101 self
102 }
103
104 pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
106 where
107 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
108 + Clone
109 + Send
110 + 'static,
111 S::Future: Send,
112 {
113 process_single_request(self.inner.clone(), req, self.extensions.clone()).await
114 }
115
116 pub async fn call_batch(
118 &mut self,
119 requests: Vec<JsonRpcRequest>,
120 ) -> Result<Vec<JsonRpcResponse>>
121 where
122 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
123 + Clone
124 + Send
125 + 'static,
126 S::Future: Send,
127 {
128 if requests.is_empty() {
129 return Err(Error::JsonRpc(JsonRpcError::invalid_request(
130 "Empty batch request",
131 )));
132 }
133
134 let futures: Vec<_> = requests
136 .into_iter()
137 .map(|req| {
138 let inner = self.inner.clone();
139 let extensions = self.extensions.clone();
140 let req_id = req.id.clone();
141 async move {
142 match process_single_request(inner, req, extensions).await {
143 Ok(resp) => resp,
144 Err(e) => {
145 JsonRpcResponse::error(
147 Some(req_id),
148 JsonRpcError::internal_error(e.to_string()),
149 )
150 }
151 }
152 }
153 })
154 .collect();
155
156 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
157
158 Ok(results)
160 }
161
162 pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
164 where
165 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
166 + Clone
167 + Send
168 + 'static,
169 S::Future: Send,
170 {
171 match msg {
172 JsonRpcMessage::Single(req) => {
173 let response = self.call_single(req).await?;
174 Ok(JsonRpcResponseMessage::Single(response))
175 }
176 JsonRpcMessage::Batch(requests) => {
177 let responses = self.call_batch(requests).await?;
178 Ok(JsonRpcResponseMessage::Batch(responses))
179 }
180 }
181 }
182}
183
184impl<S> Clone for JsonRpcService<S>
185where
186 S: Clone,
187{
188 fn clone(&self) -> Self {
189 Self {
190 inner: self.inner.clone(),
191 extensions: self.extensions.clone(),
192 }
193 }
194}
195
196impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
197where
198 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
199 + Clone
200 + Send
201 + 'static,
202 S::Future: Send,
203{
204 type Response = JsonRpcResponse;
205 type Error = Error;
206 type Future =
207 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
208
209 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
210 self.inner.poll_ready(cx).map_err(|_| unreachable!())
211 }
212
213 fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
214 let mut inner = self.inner.clone();
215 let extensions = self.extensions.clone();
216 Box::pin(async move {
217 let mcp_request = McpRequest::from_jsonrpc(&req)?;
219
220 let router_req = RouterRequest {
222 id: req.id,
223 inner: mcp_request,
224 extensions,
225 };
226
227 let response = inner.call(router_req).await.unwrap(); Ok(response.into_jsonrpc())
232 })
233 }
234}
235
236impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
238where
239 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
240 + Clone
241 + Send
242 + 'static,
243 S::Future: Send,
244{
245 type Response = JsonRpcResponseMessage;
246 type Error = Error;
247 type Future =
248 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
249
250 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
251 self.inner.poll_ready(cx).map_err(|_| unreachable!())
252 }
253
254 fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
255 let inner = self.inner.clone();
256 let extensions = self.extensions.clone();
257 Box::pin(async move {
258 match msg {
259 JsonRpcMessage::Single(req) => {
260 let response = process_single_request(inner, req, extensions).await?;
261 Ok(JsonRpcResponseMessage::Single(response))
262 }
263 JsonRpcMessage::Batch(requests) => {
264 if requests.is_empty() {
265 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
267 None,
268 JsonRpcError::invalid_request("Empty batch request"),
269 )));
270 }
271
272 let futures: Vec<_> = requests
274 .into_iter()
275 .map(|req| {
276 let inner = inner.clone();
277 let extensions = extensions.clone();
278 let req_id = req.id.clone();
279 async move {
280 match process_single_request(inner, req, extensions).await {
281 Ok(resp) => resp,
282 Err(e) => {
283 JsonRpcResponse::error(
285 Some(req_id),
286 JsonRpcError::internal_error(e.to_string()),
287 )
288 }
289 }
290 }
291 })
292 .collect();
293
294 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
295
296 if results.is_empty() {
298 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
299 None,
300 JsonRpcError::internal_error("All batch requests failed"),
301 )));
302 }
303
304 Ok(JsonRpcResponseMessage::Batch(results))
305 }
306 }
307 })
308 }
309}
310
311async fn process_single_request<S>(
313 mut inner: S,
314 req: JsonRpcRequest,
315 extensions: Extensions,
316) -> std::result::Result<JsonRpcResponse, Error>
317where
318 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
319 + Send
320 + 'static,
321 S::Future: Send,
322{
323 if let Err(e) = req.validate() {
325 return Ok(JsonRpcResponse::error(Some(req.id), e));
326 }
327
328 let mcp_request = match McpRequest::from_jsonrpc(&req) {
330 Ok(r) => r,
331 Err(e) => {
332 return Ok(JsonRpcResponse::error(
333 Some(req.id),
334 JsonRpcError::invalid_params(e.to_string()),
335 ));
336 }
337 };
338
339 let router_req = RouterRequest {
341 id: req.id,
342 inner: mcp_request,
343 extensions,
344 };
345
346 let response = inner.call(router_req).await.unwrap(); Ok(response.into_jsonrpc())
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::McpRouter;
357 use crate::tool::ToolBuilder;
358 use schemars::JsonSchema;
359 use serde::Deserialize;
360
361 #[derive(Debug, Deserialize, JsonSchema)]
362 struct AddInput {
363 a: i32,
364 b: i32,
365 }
366
367 fn create_test_router() -> McpRouter {
368 let add_tool = ToolBuilder::new("add")
369 .description("Add two numbers")
370 .handler(|input: AddInput| async move {
371 Ok(crate::CallToolResult::text(format!(
372 "{}",
373 input.a + input.b
374 )))
375 })
376 .build()
377 .unwrap();
378
379 McpRouter::new()
380 .server_info("test-server", "1.0.0")
381 .tool(add_tool)
382 }
383
384 #[tokio::test]
385 async fn test_jsonrpc_service() {
386 let router = create_test_router();
387 let mut service = JsonRpcService::new(router.clone());
388
389 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
391 "protocolVersion": "2025-11-25",
392 "capabilities": {},
393 "clientInfo": { "name": "test", "version": "1.0" }
394 }));
395 let resp = service.call_single(init_req).await.unwrap();
396 assert!(matches!(resp, JsonRpcResponse::Result(_)));
397
398 router.handle_notification(crate::protocol::McpNotification::Initialized);
400
401 let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
403 let resp = service.call_single(req).await.unwrap();
404
405 match resp {
406 JsonRpcResponse::Result(r) => {
407 let tools = r.result.get("tools").unwrap().as_array().unwrap();
408 assert_eq!(tools.len(), 1);
409 }
410 JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
411 }
412 }
413
414 #[tokio::test]
415 async fn test_batch_request() {
416 let router = create_test_router();
417 let mut service = JsonRpcService::new(router.clone());
418
419 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
421 "protocolVersion": "2025-11-25",
422 "capabilities": {},
423 "clientInfo": { "name": "test", "version": "1.0" }
424 }));
425 service.call_single(init_req).await.unwrap();
426 router.handle_notification(crate::protocol::McpNotification::Initialized);
427
428 let requests = vec![
430 JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
431 JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
432 "name": "add",
433 "arguments": { "a": 1, "b": 2 }
434 })),
435 ];
436
437 let responses = service.call_batch(requests).await.unwrap();
438 assert_eq!(responses.len(), 2);
439 }
440
441 #[tokio::test]
442 async fn test_empty_batch_error() {
443 let router = create_test_router();
444 let mut service = JsonRpcService::new(router);
445
446 let result = service.call_batch(vec![]).await;
447 assert!(result.is_err());
448 }
449
450 #[tokio::test]
451 async fn test_jsonrpc_layer() {
452 use tower::ServiceBuilder;
453
454 let router = create_test_router();
455 let router_clone = router.clone();
456
457 let mut service = ServiceBuilder::new()
459 .layer(JsonRpcLayer::new())
460 .service(router);
461
462 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
464 "protocolVersion": "2025-11-25",
465 "capabilities": {},
466 "clientInfo": { "name": "test", "version": "1.0" }
467 }));
468 let resp = Service::<JsonRpcRequest>::call(&mut service, init_req)
469 .await
470 .unwrap();
471 assert!(matches!(resp, JsonRpcResponse::Result(_)));
472
473 router_clone.handle_notification(crate::protocol::McpNotification::Initialized);
474
475 let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
477 let resp = Service::<JsonRpcRequest>::call(&mut service, req)
478 .await
479 .unwrap();
480
481 match resp {
482 JsonRpcResponse::Result(r) => {
483 let tools = r.result.get("tools").unwrap().as_array().unwrap();
484 assert_eq!(tools.len(), 1);
485 }
486 JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
487 }
488 }
489
490 #[test]
491 fn test_jsonrpc_layer_default() {
492 let _layer = JsonRpcLayer::default();
494 }
495
496 #[test]
497 fn test_jsonrpc_layer_clone() {
498 let layer = JsonRpcLayer::new();
500 let _cloned = layer;
501 let _copied = layer;
502 }
503}