Skip to main content

mcp_authorization/
server.rs

1use std::future::Future;
2
3use rmcp::{
4    ErrorData as McpError,
5    handler::server::ServerHandler,
6    model::*,
7    service::{MaybeSendFuture, NotificationContext, RequestContext, RoleServer},
8};
9use schemars::JsonSchema;
10
11use crate::capability::AuthContext;
12use crate::metadata::AuthSchemaMetadata;
13use crate::registry::AuthToolRegistry;
14
15/// A `ServerHandler` wrapper that adds authorization-based schema shaping.
16///
17/// Wraps any inner `ServerHandler` and intercepts `list_tools` and
18/// `call_tool` to filter tools and shape schemas based on the user's
19/// `AuthContext` (stored in `RequestContext::extensions`).
20///
21/// ```ignore
22/// let server = AuthorizedServer::new(MyHandler)
23///     .register::<AdvanceStepInput, AdvanceStepOutput>(
24///         "advance_step",
25///         "Advance an applicant",
26///     )
27///     .authorize("advance_step", "manage_workflows");
28/// ```
29pub struct AuthorizedServer<S: ServerHandler> {
30    inner: S,
31    registry: AuthToolRegistry,
32}
33
34impl<S: ServerHandler> AuthorizedServer<S> {
35    pub fn new(inner: S) -> Self {
36        Self {
37            inner,
38            registry: AuthToolRegistry::new(),
39        }
40    }
41
42    /// Register a tool with typed input/output for schema generation
43    /// and authorization metadata.
44    pub fn register<I, O>(
45        mut self,
46        name: impl Into<String>,
47        description: impl Into<String>,
48    ) -> Self
49    where
50        I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
51        O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
52    {
53        self.registry
54            .register_typed::<I, O>(name, description);
55        self
56    }
57
58    /// Set tool-level authorization for a named tool.
59    /// The tool will be hidden from `list_tools` if the user lacks this capability.
60    pub fn authorize(mut self, tool_name: &str, capability: &'static str) -> Self {
61        self.registry.set_authorization(tool_name, capability);
62        self
63    }
64
65    /// Get a reference to the inner handler.
66    pub fn inner(&self) -> &S {
67        &self.inner
68    }
69
70    /// Get a reference to the authorization registry.
71    pub fn registry(&self) -> &AuthToolRegistry {
72        &self.registry
73    }
74}
75
76impl<S: ServerHandler> ServerHandler for AuthorizedServer<S> {
77    fn get_info(&self) -> ServerInfo {
78        self.inner.get_info()
79    }
80
81    fn list_tools(
82        &self,
83        _request: Option<PaginatedRequestParams>,
84        context: RequestContext<RoleServer>,
85    ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
86        async move {
87            let auth = context
88                .extensions
89                .get::<AuthContext>()
90                .ok_or_else(|| McpError::internal_error("missing AuthContext in extensions", None))?;
91
92            let tools = self.registry.materialize(auth);
93            Ok(ListToolsResult {
94                tools,
95                ..Default::default()
96            })
97        }
98    }
99
100    fn call_tool(
101        &self,
102        request: CallToolRequestParams,
103        context: RequestContext<RoleServer>,
104    ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
105        async move {
106            // Check tool-level authorization before delegating
107            if let Some(auth) = context.extensions.get::<AuthContext>() {
108                if !self.registry.is_visible(&request.name, auth) {
109                    return Err(McpError::new(
110                        ErrorCode::METHOD_NOT_FOUND,
111                        format!("tool not found: {}", request.name),
112                        None,
113                    ));
114                }
115            }
116
117            // Delegate actual execution to the inner handler
118            self.inner.call_tool(request, context).await
119        }
120    }
121
122    // --- Delegate everything else to inner ---
123
124    fn initialize(
125        &self,
126        request: InitializeRequestParams,
127        context: RequestContext<RoleServer>,
128    ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
129        self.inner.initialize(request, context)
130    }
131
132    fn ping(
133        &self,
134        context: RequestContext<RoleServer>,
135    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
136        self.inner.ping(context)
137    }
138
139    fn complete(
140        &self,
141        request: CompleteRequestParams,
142        context: RequestContext<RoleServer>,
143    ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
144        self.inner.complete(request, context)
145    }
146
147    fn set_level(
148        &self,
149        request: SetLevelRequestParams,
150        context: RequestContext<RoleServer>,
151    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
152        self.inner.set_level(request, context)
153    }
154
155    fn get_prompt(
156        &self,
157        request: GetPromptRequestParams,
158        context: RequestContext<RoleServer>,
159    ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
160        self.inner.get_prompt(request, context)
161    }
162
163    fn list_prompts(
164        &self,
165        request: Option<PaginatedRequestParams>,
166        context: RequestContext<RoleServer>,
167    ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
168        self.inner.list_prompts(request, context)
169    }
170
171    fn list_resources(
172        &self,
173        request: Option<PaginatedRequestParams>,
174        context: RequestContext<RoleServer>,
175    ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
176        self.inner.list_resources(request, context)
177    }
178
179    fn list_resource_templates(
180        &self,
181        request: Option<PaginatedRequestParams>,
182        context: RequestContext<RoleServer>,
183    ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
184    {
185        self.inner.list_resource_templates(request, context)
186    }
187
188    fn read_resource(
189        &self,
190        request: ReadResourceRequestParams,
191        context: RequestContext<RoleServer>,
192    ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
193        self.inner.read_resource(request, context)
194    }
195
196    fn subscribe(
197        &self,
198        request: SubscribeRequestParams,
199        context: RequestContext<RoleServer>,
200    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
201        self.inner.subscribe(request, context)
202    }
203
204    fn unsubscribe(
205        &self,
206        request: UnsubscribeRequestParams,
207        context: RequestContext<RoleServer>,
208    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
209        self.inner.unsubscribe(request, context)
210    }
211
212    fn get_tool(&self, name: &str) -> Option<Tool> {
213        self.inner.get_tool(name)
214    }
215
216    fn on_custom_request(
217        &self,
218        request: CustomRequest,
219        context: RequestContext<RoleServer>,
220    ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
221        self.inner.on_custom_request(request, context)
222    }
223
224    fn on_cancelled(
225        &self,
226        notification: CancelledNotificationParam,
227        context: NotificationContext<RoleServer>,
228    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
229        self.inner.on_cancelled(notification, context)
230    }
231
232    fn on_progress(
233        &self,
234        notification: ProgressNotificationParam,
235        context: NotificationContext<RoleServer>,
236    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
237        self.inner.on_progress(notification, context)
238    }
239
240    fn on_initialized(
241        &self,
242        context: NotificationContext<RoleServer>,
243    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
244        self.inner.on_initialized(context)
245    }
246
247    fn on_roots_list_changed(
248        &self,
249        context: NotificationContext<RoleServer>,
250    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
251        self.inner.on_roots_list_changed(context)
252    }
253
254    fn on_custom_notification(
255        &self,
256        notification: CustomNotification,
257        context: NotificationContext<RoleServer>,
258    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
259        self.inner.on_custom_notification(notification, context)
260    }
261
262    fn enqueue_task(
263        &self,
264        request: CallToolRequestParams,
265        context: RequestContext<RoleServer>,
266    ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
267        self.inner.enqueue_task(request, context)
268    }
269
270    fn list_tasks(
271        &self,
272        request: Option<PaginatedRequestParams>,
273        context: RequestContext<RoleServer>,
274    ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
275        self.inner.list_tasks(request, context)
276    }
277
278    fn get_task_info(
279        &self,
280        request: GetTaskInfoParams,
281        context: RequestContext<RoleServer>,
282    ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
283        self.inner.get_task_info(request, context)
284    }
285
286    fn get_task_result(
287        &self,
288        request: GetTaskResultParams,
289        context: RequestContext<RoleServer>,
290    ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
291        self.inner.get_task_result(request, context)
292    }
293
294    fn cancel_task(
295        &self,
296        request: CancelTaskParams,
297        context: RequestContext<RoleServer>,
298    ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
299        self.inner.cancel_task(request, context)
300    }
301}