1use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
39use mcpkit_core::error::McpError;
40use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
41use mcpkit_core::protocol_version::ProtocolVersion;
42use mcpkit_core::types::CallToolResult;
43use mcpkit_transport::Transport;
44use std::collections::HashMap;
45use std::sync::Arc;
46use std::sync::RwLock;
47use std::sync::atomic::{AtomicBool, Ordering};
48
49pub struct ServerState {
51 pub client_caps: RwLock<ClientCapabilities>,
53 pub server_caps: ServerCapabilities,
55 pub initialized: AtomicBool,
57 pub cancellations: RwLock<HashMap<String, CancellationToken>>,
59 pub negotiated_version: RwLock<Option<ProtocolVersion>>,
64}
65
66impl ServerState {
67 #[must_use]
69 pub fn new(server_caps: ServerCapabilities) -> Self {
70 Self {
71 client_caps: RwLock::new(ClientCapabilities::default()),
72 server_caps,
73 initialized: AtomicBool::new(false),
74 cancellations: RwLock::new(HashMap::new()),
75 negotiated_version: RwLock::new(None),
76 }
77 }
78
79 pub fn protocol_version(&self) -> Option<ProtocolVersion> {
93 self.negotiated_version.read().ok().and_then(|guard| *guard)
94 }
95
96 pub fn set_protocol_version(&self, version: ProtocolVersion) {
100 if let Ok(mut guard) = self.negotiated_version.write() {
101 *guard = Some(version);
102 }
103 }
104
105 pub fn client_caps(&self) -> ClientCapabilities {
109 self.client_caps
110 .read()
111 .map(|guard| guard.clone())
112 .unwrap_or_default()
113 }
114
115 pub fn set_client_caps(&self, caps: ClientCapabilities) {
119 if let Ok(mut guard) = self.client_caps.write() {
120 *guard = caps;
121 }
122 }
123
124 pub fn is_initialized(&self) -> bool {
126 self.initialized.load(Ordering::Acquire)
127 }
128
129 pub fn set_initialized(&self) {
131 self.initialized.store(true, Ordering::Release);
132 }
133
134 pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
136 if let Ok(mut cancellations) = self.cancellations.write() {
137 cancellations.insert(request_id.to_string(), token);
138 }
139 }
140
141 pub fn cancel_request(&self, request_id: &str) {
143 if let Ok(cancellations) = self.cancellations.read() {
144 if let Some(token) = cancellations.get(request_id) {
145 token.cancel();
146 }
147 }
148 }
149
150 pub fn remove_cancellation(&self, request_id: &str) {
152 if let Ok(mut cancellations) = self.cancellations.write() {
153 cancellations.remove(request_id);
154 }
155 }
156}
157
158pub struct TransportPeer<T: Transport> {
160 transport: Arc<T>,
161}
162
163impl<T: Transport> TransportPeer<T> {
164 pub const fn new(transport: Arc<T>) -> Self {
166 Self { transport }
167 }
168}
169
170impl<T: Transport + 'static> Peer for TransportPeer<T>
171where
172 T::Error: Into<McpError>,
173{
174 fn notify(
175 &self,
176 notification: Notification,
177 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
178 {
179 let transport = self.transport.clone();
180 Box::pin(async move {
181 transport
182 .send(Message::Notification(notification))
183 .await
184 .map_err(std::convert::Into::into)
185 })
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct RuntimeConfig {
192 pub auto_initialized: bool,
194 pub max_concurrent_requests: usize,
196}
197
198impl Default for RuntimeConfig {
199 fn default() -> Self {
200 Self {
201 auto_initialized: true,
202 max_concurrent_requests: 100,
203 }
204 }
205}
206
207pub struct ServerRuntime<S, Tr>
212where
213 Tr: Transport,
214{
215 server: S,
216 transport: Arc<Tr>,
217 state: Arc<ServerState>,
218 #[allow(dead_code)]
220 config: RuntimeConfig,
221}
222
223impl<S, Tr> ServerRuntime<S, Tr>
224where
225 S: RequestRouter + Send + Sync,
226 Tr: Transport + 'static,
227 Tr::Error: Into<McpError>,
228{
229 pub const fn state(&self) -> &Arc<ServerState> {
231 &self.state
232 }
233
234 pub async fn run(&self) -> Result<(), McpError> {
238 loop {
239 match self.transport.recv().await {
240 Ok(Some(message)) => {
241 if let Err(e) = self.handle_message(message).await {
242 tracing::error!(error = %e, "Error handling message");
243 }
244 }
245 Ok(None) => {
246 tracing::info!("Connection closed");
248 break;
249 }
250 Err(e) => {
251 let err: McpError = e.into();
252 tracing::error!(error = %err, "Transport error");
253 return Err(err);
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 async fn handle_message(&self, message: Message) -> Result<(), McpError> {
263 match message {
264 Message::Request(request) => self.handle_request(request).await,
265 Message::Notification(notification) => self.handle_notification(notification).await,
266 Message::Response(_) => {
267 tracing::warn!("Received unexpected response message");
269 Ok(())
270 }
271 }
272 }
273
274 async fn handle_request(&self, request: Request) -> Result<(), McpError> {
276 let method = request.method.to_string();
277 let id = request.id.clone();
278
279 tracing::debug!(method = %method, id = %id, "Handling request");
280
281 let response = match method.as_str() {
282 "initialize" => self.handle_initialize(&request).await,
283 _ if !self.state.is_initialized() => {
284 Err(McpError::invalid_request("Server not initialized"))
285 }
286 _ => self.route_request(&request).await,
287 };
288
289 let response_msg = match response {
291 Ok(result) => Response::success(id, result),
292 Err(e) => Response::error(id, e.into()),
293 };
294
295 self.transport
296 .send(Message::Response(response_msg))
297 .await
298 .map_err(std::convert::Into::into)
299 }
300
301 async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
308 if self.state.is_initialized() {
309 return Err(McpError::invalid_request("Already initialized"));
310 }
311
312 let params = request
314 .params
315 .as_ref()
316 .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
317
318 let requested_version_str = params
320 .get("protocolVersion")
321 .and_then(|v| v.as_str())
322 .unwrap_or("");
323
324 let negotiated_version =
326 ProtocolVersion::negotiate(requested_version_str, ProtocolVersion::ALL)
327 .unwrap_or(ProtocolVersion::LATEST);
328
329 if requested_version_str == negotiated_version.as_str() {
331 tracing::debug!(
332 version = %negotiated_version,
333 "Protocol version negotiated successfully"
334 );
335 } else {
336 tracing::info!(
337 requested = %requested_version_str,
338 negotiated = %negotiated_version,
339 supported = ?ProtocolVersion::ALL.iter().map(ProtocolVersion::as_str).collect::<Vec<_>>(),
340 "Protocol version negotiation: client requested different version"
341 );
342 }
343
344 self.state.set_protocol_version(negotiated_version);
346
347 if let Some(caps) = params.get("capabilities") {
349 if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
350 self.state.set_client_caps(client_caps);
351 }
352 }
353
354 let result = serde_json::json!({
356 "protocolVersion": negotiated_version.as_str(),
357 "serverInfo": self.server.server_info(),
358 "capabilities": self.state.server_caps
359 });
360
361 self.state.set_initialized();
362
363 Ok(result)
364 }
365
366 async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
368 let method = request.method.as_ref();
369 let params = request.params.as_ref();
370
371 let progress_token = extract_progress_token(params);
373
374 let peer = TransportPeer::new(self.transport.clone());
376 let client_caps = self.state.client_caps();
377 let protocol_version = self
378 .state
379 .protocol_version()
380 .unwrap_or(ProtocolVersion::LATEST);
381 let ctx = Context::new(
382 &request.id,
383 progress_token.as_ref(),
384 &client_caps,
385 &self.state.server_caps,
386 protocol_version,
387 &peer,
388 );
389
390 self.server.route(method, params, &ctx).await
392 }
393
394 async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
396 let method = notification.method.as_ref();
397
398 tracing::debug!(method = %method, "Handling notification");
399
400 match method {
401 "notifications/initialized" => {
402 tracing::info!("Client sent initialized notification");
403 Ok(())
404 }
405 "notifications/cancelled" => {
406 if let Some(params) = ¬ification.params {
407 if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
408 self.state.cancel_request(request_id);
409 }
410 }
411 Ok(())
412 }
413 _ => {
414 tracing::debug!(method = %method, "Ignoring unknown notification");
415 Ok(())
416 }
417 }
418 }
419}
420
421impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
423where
424 H: ServerHandler + Send + Sync,
425 T: Send + Sync,
426 R: Send + Sync,
427 P: Send + Sync,
428 K: Send + Sync,
429 Tr: Transport + 'static,
430 Tr::Error: Into<McpError>,
431{
432 pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
434 let caps = server.capabilities().clone();
435 Self {
436 server,
437 transport: Arc::new(transport),
438 state: Arc::new(ServerState::new(caps)),
439 config: RuntimeConfig::default(),
440 }
441 }
442
443 pub fn with_config(
445 server: Server<H, T, R, P, K>,
446 transport: Tr,
447 config: RuntimeConfig,
448 ) -> Self {
449 let caps = server.capabilities().clone();
450 Self {
451 server,
452 transport: Arc::new(transport),
453 state: Arc::new(ServerState::new(caps)),
454 config,
455 }
456 }
457}
458
459#[allow(async_fn_in_trait)]
464pub trait RequestRouter: Send + Sync {
465 fn server_info(&self) -> mcpkit_core::capability::ServerInfo;
467
468 async fn route(
470 &self,
471 method: &str,
472 params: Option<&serde_json::Value>,
473 ctx: &Context<'_>,
474 ) -> Result<serde_json::Value, McpError>;
475}
476
477impl<H, T, R, P, K> Server<H, T, R, P, K>
479where
480 H: ServerHandler + Send + Sync + 'static,
481 T: Send + Sync + 'static,
482 R: Send + Sync + 'static,
483 P: Send + Sync + 'static,
484 K: Send + Sync + 'static,
485 Self: RequestRouter,
486{
487 pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
489 where
490 Tr: Transport + 'static,
491 Tr::Error: Into<McpError>,
492 {
493 let runtime = ServerRuntime::new(self, transport);
494 runtime.run().await
495 }
496}
497
498async fn route_tools<TH: ToolHandler + Send + Sync>(
506 handler: &TH,
507 method: &str,
508 params: Option<&serde_json::Value>,
509 ctx: &Context<'_>,
510) -> Option<Result<serde_json::Value, McpError>> {
511 match method {
512 "tools/list" => {
513 tracing::debug!("Listing available tools");
514 let result = handler.list_tools(ctx).await;
515 match &result {
516 Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
517 Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
518 }
519 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
520 }
521 "tools/call" => {
522 let result = async {
523 let params = params.ok_or_else(|| {
524 McpError::invalid_params("tools/call", "missing params")
525 })?;
526 let name = params.get("name")
527 .and_then(|v| v.as_str())
528 .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
529 let args = params.get("arguments")
530 .cloned()
531 .unwrap_or_else(|| serde_json::json!({}));
532
533 tracing::info!(tool = %name, "Calling tool");
534 let start = std::time::Instant::now();
535 let output = handler.call_tool(name, args, ctx).await;
536 let duration = start.elapsed();
537
538 match &output {
539 Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
540 Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
541 }
542
543 let output = output?;
544 let result: CallToolResult = output.into();
545 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
546 }.await;
547 Some(result)
548 }
549 _ => None,
550 }
551}
552
553async fn route_resources<RH: ResourceHandler + Send + Sync>(
554 handler: &RH,
555 method: &str,
556 params: Option<&serde_json::Value>,
557 ctx: &Context<'_>,
558) -> Option<Result<serde_json::Value, McpError>> {
559 match method {
560 "resources/list" => {
561 tracing::debug!("Listing available resources");
562 let result = handler.list_resources(ctx).await;
563 match &result {
564 Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
565 Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
566 }
567 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
568 }
569 "resources/read" => {
570 let result = async {
571 let params = params.ok_or_else(|| {
572 McpError::invalid_params("resources/read", "missing params")
573 })?;
574 let uri = params.get("uri")
575 .and_then(|v| v.as_str())
576 .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
577
578 tracing::info!(uri = %uri, "Reading resource");
579 let start = std::time::Instant::now();
580 let contents = handler.read_resource(uri, ctx).await;
581 let duration = start.elapsed();
582
583 match &contents {
584 Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
585 Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
586 }
587
588 let contents = contents?;
589 Ok(serde_json::json!({ "contents": contents }))
590 }.await;
591 Some(result)
592 }
593 _ => None,
594 }
595}
596
597async fn route_prompts<PH: PromptHandler + Send + Sync>(
598 handler: &PH,
599 method: &str,
600 params: Option<&serde_json::Value>,
601 ctx: &Context<'_>,
602) -> Option<Result<serde_json::Value, McpError>> {
603 match method {
604 "prompts/list" => {
605 tracing::debug!("Listing available prompts");
606 let result = handler.list_prompts(ctx).await;
607 match &result {
608 Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
609 Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
610 }
611 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
612 }
613 "prompts/get" => {
614 let result = async {
615 let params = params.ok_or_else(|| {
616 McpError::invalid_params("prompts/get", "missing params")
617 })?;
618 let name = params.get("name")
619 .and_then(|v| v.as_str())
620 .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
621 let args = params.get("arguments")
622 .and_then(|v| v.as_object())
623 .cloned();
624
625 tracing::info!(prompt = %name, "Getting prompt");
626 let start = std::time::Instant::now();
627 let prompt_result = handler.get_prompt(name, args, ctx).await;
628 let duration = start.elapsed();
629
630 match &prompt_result {
631 Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
632 Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
633 }
634
635 let result = prompt_result?;
636 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
637 }.await;
638 Some(result)
639 }
640 _ => None,
641 }
642}
643
644macro_rules! impl_request_router {
649 (base; $($bounds:tt)*) => {
651 impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
652 where
653 H: ServerHandler + Send + Sync,
654 {
655 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
656 self.handler().server_info()
657 }
658
659 async fn route(
660 &self,
661 method: &str,
662 _params: Option<&serde_json::Value>,
663 _ctx: &Context<'_>,
664 ) -> Result<serde_json::Value, McpError> {
665 match method {
666 "ping" => Ok(serde_json::json!({})),
667 _ => Err(McpError::method_not_found(method)),
668 }
669 }
670 }
671 };
672
673 (tools; $($bounds:tt)*) => {
675 impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
676 where
677 H: ServerHandler + Send + Sync,
678 TH: ToolHandler + Send + Sync,
679 {
680 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
681 self.handler().server_info()
682 }
683
684 async fn route(
685 &self,
686 method: &str,
687 params: Option<&serde_json::Value>,
688 ctx: &Context<'_>,
689 ) -> Result<serde_json::Value, McpError> {
690 if method == "ping" {
691 return Ok(serde_json::json!({}));
692 }
693 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
694 return result;
695 }
696 Err(McpError::method_not_found(method))
697 }
698 }
699 };
700
701 (resources; $($bounds:tt)*) => {
703 impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
704 where
705 H: ServerHandler + Send + Sync,
706 RH: ResourceHandler + Send + Sync,
707 {
708 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
709 self.handler().server_info()
710 }
711
712 async fn route(
713 &self,
714 method: &str,
715 params: Option<&serde_json::Value>,
716 ctx: &Context<'_>,
717 ) -> Result<serde_json::Value, McpError> {
718 if method == "ping" {
719 return Ok(serde_json::json!({}));
720 }
721 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
722 return result;
723 }
724 Err(McpError::method_not_found(method))
725 }
726 }
727 };
728
729 (prompts; $($bounds:tt)*) => {
731 impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
732 where
733 H: ServerHandler + Send + Sync,
734 PH: PromptHandler + Send + Sync,
735 {
736 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
737 self.handler().server_info()
738 }
739
740 async fn route(
741 &self,
742 method: &str,
743 params: Option<&serde_json::Value>,
744 ctx: &Context<'_>,
745 ) -> Result<serde_json::Value, McpError> {
746 if method == "ping" {
747 return Ok(serde_json::json!({}));
748 }
749 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
750 return result;
751 }
752 Err(McpError::method_not_found(method))
753 }
754 }
755 };
756
757 (tools_resources; $($bounds:tt)*) => {
759 impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
760 where
761 H: ServerHandler + Send + Sync,
762 TH: ToolHandler + Send + Sync,
763 RH: ResourceHandler + Send + Sync,
764 {
765 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
766 self.handler().server_info()
767 }
768
769 async fn route(
770 &self,
771 method: &str,
772 params: Option<&serde_json::Value>,
773 ctx: &Context<'_>,
774 ) -> Result<serde_json::Value, McpError> {
775 if method == "ping" {
776 return Ok(serde_json::json!({}));
777 }
778 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
779 return result;
780 }
781 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
782 return result;
783 }
784 Err(McpError::method_not_found(method))
785 }
786 }
787 };
788
789 (tools_prompts; $($bounds:tt)*) => {
791 impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
792 where
793 H: ServerHandler + Send + Sync,
794 TH: ToolHandler + Send + Sync,
795 PH: PromptHandler + Send + Sync,
796 {
797 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
798 self.handler().server_info()
799 }
800
801 async fn route(
802 &self,
803 method: &str,
804 params: Option<&serde_json::Value>,
805 ctx: &Context<'_>,
806 ) -> Result<serde_json::Value, McpError> {
807 if method == "ping" {
808 return Ok(serde_json::json!({}));
809 }
810 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
811 return result;
812 }
813 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
814 return result;
815 }
816 Err(McpError::method_not_found(method))
817 }
818 }
819 };
820
821 (resources_prompts; $($bounds:tt)*) => {
823 impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
824 where
825 H: ServerHandler + Send + Sync,
826 RH: ResourceHandler + Send + Sync,
827 PH: PromptHandler + Send + Sync,
828 {
829 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
830 self.handler().server_info()
831 }
832
833 async fn route(
834 &self,
835 method: &str,
836 params: Option<&serde_json::Value>,
837 ctx: &Context<'_>,
838 ) -> Result<serde_json::Value, McpError> {
839 if method == "ping" {
840 return Ok(serde_json::json!({}));
841 }
842 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
843 return result;
844 }
845 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
846 return result;
847 }
848 Err(McpError::method_not_found(method))
849 }
850 }
851 };
852
853 (tools_resources_prompts; $($bounds:tt)*) => {
855 impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
856 where
857 H: ServerHandler + Send + Sync,
858 TH: ToolHandler + Send + Sync,
859 RH: ResourceHandler + Send + Sync,
860 PH: PromptHandler + Send + Sync,
861 {
862 fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
863 self.handler().server_info()
864 }
865
866 async fn route(
867 &self,
868 method: &str,
869 params: Option<&serde_json::Value>,
870 ctx: &Context<'_>,
871 ) -> Result<serde_json::Value, McpError> {
872 if method == "ping" {
873 return Ok(serde_json::json!({}));
874 }
875 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
876 return result;
877 }
878 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
879 return result;
880 }
881 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
882 return result;
883 }
884 Err(McpError::method_not_found(method))
885 }
886 }
887 };
888}
889
890impl_request_router!(base;);
892impl_request_router!(tools;);
893impl_request_router!(resources;);
894impl_request_router!(prompts;);
895impl_request_router!(tools_resources;);
896impl_request_router!(tools_prompts;);
897impl_request_router!(resources_prompts;);
898impl_request_router!(tools_resources_prompts;);
899
900fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
921 params?
922 .get("_meta")?
923 .get("progressToken")
924 .and_then(|v| serde_json::from_value(v.clone()).ok())
925}
926
927#[cfg(test)]
928mod tests {
929 use super::*;
930
931 #[test]
932 fn test_server_state_initialization() {
933 let state = ServerState::new(ServerCapabilities::default());
934 assert!(!state.is_initialized());
935
936 state.set_initialized();
937 assert!(state.is_initialized());
938 }
939
940 #[test]
941 fn test_cancellation_management() {
942 let state = ServerState::new(ServerCapabilities::default());
943 let token = CancellationToken::new();
944
945 state.register_cancellation("req-1", token.clone());
946 assert!(!token.is_cancelled());
947
948 state.cancel_request("req-1");
949 assert!(token.is_cancelled());
950
951 state.remove_cancellation("req-1");
952 }
953
954 #[test]
955 fn test_runtime_config_default() {
956 let config = RuntimeConfig::default();
957 assert!(config.auto_initialized);
958 assert_eq!(config.max_concurrent_requests, 100);
959 }
960
961 #[test]
962 fn test_extract_progress_token_string() {
963 let params = serde_json::json!({
964 "_meta": {
965 "progressToken": "my-token-123"
966 },
967 "name": "test-tool"
968 });
969 let token = extract_progress_token(Some(¶ms));
970 assert!(token.is_some());
971 assert_eq!(
972 token.unwrap(),
973 ProgressToken::String("my-token-123".to_string())
974 );
975 }
976
977 #[test]
978 fn test_extract_progress_token_number() {
979 let params = serde_json::json!({
980 "_meta": {
981 "progressToken": 42
982 },
983 "arguments": {}
984 });
985 let token = extract_progress_token(Some(¶ms));
986 assert!(token.is_some());
987 assert_eq!(token.unwrap(), ProgressToken::Number(42));
988 }
989
990 #[test]
991 fn test_extract_progress_token_missing_meta() {
992 let params = serde_json::json!({
993 "name": "test-tool",
994 "arguments": {}
995 });
996 let token = extract_progress_token(Some(¶ms));
997 assert!(token.is_none());
998 }
999
1000 #[test]
1001 fn test_extract_progress_token_missing_token() {
1002 let params = serde_json::json!({
1003 "_meta": {},
1004 "name": "test-tool"
1005 });
1006 let token = extract_progress_token(Some(¶ms));
1007 assert!(token.is_none());
1008 }
1009
1010 #[test]
1011 fn test_extract_progress_token_none_params() {
1012 let token = extract_progress_token(None);
1013 assert!(token.is_none());
1014 }
1015}