Skip to main content

mcp_proxy/
coalesce.rs

1//! Request coalescing middleware for the proxy.
2//!
3//! Deduplicates identical in-flight `CallTool` and `ReadResource` requests.
4//! When multiple identical requests arrive concurrently, only one is forwarded
5//! to the backend; all callers receive the same response.
6
7use std::collections::HashMap;
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tokio::sync::{Mutex, broadcast};
15use tower::{Layer, Service};
16use tower_mcp::router::{RouterRequest, RouterResponse};
17use tower_mcp_types::protocol::McpRequest;
18
19/// Tower layer that produces a [`CoalesceService`].
20#[derive(Clone)]
21pub struct CoalesceLayer;
22
23impl CoalesceLayer {
24    /// Create a new request coalescing layer.
25    pub fn new() -> Self {
26        Self
27    }
28}
29
30impl Default for CoalesceLayer {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<S> Layer<S> for CoalesceLayer {
37    type Service = CoalesceService<S>;
38
39    fn layer(&self, inner: S) -> Self::Service {
40        CoalesceService::new(inner)
41    }
42}
43
44/// Tower service that coalesces identical in-flight requests.
45#[derive(Clone)]
46pub struct CoalesceService<S> {
47    inner: S,
48    in_flight: Arc<Mutex<HashMap<String, broadcast::Sender<RouterResponse>>>>,
49}
50
51impl<S> CoalesceService<S> {
52    /// Create a new request coalescing service wrapping `inner`.
53    pub fn new(inner: S) -> Self {
54        Self {
55            inner,
56            in_flight: Arc::new(Mutex::new(HashMap::new())),
57        }
58    }
59}
60
61fn coalesce_key(req: &McpRequest) -> Option<String> {
62    match req {
63        McpRequest::CallTool(params) => {
64            let args = serde_json::to_string(&params.arguments).unwrap_or_default();
65            Some(format!("tool:{}:{}", params.name, args))
66        }
67        McpRequest::ReadResource(params) => Some(format!("res:{}", params.uri)),
68        _ => None,
69    }
70}
71
72impl<S> Service<RouterRequest> for CoalesceService<S>
73where
74    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
75        + Clone
76        + Send
77        + 'static,
78    S::Future: Send,
79{
80    type Response = RouterResponse;
81    type Error = Infallible;
82    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
83
84    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85        self.inner.poll_ready(cx)
86    }
87
88    fn call(&mut self, req: RouterRequest) -> Self::Future {
89        let Some(key) = coalesce_key(&req.inner) else {
90            // Non-coalesceable request, pass through
91            let fut = self.inner.call(req);
92            return Box::pin(fut);
93        };
94
95        let in_flight = Arc::clone(&self.in_flight);
96        let mut inner = self.inner.clone();
97        let request_id = req.id.clone();
98
99        Box::pin(async move {
100            // Check if there's already an in-flight request for this key
101            {
102                let map = in_flight.lock().await;
103                if let Some(tx) = map.get(&key) {
104                    let mut rx = tx.subscribe();
105                    drop(map);
106                    // Wait for the in-flight request to complete
107                    if let Ok(resp) = rx.recv().await {
108                        return Ok(RouterResponse {
109                            id: request_id,
110                            inner: resp.inner,
111                        });
112                    }
113                    // Sender dropped (shouldn't happen), fall through to make our own request
114                }
115            }
116
117            // We're the first — register ourselves
118            let (tx, _) = broadcast::channel(1);
119            {
120                let mut map = in_flight.lock().await;
121                map.insert(key.clone(), tx.clone());
122            }
123
124            let result = inner.call(req).await;
125
126            // Broadcast result to any waiters and clean up
127            let Ok(ref resp) = result;
128            let _ = tx.send(resp.clone());
129            {
130                let mut map = in_flight.lock().await;
131                map.remove(&key);
132            }
133
134            result
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use tower_mcp::protocol::{McpRequest, McpResponse};
142
143    use super::CoalesceService;
144    use crate::test_util::{MockService, call_service};
145
146    #[tokio::test]
147    async fn test_coalesce_passes_through_single_request() {
148        let mock = MockService::with_tools(&["fs/read"]);
149        let mut svc = CoalesceService::new(mock);
150
151        let resp = call_service(
152            &mut svc,
153            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
154                name: "fs/read".to_string(),
155                arguments: serde_json::json!({}),
156                meta: None,
157                task: None,
158            }),
159        )
160        .await;
161
162        match resp.inner.unwrap() {
163            McpResponse::CallTool(r) => assert_eq!(r.all_text(), "called: fs/read"),
164            other => panic!("expected CallTool, got: {:?}", other),
165        }
166    }
167
168    #[tokio::test]
169    async fn test_coalesce_non_coalesceable_passes_through() {
170        let mock = MockService::with_tools(&["tool"]);
171        let mut svc = CoalesceService::new(mock);
172
173        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
174        assert!(resp.inner.is_ok(), "list_tools should pass through");
175    }
176
177    #[tokio::test]
178    async fn test_coalesce_key_includes_arguments() {
179        // Different arguments should produce different keys
180        let key1 =
181            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
182                name: "tool".to_string(),
183                arguments: serde_json::json!({"a": 1}),
184                meta: None,
185                task: None,
186            }));
187        let key2 =
188            super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
189                name: "tool".to_string(),
190                arguments: serde_json::json!({"a": 2}),
191                meta: None,
192                task: None,
193            }));
194        assert_ne!(key1, key2, "different args should have different keys");
195    }
196}