1use std::collections::HashMap;
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tokio::sync::{Mutex, broadcast};
15use tower::{Layer, Service};
16use tower_mcp::router::{RouterRequest, RouterResponse};
17use tower_mcp_types::protocol::McpRequest;
18
19#[derive(Clone)]
21pub struct CoalesceLayer;
22
23impl CoalesceLayer {
24 pub fn new() -> Self {
26 Self
27 }
28}
29
30impl Default for CoalesceLayer {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<S> Layer<S> for CoalesceLayer {
37 type Service = CoalesceService<S>;
38
39 fn layer(&self, inner: S) -> Self::Service {
40 CoalesceService::new(inner)
41 }
42}
43
44#[derive(Clone)]
46pub struct CoalesceService<S> {
47 inner: S,
48 in_flight: Arc<Mutex<HashMap<String, broadcast::Sender<RouterResponse>>>>,
49}
50
51impl<S> CoalesceService<S> {
52 pub fn new(inner: S) -> Self {
54 Self {
55 inner,
56 in_flight: Arc::new(Mutex::new(HashMap::new())),
57 }
58 }
59}
60
61fn coalesce_key(req: &McpRequest) -> Option<String> {
62 match req {
63 McpRequest::CallTool(params) => {
64 let args = serde_json::to_string(¶ms.arguments).unwrap_or_default();
65 Some(format!("tool:{}:{}", params.name, args))
66 }
67 McpRequest::ReadResource(params) => Some(format!("res:{}", params.uri)),
68 _ => None,
69 }
70}
71
72impl<S> Service<RouterRequest> for CoalesceService<S>
73where
74 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
75 + Clone
76 + Send
77 + 'static,
78 S::Future: Send,
79{
80 type Response = RouterResponse;
81 type Error = Infallible;
82 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
83
84 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 self.inner.poll_ready(cx)
86 }
87
88 fn call(&mut self, req: RouterRequest) -> Self::Future {
89 let Some(key) = coalesce_key(&req.inner) else {
90 let fut = self.inner.call(req);
92 return Box::pin(fut);
93 };
94
95 let in_flight = Arc::clone(&self.in_flight);
96 let mut inner = self.inner.clone();
97 let request_id = req.id.clone();
98
99 Box::pin(async move {
100 {
102 let map = in_flight.lock().await;
103 if let Some(tx) = map.get(&key) {
104 let mut rx = tx.subscribe();
105 drop(map);
106 if let Ok(resp) = rx.recv().await {
108 return Ok(RouterResponse {
109 id: request_id,
110 inner: resp.inner,
111 });
112 }
113 }
115 }
116
117 let (tx, _) = broadcast::channel(1);
119 {
120 let mut map = in_flight.lock().await;
121 map.insert(key.clone(), tx.clone());
122 }
123
124 let result = inner.call(req).await;
125
126 let Ok(ref resp) = result;
128 let _ = tx.send(resp.clone());
129 {
130 let mut map = in_flight.lock().await;
131 map.remove(&key);
132 }
133
134 result
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use tower_mcp::protocol::{McpRequest, McpResponse};
142
143 use super::CoalesceService;
144 use crate::test_util::{MockService, call_service};
145
146 #[tokio::test]
147 async fn test_coalesce_passes_through_single_request() {
148 let mock = MockService::with_tools(&["fs/read"]);
149 let mut svc = CoalesceService::new(mock);
150
151 let resp = call_service(
152 &mut svc,
153 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
154 name: "fs/read".to_string(),
155 arguments: serde_json::json!({}),
156 meta: None,
157 task: None,
158 }),
159 )
160 .await;
161
162 match resp.inner.unwrap() {
163 McpResponse::CallTool(r) => assert_eq!(r.all_text(), "called: fs/read"),
164 other => panic!("expected CallTool, got: {:?}", other),
165 }
166 }
167
168 #[tokio::test]
169 async fn test_coalesce_non_coalesceable_passes_through() {
170 let mock = MockService::with_tools(&["tool"]);
171 let mut svc = CoalesceService::new(mock);
172
173 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
174 assert!(resp.inner.is_ok(), "list_tools should pass through");
175 }
176
177 #[tokio::test]
178 async fn test_coalesce_key_includes_arguments() {
179 let key1 =
181 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
182 name: "tool".to_string(),
183 arguments: serde_json::json!({"a": 1}),
184 meta: None,
185 task: None,
186 }));
187 let key2 =
188 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
189 name: "tool".to_string(),
190 arguments: serde_json::json!({"a": 2}),
191 meta: None,
192 task: None,
193 }));
194 assert_ne!(key1, key2, "different args should have different keys");
195 }
196
197 #[tokio::test]
198 async fn test_coalesce_key_same_arguments_produce_same_key() {
199 let key1 =
200 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
201 name: "tool".to_string(),
202 arguments: serde_json::json!({"a": 1}),
203 meta: None,
204 task: None,
205 }));
206 let key2 =
207 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
208 name: "tool".to_string(),
209 arguments: serde_json::json!({"a": 1}),
210 meta: None,
211 task: None,
212 }));
213 assert_eq!(key1, key2, "same tool+args should have the same key");
214 }
215
216 #[tokio::test]
217 async fn test_coalesce_key_read_resource() {
218 let key = super::coalesce_key(&McpRequest::ReadResource(
219 tower_mcp::protocol::ReadResourceParams {
220 uri: "file:///tmp/test.txt".to_string(),
221 meta: None,
222 },
223 ));
224 assert_eq!(key, Some("res:file:///tmp/test.txt".to_string()));
225 }
226
227 #[tokio::test]
228 async fn test_coalesce_key_non_coalesceable_returns_none() {
229 let key = super::coalesce_key(&McpRequest::ListTools(Default::default()));
230 assert!(key.is_none(), "ListTools should not be coalesceable");
231
232 let key = super::coalesce_key(&McpRequest::ListResources(Default::default()));
233 assert!(key.is_none(), "ListResources should not be coalesceable");
234 }
235
236 #[tokio::test]
237 async fn test_concurrent_identical_requests_coalesced() {
238 use std::sync::Arc;
239 use std::sync::atomic::{AtomicUsize, Ordering};
240 use tower::Service;
241
242 #[derive(Clone)]
244 struct CountingService {
245 call_count: Arc<AtomicUsize>,
246 }
247
248 impl Service<tower_mcp::router::RouterRequest> for CountingService {
249 type Response = tower_mcp::router::RouterResponse;
250 type Error = std::convert::Infallible;
251 type Future = std::pin::Pin<
252 Box<
253 dyn std::future::Future<
254 Output = Result<
255 tower_mcp::router::RouterResponse,
256 std::convert::Infallible,
257 >,
258 > + Send,
259 >,
260 >;
261
262 fn poll_ready(
263 &mut self,
264 _cx: &mut std::task::Context<'_>,
265 ) -> std::task::Poll<Result<(), Self::Error>> {
266 std::task::Poll::Ready(Ok(()))
267 }
268
269 fn call(&mut self, req: tower_mcp::router::RouterRequest) -> Self::Future {
270 let count = self.call_count.clone();
271 let id = req.id.clone();
272 Box::pin(async move {
273 count.fetch_add(1, Ordering::SeqCst);
274 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
276 Ok(tower_mcp::router::RouterResponse {
277 id,
278 inner: Ok(McpResponse::CallTool(
279 tower_mcp::protocol::CallToolResult::text("result"),
280 )),
281 })
282 })
283 }
284 }
285
286 let call_count = Arc::new(AtomicUsize::new(0));
287 let svc = CountingService {
288 call_count: call_count.clone(),
289 };
290 let coalesce = CoalesceService::new(svc);
291
292 let make_request = || {
293 let mut c = coalesce.clone();
294 let req = tower_mcp::router::RouterRequest {
295 id: tower_mcp::protocol::RequestId::Number(1),
296 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
297 name: "tool".to_string(),
298 arguments: serde_json::json!({"x": 42}),
299 meta: None,
300 task: None,
301 }),
302 extensions: tower_mcp::router::Extensions::new(),
303 };
304 async move { c.call(req).await }
305 };
306
307 let (r1, r2, r3) = tokio::join!(make_request(), make_request(), make_request());
309
310 assert!(r1.is_ok());
312 assert!(r2.is_ok());
313 assert!(r3.is_ok());
314
315 let count = call_count.load(Ordering::SeqCst);
319 assert!(
320 count < 3,
321 "expected fewer than 3 backend calls due to coalescing, got {count}"
322 );
323 }
324
325 #[tokio::test]
326 async fn test_different_requests_not_coalesced() {
327 let mock = MockService::with_tools(&["tool"]);
328 let coalesce = CoalesceService::new(mock);
329
330 let mut c1 = coalesce.clone();
332 let req1 = tower_mcp::router::RouterRequest {
333 id: tower_mcp::protocol::RequestId::Number(1),
334 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
335 name: "tool".to_string(),
336 arguments: serde_json::json!({"x": 1}),
337 meta: None,
338 task: None,
339 }),
340 extensions: tower_mcp::router::Extensions::new(),
341 };
342
343 let mut c2 = coalesce.clone();
344 let req2 = tower_mcp::router::RouterRequest {
345 id: tower_mcp::protocol::RequestId::Number(2),
346 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
347 name: "tool".to_string(),
348 arguments: serde_json::json!({"x": 2}),
349 meta: None,
350 task: None,
351 }),
352 extensions: tower_mcp::router::Extensions::new(),
353 };
354
355 let (r1, r2) = tokio::join!(
356 tower::Service::call(&mut c1, req1),
357 tower::Service::call(&mut c2, req2)
358 );
359
360 assert!(r1.is_ok());
362 assert!(r2.is_ok());
363 }
364
365 #[tokio::test]
366 async fn test_coalesce_with_error_response() {
367 use crate::test_util::ErrorMockService;
368
369 let mock = ErrorMockService;
370 let mut svc = CoalesceService::new(mock);
371
372 let resp = call_service(
373 &mut svc,
374 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
375 name: "failing_tool".to_string(),
376 arguments: serde_json::json!({}),
377 meta: None,
378 task: None,
379 }),
380 )
381 .await;
382
383 assert!(
385 resp.inner.is_err(),
386 "error response should propagate through coalesce"
387 );
388 let err = resp.inner.unwrap_err();
389 assert_eq!(err.code, -32603);
390 assert_eq!(err.message, "internal error");
391 }
392}