Skip to main content

mcpkit_rs/handler/
server.rs

1use std::sync::Arc;
2
3use crate::{
4    error::ErrorData as McpError,
5    model::*,
6    service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole},
7};
8
9pub mod common;
10#[cfg(feature = "policy")]
11pub mod policy;
12pub mod prompt;
13mod resource;
14pub mod router;
15pub mod tool;
16pub mod tool_name_validation;
17pub mod wrapper;
18
19impl<H: ServerHandler> Service<RoleServer> for H {
20    async fn handle_request(
21        &self,
22        request: <RoleServer as ServiceRole>::PeerReq,
23        context: RequestContext<RoleServer>,
24    ) -> Result<<RoleServer as ServiceRole>::Resp, McpError> {
25        match request {
26            ClientRequest::InitializeRequest(request) => self
27                .initialize(request.params, context)
28                .await
29                .map(ServerResult::InitializeResult),
30            ClientRequest::PingRequest(_request) => {
31                self.ping(context).await.map(ServerResult::empty)
32            }
33            ClientRequest::CompleteRequest(request) => self
34                .complete(request.params, context)
35                .await
36                .map(ServerResult::CompleteResult),
37            ClientRequest::SetLevelRequest(request) => self
38                .set_level(request.params, context)
39                .await
40                .map(ServerResult::empty),
41            ClientRequest::GetPromptRequest(request) => self
42                .get_prompt(request.params, context)
43                .await
44                .map(ServerResult::GetPromptResult),
45            ClientRequest::ListPromptsRequest(request) => self
46                .list_prompts(request.params, context)
47                .await
48                .map(ServerResult::ListPromptsResult),
49            ClientRequest::ListResourcesRequest(request) => self
50                .list_resources(request.params, context)
51                .await
52                .map(ServerResult::ListResourcesResult),
53            ClientRequest::ListResourceTemplatesRequest(request) => self
54                .list_resource_templates(request.params, context)
55                .await
56                .map(ServerResult::ListResourceTemplatesResult),
57            ClientRequest::ReadResourceRequest(request) => self
58                .read_resource(request.params, context)
59                .await
60                .map(ServerResult::ReadResourceResult),
61            ClientRequest::SubscribeRequest(request) => self
62                .subscribe(request.params, context)
63                .await
64                .map(ServerResult::empty),
65            ClientRequest::UnsubscribeRequest(request) => self
66                .unsubscribe(request.params, context)
67                .await
68                .map(ServerResult::empty),
69            ClientRequest::CallToolRequest(request) => {
70                if request.params.task.is_some() {
71                    tracing::info!("Enqueueing task for tool call: {}", request.params.name);
72                    self.enqueue_task(request.params, context.clone())
73                        .await
74                        .map(ServerResult::CreateTaskResult)
75                } else {
76                    self.call_tool(request.params, context)
77                        .await
78                        .map(ServerResult::CallToolResult)
79                }
80            }
81            ClientRequest::ListToolsRequest(request) => self
82                .list_tools(request.params, context)
83                .await
84                .map(ServerResult::ListToolsResult),
85            ClientRequest::CustomRequest(request) => self
86                .on_custom_request(request, context)
87                .await
88                .map(ServerResult::CustomResult),
89            ClientRequest::ListTasksRequest(request) => self
90                .list_tasks(request.params, context)
91                .await
92                .map(ServerResult::ListTasksResult),
93            ClientRequest::GetTaskInfoRequest(request) => self
94                .get_task_info(request.params, context)
95                .await
96                .map(ServerResult::GetTaskInfoResult),
97            ClientRequest::GetTaskResultRequest(request) => self
98                .get_task_result(request.params, context)
99                .await
100                .map(ServerResult::TaskResult),
101            ClientRequest::CancelTaskRequest(request) => self
102                .cancel_task(request.params, context)
103                .await
104                .map(ServerResult::empty),
105        }
106    }
107
108    async fn handle_notification(
109        &self,
110        notification: <RoleServer as ServiceRole>::PeerNot,
111        context: NotificationContext<RoleServer>,
112    ) -> Result<(), McpError> {
113        match notification {
114            ClientNotification::CancelledNotification(notification) => {
115                self.on_cancelled(notification.params, context).await
116            }
117            ClientNotification::ProgressNotification(notification) => {
118                self.on_progress(notification.params, context).await
119            }
120            ClientNotification::InitializedNotification(_notification) => {
121                self.on_initialized(context).await
122            }
123            ClientNotification::RootsListChangedNotification(_notification) => {
124                self.on_roots_list_changed(context).await
125            }
126            ClientNotification::CustomNotification(notification) => {
127                self.on_custom_notification(notification, context).await
128            }
129        };
130        Ok(())
131    }
132
133    fn get_info(&self) -> <RoleServer as ServiceRole>::Info {
134        self.get_info()
135    }
136}
137
138#[allow(unused_variables)]
139pub trait ServerHandler: Sized + Send + Sync + 'static {
140    fn enqueue_task(
141        &self,
142        _request: CallToolRequestParams,
143        _context: RequestContext<RoleServer>,
144    ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
145        std::future::ready(Err(McpError::internal_error(
146            "Task processing not implemented".to_string(),
147            None,
148        )))
149    }
150    fn ping(
151        &self,
152        context: RequestContext<RoleServer>,
153    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
154        std::future::ready(Ok(()))
155    }
156    // handle requests
157    fn initialize(
158        &self,
159        request: InitializeRequestParams,
160        context: RequestContext<RoleServer>,
161    ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
162        if context.peer.peer_info().is_none() {
163            context.peer.set_peer_info(request);
164        }
165        std::future::ready(Ok(self.get_info()))
166    }
167    fn complete(
168        &self,
169        request: CompleteRequestParams,
170        context: RequestContext<RoleServer>,
171    ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
172        std::future::ready(Ok(CompleteResult::default()))
173    }
174    fn set_level(
175        &self,
176        request: SetLevelRequestParams,
177        context: RequestContext<RoleServer>,
178    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
179        std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>()))
180    }
181    fn get_prompt(
182        &self,
183        request: GetPromptRequestParams,
184        context: RequestContext<RoleServer>,
185    ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
186        std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>()))
187    }
188    fn list_prompts(
189        &self,
190        request: Option<PaginatedRequestParams>,
191        context: RequestContext<RoleServer>,
192    ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
193        std::future::ready(Ok(ListPromptsResult::default()))
194    }
195    fn list_resources(
196        &self,
197        request: Option<PaginatedRequestParams>,
198        context: RequestContext<RoleServer>,
199    ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
200        std::future::ready(Ok(ListResourcesResult::default()))
201    }
202    fn list_resource_templates(
203        &self,
204        request: Option<PaginatedRequestParams>,
205        context: RequestContext<RoleServer>,
206    ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_ {
207        std::future::ready(Ok(ListResourceTemplatesResult::default()))
208    }
209    fn read_resource(
210        &self,
211        request: ReadResourceRequestParams,
212        context: RequestContext<RoleServer>,
213    ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
214        std::future::ready(Err(
215            McpError::method_not_found::<ReadResourceRequestMethod>(),
216        ))
217    }
218    fn subscribe(
219        &self,
220        request: SubscribeRequestParams,
221        context: RequestContext<RoleServer>,
222    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
223        std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>()))
224    }
225    fn unsubscribe(
226        &self,
227        request: UnsubscribeRequestParams,
228        context: RequestContext<RoleServer>,
229    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
230        std::future::ready(Err(McpError::method_not_found::<UnsubscribeRequestMethod>()))
231    }
232    fn call_tool(
233        &self,
234        request: CallToolRequestParams,
235        context: RequestContext<RoleServer>,
236    ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
237        std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>()))
238    }
239    fn list_tools(
240        &self,
241        request: Option<PaginatedRequestParams>,
242        context: RequestContext<RoleServer>,
243    ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
244        std::future::ready(Ok(ListToolsResult::default()))
245    }
246    fn on_custom_request(
247        &self,
248        request: CustomRequest,
249        context: RequestContext<RoleServer>,
250    ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
251        let CustomRequest { method, .. } = request;
252        let _ = context;
253        std::future::ready(Err(McpError::new(
254            ErrorCode::METHOD_NOT_FOUND,
255            method,
256            None,
257        )))
258    }
259
260    fn on_cancelled(
261        &self,
262        notification: CancelledNotificationParam,
263        context: NotificationContext<RoleServer>,
264    ) -> impl Future<Output = ()> + Send + '_ {
265        std::future::ready(())
266    }
267    fn on_progress(
268        &self,
269        notification: ProgressNotificationParam,
270        context: NotificationContext<RoleServer>,
271    ) -> impl Future<Output = ()> + Send + '_ {
272        std::future::ready(())
273    }
274    fn on_initialized(
275        &self,
276        context: NotificationContext<RoleServer>,
277    ) -> impl Future<Output = ()> + Send + '_ {
278        tracing::info!("client initialized");
279        std::future::ready(())
280    }
281    fn on_roots_list_changed(
282        &self,
283        context: NotificationContext<RoleServer>,
284    ) -> impl Future<Output = ()> + Send + '_ {
285        std::future::ready(())
286    }
287    fn on_custom_notification(
288        &self,
289        notification: CustomNotification,
290        context: NotificationContext<RoleServer>,
291    ) -> impl Future<Output = ()> + Send + '_ {
292        let _ = (notification, context);
293        std::future::ready(())
294    }
295
296    fn get_info(&self) -> ServerInfo {
297        ServerInfo::default()
298    }
299
300    fn list_tasks(
301        &self,
302        request: Option<PaginatedRequestParams>,
303        context: RequestContext<RoleServer>,
304    ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
305        std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
306    }
307
308    fn get_task_info(
309        &self,
310        request: GetTaskInfoParams,
311        context: RequestContext<RoleServer>,
312    ) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
313        std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
314    }
315
316    fn get_task_result(
317        &self,
318        request: GetTaskResultParams,
319        context: RequestContext<RoleServer>,
320    ) -> impl Future<Output = Result<TaskResult, McpError>> + Send + '_ {
321        let _ = (request, context);
322        std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>()))
323    }
324
325    fn cancel_task(
326        &self,
327        request: CancelTaskParams,
328        context: RequestContext<RoleServer>,
329    ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
330        let _ = (request, context);
331        std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
332    }
333}
334
335macro_rules! impl_server_handler_for_wrapper {
336    ($wrapper:ident) => {
337        impl<T: ServerHandler> ServerHandler for $wrapper<T> {
338            fn enqueue_task(
339                &self,
340                request: CallToolRequestParams,
341                context: RequestContext<RoleServer>,
342            ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
343                (**self).enqueue_task(request, context)
344            }
345
346            fn ping(
347                &self,
348                context: RequestContext<RoleServer>,
349            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
350                (**self).ping(context)
351            }
352
353            fn initialize(
354                &self,
355                request: InitializeRequestParams,
356                context: RequestContext<RoleServer>,
357            ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
358                (**self).initialize(request, context)
359            }
360
361            fn complete(
362                &self,
363                request: CompleteRequestParams,
364                context: RequestContext<RoleServer>,
365            ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
366                (**self).complete(request, context)
367            }
368
369            fn set_level(
370                &self,
371                request: SetLevelRequestParams,
372                context: RequestContext<RoleServer>,
373            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
374                (**self).set_level(request, context)
375            }
376
377            fn get_prompt(
378                &self,
379                request: GetPromptRequestParams,
380                context: RequestContext<RoleServer>,
381            ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
382                (**self).get_prompt(request, context)
383            }
384
385            fn list_prompts(
386                &self,
387                request: Option<PaginatedRequestParams>,
388                context: RequestContext<RoleServer>,
389            ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
390                (**self).list_prompts(request, context)
391            }
392
393            fn list_resources(
394                &self,
395                request: Option<PaginatedRequestParams>,
396                context: RequestContext<RoleServer>,
397            ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
398                (**self).list_resources(request, context)
399            }
400
401            fn list_resource_templates(
402                &self,
403                request: Option<PaginatedRequestParams>,
404                context: RequestContext<RoleServer>,
405            ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_
406            {
407                (**self).list_resource_templates(request, context)
408            }
409
410            fn read_resource(
411                &self,
412                request: ReadResourceRequestParams,
413                context: RequestContext<RoleServer>,
414            ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
415                (**self).read_resource(request, context)
416            }
417
418            fn subscribe(
419                &self,
420                request: SubscribeRequestParams,
421                context: RequestContext<RoleServer>,
422            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
423                (**self).subscribe(request, context)
424            }
425
426            fn unsubscribe(
427                &self,
428                request: UnsubscribeRequestParams,
429                context: RequestContext<RoleServer>,
430            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
431                (**self).unsubscribe(request, context)
432            }
433
434            fn call_tool(
435                &self,
436                request: CallToolRequestParams,
437                context: RequestContext<RoleServer>,
438            ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
439                (**self).call_tool(request, context)
440            }
441
442            fn list_tools(
443                &self,
444                request: Option<PaginatedRequestParams>,
445                context: RequestContext<RoleServer>,
446            ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
447                (**self).list_tools(request, context)
448            }
449
450            fn on_custom_request(
451                &self,
452                request: CustomRequest,
453                context: RequestContext<RoleServer>,
454            ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
455                (**self).on_custom_request(request, context)
456            }
457
458            fn on_cancelled(
459                &self,
460                notification: CancelledNotificationParam,
461                context: NotificationContext<RoleServer>,
462            ) -> impl Future<Output = ()> + Send + '_ {
463                (**self).on_cancelled(notification, context)
464            }
465
466            fn on_progress(
467                &self,
468                notification: ProgressNotificationParam,
469                context: NotificationContext<RoleServer>,
470            ) -> impl Future<Output = ()> + Send + '_ {
471                (**self).on_progress(notification, context)
472            }
473
474            fn on_initialized(
475                &self,
476                context: NotificationContext<RoleServer>,
477            ) -> impl Future<Output = ()> + Send + '_ {
478                (**self).on_initialized(context)
479            }
480
481            fn on_roots_list_changed(
482                &self,
483                context: NotificationContext<RoleServer>,
484            ) -> impl Future<Output = ()> + Send + '_ {
485                (**self).on_roots_list_changed(context)
486            }
487
488            fn on_custom_notification(
489                &self,
490                notification: CustomNotification,
491                context: NotificationContext<RoleServer>,
492            ) -> impl Future<Output = ()> + Send + '_ {
493                (**self).on_custom_notification(notification, context)
494            }
495
496            fn get_info(&self) -> ServerInfo {
497                (**self).get_info()
498            }
499
500            fn list_tasks(
501                &self,
502                request: Option<PaginatedRequestParams>,
503                context: RequestContext<RoleServer>,
504            ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
505                (**self).list_tasks(request, context)
506            }
507
508            fn get_task_info(
509                &self,
510                request: GetTaskInfoParams,
511                context: RequestContext<RoleServer>,
512            ) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
513                (**self).get_task_info(request, context)
514            }
515
516            fn get_task_result(
517                &self,
518                request: GetTaskResultParams,
519                context: RequestContext<RoleServer>,
520            ) -> impl Future<Output = Result<TaskResult, McpError>> + Send + '_ {
521                (**self).get_task_result(request, context)
522            }
523
524            fn cancel_task(
525                &self,
526                request: CancelTaskParams,
527                context: RequestContext<RoleServer>,
528            ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
529                (**self).cancel_task(request, context)
530            }
531        }
532    };
533}
534
535impl_server_handler_for_wrapper!(Box);
536impl_server_handler_for_wrapper!(Arc);