Skip to main content

mcp_proxy/
composite.rs

1//! Composite tool middleware for fan-out to multiple backend tools.
2//!
3//! Composite tools are virtual tools that do not exist on any single backend.
4//! When called, they fan out the request to multiple backend tools concurrently,
5//! aggregating all results into a single response. This is useful for
6//! cross-cutting operations like "search everything" or "health-check all
7//! backends."
8//!
9//! # How it works
10//!
11//! The [`CompositeService`] intercepts two request types:
12//!
13//! - **`ListTools`** -- appends the composite tool definitions to the response
14//!   so clients discover them alongside regular backend tools.
15//! - **`CallTool`** -- if the tool name matches a composite, the same arguments
16//!   are forwarded to every target tool concurrently using `tokio::JoinSet`.
17//!   Results from all targets are collected into a single `CallToolResult`
18//!   whose `content` is the concatenation of all individual results. If any
19//!   target fails, the aggregated result's `is_error` flag is set to `true`,
20//!   but successful results are still included.
21//!
22//! All other request types pass through unchanged.
23//!
24//! # Strategy
25//!
26//! The `strategy` field controls execution order. Currently one strategy
27//! is supported:
28//!
29//! - **`parallel`** (default) -- all target tools execute concurrently via
30//!   `tokio::JoinSet`. Results are returned in completion order.
31//!
32//! # Configuration
33//!
34//! Composite tools are defined at the top level in TOML, referencing
35//! namespaced tool names from any backend:
36//!
37//! ```toml
38//! [[composite_tools]]
39//! name = "search_all"
40//! description = "Search across all knowledge sources"
41//! tools = ["github/search", "jira/search", "docs/search"]
42//! strategy = "parallel"
43//! ```
44//!
45//! Validation enforces that composite tool names are non-empty, unique,
46//! and reference at least one target tool.
47//!
48//! # Middleware stack position
49//!
50//! Composite tools are the outermost middleware in the request-processing
51//! stack, applied after aliasing. This means composite tool names are not
52//! subject to alias rewriting, but the target tools they reference are
53//! resolved through the full middleware chain (including aliases, filters,
54//! and validation). The ordering in `proxy.rs`:
55//!
56//! 1. Request validation ([`crate::validation`])
57//! 2. Capability filtering ([`crate::filter`])
58//! 3. Search-mode filtering ([`crate::filter`])
59//! 4. Tool aliasing ([`crate::alias`])
60//! 5. **Composite tools** (this module)
61
62use std::convert::Infallible;
63use std::future::Future;
64use std::pin::Pin;
65use std::sync::Arc;
66use std::task::{Context, Poll};
67
68use tokio::task::JoinSet;
69use tower::{Layer, Service};
70use tower_mcp::protocol::{
71    CallToolParams, CallToolResult, McpRequest, McpResponse, ToolDefinition,
72};
73use tower_mcp::router::{RouterRequest, RouterResponse};
74
75use crate::config::CompositeToolConfig;
76
77/// Tower layer that produces a [`CompositeService`].
78///
79/// # Example
80///
81/// ```rust,ignore
82/// use tower::ServiceBuilder;
83/// use mcp_proxy::composite::CompositeLayer;
84/// use mcp_proxy::config::CompositeToolConfig;
85///
86/// let composites = vec![CompositeToolConfig {
87///     name: "search_all".into(),
88///     description: "Search everything".into(),
89///     tools: vec!["github/search".into(), "docs/search".into()],
90///     strategy: Default::default(),
91/// }];
92///
93/// let service = ServiceBuilder::new()
94///     .layer(CompositeLayer::new(composites))
95///     .service(proxy);
96/// ```
97#[derive(Clone)]
98pub struct CompositeLayer {
99    composites: Arc<Vec<CompositeToolConfig>>,
100}
101
102impl CompositeLayer {
103    /// Create a new composite layer with the given tool definitions.
104    pub fn new(composites: Vec<CompositeToolConfig>) -> Self {
105        Self {
106            composites: Arc::new(composites),
107        }
108    }
109}
110
111impl<S> Layer<S> for CompositeLayer {
112    type Service = CompositeService<S>;
113
114    fn layer(&self, inner: S) -> Self::Service {
115        CompositeService::new(inner, (*self.composites).clone())
116    }
117}
118
119/// Tower service that intercepts `ListTools` and `CallTool` requests
120/// to support composite tool fan-out.
121#[derive(Clone)]
122pub struct CompositeService<S> {
123    inner: S,
124    composites: Arc<Vec<CompositeToolConfig>>,
125}
126
127impl<S> CompositeService<S> {
128    /// Create a new composite service wrapping `inner` with the given composite tool configs.
129    pub fn new(inner: S, composites: Vec<CompositeToolConfig>) -> Self {
130        Self {
131            inner,
132            composites: Arc::new(composites),
133        }
134    }
135}
136
137impl<S> Service<RouterRequest> for CompositeService<S>
138where
139    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
140        + Clone
141        + Send
142        + 'static,
143    S::Future: Send,
144{
145    type Response = RouterResponse;
146    type Error = Infallible;
147    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
148
149    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        self.inner.poll_ready(cx)
151    }
152
153    fn call(&mut self, req: RouterRequest) -> Self::Future {
154        let composites = Arc::clone(&self.composites);
155
156        // Check if this is a CallTool for a composite tool
157        if let McpRequest::CallTool(ref params) = req.inner
158            && let Some(composite) = composites.iter().find(|c| c.name == params.name)
159        {
160            let id = req.id.clone();
161            let extensions = req.extensions.clone();
162            let tool_names = composite.tools.clone();
163            let arguments = params.arguments.clone();
164            let meta = params.meta.clone();
165            let task = params.task.clone();
166            let inner = self.inner.clone();
167
168            return Box::pin(async move {
169                let mut join_set = JoinSet::new();
170
171                for tool_name in tool_names {
172                    let mut svc = inner.clone();
173                    let tool_req = RouterRequest {
174                        id: id.clone(),
175                        inner: McpRequest::CallTool(CallToolParams {
176                            name: tool_name,
177                            arguments: arguments.clone(),
178                            meta: meta.clone(),
179                            task: task.clone(),
180                        }),
181                        extensions: extensions.clone(),
182                    };
183                    join_set.spawn(async move { svc.call(tool_req).await });
184                }
185
186                let mut all_content = Vec::new();
187                let mut any_error = false;
188
189                while let Some(result) = join_set.join_next().await {
190                    match result {
191                        Ok(Ok(resp)) => match resp.inner {
192                            Ok(McpResponse::CallTool(call_result)) => {
193                                if call_result.is_error {
194                                    any_error = true;
195                                }
196                                all_content.extend(call_result.content);
197                            }
198                            Err(json_rpc_err) => {
199                                any_error = true;
200                                all_content.push(tower_mcp::protocol::Content::text(format!(
201                                    "Error: {}",
202                                    json_rpc_err.message
203                                )));
204                            }
205                            Ok(other) => {
206                                any_error = true;
207                                all_content.push(tower_mcp::protocol::Content::text(format!(
208                                    "Unexpected response type: {:?}",
209                                    other
210                                )));
211                            }
212                        },
213                        Ok(Err(_infallible)) => {
214                            // Infallible error -- cannot happen
215                        }
216                        Err(join_err) => {
217                            any_error = true;
218                            all_content.push(tower_mcp::protocol::Content::text(format!(
219                                "Task failed: {}",
220                                join_err
221                            )));
222                        }
223                    }
224                }
225
226                let result = CallToolResult {
227                    content: all_content,
228                    is_error: any_error,
229                    structured_content: None,
230                    meta: None,
231                };
232
233                Ok(RouterResponse {
234                    id,
235                    inner: Ok(McpResponse::CallTool(result)),
236                })
237            });
238        }
239
240        // For ListTools, append composite tool definitions
241        if matches!(req.inner, McpRequest::ListTools(_)) {
242            let fut = self.inner.call(req);
243
244            return Box::pin(async move {
245                let mut result = fut.await;
246
247                let Ok(ref mut resp) = result;
248                if let Ok(McpResponse::ListTools(ref mut list_result)) = resp.inner {
249                    for composite in composites.iter() {
250                        list_result.tools.push(ToolDefinition {
251                            name: composite.name.clone(),
252                            title: None,
253                            description: Some(composite.description.clone()),
254                            input_schema: serde_json::json!({"type": "object"}),
255                            output_schema: None,
256                            icons: None,
257                            annotations: None,
258                            execution: None,
259                            meta: None,
260                        });
261                    }
262                }
263
264                result
265            });
266        }
267
268        // All other requests pass through unchanged
269        let fut = self.inner.call(req);
270        Box::pin(fut)
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use tower_mcp::protocol::{McpRequest, McpResponse};
277
278    use super::CompositeService;
279    use crate::config::{CompositeStrategy, CompositeToolConfig};
280    use crate::test_util::{ErrorMockService, MockService, call_service};
281
282    fn test_composites() -> Vec<CompositeToolConfig> {
283        vec![CompositeToolConfig {
284            name: "search_all".to_string(),
285            description: "Search across all sources".to_string(),
286            tools: vec!["github/search".to_string(), "docs/search".to_string()],
287            strategy: CompositeStrategy::Parallel,
288        }]
289    }
290
291    #[tokio::test]
292    async fn test_composite_appears_in_list_tools() {
293        let mock = MockService::with_tools(&["github/search", "docs/search", "db/query"]);
294        let mut svc = CompositeService::new(mock, test_composites());
295
296        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
297        match resp.inner.unwrap() {
298            McpResponse::ListTools(result) => {
299                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
300                assert!(names.contains(&"github/search"));
301                assert!(names.contains(&"docs/search"));
302                assert!(names.contains(&"db/query"));
303                assert!(
304                    names.contains(&"search_all"),
305                    "composite tool should appear"
306                );
307                // Verify description
308                let composite_tool = result
309                    .tools
310                    .iter()
311                    .find(|t| t.name == "search_all")
312                    .unwrap();
313                assert_eq!(
314                    composite_tool.description.as_deref(),
315                    Some("Search across all sources")
316                );
317            }
318            other => panic!("expected ListTools, got: {:?}", other),
319        }
320    }
321
322    #[tokio::test]
323    async fn test_composite_fan_out_aggregates_results() {
324        let mock = MockService::with_tools(&["github/search", "docs/search"]);
325        let mut svc = CompositeService::new(mock, test_composites());
326
327        let resp = call_service(
328            &mut svc,
329            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
330                name: "search_all".to_string(),
331                arguments: serde_json::json!({"q": "test"}),
332                meta: None,
333                task: None,
334            }),
335        )
336        .await;
337
338        match resp.inner.unwrap() {
339            McpResponse::CallTool(result) => {
340                assert_eq!(result.content.len(), 2, "should aggregate both results");
341                let texts: Vec<String> = result
342                    .content
343                    .iter()
344                    .map(|c| c.as_text().unwrap().to_string())
345                    .collect();
346                assert!(texts.contains(&"called: github/search".to_string()));
347                assert!(texts.contains(&"called: docs/search".to_string()));
348                assert!(!result.is_error, "no errors expected");
349            }
350            other => panic!("expected CallTool, got: {:?}", other),
351        }
352    }
353
354    #[tokio::test]
355    async fn test_non_composite_call_passes_through() {
356        let mock = MockService::with_tools(&["db/query"]);
357        let mut svc = CompositeService::new(mock, test_composites());
358
359        let resp = call_service(
360            &mut svc,
361            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
362                name: "db/query".to_string(),
363                arguments: serde_json::json!({}),
364                meta: None,
365                task: None,
366            }),
367        )
368        .await;
369
370        match resp.inner.unwrap() {
371            McpResponse::CallTool(result) => {
372                assert_eq!(result.all_text(), "called: db/query");
373            }
374            other => panic!("expected CallTool, got: {:?}", other),
375        }
376    }
377
378    #[tokio::test]
379    async fn test_partial_failure_returns_partial_results() {
380        // Use ErrorMockService -- all calls will fail, producing error content
381        let mock = ErrorMockService;
382        let mut svc = CompositeService::new(mock, test_composites());
383
384        let resp = call_service(
385            &mut svc,
386            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
387                name: "search_all".to_string(),
388                arguments: serde_json::json!({}),
389                meta: None,
390                task: None,
391            }),
392        )
393        .await;
394
395        match resp.inner.unwrap() {
396            McpResponse::CallTool(result) => {
397                assert_eq!(
398                    result.content.len(),
399                    2,
400                    "should have error content for both tools"
401                );
402                assert!(result.is_error, "should be marked as error");
403                for content in &result.content {
404                    let text = content.as_text().unwrap();
405                    assert!(
406                        text.contains("Error:"),
407                        "content should describe error: {text}"
408                    );
409                }
410            }
411            other => panic!("expected CallTool, got: {:?}", other),
412        }
413    }
414
415    #[tokio::test]
416    async fn test_non_tool_requests_pass_through() {
417        let mock = MockService::with_tools(&[]);
418        let mut svc = CompositeService::new(mock, test_composites());
419
420        let resp = call_service(&mut svc, McpRequest::Ping).await;
421        match resp.inner.unwrap() {
422            McpResponse::Pong(_) => {} // expected
423            other => panic!("expected Pong, got: {:?}", other),
424        }
425    }
426
427    #[tokio::test]
428    async fn test_empty_composites_passes_through() {
429        let mock = MockService::with_tools(&["tool1"]);
430        let mut svc = CompositeService::new(mock, vec![]);
431
432        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
433        match resp.inner.unwrap() {
434            McpResponse::ListTools(result) => {
435                assert_eq!(result.tools.len(), 1);
436                assert_eq!(result.tools[0].name, "tool1");
437            }
438            other => panic!("expected ListTools, got: {:?}", other),
439        }
440    }
441}