1use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tokio::task::JoinSet;
14use tower::{Layer, Service};
15use tower_mcp::protocol::{
16 CallToolParams, CallToolResult, McpRequest, McpResponse, ToolDefinition,
17};
18use tower_mcp::router::{RouterRequest, RouterResponse};
19
20use crate::config::CompositeToolConfig;
21
22#[derive(Clone)]
43pub struct CompositeLayer {
44 composites: Arc<Vec<CompositeToolConfig>>,
45}
46
47impl CompositeLayer {
48 pub fn new(composites: Vec<CompositeToolConfig>) -> Self {
50 Self {
51 composites: Arc::new(composites),
52 }
53 }
54}
55
56impl<S> Layer<S> for CompositeLayer {
57 type Service = CompositeService<S>;
58
59 fn layer(&self, inner: S) -> Self::Service {
60 CompositeService::new(inner, (*self.composites).clone())
61 }
62}
63
64#[derive(Clone)]
67pub struct CompositeService<S> {
68 inner: S,
69 composites: Arc<Vec<CompositeToolConfig>>,
70}
71
72impl<S> CompositeService<S> {
73 pub fn new(inner: S, composites: Vec<CompositeToolConfig>) -> Self {
75 Self {
76 inner,
77 composites: Arc::new(composites),
78 }
79 }
80}
81
82impl<S> Service<RouterRequest> for CompositeService<S>
83where
84 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
85 + Clone
86 + Send
87 + 'static,
88 S::Future: Send,
89{
90 type Response = RouterResponse;
91 type Error = Infallible;
92 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
93
94 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95 self.inner.poll_ready(cx)
96 }
97
98 fn call(&mut self, req: RouterRequest) -> Self::Future {
99 let composites = Arc::clone(&self.composites);
100
101 if let McpRequest::CallTool(ref params) = req.inner
103 && let Some(composite) = composites.iter().find(|c| c.name == params.name)
104 {
105 let id = req.id.clone();
106 let extensions = req.extensions.clone();
107 let tool_names = composite.tools.clone();
108 let arguments = params.arguments.clone();
109 let meta = params.meta.clone();
110 let task = params.task.clone();
111 let inner = self.inner.clone();
112
113 return Box::pin(async move {
114 let mut join_set = JoinSet::new();
115
116 for tool_name in tool_names {
117 let mut svc = inner.clone();
118 let tool_req = RouterRequest {
119 id: id.clone(),
120 inner: McpRequest::CallTool(CallToolParams {
121 name: tool_name,
122 arguments: arguments.clone(),
123 meta: meta.clone(),
124 task: task.clone(),
125 }),
126 extensions: extensions.clone(),
127 };
128 join_set.spawn(async move { svc.call(tool_req).await });
129 }
130
131 let mut all_content = Vec::new();
132 let mut any_error = false;
133
134 while let Some(result) = join_set.join_next().await {
135 match result {
136 Ok(Ok(resp)) => match resp.inner {
137 Ok(McpResponse::CallTool(call_result)) => {
138 if call_result.is_error {
139 any_error = true;
140 }
141 all_content.extend(call_result.content);
142 }
143 Err(json_rpc_err) => {
144 any_error = true;
145 all_content.push(tower_mcp::protocol::Content::text(format!(
146 "Error: {}",
147 json_rpc_err.message
148 )));
149 }
150 Ok(other) => {
151 any_error = true;
152 all_content.push(tower_mcp::protocol::Content::text(format!(
153 "Unexpected response type: {:?}",
154 other
155 )));
156 }
157 },
158 Ok(Err(_infallible)) => {
159 }
161 Err(join_err) => {
162 any_error = true;
163 all_content.push(tower_mcp::protocol::Content::text(format!(
164 "Task failed: {}",
165 join_err
166 )));
167 }
168 }
169 }
170
171 let result = CallToolResult {
172 content: all_content,
173 is_error: any_error,
174 structured_content: None,
175 meta: None,
176 };
177
178 Ok(RouterResponse {
179 id,
180 inner: Ok(McpResponse::CallTool(result)),
181 })
182 });
183 }
184
185 if matches!(req.inner, McpRequest::ListTools(_)) {
187 let fut = self.inner.call(req);
188
189 return Box::pin(async move {
190 let mut result = fut.await;
191
192 let Ok(ref mut resp) = result;
193 if let Ok(McpResponse::ListTools(ref mut list_result)) = resp.inner {
194 for composite in composites.iter() {
195 list_result.tools.push(ToolDefinition {
196 name: composite.name.clone(),
197 title: None,
198 description: Some(composite.description.clone()),
199 input_schema: serde_json::json!({"type": "object"}),
200 output_schema: None,
201 icons: None,
202 annotations: None,
203 execution: None,
204 meta: None,
205 });
206 }
207 }
208
209 result
210 });
211 }
212
213 let fut = self.inner.call(req);
215 Box::pin(fut)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use tower_mcp::protocol::{McpRequest, McpResponse};
222
223 use super::CompositeService;
224 use crate::config::{CompositeStrategy, CompositeToolConfig};
225 use crate::test_util::{ErrorMockService, MockService, call_service};
226
227 fn test_composites() -> Vec<CompositeToolConfig> {
228 vec![CompositeToolConfig {
229 name: "search_all".to_string(),
230 description: "Search across all sources".to_string(),
231 tools: vec!["github/search".to_string(), "docs/search".to_string()],
232 strategy: CompositeStrategy::Parallel,
233 }]
234 }
235
236 #[tokio::test]
237 async fn test_composite_appears_in_list_tools() {
238 let mock = MockService::with_tools(&["github/search", "docs/search", "db/query"]);
239 let mut svc = CompositeService::new(mock, test_composites());
240
241 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
242 match resp.inner.unwrap() {
243 McpResponse::ListTools(result) => {
244 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
245 assert!(names.contains(&"github/search"));
246 assert!(names.contains(&"docs/search"));
247 assert!(names.contains(&"db/query"));
248 assert!(
249 names.contains(&"search_all"),
250 "composite tool should appear"
251 );
252 let composite_tool = result
254 .tools
255 .iter()
256 .find(|t| t.name == "search_all")
257 .unwrap();
258 assert_eq!(
259 composite_tool.description.as_deref(),
260 Some("Search across all sources")
261 );
262 }
263 other => panic!("expected ListTools, got: {:?}", other),
264 }
265 }
266
267 #[tokio::test]
268 async fn test_composite_fan_out_aggregates_results() {
269 let mock = MockService::with_tools(&["github/search", "docs/search"]);
270 let mut svc = CompositeService::new(mock, test_composites());
271
272 let resp = call_service(
273 &mut svc,
274 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
275 name: "search_all".to_string(),
276 arguments: serde_json::json!({"q": "test"}),
277 meta: None,
278 task: None,
279 }),
280 )
281 .await;
282
283 match resp.inner.unwrap() {
284 McpResponse::CallTool(result) => {
285 assert_eq!(result.content.len(), 2, "should aggregate both results");
286 let texts: Vec<String> = result
287 .content
288 .iter()
289 .map(|c| c.as_text().unwrap().to_string())
290 .collect();
291 assert!(texts.contains(&"called: github/search".to_string()));
292 assert!(texts.contains(&"called: docs/search".to_string()));
293 assert!(!result.is_error, "no errors expected");
294 }
295 other => panic!("expected CallTool, got: {:?}", other),
296 }
297 }
298
299 #[tokio::test]
300 async fn test_non_composite_call_passes_through() {
301 let mock = MockService::with_tools(&["db/query"]);
302 let mut svc = CompositeService::new(mock, test_composites());
303
304 let resp = call_service(
305 &mut svc,
306 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
307 name: "db/query".to_string(),
308 arguments: serde_json::json!({}),
309 meta: None,
310 task: None,
311 }),
312 )
313 .await;
314
315 match resp.inner.unwrap() {
316 McpResponse::CallTool(result) => {
317 assert_eq!(result.all_text(), "called: db/query");
318 }
319 other => panic!("expected CallTool, got: {:?}", other),
320 }
321 }
322
323 #[tokio::test]
324 async fn test_partial_failure_returns_partial_results() {
325 let mock = ErrorMockService;
327 let mut svc = CompositeService::new(mock, test_composites());
328
329 let resp = call_service(
330 &mut svc,
331 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
332 name: "search_all".to_string(),
333 arguments: serde_json::json!({}),
334 meta: None,
335 task: None,
336 }),
337 )
338 .await;
339
340 match resp.inner.unwrap() {
341 McpResponse::CallTool(result) => {
342 assert_eq!(
343 result.content.len(),
344 2,
345 "should have error content for both tools"
346 );
347 assert!(result.is_error, "should be marked as error");
348 for content in &result.content {
349 let text = content.as_text().unwrap();
350 assert!(
351 text.contains("Error:"),
352 "content should describe error: {text}"
353 );
354 }
355 }
356 other => panic!("expected CallTool, got: {:?}", other),
357 }
358 }
359
360 #[tokio::test]
361 async fn test_non_tool_requests_pass_through() {
362 let mock = MockService::with_tools(&[]);
363 let mut svc = CompositeService::new(mock, test_composites());
364
365 let resp = call_service(&mut svc, McpRequest::Ping).await;
366 match resp.inner.unwrap() {
367 McpResponse::Pong(_) => {} other => panic!("expected Pong, got: {:?}", other),
369 }
370 }
371
372 #[tokio::test]
373 async fn test_empty_composites_passes_through() {
374 let mock = MockService::with_tools(&["tool1"]);
375 let mut svc = CompositeService::new(mock, vec![]);
376
377 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
378 match resp.inner.unwrap() {
379 McpResponse::ListTools(result) => {
380 assert_eq!(result.tools.len(), 1);
381 assert_eq!(result.tools[0].name, "tool1");
382 }
383 other => panic!("expected ListTools, got: {:?}", other),
384 }
385 }
386}