1use std::future::Future;
2
3use rmcp::{
4 handler::server::ServerHandler,
5 model::*,
6 service::{MaybeSendFuture, NotificationContext, RequestContext, RoleServer},
7 ErrorData as McpError,
8};
9use schemars::JsonSchema;
10
11use crate::metadata::AuthSchemaMetadata;
12use crate::provider::{AuthProvider, DenyByDefault};
13use crate::registry::AuthToolRegistry;
14
15pub struct NoAuth;
24
25pub struct Authorized<P: AuthProvider>(pub(crate) P);
28
29pub struct AuthorizedServer<S: ServerHandler, A = NoAuth> {
81 inner: S,
82 registry: AuthToolRegistry,
83 auth: A,
84}
85
86impl<S: ServerHandler> AuthorizedServer<S, NoAuth> {
87 pub fn new(inner: S) -> Self {
91 Self {
92 inner,
93 registry: AuthToolRegistry::new(),
94 auth: NoAuth,
95 }
96 }
97}
98
99impl<S: ServerHandler, A> AuthorizedServer<S, A> {
101 pub fn register<I, O>(
104 mut self,
105 name: impl Into<String>,
106 description: impl Into<String>,
107 ) -> Self
108 where
109 I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
110 O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
111 {
112 self.registry.register_typed::<I, O>(name, description);
113 self
114 }
115
116 pub fn authorize(mut self, tool_name: &str, capability: &'static str) -> Self {
119 self.registry.set_authorization(tool_name, capability);
120 self
121 }
122
123 pub fn with_auth<P: AuthProvider>(self, provider: P) -> AuthorizedServer<S, Authorized<P>> {
130 AuthorizedServer {
131 inner: self.inner,
132 registry: self.registry,
133 auth: Authorized(provider),
134 }
135 }
136
137 pub fn deny_by_default(self) -> AuthorizedServer<S, Authorized<DenyByDefault>> {
144 self.with_auth(DenyByDefault)
145 }
146
147 pub fn inner(&self) -> &S {
149 &self.inner
150 }
151
152 pub fn registry(&self) -> &AuthToolRegistry {
154 &self.registry
155 }
156}
157
158#[diagnostic::on_unimplemented(
164 message = "this `AuthorizedServer` has no auth source, so it cannot be served",
165 note = "call `.with_auth(provider)` (required before any network transport), \
166 or `.deny_by_default()` for stdio/local/dev (least-privilege unless \
167 middleware injects an AuthContext)"
168)]
169pub trait ReadyToServe {}
170impl<P: AuthProvider> ReadyToServe for Authorized<P> {}
171
172impl<S: ServerHandler, P: AuthProvider + 'static> ServerHandler for AuthorizedServer<S, Authorized<P>> {
173 fn get_info(&self) -> ServerInfo {
174 self.inner.get_info()
175 }
176
177 fn list_tools(
178 &self,
179 _request: Option<PaginatedRequestParams>,
180 context: RequestContext<RoleServer>,
181 ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
182 async move {
183 let auth = self.auth.0.auth_for(&context);
185 let tools = self.registry.materialize(&auth);
186 Ok(ListToolsResult {
187 tools,
188 ..Default::default()
189 })
190 }
191 }
192
193 fn call_tool(
194 &self,
195 request: CallToolRequestParams,
196 context: RequestContext<RoleServer>,
197 ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
198 async move {
199 let auth = self.auth.0.auth_for(&context);
202 if !self.registry.is_visible(&request.name, &auth) {
203 return Err(McpError::new(
204 ErrorCode::METHOD_NOT_FOUND,
205 format!("tool not found: {}", request.name),
206 None,
207 ));
208 }
209 self.inner.call_tool(request, context).await
210 }
211 }
212
213 fn initialize(
216 &self,
217 request: InitializeRequestParams,
218 context: RequestContext<RoleServer>,
219 ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
220 self.inner.initialize(request, context)
221 }
222
223 fn ping(
224 &self,
225 context: RequestContext<RoleServer>,
226 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
227 self.inner.ping(context)
228 }
229
230 fn complete(
231 &self,
232 request: CompleteRequestParams,
233 context: RequestContext<RoleServer>,
234 ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
235 self.inner.complete(request, context)
236 }
237
238 fn set_level(
239 &self,
240 request: SetLevelRequestParams,
241 context: RequestContext<RoleServer>,
242 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
243 self.inner.set_level(request, context)
244 }
245
246 fn get_prompt(
247 &self,
248 request: GetPromptRequestParams,
249 context: RequestContext<RoleServer>,
250 ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
251 self.inner.get_prompt(request, context)
252 }
253
254 fn list_prompts(
255 &self,
256 request: Option<PaginatedRequestParams>,
257 context: RequestContext<RoleServer>,
258 ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
259 self.inner.list_prompts(request, context)
260 }
261
262 fn list_resources(
263 &self,
264 request: Option<PaginatedRequestParams>,
265 context: RequestContext<RoleServer>,
266 ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
267 self.inner.list_resources(request, context)
268 }
269
270 fn list_resource_templates(
271 &self,
272 request: Option<PaginatedRequestParams>,
273 context: RequestContext<RoleServer>,
274 ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
275 {
276 self.inner.list_resource_templates(request, context)
277 }
278
279 fn read_resource(
280 &self,
281 request: ReadResourceRequestParams,
282 context: RequestContext<RoleServer>,
283 ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
284 self.inner.read_resource(request, context)
285 }
286
287 fn subscribe(
288 &self,
289 request: SubscribeRequestParams,
290 context: RequestContext<RoleServer>,
291 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
292 self.inner.subscribe(request, context)
293 }
294
295 fn unsubscribe(
296 &self,
297 request: UnsubscribeRequestParams,
298 context: RequestContext<RoleServer>,
299 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
300 self.inner.unsubscribe(request, context)
301 }
302
303 fn get_tool(&self, name: &str) -> Option<Tool> {
304 self.inner.get_tool(name)
305 }
306
307 fn on_custom_request(
308 &self,
309 request: CustomRequest,
310 context: RequestContext<RoleServer>,
311 ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
312 self.inner.on_custom_request(request, context)
313 }
314
315 fn on_cancelled(
316 &self,
317 notification: CancelledNotificationParam,
318 context: NotificationContext<RoleServer>,
319 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
320 self.inner.on_cancelled(notification, context)
321 }
322
323 fn on_progress(
324 &self,
325 notification: ProgressNotificationParam,
326 context: NotificationContext<RoleServer>,
327 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
328 self.inner.on_progress(notification, context)
329 }
330
331 fn on_initialized(
332 &self,
333 context: NotificationContext<RoleServer>,
334 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
335 self.inner.on_initialized(context)
336 }
337
338 fn on_roots_list_changed(
339 &self,
340 context: NotificationContext<RoleServer>,
341 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
342 self.inner.on_roots_list_changed(context)
343 }
344
345 fn on_custom_notification(
346 &self,
347 notification: CustomNotification,
348 context: NotificationContext<RoleServer>,
349 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
350 self.inner.on_custom_notification(notification, context)
351 }
352
353 fn enqueue_task(
354 &self,
355 request: CallToolRequestParams,
356 context: RequestContext<RoleServer>,
357 ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
358 self.inner.enqueue_task(request, context)
359 }
360
361 fn list_tasks(
362 &self,
363 request: Option<PaginatedRequestParams>,
364 context: RequestContext<RoleServer>,
365 ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
366 self.inner.list_tasks(request, context)
367 }
368
369 fn get_task_info(
370 &self,
371 request: GetTaskInfoParams,
372 context: RequestContext<RoleServer>,
373 ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
374 self.inner.get_task_info(request, context)
375 }
376
377 fn get_task_result(
378 &self,
379 request: GetTaskResultParams,
380 context: RequestContext<RoleServer>,
381 ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
382 self.inner.get_task_result(request, context)
383 }
384
385 fn cancel_task(
386 &self,
387 request: CancelTaskParams,
388 context: RequestContext<RoleServer>,
389 ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
390 self.inner.cancel_task(request, context)
391 }
392}