1use std::collections::HashMap;
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tokio::sync::{Mutex, broadcast};
15use tower::{Layer, Service};
16use tower_mcp::router::{RouterRequest, RouterResponse};
17use tower_mcp_types::protocol::McpRequest;
18
19#[derive(Clone)]
21pub struct CoalesceLayer;
22
23impl CoalesceLayer {
24 pub fn new() -> Self {
26 Self
27 }
28}
29
30impl Default for CoalesceLayer {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<S> Layer<S> for CoalesceLayer {
37 type Service = CoalesceService<S>;
38
39 fn layer(&self, inner: S) -> Self::Service {
40 CoalesceService::new(inner)
41 }
42}
43
44#[derive(Clone)]
46pub struct CoalesceService<S> {
47 inner: S,
48 in_flight: Arc<Mutex<HashMap<String, broadcast::Sender<RouterResponse>>>>,
49}
50
51impl<S> CoalesceService<S> {
52 pub fn new(inner: S) -> Self {
54 Self {
55 inner,
56 in_flight: Arc::new(Mutex::new(HashMap::new())),
57 }
58 }
59}
60
61fn coalesce_key(req: &McpRequest) -> Option<String> {
62 match req {
63 McpRequest::CallTool(params) => {
64 let args = serde_json::to_string(¶ms.arguments).unwrap_or_default();
65 Some(format!("tool:{}:{}", params.name, args))
66 }
67 McpRequest::ReadResource(params) => Some(format!("res:{}", params.uri)),
68 _ => None,
69 }
70}
71
72impl<S> Service<RouterRequest> for CoalesceService<S>
73where
74 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
75 + Clone
76 + Send
77 + 'static,
78 S::Future: Send,
79{
80 type Response = RouterResponse;
81 type Error = Infallible;
82 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
83
84 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 self.inner.poll_ready(cx)
86 }
87
88 fn call(&mut self, req: RouterRequest) -> Self::Future {
89 let Some(key) = coalesce_key(&req.inner) else {
90 let fut = self.inner.call(req);
92 return Box::pin(fut);
93 };
94
95 let in_flight = Arc::clone(&self.in_flight);
96 let mut inner = self.inner.clone();
97 let request_id = req.id.clone();
98
99 Box::pin(async move {
100 {
102 let map = in_flight.lock().await;
103 if let Some(tx) = map.get(&key) {
104 let mut rx = tx.subscribe();
105 drop(map);
106 if let Ok(resp) = rx.recv().await {
108 return Ok(RouterResponse {
109 id: request_id,
110 inner: resp.inner,
111 });
112 }
113 }
115 }
116
117 let (tx, _) = broadcast::channel(1);
119 {
120 let mut map = in_flight.lock().await;
121 map.insert(key.clone(), tx.clone());
122 }
123
124 let result = inner.call(req).await;
125
126 let Ok(ref resp) = result;
128 let _ = tx.send(resp.clone());
129 {
130 let mut map = in_flight.lock().await;
131 map.remove(&key);
132 }
133
134 result
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use tower_mcp::protocol::{McpRequest, McpResponse};
142
143 use super::CoalesceService;
144 use crate::test_util::{MockService, call_service};
145
146 #[tokio::test]
147 async fn test_coalesce_passes_through_single_request() {
148 let mock = MockService::with_tools(&["fs/read"]);
149 let mut svc = CoalesceService::new(mock);
150
151 let resp = call_service(
152 &mut svc,
153 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
154 name: "fs/read".to_string(),
155 arguments: serde_json::json!({}),
156 meta: None,
157 task: None,
158 }),
159 )
160 .await;
161
162 match resp.inner.unwrap() {
163 McpResponse::CallTool(r) => assert_eq!(r.all_text(), "called: fs/read"),
164 other => panic!("expected CallTool, got: {:?}", other),
165 }
166 }
167
168 #[tokio::test]
169 async fn test_coalesce_non_coalesceable_passes_through() {
170 let mock = MockService::with_tools(&["tool"]);
171 let mut svc = CoalesceService::new(mock);
172
173 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
174 assert!(resp.inner.is_ok(), "list_tools should pass through");
175 }
176
177 #[tokio::test]
178 async fn test_coalesce_key_includes_arguments() {
179 let key1 =
181 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
182 name: "tool".to_string(),
183 arguments: serde_json::json!({"a": 1}),
184 meta: None,
185 task: None,
186 }));
187 let key2 =
188 super::coalesce_key(&McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
189 name: "tool".to_string(),
190 arguments: serde_json::json!({"a": 2}),
191 meta: None,
192 task: None,
193 }));
194 assert_ne!(key1, key2, "different args should have different keys");
195 }
196}