Skip to main content

rmcp_openapi/
filter.rs

1//! Dynamic tool filtering based on request context.
2//!
3//! This module provides the [`ToolFilter`] trait for controlling which tools are
4//! visible and callable based on runtime context such as user permissions, scopes,
5//! or other request-specific criteria.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use rmcp_openapi::{ToolFilter, Tool};
11//! use rmcp::service::{RequestContext, RoleServer};
12//! use async_trait::async_trait;
13//!
14//! /// Filter that only allows read-only (GET) tools
15//! struct ReadOnlyFilter;
16//!
17//! #[async_trait]
18//! impl ToolFilter for ReadOnlyFilter {
19//!     async fn allow(&self, tool: &Tool, _context: &RequestContext<RoleServer>) -> bool {
20//!         tool.metadata.method == "GET"
21//!     }
22//! }
23//! ```
24
25use async_trait::async_trait;
26use rmcp::service::{RequestContext, RoleServer};
27
28use crate::tool::Tool;
29
30/// Trait for dynamically filtering tools based on request context.
31///
32/// Implement this to control which tools are visible and callable
33/// based on user permissions, scopes, or other runtime context.
34///
35/// # Usage
36///
37/// ```rust,ignore
38/// use std::sync::Arc;
39/// use rmcp_openapi::{Server, ToolFilter};
40///
41/// let server = Server::builder()
42///     .openapi_spec(spec)
43///     .base_url(url)
44///     .tool_filter(Arc::new(MyFilter::new()))
45///     .build();
46/// ```
47///
48/// # Behavior
49///
50/// - `list_tools`: Only returns tools where `allow` returns `true`
51/// - `call_tool`: Returns "tool not found" error if filter rejects the tool
52#[async_trait]
53pub trait ToolFilter: Send + Sync {
54    /// Returns true if the tool should be accessible in this context.
55    ///
56    /// Called for both `list_tools` (to filter visible tools) and
57    /// `call_tool` (to enforce access control).
58    ///
59    /// # Arguments
60    ///
61    /// * `tool` - The tool to check access for
62    /// * `context` - The request context containing extensions (e.g., user scopes)
63    ///
64    /// # Returns
65    ///
66    /// `true` if the tool should be accessible, `false` to hide/block it
67    async fn allow(&self, tool: &Tool, context: &RequestContext<RoleServer>) -> bool;
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use std::sync::Arc;
74
75    /// Filter that allows all tools
76    struct AllowAll;
77
78    #[async_trait]
79    impl ToolFilter for AllowAll {
80        async fn allow(&self, _tool: &Tool, _context: &RequestContext<RoleServer>) -> bool {
81            true
82        }
83    }
84
85    /// Filter that blocks all tools
86    struct BlockAll;
87
88    #[async_trait]
89    impl ToolFilter for BlockAll {
90        async fn allow(&self, _tool: &Tool, _context: &RequestContext<RoleServer>) -> bool {
91            false
92        }
93    }
94
95    /// Filter based on tool name prefix
96    struct PrefixFilter {
97        allowed_prefix: String,
98    }
99
100    #[async_trait]
101    impl ToolFilter for PrefixFilter {
102        async fn allow(&self, tool: &Tool, _context: &RequestContext<RoleServer>) -> bool {
103            tool.metadata.name.starts_with(&self.allowed_prefix)
104        }
105    }
106
107    #[test]
108    fn test_trait_is_object_safe() {
109        // Verify the trait can be used with dynamic dispatch
110        fn accepts_filter(_filter: &dyn ToolFilter) {}
111        fn accepts_arc_filter(_filter: Arc<dyn ToolFilter>) {}
112
113        let allow_all = AllowAll;
114        let block_all = BlockAll;
115
116        accepts_filter(&allow_all);
117        accepts_filter(&block_all);
118        accepts_arc_filter(Arc::new(AllowAll));
119        accepts_arc_filter(Arc::new(BlockAll));
120    }
121
122    #[test]
123    fn test_filter_can_be_cloned_via_arc() {
124        // Verify Arc<dyn ToolFilter> can be cloned (important for Server which derives Clone)
125        let filter: Arc<dyn ToolFilter> = Arc::new(AllowAll);
126        let _cloned = filter.clone();
127    }
128
129    #[test]
130    fn test_prefix_filter_can_be_constructed() {
131        // Verify PrefixFilter can be used as a ToolFilter
132        let filter: Arc<dyn ToolFilter> = Arc::new(PrefixFilter {
133            allowed_prefix: "get".to_string(),
134        });
135        let _cloned = filter.clone();
136    }
137}