1use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
39use mcpkit_core::error::McpError;
40use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
41use mcpkit_core::protocol_version::ProtocolVersion;
42use mcpkit_core::types::CallToolResult;
43use mcpkit_transport::Transport;
44use std::collections::HashMap;
45use std::sync::Arc;
46use std::sync::RwLock;
47use std::sync::atomic::{AtomicBool, Ordering};
48
49pub struct ServerState {
51 pub client_caps: RwLock<ClientCapabilities>,
53 pub server_caps: ServerCapabilities,
55 pub initialized: AtomicBool,
57 pub cancellations: RwLock<HashMap<String, CancellationToken>>,
59 pub negotiated_version: RwLock<Option<ProtocolVersion>>,
64}
65
66impl ServerState {
67 #[must_use]
69 pub fn new(server_caps: ServerCapabilities) -> Self {
70 Self {
71 client_caps: RwLock::new(ClientCapabilities::default()),
72 server_caps,
73 initialized: AtomicBool::new(false),
74 cancellations: RwLock::new(HashMap::new()),
75 negotiated_version: RwLock::new(None),
76 }
77 }
78
79 pub fn protocol_version(&self) -> Option<ProtocolVersion> {
93 self.negotiated_version.read().ok().and_then(|guard| *guard)
94 }
95
96 pub fn set_protocol_version(&self, version: ProtocolVersion) {
100 if let Ok(mut guard) = self.negotiated_version.write() {
101 *guard = Some(version);
102 }
103 }
104
105 pub fn client_caps(&self) -> ClientCapabilities {
109 self.client_caps
110 .read()
111 .map(|guard| guard.clone())
112 .unwrap_or_default()
113 }
114
115 pub fn set_client_caps(&self, caps: ClientCapabilities) {
119 if let Ok(mut guard) = self.client_caps.write() {
120 *guard = caps;
121 }
122 }
123
124 pub fn is_initialized(&self) -> bool {
126 self.initialized.load(Ordering::Acquire)
127 }
128
129 pub fn set_initialized(&self) {
131 self.initialized.store(true, Ordering::Release);
132 }
133
134 pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
136 if let Ok(mut cancellations) = self.cancellations.write() {
137 cancellations.insert(request_id.to_string(), token);
138 }
139 }
140
141 pub fn cancel_request(&self, request_id: &str) {
143 if let Ok(cancellations) = self.cancellations.read() {
144 if let Some(token) = cancellations.get(request_id) {
145 token.cancel();
146 }
147 }
148 }
149
150 pub fn remove_cancellation(&self, request_id: &str) {
152 if let Ok(mut cancellations) = self.cancellations.write() {
153 cancellations.remove(request_id);
154 }
155 }
156}
157
158pub struct TransportPeer<T: Transport> {
160 transport: Arc<T>,
161}
162
163impl<T: Transport> TransportPeer<T> {
164 pub const fn new(transport: Arc<T>) -> Self {
166 Self { transport }
167 }
168}
169
170impl<T: Transport + 'static> Peer for TransportPeer<T>
171where
172 T::Error: Into<McpError>,
173{
174 fn notify(
175 &self,
176 notification: Notification,
177 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
178 {
179 let transport = self.transport.clone();
180 Box::pin(async move {
181 transport
182 .send(Message::Notification(notification))
183 .await
184 .map_err(std::convert::Into::into)
185 })
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct RuntimeConfig {
192 pub auto_initialized: bool,
194 pub max_concurrent_requests: usize,
196}
197
198impl Default for RuntimeConfig {
199 fn default() -> Self {
200 Self {
201 auto_initialized: true,
202 max_concurrent_requests: 100,
203 }
204 }
205}
206
207pub struct ServerRuntime<S, Tr>
212where
213 Tr: Transport,
214{
215 server: S,
216 transport: Arc<Tr>,
217 state: Arc<ServerState>,
218 #[allow(dead_code)]
220 config: RuntimeConfig,
221}
222
223impl<S, Tr> ServerRuntime<S, Tr>
224where
225 S: RequestRouter + Send + Sync,
226 Tr: Transport + 'static,
227 Tr::Error: Into<McpError>,
228{
229 pub const fn state(&self) -> &Arc<ServerState> {
231 &self.state
232 }
233
234 pub async fn run(&self) -> Result<(), McpError> {
238 loop {
239 match self.transport.recv().await {
240 Ok(Some(message)) => {
241 if let Err(e) = self.handle_message(message).await {
242 tracing::error!(error = %e, "Error handling message");
243 }
244 }
245 Ok(None) => {
246 tracing::info!("Connection closed");
248 break;
249 }
250 Err(e) => {
251 let err: McpError = e.into();
252 tracing::error!(error = %err, "Transport error");
253 return Err(err);
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 async fn handle_message(&self, message: Message) -> Result<(), McpError> {
263 match message {
264 Message::Request(request) => self.handle_request(request).await,
265 Message::Notification(notification) => self.handle_notification(notification).await,
266 Message::Response(_) => {
267 tracing::warn!("Received unexpected response message");
269 Ok(())
270 }
271 }
272 }
273
274 async fn handle_request(&self, request: Request) -> Result<(), McpError> {
276 let method = request.method.to_string();
277 let id = request.id.clone();
278
279 tracing::debug!(method = %method, id = %id, "Handling request");
280
281 let response = match method.as_str() {
282 "initialize" => self.handle_initialize(&request).await,
283 _ if !self.state.is_initialized() => {
284 Err(McpError::invalid_request("Server not initialized"))
285 }
286 _ => self.route_request(&request).await,
287 };
288
289 let response_msg = match response {
291 Ok(result) => Response::success(id, result),
292 Err(e) => Response::error(id, e.into()),
293 };
294
295 self.transport
296 .send(Message::Response(response_msg))
297 .await
298 .map_err(std::convert::Into::into)
299 }
300
301 async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
308 if self.state.is_initialized() {
309 return Err(McpError::invalid_request("Already initialized"));
310 }
311
312 let params = request
314 .params
315 .as_ref()
316 .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
317
318 let requested_version_str = params
320 .get("protocolVersion")
321 .and_then(|v| v.as_str())
322 .unwrap_or("");
323
324 let negotiated_version =
326 ProtocolVersion::negotiate(requested_version_str, ProtocolVersion::ALL)
327 .unwrap_or(ProtocolVersion::LATEST);
328
329 if requested_version_str == negotiated_version.as_str() {
331 tracing::debug!(
332 version = %negotiated_version,
333 "Protocol version negotiated successfully"
334 );
335 } else {
336 tracing::info!(
337 requested = %requested_version_str,
338 negotiated = %negotiated_version,
339 supported = ?ProtocolVersion::ALL.iter().map(ProtocolVersion::as_str).collect::<Vec<_>>(),
340 "Protocol version negotiation: client requested different version"
341 );
342 }
343
344 self.state.set_protocol_version(negotiated_version);
346
347 if let Some(caps) = params.get("capabilities") {
349 if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
350 self.state.set_client_caps(client_caps);
351 }
352 }
353
354 let result = serde_json::json!({
356 "protocolVersion": negotiated_version.as_str(),
357 "serverInfo": self.server.server_info(),
358 "capabilities": self.state.server_caps
359 });
360
361 self.state.set_initialized();
362
363 Ok(result)
364 }
365
366 async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
368 let method = request.method.as_ref();
369 let params = request.params.as_ref();
370
371 let progress_token = extract_progress_token(params);
373
374 let peer = TransportPeer::new(self.transport.clone());
376 let client_caps = self.state.client_caps();
377 let protocol_version = self
378 .state
379 .protocol_version()
380 .unwrap_or(ProtocolVersion::LATEST);
381 let ctx = Context::new(
382 &request.id,
383 progress_token.as_ref(),
384 &client_caps,
385 &self.state.server_caps,
386 protocol_version,
387 &peer,
388 );
389
390 self.server.route(method, params, &ctx).await
392 }
393
394 async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
396 let method = notification.method.as_ref();
397
398 tracing::debug!(method = %method, "Handling notification");
399
400 match method {
401 "notifications/initialized" => {
402 tracing::info!("Client sent initialized notification");
403 Ok(())
404 }
405 "notifications/cancelled" => {
406 if let Some(params) = ¬ification.params {
407 if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
408 self.state.cancel_request(request_id);
409 }
410 }
411 Ok(())
412 }
413 _ => {
414 tracing::debug!(method = %method, "Ignoring unknown notification");
415 Ok(())
416 }
417 }
418 }
419}
420
421impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
423where
424 H: ServerHandler + Send + Sync,
425 T: Send + Sync,
426 R: Send + Sync,
427 P: Send + Sync,
428 K: Send + Sync,
429 Tr: Transport + 'static,
430 Tr::Error: Into<McpError>,
431{
432 pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
434 let caps = server.capabilities().clone();
435 Self {
436 server,
437 transport: Arc::new(transport),
438 state: Arc::new(ServerState::new(caps)),
439 config: RuntimeConfig::default(),
440 }
441 }
442
443 pub fn with_config(
445 server: Server<H, T, R, P, K>,
446 transport: Tr,
447 config: RuntimeConfig,
448 ) -> Self {
449 let caps = server.capabilities().clone();
450 Self {
451 server,
452 transport: Arc::new(transport),
453 state: Arc::new(ServerState::new(caps)),
454 config,
455 }
456 }
457}
458
459#[allow(async_fn_in_trait)]
464pub trait RequestRouter: Send + Sync {
465 fn server_info(&self) -> mcpkit_core::capability::ServerInfo;
467
468 async fn route(
470 &self,
471 method: &str,
472 params: Option<&serde_json::Value>,
473 ctx: &Context<'_>,
474 ) -> Result<serde_json::Value, McpError>;
475}
476
477impl<H, T, R, P, K> Server<H, T, R, P, K>
479where
480 H: ServerHandler + Send + Sync + 'static,
481 T: Send + Sync + 'static,
482 R: Send + Sync + 'static,
483 P: Send + Sync + 'static,
484 K: Send + Sync + 'static,
485 Self: RequestRouter,
486{
487 pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
489 where
490 Tr: Transport + 'static,
491 Tr::Error: Into<McpError>,
492 {
493 let runtime = ServerRuntime::new(self, transport);
494 runtime.run().await
495 }
496}
497
498async fn route_tools<TH: ToolHandler + Send + Sync>(
506 handler: &TH,
507 method: &str,
508 params: Option<&serde_json::Value>,
509 ctx: &Context<'_>,
510) -> Option<Result<serde_json::Value, McpError>> {
511 match method {
512 "tools/list" => {
513 tracing::debug!("Listing available tools");
514 let result = handler.list_tools(ctx).await;
515 match &result {
516 Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
517 Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
518 }
519 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
520 }
521 "tools/call" => {
522 let result = async {
523 let params = params.ok_or_else(|| {
524 McpError::invalid_params("tools/call", "missing params")
525 })?;
526 let name = params.get("name")
527 .and_then(|v| v.as_str())
528 .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
529 let args = params.get("arguments")
530 .cloned()
531 .unwrap_or_else(|| serde_json::json!({}));
532
533 tracing::info!(tool = %name, "Calling tool");
534 let start = std::time::Instant::now();
535 let output = handler.call_tool(name, args, ctx).await;
536 let duration = start.elapsed();
537
538 match &output {
539 Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
540 Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
541 }
542
543 let output = output?;
544 let result: CallToolResult = output.into();
545 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
546 }.await;
547 Some(result)
548 }
549 _ => None,
550 }
551}
552
553async fn route_resources<RH: ResourceHandler + Send + Sync>(
554 handler: &RH,
555 method: &str,
556 params: Option<&serde_json::Value>,
557 ctx: &Context<'_>,
558) -> Option<Result<serde_json::Value, McpError>> {
559 match method {
560 "resources/list" => {
561 tracing::debug!("Listing available resources");
562 let result = handler.list_resources(ctx).await;
563 match &result {
564 Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
565 Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
566 }
567 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
568 }
569 "resources/templates/list" => {
570 tracing::debug!("Listing available resource templates");
571 let result = handler.list_resource_templates(ctx).await;
572 match &result {
573 Ok(templates) => {
574 tracing::debug!(count = templates.len(), "Listed resource templates");
575 }
576 Err(e) => tracing::warn!(error = %e, "Failed to list resource templates"),
577 }
578 Some(result.map(|templates| serde_json::json!({ "resourceTemplates": templates })))
579 }
580 "resources/read" => {
581 let result = async {
582 let params = params.ok_or_else(|| {
583 McpError::invalid_params("resources/read", "missing params")
584 })?;
585 let uri = params.get("uri")
586 .and_then(|v| v.as_str())
587 .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
588
589 tracing::info!(uri = %uri, "Reading resource");
590 let start = std::time::Instant::now();
591 let contents = handler.read_resource(uri, ctx).await;
592 let duration = start.elapsed();
593
594 match &contents {
595 Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
596 Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
597 }
598
599 let contents = contents?;
600 Ok(serde_json::json!({ "contents": contents }))
601 }.await;
602 Some(result)
603 }
604 _ => None,
605 }
606}
607
608async fn route_prompts<PH: PromptHandler + Send + Sync>(
609 handler: &PH,
610 method: &str,
611 params: Option<&serde_json::Value>,
612 ctx: &Context<'_>,
613) -> Option<Result<serde_json::Value, McpError>> {
614 match method {
615 "prompts/list" => {
616 tracing::debug!("Listing available prompts");
617 let result = handler.list_prompts(ctx).await;
618 match &result {
619 Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
620 Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
621 }
622 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
623 }
624 "prompts/get" => {
625 let result = async {
626 let params = params.ok_or_else(|| {
627 McpError::invalid_params("prompts/get", "missing params")
628 })?;
629 let name = params.get("name")
630 .and_then(|v| v.as_str())
631 .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
632 let args = params.get("arguments")
633 .and_then(|v| v.as_object())
634 .cloned();
635
636 tracing::info!(prompt = %name, "Getting prompt");
637 let start = std::time::Instant::now();
638 let prompt_result = handler.get_prompt(name, args, ctx).await;
639 let duration = start.elapsed();
640
641 match &prompt_result {
642 Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
643 Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
644 }
645
646 let result = prompt_result?;
647 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
648 }.await;
649 Some(result)
650 }
651 _ => None,
652 }
653}
654
655macro_rules! impl_request_router {
660 (base; $($bounds:tt)*) => {
662 impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
663 where
664 H: ServerHandler + Send + Sync,
665 {
666 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
667 self.handler().server_info()
668 }
669
670 async fn route(
671 &self,
672 method: &str,
673 _params: Option<&serde_json::Value>,
674 _ctx: &Context<'_>,
675 ) -> Result<serde_json::Value, McpError> {
676 match method {
677 "ping" => Ok(serde_json::json!({})),
678 _ => Err(McpError::method_not_found(method)),
679 }
680 }
681 }
682 };
683
684 (tools; $($bounds:tt)*) => {
686 impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
687 where
688 H: ServerHandler + Send + Sync,
689 TH: ToolHandler + Send + Sync,
690 {
691 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
692 self.handler().server_info()
693 }
694
695 async fn route(
696 &self,
697 method: &str,
698 params: Option<&serde_json::Value>,
699 ctx: &Context<'_>,
700 ) -> Result<serde_json::Value, McpError> {
701 if method == "ping" {
702 return Ok(serde_json::json!({}));
703 }
704 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
705 return result;
706 }
707 Err(McpError::method_not_found(method))
708 }
709 }
710 };
711
712 (resources; $($bounds:tt)*) => {
714 impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
715 where
716 H: ServerHandler + Send + Sync,
717 RH: ResourceHandler + Send + Sync,
718 {
719 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
720 self.handler().server_info()
721 }
722
723 async fn route(
724 &self,
725 method: &str,
726 params: Option<&serde_json::Value>,
727 ctx: &Context<'_>,
728 ) -> Result<serde_json::Value, McpError> {
729 if method == "ping" {
730 return Ok(serde_json::json!({}));
731 }
732 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
733 return result;
734 }
735 Err(McpError::method_not_found(method))
736 }
737 }
738 };
739
740 (prompts; $($bounds:tt)*) => {
742 impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
743 where
744 H: ServerHandler + Send + Sync,
745 PH: PromptHandler + Send + Sync,
746 {
747 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
748 self.handler().server_info()
749 }
750
751 async fn route(
752 &self,
753 method: &str,
754 params: Option<&serde_json::Value>,
755 ctx: &Context<'_>,
756 ) -> Result<serde_json::Value, McpError> {
757 if method == "ping" {
758 return Ok(serde_json::json!({}));
759 }
760 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
761 return result;
762 }
763 Err(McpError::method_not_found(method))
764 }
765 }
766 };
767
768 (tools_resources; $($bounds:tt)*) => {
770 impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
771 where
772 H: ServerHandler + Send + Sync,
773 TH: ToolHandler + Send + Sync,
774 RH: ResourceHandler + Send + Sync,
775 {
776 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
777 self.handler().server_info()
778 }
779
780 async fn route(
781 &self,
782 method: &str,
783 params: Option<&serde_json::Value>,
784 ctx: &Context<'_>,
785 ) -> Result<serde_json::Value, McpError> {
786 if method == "ping" {
787 return Ok(serde_json::json!({}));
788 }
789 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
790 return result;
791 }
792 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
793 return result;
794 }
795 Err(McpError::method_not_found(method))
796 }
797 }
798 };
799
800 (tools_prompts; $($bounds:tt)*) => {
802 impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
803 where
804 H: ServerHandler + Send + Sync,
805 TH: ToolHandler + Send + Sync,
806 PH: PromptHandler + Send + Sync,
807 {
808 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
809 self.handler().server_info()
810 }
811
812 async fn route(
813 &self,
814 method: &str,
815 params: Option<&serde_json::Value>,
816 ctx: &Context<'_>,
817 ) -> Result<serde_json::Value, McpError> {
818 if method == "ping" {
819 return Ok(serde_json::json!({}));
820 }
821 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
822 return result;
823 }
824 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
825 return result;
826 }
827 Err(McpError::method_not_found(method))
828 }
829 }
830 };
831
832 (resources_prompts; $($bounds:tt)*) => {
834 impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
835 where
836 H: ServerHandler + Send + Sync,
837 RH: ResourceHandler + Send + Sync,
838 PH: PromptHandler + Send + Sync,
839 {
840 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
841 self.handler().server_info()
842 }
843
844 async fn route(
845 &self,
846 method: &str,
847 params: Option<&serde_json::Value>,
848 ctx: &Context<'_>,
849 ) -> Result<serde_json::Value, McpError> {
850 if method == "ping" {
851 return Ok(serde_json::json!({}));
852 }
853 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
854 return result;
855 }
856 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
857 return result;
858 }
859 Err(McpError::method_not_found(method))
860 }
861 }
862 };
863
864 (tools_resources_prompts; $($bounds:tt)*) => {
866 impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
867 where
868 H: ServerHandler + Send + Sync,
869 TH: ToolHandler + Send + Sync,
870 RH: ResourceHandler + Send + Sync,
871 PH: PromptHandler + Send + Sync,
872 {
873 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
874 self.handler().server_info()
875 }
876
877 async fn route(
878 &self,
879 method: &str,
880 params: Option<&serde_json::Value>,
881 ctx: &Context<'_>,
882 ) -> Result<serde_json::Value, McpError> {
883 if method == "ping" {
884 return Ok(serde_json::json!({}));
885 }
886 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
887 return result;
888 }
889 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
890 return result;
891 }
892 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
893 return result;
894 }
895 Err(McpError::method_not_found(method))
896 }
897 }
898 };
899}
900
901impl_request_router!(base;);
903impl_request_router!(tools;);
904impl_request_router!(resources;);
905impl_request_router!(prompts;);
906impl_request_router!(tools_resources;);
907impl_request_router!(tools_prompts;);
908impl_request_router!(resources_prompts;);
909impl_request_router!(tools_resources_prompts;);
910
911fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
932 params?
933 .get("_meta")?
934 .get("progressToken")
935 .and_then(|v| serde_json::from_value(v.clone()).ok())
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941
942 #[test]
943 fn test_server_state_initialization() {
944 let state = ServerState::new(ServerCapabilities::default());
945 assert!(!state.is_initialized());
946
947 state.set_initialized();
948 assert!(state.is_initialized());
949 }
950
951 #[test]
952 fn test_cancellation_management() {
953 let state = ServerState::new(ServerCapabilities::default());
954 let token = CancellationToken::new();
955
956 state.register_cancellation("req-1", token.clone());
957 assert!(!token.is_cancelled());
958
959 state.cancel_request("req-1");
960 assert!(token.is_cancelled());
961
962 state.remove_cancellation("req-1");
963 }
964
965 #[test]
966 fn test_runtime_config_default() {
967 let config = RuntimeConfig::default();
968 assert!(config.auto_initialized);
969 assert_eq!(config.max_concurrent_requests, 100);
970 }
971
972 #[test]
973 fn test_extract_progress_token_string() -> Result<(), Box<dyn std::error::Error>> {
974 let params = serde_json::json!({
975 "_meta": {
976 "progressToken": "my-token-123"
977 },
978 "name": "test-tool"
979 });
980 let token = extract_progress_token(Some(¶ms));
981 assert!(token.is_some());
982 assert_eq!(
983 token.ok_or("Token not found")?,
984 ProgressToken::String("my-token-123".to_string())
985 );
986
987 Ok(())
988 }
989
990 #[test]
991 fn test_extract_progress_token_number() -> Result<(), Box<dyn std::error::Error>> {
992 let params = serde_json::json!({
993 "_meta": {
994 "progressToken": 42
995 },
996 "arguments": {}
997 });
998 let token = extract_progress_token(Some(¶ms));
999 assert!(token.is_some());
1000 assert_eq!(token.ok_or("Token not found")?, ProgressToken::Number(42));
1001
1002 Ok(())
1003 }
1004
1005 #[test]
1006 fn test_extract_progress_token_missing_meta() {
1007 let params = serde_json::json!({
1008 "name": "test-tool",
1009 "arguments": {}
1010 });
1011 let token = extract_progress_token(Some(¶ms));
1012 assert!(token.is_none());
1013 }
1014
1015 #[test]
1016 fn test_extract_progress_token_missing_token() {
1017 let params = serde_json::json!({
1018 "_meta": {},
1019 "name": "test-tool"
1020 });
1021 let token = extract_progress_token(Some(¶ms));
1022 assert!(token.is_none());
1023 }
1024
1025 #[test]
1026 fn test_extract_progress_token_none_params() {
1027 let token = extract_progress_token(None);
1028 assert!(token.is_none());
1029 }
1030}