Skip to main content

mcp_proxy/
composite.rs

1//! Composite tool middleware for fan-out to multiple backend tools.
2//!
3//! Adds config-defined composite tools that appear in `ListTools` responses
4//! and, when called, fan out the request to multiple backend tools concurrently
5//! using `tokio::JoinSet`, aggregating all results into a single response.
6
7use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tokio::task::JoinSet;
14use tower::{Layer, Service};
15use tower_mcp::protocol::{
16    CallToolParams, CallToolResult, McpRequest, McpResponse, ToolDefinition,
17};
18use tower_mcp::router::{RouterRequest, RouterResponse};
19
20use crate::config::CompositeToolConfig;
21
22/// Tower layer that produces a [`CompositeService`].
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use tower::ServiceBuilder;
28/// use mcp_proxy::composite::CompositeLayer;
29/// use mcp_proxy::config::CompositeToolConfig;
30///
31/// let composites = vec![CompositeToolConfig {
32///     name: "search_all".into(),
33///     description: "Search everything".into(),
34///     tools: vec!["github/search".into(), "docs/search".into()],
35///     strategy: Default::default(),
36/// }];
37///
38/// let service = ServiceBuilder::new()
39///     .layer(CompositeLayer::new(composites))
40///     .service(proxy);
41/// ```
42#[derive(Clone)]
43pub struct CompositeLayer {
44    composites: Arc<Vec<CompositeToolConfig>>,
45}
46
47impl CompositeLayer {
48    /// Create a new composite layer with the given tool definitions.
49    pub fn new(composites: Vec<CompositeToolConfig>) -> Self {
50        Self {
51            composites: Arc::new(composites),
52        }
53    }
54}
55
56impl<S> Layer<S> for CompositeLayer {
57    type Service = CompositeService<S>;
58
59    fn layer(&self, inner: S) -> Self::Service {
60        CompositeService::new(inner, (*self.composites).clone())
61    }
62}
63
64/// Tower service that intercepts `ListTools` and `CallTool` requests
65/// to support composite tool fan-out.
66#[derive(Clone)]
67pub struct CompositeService<S> {
68    inner: S,
69    composites: Arc<Vec<CompositeToolConfig>>,
70}
71
72impl<S> CompositeService<S> {
73    /// Create a new composite service wrapping `inner` with the given composite tool configs.
74    pub fn new(inner: S, composites: Vec<CompositeToolConfig>) -> Self {
75        Self {
76            inner,
77            composites: Arc::new(composites),
78        }
79    }
80}
81
82impl<S> Service<RouterRequest> for CompositeService<S>
83where
84    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
85        + Clone
86        + Send
87        + 'static,
88    S::Future: Send,
89{
90    type Response = RouterResponse;
91    type Error = Infallible;
92    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
93
94    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95        self.inner.poll_ready(cx)
96    }
97
98    fn call(&mut self, req: RouterRequest) -> Self::Future {
99        let composites = Arc::clone(&self.composites);
100
101        // Check if this is a CallTool for a composite tool
102        if let McpRequest::CallTool(ref params) = req.inner
103            && let Some(composite) = composites.iter().find(|c| c.name == params.name)
104        {
105            let id = req.id.clone();
106            let extensions = req.extensions.clone();
107            let tool_names = composite.tools.clone();
108            let arguments = params.arguments.clone();
109            let meta = params.meta.clone();
110            let task = params.task.clone();
111            let inner = self.inner.clone();
112
113            return Box::pin(async move {
114                let mut join_set = JoinSet::new();
115
116                for tool_name in tool_names {
117                    let mut svc = inner.clone();
118                    let tool_req = RouterRequest {
119                        id: id.clone(),
120                        inner: McpRequest::CallTool(CallToolParams {
121                            name: tool_name,
122                            arguments: arguments.clone(),
123                            meta: meta.clone(),
124                            task: task.clone(),
125                        }),
126                        extensions: extensions.clone(),
127                    };
128                    join_set.spawn(async move { svc.call(tool_req).await });
129                }
130
131                let mut all_content = Vec::new();
132                let mut any_error = false;
133
134                while let Some(result) = join_set.join_next().await {
135                    match result {
136                        Ok(Ok(resp)) => match resp.inner {
137                            Ok(McpResponse::CallTool(call_result)) => {
138                                if call_result.is_error {
139                                    any_error = true;
140                                }
141                                all_content.extend(call_result.content);
142                            }
143                            Err(json_rpc_err) => {
144                                any_error = true;
145                                all_content.push(tower_mcp::protocol::Content::text(format!(
146                                    "Error: {}",
147                                    json_rpc_err.message
148                                )));
149                            }
150                            Ok(other) => {
151                                any_error = true;
152                                all_content.push(tower_mcp::protocol::Content::text(format!(
153                                    "Unexpected response type: {:?}",
154                                    other
155                                )));
156                            }
157                        },
158                        Ok(Err(_infallible)) => {
159                            // Infallible error -- cannot happen
160                        }
161                        Err(join_err) => {
162                            any_error = true;
163                            all_content.push(tower_mcp::protocol::Content::text(format!(
164                                "Task failed: {}",
165                                join_err
166                            )));
167                        }
168                    }
169                }
170
171                let result = CallToolResult {
172                    content: all_content,
173                    is_error: any_error,
174                    structured_content: None,
175                    meta: None,
176                };
177
178                Ok(RouterResponse {
179                    id,
180                    inner: Ok(McpResponse::CallTool(result)),
181                })
182            });
183        }
184
185        // For ListTools, append composite tool definitions
186        if matches!(req.inner, McpRequest::ListTools(_)) {
187            let fut = self.inner.call(req);
188
189            return Box::pin(async move {
190                let mut result = fut.await;
191
192                let Ok(ref mut resp) = result;
193                if let Ok(McpResponse::ListTools(ref mut list_result)) = resp.inner {
194                    for composite in composites.iter() {
195                        list_result.tools.push(ToolDefinition {
196                            name: composite.name.clone(),
197                            title: None,
198                            description: Some(composite.description.clone()),
199                            input_schema: serde_json::json!({"type": "object"}),
200                            output_schema: None,
201                            icons: None,
202                            annotations: None,
203                            execution: None,
204                            meta: None,
205                        });
206                    }
207                }
208
209                result
210            });
211        }
212
213        // All other requests pass through unchanged
214        let fut = self.inner.call(req);
215        Box::pin(fut)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use tower_mcp::protocol::{McpRequest, McpResponse};
222
223    use super::CompositeService;
224    use crate::config::{CompositeStrategy, CompositeToolConfig};
225    use crate::test_util::{ErrorMockService, MockService, call_service};
226
227    fn test_composites() -> Vec<CompositeToolConfig> {
228        vec![CompositeToolConfig {
229            name: "search_all".to_string(),
230            description: "Search across all sources".to_string(),
231            tools: vec!["github/search".to_string(), "docs/search".to_string()],
232            strategy: CompositeStrategy::Parallel,
233        }]
234    }
235
236    #[tokio::test]
237    async fn test_composite_appears_in_list_tools() {
238        let mock = MockService::with_tools(&["github/search", "docs/search", "db/query"]);
239        let mut svc = CompositeService::new(mock, test_composites());
240
241        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
242        match resp.inner.unwrap() {
243            McpResponse::ListTools(result) => {
244                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
245                assert!(names.contains(&"github/search"));
246                assert!(names.contains(&"docs/search"));
247                assert!(names.contains(&"db/query"));
248                assert!(
249                    names.contains(&"search_all"),
250                    "composite tool should appear"
251                );
252                // Verify description
253                let composite_tool = result
254                    .tools
255                    .iter()
256                    .find(|t| t.name == "search_all")
257                    .unwrap();
258                assert_eq!(
259                    composite_tool.description.as_deref(),
260                    Some("Search across all sources")
261                );
262            }
263            other => panic!("expected ListTools, got: {:?}", other),
264        }
265    }
266
267    #[tokio::test]
268    async fn test_composite_fan_out_aggregates_results() {
269        let mock = MockService::with_tools(&["github/search", "docs/search"]);
270        let mut svc = CompositeService::new(mock, test_composites());
271
272        let resp = call_service(
273            &mut svc,
274            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
275                name: "search_all".to_string(),
276                arguments: serde_json::json!({"q": "test"}),
277                meta: None,
278                task: None,
279            }),
280        )
281        .await;
282
283        match resp.inner.unwrap() {
284            McpResponse::CallTool(result) => {
285                assert_eq!(result.content.len(), 2, "should aggregate both results");
286                let texts: Vec<String> = result
287                    .content
288                    .iter()
289                    .map(|c| c.as_text().unwrap().to_string())
290                    .collect();
291                assert!(texts.contains(&"called: github/search".to_string()));
292                assert!(texts.contains(&"called: docs/search".to_string()));
293                assert!(!result.is_error, "no errors expected");
294            }
295            other => panic!("expected CallTool, got: {:?}", other),
296        }
297    }
298
299    #[tokio::test]
300    async fn test_non_composite_call_passes_through() {
301        let mock = MockService::with_tools(&["db/query"]);
302        let mut svc = CompositeService::new(mock, test_composites());
303
304        let resp = call_service(
305            &mut svc,
306            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
307                name: "db/query".to_string(),
308                arguments: serde_json::json!({}),
309                meta: None,
310                task: None,
311            }),
312        )
313        .await;
314
315        match resp.inner.unwrap() {
316            McpResponse::CallTool(result) => {
317                assert_eq!(result.all_text(), "called: db/query");
318            }
319            other => panic!("expected CallTool, got: {:?}", other),
320        }
321    }
322
323    #[tokio::test]
324    async fn test_partial_failure_returns_partial_results() {
325        // Use ErrorMockService -- all calls will fail, producing error content
326        let mock = ErrorMockService;
327        let mut svc = CompositeService::new(mock, test_composites());
328
329        let resp = call_service(
330            &mut svc,
331            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
332                name: "search_all".to_string(),
333                arguments: serde_json::json!({}),
334                meta: None,
335                task: None,
336            }),
337        )
338        .await;
339
340        match resp.inner.unwrap() {
341            McpResponse::CallTool(result) => {
342                assert_eq!(
343                    result.content.len(),
344                    2,
345                    "should have error content for both tools"
346                );
347                assert!(result.is_error, "should be marked as error");
348                for content in &result.content {
349                    let text = content.as_text().unwrap();
350                    assert!(
351                        text.contains("Error:"),
352                        "content should describe error: {text}"
353                    );
354                }
355            }
356            other => panic!("expected CallTool, got: {:?}", other),
357        }
358    }
359
360    #[tokio::test]
361    async fn test_non_tool_requests_pass_through() {
362        let mock = MockService::with_tools(&[]);
363        let mut svc = CompositeService::new(mock, test_composites());
364
365        let resp = call_service(&mut svc, McpRequest::Ping).await;
366        match resp.inner.unwrap() {
367            McpResponse::Pong(_) => {} // expected
368            other => panic!("expected Pong, got: {:?}", other),
369        }
370    }
371
372    #[tokio::test]
373    async fn test_empty_composites_passes_through() {
374        let mock = MockService::with_tools(&["tool1"]);
375        let mut svc = CompositeService::new(mock, vec![]);
376
377        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
378        match resp.inner.unwrap() {
379            McpResponse::ListTools(result) => {
380                assert_eq!(result.tools.len(), 1);
381                assert_eq!(result.tools[0].name, "tool1");
382            }
383            other => panic!("expected ListTools, got: {:?}", other),
384        }
385    }
386}