Skip to main content

mcp_proxy/
alias.rs

1//! Tool aliasing middleware for the proxy.
2//!
3//! Rewrites tool names in list responses and call requests based on
4//! per-backend alias configuration.
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tower::Service;
14use tower_mcp::router::{RouterRequest, RouterResponse};
15use tower_mcp_types::protocol::{McpRequest, McpResponse};
16
17/// Resolved alias mappings for all backends.
18#[derive(Clone)]
19pub struct AliasMap {
20    /// Maps "namespace/original" -> "namespace/alias" (for list responses)
21    pub forward: HashMap<String, String>,
22    /// Maps "namespace/alias" -> "namespace/original" (for call requests)
23    reverse: HashMap<String, String>,
24}
25
26impl AliasMap {
27    /// Build an alias map from `(namespace, from, to)` triples. Returns `None` if empty.
28    pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
29        if mappings.is_empty() {
30            return None;
31        }
32        let mut forward = HashMap::new();
33        let mut reverse = HashMap::new();
34        for (namespace, from, to) in mappings {
35            let original = format!("{}{}", namespace, from);
36            let aliased = format!("{}{}", namespace, to);
37            forward.insert(original.clone(), aliased.clone());
38            reverse.insert(aliased, original);
39        }
40        Some(Self { forward, reverse })
41    }
42}
43
44/// Tower service that rewrites tool names based on alias configuration.
45#[derive(Clone)]
46pub struct AliasService<S> {
47    inner: S,
48    aliases: Arc<AliasMap>,
49}
50
51impl<S> AliasService<S> {
52    /// Create a new alias service wrapping `inner` with the given alias map.
53    pub fn new(inner: S, aliases: AliasMap) -> Self {
54        Self {
55            inner,
56            aliases: Arc::new(aliases),
57        }
58    }
59}
60
61impl<S> Service<RouterRequest> for AliasService<S>
62where
63    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
64        + Clone
65        + Send
66        + 'static,
67    S::Future: Send,
68{
69    type Response = RouterResponse;
70    type Error = Infallible;
71    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
72
73    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74        self.inner.poll_ready(cx)
75    }
76
77    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
78        let aliases = Arc::clone(&self.aliases);
79
80        // Reverse-map aliased names back to originals in requests
81        match &mut req.inner {
82            McpRequest::CallTool(params) => {
83                if let Some(original) = aliases.reverse.get(&params.name) {
84                    params.name = original.clone();
85                }
86            }
87            McpRequest::ReadResource(params) => {
88                if let Some(original) = aliases.reverse.get(&params.uri) {
89                    params.uri = original.clone();
90                }
91            }
92            McpRequest::GetPrompt(params) => {
93                if let Some(original) = aliases.reverse.get(&params.name) {
94                    params.name = original.clone();
95                }
96            }
97            _ => {}
98        }
99
100        let fut = self.inner.call(req);
101
102        Box::pin(async move {
103            let mut result = fut.await;
104
105            // Forward-map original names to aliases in responses
106            let Ok(ref mut resp) = result;
107            if let Ok(mcp_resp) = &mut resp.inner {
108                match mcp_resp {
109                    McpResponse::ListTools(r) => {
110                        for tool in &mut r.tools {
111                            if let Some(aliased) = aliases.forward.get(&tool.name) {
112                                tool.name = aliased.clone();
113                            }
114                        }
115                    }
116                    McpResponse::ListResources(r) => {
117                        for res in &mut r.resources {
118                            if let Some(aliased) = aliases.forward.get(&res.uri) {
119                                res.uri = aliased.clone();
120                            }
121                        }
122                    }
123                    McpResponse::ListPrompts(r) => {
124                        for prompt in &mut r.prompts {
125                            if let Some(aliased) = aliases.forward.get(&prompt.name) {
126                                prompt.name = aliased.clone();
127                            }
128                        }
129                    }
130                    _ => {}
131                }
132            }
133
134            result
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use tower_mcp::protocol::{McpRequest, McpResponse};
142
143    use super::{AliasMap, AliasService};
144    use crate::test_util::{MockService, call_service};
145
146    fn test_aliases() -> AliasMap {
147        AliasMap::new(vec![
148            ("files/".into(), "read_file".into(), "read".into()),
149            ("files/".into(), "write_file".into(), "write".into()),
150        ])
151        .unwrap()
152    }
153
154    #[test]
155    fn test_alias_map_empty_returns_none() {
156        assert!(AliasMap::new(vec![]).is_none());
157    }
158
159    #[test]
160    fn test_alias_map_forward_and_reverse() {
161        let aliases = test_aliases();
162        assert_eq!(
163            aliases.forward.get("files/read_file").unwrap(),
164            "files/read"
165        );
166        assert_eq!(aliases.forward.len(), 2);
167    }
168
169    #[tokio::test]
170    async fn test_alias_rewrites_list_tools() {
171        let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
172        let mut svc = AliasService::new(mock, test_aliases());
173
174        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
175        match resp.inner.unwrap() {
176            McpResponse::ListTools(result) => {
177                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
178                assert!(names.contains(&"files/read"));
179                assert!(names.contains(&"files/write"));
180                assert!(names.contains(&"db/query")); // unchanged
181            }
182            other => panic!("expected ListTools, got: {:?}", other),
183        }
184    }
185
186    #[tokio::test]
187    async fn test_alias_reverse_maps_call_tool() {
188        let mock = MockService::with_tools(&["files/read_file"]);
189        let mut svc = AliasService::new(mock, test_aliases());
190
191        let resp = call_service(
192            &mut svc,
193            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
194                name: "files/read".to_string(),
195                arguments: serde_json::json!({}),
196                meta: None,
197                task: None,
198            }),
199        )
200        .await;
201
202        match resp.inner.unwrap() {
203            McpResponse::CallTool(result) => {
204                assert_eq!(result.all_text(), "called: files/read_file");
205            }
206            other => panic!("expected CallTool, got: {:?}", other),
207        }
208    }
209
210    #[tokio::test]
211    async fn test_alias_passthrough_non_aliased() {
212        let mock = MockService::with_tools(&["db/query"]);
213        let mut svc = AliasService::new(mock, test_aliases());
214
215        let resp = call_service(
216            &mut svc,
217            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
218                name: "db/query".to_string(),
219                arguments: serde_json::json!({}),
220                meta: None,
221                task: None,
222            }),
223        )
224        .await;
225
226        match resp.inner.unwrap() {
227            McpResponse::CallTool(result) => {
228                assert_eq!(result.all_text(), "called: db/query");
229            }
230            other => panic!("expected CallTool, got: {:?}", other),
231        }
232    }
233}