Skip to main content

mcp_proxy/
alias.rs

1//! Tool aliasing middleware for the proxy.
2//!
3//! Rewrites tool names in list responses and call requests based on
4//! per-backend alias configuration. This lets operators expose backend tools
5//! under different names without modifying the backends themselves.
6//!
7//! # How it works
8//!
9//! Aliasing maintains a bidirectional mapping between original and aliased
10//! names (stored in [`AliasMap`]):
11//!
12//! - **Forward mapping** (original -> alias) -- applied to `ListTools`,
13//!   `ListResources`, and `ListPrompts` responses so clients see the
14//!   aliased names.
15//! - **Reverse mapping** (alias -> original) -- applied to `CallTool`,
16//!   `ReadResource`, and `GetPrompt` requests so the backend receives
17//!   the original name it expects.
18//!
19//! Names that have no alias configured pass through unchanged in both
20//! directions.
21//!
22//! # Configuration
23//!
24//! Aliases are configured per-backend in TOML. The `from` field is the
25//! backend-local tool name (without the namespace prefix); the `to` field
26//! is the new name to expose:
27//!
28//! ```toml
29//! [[backends]]
30//! name = "files"
31//! transport = "stdio"
32//! command = "file-server"
33//!
34//! [[backends.aliases]]
35//! from = "read_file"
36//! to = "read"
37//!
38//! [[backends.aliases]]
39//! from = "write_file"
40//! to = "write"
41//! ```
42//!
43//! With this config, `files/read_file` appears to clients as `files/read`,
44//! and calling `files/read` is transparently forwarded to the backend as
45//! `files/read_file`.
46//!
47//! # Middleware stack position
48//!
49//! Aliasing runs after capability filtering and search-mode filtering, so
50//! filters operate on original names and aliases are applied last. The
51//! ordering in `proxy.rs`:
52//!
53//! 1. Request validation ([`crate::validation`])
54//! 2. Capability filtering ([`crate::filter`])
55//! 3. Search-mode filtering ([`crate::filter`])
56//! 4. **Tool aliasing** (this module)
57//! 5. Composite tools ([`crate::composite`])
58
59use std::collections::HashMap;
60use std::convert::Infallible;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::{Context, Poll};
65
66use tower::{Layer, Service};
67use tower_mcp::router::{RouterRequest, RouterResponse};
68use tower_mcp_types::protocol::{McpRequest, McpResponse};
69
70/// Tower layer that produces an [`AliasService`].
71///
72/// # Example
73///
74/// ```rust,ignore
75/// use tower::ServiceBuilder;
76/// use mcp_proxy::alias::{AliasLayer, AliasMap};
77///
78/// let aliases = AliasMap::new(vec![
79///     ("math/".into(), "add".into(), "sum".into()),
80/// ]).unwrap();
81///
82/// let service = ServiceBuilder::new()
83///     .layer(AliasLayer::new(aliases))
84///     .service(proxy);
85/// ```
86#[derive(Clone)]
87pub struct AliasLayer {
88    aliases: AliasMap,
89}
90
91impl AliasLayer {
92    /// Create a new alias layer with the given alias map.
93    pub fn new(aliases: AliasMap) -> Self {
94        Self { aliases }
95    }
96}
97
98impl<S> Layer<S> for AliasLayer {
99    type Service = AliasService<S>;
100
101    fn layer(&self, inner: S) -> Self::Service {
102        AliasService::new(inner, self.aliases.clone())
103    }
104}
105
106/// Resolved alias mappings for all backends.
107#[derive(Clone)]
108pub struct AliasMap {
109    /// Maps "namespace/original" -> "namespace/alias" (for list responses)
110    pub forward: HashMap<String, String>,
111    /// Maps "namespace/alias" -> "namespace/original" (for call requests)
112    reverse: HashMap<String, String>,
113}
114
115impl AliasMap {
116    /// Build an alias map from `(namespace, from, to)` triples. Returns `None` if empty.
117    pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
118        if mappings.is_empty() {
119            return None;
120        }
121        let mut forward = HashMap::new();
122        let mut reverse = HashMap::new();
123        for (namespace, from, to) in mappings {
124            let original = format!("{}{}", namespace, from);
125            let aliased = format!("{}{}", namespace, to);
126            forward.insert(original.clone(), aliased.clone());
127            reverse.insert(aliased, original);
128        }
129        Some(Self { forward, reverse })
130    }
131}
132
133/// Tower service that rewrites tool names based on alias configuration.
134#[derive(Clone)]
135pub struct AliasService<S> {
136    inner: S,
137    aliases: Arc<AliasMap>,
138}
139
140impl<S> AliasService<S> {
141    /// Create a new alias service wrapping `inner` with the given alias map.
142    pub fn new(inner: S, aliases: AliasMap) -> Self {
143        Self {
144            inner,
145            aliases: Arc::new(aliases),
146        }
147    }
148}
149
150impl<S> Service<RouterRequest> for AliasService<S>
151where
152    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
153        + Clone
154        + Send
155        + 'static,
156    S::Future: Send,
157{
158    type Response = RouterResponse;
159    type Error = Infallible;
160    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
161
162    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163        self.inner.poll_ready(cx)
164    }
165
166    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
167        let aliases = Arc::clone(&self.aliases);
168
169        // Reverse-map aliased names back to originals in requests
170        match &mut req.inner {
171            McpRequest::CallTool(params) => {
172                if let Some(original) = aliases.reverse.get(&params.name) {
173                    params.name = original.clone();
174                }
175            }
176            McpRequest::ReadResource(params) => {
177                if let Some(original) = aliases.reverse.get(&params.uri) {
178                    params.uri = original.clone();
179                }
180            }
181            McpRequest::GetPrompt(params) => {
182                if let Some(original) = aliases.reverse.get(&params.name) {
183                    params.name = original.clone();
184                }
185            }
186            _ => {}
187        }
188
189        let fut = self.inner.call(req);
190
191        Box::pin(async move {
192            let mut result = fut.await;
193
194            // Forward-map original names to aliases in responses
195            let Ok(ref mut resp) = result;
196            if let Ok(mcp_resp) = &mut resp.inner {
197                match mcp_resp {
198                    McpResponse::ListTools(r) => {
199                        for tool in &mut r.tools {
200                            if let Some(aliased) = aliases.forward.get(&tool.name) {
201                                tool.name = aliased.clone();
202                            }
203                        }
204                    }
205                    McpResponse::ListResources(r) => {
206                        for res in &mut r.resources {
207                            if let Some(aliased) = aliases.forward.get(&res.uri) {
208                                res.uri = aliased.clone();
209                            }
210                        }
211                    }
212                    McpResponse::ListPrompts(r) => {
213                        for prompt in &mut r.prompts {
214                            if let Some(aliased) = aliases.forward.get(&prompt.name) {
215                                prompt.name = aliased.clone();
216                            }
217                        }
218                    }
219                    _ => {}
220                }
221            }
222
223            result
224        })
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use tower_mcp::protocol::{McpRequest, McpResponse};
231
232    use super::{AliasMap, AliasService};
233    use crate::test_util::{MockService, call_service};
234
235    fn test_aliases() -> AliasMap {
236        AliasMap::new(vec![
237            ("files/".into(), "read_file".into(), "read".into()),
238            ("files/".into(), "write_file".into(), "write".into()),
239        ])
240        .unwrap()
241    }
242
243    #[test]
244    fn test_alias_map_empty_returns_none() {
245        assert!(AliasMap::new(vec![]).is_none());
246    }
247
248    #[test]
249    fn test_alias_map_forward_and_reverse() {
250        let aliases = test_aliases();
251        assert_eq!(
252            aliases.forward.get("files/read_file").unwrap(),
253            "files/read"
254        );
255        assert_eq!(aliases.forward.len(), 2);
256    }
257
258    #[tokio::test]
259    async fn test_alias_rewrites_list_tools() {
260        let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
261        let mut svc = AliasService::new(mock, test_aliases());
262
263        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
264        match resp.inner.unwrap() {
265            McpResponse::ListTools(result) => {
266                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
267                assert!(names.contains(&"files/read"));
268                assert!(names.contains(&"files/write"));
269                assert!(names.contains(&"db/query")); // unchanged
270            }
271            other => panic!("expected ListTools, got: {:?}", other),
272        }
273    }
274
275    #[tokio::test]
276    async fn test_alias_reverse_maps_call_tool() {
277        let mock = MockService::with_tools(&["files/read_file"]);
278        let mut svc = AliasService::new(mock, test_aliases());
279
280        let resp = call_service(
281            &mut svc,
282            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
283                name: "files/read".to_string(),
284                arguments: serde_json::json!({}),
285                meta: None,
286                task: None,
287            }),
288        )
289        .await;
290
291        match resp.inner.unwrap() {
292            McpResponse::CallTool(result) => {
293                assert_eq!(result.all_text(), "called: files/read_file");
294            }
295            other => panic!("expected CallTool, got: {:?}", other),
296        }
297    }
298
299    #[tokio::test]
300    async fn test_alias_passthrough_non_aliased() {
301        let mock = MockService::with_tools(&["db/query"]);
302        let mut svc = AliasService::new(mock, test_aliases());
303
304        let resp = call_service(
305            &mut svc,
306            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
307                name: "db/query".to_string(),
308                arguments: serde_json::json!({}),
309                meta: None,
310                task: None,
311            }),
312        )
313        .await;
314
315        match resp.inner.unwrap() {
316            McpResponse::CallTool(result) => {
317                assert_eq!(result.all_text(), "called: db/query");
318            }
319            other => panic!("expected CallTool, got: {:?}", other),
320        }
321    }
322}