Skip to main content

mcp_proxy/
filter.rs

1//! Capability filtering middleware for the proxy.
2//!
3//! Wraps a `Service<RouterRequest>` and filters tools, resources, and prompts
4//! based on per-backend allow/deny lists from config.
5
6use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use tower::Service;
13
14use tower_mcp::protocol::{McpRequest, McpResponse};
15use tower_mcp::{RouterRequest, RouterResponse};
16use tower_mcp_types::JsonRpcError;
17
18use crate::config::BackendFilter;
19
20/// Middleware that filters capabilities from proxy responses.
21#[derive(Clone)]
22pub struct CapabilityFilterService<S> {
23    inner: S,
24    filters: Arc<Vec<BackendFilter>>,
25}
26
27impl<S> CapabilityFilterService<S> {
28    /// Create a new capability filter service with the given filter rules.
29    pub fn new(inner: S, filters: Vec<BackendFilter>) -> Self {
30        Self {
31            inner,
32            filters: Arc::new(filters),
33        }
34    }
35}
36
37impl<S> Service<RouterRequest> for CapabilityFilterService<S>
38where
39    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
40        + Clone
41        + Send
42        + 'static,
43    S::Future: Send,
44{
45    type Response = RouterResponse;
46    type Error = Infallible;
47    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
48
49    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        self.inner.poll_ready(cx)
51    }
52
53    fn call(&mut self, req: RouterRequest) -> Self::Future {
54        let filters = Arc::clone(&self.filters);
55        let request_id = req.id.clone();
56
57        // Check if this is a call/read/get for a filtered capability
58        match &req.inner {
59            McpRequest::CallTool(params) => {
60                if let Some(reason) = check_tool_denied(&filters, &params.name) {
61                    return Box::pin(async move {
62                        Ok(RouterResponse {
63                            id: request_id,
64                            inner: Err(JsonRpcError::invalid_params(reason)),
65                        })
66                    });
67                }
68            }
69            McpRequest::ReadResource(params) => {
70                if let Some(reason) = check_resource_denied(&filters, &params.uri) {
71                    return Box::pin(async move {
72                        Ok(RouterResponse {
73                            id: request_id,
74                            inner: Err(JsonRpcError::invalid_params(reason)),
75                        })
76                    });
77                }
78            }
79            McpRequest::GetPrompt(params) => {
80                if let Some(reason) = check_prompt_denied(&filters, &params.name) {
81                    return Box::pin(async move {
82                        Ok(RouterResponse {
83                            id: request_id,
84                            inner: Err(JsonRpcError::invalid_params(reason)),
85                        })
86                    });
87                }
88            }
89            _ => {}
90        }
91
92        let fut = self.inner.call(req);
93
94        Box::pin(async move {
95            let mut resp = fut.await?;
96
97            // Filter list responses
98            if let Ok(ref mut mcp_resp) = resp.inner {
99                match mcp_resp {
100                    McpResponse::ListTools(result) => {
101                        result.tools.retain(|tool| {
102                            for f in filters.iter() {
103                                if let Some(local_name) = tool.name.strip_prefix(&f.namespace) {
104                                    return f.tool_filter.allows(local_name);
105                                }
106                            }
107                            true
108                        });
109                    }
110                    McpResponse::ListResources(result) => {
111                        result.resources.retain(|resource| {
112                            for f in filters.iter() {
113                                if let Some(local_uri) = resource.uri.strip_prefix(&f.namespace) {
114                                    return f.resource_filter.allows(local_uri);
115                                }
116                            }
117                            true
118                        });
119                    }
120                    McpResponse::ListResourceTemplates(result) => {
121                        result.resource_templates.retain(|template| {
122                            for f in filters.iter() {
123                                if let Some(local_uri) =
124                                    template.uri_template.strip_prefix(&f.namespace)
125                                {
126                                    return f.resource_filter.allows(local_uri);
127                                }
128                            }
129                            true
130                        });
131                    }
132                    McpResponse::ListPrompts(result) => {
133                        result.prompts.retain(|prompt| {
134                            for f in filters.iter() {
135                                if let Some(local_name) = prompt.name.strip_prefix(&f.namespace) {
136                                    return f.prompt_filter.allows(local_name);
137                                }
138                            }
139                            true
140                        });
141                    }
142                    _ => {}
143                }
144            }
145
146            Ok(resp)
147        })
148    }
149}
150
151/// Check if a namespaced tool name is denied by any filter.
152/// Returns Some(reason) if denied.
153fn check_tool_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
154    for f in filters {
155        if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
156            if !f.tool_filter.allows(local_name) {
157                return Some(format!("Tool not available: {}", namespaced_name));
158            }
159            return None;
160        }
161    }
162    None
163}
164
165/// Check if a namespaced resource URI is denied by any filter.
166fn check_resource_denied(filters: &[BackendFilter], namespaced_uri: &str) -> Option<String> {
167    for f in filters {
168        if let Some(local_uri) = namespaced_uri.strip_prefix(&f.namespace) {
169            if !f.resource_filter.allows(local_uri) {
170                return Some(format!("Resource not available: {}", namespaced_uri));
171            }
172            return None;
173        }
174    }
175    None
176}
177
178/// Check if a namespaced prompt name is denied by any filter.
179fn check_prompt_denied(filters: &[BackendFilter], namespaced_name: &str) -> Option<String> {
180    for f in filters {
181        if let Some(local_name) = namespaced_name.strip_prefix(&f.namespace) {
182            if !f.prompt_filter.allows(local_name) {
183                return Some(format!("Prompt not available: {}", namespaced_name));
184            }
185            return None;
186        }
187    }
188    None
189}
190
191#[cfg(test)]
192mod tests {
193    use tower_mcp::protocol::{McpRequest, McpResponse};
194
195    use super::CapabilityFilterService;
196    use crate::config::{BackendFilter, NameFilter};
197    use crate::test_util::{MockService, call_service};
198
199    fn allow_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
200        BackendFilter {
201            namespace: namespace.to_string(),
202            tool_filter: NameFilter::AllowList(tools.iter().map(|s| s.to_string()).collect()),
203            resource_filter: NameFilter::PassAll,
204            prompt_filter: NameFilter::PassAll,
205        }
206    }
207
208    fn deny_filter(namespace: &str, tools: &[&str]) -> BackendFilter {
209        BackendFilter {
210            namespace: namespace.to_string(),
211            tool_filter: NameFilter::DenyList(tools.iter().map(|s| s.to_string()).collect()),
212            resource_filter: NameFilter::PassAll,
213            prompt_filter: NameFilter::PassAll,
214        }
215    }
216
217    #[tokio::test]
218    async fn test_filter_allow_list_tools() {
219        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
220        let filters = vec![allow_filter("fs/", &["read", "write"])];
221        let mut svc = CapabilityFilterService::new(mock, filters);
222
223        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
224        match resp.inner.unwrap() {
225            McpResponse::ListTools(result) => {
226                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
227                assert!(names.contains(&"fs/read"));
228                assert!(names.contains(&"fs/write"));
229                assert!(!names.contains(&"fs/delete"), "delete should be filtered");
230            }
231            other => panic!("expected ListTools, got: {:?}", other),
232        }
233    }
234
235    #[tokio::test]
236    async fn test_filter_deny_list_tools() {
237        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
238        let filters = vec![deny_filter("fs/", &["delete"])];
239        let mut svc = CapabilityFilterService::new(mock, filters);
240
241        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
242        match resp.inner.unwrap() {
243            McpResponse::ListTools(result) => {
244                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
245                assert!(names.contains(&"fs/read"));
246                assert!(names.contains(&"fs/write"));
247                assert!(!names.contains(&"fs/delete"));
248            }
249            other => panic!("expected ListTools, got: {:?}", other),
250        }
251    }
252
253    #[tokio::test]
254    async fn test_filter_denies_call_to_hidden_tool() {
255        let mock = MockService::with_tools(&["fs/read", "fs/delete"]);
256        let filters = vec![allow_filter("fs/", &["read"])];
257        let mut svc = CapabilityFilterService::new(mock, filters);
258
259        let resp = call_service(
260            &mut svc,
261            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
262                name: "fs/delete".to_string(),
263                arguments: serde_json::json!({}),
264                meta: None,
265                task: None,
266            }),
267        )
268        .await;
269
270        let err = resp.inner.unwrap_err();
271        assert!(
272            err.message.contains("not available"),
273            "should deny: {}",
274            err.message
275        );
276    }
277
278    #[tokio::test]
279    async fn test_filter_allows_call_to_permitted_tool() {
280        let mock = MockService::with_tools(&["fs/read"]);
281        let filters = vec![allow_filter("fs/", &["read"])];
282        let mut svc = CapabilityFilterService::new(mock, filters);
283
284        let resp = call_service(
285            &mut svc,
286            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
287                name: "fs/read".to_string(),
288                arguments: serde_json::json!({}),
289                meta: None,
290                task: None,
291            }),
292        )
293        .await;
294
295        assert!(resp.inner.is_ok(), "allowed tool should succeed");
296    }
297
298    #[tokio::test]
299    async fn test_filter_pass_all_allows_everything() {
300        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
301        let filters = vec![BackendFilter {
302            namespace: "fs/".to_string(),
303            tool_filter: NameFilter::PassAll,
304            resource_filter: NameFilter::PassAll,
305            prompt_filter: NameFilter::PassAll,
306        }];
307        let mut svc = CapabilityFilterService::new(mock, filters);
308
309        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
310        match resp.inner.unwrap() {
311            McpResponse::ListTools(result) => {
312                assert_eq!(result.tools.len(), 3);
313            }
314            other => panic!("expected ListTools, got: {:?}", other),
315        }
316    }
317
318    #[tokio::test]
319    async fn test_filter_unmatched_namespace_passes_through() {
320        let mock = MockService::with_tools(&["db/query"]);
321        let filters = vec![allow_filter("fs/", &["read"])];
322        let mut svc = CapabilityFilterService::new(mock, filters);
323
324        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
325        match resp.inner.unwrap() {
326            McpResponse::ListTools(result) => {
327                assert_eq!(result.tools.len(), 1, "unmatched namespace should pass");
328                assert_eq!(result.tools[0].name, "db/query");
329            }
330            other => panic!("expected ListTools, got: {:?}", other),
331        }
332    }
333}