1use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use tower::{Layer, Service};
13
14use tower_mcp::protocol::{McpRequest, McpResponse};
15use tower_mcp::{RouterRequest, RouterResponse};
16use tower_mcp_types::JsonRpcError;
17
18use crate::config::BackendFilter;
19
20#[derive(Clone)]
33pub struct CapabilityFilterLayer {
34 filters: Vec<BackendFilter>,
35}
36
37impl CapabilityFilterLayer {
38 pub fn new(filters: Vec<BackendFilter>) -> Self {
40 Self { filters }
41 }
42}
43
44impl<S> Layer<S> for CapabilityFilterLayer {
45 type Service = CapabilityFilterService<S>;
46
47 fn layer(&self, inner: S) -> Self::Service {
48 CapabilityFilterService::new(inner, self.filters.clone())
49 }
50}
51
52#[derive(Clone)]
54pub struct CapabilityFilterService<S> {
55 inner: S,
56 filters: Arc<Vec<BackendFilter>>,
57}
58
59impl<S> CapabilityFilterService<S> {
60 pub fn new(inner: S, filters: Vec<BackendFilter>) -> Self {
62 Self {
63 inner,
64 filters: Arc::new(filters),
65 }
66 }
67}
68
69impl<S> Service<RouterRequest> for CapabilityFilterService<S>
70where
71 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
72 + Clone
73 + Send
74 + 'static,
75 S::Future: Send,
76{
77 type Response = RouterResponse;
78 type Error = Infallible;
79 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
80
81 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
82 self.inner.poll_ready(cx)
83 }
84
85 fn call(&mut self, req: RouterRequest) -> Self::Future {
86 let filters = Arc::clone(&self.filters);
87 let request_id = req.id.clone();
88
89 match &req.inner {
91 McpRequest::CallTool(params) => {
92 if let Some(reason) = check_tool_denied(&filters, ¶ms.name) {
93 return Box::pin(async move {
94 Ok(RouterResponse {
95 id: request_id,
96 inner: Err(JsonRpcError::invalid_params(reason)),
97 })
98 });
99 }
100 }
101 McpRequest::ReadResource(params) => {
102 if let Some(reason) = check_resource_denied(&filters, ¶ms.uri) {
103 return Box::pin(async move {
104 Ok(RouterResponse {
105 id: request_id,
106 inner: Err(JsonRpcError::invalid_params(reason)),
107 })
108 });
109 }
110 }
111 McpRequest::GetPrompt(params) => {
112 if let Some(reason) = check_prompt_denied(&filters, ¶ms.name) {
113 return Box::pin(async move {
114 Ok(RouterResponse {
115 id: request_id,
116 inner: Err(JsonRpcError::invalid_params(reason)),
117 })
118 });
119 }
120 }
121 _ => {}
122 }
123
124 let fut = self.inner.call(req);
125
126 Box::pin(async move {
127 let mut resp = fut.await?;
128
129 if let Ok(ref mut mcp_resp) = resp.inner {
131 match mcp_resp {
132 McpResponse::ListTools(result) => {
133 result.tools.retain(|tool| {
134 for f in filters.iter() {
135 if let Some(local_name) = tool.name.strip_prefix(&f.namespace) {
136 if !f.tool_filter.allows(local_name) {
137 return false;
138 }
139 if let Some(ref annotations) = tool.annotations {
141 if f.hide_destructive && annotations.destructive_hint {
142 return false;
143 }
144 if f.read_only_only && !annotations.read_only_hint {
145 return false;
146 }
147 } else if f.read_only_only {
148 return false;
150 }
151 return true;
152 }
153 }
154 true
155 });
156 }
157 McpResponse::ListResources(result) => {
158 result.resources.retain(|resource| {
159 for f in filters.iter() {
160 if let Some(local_uri) = resource.uri.strip_prefix(&f.namespace) {
161 return f.resource_filter.allows(local_uri);
162 }
163 }
164 true
165 });
166 }
167 McpResponse::ListResourceTemplates(result) => {
168 result.resource_templates.retain(|template| {
169 for f in filters.iter() {
170 if let Some(local_uri) =
171 template.uri_template.strip_prefix(&f.namespace)
172 {
173 return f.resource_filter.allows(local_uri);
174 }
175 }
176 true
177 });
178 }
179 McpResponse::ListPrompts(result) => {
180 result.prompts.retain(|prompt| {
181 for f in filters.iter() {
182 if let Some(local_name) = prompt.name.strip_prefix(&f.namespace) {
183 return f.prompt_filter.allows(local_name);
184 }
185 }
186 true
187 });
188 }
189 _ => {}
190 }
191 }
192
193 Ok(resp)
194 })
195 }
196}
197
198fn check_tool_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
201 for f in filters {
202 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
203 if !f.tool_filter.allows(local_name) {
204 return Some(format!("Tool not available: {}", namespaced_name));
205 }
206 return None;
207 }
208 }
209 None
210}
211
212fn check_resource_denied(filters: &[BackendFilter], namespaced_uri: &str) -> Option<String> {
214 for f in filters {
215 if let Some(local_uri) = namespaced_uri.strip_prefix(&f.namespace) {
216 if !f.resource_filter.allows(local_uri) {
217 return Some(format!("Resource not available: {}", namespaced_uri));
218 }
219 return None;
220 }
221 }
222 None
223}
224
225fn check_prompt_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
227 for f in filters {
228 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
229 if !f.prompt_filter.allows(local_name) {
230 return Some(format!("Prompt not available: {}", namespaced_name));
231 }
232 return None;
233 }
234 }
235 None
236}
237
238#[derive(Clone)]
245pub struct SearchModeFilterLayer {
246 prefix: String,
247}
248
249impl SearchModeFilterLayer {
250 pub fn new(prefix: impl Into<String>) -> Self {
252 Self {
253 prefix: prefix.into(),
254 }
255 }
256}
257
258impl<S> Layer<S> for SearchModeFilterLayer {
259 type Service = SearchModeFilterService<S>;
260
261 fn layer(&self, inner: S) -> Self::Service {
262 SearchModeFilterService {
263 inner,
264 prefix: self.prefix.clone(),
265 }
266 }
267}
268
269#[derive(Clone)]
275pub struct SearchModeFilterService<S> {
276 inner: S,
277 prefix: String,
278}
279
280impl<S> SearchModeFilterService<S> {
281 pub fn new(inner: S, prefix: impl Into<String>) -> Self {
283 Self {
284 inner,
285 prefix: prefix.into(),
286 }
287 }
288}
289
290impl<S> Service<RouterRequest> for SearchModeFilterService<S>
291where
292 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
293 + Clone
294 + Send
295 + 'static,
296 S::Future: Send,
297{
298 type Response = RouterResponse;
299 type Error = Infallible;
300 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
301
302 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.inner.poll_ready(cx)
304 }
305
306 fn call(&mut self, req: RouterRequest) -> Self::Future {
307 let prefix = self.prefix.clone();
308 let fut = self.inner.call(req);
309
310 Box::pin(async move {
311 let mut resp = fut.await?;
312
313 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
314 result.tools.retain(|tool| tool.name.starts_with(&prefix));
315 }
316
317 Ok(resp)
318 })
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use tower_mcp::protocol::{McpRequest, McpResponse};
325
326 use super::CapabilityFilterService;
327 use crate::config::{BackendFilter, NameFilter};
328 use crate::test_util::{MockService, call_service};
329
330 fn allow_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
331 BackendFilter {
332 namespace: namespace.to_string(),
333 tool_filter: NameFilter::allow_list(tools.iter().map(|s| s.to_string())).unwrap(),
334 resource_filter: NameFilter::PassAll,
335 prompt_filter: NameFilter::PassAll,
336 hide_destructive: false,
337 read_only_only: false,
338 }
339 }
340
341 fn deny_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
342 BackendFilter {
343 namespace: namespace.to_string(),
344 tool_filter: NameFilter::deny_list(tools.iter().map(|s| s.to_string())).unwrap(),
345 resource_filter: NameFilter::PassAll,
346 prompt_filter: NameFilter::PassAll,
347 hide_destructive: false,
348 read_only_only: false,
349 }
350 }
351
352 #[tokio::test]
353 async fn test_filter_allow_list_tools() {
354 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
355 let filters = vec![allow_filter("fs/", &["read", "write"])];
356 let mut svc = CapabilityFilterService::new(mock, filters);
357
358 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
359 match resp.inner.unwrap() {
360 McpResponse::ListTools(result) => {
361 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
362 assert!(names.contains(&"fs/read"));
363 assert!(names.contains(&"fs/write"));
364 assert!(!names.contains(&"fs/delete"), "delete should be filtered");
365 }
366 other => panic!("expected ListTools, got: {:?}", other),
367 }
368 }
369
370 #[tokio::test]
371 async fn test_filter_deny_list_tools() {
372 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
373 let filters = vec![deny_filter("fs/", &["delete"])];
374 let mut svc = CapabilityFilterService::new(mock, filters);
375
376 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
377 match resp.inner.unwrap() {
378 McpResponse::ListTools(result) => {
379 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
380 assert!(names.contains(&"fs/read"));
381 assert!(names.contains(&"fs/write"));
382 assert!(!names.contains(&"fs/delete"));
383 }
384 other => panic!("expected ListTools, got: {:?}", other),
385 }
386 }
387
388 #[tokio::test]
389 async fn test_filter_denies_call_to_hidden_tool() {
390 let mock = MockService::with_tools(&["fs/read", "fs/delete"]);
391 let filters = vec![allow_filter("fs/", &["read"])];
392 let mut svc = CapabilityFilterService::new(mock, filters);
393
394 let resp = call_service(
395 &mut svc,
396 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
397 name: "fs/delete".to_string(),
398 arguments: serde_json::json!({}),
399 meta: None,
400 task: None,
401 }),
402 )
403 .await;
404
405 let err = resp.inner.unwrap_err();
406 assert!(
407 err.message.contains("not available"),
408 "should deny: {}",
409 err.message
410 );
411 }
412
413 #[tokio::test]
414 async fn test_filter_allows_call_to_permitted_tool() {
415 let mock = MockService::with_tools(&["fs/read"]);
416 let filters = vec![allow_filter("fs/", &["read"])];
417 let mut svc = CapabilityFilterService::new(mock, filters);
418
419 let resp = call_service(
420 &mut svc,
421 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
422 name: "fs/read".to_string(),
423 arguments: serde_json::json!({}),
424 meta: None,
425 task: None,
426 }),
427 )
428 .await;
429
430 assert!(resp.inner.is_ok(), "allowed tool should succeed");
431 }
432
433 #[tokio::test]
434 async fn test_filter_pass_all_allows_everything() {
435 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
436 let filters = vec![BackendFilter {
437 namespace: "fs/".to_string(),
438 tool_filter: NameFilter::PassAll,
439 resource_filter: NameFilter::PassAll,
440 prompt_filter: NameFilter::PassAll,
441 hide_destructive: false,
442 read_only_only: false,
443 }];
444 let mut svc = CapabilityFilterService::new(mock, filters);
445
446 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
447 match resp.inner.unwrap() {
448 McpResponse::ListTools(result) => {
449 assert_eq!(result.tools.len(), 3);
450 }
451 other => panic!("expected ListTools, got: {:?}", other),
452 }
453 }
454
455 #[tokio::test]
456 async fn test_filter_unmatched_namespace_passes_through() {
457 let mock = MockService::with_tools(&["db/query"]);
458 let filters = vec![allow_filter("fs/", &["read"])];
459 let mut svc = CapabilityFilterService::new(mock, filters);
460
461 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
462 match resp.inner.unwrap() {
463 McpResponse::ListTools(result) => {
464 assert_eq!(result.tools.len(), 1, "unmatched namespace should pass");
465 assert_eq!(result.tools[0].name, "db/query");
466 }
467 other => panic!("expected ListTools, got: {:?}", other),
468 }
469 }
470
471 fn mock_with_annotated_tools() -> MockService {
475 use tower_mcp::protocol::ToolDefinition;
476 use tower_mcp_types::protocol::ToolAnnotations;
477
478 let tools = vec![
479 ToolDefinition {
480 name: "fs/read_file".to_string(),
481 title: None,
482 description: Some("Read a file".to_string()),
483 input_schema: serde_json::json!({"type": "object"}),
484 output_schema: None,
485 icons: None,
486 annotations: Some(ToolAnnotations {
487 title: None,
488 read_only_hint: true,
489 destructive_hint: false,
490 idempotent_hint: true,
491 open_world_hint: false,
492 }),
493 execution: None,
494 meta: None,
495 },
496 ToolDefinition {
497 name: "fs/delete_file".to_string(),
498 title: None,
499 description: Some("Delete a file".to_string()),
500 input_schema: serde_json::json!({"type": "object"}),
501 output_schema: None,
502 icons: None,
503 annotations: Some(ToolAnnotations {
504 title: None,
505 read_only_hint: false,
506 destructive_hint: true,
507 idempotent_hint: false,
508 open_world_hint: false,
509 }),
510 execution: None,
511 meta: None,
512 },
513 ToolDefinition {
514 name: "fs/write_file".to_string(),
515 title: None,
516 description: Some("Write a file".to_string()),
517 input_schema: serde_json::json!({"type": "object"}),
518 output_schema: None,
519 icons: None,
520 annotations: Some(ToolAnnotations {
521 title: None,
522 read_only_hint: false,
523 destructive_hint: false,
524 idempotent_hint: true,
525 open_world_hint: false,
526 }),
527 execution: None,
528 meta: None,
529 },
530 ];
531 MockService { tools }
532 }
533
534 #[tokio::test]
535 async fn test_filter_hide_destructive() {
536 let mock = mock_with_annotated_tools();
537 let filters = vec![BackendFilter {
538 namespace: "fs/".to_string(),
539 tool_filter: NameFilter::PassAll,
540 resource_filter: NameFilter::PassAll,
541 prompt_filter: NameFilter::PassAll,
542 hide_destructive: true,
543 read_only_only: false,
544 }];
545 let mut svc = CapabilityFilterService::new(mock, filters);
546
547 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
548 match resp.inner.unwrap() {
549 McpResponse::ListTools(result) => {
550 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
551 assert!(names.contains(&"fs/read_file"));
552 assert!(names.contains(&"fs/write_file"));
553 assert!(
554 !names.contains(&"fs/delete_file"),
555 "destructive tool should be hidden"
556 );
557 }
558 other => panic!("expected ListTools, got: {:?}", other),
559 }
560 }
561
562 #[tokio::test]
563 async fn test_filter_read_only_only() {
564 let mock = mock_with_annotated_tools();
565 let filters = vec![BackendFilter {
566 namespace: "fs/".to_string(),
567 tool_filter: NameFilter::PassAll,
568 resource_filter: NameFilter::PassAll,
569 prompt_filter: NameFilter::PassAll,
570 hide_destructive: false,
571 read_only_only: true,
572 }];
573 let mut svc = CapabilityFilterService::new(mock, filters);
574
575 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
576 match resp.inner.unwrap() {
577 McpResponse::ListTools(result) => {
578 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
579 assert!(names.contains(&"fs/read_file"), "read-only tool visible");
580 assert!(!names.contains(&"fs/delete_file"), "non-read-only hidden");
581 assert!(!names.contains(&"fs/write_file"), "non-read-only hidden");
582 }
583 other => panic!("expected ListTools, got: {:?}", other),
584 }
585 }
586
587 #[tokio::test]
590 async fn test_search_mode_only_shows_prefix_tools() {
591 let mock = MockService::with_tools(&[
592 "proxy/search_tools",
593 "proxy/call_tool",
594 "proxy/tool_categories",
595 "fs/read",
596 "fs/write",
597 "db/query",
598 ]);
599 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
600
601 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
602 match resp.inner.unwrap() {
603 McpResponse::ListTools(result) => {
604 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
605 assert_eq!(names.len(), 3, "only proxy/ tools should be listed");
606 assert!(names.contains(&"proxy/search_tools"));
607 assert!(names.contains(&"proxy/call_tool"));
608 assert!(names.contains(&"proxy/tool_categories"));
609 assert!(!names.contains(&"fs/read"));
610 assert!(!names.contains(&"db/query"));
611 }
612 other => panic!("expected ListTools, got: {:?}", other),
613 }
614 }
615
616 #[tokio::test]
617 async fn test_search_mode_allows_call_tool_for_backend() {
618 let mock = MockService::with_tools(&["proxy/call_tool", "fs/read"]);
619 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
620
621 let resp = call_service(
623 &mut svc,
624 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
625 name: "fs/read".to_string(),
626 arguments: serde_json::json!({}),
627 meta: None,
628 task: None,
629 }),
630 )
631 .await;
632
633 assert!(
634 resp.inner.is_ok(),
635 "search mode should not block CallTool requests"
636 );
637 }
638
639 #[tokio::test]
640 async fn test_search_mode_no_proxy_tools_returns_empty() {
641 let mock = MockService::with_tools(&["fs/read", "db/query"]);
642 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
643
644 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
645 match resp.inner.unwrap() {
646 McpResponse::ListTools(result) => {
647 assert!(result.tools.is_empty(), "no proxy/ tools means empty list");
648 }
649 other => panic!("expected ListTools, got: {:?}", other),
650 }
651 }
652}