1use std::future::Future;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tower_service::Service;
15
16use crate::error::{Error, JsonRpcError, Result};
17use crate::protocol::{
18 JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
19};
20use crate::router::{RouterRequest, RouterResponse};
21
22pub struct JsonRpcService<S> {
36 inner: S,
37}
38
39impl<S> JsonRpcService<S> {
40 pub fn new(inner: S) -> Self {
42 Self { inner }
43 }
44
45 pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
47 where
48 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
49 + Clone
50 + Send
51 + 'static,
52 S::Future: Send,
53 {
54 process_single_request(self.inner.clone(), req).await
55 }
56
57 pub async fn call_batch(
59 &mut self,
60 requests: Vec<JsonRpcRequest>,
61 ) -> Result<Vec<JsonRpcResponse>>
62 where
63 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
64 + Clone
65 + Send
66 + 'static,
67 S::Future: Send,
68 {
69 if requests.is_empty() {
70 return Err(Error::JsonRpc(JsonRpcError::invalid_request(
71 "Empty batch request",
72 )));
73 }
74
75 let futures: Vec<_> = requests
77 .into_iter()
78 .map(|req| {
79 let inner = self.inner.clone();
80 let req_id = req.id.clone();
81 async move {
82 match process_single_request(inner, req).await {
83 Ok(resp) => resp,
84 Err(e) => {
85 JsonRpcResponse::error(
87 Some(req_id),
88 JsonRpcError::internal_error(e.to_string()),
89 )
90 }
91 }
92 }
93 })
94 .collect();
95
96 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
97
98 Ok(results)
100 }
101
102 pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
104 where
105 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
106 + Clone
107 + Send
108 + 'static,
109 S::Future: Send,
110 {
111 match msg {
112 JsonRpcMessage::Single(req) => {
113 let response = self.call_single(req).await?;
114 Ok(JsonRpcResponseMessage::Single(response))
115 }
116 JsonRpcMessage::Batch(requests) => {
117 let responses = self.call_batch(requests).await?;
118 Ok(JsonRpcResponseMessage::Batch(responses))
119 }
120 }
121 }
122}
123
124impl<S> Clone for JsonRpcService<S>
125where
126 S: Clone,
127{
128 fn clone(&self) -> Self {
129 Self {
130 inner: self.inner.clone(),
131 }
132 }
133}
134
135impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
136where
137 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
138 + Clone
139 + Send
140 + 'static,
141 S::Future: Send,
142{
143 type Response = JsonRpcResponse;
144 type Error = Error;
145 type Future =
146 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
147
148 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
149 self.inner.poll_ready(cx).map_err(|_| unreachable!())
150 }
151
152 fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
153 let mut inner = self.inner.clone();
154 Box::pin(async move {
155 let mcp_request = McpRequest::from_jsonrpc(&req)?;
157
158 let router_req = RouterRequest {
160 id: req.id,
161 inner: mcp_request,
162 };
163
164 let response = inner.call(router_req).await.unwrap(); Ok(response.into_jsonrpc())
169 })
170 }
171}
172
173impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
175where
176 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
177 + Clone
178 + Send
179 + 'static,
180 S::Future: Send,
181{
182 type Response = JsonRpcResponseMessage;
183 type Error = Error;
184 type Future =
185 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
186
187 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
188 self.inner.poll_ready(cx).map_err(|_| unreachable!())
189 }
190
191 fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
192 let inner = self.inner.clone();
193 Box::pin(async move {
194 match msg {
195 JsonRpcMessage::Single(req) => {
196 let response = process_single_request(inner, req).await?;
197 Ok(JsonRpcResponseMessage::Single(response))
198 }
199 JsonRpcMessage::Batch(requests) => {
200 if requests.is_empty() {
201 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
203 None,
204 JsonRpcError::invalid_request("Empty batch request"),
205 )));
206 }
207
208 let futures: Vec<_> = requests
210 .into_iter()
211 .map(|req| {
212 let inner = inner.clone();
213 let req_id = req.id.clone();
214 async move {
215 match process_single_request(inner, req).await {
216 Ok(resp) => resp,
217 Err(e) => {
218 JsonRpcResponse::error(
220 Some(req_id),
221 JsonRpcError::internal_error(e.to_string()),
222 )
223 }
224 }
225 }
226 })
227 .collect();
228
229 let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
230
231 if results.is_empty() {
233 return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
234 None,
235 JsonRpcError::internal_error("All batch requests failed"),
236 )));
237 }
238
239 Ok(JsonRpcResponseMessage::Batch(results))
240 }
241 }
242 })
243 }
244}
245
246async fn process_single_request<S>(
248 mut inner: S,
249 req: JsonRpcRequest,
250) -> std::result::Result<JsonRpcResponse, Error>
251where
252 S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
253 + Send
254 + 'static,
255 S::Future: Send,
256{
257 if let Err(e) = req.validate() {
259 return Ok(JsonRpcResponse::error(Some(req.id), e));
260 }
261
262 let mcp_request = match McpRequest::from_jsonrpc(&req) {
264 Ok(r) => r,
265 Err(e) => {
266 return Ok(JsonRpcResponse::error(
267 Some(req.id),
268 JsonRpcError::invalid_params(e.to_string()),
269 ));
270 }
271 };
272
273 let router_req = RouterRequest {
275 id: req.id,
276 inner: mcp_request,
277 };
278
279 let response = inner.call(router_req).await.unwrap(); Ok(response.into_jsonrpc())
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::McpRouter;
290 use crate::tool::ToolBuilder;
291 use schemars::JsonSchema;
292 use serde::Deserialize;
293
294 #[derive(Debug, Deserialize, JsonSchema)]
295 struct AddInput {
296 a: i32,
297 b: i32,
298 }
299
300 fn create_test_router() -> McpRouter {
301 let add_tool = ToolBuilder::new("add")
302 .description("Add two numbers")
303 .handler(|input: AddInput| async move {
304 Ok(crate::CallToolResult::text(format!(
305 "{}",
306 input.a + input.b
307 )))
308 })
309 .build()
310 .unwrap();
311
312 McpRouter::new()
313 .server_info("test-server", "1.0.0")
314 .tool(add_tool)
315 }
316
317 #[tokio::test]
318 async fn test_jsonrpc_service() {
319 let router = create_test_router();
320 let mut service = JsonRpcService::new(router.clone());
321
322 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
324 "protocolVersion": "2025-03-26",
325 "capabilities": {},
326 "clientInfo": { "name": "test", "version": "1.0" }
327 }));
328 let resp = service.call_single(init_req).await.unwrap();
329 assert!(matches!(resp, JsonRpcResponse::Result(_)));
330
331 router.handle_notification(crate::protocol::McpNotification::Initialized);
333
334 let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
336 let resp = service.call_single(req).await.unwrap();
337
338 match resp {
339 JsonRpcResponse::Result(r) => {
340 let tools = r.result.get("tools").unwrap().as_array().unwrap();
341 assert_eq!(tools.len(), 1);
342 }
343 JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
344 }
345 }
346
347 #[tokio::test]
348 async fn test_batch_request() {
349 let router = create_test_router();
350 let mut service = JsonRpcService::new(router.clone());
351
352 let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
354 "protocolVersion": "2025-03-26",
355 "capabilities": {},
356 "clientInfo": { "name": "test", "version": "1.0" }
357 }));
358 service.call_single(init_req).await.unwrap();
359 router.handle_notification(crate::protocol::McpNotification::Initialized);
360
361 let requests = vec![
363 JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
364 JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
365 "name": "add",
366 "arguments": { "a": 1, "b": 2 }
367 })),
368 ];
369
370 let responses = service.call_batch(requests).await.unwrap();
371 assert_eq!(responses.len(), 2);
372 }
373
374 #[tokio::test]
375 async fn test_empty_batch_error() {
376 let router = create_test_router();
377 let mut service = JsonRpcService::new(router);
378
379 let result = service.call_batch(vec![]).await;
380 assert!(result.is_err());
381 }
382}