1use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
39use mcpkit_core::error::McpError;
40use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
41use mcpkit_core::protocol_version::ProtocolVersion;
42use mcpkit_core::types::CallToolResult;
43use mcpkit_transport::Transport;
44use std::collections::HashMap;
45use std::sync::atomic::{AtomicBool, Ordering};
46use std::sync::Arc;
47use std::sync::RwLock;
48
49pub struct ServerState {
51 pub client_caps: RwLock<ClientCapabilities>,
53 pub server_caps: ServerCapabilities,
55 pub initialized: AtomicBool,
57 pub cancellations: RwLock<HashMap<String, CancellationToken>>,
59 pub negotiated_version: RwLock<Option<ProtocolVersion>>,
64}
65
66impl ServerState {
67 #[must_use]
69 pub fn new(server_caps: ServerCapabilities) -> Self {
70 Self {
71 client_caps: RwLock::new(ClientCapabilities::default()),
72 server_caps,
73 initialized: AtomicBool::new(false),
74 cancellations: RwLock::new(HashMap::new()),
75 negotiated_version: RwLock::new(None),
76 }
77 }
78
79 pub fn protocol_version(&self) -> Option<ProtocolVersion> {
93 self.negotiated_version.read().ok().and_then(|guard| *guard)
94 }
95
96 pub fn set_protocol_version(&self, version: ProtocolVersion) {
100 if let Ok(mut guard) = self.negotiated_version.write() {
101 *guard = Some(version);
102 }
103 }
104
105 pub fn client_caps(&self) -> ClientCapabilities {
109 self.client_caps
110 .read()
111 .map(|guard| guard.clone())
112 .unwrap_or_default()
113 }
114
115 pub fn set_client_caps(&self, caps: ClientCapabilities) {
119 if let Ok(mut guard) = self.client_caps.write() {
120 *guard = caps;
121 }
122 }
123
124 pub fn is_initialized(&self) -> bool {
126 self.initialized.load(Ordering::Acquire)
127 }
128
129 pub fn set_initialized(&self) {
131 self.initialized.store(true, Ordering::Release);
132 }
133
134 pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
136 if let Ok(mut cancellations) = self.cancellations.write() {
137 cancellations.insert(request_id.to_string(), token);
138 }
139 }
140
141 pub fn cancel_request(&self, request_id: &str) {
143 if let Ok(cancellations) = self.cancellations.read() {
144 if let Some(token) = cancellations.get(request_id) {
145 token.cancel();
146 }
147 }
148 }
149
150 pub fn remove_cancellation(&self, request_id: &str) {
152 if let Ok(mut cancellations) = self.cancellations.write() {
153 cancellations.remove(request_id);
154 }
155 }
156}
157
158pub struct TransportPeer<T: Transport> {
160 transport: Arc<T>,
161}
162
163impl<T: Transport> TransportPeer<T> {
164 pub const fn new(transport: Arc<T>) -> Self {
166 Self { transport }
167 }
168}
169
170impl<T: Transport + 'static> Peer for TransportPeer<T>
171where
172 T::Error: Into<McpError>,
173{
174 fn notify(
175 &self,
176 notification: Notification,
177 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
178 {
179 let transport = self.transport.clone();
180 Box::pin(async move {
181 transport
182 .send(Message::Notification(notification))
183 .await
184 .map_err(std::convert::Into::into)
185 })
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct RuntimeConfig {
192 pub auto_initialized: bool,
194 pub max_concurrent_requests: usize,
196}
197
198impl Default for RuntimeConfig {
199 fn default() -> Self {
200 Self {
201 auto_initialized: true,
202 max_concurrent_requests: 100,
203 }
204 }
205}
206
207pub struct ServerRuntime<S, Tr>
212where
213 Tr: Transport,
214{
215 server: S,
216 transport: Arc<Tr>,
217 state: Arc<ServerState>,
218 #[allow(dead_code)]
220 config: RuntimeConfig,
221}
222
223impl<S, Tr> ServerRuntime<S, Tr>
224where
225 S: RequestRouter + Send + Sync,
226 Tr: Transport + 'static,
227 Tr::Error: Into<McpError>,
228{
229 pub const fn state(&self) -> &Arc<ServerState> {
231 &self.state
232 }
233
234 pub async fn run(&self) -> Result<(), McpError> {
238 loop {
239 match self.transport.recv().await {
240 Ok(Some(message)) => {
241 if let Err(e) = self.handle_message(message).await {
242 tracing::error!(error = %e, "Error handling message");
243 }
244 }
245 Ok(None) => {
246 tracing::info!("Connection closed");
248 break;
249 }
250 Err(e) => {
251 let err: McpError = e.into();
252 tracing::error!(error = %err, "Transport error");
253 return Err(err);
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 async fn handle_message(&self, message: Message) -> Result<(), McpError> {
263 match message {
264 Message::Request(request) => self.handle_request(request).await,
265 Message::Notification(notification) => self.handle_notification(notification).await,
266 Message::Response(_) => {
267 tracing::warn!("Received unexpected response message");
269 Ok(())
270 }
271 }
272 }
273
274 async fn handle_request(&self, request: Request) -> Result<(), McpError> {
276 let method = request.method.to_string();
277 let id = request.id.clone();
278
279 tracing::debug!(method = %method, id = %id, "Handling request");
280
281 let response = match method.as_str() {
282 "initialize" => self.handle_initialize(&request).await,
283 _ if !self.state.is_initialized() => {
284 Err(McpError::invalid_request("Server not initialized"))
285 }
286 _ => self.route_request(&request).await,
287 };
288
289 let response_msg = match response {
291 Ok(result) => Response::success(id, result),
292 Err(e) => Response::error(id, e.into()),
293 };
294
295 self.transport
296 .send(Message::Response(response_msg))
297 .await
298 .map_err(std::convert::Into::into)
299 }
300
301 async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
308 if self.state.is_initialized() {
309 return Err(McpError::invalid_request("Already initialized"));
310 }
311
312 let params = request
314 .params
315 .as_ref()
316 .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
317
318 let requested_version_str = params
320 .get("protocolVersion")
321 .and_then(|v| v.as_str())
322 .unwrap_or("");
323
324 let negotiated_version =
326 ProtocolVersion::negotiate(requested_version_str, ProtocolVersion::ALL)
327 .unwrap_or(ProtocolVersion::LATEST);
328
329 if requested_version_str == negotiated_version.as_str() {
331 tracing::debug!(
332 version = %negotiated_version,
333 "Protocol version negotiated successfully"
334 );
335 } else {
336 tracing::info!(
337 requested = %requested_version_str,
338 negotiated = %negotiated_version,
339 supported = ?ProtocolVersion::ALL.iter().map(ProtocolVersion::as_str).collect::<Vec<_>>(),
340 "Protocol version negotiation: client requested different version"
341 );
342 }
343
344 self.state.set_protocol_version(negotiated_version);
346
347 if let Some(caps) = params.get("capabilities") {
349 if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
350 self.state.set_client_caps(client_caps);
351 }
352 }
353
354 let result = serde_json::json!({
356 "protocolVersion": negotiated_version.as_str(),
357 "serverInfo": {
358 "name": "mcp-server",
359 "version": "1.0.0"
360 },
361 "capabilities": self.state.server_caps
362 });
363
364 self.state.set_initialized();
365
366 Ok(result)
367 }
368
369 async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
371 let method = request.method.as_ref();
372 let params = request.params.as_ref();
373
374 let progress_token = extract_progress_token(params);
376
377 let peer = TransportPeer::new(self.transport.clone());
379 let client_caps = self.state.client_caps();
380 let protocol_version = self
381 .state
382 .protocol_version()
383 .unwrap_or(ProtocolVersion::LATEST);
384 let ctx = Context::new(
385 &request.id,
386 progress_token.as_ref(),
387 &client_caps,
388 &self.state.server_caps,
389 protocol_version,
390 &peer,
391 );
392
393 self.server.route(method, params, &ctx).await
395 }
396
397 async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
399 let method = notification.method.as_ref();
400
401 tracing::debug!(method = %method, "Handling notification");
402
403 match method {
404 "notifications/initialized" => {
405 tracing::info!("Client sent initialized notification");
406 Ok(())
407 }
408 "notifications/cancelled" => {
409 if let Some(params) = ¬ification.params {
410 if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
411 self.state.cancel_request(request_id);
412 }
413 }
414 Ok(())
415 }
416 _ => {
417 tracing::debug!(method = %method, "Ignoring unknown notification");
418 Ok(())
419 }
420 }
421 }
422}
423
424impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
426where
427 H: ServerHandler + Send + Sync,
428 T: Send + Sync,
429 R: Send + Sync,
430 P: Send + Sync,
431 K: Send + Sync,
432 Tr: Transport + 'static,
433 Tr::Error: Into<McpError>,
434{
435 pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
437 let caps = server.capabilities().clone();
438 Self {
439 server,
440 transport: Arc::new(transport),
441 state: Arc::new(ServerState::new(caps)),
442 config: RuntimeConfig::default(),
443 }
444 }
445
446 pub fn with_config(
448 server: Server<H, T, R, P, K>,
449 transport: Tr,
450 config: RuntimeConfig,
451 ) -> Self {
452 let caps = server.capabilities().clone();
453 Self {
454 server,
455 transport: Arc::new(transport),
456 state: Arc::new(ServerState::new(caps)),
457 config,
458 }
459 }
460}
461
462#[allow(async_fn_in_trait)]
467pub trait RequestRouter: Send + Sync {
468 async fn route(
470 &self,
471 method: &str,
472 params: Option<&serde_json::Value>,
473 ctx: &Context<'_>,
474 ) -> Result<serde_json::Value, McpError>;
475}
476
477impl<H, T, R, P, K> Server<H, T, R, P, K>
479where
480 H: ServerHandler + Send + Sync + 'static,
481 T: Send + Sync + 'static,
482 R: Send + Sync + 'static,
483 P: Send + Sync + 'static,
484 K: Send + Sync + 'static,
485 Self: RequestRouter,
486{
487 pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
489 where
490 Tr: Transport + 'static,
491 Tr::Error: Into<McpError>,
492 {
493 let runtime = ServerRuntime::new(self, transport);
494 runtime.run().await
495 }
496}
497
498async fn route_tools<TH: ToolHandler + Send + Sync>(
506 handler: &TH,
507 method: &str,
508 params: Option<&serde_json::Value>,
509 ctx: &Context<'_>,
510) -> Option<Result<serde_json::Value, McpError>> {
511 match method {
512 "tools/list" => {
513 tracing::debug!("Listing available tools");
514 let result = handler.list_tools(ctx).await;
515 match &result {
516 Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
517 Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
518 }
519 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
520 }
521 "tools/call" => {
522 let result = async {
523 let params = params.ok_or_else(|| {
524 McpError::invalid_params("tools/call", "missing params")
525 })?;
526 let name = params.get("name")
527 .and_then(|v| v.as_str())
528 .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
529 let args = params.get("arguments")
530 .cloned()
531 .unwrap_or_else(|| serde_json::json!({}));
532
533 tracing::info!(tool = %name, "Calling tool");
534 let start = std::time::Instant::now();
535 let output = handler.call_tool(name, args, ctx).await;
536 let duration = start.elapsed();
537
538 match &output {
539 Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
540 Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
541 }
542
543 let output = output?;
544 let result: CallToolResult = output.into();
545 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
546 }.await;
547 Some(result)
548 }
549 _ => None,
550 }
551}
552
553async fn route_resources<RH: ResourceHandler + Send + Sync>(
554 handler: &RH,
555 method: &str,
556 params: Option<&serde_json::Value>,
557 ctx: &Context<'_>,
558) -> Option<Result<serde_json::Value, McpError>> {
559 match method {
560 "resources/list" => {
561 tracing::debug!("Listing available resources");
562 let result = handler.list_resources(ctx).await;
563 match &result {
564 Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
565 Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
566 }
567 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
568 }
569 "resources/read" => {
570 let result = async {
571 let params = params.ok_or_else(|| {
572 McpError::invalid_params("resources/read", "missing params")
573 })?;
574 let uri = params.get("uri")
575 .and_then(|v| v.as_str())
576 .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
577
578 tracing::info!(uri = %uri, "Reading resource");
579 let start = std::time::Instant::now();
580 let contents = handler.read_resource(uri, ctx).await;
581 let duration = start.elapsed();
582
583 match &contents {
584 Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
585 Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
586 }
587
588 let contents = contents?;
589 Ok(serde_json::json!({ "contents": contents }))
590 }.await;
591 Some(result)
592 }
593 _ => None,
594 }
595}
596
597async fn route_prompts<PH: PromptHandler + Send + Sync>(
598 handler: &PH,
599 method: &str,
600 params: Option<&serde_json::Value>,
601 ctx: &Context<'_>,
602) -> Option<Result<serde_json::Value, McpError>> {
603 match method {
604 "prompts/list" => {
605 tracing::debug!("Listing available prompts");
606 let result = handler.list_prompts(ctx).await;
607 match &result {
608 Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
609 Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
610 }
611 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
612 }
613 "prompts/get" => {
614 let result = async {
615 let params = params.ok_or_else(|| {
616 McpError::invalid_params("prompts/get", "missing params")
617 })?;
618 let name = params.get("name")
619 .and_then(|v| v.as_str())
620 .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
621 let args = params.get("arguments")
622 .and_then(|v| v.as_object())
623 .cloned();
624
625 tracing::info!(prompt = %name, "Getting prompt");
626 let start = std::time::Instant::now();
627 let prompt_result = handler.get_prompt(name, args, ctx).await;
628 let duration = start.elapsed();
629
630 match &prompt_result {
631 Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
632 Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
633 }
634
635 let result = prompt_result?;
636 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
637 }.await;
638 Some(result)
639 }
640 _ => None,
641 }
642}
643
644macro_rules! impl_request_router {
649 (base; $($bounds:tt)*) => {
651 impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
652 where
653 H: ServerHandler + Send + Sync,
654 {
655 async fn route(
656 &self,
657 method: &str,
658 _params: Option<&serde_json::Value>,
659 _ctx: &Context<'_>,
660 ) -> Result<serde_json::Value, McpError> {
661 match method {
662 "ping" => Ok(serde_json::json!({})),
663 _ => Err(McpError::method_not_found(method)),
664 }
665 }
666 }
667 };
668
669 (tools; $($bounds:tt)*) => {
671 impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
672 where
673 H: ServerHandler + Send + Sync,
674 TH: ToolHandler + Send + Sync,
675 {
676 async fn route(
677 &self,
678 method: &str,
679 params: Option<&serde_json::Value>,
680 ctx: &Context<'_>,
681 ) -> Result<serde_json::Value, McpError> {
682 if method == "ping" {
683 return Ok(serde_json::json!({}));
684 }
685 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
686 return result;
687 }
688 Err(McpError::method_not_found(method))
689 }
690 }
691 };
692
693 (resources; $($bounds:tt)*) => {
695 impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
696 where
697 H: ServerHandler + Send + Sync,
698 RH: ResourceHandler + Send + Sync,
699 {
700 async fn route(
701 &self,
702 method: &str,
703 params: Option<&serde_json::Value>,
704 ctx: &Context<'_>,
705 ) -> Result<serde_json::Value, McpError> {
706 if method == "ping" {
707 return Ok(serde_json::json!({}));
708 }
709 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
710 return result;
711 }
712 Err(McpError::method_not_found(method))
713 }
714 }
715 };
716
717 (prompts; $($bounds:tt)*) => {
719 impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
720 where
721 H: ServerHandler + Send + Sync,
722 PH: PromptHandler + Send + Sync,
723 {
724 async fn route(
725 &self,
726 method: &str,
727 params: Option<&serde_json::Value>,
728 ctx: &Context<'_>,
729 ) -> Result<serde_json::Value, McpError> {
730 if method == "ping" {
731 return Ok(serde_json::json!({}));
732 }
733 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
734 return result;
735 }
736 Err(McpError::method_not_found(method))
737 }
738 }
739 };
740
741 (tools_resources; $($bounds:tt)*) => {
743 impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
744 where
745 H: ServerHandler + Send + Sync,
746 TH: ToolHandler + Send + Sync,
747 RH: ResourceHandler + Send + Sync,
748 {
749 async fn route(
750 &self,
751 method: &str,
752 params: Option<&serde_json::Value>,
753 ctx: &Context<'_>,
754 ) -> Result<serde_json::Value, McpError> {
755 if method == "ping" {
756 return Ok(serde_json::json!({}));
757 }
758 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
759 return result;
760 }
761 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
762 return result;
763 }
764 Err(McpError::method_not_found(method))
765 }
766 }
767 };
768
769 (tools_prompts; $($bounds:tt)*) => {
771 impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
772 where
773 H: ServerHandler + Send + Sync,
774 TH: ToolHandler + Send + Sync,
775 PH: PromptHandler + Send + Sync,
776 {
777 async fn route(
778 &self,
779 method: &str,
780 params: Option<&serde_json::Value>,
781 ctx: &Context<'_>,
782 ) -> Result<serde_json::Value, McpError> {
783 if method == "ping" {
784 return Ok(serde_json::json!({}));
785 }
786 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
787 return result;
788 }
789 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
790 return result;
791 }
792 Err(McpError::method_not_found(method))
793 }
794 }
795 };
796
797 (resources_prompts; $($bounds:tt)*) => {
799 impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
800 where
801 H: ServerHandler + Send + Sync,
802 RH: ResourceHandler + Send + Sync,
803 PH: PromptHandler + Send + Sync,
804 {
805 async fn route(
806 &self,
807 method: &str,
808 params: Option<&serde_json::Value>,
809 ctx: &Context<'_>,
810 ) -> Result<serde_json::Value, McpError> {
811 if method == "ping" {
812 return Ok(serde_json::json!({}));
813 }
814 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
815 return result;
816 }
817 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
818 return result;
819 }
820 Err(McpError::method_not_found(method))
821 }
822 }
823 };
824
825 (tools_resources_prompts; $($bounds:tt)*) => {
827 impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
828 where
829 H: ServerHandler + Send + Sync,
830 TH: ToolHandler + Send + Sync,
831 RH: ResourceHandler + Send + Sync,
832 PH: PromptHandler + Send + Sync,
833 {
834 async fn route(
835 &self,
836 method: &str,
837 params: Option<&serde_json::Value>,
838 ctx: &Context<'_>,
839 ) -> Result<serde_json::Value, McpError> {
840 if method == "ping" {
841 return Ok(serde_json::json!({}));
842 }
843 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
844 return result;
845 }
846 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
847 return result;
848 }
849 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
850 return result;
851 }
852 Err(McpError::method_not_found(method))
853 }
854 }
855 };
856}
857
858impl_request_router!(base;);
860impl_request_router!(tools;);
861impl_request_router!(resources;);
862impl_request_router!(prompts;);
863impl_request_router!(tools_resources;);
864impl_request_router!(tools_prompts;);
865impl_request_router!(resources_prompts;);
866impl_request_router!(tools_resources_prompts;);
867
868fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
889 params?
890 .get("_meta")?
891 .get("progressToken")
892 .and_then(|v| serde_json::from_value(v.clone()).ok())
893}
894
895#[cfg(test)]
896mod tests {
897 use super::*;
898
899 #[test]
900 fn test_server_state_initialization() {
901 let state = ServerState::new(ServerCapabilities::default());
902 assert!(!state.is_initialized());
903
904 state.set_initialized();
905 assert!(state.is_initialized());
906 }
907
908 #[test]
909 fn test_cancellation_management() {
910 let state = ServerState::new(ServerCapabilities::default());
911 let token = CancellationToken::new();
912
913 state.register_cancellation("req-1", token.clone());
914 assert!(!token.is_cancelled());
915
916 state.cancel_request("req-1");
917 assert!(token.is_cancelled());
918
919 state.remove_cancellation("req-1");
920 }
921
922 #[test]
923 fn test_runtime_config_default() {
924 let config = RuntimeConfig::default();
925 assert!(config.auto_initialized);
926 assert_eq!(config.max_concurrent_requests, 100);
927 }
928
929 #[test]
930 fn test_extract_progress_token_string() {
931 let params = serde_json::json!({
932 "_meta": {
933 "progressToken": "my-token-123"
934 },
935 "name": "test-tool"
936 });
937 let token = extract_progress_token(Some(¶ms));
938 assert!(token.is_some());
939 assert_eq!(
940 token.unwrap(),
941 ProgressToken::String("my-token-123".to_string())
942 );
943 }
944
945 #[test]
946 fn test_extract_progress_token_number() {
947 let params = serde_json::json!({
948 "_meta": {
949 "progressToken": 42
950 },
951 "arguments": {}
952 });
953 let token = extract_progress_token(Some(¶ms));
954 assert!(token.is_some());
955 assert_eq!(token.unwrap(), ProgressToken::Number(42));
956 }
957
958 #[test]
959 fn test_extract_progress_token_missing_meta() {
960 let params = serde_json::json!({
961 "name": "test-tool",
962 "arguments": {}
963 });
964 let token = extract_progress_token(Some(¶ms));
965 assert!(token.is_none());
966 }
967
968 #[test]
969 fn test_extract_progress_token_missing_token() {
970 let params = serde_json::json!({
971 "_meta": {},
972 "name": "test-tool"
973 });
974 let token = extract_progress_token(Some(¶ms));
975 assert!(token.is_none());
976 }
977
978 #[test]
979 fn test_extract_progress_token_none_params() {
980 let token = extract_progress_token(None);
981 assert!(token.is_none());
982 }
983}