Skip to main content

mcp_proxy/
coalesce.rs

1//! Request coalescing middleware for the proxy.
2//!
3//! Deduplicates identical in-flight `CallTool` and `ReadResource` requests.
4//! When multiple identical requests arrive concurrently, only one is forwarded
5//! to the backend; all callers receive the same response.
6
7use std::collections::HashMap;
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tokio::sync::{Mutex, broadcast};
15use tower::{Layer, Service};
16use tower_mcp::router::{RouterRequest, RouterResponse};
17use tower_mcp_types::protocol::McpRequest;
18
19/// Tower layer that produces a [`CoalesceService`].
20#[derive(Clone)]
21pub struct CoalesceLayer;
22
23impl CoalesceLayer {
24    /// Create a new request coalescing layer.
25    pub fn new() -> Self {
26        Self
27    }
28}
29
30impl Default for CoalesceLayer {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<S> Layer<S> for CoalesceLayer {
37    type Service = CoalesceService<S>;
38
39    fn layer(&self, inner: S) -> Self::Service {
40        CoalesceService::new(inner)
41    }
42}
43
44/// Tower service that coalesces identical in-flight requests.
45#[derive(Clone)]
46pub struct CoalesceService<S> {
47    inner: S,
48    in_flight: Arc<Mutex<HashMap<String, broadcast::Sender<RouterResponse>>>>,
49}
50
51impl<S> CoalesceService<S> {
52    /// Create a new request coalescing service wrapping `inner`.
53    pub fn new(inner: S) -> Self {
54        Self {
55            inner,
56            in_flight: Arc::new(Mutex::new(HashMap::new())),
57        }
58    }
59}
60
61fn coalesce_key(req: &McpRequest) -> Option<String> {
62    match req {
63        McpRequest::CallTool(params) => {
64            let args = serde_json::to_string(&params.arguments).unwrap_or_default();
65            Some(format!("tool:{}:{}", params.name, args))
66        }
67        McpRequest::ReadResource(params) => Some(format!("res:{}", params.uri)),
68        _ => None,
69    }
70}
71
72impl<S> Service<RouterRequest> for CoalesceService<S>
73where
74    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
75        + Clone
76        + Send
77        + 'static,
78    S::Future: Send,
79{
80    type Response = RouterResponse;
81    type Error = Infallible;
82    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
83
84    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85        self.inner.poll_ready(cx)
86    }
87
88    fn call(&mut self, req: RouterRequest) -> Self::Future {
89        let Some(key) = coalesce_key(&req.inner) else {
90            // Non-coalesceable request, pass through
91            let fut = self.inner.call(req);
92            return Box::pin(fut);
93        };
94
95        let in_flight = Arc::clone(&self.in_flight);
96        let mut inner = self.inner.clone();
97        let request_id = req.id.clone();
98
99        Box::pin(async move {
100            // Check if there's already an in-flight request for this key
101            {
102                let map = in_flight.lock().await;
103                if let Some(tx) = map.get(&key) {
104                    let mut rx = tx.subscribe();
105                    drop(map);
106                    // Wait for the in-flight request to complete
107                    if let Ok(resp) = rx.recv().await {
108                        return Ok(RouterResponse {
109                            id: request_id,
110                            inner: resp.inner,
111                        });
112                    }
113                    // Sender dropped (shouldn't happen), fall through to make our own request
114                }
115            }
116
117            // We're the first — register ourselves
118            let (tx, _) = broadcast::channel(1);
119            {
120                let mut map = in_flight.lock().await;
121                map.insert(key.clone(), tx.clone());
122            }
123
124            let result = inner.call(req).await;
125
126            // Broadcast result to any waiters and clean up
127            let Ok(ref resp) = result;
128            let _ = tx.send(resp.clone());
129            {
130                let mut map = in_flight.lock().await;
131                map.remove(&key);
132            }
133
134            result
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use tower_mcp::protocol::{McpRequest, McpResponse};
142
143    use super::CoalesceService;
144    use crate::test_util::{MockService, call_service};
145
146    #[tokio::test]
147    async fn test_coalesce_passes_through_single_request() {
148        let mock = MockService::with_tools(&["fs/read"]);
149        let mut svc = CoalesceService::new(mock);
150
151        let resp = call_service(
152            &mut svc,
153            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
154                name: "fs/read".to_string(),
155                arguments: serde_json::json!({}),
156                meta: None,
157                task: None,
158            }),
159        )
160        .await;
161
162        match resp.inner.unwrap() {
163            McpResponse::CallTool(r) => assert_eq!(r.all_text(), "called: fs/read"),
164            other => panic!("expected CallTool, got: {:?}", other),
165        }
166    }
167
168    #[tokio::test]
169    async fn test_coalesce_non_coalesceable_passes_through() {
170        let mock = MockService::with_tools(&["tool"]);
171        let mut svc = CoalesceService::new(mock);
172
173        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
174        assert!(resp.inner.is_ok(), "list_tools should pass through");
175    }
176
177    #[tokio::test]
178    async fn test_coalesce_key_includes_arguments() {
179        // Different arguments should produce different keys
180        let key1 =
181            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
182                name: "tool".to_string(),
183                arguments: serde_json::json!({"a": 1}),
184                meta: None,
185                task: None,
186            }));
187        let key2 =
188            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
189                name: "tool".to_string(),
190                arguments: serde_json::json!({"a": 2}),
191                meta: None,
192                task: None,
193            }));
194        assert_ne!(key1, key2, "different args should have different keys");
195    }
196
197    #[tokio::test]
198    async fn test_coalesce_key_same_arguments_produce_same_key() {
199        let key1 =
200            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
201                name: "tool".to_string(),
202                arguments: serde_json::json!({"a": 1}),
203                meta: None,
204                task: None,
205            }));
206        let key2 =
207            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
208                name: "tool".to_string(),
209                arguments: serde_json::json!({"a": 1}),
210                meta: None,
211                task: None,
212            }));
213        assert_eq!(key1, key2, "same tool+args should have the same key");
214    }
215
216    #[tokio::test]
217    async fn test_coalesce_key_read_resource() {
218        let key = super::coalesce_key(&McpRequest::ReadResource(
219            tower_mcp::protocol::ReadResourceParams {
220                uri: "file:///tmp/test.txt".to_string(),
221                meta: None,
222            },
223        ));
224        assert_eq!(key, Some("res:file:///tmp/test.txt".to_string()));
225    }
226
227    #[tokio::test]
228    async fn test_coalesce_key_non_coalesceable_returns_none() {
229        let key = super::coalesce_key(&McpRequest::ListTools(Default::default()));
230        assert!(key.is_none(), "ListTools should not be coalesceable");
231
232        let key = super::coalesce_key(&McpRequest::ListResources(Default::default()));
233        assert!(key.is_none(), "ListResources should not be coalesceable");
234    }
235
236    #[tokio::test]
237    async fn test_concurrent_identical_requests_coalesced() {
238        use std::sync::Arc;
239        use std::sync::atomic::{AtomicUsize, Ordering};
240        use tower::Service;
241
242        // A mock that counts how many times it's actually called
243        #[derive(Clone)]
244        struct CountingService {
245            call_count: Arc<AtomicUsize>,
246        }
247
248        impl Service<tower_mcp::router::RouterRequest> for CountingService {
249            type Response = tower_mcp::router::RouterResponse;
250            type Error = std::convert::Infallible;
251            type Future = std::pin::Pin<
252                Box<
253                    dyn std::future::Future<
254                            Output = Result<
255                                tower_mcp::router::RouterResponse,
256                                std::convert::Infallible,
257                            >,
258                        > + Send,
259                >,
260            >;
261
262            fn poll_ready(
263                &mut self,
264                _cx: &mut std::task::Context<'_>,
265            ) -> std::task::Poll<Result<(), Self::Error>> {
266                std::task::Poll::Ready(Ok(()))
267            }
268
269            fn call(&mut self, req: tower_mcp::router::RouterRequest) -> Self::Future {
270                let count = self.call_count.clone();
271                let id = req.id.clone();
272                Box::pin(async move {
273                    count.fetch_add(1, Ordering::SeqCst);
274                    // Small delay to ensure concurrent requests overlap
275                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
276                    Ok(tower_mcp::router::RouterResponse {
277                        id,
278                        inner: Ok(McpResponse::CallTool(
279                            tower_mcp::protocol::CallToolResult::text("result"),
280                        )),
281                    })
282                })
283            }
284        }
285
286        let call_count = Arc::new(AtomicUsize::new(0));
287        let svc = CountingService {
288            call_count: call_count.clone(),
289        };
290        let coalesce = CoalesceService::new(svc);
291
292        let make_request = || {
293            let mut c = coalesce.clone();
294            let req = tower_mcp::router::RouterRequest {
295                id: tower_mcp::protocol::RequestId::Number(1),
296                inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
297                    name: "tool".to_string(),
298                    arguments: serde_json::json!({"x": 42}),
299                    meta: None,
300                    task: None,
301                }),
302                extensions: tower_mcp::router::Extensions::new(),
303            };
304            async move { c.call(req).await }
305        };
306
307        // Fire 3 identical requests concurrently
308        let (r1, r2, r3) = tokio::join!(make_request(), make_request(), make_request());
309
310        // All should succeed
311        assert!(r1.is_ok());
312        assert!(r2.is_ok());
313        assert!(r3.is_ok());
314
315        // The backend should be called at most twice (the first caller registers,
316        // some others may arrive before the lock is acquired). The key invariant
317        // is that it's called fewer times than the number of requests.
318        let count = call_count.load(Ordering::SeqCst);
319        assert!(
320            count < 3,
321            "expected fewer than 3 backend calls due to coalescing, got {count}"
322        );
323    }
324
325    #[tokio::test]
326    async fn test_different_requests_not_coalesced() {
327        let mock = MockService::with_tools(&["tool"]);
328        let coalesce = CoalesceService::new(mock);
329
330        // Two requests with different arguments
331        let mut c1 = coalesce.clone();
332        let req1 = tower_mcp::router::RouterRequest {
333            id: tower_mcp::protocol::RequestId::Number(1),
334            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
335                name: "tool".to_string(),
336                arguments: serde_json::json!({"x": 1}),
337                meta: None,
338                task: None,
339            }),
340            extensions: tower_mcp::router::Extensions::new(),
341        };
342
343        let mut c2 = coalesce.clone();
344        let req2 = tower_mcp::router::RouterRequest {
345            id: tower_mcp::protocol::RequestId::Number(2),
346            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
347                name: "tool".to_string(),
348                arguments: serde_json::json!({"x": 2}),
349                meta: None,
350                task: None,
351            }),
352            extensions: tower_mcp::router::Extensions::new(),
353        };
354
355        let (r1, r2) = tokio::join!(
356            tower::Service::call(&mut c1, req1),
357            tower::Service::call(&mut c2, req2)
358        );
359
360        // Both should succeed independently
361        assert!(r1.is_ok());
362        assert!(r2.is_ok());
363    }
364
365    #[tokio::test]
366    async fn test_coalesce_with_error_response() {
367        use crate::test_util::ErrorMockService;
368
369        let mock = ErrorMockService;
370        let mut svc = CoalesceService::new(mock);
371
372        let resp = call_service(
373            &mut svc,
374            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
375                name: "failing_tool".to_string(),
376                arguments: serde_json::json!({}),
377                meta: None,
378                task: None,
379            }),
380        )
381        .await;
382
383        // The error response should pass through correctly
384        assert!(
385            resp.inner.is_err(),
386            "error response should propagate through coalesce"
387        );
388        let err = resp.inner.unwrap_err();
389        assert_eq!(err.code, -32603);
390        assert_eq!(err.message, "internal error");
391    }
392}