Skip to main content

mcp_proxy/
alias.rs

1//! Tool aliasing middleware for the proxy.
2//!
3//! Rewrites tool names in list responses and call requests based on
4//! per-backend alias configuration.
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tower::{Layer, Service};
14use tower_mcp::router::{RouterRequest, RouterResponse};
15use tower_mcp_types::protocol::{McpRequest, McpResponse};
16
17/// Tower layer that produces an [`AliasService`].
18///
19/// # Example
20///
21/// ```rust,ignore
22/// use tower::ServiceBuilder;
23/// use mcp_proxy::alias::{AliasLayer, AliasMap};
24///
25/// let aliases = AliasMap::new(vec![
26///     ("math/".into(), "add".into(), "sum".into()),
27/// ]).unwrap();
28///
29/// let service = ServiceBuilder::new()
30///     .layer(AliasLayer::new(aliases))
31///     .service(proxy);
32/// ```
33#[derive(Clone)]
34pub struct AliasLayer {
35    aliases: AliasMap,
36}
37
38impl AliasLayer {
39    /// Create a new alias layer with the given alias map.
40    pub fn new(aliases: AliasMap) -> Self {
41        Self { aliases }
42    }
43}
44
45impl<S> Layer<S> for AliasLayer {
46    type Service = AliasService<S>;
47
48    fn layer(&self, inner: S) -> Self::Service {
49        AliasService::new(inner, self.aliases.clone())
50    }
51}
52
53/// Resolved alias mappings for all backends.
54#[derive(Clone)]
55pub struct AliasMap {
56    /// Maps "namespace/original" -> "namespace/alias" (for list responses)
57    pub forward: HashMap<String, String>,
58    /// Maps "namespace/alias" -> "namespace/original" (for call requests)
59    reverse: HashMap<String, String>,
60}
61
62impl AliasMap {
63    /// Build an alias map from `(namespace, from, to)` triples. Returns `None` if empty.
64    pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
65        if mappings.is_empty() {
66            return None;
67        }
68        let mut forward = HashMap::new();
69        let mut reverse = HashMap::new();
70        for (namespace, from, to) in mappings {
71            let original = format!("{}{}", namespace, from);
72            let aliased = format!("{}{}", namespace, to);
73            forward.insert(original.clone(), aliased.clone());
74            reverse.insert(aliased, original);
75        }
76        Some(Self { forward, reverse })
77    }
78}
79
80/// Tower service that rewrites tool names based on alias configuration.
81#[derive(Clone)]
82pub struct AliasService<S> {
83    inner: S,
84    aliases: Arc<AliasMap>,
85}
86
87impl<S> AliasService<S> {
88    /// Create a new alias service wrapping `inner` with the given alias map.
89    pub fn new(inner: S, aliases: AliasMap) -> Self {
90        Self {
91            inner,
92            aliases: Arc::new(aliases),
93        }
94    }
95}
96
97impl<S> Service<RouterRequest> for AliasService<S>
98where
99    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
100        + Clone
101        + Send
102        + 'static,
103    S::Future: Send,
104{
105    type Response = RouterResponse;
106    type Error = Infallible;
107    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
108
109    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        self.inner.poll_ready(cx)
111    }
112
113    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
114        let aliases = Arc::clone(&self.aliases);
115
116        // Reverse-map aliased names back to originals in requests
117        match &mut req.inner {
118            McpRequest::CallTool(params) => {
119                if let Some(original) = aliases.reverse.get(&params.name) {
120                    params.name = original.clone();
121                }
122            }
123            McpRequest::ReadResource(params) => {
124                if let Some(original) = aliases.reverse.get(&params.uri) {
125                    params.uri = original.clone();
126                }
127            }
128            McpRequest::GetPrompt(params) => {
129                if let Some(original) = aliases.reverse.get(&params.name) {
130                    params.name = original.clone();
131                }
132            }
133            _ => {}
134        }
135
136        let fut = self.inner.call(req);
137
138        Box::pin(async move {
139            let mut result = fut.await;
140
141            // Forward-map original names to aliases in responses
142            let Ok(ref mut resp) = result;
143            if let Ok(mcp_resp) = &mut resp.inner {
144                match mcp_resp {
145                    McpResponse::ListTools(r) => {
146                        for tool in &mut r.tools {
147                            if let Some(aliased) = aliases.forward.get(&tool.name) {
148                                tool.name = aliased.clone();
149                            }
150                        }
151                    }
152                    McpResponse::ListResources(r) => {
153                        for res in &mut r.resources {
154                            if let Some(aliased) = aliases.forward.get(&res.uri) {
155                                res.uri = aliased.clone();
156                            }
157                        }
158                    }
159                    McpResponse::ListPrompts(r) => {
160                        for prompt in &mut r.prompts {
161                            if let Some(aliased) = aliases.forward.get(&prompt.name) {
162                                prompt.name = aliased.clone();
163                            }
164                        }
165                    }
166                    _ => {}
167                }
168            }
169
170            result
171        })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use tower_mcp::protocol::{McpRequest, McpResponse};
178
179    use super::{AliasMap, AliasService};
180    use crate::test_util::{MockService, call_service};
181
182    fn test_aliases() -> AliasMap {
183        AliasMap::new(vec![
184            ("files/".into(), "read_file".into(), "read".into()),
185            ("files/".into(), "write_file".into(), "write".into()),
186        ])
187        .unwrap()
188    }
189
190    #[test]
191    fn test_alias_map_empty_returns_none() {
192        assert!(AliasMap::new(vec![]).is_none());
193    }
194
195    #[test]
196    fn test_alias_map_forward_and_reverse() {
197        let aliases = test_aliases();
198        assert_eq!(
199            aliases.forward.get("files/read_file").unwrap(),
200            "files/read"
201        );
202        assert_eq!(aliases.forward.len(), 2);
203    }
204
205    #[tokio::test]
206    async fn test_alias_rewrites_list_tools() {
207        let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
208        let mut svc = AliasService::new(mock, test_aliases());
209
210        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
211        match resp.inner.unwrap() {
212            McpResponse::ListTools(result) => {
213                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
214                assert!(names.contains(&"files/read"));
215                assert!(names.contains(&"files/write"));
216                assert!(names.contains(&"db/query")); // unchanged
217            }
218            other => panic!("expected ListTools, got: {:?}", other),
219        }
220    }
221
222    #[tokio::test]
223    async fn test_alias_reverse_maps_call_tool() {
224        let mock = MockService::with_tools(&["files/read_file"]);
225        let mut svc = AliasService::new(mock, test_aliases());
226
227        let resp = call_service(
228            &mut svc,
229            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
230                name: "files/read".to_string(),
231                arguments: serde_json::json!({}),
232                meta: None,
233                task: None,
234            }),
235        )
236        .await;
237
238        match resp.inner.unwrap() {
239            McpResponse::CallTool(result) => {
240                assert_eq!(result.all_text(), "called: files/read_file");
241            }
242            other => panic!("expected CallTool, got: {:?}", other),
243        }
244    }
245
246    #[tokio::test]
247    async fn test_alias_passthrough_non_aliased() {
248        let mock = MockService::with_tools(&["db/query"]);
249        let mut svc = AliasService::new(mock, test_aliases());
250
251        let resp = call_service(
252            &mut svc,
253            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
254                name: "db/query".to_string(),
255                arguments: serde_json::json!({}),
256                meta: None,
257                task: None,
258            }),
259        )
260        .await;
261
262        match resp.inner.unwrap() {
263            McpResponse::CallTool(result) => {
264                assert_eq!(result.all_text(), "called: db/query");
265            }
266            other => panic!("expected CallTool, got: {:?}", other),
267        }
268    }
269}