Skip to main content

rmcp_soddygo/handler/
server.rs

1use std::sync::Arc;
2
3use crate::{
4    error::ErrorData as McpError,
5    model::{TaskSupport, *},
6    service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole},
7};
8
9pub mod common;
10pub mod prompt;
11mod resource;
12pub mod router;
13pub mod tool;
14pub mod tool_name_validation;
15pub mod wrapper;
16
17impl<H: ServerHandler> Service<RoleServer> for H {
18    async fn handle_request(
19        &self,
20        request: <RoleServer as ServiceRole>::PeerReq,
21        context: RequestContext<RoleServer>,
22    ) -> Result<<RoleServer as ServiceRole>::Resp, McpError> {
23        match request {
24            ClientRequest::InitializeRequest(request) => self
25                .initialize(request.params, context)
26                .await
27                .map(ServerResult::InitializeResult),
28            ClientRequest::PingRequest(_request) => {
29                self.ping(context).await.map(ServerResult::empty)
30            }
31            ClientRequest::CompleteRequest(request) => self
32                .complete(request.params, context)
33                .await
34                .map(ServerResult::CompleteResult),
35            ClientRequest::SetLevelRequest(request) => self
36                .set_level(request.params, context)
37                .await
38                .map(ServerResult::empty),
39            ClientRequest::GetPromptRequest(request) => self
40                .get_prompt(request.params, context)
41                .await
42                .map(ServerResult::GetPromptResult),
43            ClientRequest::ListPromptsRequest(request) => self
44                .list_prompts(request.params, context)
45                .await
46                .map(ServerResult::ListPromptsResult),
47            ClientRequest::ListResourcesRequest(request) => self
48                .list_resources(request.params, context)
49                .await
50                .map(ServerResult::ListResourcesResult),
51            ClientRequest::ListResourceTemplatesRequest(request) => self
52                .list_resource_templates(request.params, context)
53                .await
54                .map(ServerResult::ListResourceTemplatesResult),
55            ClientRequest::ReadResourceRequest(request) => self
56                .read_resource(request.params, context)
57                .await
58                .map(ServerResult::ReadResourceResult),
59            ClientRequest::SubscribeRequest(request) => self
60                .subscribe(request.params, context)
61                .await
62                .map(ServerResult::empty),
63            ClientRequest::UnsubscribeRequest(request) => self
64                .unsubscribe(request.params, context)
65                .await
66                .map(ServerResult::empty),
67            ClientRequest::CallToolRequest(request) => {
68                let is_task = request.params.task.is_some();
69
70                // Validate task support mode per MCP specification
71                if let Some(tool) = self.get_tool(&request.params.name) {
72                    match (tool.task_support(), is_task) {
73                        // If taskSupport is "required", clients MUST invoke the tool as a task.
74                        // Servers MUST return a -32601 (Method not found) error if they don't.
75                        (TaskSupport::Required, false) => {
76                            return Err(McpError::new(
77                                ErrorCode::METHOD_NOT_FOUND,
78                                "Tool requires task-based invocation",
79                                None,
80                            ));
81                        }
82                        // If taskSupport is "forbidden" (default), clients MUST NOT invoke as a task.
83                        (TaskSupport::Forbidden, true) => {
84                            return Err(McpError::invalid_params(
85                                "Tool does not support task-based invocation",
86                                None,
87                            ));
88                        }
89                        _ => {}
90                    }
91                }
92
93                if is_task {
94                    tracing::info!("Enqueueing task for tool call: {}", request.params.name);
95                    self.enqueue_task(request.params, context.clone())
96                        .await
97                        .map(ServerResult::CreateTaskResult)
98                } else {
99                    self.call_tool(request.params, context)
100                        .await
101                        .map(ServerResult::CallToolResult)
102                }
103            }
104            ClientRequest::ListToolsRequest(request) => self
105                .list_tools(request.params, context)
106                .await
107                .map(ServerResult::ListToolsResult),
108            ClientRequest::CustomRequest(request) => self
109                .on_custom_request(request, context)
110                .await
111                .map(ServerResult::CustomResult),
112            ClientRequest::ListTasksRequest(request) => self
113                .list_tasks(request.params, context)
114                .await
115                .map(ServerResult::ListTasksResult),
116            ClientRequest::GetTaskInfoRequest(request) => self
117                .get_task_info(request.params, context)
118                .await
119                .map(ServerResult::GetTaskResult),
120            ClientRequest::GetTaskResultRequest(request) => self
121                .get_task_result(request.params, context)
122                .await
123                .map(ServerResult::GetTaskPayloadResult),
124            ClientRequest::CancelTaskRequest(request) => self
125                .cancel_task(request.params, context)
126                .await
127                .map(ServerResult::CancelTaskResult),
128        }
129    }
130
131    async fn handle_notification(
132        &self,
133        notification: <RoleServer as ServiceRole>::PeerNot,
134        context: NotificationContext<RoleServer>,
135    ) -> Result<(), McpError> {
136        match notification {
137            ClientNotification::CancelledNotification(notification) => {
138                self.on_cancelled(notification.params, context).await
139            }
140            ClientNotification::ProgressNotification(notification) => {
141                self.on_progress(notification.params, context).await
142            }
143            ClientNotification::InitializedNotification(_notification) => {
144                self.on_initialized(context).await
145            }
146            ClientNotification::RootsListChangedNotification(_notification) => {
147                self.on_roots_list_changed(context).await
148            }
149            ClientNotification::CustomNotification(notification) => {
150                self.on_custom_notification(notification, context).await
151            }
152        };
153        Ok(())
154    }
155
156    fn get_info(&self) -> <RoleServer as ServiceRole>::Info {
157        self.get_info()
158    }
159}
160
161#[allow(unused_variables)]
162pub trait ServerHandler: Sized + Send + Sync + 'static {
163    fn enqueue_task(
164        &self,
165        _request: CallToolRequestParams,
166        _context: RequestContext<RoleServer>,
167    ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
168        std::future::ready(Err(McpError::internal_error(
169            "Task processing not implemented".to_string(),
170            None,
171        )))
172    }
173    fn ping(
174        &self,
175        context: RequestContext<RoleServer>,
176    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
177        std::future::ready(Ok(()))
178    }
179    // handle requests
180    fn initialize(
181        &self,
182        request: InitializeRequestParams,
183        context: RequestContext<RoleServer>,
184    ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
185        if context.peer.peer_info().is_none() {
186            context.peer.set_peer_info(request);
187        }
188        std::future::ready(Ok(self.get_info()))
189    }
190    fn complete(
191        &self,
192        request: CompleteRequestParams,
193        context: RequestContext<RoleServer>,
194    ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
195        std::future::ready(Ok(CompleteResult::default()))
196    }
197    fn set_level(
198        &self,
199        request: SetLevelRequestParams,
200        context: RequestContext<RoleServer>,
201    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
202        std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>()))
203    }
204    fn get_prompt(
205        &self,
206        request: GetPromptRequestParams,
207        context: RequestContext<RoleServer>,
208    ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
209        std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>()))
210    }
211    fn list_prompts(
212        &self,
213        request: Option<PaginatedRequestParams>,
214        context: RequestContext<RoleServer>,
215    ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
216        std::future::ready(Ok(ListPromptsResult::default()))
217    }
218    fn list_resources(
219        &self,
220        request: Option<PaginatedRequestParams>,
221        context: RequestContext<RoleServer>,
222    ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
223        std::future::ready(Ok(ListResourcesResult::default()))
224    }
225    fn list_resource_templates(
226        &self,
227        request: Option<PaginatedRequestParams>,
228        context: RequestContext<RoleServer>,
229    ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_ {
230        std::future::ready(Ok(ListResourceTemplatesResult::default()))
231    }
232    fn read_resource(
233        &self,
234        request: ReadResourceRequestParams,
235        context: RequestContext<RoleServer>,
236    ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
237        std::future::ready(Err(
238            McpError::method_not_found::<ReadResourceRequestMethod>(),
239        ))
240    }
241    fn subscribe(
242        &self,
243        request: SubscribeRequestParams,
244        context: RequestContext<RoleServer>,
245    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
246        std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>()))
247    }
248    fn unsubscribe(
249        &self,
250        request: UnsubscribeRequestParams,
251        context: RequestContext<RoleServer>,
252    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
253        std::future::ready(Err(McpError::method_not_found::<UnsubscribeRequestMethod>()))
254    }
255    fn call_tool(
256        &self,
257        request: CallToolRequestParams,
258        context: RequestContext<RoleServer>,
259    ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
260        std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>()))
261    }
262    fn list_tools(
263        &self,
264        request: Option<PaginatedRequestParams>,
265        context: RequestContext<RoleServer>,
266    ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
267        std::future::ready(Ok(ListToolsResult::default()))
268    }
269    /// Get a tool definition by name.
270    ///
271    /// The default implementation returns `None`, which bypasses validation.
272    /// When using `#[tool_handler]`, this method is automatically implemented.
273    fn get_tool(&self, _name: &str) -> Option<Tool> {
274        None
275    }
276    fn on_custom_request(
277        &self,
278        request: CustomRequest,
279        context: RequestContext<RoleServer>,
280    ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
281        let CustomRequest { method, .. } = request;
282        let _ = context;
283        std::future::ready(Err(McpError::new(
284            ErrorCode::METHOD_NOT_FOUND,
285            method,
286            None,
287        )))
288    }
289
290    fn on_cancelled(
291        &self,
292        notification: CancelledNotificationParam,
293        context: NotificationContext<RoleServer>,
294    ) -> impl Future<Output = ()> + Send + '_ {
295        std::future::ready(())
296    }
297    fn on_progress(
298        &self,
299        notification: ProgressNotificationParam,
300        context: NotificationContext<RoleServer>,
301    ) -> impl Future<Output = ()> + Send + '_ {
302        std::future::ready(())
303    }
304    fn on_initialized(
305        &self,
306        context: NotificationContext<RoleServer>,
307    ) -> impl Future<Output = ()> + Send + '_ {
308        tracing::info!("client initialized");
309        std::future::ready(())
310    }
311    fn on_roots_list_changed(
312        &self,
313        context: NotificationContext<RoleServer>,
314    ) -> impl Future<Output = ()> + Send + '_ {
315        std::future::ready(())
316    }
317    fn on_custom_notification(
318        &self,
319        notification: CustomNotification,
320        context: NotificationContext<RoleServer>,
321    ) -> impl Future<Output = ()> + Send + '_ {
322        let _ = (notification, context);
323        std::future::ready(())
324    }
325
326    fn get_info(&self) -> ServerInfo {
327        ServerInfo::default()
328    }
329
330    fn list_tasks(
331        &self,
332        request: Option<PaginatedRequestParams>,
333        context: RequestContext<RoleServer>,
334    ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
335        std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
336    }
337
338    fn get_task_info(
339        &self,
340        request: GetTaskInfoParams,
341        context: RequestContext<RoleServer>,
342    ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
343        let _ = (request, context);
344        std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
345    }
346
347    fn get_task_result(
348        &self,
349        request: GetTaskResultParams,
350        context: RequestContext<RoleServer>,
351    ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
352        let _ = (request, context);
353        std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>()))
354    }
355
356    fn cancel_task(
357        &self,
358        request: CancelTaskParams,
359        context: RequestContext<RoleServer>,
360    ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
361        let _ = (request, context);
362        std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
363    }
364}
365
366macro_rules! impl_server_handler_for_wrapper {
367    ($wrapper:ident) => {
368        impl<T: ServerHandler> ServerHandler for $wrapper<T> {
369            fn enqueue_task(
370                &self,
371                request: CallToolRequestParams,
372                context: RequestContext<RoleServer>,
373            ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
374                (**self).enqueue_task(request, context)
375            }
376
377            fn ping(
378                &self,
379                context: RequestContext<RoleServer>,
380            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
381                (**self).ping(context)
382            }
383
384            fn initialize(
385                &self,
386                request: InitializeRequestParams,
387                context: RequestContext<RoleServer>,
388            ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
389                (**self).initialize(request, context)
390            }
391
392            fn complete(
393                &self,
394                request: CompleteRequestParams,
395                context: RequestContext<RoleServer>,
396            ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
397                (**self).complete(request, context)
398            }
399
400            fn set_level(
401                &self,
402                request: SetLevelRequestParams,
403                context: RequestContext<RoleServer>,
404            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
405                (**self).set_level(request, context)
406            }
407
408            fn get_prompt(
409                &self,
410                request: GetPromptRequestParams,
411                context: RequestContext<RoleServer>,
412            ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
413                (**self).get_prompt(request, context)
414            }
415
416            fn list_prompts(
417                &self,
418                request: Option<PaginatedRequestParams>,
419                context: RequestContext<RoleServer>,
420            ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
421                (**self).list_prompts(request, context)
422            }
423
424            fn list_resources(
425                &self,
426                request: Option<PaginatedRequestParams>,
427                context: RequestContext<RoleServer>,
428            ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
429                (**self).list_resources(request, context)
430            }
431
432            fn list_resource_templates(
433                &self,
434                request: Option<PaginatedRequestParams>,
435                context: RequestContext<RoleServer>,
436            ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_
437            {
438                (**self).list_resource_templates(request, context)
439            }
440
441            fn read_resource(
442                &self,
443                request: ReadResourceRequestParams,
444                context: RequestContext<RoleServer>,
445            ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
446                (**self).read_resource(request, context)
447            }
448
449            fn subscribe(
450                &self,
451                request: SubscribeRequestParams,
452                context: RequestContext<RoleServer>,
453            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
454                (**self).subscribe(request, context)
455            }
456
457            fn unsubscribe(
458                &self,
459                request: UnsubscribeRequestParams,
460                context: RequestContext<RoleServer>,
461            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
462                (**self).unsubscribe(request, context)
463            }
464
465            fn call_tool(
466                &self,
467                request: CallToolRequestParams,
468                context: RequestContext<RoleServer>,
469            ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
470                (**self).call_tool(request, context)
471            }
472
473            fn list_tools(
474                &self,
475                request: Option<PaginatedRequestParams>,
476                context: RequestContext<RoleServer>,
477            ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
478                (**self).list_tools(request, context)
479            }
480
481            fn get_tool(&self, name: &str) -> Option<Tool> {
482                (**self).get_tool(name)
483            }
484
485            fn on_custom_request(
486                &self,
487                request: CustomRequest,
488                context: RequestContext<RoleServer>,
489            ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
490                (**self).on_custom_request(request, context)
491            }
492
493            fn on_cancelled(
494                &self,
495                notification: CancelledNotificationParam,
496                context: NotificationContext<RoleServer>,
497            ) -> impl Future<Output = ()> + Send + '_ {
498                (**self).on_cancelled(notification, context)
499            }
500
501            fn on_progress(
502                &self,
503                notification: ProgressNotificationParam,
504                context: NotificationContext<RoleServer>,
505            ) -> impl Future<Output = ()> + Send + '_ {
506                (**self).on_progress(notification, context)
507            }
508
509            fn on_initialized(
510                &self,
511                context: NotificationContext<RoleServer>,
512            ) -> impl Future<Output = ()> + Send + '_ {
513                (**self).on_initialized(context)
514            }
515
516            fn on_roots_list_changed(
517                &self,
518                context: NotificationContext<RoleServer>,
519            ) -> impl Future<Output = ()> + Send + '_ {
520                (**self).on_roots_list_changed(context)
521            }
522
523            fn on_custom_notification(
524                &self,
525                notification: CustomNotification,
526                context: NotificationContext<RoleServer>,
527            ) -> impl Future<Output = ()> + Send + '_ {
528                (**self).on_custom_notification(notification, context)
529            }
530
531            fn get_info(&self) -> ServerInfo {
532                (**self).get_info()
533            }
534
535            fn list_tasks(
536                &self,
537                request: Option<PaginatedRequestParams>,
538                context: RequestContext<RoleServer>,
539            ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
540                (**self).list_tasks(request, context)
541            }
542
543            fn get_task_info(
544                &self,
545                request: GetTaskInfoParams,
546                context: RequestContext<RoleServer>,
547            ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
548                (**self).get_task_info(request, context)
549            }
550
551            fn get_task_result(
552                &self,
553                request: GetTaskResultParams,
554                context: RequestContext<RoleServer>,
555            ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
556                (**self).get_task_result(request, context)
557            }
558
559            fn cancel_task(
560                &self,
561                request: CancelTaskParams,
562                context: RequestContext<RoleServer>,
563            ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
564                (**self).cancel_task(request, context)
565            }
566        }
567    };
568}
569
570impl_server_handler_for_wrapper!(Box);
571impl_server_handler_for_wrapper!(Arc);