1use std::convert::Infallible;
63use std::future::Future;
64use std::pin::Pin;
65use std::sync::Arc;
66use std::task::{Context, Poll};
67
68use tokio::task::JoinSet;
69use tower::{Layer, Service};
70use tower_mcp::protocol::{
71 CallToolParams, CallToolResult, McpRequest, McpResponse, ToolDefinition,
72};
73use tower_mcp::router::{RouterRequest, RouterResponse};
74
75use crate::config::CompositeToolConfig;
76
77#[derive(Clone)]
98pub struct CompositeLayer {
99 composites: Arc<Vec<CompositeToolConfig>>,
100}
101
102impl CompositeLayer {
103 pub fn new(composites: Vec<CompositeToolConfig>) -> Self {
105 Self {
106 composites: Arc::new(composites),
107 }
108 }
109}
110
111impl<S> Layer<S> for CompositeLayer {
112 type Service = CompositeService<S>;
113
114 fn layer(&self, inner: S) -> Self::Service {
115 CompositeService::new(inner, (*self.composites).clone())
116 }
117}
118
119#[derive(Clone)]
122pub struct CompositeService<S> {
123 inner: S,
124 composites: Arc<Vec<CompositeToolConfig>>,
125}
126
127impl<S> CompositeService<S> {
128 pub fn new(inner: S, composites: Vec<CompositeToolConfig>) -> Self {
130 Self {
131 inner,
132 composites: Arc::new(composites),
133 }
134 }
135}
136
137impl<S> Service<RouterRequest> for CompositeService<S>
138where
139 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
140 + Clone
141 + Send
142 + 'static,
143 S::Future: Send,
144{
145 type Response = RouterResponse;
146 type Error = Infallible;
147 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
148
149 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 self.inner.poll_ready(cx)
151 }
152
153 fn call(&mut self, req: RouterRequest) -> Self::Future {
154 let composites = Arc::clone(&self.composites);
155
156 if let McpRequest::CallTool(ref params) = req.inner
158 && let Some(composite) = composites.iter().find(|c| c.name == params.name)
159 {
160 let id = req.id.clone();
161 let extensions = req.extensions.clone();
162 let tool_names = composite.tools.clone();
163 let arguments = params.arguments.clone();
164 let meta = params.meta.clone();
165 let task = params.task.clone();
166 let inner = self.inner.clone();
167
168 return Box::pin(async move {
169 let mut join_set = JoinSet::new();
170
171 for tool_name in tool_names {
172 let mut svc = inner.clone();
173 let tool_req = RouterRequest {
174 id: id.clone(),
175 inner: McpRequest::CallTool(CallToolParams {
176 name: tool_name,
177 arguments: arguments.clone(),
178 meta: meta.clone(),
179 task: task.clone(),
180 }),
181 extensions: extensions.clone(),
182 };
183 join_set.spawn(async move { svc.call(tool_req).await });
184 }
185
186 let mut all_content = Vec::new();
187 let mut any_error = false;
188
189 while let Some(result) = join_set.join_next().await {
190 match result {
191 Ok(Ok(resp)) => match resp.inner {
192 Ok(McpResponse::CallTool(call_result)) => {
193 if call_result.is_error {
194 any_error = true;
195 }
196 all_content.extend(call_result.content);
197 }
198 Err(json_rpc_err) => {
199 any_error = true;
200 all_content.push(tower_mcp::protocol::Content::text(format!(
201 "Error: {}",
202 json_rpc_err.message
203 )));
204 }
205 Ok(other) => {
206 any_error = true;
207 all_content.push(tower_mcp::protocol::Content::text(format!(
208 "Unexpected response type: {:?}",
209 other
210 )));
211 }
212 },
213 Ok(Err(_infallible)) => {
214 }
216 Err(join_err) => {
217 any_error = true;
218 all_content.push(tower_mcp::protocol::Content::text(format!(
219 "Task failed: {}",
220 join_err
221 )));
222 }
223 }
224 }
225
226 let result = CallToolResult {
227 content: all_content,
228 is_error: any_error,
229 structured_content: None,
230 meta: None,
231 };
232
233 Ok(RouterResponse {
234 id,
235 inner: Ok(McpResponse::CallTool(result)),
236 })
237 });
238 }
239
240 if matches!(req.inner, McpRequest::ListTools(_)) {
242 let fut = self.inner.call(req);
243
244 return Box::pin(async move {
245 let mut result = fut.await;
246
247 let Ok(ref mut resp) = result;
248 if let Ok(McpResponse::ListTools(ref mut list_result)) = resp.inner {
249 for composite in composites.iter() {
250 list_result.tools.push(ToolDefinition {
251 name: composite.name.clone(),
252 title: None,
253 description: Some(composite.description.clone()),
254 input_schema: serde_json::json!({"type": "object"}),
255 output_schema: None,
256 icons: None,
257 annotations: None,
258 execution: None,
259 meta: None,
260 });
261 }
262 }
263
264 result
265 });
266 }
267
268 let fut = self.inner.call(req);
270 Box::pin(fut)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use tower_mcp::protocol::{McpRequest, McpResponse};
277
278 use super::CompositeService;
279 use crate::config::{CompositeStrategy, CompositeToolConfig};
280 use crate::test_util::{ErrorMockService, MockService, call_service};
281
282 fn test_composites() -> Vec<CompositeToolConfig> {
283 vec![CompositeToolConfig {
284 name: "search_all".to_string(),
285 description: "Search across all sources".to_string(),
286 tools: vec!["github/search".to_string(), "docs/search".to_string()],
287 strategy: CompositeStrategy::Parallel,
288 }]
289 }
290
291 #[tokio::test]
292 async fn test_composite_appears_in_list_tools() {
293 let mock = MockService::with_tools(&["github/search", "docs/search", "db/query"]);
294 let mut svc = CompositeService::new(mock, test_composites());
295
296 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
297 match resp.inner.unwrap() {
298 McpResponse::ListTools(result) => {
299 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
300 assert!(names.contains(&"github/search"));
301 assert!(names.contains(&"docs/search"));
302 assert!(names.contains(&"db/query"));
303 assert!(
304 names.contains(&"search_all"),
305 "composite tool should appear"
306 );
307 let composite_tool = result
309 .tools
310 .iter()
311 .find(|t| t.name == "search_all")
312 .unwrap();
313 assert_eq!(
314 composite_tool.description.as_deref(),
315 Some("Search across all sources")
316 );
317 }
318 other => panic!("expected ListTools, got: {:?}", other),
319 }
320 }
321
322 #[tokio::test]
323 async fn test_composite_fan_out_aggregates_results() {
324 let mock = MockService::with_tools(&["github/search", "docs/search"]);
325 let mut svc = CompositeService::new(mock, test_composites());
326
327 let resp = call_service(
328 &mut svc,
329 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
330 name: "search_all".to_string(),
331 arguments: serde_json::json!({"q": "test"}),
332 meta: None,
333 task: None,
334 }),
335 )
336 .await;
337
338 match resp.inner.unwrap() {
339 McpResponse::CallTool(result) => {
340 assert_eq!(result.content.len(), 2, "should aggregate both results");
341 let texts: Vec<String> = result
342 .content
343 .iter()
344 .map(|c| c.as_text().unwrap().to_string())
345 .collect();
346 assert!(texts.contains(&"called: github/search".to_string()));
347 assert!(texts.contains(&"called: docs/search".to_string()));
348 assert!(!result.is_error, "no errors expected");
349 }
350 other => panic!("expected CallTool, got: {:?}", other),
351 }
352 }
353
354 #[tokio::test]
355 async fn test_non_composite_call_passes_through() {
356 let mock = MockService::with_tools(&["db/query"]);
357 let mut svc = CompositeService::new(mock, test_composites());
358
359 let resp = call_service(
360 &mut svc,
361 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
362 name: "db/query".to_string(),
363 arguments: serde_json::json!({}),
364 meta: None,
365 task: None,
366 }),
367 )
368 .await;
369
370 match resp.inner.unwrap() {
371 McpResponse::CallTool(result) => {
372 assert_eq!(result.all_text(), "called: db/query");
373 }
374 other => panic!("expected CallTool, got: {:?}", other),
375 }
376 }
377
378 #[tokio::test]
379 async fn test_partial_failure_returns_partial_results() {
380 let mock = ErrorMockService;
382 let mut svc = CompositeService::new(mock, test_composites());
383
384 let resp = call_service(
385 &mut svc,
386 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
387 name: "search_all".to_string(),
388 arguments: serde_json::json!({}),
389 meta: None,
390 task: None,
391 }),
392 )
393 .await;
394
395 match resp.inner.unwrap() {
396 McpResponse::CallTool(result) => {
397 assert_eq!(
398 result.content.len(),
399 2,
400 "should have error content for both tools"
401 );
402 assert!(result.is_error, "should be marked as error");
403 for content in &result.content {
404 let text = content.as_text().unwrap();
405 assert!(
406 text.contains("Error:"),
407 "content should describe error: {text}"
408 );
409 }
410 }
411 other => panic!("expected CallTool, got: {:?}", other),
412 }
413 }
414
415 #[tokio::test]
416 async fn test_non_tool_requests_pass_through() {
417 let mock = MockService::with_tools(&[]);
418 let mut svc = CompositeService::new(mock, test_composites());
419
420 let resp = call_service(&mut svc, McpRequest::Ping).await;
421 match resp.inner.unwrap() {
422 McpResponse::Pong(_) => {} other => panic!("expected Pong, got: {:?}", other),
424 }
425 }
426
427 #[tokio::test]
428 async fn test_empty_composites_passes_through() {
429 let mock = MockService::with_tools(&["tool1"]);
430 let mut svc = CompositeService::new(mock, vec![]);
431
432 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
433 match resp.inner.unwrap() {
434 McpResponse::ListTools(result) => {
435 assert_eq!(result.tools.len(), 1);
436 assert_eq!(result.tools[0].name, "tool1");
437 }
438 other => panic!("expected ListTools, got: {:?}", other),
439 }
440 }
441}