1use std::future::Future;
2
3use rmcp::{
4 ErrorData as McpError,
5 handler::server::ServerHandler,
6 model::*,
7 service::{MaybeSendFuture, NotificationContext, RequestContext, RoleServer},
8};
9use schemars::JsonSchema;
10
11use crate::capability::AuthContext;
12use crate::metadata::AuthSchemaMetadata;
13use crate::registry::AuthToolRegistry;
14
15pub struct AuthorizedServer<S: ServerHandler> {
30 inner: S,
31 registry: AuthToolRegistry,
32}
33
34impl<S: ServerHandler> AuthorizedServer<S> {
35 pub fn new(inner: S) -> Self {
36 Self {
37 inner,
38 registry: AuthToolRegistry::new(),
39 }
40 }
41
42 pub fn register<I, O>(
45 mut self,
46 name: impl Into<String>,
47 description: impl Into<String>,
48 ) -> Self
49 where
50 I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
51 O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
52 {
53 self.registry
54 .register_typed::<I, O>(name, description);
55 self
56 }
57
58 pub fn authorize(mut self, tool_name: &str, capability: &'static str) -> Self {
61 self.registry.set_authorization(tool_name, capability);
62 self
63 }
64
65 pub fn inner(&self) -> &S {
67 &self.inner
68 }
69
70 pub fn registry(&self) -> &AuthToolRegistry {
72 &self.registry
73 }
74}
75
76impl<S: ServerHandler> ServerHandler for AuthorizedServer<S> {
77 fn get_info(&self) -> ServerInfo {
78 self.inner.get_info()
79 }
80
81 fn list_tools(
82 &self,
83 _request: Option<PaginatedRequestParams>,
84 context: RequestContext<RoleServer>,
85 ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
86 async move {
87 let auth = context
88 .extensions
89 .get::<AuthContext>()
90 .ok_or_else(|| McpError::internal_error("missing AuthContext in extensions", None))?;
91
92 let tools = self.registry.materialize(auth);
93 Ok(ListToolsResult {
94 tools,
95 ..Default::default()
96 })
97 }
98 }
99
100 fn call_tool(
101 &self,
102 request: CallToolRequestParams,
103 context: RequestContext<RoleServer>,
104 ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
105 async move {
106 if let Some(auth) = context.extensions.get::<AuthContext>() {
108 if !self.registry.is_visible(&request.name, auth) {
109 return Err(McpError::new(
110 ErrorCode::METHOD_NOT_FOUND,
111 format!("tool not found: {}", request.name),
112 None,
113 ));
114 }
115 }
116
117 self.inner.call_tool(request, context).await
119 }
120 }
121
122 fn initialize(
125 &self,
126 request: InitializeRequestParams,
127 context: RequestContext<RoleServer>,
128 ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
129 self.inner.initialize(request, context)
130 }
131
132 fn ping(
133 &self,
134 context: RequestContext<RoleServer>,
135 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
136 self.inner.ping(context)
137 }
138
139 fn complete(
140 &self,
141 request: CompleteRequestParams,
142 context: RequestContext<RoleServer>,
143 ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
144 self.inner.complete(request, context)
145 }
146
147 fn set_level(
148 &self,
149 request: SetLevelRequestParams,
150 context: RequestContext<RoleServer>,
151 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
152 self.inner.set_level(request, context)
153 }
154
155 fn get_prompt(
156 &self,
157 request: GetPromptRequestParams,
158 context: RequestContext<RoleServer>,
159 ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
160 self.inner.get_prompt(request, context)
161 }
162
163 fn list_prompts(
164 &self,
165 request: Option<PaginatedRequestParams>,
166 context: RequestContext<RoleServer>,
167 ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
168 self.inner.list_prompts(request, context)
169 }
170
171 fn list_resources(
172 &self,
173 request: Option<PaginatedRequestParams>,
174 context: RequestContext<RoleServer>,
175 ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
176 self.inner.list_resources(request, context)
177 }
178
179 fn list_resource_templates(
180 &self,
181 request: Option<PaginatedRequestParams>,
182 context: RequestContext<RoleServer>,
183 ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
184 {
185 self.inner.list_resource_templates(request, context)
186 }
187
188 fn read_resource(
189 &self,
190 request: ReadResourceRequestParams,
191 context: RequestContext<RoleServer>,
192 ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
193 self.inner.read_resource(request, context)
194 }
195
196 fn subscribe(
197 &self,
198 request: SubscribeRequestParams,
199 context: RequestContext<RoleServer>,
200 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
201 self.inner.subscribe(request, context)
202 }
203
204 fn unsubscribe(
205 &self,
206 request: UnsubscribeRequestParams,
207 context: RequestContext<RoleServer>,
208 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
209 self.inner.unsubscribe(request, context)
210 }
211
212 fn get_tool(&self, name: &str) -> Option<Tool> {
213 self.inner.get_tool(name)
214 }
215
216 fn on_custom_request(
217 &self,
218 request: CustomRequest,
219 context: RequestContext<RoleServer>,
220 ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
221 self.inner.on_custom_request(request, context)
222 }
223
224 fn on_cancelled(
225 &self,
226 notification: CancelledNotificationParam,
227 context: NotificationContext<RoleServer>,
228 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
229 self.inner.on_cancelled(notification, context)
230 }
231
232 fn on_progress(
233 &self,
234 notification: ProgressNotificationParam,
235 context: NotificationContext<RoleServer>,
236 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
237 self.inner.on_progress(notification, context)
238 }
239
240 fn on_initialized(
241 &self,
242 context: NotificationContext<RoleServer>,
243 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
244 self.inner.on_initialized(context)
245 }
246
247 fn on_roots_list_changed(
248 &self,
249 context: NotificationContext<RoleServer>,
250 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
251 self.inner.on_roots_list_changed(context)
252 }
253
254 fn on_custom_notification(
255 &self,
256 notification: CustomNotification,
257 context: NotificationContext<RoleServer>,
258 ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
259 self.inner.on_custom_notification(notification, context)
260 }
261
262 fn enqueue_task(
263 &self,
264 request: CallToolRequestParams,
265 context: RequestContext<RoleServer>,
266 ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
267 self.inner.enqueue_task(request, context)
268 }
269
270 fn list_tasks(
271 &self,
272 request: Option<PaginatedRequestParams>,
273 context: RequestContext<RoleServer>,
274 ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
275 self.inner.list_tasks(request, context)
276 }
277
278 fn get_task_info(
279 &self,
280 request: GetTaskInfoParams,
281 context: RequestContext<RoleServer>,
282 ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
283 self.inner.get_task_info(request, context)
284 }
285
286 fn get_task_result(
287 &self,
288 request: GetTaskResultParams,
289 context: RequestContext<RoleServer>,
290 ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
291 self.inner.get_task_result(request, context)
292 }
293
294 fn cancel_task(
295 &self,
296 request: CancelTaskParams,
297 context: RequestContext<RoleServer>,
298 ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
299 self.inner.cancel_task(request, context)
300 }
301}