1use std::collections::HashMap;
60use std::convert::Infallible;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::{Context, Poll};
65
66use tower::{Layer, Service};
67use tower_mcp::router::{RouterRequest, RouterResponse};
68use tower_mcp_types::protocol::{McpRequest, McpResponse};
69
70#[derive(Clone)]
87pub struct AliasLayer {
88 aliases: AliasMap,
89}
90
91impl AliasLayer {
92 pub fn new(aliases: AliasMap) -> Self {
94 Self { aliases }
95 }
96}
97
98impl<S> Layer<S> for AliasLayer {
99 type Service = AliasService<S>;
100
101 fn layer(&self, inner: S) -> Self::Service {
102 AliasService::new(inner, self.aliases.clone())
103 }
104}
105
106#[derive(Clone)]
108pub struct AliasMap {
109 pub forward: HashMap<String, String>,
111 reverse: HashMap<String, String>,
113}
114
115impl AliasMap {
116 pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
118 if mappings.is_empty() {
119 return None;
120 }
121 let mut forward = HashMap::new();
122 let mut reverse = HashMap::new();
123 for (namespace, from, to) in mappings {
124 let original = format!("{}{}", namespace, from);
125 let aliased = format!("{}{}", namespace, to);
126 forward.insert(original.clone(), aliased.clone());
127 reverse.insert(aliased, original);
128 }
129 Some(Self { forward, reverse })
130 }
131}
132
133#[derive(Clone)]
135pub struct AliasService<S> {
136 inner: S,
137 aliases: Arc<AliasMap>,
138}
139
140impl<S> AliasService<S> {
141 pub fn new(inner: S, aliases: AliasMap) -> Self {
143 Self {
144 inner,
145 aliases: Arc::new(aliases),
146 }
147 }
148}
149
150impl<S> Service<RouterRequest> for AliasService<S>
151where
152 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
153 + Clone
154 + Send
155 + 'static,
156 S::Future: Send,
157{
158 type Response = RouterResponse;
159 type Error = Infallible;
160 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
161
162 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163 self.inner.poll_ready(cx)
164 }
165
166 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
167 let aliases = Arc::clone(&self.aliases);
168
169 match &mut req.inner {
171 McpRequest::CallTool(params) => {
172 if let Some(original) = aliases.reverse.get(¶ms.name) {
173 params.name = original.clone();
174 }
175 }
176 McpRequest::ReadResource(params) => {
177 if let Some(original) = aliases.reverse.get(¶ms.uri) {
178 params.uri = original.clone();
179 }
180 }
181 McpRequest::GetPrompt(params) => {
182 if let Some(original) = aliases.reverse.get(¶ms.name) {
183 params.name = original.clone();
184 }
185 }
186 _ => {}
187 }
188
189 let fut = self.inner.call(req);
190
191 Box::pin(async move {
192 let mut result = fut.await;
193
194 let Ok(ref mut resp) = result;
196 if let Ok(mcp_resp) = &mut resp.inner {
197 match mcp_resp {
198 McpResponse::ListTools(r) => {
199 for tool in &mut r.tools {
200 if let Some(aliased) = aliases.forward.get(&tool.name) {
201 tool.name = aliased.clone();
202 }
203 }
204 }
205 McpResponse::ListResources(r) => {
206 for res in &mut r.resources {
207 if let Some(aliased) = aliases.forward.get(&res.uri) {
208 res.uri = aliased.clone();
209 }
210 }
211 }
212 McpResponse::ListPrompts(r) => {
213 for prompt in &mut r.prompts {
214 if let Some(aliased) = aliases.forward.get(&prompt.name) {
215 prompt.name = aliased.clone();
216 }
217 }
218 }
219 _ => {}
220 }
221 }
222
223 result
224 })
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use tower_mcp::protocol::{McpRequest, McpResponse};
231
232 use super::{AliasMap, AliasService};
233 use crate::test_util::{MockService, call_service};
234
235 fn test_aliases() -> AliasMap {
236 AliasMap::new(vec![
237 ("files/".into(), "read_file".into(), "read".into()),
238 ("files/".into(), "write_file".into(), "write".into()),
239 ])
240 .unwrap()
241 }
242
243 #[test]
244 fn test_alias_map_empty_returns_none() {
245 assert!(AliasMap::new(vec![]).is_none());
246 }
247
248 #[test]
249 fn test_alias_map_forward_and_reverse() {
250 let aliases = test_aliases();
251 assert_eq!(
252 aliases.forward.get("files/read_file").unwrap(),
253 "files/read"
254 );
255 assert_eq!(aliases.forward.len(), 2);
256 }
257
258 #[tokio::test]
259 async fn test_alias_rewrites_list_tools() {
260 let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
261 let mut svc = AliasService::new(mock, test_aliases());
262
263 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
264 match resp.inner.unwrap() {
265 McpResponse::ListTools(result) => {
266 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
267 assert!(names.contains(&"files/read"));
268 assert!(names.contains(&"files/write"));
269 assert!(names.contains(&"db/query")); }
271 other => panic!("expected ListTools, got: {:?}", other),
272 }
273 }
274
275 #[tokio::test]
276 async fn test_alias_reverse_maps_call_tool() {
277 let mock = MockService::with_tools(&["files/read_file"]);
278 let mut svc = AliasService::new(mock, test_aliases());
279
280 let resp = call_service(
281 &mut svc,
282 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
283 name: "files/read".to_string(),
284 arguments: serde_json::json!({}),
285 meta: None,
286 task: None,
287 }),
288 )
289 .await;
290
291 match resp.inner.unwrap() {
292 McpResponse::CallTool(result) => {
293 assert_eq!(result.all_text(), "called: files/read_file");
294 }
295 other => panic!("expected CallTool, got: {:?}", other),
296 }
297 }
298
299 #[tokio::test]
300 async fn test_alias_passthrough_non_aliased() {
301 let mock = MockService::with_tools(&["db/query"]);
302 let mut svc = AliasService::new(mock, test_aliases());
303
304 let resp = call_service(
305 &mut svc,
306 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
307 name: "db/query".to_string(),
308 arguments: serde_json::json!({}),
309 meta: None,
310 task: None,
311 }),
312 )
313 .await;
314
315 match resp.inner.unwrap() {
316 McpResponse::CallTool(result) => {
317 assert_eq!(result.all_text(), "called: db/query");
318 }
319 other => panic!("expected CallTool, got: {:?}", other),
320 }
321 }
322}