1use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use tower::Service;
13
14use tower_mcp::protocol::{McpRequest, McpResponse};
15use tower_mcp::{RouterRequest, RouterResponse};
16use tower_mcp_types::JsonRpcError;
17
18use crate::config::BackendFilter;
19
20#[derive(Clone)]
22pub struct CapabilityFilterService<S> {
23 inner: S,
24 filters: Arc<Vec<BackendFilter>>,
25}
26
27impl<S> CapabilityFilterService<S> {
28 pub fn new(inner: S, filters: Vec<BackendFilter>) -> Self {
30 Self {
31 inner,
32 filters: Arc::new(filters),
33 }
34 }
35}
36
37impl<S> Service<RouterRequest> for CapabilityFilterService<S>
38where
39 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
40 + Clone
41 + Send
42 + 'static,
43 S::Future: Send,
44{
45 type Response = RouterResponse;
46 type Error = Infallible;
47 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
48
49 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 self.inner.poll_ready(cx)
51 }
52
53 fn call(&mut self, req: RouterRequest) -> Self::Future {
54 let filters = Arc::clone(&self.filters);
55 let request_id = req.id.clone();
56
57 match &req.inner {
59 McpRequest::CallTool(params) => {
60 if let Some(reason) = check_tool_denied(&filters, ¶ms.name) {
61 return Box::pin(async move {
62 Ok(RouterResponse {
63 id: request_id,
64 inner: Err(JsonRpcError::invalid_params(reason)),
65 })
66 });
67 }
68 }
69 McpRequest::ReadResource(params) => {
70 if let Some(reason) = check_resource_denied(&filters, ¶ms.uri) {
71 return Box::pin(async move {
72 Ok(RouterResponse {
73 id: request_id,
74 inner: Err(JsonRpcError::invalid_params(reason)),
75 })
76 });
77 }
78 }
79 McpRequest::GetPrompt(params) => {
80 if let Some(reason) = check_prompt_denied(&filters, ¶ms.name) {
81 return Box::pin(async move {
82 Ok(RouterResponse {
83 id: request_id,
84 inner: Err(JsonRpcError::invalid_params(reason)),
85 })
86 });
87 }
88 }
89 _ => {}
90 }
91
92 let fut = self.inner.call(req);
93
94 Box::pin(async move {
95 let mut resp = fut.await?;
96
97 if let Ok(ref mut mcp_resp) = resp.inner {
99 match mcp_resp {
100 McpResponse::ListTools(result) => {
101 result.tools.retain(|tool| {
102 for f in filters.iter() {
103 if let Some(local_name) = tool.name.strip_prefix(&f.namespace) {
104 return f.tool_filter.allows(local_name);
105 }
106 }
107 true
108 });
109 }
110 McpResponse::ListResources(result) => {
111 result.resources.retain(|resource| {
112 for f in filters.iter() {
113 if let Some(local_uri) = resource.uri.strip_prefix(&f.namespace) {
114 return f.resource_filter.allows(local_uri);
115 }
116 }
117 true
118 });
119 }
120 McpResponse::ListResourceTemplates(result) => {
121 result.resource_templates.retain(|template| {
122 for f in filters.iter() {
123 if let Some(local_uri) =
124 template.uri_template.strip_prefix(&f.namespace)
125 {
126 return f.resource_filter.allows(local_uri);
127 }
128 }
129 true
130 });
131 }
132 McpResponse::ListPrompts(result) => {
133 result.prompts.retain(|prompt| {
134 for f in filters.iter() {
135 if let Some(local_name) = prompt.name.strip_prefix(&f.namespace) {
136 return f.prompt_filter.allows(local_name);
137 }
138 }
139 true
140 });
141 }
142 _ => {}
143 }
144 }
145
146 Ok(resp)
147 })
148 }
149}
150
151fn check_tool_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
154 for f in filters {
155 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
156 if !f.tool_filter.allows(local_name) {
157 return Some(format!("Tool not available: {}", namespaced_name));
158 }
159 return None;
160 }
161 }
162 None
163}
164
165fn check_resource_denied(filters: &[BackendFilter], namespaced_uri: &str) -> Option<String> {
167 for f in filters {
168 if let Some(local_uri) = namespaced_uri.strip_prefix(&f.namespace) {
169 if !f.resource_filter.allows(local_uri) {
170 return Some(format!("Resource not available: {}", namespaced_uri));
171 }
172 return None;
173 }
174 }
175 None
176}
177
178fn check_prompt_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
180 for f in filters {
181 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
182 if !f.prompt_filter.allows(local_name) {
183 return Some(format!("Prompt not available: {}", namespaced_name));
184 }
185 return None;
186 }
187 }
188 None
189}
190
191#[cfg(test)]
192mod tests {
193 use tower_mcp::protocol::{McpRequest, McpResponse};
194
195 use super::CapabilityFilterService;
196 use crate::config::{BackendFilter, NameFilter};
197 use crate::test_util::{MockService, call_service};
198
199 fn allow_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
200 BackendFilter {
201 namespace: namespace.to_string(),
202 tool_filter: NameFilter::AllowList(tools.iter().map(|s| s.to_string()).collect()),
203 resource_filter: NameFilter::PassAll,
204 prompt_filter: NameFilter::PassAll,
205 }
206 }
207
208 fn deny_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
209 BackendFilter {
210 namespace: namespace.to_string(),
211 tool_filter: NameFilter::DenyList(tools.iter().map(|s| s.to_string()).collect()),
212 resource_filter: NameFilter::PassAll,
213 prompt_filter: NameFilter::PassAll,
214 }
215 }
216
217 #[tokio::test]
218 async fn test_filter_allow_list_tools() {
219 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
220 let filters = vec![allow_filter("fs/", &["read", "write"])];
221 let mut svc = CapabilityFilterService::new(mock, filters);
222
223 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
224 match resp.inner.unwrap() {
225 McpResponse::ListTools(result) => {
226 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
227 assert!(names.contains(&"fs/read"));
228 assert!(names.contains(&"fs/write"));
229 assert!(!names.contains(&"fs/delete"), "delete should be filtered");
230 }
231 other => panic!("expected ListTools, got: {:?}", other),
232 }
233 }
234
235 #[tokio::test]
236 async fn test_filter_deny_list_tools() {
237 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
238 let filters = vec![deny_filter("fs/", &["delete"])];
239 let mut svc = CapabilityFilterService::new(mock, filters);
240
241 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
242 match resp.inner.unwrap() {
243 McpResponse::ListTools(result) => {
244 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
245 assert!(names.contains(&"fs/read"));
246 assert!(names.contains(&"fs/write"));
247 assert!(!names.contains(&"fs/delete"));
248 }
249 other => panic!("expected ListTools, got: {:?}", other),
250 }
251 }
252
253 #[tokio::test]
254 async fn test_filter_denies_call_to_hidden_tool() {
255 let mock = MockService::with_tools(&["fs/read", "fs/delete"]);
256 let filters = vec![allow_filter("fs/", &["read"])];
257 let mut svc = CapabilityFilterService::new(mock, filters);
258
259 let resp = call_service(
260 &mut svc,
261 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
262 name: "fs/delete".to_string(),
263 arguments: serde_json::json!({}),
264 meta: None,
265 task: None,
266 }),
267 )
268 .await;
269
270 let err = resp.inner.unwrap_err();
271 assert!(
272 err.message.contains("not available"),
273 "should deny: {}",
274 err.message
275 );
276 }
277
278 #[tokio::test]
279 async fn test_filter_allows_call_to_permitted_tool() {
280 let mock = MockService::with_tools(&["fs/read"]);
281 let filters = vec![allow_filter("fs/", &["read"])];
282 let mut svc = CapabilityFilterService::new(mock, filters);
283
284 let resp = call_service(
285 &mut svc,
286 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
287 name: "fs/read".to_string(),
288 arguments: serde_json::json!({}),
289 meta: None,
290 task: None,
291 }),
292 )
293 .await;
294
295 assert!(resp.inner.is_ok(), "allowed tool should succeed");
296 }
297
298 #[tokio::test]
299 async fn test_filter_pass_all_allows_everything() {
300 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
301 let filters = vec![BackendFilter {
302 namespace: "fs/".to_string(),
303 tool_filter: NameFilter::PassAll,
304 resource_filter: NameFilter::PassAll,
305 prompt_filter: NameFilter::PassAll,
306 }];
307 let mut svc = CapabilityFilterService::new(mock, filters);
308
309 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
310 match resp.inner.unwrap() {
311 McpResponse::ListTools(result) => {
312 assert_eq!(result.tools.len(), 3);
313 }
314 other => panic!("expected ListTools, got: {:?}", other),
315 }
316 }
317
318 #[tokio::test]
319 async fn test_filter_unmatched_namespace_passes_through() {
320 let mock = MockService::with_tools(&["db/query"]);
321 let filters = vec![allow_filter("fs/", &["read"])];
322 let mut svc = CapabilityFilterService::new(mock, filters);
323
324 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
325 match resp.inner.unwrap() {
326 McpResponse::ListTools(result) => {
327 assert_eq!(result.tools.len(), 1, "unmatched namespace should pass");
328 assert_eq!(result.tools[0].name, "db/query");
329 }
330 other => panic!("expected ListTools, got: {:?}", other),
331 }
332 }
333}