Skip to main content

mcp_authorization/
server.rs

1use std::future::Future;
2
3use rmcp::{
4    handler::server::ServerHandler,
5    model::*,
6    service::{MaybeSendFuture, NotificationContext, RequestContext, RoleServer},
7    ErrorData as McpError,
8};
9use schemars::JsonSchema;
10
11use crate::metadata::AuthSchemaMetadata;
12use crate::provider::{AuthProvider, DenyByDefault};
13use crate::registry::AuthToolRegistry;
14
15/// Typestate marker: no auth source has been chosen yet.
16///
17/// An `AuthorizedServer<S, NoAuth>` is a *builder* — it deliberately does **not**
18/// implement [`ServerHandler`], so it cannot be served over any transport. You
19/// must first choose how requests get an [`AuthContext`](crate::AuthContext):
20/// [`with_auth`](AuthorizedServer::with_auth) (required for network transports)
21/// or [`deny_by_default`](AuthorizedServer::deny_by_default) (ergonomic for
22/// stdio/dev).
23pub struct NoAuth;
24
25/// Typestate marker: an auth source `P` is wired in. Only in this state does
26/// [`AuthorizedServer`] implement [`ServerHandler`].
27pub struct Authorized<P: AuthProvider>(pub(crate) P);
28
29/// A `ServerHandler` wrapper that adds authorization-based schema shaping.
30///
31/// Wraps any inner `ServerHandler` and intercepts `list_tools` / `call_tool` to
32/// filter tools and shape schemas based on the request's
33/// [`AuthContext`](crate::AuthContext).
34///
35/// # Compile-time guarantee
36///
37/// In the spirit of [`Proof<C>`](crate::Proof) — which makes skipping a
38/// capability check *uncompilable* — forgetting to wire authentication is a
39/// **build error**, not a runtime panic. `ServerHandler` is implemented only for
40/// `AuthorizedServer<S, Authorized<P>>`, so a server with no auth source chosen
41/// cannot be served:
42///
43/// ```compile_fail
44/// use mcp_authorization::AuthorizedServer;
45/// use rmcp::handler::server::ServerHandler;
46/// use rmcp::model::{ServerInfo, ServerCapabilities, Implementation};
47///
48/// struct Inner;
49/// impl ServerHandler for Inner {
50///     fn get_info(&self) -> ServerInfo {
51///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
52///             .with_server_info(Implementation::new("inner", "0.0.0"))
53///     }
54/// }
55/// fn requires_handler<T: ServerHandler>(_: T) {}
56///
57/// // No auth source chosen → not a ServerHandler → does not compile.
58/// requires_handler(AuthorizedServer::new(Inner));
59/// ```
60///
61/// Choosing an auth source makes it servable:
62///
63/// ```
64/// use mcp_authorization::AuthorizedServer;
65/// use rmcp::handler::server::ServerHandler;
66/// use rmcp::model::{ServerInfo, ServerCapabilities, Implementation};
67///
68/// struct Inner;
69/// impl ServerHandler for Inner {
70///     fn get_info(&self) -> ServerInfo {
71///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
72///             .with_server_info(Implementation::new("inner", "0.0.0"))
73///     }
74/// }
75/// fn requires_handler<T: ServerHandler>(_: T) {}
76///
77/// // deny_by_default() (or with_auth(..)) yields a real ServerHandler.
78/// requires_handler(AuthorizedServer::new(Inner).deny_by_default());
79/// ```
80pub struct AuthorizedServer<S: ServerHandler, A = NoAuth> {
81    inner: S,
82    registry: AuthToolRegistry,
83    auth: A,
84}
85
86impl<S: ServerHandler> AuthorizedServer<S, NoAuth> {
87    /// Start building an authorized server. No auth source is chosen yet, so the
88    /// result is not yet a [`ServerHandler`] — call
89    /// [`with_auth`](Self::with_auth) or [`deny_by_default`](Self::deny_by_default).
90    pub fn new(inner: S) -> Self {
91        Self {
92            inner,
93            registry: AuthToolRegistry::new(),
94            auth: NoAuth,
95        }
96    }
97}
98
99// Builder operations available in any auth state.
100impl<S: ServerHandler, A> AuthorizedServer<S, A> {
101    /// Register a tool with typed input/output for schema generation and
102    /// authorization metadata.
103    pub fn register<I, O>(
104        mut self,
105        name: impl Into<String>,
106        description: impl Into<String>,
107    ) -> Self
108    where
109        I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
110        O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
111    {
112        self.registry.register_typed::<I, O>(name, description);
113        self
114    }
115
116    /// Set tool-level authorization for a named tool. The tool is hidden from
117    /// `list_tools` if the request's `AuthContext` lacks this capability.
118    pub fn authorize(mut self, tool_name: &str, capability: &'static str) -> Self {
119        self.registry.set_authorization(tool_name, capability);
120        self
121    }
122
123    /// Choose an explicit auth source. **Required before serving any network
124    /// transport.** Transitions the server into the servable
125    /// [`Authorized`] state.
126    ///
127    /// `provider` is anything implementing [`AuthProvider`], including a closure
128    /// `Fn(&RequestContext<RoleServer>) -> AuthContext`.
129    pub fn with_auth<P: AuthProvider>(self, provider: P) -> AuthorizedServer<S, Authorized<P>> {
130        AuthorizedServer {
131            inner: self.inner,
132            registry: self.registry,
133            auth: Authorized(provider),
134        }
135    }
136
137    /// Install [`DenyByDefault`]: use an `AuthContext` injected by middleware if
138    /// present, otherwise resolve to [`AuthContext::empty`](crate::AuthContext::empty).
139    ///
140    /// The ergonomic choice for stdio / local / dev: the server is immediately
141    /// servable and an unauthenticated client sees the least-privileged view
142    /// (only ungated tools) rather than an error.
143    pub fn deny_by_default(self) -> AuthorizedServer<S, Authorized<DenyByDefault>> {
144        self.with_auth(DenyByDefault)
145    }
146
147    /// Get a reference to the inner handler.
148    pub fn inner(&self) -> &S {
149        &self.inner
150    }
151
152    /// Get a reference to the authorization registry.
153    pub fn registry(&self) -> &AuthToolRegistry {
154        &self.registry
155    }
156}
157
158/// Sealed marker for "this server has an auth source and may be served".
159///
160/// Implemented only for [`Authorized`]. Use it as a bound (e.g. in your own
161/// serve helpers) to require — at compile time, with a readable error — that an
162/// auth source was chosen before serving.
163#[diagnostic::on_unimplemented(
164    message = "this `AuthorizedServer` has no auth source, so it cannot be served",
165    note = "call `.with_auth(provider)` (required before any network transport), \
166            or `.deny_by_default()` for stdio/local/dev (least-privilege unless \
167            middleware injects an AuthContext)"
168)]
169pub trait ReadyToServe {}
170impl<P: AuthProvider> ReadyToServe for Authorized<P> {}
171
172impl<S: ServerHandler, P: AuthProvider + 'static> ServerHandler for AuthorizedServer<S, Authorized<P>> {
173    fn get_info(&self) -> ServerInfo {
174        self.inner.get_info()
175    }
176
177    fn list_tools(
178        &self,
179        _request: Option<PaginatedRequestParams>,
180        context: RequestContext<RoleServer>,
181    ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
182        async move {
183            // Resolve via the provider — never errors. No context → deny-by-default.
184            let auth = self.auth.0.auth_for(&context);
185            let tools = self.registry.materialize(&auth);
186            Ok(ListToolsResult {
187                tools,
188                ..Default::default()
189            })
190        }
191    }
192
193    fn call_tool(
194        &self,
195        request: CallToolRequestParams,
196        context: RequestContext<RoleServer>,
197    ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
198        async move {
199            // Always enforce tool-level visibility against the resolved context
200            // (previously skipped when no context was present — a silent gap).
201            let auth = self.auth.0.auth_for(&context);
202            if !self.registry.is_visible(&request.name, &auth) {
203                return Err(McpError::new(
204                    ErrorCode::METHOD_NOT_FOUND,
205                    format!("tool not found: {}", request.name),
206                    None,
207                ));
208            }
209            self.inner.call_tool(request, context).await
210        }
211    }
212
213    // --- Delegate everything else to inner ---
214
215    fn initialize(
216        &self,
217        request: InitializeRequestParams,
218        context: RequestContext<RoleServer>,
219    ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
220        self.inner.initialize(request, context)
221    }
222
223    fn ping(
224        &self,
225        context: RequestContext<RoleServer>,
226    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
227        self.inner.ping(context)
228    }
229
230    fn complete(
231        &self,
232        request: CompleteRequestParams,
233        context: RequestContext<RoleServer>,
234    ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
235        self.inner.complete(request, context)
236    }
237
238    fn set_level(
239        &self,
240        request: SetLevelRequestParams,
241        context: RequestContext<RoleServer>,
242    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
243        self.inner.set_level(request, context)
244    }
245
246    fn get_prompt(
247        &self,
248        request: GetPromptRequestParams,
249        context: RequestContext<RoleServer>,
250    ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
251        self.inner.get_prompt(request, context)
252    }
253
254    fn list_prompts(
255        &self,
256        request: Option<PaginatedRequestParams>,
257        context: RequestContext<RoleServer>,
258    ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
259        self.inner.list_prompts(request, context)
260    }
261
262    fn list_resources(
263        &self,
264        request: Option<PaginatedRequestParams>,
265        context: RequestContext<RoleServer>,
266    ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
267        self.inner.list_resources(request, context)
268    }
269
270    fn list_resource_templates(
271        &self,
272        request: Option<PaginatedRequestParams>,
273        context: RequestContext<RoleServer>,
274    ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
275    {
276        self.inner.list_resource_templates(request, context)
277    }
278
279    fn read_resource(
280        &self,
281        request: ReadResourceRequestParams,
282        context: RequestContext<RoleServer>,
283    ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
284        self.inner.read_resource(request, context)
285    }
286
287    fn subscribe(
288        &self,
289        request: SubscribeRequestParams,
290        context: RequestContext<RoleServer>,
291    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
292        self.inner.subscribe(request, context)
293    }
294
295    fn unsubscribe(
296        &self,
297        request: UnsubscribeRequestParams,
298        context: RequestContext<RoleServer>,
299    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
300        self.inner.unsubscribe(request, context)
301    }
302
303    fn get_tool(&self, name: &str) -> Option<Tool> {
304        self.inner.get_tool(name)
305    }
306
307    fn on_custom_request(
308        &self,
309        request: CustomRequest,
310        context: RequestContext<RoleServer>,
311    ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
312        self.inner.on_custom_request(request, context)
313    }
314
315    fn on_cancelled(
316        &self,
317        notification: CancelledNotificationParam,
318        context: NotificationContext<RoleServer>,
319    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
320        self.inner.on_cancelled(notification, context)
321    }
322
323    fn on_progress(
324        &self,
325        notification: ProgressNotificationParam,
326        context: NotificationContext<RoleServer>,
327    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
328        self.inner.on_progress(notification, context)
329    }
330
331    fn on_initialized(
332        &self,
333        context: NotificationContext<RoleServer>,
334    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
335        self.inner.on_initialized(context)
336    }
337
338    fn on_roots_list_changed(
339        &self,
340        context: NotificationContext<RoleServer>,
341    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
342        self.inner.on_roots_list_changed(context)
343    }
344
345    fn on_custom_notification(
346        &self,
347        notification: CustomNotification,
348        context: NotificationContext<RoleServer>,
349    ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
350        self.inner.on_custom_notification(notification, context)
351    }
352
353    fn enqueue_task(
354        &self,
355        request: CallToolRequestParams,
356        context: RequestContext<RoleServer>,
357    ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
358        self.inner.enqueue_task(request, context)
359    }
360
361    fn list_tasks(
362        &self,
363        request: Option<PaginatedRequestParams>,
364        context: RequestContext<RoleServer>,
365    ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
366        self.inner.list_tasks(request, context)
367    }
368
369    fn get_task_info(
370        &self,
371        request: GetTaskInfoParams,
372        context: RequestContext<RoleServer>,
373    ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
374        self.inner.get_task_info(request, context)
375    }
376
377    fn get_task_result(
378        &self,
379        request: GetTaskResultParams,
380        context: RequestContext<RoleServer>,
381    ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
382        self.inner.get_task_result(request, context)
383    }
384
385    fn cancel_task(
386        &self,
387        request: CancelTaskParams,
388        context: RequestContext<RoleServer>,
389    ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
390        self.inner.cancel_task(request, context)
391    }
392}