1use std::collections::HashMap;
7use std::convert::Infallible;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use tower::Service;
14use tower_mcp::router::{RouterRequest, RouterResponse};
15use tower_mcp_types::protocol::{McpRequest, McpResponse};
16
17#[derive(Clone)]
19pub struct AliasMap {
20 pub forward: HashMap<String, String>,
22 reverse: HashMap<String, String>,
24}
25
26impl AliasMap {
27 pub fn new(mappings: Vec<(String, String, String)>) -> Option<Self> {
29 if mappings.is_empty() {
30 return None;
31 }
32 let mut forward = HashMap::new();
33 let mut reverse = HashMap::new();
34 for (namespace, from, to) in mappings {
35 let original = format!("{}{}", namespace, from);
36 let aliased = format!("{}{}", namespace, to);
37 forward.insert(original.clone(), aliased.clone());
38 reverse.insert(aliased, original);
39 }
40 Some(Self { forward, reverse })
41 }
42}
43
44#[derive(Clone)]
46pub struct AliasService<S> {
47 inner: S,
48 aliases: Arc<AliasMap>,
49}
50
51impl<S> AliasService<S> {
52 pub fn new(inner: S, aliases: AliasMap) -> Self {
54 Self {
55 inner,
56 aliases: Arc::new(aliases),
57 }
58 }
59}
60
61impl<S> Service<RouterRequest> for AliasService<S>
62where
63 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
64 + Clone
65 + Send
66 + 'static,
67 S::Future: Send,
68{
69 type Response = RouterResponse;
70 type Error = Infallible;
71 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
72
73 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74 self.inner.poll_ready(cx)
75 }
76
77 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
78 let aliases = Arc::clone(&self.aliases);
79
80 match &mut req.inner {
82 McpRequest::CallTool(params) => {
83 if let Some(original) = aliases.reverse.get(¶ms.name) {
84 params.name = original.clone();
85 }
86 }
87 McpRequest::ReadResource(params) => {
88 if let Some(original) = aliases.reverse.get(¶ms.uri) {
89 params.uri = original.clone();
90 }
91 }
92 McpRequest::GetPrompt(params) => {
93 if let Some(original) = aliases.reverse.get(¶ms.name) {
94 params.name = original.clone();
95 }
96 }
97 _ => {}
98 }
99
100 let fut = self.inner.call(req);
101
102 Box::pin(async move {
103 let mut result = fut.await;
104
105 let Ok(ref mut resp) = result;
107 if let Ok(mcp_resp) = &mut resp.inner {
108 match mcp_resp {
109 McpResponse::ListTools(r) => {
110 for tool in &mut r.tools {
111 if let Some(aliased) = aliases.forward.get(&tool.name) {
112 tool.name = aliased.clone();
113 }
114 }
115 }
116 McpResponse::ListResources(r) => {
117 for res in &mut r.resources {
118 if let Some(aliased) = aliases.forward.get(&res.uri) {
119 res.uri = aliased.clone();
120 }
121 }
122 }
123 McpResponse::ListPrompts(r) => {
124 for prompt in &mut r.prompts {
125 if let Some(aliased) = aliases.forward.get(&prompt.name) {
126 prompt.name = aliased.clone();
127 }
128 }
129 }
130 _ => {}
131 }
132 }
133
134 result
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use tower_mcp::protocol::{McpRequest, McpResponse};
142
143 use super::{AliasMap, AliasService};
144 use crate::test_util::{MockService, call_service};
145
146 fn test_aliases() -> AliasMap {
147 AliasMap::new(vec![
148 ("files/".into(), "read_file".into(), "read".into()),
149 ("files/".into(), "write_file".into(), "write".into()),
150 ])
151 .unwrap()
152 }
153
154 #[test]
155 fn test_alias_map_empty_returns_none() {
156 assert!(AliasMap::new(vec![]).is_none());
157 }
158
159 #[test]
160 fn test_alias_map_forward_and_reverse() {
161 let aliases = test_aliases();
162 assert_eq!(
163 aliases.forward.get("files/read_file").unwrap(),
164 "files/read"
165 );
166 assert_eq!(aliases.forward.len(), 2);
167 }
168
169 #[tokio::test]
170 async fn test_alias_rewrites_list_tools() {
171 let mock = MockService::with_tools(&["files/read_file", "files/write_file", "db/query"]);
172 let mut svc = AliasService::new(mock, test_aliases());
173
174 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
175 match resp.inner.unwrap() {
176 McpResponse::ListTools(result) => {
177 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
178 assert!(names.contains(&"files/read"));
179 assert!(names.contains(&"files/write"));
180 assert!(names.contains(&"db/query")); }
182 other => panic!("expected ListTools, got: {:?}", other),
183 }
184 }
185
186 #[tokio::test]
187 async fn test_alias_reverse_maps_call_tool() {
188 let mock = MockService::with_tools(&["files/read_file"]);
189 let mut svc = AliasService::new(mock, test_aliases());
190
191 let resp = call_service(
192 &mut svc,
193 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
194 name: "files/read".to_string(),
195 arguments: serde_json::json!({}),
196 meta: None,
197 task: None,
198 }),
199 )
200 .await;
201
202 match resp.inner.unwrap() {
203 McpResponse::CallTool(result) => {
204 assert_eq!(result.all_text(), "called: files/read_file");
205 }
206 other => panic!("expected CallTool, got: {:?}", other),
207 }
208 }
209
210 #[tokio::test]
211 async fn test_alias_passthrough_non_aliased() {
212 let mock = MockService::with_tools(&["db/query"]);
213 let mut svc = AliasService::new(mock, test_aliases());
214
215 let resp = call_service(
216 &mut svc,
217 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
218 name: "db/query".to_string(),
219 arguments: serde_json::json!({}),
220 meta: None,
221 task: None,
222 }),
223 )
224 .await;
225
226 match resp.inner.unwrap() {
227 McpResponse::CallTool(result) => {
228 assert_eq!(result.all_text(), "called: db/query");
229 }
230 other => panic!("expected CallTool, got: {:?}", other),
231 }
232 }
233}