1use std::collections::HashMap;
7use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tower::{Layer, Service};
14use tower_mcp::router::{RouterRequest, RouterResponse};
15use tower_mcp_types::protocol::{McpRequest, McpResponse};
16
17#[derive(Clone)]
34pub struct AliasLayer {
35 aliases: AliasMap,
36}
37
38impl AliasLayer {
39 pub fn new(aliases: AliasMap) -> Self {
41 Self { aliases }
42 }
43}
44
45impl<S> Layer<S> for AliasLayer {
46 type Service = AliasService<S>;
47
48 fn layer(&self, inner: S) -> Self::Service {
49 AliasService::new(inner, self.aliases.clone())
50 }
51}
52
53#[derive(Clone)]
55pub struct AliasMap {
56 pub forward: HashMap<String, String>,
58 reverse: HashMap<String, String>,
60}
61
62impl AliasMap {
63 pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
65 if mappings.is_empty() {
66 return None;
67 }
68 let mut forward = HashMap::new();
69 let mut reverse = HashMap::new();
70 for (namespace, from, to) in mappings {
71 let original = format!("{}{}", namespace, from);
72 let aliased = format!("{}{}", namespace, to);
73 forward.insert(original.clone(), aliased.clone());
74 reverse.insert(aliased, original);
75 }
76 Some(Self { forward, reverse })
77 }
78}
79
80#[derive(Clone)]
82pub struct AliasService<S> {
83 inner: S,
84 aliases: Arc<AliasMap>,
85}
86
87impl<S> AliasService<S> {
88 pub fn new(inner: S, aliases: AliasMap) -> Self {
90 Self {
91 inner,
92 aliases: Arc::new(aliases),
93 }
94 }
95}
96
97impl<S> Service<RouterRequest> for AliasService<S>
98where
99 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
100 + Clone
101 + Send
102 + 'static,
103 S::Future: Send,
104{
105 type Response = RouterResponse;
106 type Error = Infallible;
107 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
108
109 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 self.inner.poll_ready(cx)
111 }
112
113 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
114 let aliases = Arc::clone(&self.aliases);
115
116 match &mut req.inner {
118 McpRequest::CallTool(params) => {
119 if let Some(original) = aliases.reverse.get(¶ms.name) {
120 params.name = original.clone();
121 }
122 }
123 McpRequest::ReadResource(params) => {
124 if let Some(original) = aliases.reverse.get(¶ms.uri) {
125 params.uri = original.clone();
126 }
127 }
128 McpRequest::GetPrompt(params) => {
129 if let Some(original) = aliases.reverse.get(¶ms.name) {
130 params.name = original.clone();
131 }
132 }
133 _ => {}
134 }
135
136 let fut = self.inner.call(req);
137
138 Box::pin(async move {
139 let mut result = fut.await;
140
141 let Ok(ref mut resp) = result;
143 if let Ok(mcp_resp) = &mut resp.inner {
144 match mcp_resp {
145 McpResponse::ListTools(r) => {
146 for tool in &mut r.tools {
147 if let Some(aliased) = aliases.forward.get(&tool.name) {
148 tool.name = aliased.clone();
149 }
150 }
151 }
152 McpResponse::ListResources(r) => {
153 for res in &mut r.resources {
154 if let Some(aliased) = aliases.forward.get(&res.uri) {
155 res.uri = aliased.clone();
156 }
157 }
158 }
159 McpResponse::ListPrompts(r) => {
160 for prompt in &mut r.prompts {
161 if let Some(aliased) = aliases.forward.get(&prompt.name) {
162 prompt.name = aliased.clone();
163 }
164 }
165 }
166 _ => {}
167 }
168 }
169
170 result
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use tower_mcp::protocol::{McpRequest, McpResponse};
178
179 use super::{AliasMap, AliasService};
180 use crate::test_util::{MockService, call_service};
181
182 fn test_aliases() -> AliasMap {
183 AliasMap::new(vec![
184 ("files/".into(), "read_file".into(), "read".into()),
185 ("files/".into(), "write_file".into(), "write".into()),
186 ])
187 .unwrap()
188 }
189
190 #[test]
191 fn test_alias_map_empty_returns_none() {
192 assert!(AliasMap::new(vec![]).is_none());
193 }
194
195 #[test]
196 fn test_alias_map_forward_and_reverse() {
197 let aliases = test_aliases();
198 assert_eq!(
199 aliases.forward.get("files/read_file").unwrap(),
200 "files/read"
201 );
202 assert_eq!(aliases.forward.len(), 2);
203 }
204
205 #[tokio::test]
206 async fn test_alias_rewrites_list_tools() {
207 let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
208 let mut svc = AliasService::new(mock, test_aliases());
209
210 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
211 match resp.inner.unwrap() {
212 McpResponse::ListTools(result) => {
213 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
214 assert!(names.contains(&"files/read"));
215 assert!(names.contains(&"files/write"));
216 assert!(names.contains(&"db/query")); }
218 other => panic!("expected ListTools, got: {:?}", other),
219 }
220 }
221
222 #[tokio::test]
223 async fn test_alias_reverse_maps_call_tool() {
224 let mock = MockService::with_tools(&["files/read_file"]);
225 let mut svc = AliasService::new(mock, test_aliases());
226
227 let resp = call_service(
228 &mut svc,
229 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
230 name: "files/read".to_string(),
231 arguments: serde_json::json!({}),
232 meta: None,
233 task: None,
234 }),
235 )
236 .await;
237
238 match resp.inner.unwrap() {
239 McpResponse::CallTool(result) => {
240 assert_eq!(result.all_text(), "called: files/read_file");
241 }
242 other => panic!("expected CallTool, got: {:?}", other),
243 }
244 }
245
246 #[tokio::test]
247 async fn test_alias_passthrough_non_aliased() {
248 let mock = MockService::with_tools(&["db/query"]);
249 let mut svc = AliasService::new(mock, test_aliases());
250
251 let resp = call_service(
252 &mut svc,
253 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
254 name: "db/query".to_string(),
255 arguments: serde_json::json!({}),
256 meta: None,
257 task: None,
258 }),
259 )
260 .await;
261
262 match resp.inner.unwrap() {
263 McpResponse::CallTool(result) => {
264 assert_eq!(result.all_text(), "called: db/query");
265 }
266 other => panic!("expected CallTool, got: {:?}", other),
267 }
268 }
269}