1use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{
39 negotiate_version, ClientCapabilities, ServerCapabilities, SUPPORTED_PROTOCOL_VERSIONS,
40};
41use mcpkit_core::error::McpError;
42use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
43use mcpkit_core::types::CallToolResult;
44use mcpkit_transport::Transport;
45use std::collections::HashMap;
46use std::sync::atomic::{AtomicBool, Ordering};
47use std::sync::Arc;
48use std::sync::RwLock;
49
50pub struct ServerState {
52 pub client_caps: RwLock<ClientCapabilities>,
54 pub server_caps: ServerCapabilities,
56 pub initialized: AtomicBool,
58 pub cancellations: RwLock<HashMap<String, CancellationToken>>,
60 pub negotiated_version: RwLock<Option<String>>,
62}
63
64impl ServerState {
65 #[must_use]
67 pub fn new(server_caps: ServerCapabilities) -> Self {
68 Self {
69 client_caps: RwLock::new(ClientCapabilities::default()),
70 server_caps,
71 initialized: AtomicBool::new(false),
72 cancellations: RwLock::new(HashMap::new()),
73 negotiated_version: RwLock::new(None),
74 }
75 }
76
77 pub fn protocol_version(&self) -> Option<String> {
81 self.negotiated_version
82 .read()
83 .ok()
84 .and_then(|guard| guard.clone())
85 }
86
87 pub fn set_protocol_version(&self, version: String) {
91 if let Ok(mut guard) = self.negotiated_version.write() {
92 *guard = Some(version);
93 }
94 }
95
96 pub fn client_caps(&self) -> ClientCapabilities {
100 self.client_caps
101 .read()
102 .map(|guard| guard.clone())
103 .unwrap_or_default()
104 }
105
106 pub fn set_client_caps(&self, caps: ClientCapabilities) {
110 if let Ok(mut guard) = self.client_caps.write() {
111 *guard = caps;
112 }
113 }
114
115 pub fn is_initialized(&self) -> bool {
117 self.initialized.load(Ordering::Acquire)
118 }
119
120 pub fn set_initialized(&self) {
122 self.initialized.store(true, Ordering::Release);
123 }
124
125 pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
127 if let Ok(mut cancellations) = self.cancellations.write() {
128 cancellations.insert(request_id.to_string(), token);
129 }
130 }
131
132 pub fn cancel_request(&self, request_id: &str) {
134 if let Ok(cancellations) = self.cancellations.read() {
135 if let Some(token) = cancellations.get(request_id) {
136 token.cancel();
137 }
138 }
139 }
140
141 pub fn remove_cancellation(&self, request_id: &str) {
143 if let Ok(mut cancellations) = self.cancellations.write() {
144 cancellations.remove(request_id);
145 }
146 }
147}
148
149pub struct TransportPeer<T: Transport> {
151 transport: Arc<T>,
152}
153
154impl<T: Transport> TransportPeer<T> {
155 pub const fn new(transport: Arc<T>) -> Self {
157 Self { transport }
158 }
159}
160
161impl<T: Transport + 'static> Peer for TransportPeer<T>
162where
163 T::Error: Into<McpError>,
164{
165 fn notify(
166 &self,
167 notification: Notification,
168 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
169 {
170 let transport = self.transport.clone();
171 Box::pin(async move {
172 transport
173 .send(Message::Notification(notification))
174 .await
175 .map_err(std::convert::Into::into)
176 })
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct RuntimeConfig {
183 pub auto_initialized: bool,
185 pub max_concurrent_requests: usize,
187}
188
189impl Default for RuntimeConfig {
190 fn default() -> Self {
191 Self {
192 auto_initialized: true,
193 max_concurrent_requests: 100,
194 }
195 }
196}
197
198pub struct ServerRuntime<S, Tr>
203where
204 Tr: Transport,
205{
206 server: S,
207 transport: Arc<Tr>,
208 state: Arc<ServerState>,
209 #[allow(dead_code)]
211 config: RuntimeConfig,
212}
213
214impl<S, Tr> ServerRuntime<S, Tr>
215where
216 S: RequestRouter + Send + Sync,
217 Tr: Transport + 'static,
218 Tr::Error: Into<McpError>,
219{
220 pub const fn state(&self) -> &Arc<ServerState> {
222 &self.state
223 }
224
225 pub async fn run(&self) -> Result<(), McpError> {
229 loop {
230 match self.transport.recv().await {
231 Ok(Some(message)) => {
232 if let Err(e) = self.handle_message(message).await {
233 tracing::error!(error = %e, "Error handling message");
234 }
235 }
236 Ok(None) => {
237 tracing::info!("Connection closed");
239 break;
240 }
241 Err(e) => {
242 let err: McpError = e.into();
243 tracing::error!(error = %err, "Transport error");
244 return Err(err);
245 }
246 }
247 }
248
249 Ok(())
250 }
251
252 async fn handle_message(&self, message: Message) -> Result<(), McpError> {
254 match message {
255 Message::Request(request) => self.handle_request(request).await,
256 Message::Notification(notification) => self.handle_notification(notification).await,
257 Message::Response(_) => {
258 tracing::warn!("Received unexpected response message");
260 Ok(())
261 }
262 }
263 }
264
265 async fn handle_request(&self, request: Request) -> Result<(), McpError> {
267 let method = request.method.to_string();
268 let id = request.id.clone();
269
270 tracing::debug!(method = %method, id = %id, "Handling request");
271
272 let response = match method.as_str() {
273 "initialize" => self.handle_initialize(&request).await,
274 _ if !self.state.is_initialized() => {
275 Err(McpError::invalid_request("Server not initialized"))
276 }
277 _ => self.route_request(&request).await,
278 };
279
280 let response_msg = match response {
282 Ok(result) => Response::success(id, result),
283 Err(e) => Response::error(id, e.into()),
284 };
285
286 self.transport
287 .send(Message::Response(response_msg))
288 .await
289 .map_err(std::convert::Into::into)
290 }
291
292 async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
299 if self.state.is_initialized() {
300 return Err(McpError::invalid_request("Already initialized"));
301 }
302
303 let params = request
305 .params
306 .as_ref()
307 .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
308
309 let requested_version = params
311 .get("protocolVersion")
312 .and_then(|v| v.as_str())
313 .unwrap_or("");
314
315 let negotiated_version = negotiate_version(requested_version);
316
317 if requested_version == negotiated_version {
319 tracing::debug!(
320 version = %negotiated_version,
321 "Protocol version negotiated successfully"
322 );
323 } else {
324 tracing::info!(
325 requested = %requested_version,
326 negotiated = %negotiated_version,
327 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
328 "Protocol version negotiation: client requested unsupported version"
329 );
330 }
331
332 self.state
334 .set_protocol_version(negotiated_version.to_string());
335
336 if let Some(caps) = params.get("capabilities") {
338 if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
339 self.state.set_client_caps(client_caps);
340 }
341 }
342
343 let result = serde_json::json!({
345 "protocolVersion": negotiated_version,
346 "serverInfo": {
347 "name": "mcp-server",
348 "version": "1.0.0"
349 },
350 "capabilities": self.state.server_caps
351 });
352
353 self.state.set_initialized();
354
355 Ok(result)
356 }
357
358 async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
360 let method = request.method.as_ref();
361 let params = request.params.as_ref();
362
363 let progress_token = extract_progress_token(params);
365
366 let peer = TransportPeer::new(self.transport.clone());
368 let client_caps = self.state.client_caps();
369 let ctx = Context::new(
370 &request.id,
371 progress_token.as_ref(),
372 &client_caps,
373 &self.state.server_caps,
374 &peer,
375 );
376
377 self.server.route(method, params, &ctx).await
379 }
380
381 async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
383 let method = notification.method.as_ref();
384
385 tracing::debug!(method = %method, "Handling notification");
386
387 match method {
388 "notifications/initialized" => {
389 tracing::info!("Client sent initialized notification");
390 Ok(())
391 }
392 "notifications/cancelled" => {
393 if let Some(params) = ¬ification.params {
394 if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
395 self.state.cancel_request(request_id);
396 }
397 }
398 Ok(())
399 }
400 _ => {
401 tracing::debug!(method = %method, "Ignoring unknown notification");
402 Ok(())
403 }
404 }
405 }
406}
407
408impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
410where
411 H: ServerHandler + Send + Sync,
412 T: Send + Sync,
413 R: Send + Sync,
414 P: Send + Sync,
415 K: Send + Sync,
416 Tr: Transport + 'static,
417 Tr::Error: Into<McpError>,
418{
419 pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
421 let caps = server.capabilities().clone();
422 Self {
423 server,
424 transport: Arc::new(transport),
425 state: Arc::new(ServerState::new(caps)),
426 config: RuntimeConfig::default(),
427 }
428 }
429
430 pub fn with_config(
432 server: Server<H, T, R, P, K>,
433 transport: Tr,
434 config: RuntimeConfig,
435 ) -> Self {
436 let caps = server.capabilities().clone();
437 Self {
438 server,
439 transport: Arc::new(transport),
440 state: Arc::new(ServerState::new(caps)),
441 config,
442 }
443 }
444}
445
446#[allow(async_fn_in_trait)]
451pub trait RequestRouter: Send + Sync {
452 async fn route(
454 &self,
455 method: &str,
456 params: Option<&serde_json::Value>,
457 ctx: &Context<'_>,
458 ) -> Result<serde_json::Value, McpError>;
459}
460
461impl<H, T, R, P, K> Server<H, T, R, P, K>
463where
464 H: ServerHandler + Send + Sync + 'static,
465 T: Send + Sync + 'static,
466 R: Send + Sync + 'static,
467 P: Send + Sync + 'static,
468 K: Send + Sync + 'static,
469 Self: RequestRouter,
470{
471 pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
473 where
474 Tr: Transport + 'static,
475 Tr::Error: Into<McpError>,
476 {
477 let runtime = ServerRuntime::new(self, transport);
478 runtime.run().await
479 }
480}
481
482async fn route_tools<TH: ToolHandler + Send + Sync>(
490 handler: &TH,
491 method: &str,
492 params: Option<&serde_json::Value>,
493 ctx: &Context<'_>,
494) -> Option<Result<serde_json::Value, McpError>> {
495 match method {
496 "tools/list" => {
497 tracing::debug!("Listing available tools");
498 let result = handler.list_tools(ctx).await;
499 match &result {
500 Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
501 Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
502 }
503 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
504 }
505 "tools/call" => {
506 let result = async {
507 let params = params.ok_or_else(|| {
508 McpError::invalid_params("tools/call", "missing params")
509 })?;
510 let name = params.get("name")
511 .and_then(|v| v.as_str())
512 .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
513 let args = params.get("arguments")
514 .cloned()
515 .unwrap_or_else(|| serde_json::json!({}));
516
517 tracing::info!(tool = %name, "Calling tool");
518 let start = std::time::Instant::now();
519 let output = handler.call_tool(name, args, ctx).await;
520 let duration = start.elapsed();
521
522 match &output {
523 Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
524 Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
525 }
526
527 let output = output?;
528 let result: CallToolResult = output.into();
529 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
530 }.await;
531 Some(result)
532 }
533 _ => None,
534 }
535}
536
537async fn route_resources<RH: ResourceHandler + Send + Sync>(
538 handler: &RH,
539 method: &str,
540 params: Option<&serde_json::Value>,
541 ctx: &Context<'_>,
542) -> Option<Result<serde_json::Value, McpError>> {
543 match method {
544 "resources/list" => {
545 tracing::debug!("Listing available resources");
546 let result = handler.list_resources(ctx).await;
547 match &result {
548 Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
549 Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
550 }
551 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
552 }
553 "resources/read" => {
554 let result = async {
555 let params = params.ok_or_else(|| {
556 McpError::invalid_params("resources/read", "missing params")
557 })?;
558 let uri = params.get("uri")
559 .and_then(|v| v.as_str())
560 .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
561
562 tracing::info!(uri = %uri, "Reading resource");
563 let start = std::time::Instant::now();
564 let contents = handler.read_resource(uri, ctx).await;
565 let duration = start.elapsed();
566
567 match &contents {
568 Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
569 Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
570 }
571
572 let contents = contents?;
573 Ok(serde_json::json!({ "contents": contents }))
574 }.await;
575 Some(result)
576 }
577 _ => None,
578 }
579}
580
581async fn route_prompts<PH: PromptHandler + Send + Sync>(
582 handler: &PH,
583 method: &str,
584 params: Option<&serde_json::Value>,
585 ctx: &Context<'_>,
586) -> Option<Result<serde_json::Value, McpError>> {
587 match method {
588 "prompts/list" => {
589 tracing::debug!("Listing available prompts");
590 let result = handler.list_prompts(ctx).await;
591 match &result {
592 Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
593 Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
594 }
595 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
596 }
597 "prompts/get" => {
598 let result = async {
599 let params = params.ok_or_else(|| {
600 McpError::invalid_params("prompts/get", "missing params")
601 })?;
602 let name = params.get("name")
603 .and_then(|v| v.as_str())
604 .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
605 let args = params.get("arguments")
606 .and_then(|v| v.as_object())
607 .cloned();
608
609 tracing::info!(prompt = %name, "Getting prompt");
610 let start = std::time::Instant::now();
611 let prompt_result = handler.get_prompt(name, args, ctx).await;
612 let duration = start.elapsed();
613
614 match &prompt_result {
615 Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
616 Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
617 }
618
619 let result = prompt_result?;
620 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
621 }.await;
622 Some(result)
623 }
624 _ => None,
625 }
626}
627
628macro_rules! impl_request_router {
633 (base; $($bounds:tt)*) => {
635 impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
636 where
637 H: ServerHandler + Send + Sync,
638 {
639 async fn route(
640 &self,
641 method: &str,
642 _params: Option<&serde_json::Value>,
643 _ctx: &Context<'_>,
644 ) -> Result<serde_json::Value, McpError> {
645 match method {
646 "ping" => Ok(serde_json::json!({})),
647 _ => Err(McpError::method_not_found(method)),
648 }
649 }
650 }
651 };
652
653 (tools; $($bounds:tt)*) => {
655 impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
656 where
657 H: ServerHandler + Send + Sync,
658 TH: ToolHandler + Send + Sync,
659 {
660 async fn route(
661 &self,
662 method: &str,
663 params: Option<&serde_json::Value>,
664 ctx: &Context<'_>,
665 ) -> Result<serde_json::Value, McpError> {
666 if method == "ping" {
667 return Ok(serde_json::json!({}));
668 }
669 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
670 return result;
671 }
672 Err(McpError::method_not_found(method))
673 }
674 }
675 };
676
677 (resources; $($bounds:tt)*) => {
679 impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
680 where
681 H: ServerHandler + Send + Sync,
682 RH: ResourceHandler + Send + Sync,
683 {
684 async fn route(
685 &self,
686 method: &str,
687 params: Option<&serde_json::Value>,
688 ctx: &Context<'_>,
689 ) -> Result<serde_json::Value, McpError> {
690 if method == "ping" {
691 return Ok(serde_json::json!({}));
692 }
693 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
694 return result;
695 }
696 Err(McpError::method_not_found(method))
697 }
698 }
699 };
700
701 (prompts; $($bounds:tt)*) => {
703 impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
704 where
705 H: ServerHandler + Send + Sync,
706 PH: PromptHandler + Send + Sync,
707 {
708 async fn route(
709 &self,
710 method: &str,
711 params: Option<&serde_json::Value>,
712 ctx: &Context<'_>,
713 ) -> Result<serde_json::Value, McpError> {
714 if method == "ping" {
715 return Ok(serde_json::json!({}));
716 }
717 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
718 return result;
719 }
720 Err(McpError::method_not_found(method))
721 }
722 }
723 };
724
725 (tools_resources; $($bounds:tt)*) => {
727 impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
728 where
729 H: ServerHandler + Send + Sync,
730 TH: ToolHandler + Send + Sync,
731 RH: ResourceHandler + Send + Sync,
732 {
733 async fn route(
734 &self,
735 method: &str,
736 params: Option<&serde_json::Value>,
737 ctx: &Context<'_>,
738 ) -> Result<serde_json::Value, McpError> {
739 if method == "ping" {
740 return Ok(serde_json::json!({}));
741 }
742 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
743 return result;
744 }
745 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
746 return result;
747 }
748 Err(McpError::method_not_found(method))
749 }
750 }
751 };
752
753 (tools_prompts; $($bounds:tt)*) => {
755 impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
756 where
757 H: ServerHandler + Send + Sync,
758 TH: ToolHandler + Send + Sync,
759 PH: PromptHandler + Send + Sync,
760 {
761 async fn route(
762 &self,
763 method: &str,
764 params: Option<&serde_json::Value>,
765 ctx: &Context<'_>,
766 ) -> Result<serde_json::Value, McpError> {
767 if method == "ping" {
768 return Ok(serde_json::json!({}));
769 }
770 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
771 return result;
772 }
773 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
774 return result;
775 }
776 Err(McpError::method_not_found(method))
777 }
778 }
779 };
780
781 (resources_prompts; $($bounds:tt)*) => {
783 impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
784 where
785 H: ServerHandler + Send + Sync,
786 RH: ResourceHandler + Send + Sync,
787 PH: PromptHandler + Send + Sync,
788 {
789 async fn route(
790 &self,
791 method: &str,
792 params: Option<&serde_json::Value>,
793 ctx: &Context<'_>,
794 ) -> Result<serde_json::Value, McpError> {
795 if method == "ping" {
796 return Ok(serde_json::json!({}));
797 }
798 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
799 return result;
800 }
801 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
802 return result;
803 }
804 Err(McpError::method_not_found(method))
805 }
806 }
807 };
808
809 (tools_resources_prompts; $($bounds:tt)*) => {
811 impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
812 where
813 H: ServerHandler + Send + Sync,
814 TH: ToolHandler + Send + Sync,
815 RH: ResourceHandler + Send + Sync,
816 PH: PromptHandler + Send + Sync,
817 {
818 async fn route(
819 &self,
820 method: &str,
821 params: Option<&serde_json::Value>,
822 ctx: &Context<'_>,
823 ) -> Result<serde_json::Value, McpError> {
824 if method == "ping" {
825 return Ok(serde_json::json!({}));
826 }
827 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
828 return result;
829 }
830 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
831 return result;
832 }
833 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
834 return result;
835 }
836 Err(McpError::method_not_found(method))
837 }
838 }
839 };
840}
841
842impl_request_router!(base;);
844impl_request_router!(tools;);
845impl_request_router!(resources;);
846impl_request_router!(prompts;);
847impl_request_router!(tools_resources;);
848impl_request_router!(tools_prompts;);
849impl_request_router!(resources_prompts;);
850impl_request_router!(tools_resources_prompts;);
851
852fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
873 params?
874 .get("_meta")?
875 .get("progressToken")
876 .and_then(|v| serde_json::from_value(v.clone()).ok())
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_server_state_initialization() {
885 let state = ServerState::new(ServerCapabilities::default());
886 assert!(!state.is_initialized());
887
888 state.set_initialized();
889 assert!(state.is_initialized());
890 }
891
892 #[test]
893 fn test_cancellation_management() {
894 let state = ServerState::new(ServerCapabilities::default());
895 let token = CancellationToken::new();
896
897 state.register_cancellation("req-1", token.clone());
898 assert!(!token.is_cancelled());
899
900 state.cancel_request("req-1");
901 assert!(token.is_cancelled());
902
903 state.remove_cancellation("req-1");
904 }
905
906 #[test]
907 fn test_runtime_config_default() {
908 let config = RuntimeConfig::default();
909 assert!(config.auto_initialized);
910 assert_eq!(config.max_concurrent_requests, 100);
911 }
912
913 #[test]
914 fn test_extract_progress_token_string() {
915 let params = serde_json::json!({
916 "_meta": {
917 "progressToken": "my-token-123"
918 },
919 "name": "test-tool"
920 });
921 let token = extract_progress_token(Some(¶ms));
922 assert!(token.is_some());
923 assert_eq!(
924 token.unwrap(),
925 ProgressToken::String("my-token-123".to_string())
926 );
927 }
928
929 #[test]
930 fn test_extract_progress_token_number() {
931 let params = serde_json::json!({
932 "_meta": {
933 "progressToken": 42
934 },
935 "arguments": {}
936 });
937 let token = extract_progress_token(Some(¶ms));
938 assert!(token.is_some());
939 assert_eq!(token.unwrap(), ProgressToken::Number(42));
940 }
941
942 #[test]
943 fn test_extract_progress_token_missing_meta() {
944 let params = serde_json::json!({
945 "name": "test-tool",
946 "arguments": {}
947 });
948 let token = extract_progress_token(Some(¶ms));
949 assert!(token.is_none());
950 }
951
952 #[test]
953 fn test_extract_progress_token_missing_token() {
954 let params = serde_json::json!({
955 "_meta": {},
956 "name": "test-tool"
957 });
958 let token = extract_progress_token(Some(¶ms));
959 assert!(token.is_none());
960 }
961
962 #[test]
963 fn test_extract_progress_token_none_params() {
964 let token = extract_progress_token(None);
965 assert!(token.is_none());
966 }
967}