1use std::convert::Infallible;
103use std::future::Future;
104use std::pin::Pin;
105use std::sync::Arc;
106use std::task::{Context, Poll};
107
108use tower::{Layer, Service};
109
110use tower_mcp::protocol::{McpRequest, McpResponse};
111use tower_mcp::{RouterRequest, RouterResponse};
112use tower_mcp_types::JsonRpcError;
113
114use crate::config::BackendFilter;
115
116#[derive(Clone)]
129pub struct CapabilityFilterLayer {
130 filters: Vec<BackendFilter>,
131}
132
133impl CapabilityFilterLayer {
134 pub fn new(filters: Vec<BackendFilter>) -> Self {
136 Self { filters }
137 }
138}
139
140impl<S> Layer<S> for CapabilityFilterLayer {
141 type Service = CapabilityFilterService<S>;
142
143 fn layer(&self, inner: S) -> Self::Service {
144 CapabilityFilterService::new(inner, self.filters.clone())
145 }
146}
147
148#[derive(Clone)]
150pub struct CapabilityFilterService<S> {
151 inner: S,
152 filters: Arc<Vec<BackendFilter>>,
153}
154
155impl<S> CapabilityFilterService<S> {
156 pub fn new(inner: S, filters: Vec<BackendFilter>) -> Self {
158 Self {
159 inner,
160 filters: Arc::new(filters),
161 }
162 }
163}
164
165impl<S> Service<RouterRequest> for CapabilityFilterService<S>
166where
167 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
168 + Clone
169 + Send
170 + 'static,
171 S::Future: Send,
172{
173 type Response = RouterResponse;
174 type Error = Infallible;
175 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
176
177 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
178 self.inner.poll_ready(cx)
179 }
180
181 fn call(&mut self, req: RouterRequest) -> Self::Future {
182 let filters = Arc::clone(&self.filters);
183 let request_id = req.id.clone();
184
185 match &req.inner {
187 McpRequest::CallTool(params) => {
188 if let Some(reason) = check_tool_denied(&filters, ¶ms.name) {
189 return Box::pin(async move {
190 Ok(RouterResponse {
191 id: request_id,
192 inner: Err(JsonRpcError::invalid_params(reason)),
193 })
194 });
195 }
196 }
197 McpRequest::ReadResource(params) => {
198 if let Some(reason) = check_resource_denied(&filters, ¶ms.uri) {
199 return Box::pin(async move {
200 Ok(RouterResponse {
201 id: request_id,
202 inner: Err(JsonRpcError::invalid_params(reason)),
203 })
204 });
205 }
206 }
207 McpRequest::GetPrompt(params) => {
208 if let Some(reason) = check_prompt_denied(&filters, ¶ms.name) {
209 return Box::pin(async move {
210 Ok(RouterResponse {
211 id: request_id,
212 inner: Err(JsonRpcError::invalid_params(reason)),
213 })
214 });
215 }
216 }
217 _ => {}
218 }
219
220 let fut = self.inner.call(req);
221
222 Box::pin(async move {
223 let mut resp = fut.await?;
224
225 if let Ok(ref mut mcp_resp) = resp.inner {
227 match mcp_resp {
228 McpResponse::ListTools(result) => {
229 result.tools.retain(|tool| {
230 for f in filters.iter() {
231 if let Some(local_name) = tool.name.strip_prefix(&f.namespace) {
232 if !f.tool_filter.allows(local_name) {
233 return false;
234 }
235 if let Some(ref annotations) = tool.annotations {
237 if f.hide_destructive && annotations.destructive_hint {
238 return false;
239 }
240 if f.read_only_only && !annotations.read_only_hint {
241 return false;
242 }
243 } else if f.read_only_only {
244 return false;
246 }
247 return true;
248 }
249 }
250 true
251 });
252 }
253 McpResponse::ListResources(result) => {
254 result.resources.retain(|resource| {
255 for f in filters.iter() {
256 if let Some(local_uri) = resource.uri.strip_prefix(&f.namespace) {
257 return f.resource_filter.allows(local_uri);
258 }
259 }
260 true
261 });
262 }
263 McpResponse::ListResourceTemplates(result) => {
264 result.resource_templates.retain(|template| {
265 for f in filters.iter() {
266 if let Some(local_uri) =
267 template.uri_template.strip_prefix(&f.namespace)
268 {
269 return f.resource_filter.allows(local_uri);
270 }
271 }
272 true
273 });
274 }
275 McpResponse::ListPrompts(result) => {
276 result.prompts.retain(|prompt| {
277 for f in filters.iter() {
278 if let Some(local_name) = prompt.name.strip_prefix(&f.namespace) {
279 return f.prompt_filter.allows(local_name);
280 }
281 }
282 true
283 });
284 }
285 _ => {}
286 }
287 }
288
289 Ok(resp)
290 })
291 }
292}
293
294fn check_tool_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
297 for f in filters {
298 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
299 if !f.tool_filter.allows(local_name) {
300 return Some(format!("Tool not available: {}", namespaced_name));
301 }
302 return None;
303 }
304 }
305 None
306}
307
308fn check_resource_denied(filters: &[BackendFilter], namespaced_uri: &str) -> Option<String> {
310 for f in filters {
311 if let Some(local_uri) = namespaced_uri.strip_prefix(&f.namespace) {
312 if !f.resource_filter.allows(local_uri) {
313 return Some(format!("Resource not available: {}", namespaced_uri));
314 }
315 return None;
316 }
317 }
318 None
319}
320
321fn check_prompt_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
323 for f in filters {
324 if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
325 if !f.prompt_filter.allows(local_name) {
326 return Some(format!("Prompt not available: {}", namespaced_name));
327 }
328 return None;
329 }
330 }
331 None
332}
333
334#[derive(Clone)]
341pub struct SearchModeFilterLayer {
342 prefix: String,
343}
344
345impl SearchModeFilterLayer {
346 pub fn new(prefix: impl Into<String>) -> Self {
348 Self {
349 prefix: prefix.into(),
350 }
351 }
352}
353
354impl<S> Layer<S> for SearchModeFilterLayer {
355 type Service = SearchModeFilterService<S>;
356
357 fn layer(&self, inner: S) -> Self::Service {
358 SearchModeFilterService {
359 inner,
360 prefix: self.prefix.clone(),
361 }
362 }
363}
364
365#[derive(Clone)]
371pub struct SearchModeFilterService<S> {
372 inner: S,
373 prefix: String,
374}
375
376impl<S> SearchModeFilterService<S> {
377 pub fn new(inner: S, prefix: impl Into<String>) -> Self {
379 Self {
380 inner,
381 prefix: prefix.into(),
382 }
383 }
384}
385
386impl<S> Service<RouterRequest> for SearchModeFilterService<S>
387where
388 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
389 + Clone
390 + Send
391 + 'static,
392 S::Future: Send,
393{
394 type Response = RouterResponse;
395 type Error = Infallible;
396 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
397
398 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
399 self.inner.poll_ready(cx)
400 }
401
402 fn call(&mut self, req: RouterRequest) -> Self::Future {
403 let prefix = self.prefix.clone();
404 let fut = self.inner.call(req);
405
406 Box::pin(async move {
407 let mut resp = fut.await?;
408
409 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
410 result.tools.retain(|tool| tool.name.starts_with(&prefix));
411 }
412
413 Ok(resp)
414 })
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use tower_mcp::protocol::{McpRequest, McpResponse};
421
422 use super::CapabilityFilterService;
423 use crate::config::{BackendFilter, NameFilter};
424 use crate::test_util::{MockService, call_service};
425
426 fn allow_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
427 BackendFilter {
428 namespace: namespace.to_string(),
429 tool_filter: NameFilter::allow_list(tools.iter().map(|s| s.to_string())).unwrap(),
430 resource_filter: NameFilter::PassAll,
431 prompt_filter: NameFilter::PassAll,
432 hide_destructive: false,
433 read_only_only: false,
434 }
435 }
436
437 fn deny_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
438 BackendFilter {
439 namespace: namespace.to_string(),
440 tool_filter: NameFilter::deny_list(tools.iter().map(|s| s.to_string())).unwrap(),
441 resource_filter: NameFilter::PassAll,
442 prompt_filter: NameFilter::PassAll,
443 hide_destructive: false,
444 read_only_only: false,
445 }
446 }
447
448 #[tokio::test]
449 async fn test_filter_allow_list_tools() {
450 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
451 let filters = vec![allow_filter("fs/", &["read", "write"])];
452 let mut svc = CapabilityFilterService::new(mock, filters);
453
454 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
455 match resp.inner.unwrap() {
456 McpResponse::ListTools(result) => {
457 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
458 assert!(names.contains(&"fs/read"));
459 assert!(names.contains(&"fs/write"));
460 assert!(!names.contains(&"fs/delete"), "delete should be filtered");
461 }
462 other => panic!("expected ListTools, got: {:?}", other),
463 }
464 }
465
466 #[tokio::test]
467 async fn test_filter_deny_list_tools() {
468 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
469 let filters = vec![deny_filter("fs/", &["delete"])];
470 let mut svc = CapabilityFilterService::new(mock, filters);
471
472 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
473 match resp.inner.unwrap() {
474 McpResponse::ListTools(result) => {
475 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
476 assert!(names.contains(&"fs/read"));
477 assert!(names.contains(&"fs/write"));
478 assert!(!names.contains(&"fs/delete"));
479 }
480 other => panic!("expected ListTools, got: {:?}", other),
481 }
482 }
483
484 #[tokio::test]
485 async fn test_filter_denies_call_to_hidden_tool() {
486 let mock = MockService::with_tools(&["fs/read", "fs/delete"]);
487 let filters = vec![allow_filter("fs/", &["read"])];
488 let mut svc = CapabilityFilterService::new(mock, filters);
489
490 let resp = call_service(
491 &mut svc,
492 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
493 name: "fs/delete".to_string(),
494 arguments: serde_json::json!({}),
495 meta: None,
496 task: None,
497 }),
498 )
499 .await;
500
501 let err = resp.inner.unwrap_err();
502 assert!(
503 err.message.contains("not available"),
504 "should deny: {}",
505 err.message
506 );
507 }
508
509 #[tokio::test]
510 async fn test_filter_allows_call_to_permitted_tool() {
511 let mock = MockService::with_tools(&["fs/read"]);
512 let filters = vec![allow_filter("fs/", &["read"])];
513 let mut svc = CapabilityFilterService::new(mock, filters);
514
515 let resp = call_service(
516 &mut svc,
517 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
518 name: "fs/read".to_string(),
519 arguments: serde_json::json!({}),
520 meta: None,
521 task: None,
522 }),
523 )
524 .await;
525
526 assert!(resp.inner.is_ok(), "allowed tool should succeed");
527 }
528
529 #[tokio::test]
530 async fn test_filter_pass_all_allows_everything() {
531 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
532 let filters = vec![BackendFilter {
533 namespace: "fs/".to_string(),
534 tool_filter: NameFilter::PassAll,
535 resource_filter: NameFilter::PassAll,
536 prompt_filter: NameFilter::PassAll,
537 hide_destructive: false,
538 read_only_only: false,
539 }];
540 let mut svc = CapabilityFilterService::new(mock, filters);
541
542 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
543 match resp.inner.unwrap() {
544 McpResponse::ListTools(result) => {
545 assert_eq!(result.tools.len(), 3);
546 }
547 other => panic!("expected ListTools, got: {:?}", other),
548 }
549 }
550
551 #[tokio::test]
552 async fn test_filter_unmatched_namespace_passes_through() {
553 let mock = MockService::with_tools(&["db/query"]);
554 let filters = vec![allow_filter("fs/", &["read"])];
555 let mut svc = CapabilityFilterService::new(mock, filters);
556
557 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
558 match resp.inner.unwrap() {
559 McpResponse::ListTools(result) => {
560 assert_eq!(result.tools.len(), 1, "unmatched namespace should pass");
561 assert_eq!(result.tools[0].name, "db/query");
562 }
563 other => panic!("expected ListTools, got: {:?}", other),
564 }
565 }
566
567 fn mock_with_annotated_tools() -> MockService {
571 use tower_mcp::protocol::ToolDefinition;
572 use tower_mcp_types::protocol::ToolAnnotations;
573
574 let tools = vec![
575 ToolDefinition {
576 name: "fs/read_file".to_string(),
577 title: None,
578 description: Some("Read a file".to_string()),
579 input_schema: serde_json::json!({"type": "object"}),
580 output_schema: None,
581 icons: None,
582 annotations: Some(ToolAnnotations {
583 title: None,
584 read_only_hint: true,
585 destructive_hint: false,
586 idempotent_hint: true,
587 open_world_hint: false,
588 }),
589 execution: None,
590 meta: None,
591 },
592 ToolDefinition {
593 name: "fs/delete_file".to_string(),
594 title: None,
595 description: Some("Delete a file".to_string()),
596 input_schema: serde_json::json!({"type": "object"}),
597 output_schema: None,
598 icons: None,
599 annotations: Some(ToolAnnotations {
600 title: None,
601 read_only_hint: false,
602 destructive_hint: true,
603 idempotent_hint: false,
604 open_world_hint: false,
605 }),
606 execution: None,
607 meta: None,
608 },
609 ToolDefinition {
610 name: "fs/write_file".to_string(),
611 title: None,
612 description: Some("Write a file".to_string()),
613 input_schema: serde_json::json!({"type": "object"}),
614 output_schema: None,
615 icons: None,
616 annotations: Some(ToolAnnotations {
617 title: None,
618 read_only_hint: false,
619 destructive_hint: false,
620 idempotent_hint: true,
621 open_world_hint: false,
622 }),
623 execution: None,
624 meta: None,
625 },
626 ];
627 MockService { tools }
628 }
629
630 #[tokio::test]
631 async fn test_filter_hide_destructive() {
632 let mock = mock_with_annotated_tools();
633 let filters = vec![BackendFilter {
634 namespace: "fs/".to_string(),
635 tool_filter: NameFilter::PassAll,
636 resource_filter: NameFilter::PassAll,
637 prompt_filter: NameFilter::PassAll,
638 hide_destructive: true,
639 read_only_only: false,
640 }];
641 let mut svc = CapabilityFilterService::new(mock, filters);
642
643 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
644 match resp.inner.unwrap() {
645 McpResponse::ListTools(result) => {
646 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
647 assert!(names.contains(&"fs/read_file"));
648 assert!(names.contains(&"fs/write_file"));
649 assert!(
650 !names.contains(&"fs/delete_file"),
651 "destructive tool should be hidden"
652 );
653 }
654 other => panic!("expected ListTools, got: {:?}", other),
655 }
656 }
657
658 #[tokio::test]
659 async fn test_filter_read_only_only() {
660 let mock = mock_with_annotated_tools();
661 let filters = vec![BackendFilter {
662 namespace: "fs/".to_string(),
663 tool_filter: NameFilter::PassAll,
664 resource_filter: NameFilter::PassAll,
665 prompt_filter: NameFilter::PassAll,
666 hide_destructive: false,
667 read_only_only: true,
668 }];
669 let mut svc = CapabilityFilterService::new(mock, filters);
670
671 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
672 match resp.inner.unwrap() {
673 McpResponse::ListTools(result) => {
674 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
675 assert!(names.contains(&"fs/read_file"), "read-only tool visible");
676 assert!(!names.contains(&"fs/delete_file"), "non-read-only hidden");
677 assert!(!names.contains(&"fs/write_file"), "non-read-only hidden");
678 }
679 other => panic!("expected ListTools, got: {:?}", other),
680 }
681 }
682
683 #[tokio::test]
686 async fn test_search_mode_only_shows_prefix_tools() {
687 let mock = MockService::with_tools(&[
688 "proxy/search_tools",
689 "proxy/call_tool",
690 "proxy/tool_categories",
691 "fs/read",
692 "fs/write",
693 "db/query",
694 ]);
695 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
696
697 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
698 match resp.inner.unwrap() {
699 McpResponse::ListTools(result) => {
700 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
701 assert_eq!(names.len(), 3, "only proxy/ tools should be listed");
702 assert!(names.contains(&"proxy/search_tools"));
703 assert!(names.contains(&"proxy/call_tool"));
704 assert!(names.contains(&"proxy/tool_categories"));
705 assert!(!names.contains(&"fs/read"));
706 assert!(!names.contains(&"db/query"));
707 }
708 other => panic!("expected ListTools, got: {:?}", other),
709 }
710 }
711
712 #[tokio::test]
713 async fn test_search_mode_allows_call_tool_for_backend() {
714 let mock = MockService::with_tools(&["proxy/call_tool", "fs/read"]);
715 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
716
717 let resp = call_service(
719 &mut svc,
720 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
721 name: "fs/read".to_string(),
722 arguments: serde_json::json!({}),
723 meta: None,
724 task: None,
725 }),
726 )
727 .await;
728
729 assert!(
730 resp.inner.is_ok(),
731 "search mode should not block CallTool requests"
732 );
733 }
734
735 #[tokio::test]
736 async fn test_search_mode_no_proxy_tools_returns_empty() {
737 let mock = MockService::with_tools(&["fs/read", "db/query"]);
738 let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
739
740 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
741 match resp.inner.unwrap() {
742 McpResponse::ListTools(result) => {
743 assert!(result.tools.is_empty(), "no proxy/ tools means empty list");
744 }
745 other => panic!("expected ListTools, got: {:?}", other),
746 }
747 }
748}