Skip to main content

mcp_proxy/
filter.rs

1//! Capability filtering middleware for the proxy.
2//!
3//! Wraps a `Service<RouterRequest>` and filters tools, resources, and prompts
4//! based on per-backend allow/deny lists from config.
5
6use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use tower::{Layer, Service};
13
14use tower_mcp::protocol::{McpRequest, McpResponse};
15use tower_mcp::{RouterRequest, RouterResponse};
16use tower_mcp_types::JsonRpcError;
17
18use crate::config::BackendFilter;
19
20/// Tower layer that produces a [`CapabilityFilterService`].
21///
22/// # Example
23///
24/// ```rust,ignore
25/// use tower::ServiceBuilder;
26/// use mcp_proxy::filter::CapabilityFilterLayer;
27///
28/// let service = ServiceBuilder::new()
29///     .layer(CapabilityFilterLayer::new(filters))
30///     .service(proxy);
31/// ```
32#[derive(Clone)]
33pub struct CapabilityFilterLayer {
34    filters: Vec<BackendFilter>,
35}
36
37impl CapabilityFilterLayer {
38    /// Create a new capability filter layer with the given filter rules.
39    pub fn new(filters: Vec<BackendFilter>) -> Self {
40        Self { filters }
41    }
42}
43
44impl<S> Layer<S> for CapabilityFilterLayer {
45    type Service = CapabilityFilterService<S>;
46
47    fn layer(&self, inner: S) -> Self::Service {
48        CapabilityFilterService::new(inner, self.filters.clone())
49    }
50}
51
52/// Middleware that filters capabilities from proxy responses.
53#[derive(Clone)]
54pub struct CapabilityFilterService<S> {
55    inner: S,
56    filters: Arc<Vec<BackendFilter>>,
57}
58
59impl<S> CapabilityFilterService<S> {
60    /// Create a new capability filter service with the given filter rules.
61    pub fn new(inner: S, filters: Vec<BackendFilter>) -> Self {
62        Self {
63            inner,
64            filters: Arc::new(filters),
65        }
66    }
67}
68
69impl<S> Service<RouterRequest> for CapabilityFilterService<S>
70where
71    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
72        + Clone
73        + Send
74        + 'static,
75    S::Future: Send,
76{
77    type Response = RouterResponse;
78    type Error = Infallible;
79    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
80
81    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
82        self.inner.poll_ready(cx)
83    }
84
85    fn call(&mut self, req: RouterRequest) -> Self::Future {
86        let filters = Arc::clone(&self.filters);
87        let request_id = req.id.clone();
88
89        // Check if this is a call/read/get for a filtered capability
90        match &req.inner {
91            McpRequest::CallTool(params) => {
92                if let Some(reason) = check_tool_denied(&filters, &params.name) {
93                    return Box::pin(async move {
94                        Ok(RouterResponse {
95                            id: request_id,
96                            inner: Err(JsonRpcError::invalid_params(reason)),
97                        })
98                    });
99                }
100            }
101            McpRequest::ReadResource(params) => {
102                if let Some(reason) = check_resource_denied(&filters, &params.uri) {
103                    return Box::pin(async move {
104                        Ok(RouterResponse {
105                            id: request_id,
106                            inner: Err(JsonRpcError::invalid_params(reason)),
107                        })
108                    });
109                }
110            }
111            McpRequest::GetPrompt(params) => {
112                if let Some(reason) = check_prompt_denied(&filters, &params.name) {
113                    return Box::pin(async move {
114                        Ok(RouterResponse {
115                            id: request_id,
116                            inner: Err(JsonRpcError::invalid_params(reason)),
117                        })
118                    });
119                }
120            }
121            _ => {}
122        }
123
124        let fut = self.inner.call(req);
125
126        Box::pin(async move {
127            let mut resp = fut.await?;
128
129            // Filter list responses
130            if let Ok(ref mut mcp_resp) = resp.inner {
131                match mcp_resp {
132                    McpResponse::ListTools(result) => {
133                        result.tools.retain(|tool| {
134                            for f in filters.iter() {
135                                if let Some(local_name) = tool.name.strip_prefix(&f.namespace) {
136                                    if !f.tool_filter.allows(local_name) {
137                                        return false;
138                                    }
139                                    // Annotation-based filtering
140                                    if let Some(ref annotations) = tool.annotations {
141                                        if f.hide_destructive && annotations.destructive_hint {
142                                            return false;
143                                        }
144                                        if f.read_only_only && !annotations.read_only_hint {
145                                            return false;
146                                        }
147                                    } else if f.read_only_only {
148                                        // No annotations = not known to be read-only
149                                        return false;
150                                    }
151                                    return true;
152                                }
153                            }
154                            true
155                        });
156                    }
157                    McpResponse::ListResources(result) => {
158                        result.resources.retain(|resource| {
159                            for f in filters.iter() {
160                                if let Some(local_uri) = resource.uri.strip_prefix(&f.namespace) {
161                                    return f.resource_filter.allows(local_uri);
162                                }
163                            }
164                            true
165                        });
166                    }
167                    McpResponse::ListResourceTemplates(result) => {
168                        result.resource_templates.retain(|template| {
169                            for f in filters.iter() {
170                                if let Some(local_uri) =
171                                    template.uri_template.strip_prefix(&f.namespace)
172                                {
173                                    return f.resource_filter.allows(local_uri);
174                                }
175                            }
176                            true
177                        });
178                    }
179                    McpResponse::ListPrompts(result) => {
180                        result.prompts.retain(|prompt| {
181                            for f in filters.iter() {
182                                if let Some(local_name) = prompt.name.strip_prefix(&f.namespace) {
183                                    return f.prompt_filter.allows(local_name);
184                                }
185                            }
186                            true
187                        });
188                    }
189                    _ => {}
190                }
191            }
192
193            Ok(resp)
194        })
195    }
196}
197
198/// Check if a namespaced tool name is denied by any filter.
199/// Returns Some(reason) if denied.
200fn check_tool_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
201    for f in filters {
202        if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
203            if !f.tool_filter.allows(local_name) {
204                return Some(format!("Tool not available: {}", namespaced_name));
205            }
206            return None;
207        }
208    }
209    None
210}
211
212/// Check if a namespaced resource URI is denied by any filter.
213fn check_resource_denied(filters: &[BackendFilter], namespaced_uri: &str) -> Option<String> {
214    for f in filters {
215        if let Some(local_uri) = namespaced_uri.strip_prefix(&f.namespace) {
216            if !f.resource_filter.allows(local_uri) {
217                return Some(format!("Resource not available: {}", namespaced_uri));
218            }
219            return None;
220        }
221    }
222    None
223}
224
225/// Check if a namespaced prompt name is denied by any filter.
226fn check_prompt_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
227    for f in filters {
228        if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
229            if !f.prompt_filter.allows(local_name) {
230                return Some(format!("Prompt not available: {}", namespaced_name));
231            }
232            return None;
233        }
234    }
235    None
236}
237
238/// Tower layer that produces a [`SearchModeFilterService`].
239///
240/// When search mode is enabled, `ListTools` responses are filtered to only
241/// include tools under the given namespace prefix (typically `"proxy/"`).
242/// All other requests pass through unchanged -- `CallTool` requests for
243/// backend tools still work, allowing `proxy/call_tool` to forward them.
244#[derive(Clone)]
245pub struct SearchModeFilterLayer {
246    prefix: String,
247}
248
249impl SearchModeFilterLayer {
250    /// Create a new search mode filter that only lists tools matching `prefix`.
251    pub fn new(prefix: impl Into<String>) -> Self {
252        Self {
253            prefix: prefix.into(),
254        }
255    }
256}
257
258impl<S> Layer<S> for SearchModeFilterLayer {
259    type Service = SearchModeFilterService<S>;
260
261    fn layer(&self, inner: S) -> Self::Service {
262        SearchModeFilterService {
263            inner,
264            prefix: self.prefix.clone(),
265        }
266    }
267}
268
269/// Middleware that filters `ListTools` responses to only show tools under
270/// a specific namespace prefix.
271///
272/// Used by search mode to hide individual backend tools from tool listings
273/// while keeping them callable through `proxy/call_tool`.
274#[derive(Clone)]
275pub struct SearchModeFilterService<S> {
276    inner: S,
277    prefix: String,
278}
279
280impl<S> SearchModeFilterService<S> {
281    /// Create a new search mode filter service.
282    pub fn new(inner: S, prefix: impl Into<String>) -> Self {
283        Self {
284            inner,
285            prefix: prefix.into(),
286        }
287    }
288}
289
290impl<S> Service<RouterRequest> for SearchModeFilterService<S>
291where
292    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
293        + Clone
294        + Send
295        + 'static,
296    S::Future: Send,
297{
298    type Response = RouterResponse;
299    type Error = Infallible;
300    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
301
302    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303        self.inner.poll_ready(cx)
304    }
305
306    fn call(&mut self, req: RouterRequest) -> Self::Future {
307        let prefix = self.prefix.clone();
308        let fut = self.inner.call(req);
309
310        Box::pin(async move {
311            let mut resp = fut.await?;
312
313            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
314                result.tools.retain(|tool| tool.name.starts_with(&prefix));
315            }
316
317            Ok(resp)
318        })
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use tower_mcp::protocol::{McpRequest, McpResponse};
325
326    use super::CapabilityFilterService;
327    use crate::config::{BackendFilter, NameFilter};
328    use crate::test_util::{MockService, call_service};
329
330    fn allow_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
331        BackendFilter {
332            namespace: namespace.to_string(),
333            tool_filter: NameFilter::allow_list(tools.iter().map(|s| s.to_string())).unwrap(),
334            resource_filter: NameFilter::PassAll,
335            prompt_filter: NameFilter::PassAll,
336            hide_destructive: false,
337            read_only_only: false,
338        }
339    }
340
341    fn deny_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
342        BackendFilter {
343            namespace: namespace.to_string(),
344            tool_filter: NameFilter::deny_list(tools.iter().map(|s| s.to_string())).unwrap(),
345            resource_filter: NameFilter::PassAll,
346            prompt_filter: NameFilter::PassAll,
347            hide_destructive: false,
348            read_only_only: false,
349        }
350    }
351
352    #[tokio::test]
353    async fn test_filter_allow_list_tools() {
354        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
355        let filters = vec![allow_filter("fs/", &["read", "write"])];
356        let mut svc = CapabilityFilterService::new(mock, filters);
357
358        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
359        match resp.inner.unwrap() {
360            McpResponse::ListTools(result) => {
361                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
362                assert!(names.contains(&"fs/read"));
363                assert!(names.contains(&"fs/write"));
364                assert!(!names.contains(&"fs/delete"), "delete should be filtered");
365            }
366            other => panic!("expected ListTools, got: {:?}", other),
367        }
368    }
369
370    #[tokio::test]
371    async fn test_filter_deny_list_tools() {
372        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
373        let filters = vec![deny_filter("fs/", &["delete"])];
374        let mut svc = CapabilityFilterService::new(mock, filters);
375
376        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
377        match resp.inner.unwrap() {
378            McpResponse::ListTools(result) => {
379                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
380                assert!(names.contains(&"fs/read"));
381                assert!(names.contains(&"fs/write"));
382                assert!(!names.contains(&"fs/delete"));
383            }
384            other => panic!("expected ListTools, got: {:?}", other),
385        }
386    }
387
388    #[tokio::test]
389    async fn test_filter_denies_call_to_hidden_tool() {
390        let mock = MockService::with_tools(&["fs/read", "fs/delete"]);
391        let filters = vec![allow_filter("fs/", &["read"])];
392        let mut svc = CapabilityFilterService::new(mock, filters);
393
394        let resp = call_service(
395            &mut svc,
396            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
397                name: "fs/delete".to_string(),
398                arguments: serde_json::json!({}),
399                meta: None,
400                task: None,
401            }),
402        )
403        .await;
404
405        let err = resp.inner.unwrap_err();
406        assert!(
407            err.message.contains("not available"),
408            "should deny: {}",
409            err.message
410        );
411    }
412
413    #[tokio::test]
414    async fn test_filter_allows_call_to_permitted_tool() {
415        let mock = MockService::with_tools(&["fs/read"]);
416        let filters = vec![allow_filter("fs/", &["read"])];
417        let mut svc = CapabilityFilterService::new(mock, filters);
418
419        let resp = call_service(
420            &mut svc,
421            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
422                name: "fs/read".to_string(),
423                arguments: serde_json::json!({}),
424                meta: None,
425                task: None,
426            }),
427        )
428        .await;
429
430        assert!(resp.inner.is_ok(), "allowed tool should succeed");
431    }
432
433    #[tokio::test]
434    async fn test_filter_pass_all_allows_everything() {
435        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
436        let filters = vec![BackendFilter {
437            namespace: "fs/".to_string(),
438            tool_filter: NameFilter::PassAll,
439            resource_filter: NameFilter::PassAll,
440            prompt_filter: NameFilter::PassAll,
441            hide_destructive: false,
442            read_only_only: false,
443        }];
444        let mut svc = CapabilityFilterService::new(mock, filters);
445
446        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
447        match resp.inner.unwrap() {
448            McpResponse::ListTools(result) => {
449                assert_eq!(result.tools.len(), 3);
450            }
451            other => panic!("expected ListTools, got: {:?}", other),
452        }
453    }
454
455    #[tokio::test]
456    async fn test_filter_unmatched_namespace_passes_through() {
457        let mock = MockService::with_tools(&["db/query"]);
458        let filters = vec![allow_filter("fs/", &["read"])];
459        let mut svc = CapabilityFilterService::new(mock, filters);
460
461        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
462        match resp.inner.unwrap() {
463            McpResponse::ListTools(result) => {
464                assert_eq!(result.tools.len(), 1, "unmatched namespace should pass");
465                assert_eq!(result.tools[0].name, "db/query");
466            }
467            other => panic!("expected ListTools, got: {:?}", other),
468        }
469    }
470
471    // --- Annotation-based filtering ---
472
473    /// Create a mock service with tools that have annotations.
474    fn mock_with_annotated_tools() -> MockService {
475        use tower_mcp::protocol::ToolDefinition;
476        use tower_mcp_types::protocol::ToolAnnotations;
477
478        let tools = vec![
479            ToolDefinition {
480                name: "fs/read_file".to_string(),
481                title: None,
482                description: Some("Read a file".to_string()),
483                input_schema: serde_json::json!({"type": "object"}),
484                output_schema: None,
485                icons: None,
486                annotations: Some(ToolAnnotations {
487                    title: None,
488                    read_only_hint: true,
489                    destructive_hint: false,
490                    idempotent_hint: true,
491                    open_world_hint: false,
492                }),
493                execution: None,
494                meta: None,
495            },
496            ToolDefinition {
497                name: "fs/delete_file".to_string(),
498                title: None,
499                description: Some("Delete a file".to_string()),
500                input_schema: serde_json::json!({"type": "object"}),
501                output_schema: None,
502                icons: None,
503                annotations: Some(ToolAnnotations {
504                    title: None,
505                    read_only_hint: false,
506                    destructive_hint: true,
507                    idempotent_hint: false,
508                    open_world_hint: false,
509                }),
510                execution: None,
511                meta: None,
512            },
513            ToolDefinition {
514                name: "fs/write_file".to_string(),
515                title: None,
516                description: Some("Write a file".to_string()),
517                input_schema: serde_json::json!({"type": "object"}),
518                output_schema: None,
519                icons: None,
520                annotations: Some(ToolAnnotations {
521                    title: None,
522                    read_only_hint: false,
523                    destructive_hint: false,
524                    idempotent_hint: true,
525                    open_world_hint: false,
526                }),
527                execution: None,
528                meta: None,
529            },
530        ];
531        MockService { tools }
532    }
533
534    #[tokio::test]
535    async fn test_filter_hide_destructive() {
536        let mock = mock_with_annotated_tools();
537        let filters = vec![BackendFilter {
538            namespace: "fs/".to_string(),
539            tool_filter: NameFilter::PassAll,
540            resource_filter: NameFilter::PassAll,
541            prompt_filter: NameFilter::PassAll,
542            hide_destructive: true,
543            read_only_only: false,
544        }];
545        let mut svc = CapabilityFilterService::new(mock, filters);
546
547        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
548        match resp.inner.unwrap() {
549            McpResponse::ListTools(result) => {
550                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
551                assert!(names.contains(&"fs/read_file"));
552                assert!(names.contains(&"fs/write_file"));
553                assert!(
554                    !names.contains(&"fs/delete_file"),
555                    "destructive tool should be hidden"
556                );
557            }
558            other => panic!("expected ListTools, got: {:?}", other),
559        }
560    }
561
562    #[tokio::test]
563    async fn test_filter_read_only_only() {
564        let mock = mock_with_annotated_tools();
565        let filters = vec![BackendFilter {
566            namespace: "fs/".to_string(),
567            tool_filter: NameFilter::PassAll,
568            resource_filter: NameFilter::PassAll,
569            prompt_filter: NameFilter::PassAll,
570            hide_destructive: false,
571            read_only_only: true,
572        }];
573        let mut svc = CapabilityFilterService::new(mock, filters);
574
575        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
576        match resp.inner.unwrap() {
577            McpResponse::ListTools(result) => {
578                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
579                assert!(names.contains(&"fs/read_file"), "read-only tool visible");
580                assert!(!names.contains(&"fs/delete_file"), "non-read-only hidden");
581                assert!(!names.contains(&"fs/write_file"), "non-read-only hidden");
582            }
583            other => panic!("expected ListTools, got: {:?}", other),
584        }
585    }
586
587    // --- Search mode filtering ---
588
589    #[tokio::test]
590    async fn test_search_mode_only_shows_prefix_tools() {
591        let mock = MockService::with_tools(&[
592            "proxy/search_tools",
593            "proxy/call_tool",
594            "proxy/tool_categories",
595            "fs/read",
596            "fs/write",
597            "db/query",
598        ]);
599        let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
600
601        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
602        match resp.inner.unwrap() {
603            McpResponse::ListTools(result) => {
604                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
605                assert_eq!(names.len(), 3, "only proxy/ tools should be listed");
606                assert!(names.contains(&"proxy/search_tools"));
607                assert!(names.contains(&"proxy/call_tool"));
608                assert!(names.contains(&"proxy/tool_categories"));
609                assert!(!names.contains(&"fs/read"));
610                assert!(!names.contains(&"db/query"));
611            }
612            other => panic!("expected ListTools, got: {:?}", other),
613        }
614    }
615
616    #[tokio::test]
617    async fn test_search_mode_allows_call_tool_for_backend() {
618        let mock = MockService::with_tools(&["proxy/call_tool", "fs/read"]);
619        let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
620
621        // CallTool requests should pass through regardless of namespace
622        let resp = call_service(
623            &mut svc,
624            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
625                name: "fs/read".to_string(),
626                arguments: serde_json::json!({}),
627                meta: None,
628                task: None,
629            }),
630        )
631        .await;
632
633        assert!(
634            resp.inner.is_ok(),
635            "search mode should not block CallTool requests"
636        );
637    }
638
639    #[tokio::test]
640    async fn test_search_mode_no_proxy_tools_returns_empty() {
641        let mock = MockService::with_tools(&["fs/read", "db/query"]);
642        let mut svc = super::SearchModeFilterService::new(mock, "proxy/");
643
644        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
645        match resp.inner.unwrap() {
646            McpResponse::ListTools(result) => {
647                assert!(result.tools.is_empty(), "no proxy/ tools means empty list");
648            }
649            other => panic!("expected ListTools, got: {:?}", other),
650        }
651    }
652}