1use std::sync::Arc;
2
3use crate::{
4 error::ErrorData as McpError,
5 model::{TaskSupport, *},
6 service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole},
7};
8
9pub mod common;
10pub mod prompt;
11mod resource;
12pub mod router;
13pub mod tool;
14pub mod tool_name_validation;
15pub mod wrapper;
16
17impl<H: ServerHandler> Service<RoleServer> for H {
18 async fn handle_request(
19 &self,
20 request: <RoleServer as ServiceRole>::PeerReq,
21 context: RequestContext<RoleServer>,
22 ) -> Result<<RoleServer as ServiceRole>::Resp, McpError> {
23 match request {
24 ClientRequest::InitializeRequest(request) => self
25 .initialize(request.params, context)
26 .await
27 .map(ServerResult::InitializeResult),
28 ClientRequest::PingRequest(_request) => {
29 self.ping(context).await.map(ServerResult::empty)
30 }
31 ClientRequest::CompleteRequest(request) => self
32 .complete(request.params, context)
33 .await
34 .map(ServerResult::CompleteResult),
35 ClientRequest::SetLevelRequest(request) => self
36 .set_level(request.params, context)
37 .await
38 .map(ServerResult::empty),
39 ClientRequest::GetPromptRequest(request) => self
40 .get_prompt(request.params, context)
41 .await
42 .map(ServerResult::GetPromptResult),
43 ClientRequest::ListPromptsRequest(request) => self
44 .list_prompts(request.params, context)
45 .await
46 .map(ServerResult::ListPromptsResult),
47 ClientRequest::ListResourcesRequest(request) => self
48 .list_resources(request.params, context)
49 .await
50 .map(ServerResult::ListResourcesResult),
51 ClientRequest::ListResourceTemplatesRequest(request) => self
52 .list_resource_templates(request.params, context)
53 .await
54 .map(ServerResult::ListResourceTemplatesResult),
55 ClientRequest::ReadResourceRequest(request) => self
56 .read_resource(request.params, context)
57 .await
58 .map(ServerResult::ReadResourceResult),
59 ClientRequest::SubscribeRequest(request) => self
60 .subscribe(request.params, context)
61 .await
62 .map(ServerResult::empty),
63 ClientRequest::UnsubscribeRequest(request) => self
64 .unsubscribe(request.params, context)
65 .await
66 .map(ServerResult::empty),
67 ClientRequest::CallToolRequest(request) => {
68 let is_task = request.params.task.is_some();
69
70 if let Some(tool) = self.get_tool(&request.params.name) {
72 match (tool.task_support(), is_task) {
73 (TaskSupport::Required, false) => {
76 return Err(McpError::new(
77 ErrorCode::METHOD_NOT_FOUND,
78 "Tool requires task-based invocation",
79 None,
80 ));
81 }
82 (TaskSupport::Forbidden, true) => {
84 return Err(McpError::invalid_params(
85 "Tool does not support task-based invocation",
86 None,
87 ));
88 }
89 _ => {}
90 }
91 }
92
93 if is_task {
94 tracing::info!("Enqueueing task for tool call: {}", request.params.name);
95 self.enqueue_task(request.params, context.clone())
96 .await
97 .map(ServerResult::CreateTaskResult)
98 } else {
99 self.call_tool(request.params, context)
100 .await
101 .map(ServerResult::CallToolResult)
102 }
103 }
104 ClientRequest::ListToolsRequest(request) => self
105 .list_tools(request.params, context)
106 .await
107 .map(ServerResult::ListToolsResult),
108 ClientRequest::CustomRequest(request) => self
109 .on_custom_request(request, context)
110 .await
111 .map(ServerResult::CustomResult),
112 ClientRequest::ListTasksRequest(request) => self
113 .list_tasks(request.params, context)
114 .await
115 .map(ServerResult::ListTasksResult),
116 ClientRequest::GetTaskInfoRequest(request) => self
117 .get_task_info(request.params, context)
118 .await
119 .map(ServerResult::GetTaskResult),
120 ClientRequest::GetTaskResultRequest(request) => self
121 .get_task_result(request.params, context)
122 .await
123 .map(ServerResult::GetTaskPayloadResult),
124 ClientRequest::CancelTaskRequest(request) => self
125 .cancel_task(request.params, context)
126 .await
127 .map(ServerResult::CancelTaskResult),
128 }
129 }
130
131 async fn handle_notification(
132 &self,
133 notification: <RoleServer as ServiceRole>::PeerNot,
134 context: NotificationContext<RoleServer>,
135 ) -> Result<(), McpError> {
136 match notification {
137 ClientNotification::CancelledNotification(notification) => {
138 self.on_cancelled(notification.params, context).await
139 }
140 ClientNotification::ProgressNotification(notification) => {
141 self.on_progress(notification.params, context).await
142 }
143 ClientNotification::InitializedNotification(_notification) => {
144 self.on_initialized(context).await
145 }
146 ClientNotification::RootsListChangedNotification(_notification) => {
147 self.on_roots_list_changed(context).await
148 }
149 ClientNotification::CustomNotification(notification) => {
150 self.on_custom_notification(notification, context).await
151 }
152 };
153 Ok(())
154 }
155
156 fn get_info(&self) -> <RoleServer as ServiceRole>::Info {
157 self.get_info()
158 }
159}
160
161#[allow(unused_variables)]
162pub trait ServerHandler: Sized + Send + Sync + 'static {
163 fn enqueue_task(
164 &self,
165 _request: CallToolRequestParams,
166 _context: RequestContext<RoleServer>,
167 ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
168 std::future::ready(Err(McpError::internal_error(
169 "Task processing not implemented".to_string(),
170 None,
171 )))
172 }
173 fn ping(
174 &self,
175 context: RequestContext<RoleServer>,
176 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
177 std::future::ready(Ok(()))
178 }
179 fn initialize(
181 &self,
182 request: InitializeRequestParams,
183 context: RequestContext<RoleServer>,
184 ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
185 if context.peer.peer_info().is_none() {
186 context.peer.set_peer_info(request);
187 }
188 std::future::ready(Ok(self.get_info()))
189 }
190 fn complete(
191 &self,
192 request: CompleteRequestParams,
193 context: RequestContext<RoleServer>,
194 ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
195 std::future::ready(Ok(CompleteResult::default()))
196 }
197 fn set_level(
198 &self,
199 request: SetLevelRequestParams,
200 context: RequestContext<RoleServer>,
201 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
202 std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>()))
203 }
204 fn get_prompt(
205 &self,
206 request: GetPromptRequestParams,
207 context: RequestContext<RoleServer>,
208 ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
209 std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>()))
210 }
211 fn list_prompts(
212 &self,
213 request: Option<PaginatedRequestParams>,
214 context: RequestContext<RoleServer>,
215 ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
216 std::future::ready(Ok(ListPromptsResult::default()))
217 }
218 fn list_resources(
219 &self,
220 request: Option<PaginatedRequestParams>,
221 context: RequestContext<RoleServer>,
222 ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
223 std::future::ready(Ok(ListResourcesResult::default()))
224 }
225 fn list_resource_templates(
226 &self,
227 request: Option<PaginatedRequestParams>,
228 context: RequestContext<RoleServer>,
229 ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_ {
230 std::future::ready(Ok(ListResourceTemplatesResult::default()))
231 }
232 fn read_resource(
233 &self,
234 request: ReadResourceRequestParams,
235 context: RequestContext<RoleServer>,
236 ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
237 std::future::ready(Err(
238 McpError::method_not_found::<ReadResourceRequestMethod>(),
239 ))
240 }
241 fn subscribe(
242 &self,
243 request: SubscribeRequestParams,
244 context: RequestContext<RoleServer>,
245 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
246 std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>()))
247 }
248 fn unsubscribe(
249 &self,
250 request: UnsubscribeRequestParams,
251 context: RequestContext<RoleServer>,
252 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
253 std::future::ready(Err(McpError::method_not_found::<UnsubscribeRequestMethod>()))
254 }
255 fn call_tool(
256 &self,
257 request: CallToolRequestParams,
258 context: RequestContext<RoleServer>,
259 ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
260 std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>()))
261 }
262 fn list_tools(
263 &self,
264 request: Option<PaginatedRequestParams>,
265 context: RequestContext<RoleServer>,
266 ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
267 std::future::ready(Ok(ListToolsResult::default()))
268 }
269 fn get_tool(&self, _name: &str) -> Option<Tool> {
274 None
275 }
276 fn on_custom_request(
277 &self,
278 request: CustomRequest,
279 context: RequestContext<RoleServer>,
280 ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
281 let CustomRequest { method, .. } = request;
282 let _ = context;
283 std::future::ready(Err(McpError::new(
284 ErrorCode::METHOD_NOT_FOUND,
285 method,
286 None,
287 )))
288 }
289
290 fn on_cancelled(
291 &self,
292 notification: CancelledNotificationParam,
293 context: NotificationContext<RoleServer>,
294 ) -> impl Future<Output = ()> + Send + '_ {
295 std::future::ready(())
296 }
297 fn on_progress(
298 &self,
299 notification: ProgressNotificationParam,
300 context: NotificationContext<RoleServer>,
301 ) -> impl Future<Output = ()> + Send + '_ {
302 std::future::ready(())
303 }
304 fn on_initialized(
305 &self,
306 context: NotificationContext<RoleServer>,
307 ) -> impl Future<Output = ()> + Send + '_ {
308 tracing::info!("client initialized");
309 std::future::ready(())
310 }
311 fn on_roots_list_changed(
312 &self,
313 context: NotificationContext<RoleServer>,
314 ) -> impl Future<Output = ()> + Send + '_ {
315 std::future::ready(())
316 }
317 fn on_custom_notification(
318 &self,
319 notification: CustomNotification,
320 context: NotificationContext<RoleServer>,
321 ) -> impl Future<Output = ()> + Send + '_ {
322 let _ = (notification, context);
323 std::future::ready(())
324 }
325
326 fn get_info(&self) -> ServerInfo {
327 ServerInfo::default()
328 }
329
330 fn list_tasks(
331 &self,
332 request: Option<PaginatedRequestParams>,
333 context: RequestContext<RoleServer>,
334 ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
335 std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
336 }
337
338 fn get_task_info(
339 &self,
340 request: GetTaskInfoParams,
341 context: RequestContext<RoleServer>,
342 ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
343 let _ = (request, context);
344 std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
345 }
346
347 fn get_task_result(
348 &self,
349 request: GetTaskResultParams,
350 context: RequestContext<RoleServer>,
351 ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
352 let _ = (request, context);
353 std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>()))
354 }
355
356 fn cancel_task(
357 &self,
358 request: CancelTaskParams,
359 context: RequestContext<RoleServer>,
360 ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
361 let _ = (request, context);
362 std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
363 }
364}
365
366macro_rules! impl_server_handler_for_wrapper {
367 ($wrapper:ident) => {
368 impl<T: ServerHandler> ServerHandler for $wrapper<T> {
369 fn enqueue_task(
370 &self,
371 request: CallToolRequestParams,
372 context: RequestContext<RoleServer>,
373 ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ {
374 (**self).enqueue_task(request, context)
375 }
376
377 fn ping(
378 &self,
379 context: RequestContext<RoleServer>,
380 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
381 (**self).ping(context)
382 }
383
384 fn initialize(
385 &self,
386 request: InitializeRequestParams,
387 context: RequestContext<RoleServer>,
388 ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
389 (**self).initialize(request, context)
390 }
391
392 fn complete(
393 &self,
394 request: CompleteRequestParams,
395 context: RequestContext<RoleServer>,
396 ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ {
397 (**self).complete(request, context)
398 }
399
400 fn set_level(
401 &self,
402 request: SetLevelRequestParams,
403 context: RequestContext<RoleServer>,
404 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
405 (**self).set_level(request, context)
406 }
407
408 fn get_prompt(
409 &self,
410 request: GetPromptRequestParams,
411 context: RequestContext<RoleServer>,
412 ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ {
413 (**self).get_prompt(request, context)
414 }
415
416 fn list_prompts(
417 &self,
418 request: Option<PaginatedRequestParams>,
419 context: RequestContext<RoleServer>,
420 ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ {
421 (**self).list_prompts(request, context)
422 }
423
424 fn list_resources(
425 &self,
426 request: Option<PaginatedRequestParams>,
427 context: RequestContext<RoleServer>,
428 ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
429 (**self).list_resources(request, context)
430 }
431
432 fn list_resource_templates(
433 &self,
434 request: Option<PaginatedRequestParams>,
435 context: RequestContext<RoleServer>,
436 ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_
437 {
438 (**self).list_resource_templates(request, context)
439 }
440
441 fn read_resource(
442 &self,
443 request: ReadResourceRequestParams,
444 context: RequestContext<RoleServer>,
445 ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
446 (**self).read_resource(request, context)
447 }
448
449 fn subscribe(
450 &self,
451 request: SubscribeRequestParams,
452 context: RequestContext<RoleServer>,
453 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
454 (**self).subscribe(request, context)
455 }
456
457 fn unsubscribe(
458 &self,
459 request: UnsubscribeRequestParams,
460 context: RequestContext<RoleServer>,
461 ) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
462 (**self).unsubscribe(request, context)
463 }
464
465 fn call_tool(
466 &self,
467 request: CallToolRequestParams,
468 context: RequestContext<RoleServer>,
469 ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
470 (**self).call_tool(request, context)
471 }
472
473 fn list_tools(
474 &self,
475 request: Option<PaginatedRequestParams>,
476 context: RequestContext<RoleServer>,
477 ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
478 (**self).list_tools(request, context)
479 }
480
481 fn get_tool(&self, name: &str) -> Option<Tool> {
482 (**self).get_tool(name)
483 }
484
485 fn on_custom_request(
486 &self,
487 request: CustomRequest,
488 context: RequestContext<RoleServer>,
489 ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
490 (**self).on_custom_request(request, context)
491 }
492
493 fn on_cancelled(
494 &self,
495 notification: CancelledNotificationParam,
496 context: NotificationContext<RoleServer>,
497 ) -> impl Future<Output = ()> + Send + '_ {
498 (**self).on_cancelled(notification, context)
499 }
500
501 fn on_progress(
502 &self,
503 notification: ProgressNotificationParam,
504 context: NotificationContext<RoleServer>,
505 ) -> impl Future<Output = ()> + Send + '_ {
506 (**self).on_progress(notification, context)
507 }
508
509 fn on_initialized(
510 &self,
511 context: NotificationContext<RoleServer>,
512 ) -> impl Future<Output = ()> + Send + '_ {
513 (**self).on_initialized(context)
514 }
515
516 fn on_roots_list_changed(
517 &self,
518 context: NotificationContext<RoleServer>,
519 ) -> impl Future<Output = ()> + Send + '_ {
520 (**self).on_roots_list_changed(context)
521 }
522
523 fn on_custom_notification(
524 &self,
525 notification: CustomNotification,
526 context: NotificationContext<RoleServer>,
527 ) -> impl Future<Output = ()> + Send + '_ {
528 (**self).on_custom_notification(notification, context)
529 }
530
531 fn get_info(&self) -> ServerInfo {
532 (**self).get_info()
533 }
534
535 fn list_tasks(
536 &self,
537 request: Option<PaginatedRequestParams>,
538 context: RequestContext<RoleServer>,
539 ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
540 (**self).list_tasks(request, context)
541 }
542
543 fn get_task_info(
544 &self,
545 request: GetTaskInfoParams,
546 context: RequestContext<RoleServer>,
547 ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
548 (**self).get_task_info(request, context)
549 }
550
551 fn get_task_result(
552 &self,
553 request: GetTaskResultParams,
554 context: RequestContext<RoleServer>,
555 ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
556 (**self).get_task_result(request, context)
557 }
558
559 fn cancel_task(
560 &self,
561 request: CancelTaskParams,
562 context: RequestContext<RoleServer>,
563 ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
564 (**self).cancel_task(request, context)
565 }
566 }
567 };
568}
569
570impl_server_handler_for_wrapper!(Box);
571impl_server_handler_for_wrapper!(Arc);