mcpkit_server/
server.rs

1//! Server runtime for MCP servers.
2//!
3//! This module provides the runtime that executes an MCP server over
4//! a transport, handling message routing, request correlation, and
5//! the connection lifecycle.
6//!
7//! # Overview
8//!
9//! The server runtime:
10//! 1. Accepts a transport for communication
11//! 2. Handles the initialize/initialized handshake
12//! 3. Routes incoming requests to the appropriate handlers
13//! 4. Manages the connection lifecycle
14//!
15//! # Example
16//!
17//! ```rust
18//! use mcpkit_server::{ServerBuilder, ServerHandler, ServerState};
19//! use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
20//!
21//! struct MyHandler;
22//! impl ServerHandler for MyHandler {
23//!     fn server_info(&self) -> ServerInfo {
24//!         ServerInfo::new("my-server", "1.0.0")
25//!     }
26//! }
27//!
28//! // Build a server and create server state
29//! let server = ServerBuilder::new(MyHandler).build();
30//! let state = ServerState::new(server.capabilities().clone());
31//!
32//! assert!(!state.is_initialized());
33//! ```
34
35use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{
39    negotiate_version, ClientCapabilities, ServerCapabilities, SUPPORTED_PROTOCOL_VERSIONS,
40};
41use mcpkit_core::error::McpError;
42use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
43use mcpkit_core::types::CallToolResult;
44use mcpkit_transport::Transport;
45use std::collections::HashMap;
46use std::sync::atomic::{AtomicBool, Ordering};
47use std::sync::Arc;
48use std::sync::RwLock;
49
50/// State for a running server.
51pub struct ServerState {
52    /// Client capabilities negotiated during initialization.
53    pub client_caps: RwLock<ClientCapabilities>,
54    /// Server capabilities advertised during initialization.
55    pub server_caps: ServerCapabilities,
56    /// Whether the server has been initialized.
57    pub initialized: AtomicBool,
58    /// Active cancellation tokens by request ID.
59    pub cancellations: RwLock<HashMap<String, CancellationToken>>,
60    /// The protocol version negotiated during initialization.
61    pub negotiated_version: RwLock<Option<String>>,
62}
63
64impl ServerState {
65    /// Create a new server state.
66    pub fn new(server_caps: ServerCapabilities) -> Self {
67        Self {
68            client_caps: RwLock::new(ClientCapabilities::default()),
69            server_caps,
70            initialized: AtomicBool::new(false),
71            cancellations: RwLock::new(HashMap::new()),
72            negotiated_version: RwLock::new(None),
73        }
74    }
75
76    /// Get the negotiated protocol version.
77    ///
78    /// Returns `None` if not yet initialized.
79    pub fn protocol_version(&self) -> Option<String> {
80        self.negotiated_version
81            .read()
82            .ok()
83            .and_then(|guard| guard.clone())
84    }
85
86    /// Set the negotiated protocol version.
87    ///
88    /// Silently fails if the lock is poisoned.
89    pub fn set_protocol_version(&self, version: String) {
90        if let Ok(mut guard) = self.negotiated_version.write() {
91            *guard = Some(version);
92        }
93    }
94
95    /// Get a snapshot of client capabilities.
96    ///
97    /// Returns default capabilities if the lock is poisoned.
98    pub fn client_caps(&self) -> ClientCapabilities {
99        self.client_caps
100            .read()
101            .map(|guard| guard.clone())
102            .unwrap_or_default()
103    }
104
105    /// Update client capabilities.
106    ///
107    /// Silently fails if the lock is poisoned.
108    pub fn set_client_caps(&self, caps: ClientCapabilities) {
109        if let Ok(mut guard) = self.client_caps.write() {
110            *guard = caps;
111        }
112    }
113
114    /// Check if the server is initialized.
115    pub fn is_initialized(&self) -> bool {
116        self.initialized.load(Ordering::Acquire)
117    }
118
119    /// Mark the server as initialized.
120    pub fn set_initialized(&self) {
121        self.initialized.store(true, Ordering::Release);
122    }
123
124    /// Register a cancellation token for a request.
125    pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
126        if let Ok(mut cancellations) = self.cancellations.write() {
127            cancellations.insert(request_id.to_string(), token);
128        }
129    }
130
131    /// Cancel a request by ID.
132    pub fn cancel_request(&self, request_id: &str) {
133        if let Ok(cancellations) = self.cancellations.read() {
134            if let Some(token) = cancellations.get(request_id) {
135                token.cancel();
136            }
137        }
138    }
139
140    /// Remove a cancellation token after request completion.
141    pub fn remove_cancellation(&self, request_id: &str) {
142        if let Ok(mut cancellations) = self.cancellations.write() {
143            cancellations.remove(request_id);
144        }
145    }
146}
147
148/// A peer implementation that sends notifications over a transport.
149pub struct TransportPeer<T: Transport> {
150    transport: Arc<T>,
151}
152
153impl<T: Transport> TransportPeer<T> {
154    /// Create a new transport peer.
155    pub fn new(transport: Arc<T>) -> Self {
156        Self { transport }
157    }
158}
159
160impl<T: Transport + 'static> Peer for TransportPeer<T>
161where
162    T::Error: Into<McpError>,
163{
164    fn notify(
165        &self,
166        notification: Notification,
167    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
168    {
169        let transport = self.transport.clone();
170        Box::pin(async move {
171            transport
172                .send(Message::Notification(notification))
173                .await
174                .map_err(|e| e.into())
175        })
176    }
177}
178
179/// Server runtime configuration.
180#[derive(Debug, Clone)]
181pub struct RuntimeConfig {
182    /// Whether to automatically send initialized notification.
183    pub auto_initialized: bool,
184    /// Maximum concurrent requests to process.
185    pub max_concurrent_requests: usize,
186}
187
188impl Default for RuntimeConfig {
189    fn default() -> Self {
190        Self {
191            auto_initialized: true,
192            max_concurrent_requests: 100,
193        }
194    }
195}
196
197/// Server runtime that handles the message loop.
198///
199/// This runtime manages the connection lifecycle, routes requests to
200/// handlers, and coordinates response delivery.
201pub struct ServerRuntime<S, Tr>
202where
203    Tr: Transport,
204{
205    server: S,
206    transport: Arc<Tr>,
207    state: Arc<ServerState>,
208    /// Runtime configuration (request timeouts, etc.) - will be used by advanced features.
209    #[allow(dead_code)]
210    config: RuntimeConfig,
211}
212
213impl<S, Tr> ServerRuntime<S, Tr>
214where
215    S: RequestRouter + Send + Sync,
216    Tr: Transport + 'static,
217    Tr::Error: Into<McpError>,
218{
219    /// Get the server state.
220    pub fn state(&self) -> &Arc<ServerState> {
221        &self.state
222    }
223
224    /// Run the server message loop.
225    ///
226    /// This method runs until the connection is closed or an error occurs.
227    pub async fn run(&self) -> Result<(), McpError> {
228        loop {
229            match self.transport.recv().await {
230                Ok(Some(message)) => {
231                    if let Err(e) = self.handle_message(message).await {
232                        tracing::error!(error = %e, "Error handling message");
233                    }
234                }
235                Ok(None) => {
236                    // Connection closed cleanly
237                    tracing::info!("Connection closed");
238                    break;
239                }
240                Err(e) => {
241                    let err: McpError = e.into();
242                    tracing::error!(error = %err, "Transport error");
243                    return Err(err);
244                }
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Handle a single message.
252    async fn handle_message(&self, message: Message) -> Result<(), McpError> {
253        match message {
254            Message::Request(request) => self.handle_request(request).await,
255            Message::Notification(notification) => self.handle_notification(notification).await,
256            Message::Response(_) => {
257                // Servers don't typically receive responses
258                tracing::warn!("Received unexpected response message");
259                Ok(())
260            }
261        }
262    }
263
264    /// Handle a request.
265    async fn handle_request(&self, request: Request) -> Result<(), McpError> {
266        let method = request.method.to_string();
267        let id = request.id.clone();
268
269        tracing::debug!(method = %method, id = %id, "Handling request");
270
271        let response = match method.as_str() {
272            "initialize" => self.handle_initialize(&request).await,
273            _ if !self.state.is_initialized() => {
274                Err(McpError::invalid_request("Server not initialized"))
275            }
276            _ => self.route_request(&request).await,
277        };
278
279        // Send response
280        let response_msg = match response {
281            Ok(result) => Response::success(id, result),
282            Err(e) => Response::error(id, e.into()),
283        };
284
285        self.transport
286            .send(Message::Response(response_msg))
287            .await
288            .map_err(|e| e.into())
289    }
290
291    /// Handle the initialize request.
292    ///
293    /// This performs protocol version negotiation according to the MCP specification:
294    /// 1. Client sends its preferred protocol version
295    /// 2. Server responds with the same version if supported, or its preferred version
296    /// 3. Client must support the returned version or disconnect
297    async fn handle_initialize(
298        &self,
299        request: &Request,
300    ) -> Result<serde_json::Value, McpError> {
301        if self.state.is_initialized() {
302            return Err(McpError::invalid_request("Already initialized"));
303        }
304
305        // Parse initialize params
306        let params = request.params.as_ref().ok_or_else(|| {
307            McpError::invalid_params("initialize", "missing params")
308        })?;
309
310        // Extract and negotiate protocol version
311        let requested_version = params
312            .get("protocolVersion")
313            .and_then(|v| v.as_str())
314            .unwrap_or("");
315
316        let negotiated_version = negotiate_version(requested_version);
317
318        // Log version negotiation details for debugging
319        if requested_version != negotiated_version {
320            tracing::info!(
321                requested = %requested_version,
322                negotiated = %negotiated_version,
323                supported = ?SUPPORTED_PROTOCOL_VERSIONS,
324                "Protocol version negotiation: client requested unsupported version"
325            );
326        } else {
327            tracing::debug!(
328                version = %negotiated_version,
329                "Protocol version negotiated successfully"
330            );
331        }
332
333        // Store the negotiated version
334        self.state.set_protocol_version(negotiated_version.to_string());
335
336        // Extract client info and capabilities
337        if let Some(caps) = params.get("capabilities") {
338            if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
339                self.state.set_client_caps(client_caps);
340            }
341        }
342
343        // Build response with negotiated version
344        let result = serde_json::json!({
345            "protocolVersion": negotiated_version,
346            "serverInfo": {
347                "name": "mcp-server",
348                "version": "1.0.0"
349            },
350            "capabilities": self.state.server_caps
351        });
352
353        self.state.set_initialized();
354
355        Ok(result)
356    }
357
358    /// Route a request to the appropriate handler.
359    async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
360        let method = request.method.as_ref();
361        let params = request.params.as_ref();
362
363        // Extract progress token from params._meta.progressToken if present
364        let progress_token = extract_progress_token(params);
365
366        // Create context for the handler
367        let peer = TransportPeer::new(self.transport.clone());
368        let client_caps = self.state.client_caps();
369        let ctx = Context::new(
370            &request.id,
371            progress_token.as_ref(),
372            &client_caps,
373            &self.state.server_caps,
374            &peer,
375        );
376
377        // Delegate to the router
378        self.server.route(method, params, &ctx).await
379    }
380
381    /// Handle a notification.
382    async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
383        let method = notification.method.as_ref();
384
385        tracing::debug!(method = %method, "Handling notification");
386
387        match method {
388            "notifications/initialized" => {
389                tracing::info!("Client sent initialized notification");
390                Ok(())
391            }
392            "notifications/cancelled" => {
393                if let Some(params) = &notification.params {
394                    if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
395                        self.state.cancel_request(request_id);
396                    }
397                }
398                Ok(())
399            }
400            _ => {
401                tracing::debug!(method = %method, "Ignoring unknown notification");
402                Ok(())
403            }
404        }
405    }
406}
407
408// Constructor implementations for ServerRuntime with different server types
409impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
410where
411    H: ServerHandler + Send + Sync,
412    T: Send + Sync,
413    R: Send + Sync,
414    P: Send + Sync,
415    K: Send + Sync,
416    Tr: Transport + 'static,
417    Tr::Error: Into<McpError>,
418{
419    /// Create a new server runtime.
420    pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
421        let caps = server.capabilities().clone();
422        Self {
423            server,
424            transport: Arc::new(transport),
425            state: Arc::new(ServerState::new(caps)),
426            config: RuntimeConfig::default(),
427        }
428    }
429
430    /// Create a new server runtime with custom configuration.
431    pub fn with_config(server: Server<H, T, R, P, K>, transport: Tr, config: RuntimeConfig) -> Self {
432        let caps = server.capabilities().clone();
433        Self {
434            server,
435            transport: Arc::new(transport),
436            state: Arc::new(ServerState::new(caps)),
437            config,
438        }
439    }
440}
441
442/// Trait for routing requests to handlers.
443///
444/// This trait is implemented by Server with different bounds depending on
445/// which handlers are registered.
446#[allow(async_fn_in_trait)]
447pub trait RequestRouter: Send + Sync {
448    /// Route a request and return the result.
449    async fn route(
450        &self,
451        method: &str,
452        params: Option<&serde_json::Value>,
453        ctx: &Context<'_>,
454    ) -> Result<serde_json::Value, McpError>;
455}
456
457/// Extension methods for Server to run with a transport.
458impl<H, T, R, P, K> Server<H, T, R, P, K>
459where
460    H: ServerHandler + Send + Sync + 'static,
461    T: Send + Sync + 'static,
462    R: Send + Sync + 'static,
463    P: Send + Sync + 'static,
464    K: Send + Sync + 'static,
465    Self: RequestRouter,
466{
467    /// Run this server over the given transport.
468    pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
469    where
470        Tr: Transport + 'static,
471        Tr::Error: Into<McpError>,
472    {
473        let runtime = ServerRuntime::new(self, transport);
474        runtime.run().await
475    }
476}
477
478// ============================================================================
479// RequestRouter implementations via macro
480// ============================================================================
481
482// Internal routing functions to reduce code duplication.
483// Each function handles a specific handler type's methods.
484
485async fn route_tools<TH: ToolHandler + Send + Sync>(
486    handler: &TH,
487    method: &str,
488    params: Option<&serde_json::Value>,
489    ctx: &Context<'_>,
490) -> Option<Result<serde_json::Value, McpError>> {
491    match method {
492        "tools/list" => {
493            let result = handler.list_tools(ctx).await;
494            Some(result.map(|tools| serde_json::json!({ "tools": tools })))
495        }
496        "tools/call" => {
497            let result = (|| async {
498                let params = params.ok_or_else(|| {
499                    McpError::invalid_params("tools/call", "missing params")
500                })?;
501                let name = params.get("name")
502                    .and_then(|v| v.as_str())
503                    .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
504                let args = params.get("arguments")
505                    .cloned()
506                    .unwrap_or(serde_json::json!({}));
507                let output = handler.call_tool(name, args, ctx).await?;
508                let result: CallToolResult = output.into();
509                Ok(serde_json::to_value(result).unwrap_or(serde_json::json!({})))
510            })().await;
511            Some(result)
512        }
513        _ => None,
514    }
515}
516
517async fn route_resources<RH: ResourceHandler + Send + Sync>(
518    handler: &RH,
519    method: &str,
520    params: Option<&serde_json::Value>,
521    ctx: &Context<'_>,
522) -> Option<Result<serde_json::Value, McpError>> {
523    match method {
524        "resources/list" => {
525            let result = handler.list_resources(ctx).await;
526            Some(result.map(|resources| serde_json::json!({ "resources": resources })))
527        }
528        "resources/read" => {
529            let result = (|| async {
530                let params = params.ok_or_else(|| {
531                    McpError::invalid_params("resources/read", "missing params")
532                })?;
533                let uri = params.get("uri")
534                    .and_then(|v| v.as_str())
535                    .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
536                let contents = handler.read_resource(uri, ctx).await?;
537                Ok(serde_json::json!({ "contents": contents }))
538            })().await;
539            Some(result)
540        }
541        _ => None,
542    }
543}
544
545async fn route_prompts<PH: PromptHandler + Send + Sync>(
546    handler: &PH,
547    method: &str,
548    params: Option<&serde_json::Value>,
549    ctx: &Context<'_>,
550) -> Option<Result<serde_json::Value, McpError>> {
551    match method {
552        "prompts/list" => {
553            let result = handler.list_prompts(ctx).await;
554            Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
555        }
556        "prompts/get" => {
557            let result = (|| async {
558                let params = params.ok_or_else(|| {
559                    McpError::invalid_params("prompts/get", "missing params")
560                })?;
561                let name = params.get("name")
562                    .and_then(|v| v.as_str())
563                    .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
564                let args = params.get("arguments")
565                    .and_then(|v| v.as_object())
566                    .cloned();
567                let result = handler.get_prompt(name, args, ctx).await?;
568                Ok(serde_json::to_value(result).unwrap_or(serde_json::json!({})))
569            })().await;
570            Some(result)
571        }
572        _ => None,
573    }
574}
575
576/// Macro to generate `RequestRouter` implementations for all handler combinations.
577///
578/// This macro reduces code duplication by generating all 2^3 = 8 combinations
579/// of tool/resource/prompt handler registration states.
580macro_rules! impl_request_router {
581    // Base case: no handlers
582    (base; $($bounds:tt)*) => {
583        impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
584        where
585            H: ServerHandler + Send + Sync,
586        {
587            async fn route(
588                &self,
589                method: &str,
590                _params: Option<&serde_json::Value>,
591                _ctx: &Context<'_>,
592            ) -> Result<serde_json::Value, McpError> {
593                match method {
594                    "ping" => Ok(serde_json::json!({})),
595                    _ => Err(McpError::method_not_found(method)),
596                }
597            }
598        }
599    };
600
601    // Tools only
602    (tools; $($bounds:tt)*) => {
603        impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
604        where
605            H: ServerHandler + Send + Sync,
606            TH: ToolHandler + Send + Sync,
607        {
608            async fn route(
609                &self,
610                method: &str,
611                params: Option<&serde_json::Value>,
612                ctx: &Context<'_>,
613            ) -> Result<serde_json::Value, McpError> {
614                if method == "ping" {
615                    return Ok(serde_json::json!({}));
616                }
617                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
618                    return result;
619                }
620                Err(McpError::method_not_found(method))
621            }
622        }
623    };
624
625    // Resources only
626    (resources; $($bounds:tt)*) => {
627        impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
628        where
629            H: ServerHandler + Send + Sync,
630            RH: ResourceHandler + Send + Sync,
631        {
632            async fn route(
633                &self,
634                method: &str,
635                params: Option<&serde_json::Value>,
636                ctx: &Context<'_>,
637            ) -> Result<serde_json::Value, McpError> {
638                if method == "ping" {
639                    return Ok(serde_json::json!({}));
640                }
641                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
642                    return result;
643                }
644                Err(McpError::method_not_found(method))
645            }
646        }
647    };
648
649    // Prompts only
650    (prompts; $($bounds:tt)*) => {
651        impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
652        where
653            H: ServerHandler + Send + Sync,
654            PH: PromptHandler + Send + Sync,
655        {
656            async fn route(
657                &self,
658                method: &str,
659                params: Option<&serde_json::Value>,
660                ctx: &Context<'_>,
661            ) -> Result<serde_json::Value, McpError> {
662                if method == "ping" {
663                    return Ok(serde_json::json!({}));
664                }
665                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
666                    return result;
667                }
668                Err(McpError::method_not_found(method))
669            }
670        }
671    };
672
673    // Tools + Resources
674    (tools_resources; $($bounds:tt)*) => {
675        impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
676        where
677            H: ServerHandler + Send + Sync,
678            TH: ToolHandler + Send + Sync,
679            RH: ResourceHandler + Send + Sync,
680        {
681            async fn route(
682                &self,
683                method: &str,
684                params: Option<&serde_json::Value>,
685                ctx: &Context<'_>,
686            ) -> Result<serde_json::Value, McpError> {
687                if method == "ping" {
688                    return Ok(serde_json::json!({}));
689                }
690                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
691                    return result;
692                }
693                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
694                    return result;
695                }
696                Err(McpError::method_not_found(method))
697            }
698        }
699    };
700
701    // Tools + Prompts
702    (tools_prompts; $($bounds:tt)*) => {
703        impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
704        where
705            H: ServerHandler + Send + Sync,
706            TH: ToolHandler + Send + Sync,
707            PH: PromptHandler + Send + Sync,
708        {
709            async fn route(
710                &self,
711                method: &str,
712                params: Option<&serde_json::Value>,
713                ctx: &Context<'_>,
714            ) -> Result<serde_json::Value, McpError> {
715                if method == "ping" {
716                    return Ok(serde_json::json!({}));
717                }
718                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
719                    return result;
720                }
721                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
722                    return result;
723                }
724                Err(McpError::method_not_found(method))
725            }
726        }
727    };
728
729    // Resources + Prompts
730    (resources_prompts; $($bounds:tt)*) => {
731        impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
732        where
733            H: ServerHandler + Send + Sync,
734            RH: ResourceHandler + Send + Sync,
735            PH: PromptHandler + Send + Sync,
736        {
737            async fn route(
738                &self,
739                method: &str,
740                params: Option<&serde_json::Value>,
741                ctx: &Context<'_>,
742            ) -> Result<serde_json::Value, McpError> {
743                if method == "ping" {
744                    return Ok(serde_json::json!({}));
745                }
746                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
747                    return result;
748                }
749                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
750                    return result;
751                }
752                Err(McpError::method_not_found(method))
753            }
754        }
755    };
756
757    // Tools + Resources + Prompts
758    (tools_resources_prompts; $($bounds:tt)*) => {
759        impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
760        where
761            H: ServerHandler + Send + Sync,
762            TH: ToolHandler + Send + Sync,
763            RH: ResourceHandler + Send + Sync,
764            PH: PromptHandler + Send + Sync,
765        {
766            async fn route(
767                &self,
768                method: &str,
769                params: Option<&serde_json::Value>,
770                ctx: &Context<'_>,
771            ) -> Result<serde_json::Value, McpError> {
772                if method == "ping" {
773                    return Ok(serde_json::json!({}));
774                }
775                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
776                    return result;
777                }
778                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
779                    return result;
780                }
781                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
782                    return result;
783                }
784                Err(McpError::method_not_found(method))
785            }
786        }
787    };
788}
789
790// Generate all RequestRouter implementations
791impl_request_router!(base;);
792impl_request_router!(tools;);
793impl_request_router!(resources;);
794impl_request_router!(prompts;);
795impl_request_router!(tools_resources;);
796impl_request_router!(tools_prompts;);
797impl_request_router!(resources_prompts;);
798impl_request_router!(tools_resources_prompts;);
799
800// ============================================================================
801// Helper functions
802// ============================================================================
803
804/// Extract a progress token from request parameters.
805///
806/// Per the MCP specification, progress tokens are sent in the `_meta.progressToken`
807/// field of request parameters. This function attempts to extract and parse that
808/// field into a `ProgressToken`.
809///
810/// # Example JSON structure
811/// ```json
812/// {
813///   "_meta": {
814///     "progressToken": "token-123"
815///   },
816///   "name": "my-tool",
817///   "arguments": {}
818/// }
819/// ```
820fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
821    params?
822        .get("_meta")?
823        .get("progressToken")
824        .and_then(|v| serde_json::from_value(v.clone()).ok())
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830
831    #[test]
832    fn test_server_state_initialization() {
833        let state = ServerState::new(ServerCapabilities::default());
834        assert!(!state.is_initialized());
835
836        state.set_initialized();
837        assert!(state.is_initialized());
838    }
839
840    #[test]
841    fn test_cancellation_management() {
842        let state = ServerState::new(ServerCapabilities::default());
843        let token = CancellationToken::new();
844
845        state.register_cancellation("req-1", token.clone());
846        assert!(!token.is_cancelled());
847
848        state.cancel_request("req-1");
849        assert!(token.is_cancelled());
850
851        state.remove_cancellation("req-1");
852    }
853
854    #[test]
855    fn test_runtime_config_default() {
856        let config = RuntimeConfig::default();
857        assert!(config.auto_initialized);
858        assert_eq!(config.max_concurrent_requests, 100);
859    }
860
861    #[test]
862    fn test_extract_progress_token_string() {
863        let params = serde_json::json!({
864            "_meta": {
865                "progressToken": "my-token-123"
866            },
867            "name": "test-tool"
868        });
869        let token = extract_progress_token(Some(&params));
870        assert!(token.is_some());
871        assert_eq!(token.unwrap(), ProgressToken::String("my-token-123".to_string()));
872    }
873
874    #[test]
875    fn test_extract_progress_token_number() {
876        let params = serde_json::json!({
877            "_meta": {
878                "progressToken": 42
879            },
880            "arguments": {}
881        });
882        let token = extract_progress_token(Some(&params));
883        assert!(token.is_some());
884        assert_eq!(token.unwrap(), ProgressToken::Number(42));
885    }
886
887    #[test]
888    fn test_extract_progress_token_missing_meta() {
889        let params = serde_json::json!({
890            "name": "test-tool",
891            "arguments": {}
892        });
893        let token = extract_progress_token(Some(&params));
894        assert!(token.is_none());
895    }
896
897    #[test]
898    fn test_extract_progress_token_missing_token() {
899        let params = serde_json::json!({
900            "_meta": {},
901            "name": "test-tool"
902        });
903        let token = extract_progress_token(Some(&params));
904        assert!(token.is_none());
905    }
906
907    #[test]
908    fn test_extract_progress_token_none_params() {
909        let token = extract_progress_token(None);
910        assert!(token.is_none());
911    }
912}