Skip to main content

rmcp/handler/
server.rs

1use std::sync::Arc;
2
3use crate::{
4    error::ErrorData as McpError,
5    model::{TaskSupport, *},
6    service::{
7        MaybeSendFuture, NotificationContext, RequestContext, RoleServer, Service, ServiceRole,
8    },
9};
10
11pub mod common;
12pub mod prompt;
13mod resource;
14pub mod router;
15pub mod tool;
16pub mod tool_name_validation;
17pub mod wrapper;
18
19impl<H: ServerHandler> Service<RoleServer> for H {
20    async fn handle_request(
21        &self,
22        request: <RoleServer as ServiceRole>::PeerReq,
23        context: RequestContext<RoleServer>,
24    ) -> Result<<RoleServer as ServiceRole>::Resp, McpError> {
25        match request {
26            ClientRequest::InitializeRequest(request) => self
27                .initialize(request.params, context)
28                .await
29                .map(ServerResult::InitializeResult),
30            ClientRequest::PingRequest(_request) => {
31                self.ping(context).await.map(ServerResult::empty)
32            }
33            ClientRequest::CompleteRequest(request) => self
34                .complete(request.params, context)
35                .await
36                .map(ServerResult::CompleteResult),
37            ClientRequest::SetLevelRequest(request) => self
38                .set_level(request.params, context)
39                .await
40                .map(ServerResult::empty),
41            ClientRequest::GetPromptRequest(request) => self
42                .get_prompt(request.params, context)
43                .await
44                .map(ServerResult::GetPromptResult),
45            ClientRequest::ListPromptsRequest(request) => self
46                .list_prompts(request.params, context)
47                .await
48                .map(ServerResult::ListPromptsResult),
49            ClientRequest::ListResourcesRequest(request) => self
50                .list_resources(request.params, context)
51                .await
52                .map(ServerResult::ListResourcesResult),
53            ClientRequest::ListResourceTemplatesRequest(request) => self
54                .list_resource_templates(request.params, context)
55                .await
56                .map(ServerResult::ListResourceTemplatesResult),
57            ClientRequest::ReadResourceRequest(request) => self
58                .read_resource(request.params, context)
59                .await
60                .map(ServerResult::ReadResourceResult),
61            ClientRequest::SubscribeRequest(request) => self
62                .subscribe(request.params, context)
63                .await
64                .map(ServerResult::empty),
65            ClientRequest::UnsubscribeRequest(request) => self
66                .unsubscribe(request.params, context)
67                .await
68                .map(ServerResult::empty),
69            ClientRequest::CallToolRequest(request) => {
70                let is_task = request.params.task.is_some();
71
72                // Validate task support mode per MCP specification
73                if let Some(tool) = self.get_tool(&request.params.name) {
74                    match (tool.task_support(), is_task) {
75                        // If taskSupport is "required", clients MUST invoke the tool as a task.
76                        // Servers MUST return a -32601 (Method not found) error if they don't.
77                        (TaskSupport::Required, false) => {
78                            return Err(McpError::new(
79                                ErrorCode::METHOD_NOT_FOUND,
80                                "Tool requires task-based invocation",
81                                None,
82                            ));
83                        }
84                        // If taskSupport is "forbidden" (default), clients MUST NOT invoke as a task.
85                        (TaskSupport::Forbidden, true) => {
86                            return Err(McpError::invalid_params(
87                                "Tool does not support task-based invocation",
88                                None,
89                            ));
90                        }
91                        _ => {}
92                    }
93                }
94
95                if is_task {
96                    tracing::info!("Enqueueing task for tool call: {}", request.params.name);
97                    self.enqueue_task(request.params, context.clone())
98                        .await
99                        .map(ServerResult::CreateTaskResult)
100                } else {
101                    self.call_tool(request.params, context)
102                        .await
103                        .map(ServerResult::CallToolResult)
104                }
105            }
106            ClientRequest::ListToolsRequest(request) => self
107                .list_tools(request.params, context)
108                .await
109                .map(ServerResult::ListToolsResult),
110            ClientRequest::CustomRequest(request) => self
111                .on_custom_request(request, context)
112                .await
113                .map(ServerResult::CustomResult),
114            ClientRequest::ListTasksRequest(request) => self
115                .list_tasks(request.params, context)
116                .await
117                .map(ServerResult::ListTasksResult),
118            ClientRequest::GetTaskInfoRequest(request) => self
119                .get_task_info(request.params, context)
120                .await
121                .map(ServerResult::GetTaskResult),
122            ClientRequest::GetTaskResultRequest(request) => self
123                .get_task_result(request.params, context)
124                .await
125                .map(ServerResult::GetTaskPayloadResult),
126            ClientRequest::CancelTaskRequest(request) => self
127                .cancel_task(request.params, context)
128                .await
129                .map(ServerResult::CancelTaskResult),
130        }
131    }
132
133    async fn handle_notification(
134        &self,
135        notification: <RoleServer as ServiceRole>::PeerNot,
136        context: NotificationContext<RoleServer>,
137    ) -> Result<(), McpError> {
138        match notification {
139            ClientNotification::CancelledNotification(notification) => {
140                self.on_cancelled(notification.params, context).await
141            }
142            ClientNotification::ProgressNotification(notification) => {
143                self.on_progress(notification.params, context).await
144            }
145            ClientNotification::InitializedNotification(_notification) => {
146                self.on_initialized(context).await
147            }
148            ClientNotification::RootsListChangedNotification(_notification) => {
149                self.on_roots_list_changed(context).await
150            }
151            ClientNotification::CustomNotification(notification) => {
152                self.on_custom_notification(notification, context).await
153            }
154        };
155        Ok(())
156    }
157
158    fn get_info(&self) -> <RoleServer as ServiceRole>::Info {
159        self.get_info()
160    }
161}
162
163macro_rules! server_handler_methods {
164    () => {
165        fn enqueue_task(
166            &self,
167            _request: CallToolRequestParams,
168            _context: RequestContext<RoleServer>,
169        ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
170            std::future::ready(Err(McpError::internal_error(
171                "Task processing not implemented".to_string(),
172                None,
173            )))
174        }
175        fn ping(
176            &self,
177            context: RequestContext<RoleServer>,
178        ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
179            std::future::ready(Ok(()))
180        }
181        // handle requests
182        fn initialize(
183            &self,
184            request: InitializeRequestParams,
185            context: RequestContext<RoleServer>,
186        ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
187            if context.peer.peer_info().is_none() {
188                context.peer.set_peer_info(request);
189            }
190            std::future::ready(Ok(self.get_info()))
191        }
192        fn complete(
193            &self,
194            request: CompleteRequestParams,
195            context: RequestContext<RoleServer>,
196        ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
197            std::future::ready(Ok(CompleteResult::default()))
198        }
199        fn set_level(
200            &self,
201            request: SetLevelRequestParams,
202            context: RequestContext<RoleServer>,
203        ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
204            std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>()))
205        }
206        fn get_prompt(
207            &self,
208            request: GetPromptRequestParams,
209            context: RequestContext<RoleServer>,
210        ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
211            std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>()))
212        }
213        fn list_prompts(
214            &self,
215            request: Option<PaginatedRequestParams>,
216            context: RequestContext<RoleServer>,
217        ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
218            std::future::ready(Ok(ListPromptsResult::default()))
219        }
220        fn list_resources(
221            &self,
222            request: Option<PaginatedRequestParams>,
223            context: RequestContext<RoleServer>,
224        ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
225            std::future::ready(Ok(ListResourcesResult::default()))
226        }
227        fn list_resource_templates(
228            &self,
229            request: Option<PaginatedRequestParams>,
230            context: RequestContext<RoleServer>,
231        ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>>
232               + MaybeSendFuture
233               + '_ {
234            std::future::ready(Ok(ListResourceTemplatesResult::default()))
235        }
236        fn read_resource(
237            &self,
238            request: ReadResourceRequestParams,
239            context: RequestContext<RoleServer>,
240        ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
241            std::future::ready(Err(
242                McpError::method_not_found::<ReadResourceRequestMethod>(),
243            ))
244        }
245        fn subscribe(
246            &self,
247            request: SubscribeRequestParams,
248            context: RequestContext<RoleServer>,
249        ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
250            std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>()))
251        }
252        fn unsubscribe(
253            &self,
254            request: UnsubscribeRequestParams,
255            context: RequestContext<RoleServer>,
256        ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
257            std::future::ready(Err(
258                McpError::method_not_found::<UnsubscribeRequestMethod>(),
259            ))
260        }
261        fn call_tool(
262            &self,
263            request: CallToolRequestParams,
264            context: RequestContext<RoleServer>,
265        ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
266            std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>()))
267        }
268        fn list_tools(
269            &self,
270            request: Option<PaginatedRequestParams>,
271            context: RequestContext<RoleServer>,
272        ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
273            std::future::ready(Ok(ListToolsResult::default()))
274        }
275        /// Get a tool definition by name.
276        ///
277        /// The default implementation returns `None`, which bypasses validation.
278        /// When using `#[tool_handler]`, this method is automatically implemented.
279        fn get_tool(&self, _name: &str) -> Option<Tool> {
280            None
281        }
282        fn on_custom_request(
283            &self,
284            request: CustomRequest,
285            context: RequestContext<RoleServer>,
286        ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
287            let CustomRequest { method, .. } = request;
288            let _ = context;
289            std::future::ready(Err(McpError::new(
290                ErrorCode::METHOD_NOT_FOUND,
291                method,
292                None,
293            )))
294        }
295
296        fn on_cancelled(
297            &self,
298            notification: CancelledNotificationParam,
299            context: NotificationContext<RoleServer>,
300        ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
301            std::future::ready(())
302        }
303        fn on_progress(
304            &self,
305            notification: ProgressNotificationParam,
306            context: NotificationContext<RoleServer>,
307        ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
308            std::future::ready(())
309        }
310        fn on_initialized(
311            &self,
312            context: NotificationContext<RoleServer>,
313        ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
314            tracing::info!("client initialized");
315            std::future::ready(())
316        }
317        fn on_roots_list_changed(
318            &self,
319            context: NotificationContext<RoleServer>,
320        ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
321            std::future::ready(())
322        }
323        fn on_custom_notification(
324            &self,
325            notification: CustomNotification,
326            context: NotificationContext<RoleServer>,
327        ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
328            let _ = (notification, context);
329            std::future::ready(())
330        }
331
332        fn get_info(&self) -> ServerInfo {
333            ServerInfo::default()
334        }
335
336        fn list_tasks(
337            &self,
338            request: Option<PaginatedRequestParams>,
339            context: RequestContext<RoleServer>,
340        ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
341            std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
342        }
343
344        fn get_task_info(
345            &self,
346            request: GetTaskInfoParams,
347            context: RequestContext<RoleServer>,
348        ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
349            let _ = (request, context);
350            std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
351        }
352
353        fn get_task_result(
354            &self,
355            request: GetTaskResultParams,
356            context: RequestContext<RoleServer>,
357        ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
358            let _ = (request, context);
359            std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>()))
360        }
361
362        fn cancel_task(
363            &self,
364            request: CancelTaskParams,
365            context: RequestContext<RoleServer>,
366        ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
367            let _ = (request, context);
368            std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
369        }
370    };
371}
372
373#[allow(unused_variables)]
374#[cfg(not(feature = "local"))]
375pub trait ServerHandler: Sized + Send + Sync + 'static {
376    server_handler_methods!();
377}
378
379#[allow(unused_variables)]
380#[cfg(feature = "local")]
381pub trait ServerHandler: Sized + 'static {
382    server_handler_methods!();
383}
384
385macro_rules! impl_server_handler_for_wrapper {
386    ($wrapper:ident) => {
387        impl<T: ServerHandler> ServerHandler for $wrapper<T> {
388            fn enqueue_task(
389                &self,
390                request: CallToolRequestParams,
391                context: RequestContext<RoleServer>,
392            ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
393                (**self).enqueue_task(request, context)
394            }
395
396            fn ping(
397                &self,
398                context: RequestContext<RoleServer>,
399            ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
400                (**self).ping(context)
401            }
402
403            fn initialize(
404                &self,
405                request: InitializeRequestParams,
406                context: RequestContext<RoleServer>,
407            ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
408                (**self).initialize(request, context)
409            }
410
411            fn complete(
412                &self,
413                request: CompleteRequestParams,
414                context: RequestContext<RoleServer>,
415            ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
416                (**self).complete(request, context)
417            }
418
419            fn set_level(
420                &self,
421                request: SetLevelRequestParams,
422                context: RequestContext<RoleServer>,
423            ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
424                (**self).set_level(request, context)
425            }
426
427            fn get_prompt(
428                &self,
429                request: GetPromptRequestParams,
430                context: RequestContext<RoleServer>,
431            ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
432                (**self).get_prompt(request, context)
433            }
434
435            fn list_prompts(
436                &self,
437                request: Option<PaginatedRequestParams>,
438                context: RequestContext<RoleServer>,
439            ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
440                (**self).list_prompts(request, context)
441            }
442
443            fn list_resources(
444                &self,
445                request: Option<PaginatedRequestParams>,
446                context: RequestContext<RoleServer>,
447            ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
448                (**self).list_resources(request, context)
449            }
450
451            fn list_resource_templates(
452                &self,
453                request: Option<PaginatedRequestParams>,
454                context: RequestContext<RoleServer>,
455            ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
456            {
457                (**self).list_resource_templates(request, context)
458            }
459
460            fn read_resource(
461                &self,
462                request: ReadResourceRequestParams,
463                context: RequestContext<RoleServer>,
464            ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
465                (**self).read_resource(request, context)
466            }
467
468            fn subscribe(
469                &self,
470                request: SubscribeRequestParams,
471                context: RequestContext<RoleServer>,
472            ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
473                (**self).subscribe(request, context)
474            }
475
476            fn unsubscribe(
477                &self,
478                request: UnsubscribeRequestParams,
479                context: RequestContext<RoleServer>,
480            ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
481                (**self).unsubscribe(request, context)
482            }
483
484            fn call_tool(
485                &self,
486                request: CallToolRequestParams,
487                context: RequestContext<RoleServer>,
488            ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
489                (**self).call_tool(request, context)
490            }
491
492            fn list_tools(
493                &self,
494                request: Option<PaginatedRequestParams>,
495                context: RequestContext<RoleServer>,
496            ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
497                (**self).list_tools(request, context)
498            }
499
500            fn get_tool(&self, name: &str) -> Option<Tool> {
501                (**self).get_tool(name)
502            }
503
504            fn on_custom_request(
505                &self,
506                request: CustomRequest,
507                context: RequestContext<RoleServer>,
508            ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
509                (**self).on_custom_request(request, context)
510            }
511
512            fn on_cancelled(
513                &self,
514                notification: CancelledNotificationParam,
515                context: NotificationContext<RoleServer>,
516            ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
517                (**self).on_cancelled(notification, context)
518            }
519
520            fn on_progress(
521                &self,
522                notification: ProgressNotificationParam,
523                context: NotificationContext<RoleServer>,
524            ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
525                (**self).on_progress(notification, context)
526            }
527
528            fn on_initialized(
529                &self,
530                context: NotificationContext<RoleServer>,
531            ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
532                (**self).on_initialized(context)
533            }
534
535            fn on_roots_list_changed(
536                &self,
537                context: NotificationContext<RoleServer>,
538            ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
539                (**self).on_roots_list_changed(context)
540            }
541
542            fn on_custom_notification(
543                &self,
544                notification: CustomNotification,
545                context: NotificationContext<RoleServer>,
546            ) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
547                (**self).on_custom_notification(notification, context)
548            }
549
550            fn get_info(&self) -> ServerInfo {
551                (**self).get_info()
552            }
553
554            fn list_tasks(
555                &self,
556                request: Option<PaginatedRequestParams>,
557                context: RequestContext<RoleServer>,
558            ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
559                (**self).list_tasks(request, context)
560            }
561
562            fn get_task_info(
563                &self,
564                request: GetTaskInfoParams,
565                context: RequestContext<RoleServer>,
566            ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
567                (**self).get_task_info(request, context)
568            }
569
570            fn get_task_result(
571                &self,
572                request: GetTaskResultParams,
573                context: RequestContext<RoleServer>,
574            ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
575                (**self).get_task_result(request, context)
576            }
577
578            fn cancel_task(
579                &self,
580                request: CancelTaskParams,
581                context: RequestContext<RoleServer>,
582            ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
583                (**self).cancel_task(request, context)
584            }
585        }
586    };
587}
588
589impl_server_handler_for_wrapper!(Box);
590impl_server_handler_for_wrapper!(Arc);