mcpkit_server/
server.rs

1//! Server runtime for MCP servers.
2//!
3//! This module provides the runtime that executes an MCP server over
4//! a transport, handling message routing, request correlation, and
5//! the connection lifecycle.
6//!
7//! # Overview
8//!
9//! The server runtime:
10//! 1. Accepts a transport for communication
11//! 2. Handles the initialize/initialized handshake
12//! 3. Routes incoming requests to the appropriate handlers
13//! 4. Manages the connection lifecycle
14//!
15//! # Example
16//!
17//! ```rust
18//! use mcpkit_server::{ServerBuilder, ServerHandler, ServerState};
19//! use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
20//!
21//! struct MyHandler;
22//! impl ServerHandler for MyHandler {
23//!     fn server_info(&self) -> ServerInfo {
24//!         ServerInfo::new("my-server", "1.0.0")
25//!     }
26//! }
27//!
28//! // Build a server and create server state
29//! let server = ServerBuilder::new(MyHandler).build();
30//! let state = ServerState::new(server.capabilities().clone());
31//!
32//! assert!(!state.is_initialized());
33//! ```
34
35use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
39use mcpkit_core::error::McpError;
40use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
41use mcpkit_core::protocol_version::ProtocolVersion;
42use mcpkit_core::types::CallToolResult;
43use mcpkit_transport::Transport;
44use std::collections::HashMap;
45use std::sync::atomic::{AtomicBool, Ordering};
46use std::sync::Arc;
47use std::sync::RwLock;
48
49/// State for a running server.
50pub struct ServerState {
51    /// Client capabilities negotiated during initialization.
52    pub client_caps: RwLock<ClientCapabilities>,
53    /// Server capabilities advertised during initialization.
54    pub server_caps: ServerCapabilities,
55    /// Whether the server has been initialized.
56    pub initialized: AtomicBool,
57    /// Active cancellation tokens by request ID.
58    pub cancellations: RwLock<HashMap<String, CancellationToken>>,
59    /// The protocol version negotiated during initialization.
60    ///
61    /// This is stored as a `ProtocolVersion` enum for type-safe feature detection.
62    /// Use methods like `protocol_version().supports_tasks()` to check capabilities.
63    pub negotiated_version: RwLock<Option<ProtocolVersion>>,
64}
65
66impl ServerState {
67    /// Create a new server state.
68    #[must_use]
69    pub fn new(server_caps: ServerCapabilities) -> Self {
70        Self {
71            client_caps: RwLock::new(ClientCapabilities::default()),
72            server_caps,
73            initialized: AtomicBool::new(false),
74            cancellations: RwLock::new(HashMap::new()),
75            negotiated_version: RwLock::new(None),
76        }
77    }
78
79    /// Get the negotiated protocol version.
80    ///
81    /// Returns `None` if not yet initialized.
82    ///
83    /// # Example
84    ///
85    /// ```rust,ignore
86    /// if let Some(version) = state.protocol_version() {
87    ///     if version.supports_tasks() {
88    ///         // Tasks are available in this session
89    ///     }
90    /// }
91    /// ```
92    pub fn protocol_version(&self) -> Option<ProtocolVersion> {
93        self.negotiated_version.read().ok().and_then(|guard| *guard)
94    }
95
96    /// Set the negotiated protocol version.
97    ///
98    /// Silently fails if the lock is poisoned.
99    pub fn set_protocol_version(&self, version: ProtocolVersion) {
100        if let Ok(mut guard) = self.negotiated_version.write() {
101            *guard = Some(version);
102        }
103    }
104
105    /// Get a snapshot of client capabilities.
106    ///
107    /// Returns default capabilities if the lock is poisoned.
108    pub fn client_caps(&self) -> ClientCapabilities {
109        self.client_caps
110            .read()
111            .map(|guard| guard.clone())
112            .unwrap_or_default()
113    }
114
115    /// Update client capabilities.
116    ///
117    /// Silently fails if the lock is poisoned.
118    pub fn set_client_caps(&self, caps: ClientCapabilities) {
119        if let Ok(mut guard) = self.client_caps.write() {
120            *guard = caps;
121        }
122    }
123
124    /// Check if the server is initialized.
125    pub fn is_initialized(&self) -> bool {
126        self.initialized.load(Ordering::Acquire)
127    }
128
129    /// Mark the server as initialized.
130    pub fn set_initialized(&self) {
131        self.initialized.store(true, Ordering::Release);
132    }
133
134    /// Register a cancellation token for a request.
135    pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
136        if let Ok(mut cancellations) = self.cancellations.write() {
137            cancellations.insert(request_id.to_string(), token);
138        }
139    }
140
141    /// Cancel a request by ID.
142    pub fn cancel_request(&self, request_id: &str) {
143        if let Ok(cancellations) = self.cancellations.read() {
144            if let Some(token) = cancellations.get(request_id) {
145                token.cancel();
146            }
147        }
148    }
149
150    /// Remove a cancellation token after request completion.
151    pub fn remove_cancellation(&self, request_id: &str) {
152        if let Ok(mut cancellations) = self.cancellations.write() {
153            cancellations.remove(request_id);
154        }
155    }
156}
157
158/// A peer implementation that sends notifications over a transport.
159pub struct TransportPeer<T: Transport> {
160    transport: Arc<T>,
161}
162
163impl<T: Transport> TransportPeer<T> {
164    /// Create a new transport peer.
165    pub const fn new(transport: Arc<T>) -> Self {
166        Self { transport }
167    }
168}
169
170impl<T: Transport + 'static> Peer for TransportPeer<T>
171where
172    T::Error: Into<McpError>,
173{
174    fn notify(
175        &self,
176        notification: Notification,
177    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
178    {
179        let transport = self.transport.clone();
180        Box::pin(async move {
181            transport
182                .send(Message::Notification(notification))
183                .await
184                .map_err(std::convert::Into::into)
185        })
186    }
187}
188
189/// Server runtime configuration.
190#[derive(Debug, Clone)]
191pub struct RuntimeConfig {
192    /// Whether to automatically send initialized notification.
193    pub auto_initialized: bool,
194    /// Maximum concurrent requests to process.
195    pub max_concurrent_requests: usize,
196}
197
198impl Default for RuntimeConfig {
199    fn default() -> Self {
200        Self {
201            auto_initialized: true,
202            max_concurrent_requests: 100,
203        }
204    }
205}
206
207/// Server runtime that handles the message loop.
208///
209/// This runtime manages the connection lifecycle, routes requests to
210/// handlers, and coordinates response delivery.
211pub struct ServerRuntime<S, Tr>
212where
213    Tr: Transport,
214{
215    server: S,
216    transport: Arc<Tr>,
217    state: Arc<ServerState>,
218    /// Runtime configuration (request timeouts, etc.) - will be used by advanced features.
219    #[allow(dead_code)]
220    config: RuntimeConfig,
221}
222
223impl<S, Tr> ServerRuntime<S, Tr>
224where
225    S: RequestRouter + Send + Sync,
226    Tr: Transport + 'static,
227    Tr::Error: Into<McpError>,
228{
229    /// Get the server state.
230    pub const fn state(&self) -> &Arc<ServerState> {
231        &self.state
232    }
233
234    /// Run the server message loop.
235    ///
236    /// This method runs until the connection is closed or an error occurs.
237    pub async fn run(&self) -> Result<(), McpError> {
238        loop {
239            match self.transport.recv().await {
240                Ok(Some(message)) => {
241                    if let Err(e) = self.handle_message(message).await {
242                        tracing::error!(error = %e, "Error handling message");
243                    }
244                }
245                Ok(None) => {
246                    // Connection closed cleanly
247                    tracing::info!("Connection closed");
248                    break;
249                }
250                Err(e) => {
251                    let err: McpError = e.into();
252                    tracing::error!(error = %err, "Transport error");
253                    return Err(err);
254                }
255            }
256        }
257
258        Ok(())
259    }
260
261    /// Handle a single message.
262    async fn handle_message(&self, message: Message) -> Result<(), McpError> {
263        match message {
264            Message::Request(request) => self.handle_request(request).await,
265            Message::Notification(notification) => self.handle_notification(notification).await,
266            Message::Response(_) => {
267                // Servers don't typically receive responses
268                tracing::warn!("Received unexpected response message");
269                Ok(())
270            }
271        }
272    }
273
274    /// Handle a request.
275    async fn handle_request(&self, request: Request) -> Result<(), McpError> {
276        let method = request.method.to_string();
277        let id = request.id.clone();
278
279        tracing::debug!(method = %method, id = %id, "Handling request");
280
281        let response = match method.as_str() {
282            "initialize" => self.handle_initialize(&request).await,
283            _ if !self.state.is_initialized() => {
284                Err(McpError::invalid_request("Server not initialized"))
285            }
286            _ => self.route_request(&request).await,
287        };
288
289        // Send response
290        let response_msg = match response {
291            Ok(result) => Response::success(id, result),
292            Err(e) => Response::error(id, e.into()),
293        };
294
295        self.transport
296            .send(Message::Response(response_msg))
297            .await
298            .map_err(std::convert::Into::into)
299    }
300
301    /// Handle the initialize request.
302    ///
303    /// This performs protocol version negotiation according to the MCP specification:
304    /// 1. Client sends its preferred protocol version
305    /// 2. Server responds with the same version if supported, or its preferred version
306    /// 3. Client must support the returned version or disconnect
307    async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
308        if self.state.is_initialized() {
309            return Err(McpError::invalid_request("Already initialized"));
310        }
311
312        // Parse initialize params
313        let params = request
314            .params
315            .as_ref()
316            .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
317
318        // Extract and negotiate protocol version using type-safe enum
319        let requested_version_str = params
320            .get("protocolVersion")
321            .and_then(|v| v.as_str())
322            .unwrap_or("");
323
324        // Negotiate using the ProtocolVersion enum for type safety
325        let negotiated_version =
326            ProtocolVersion::negotiate(requested_version_str, ProtocolVersion::ALL)
327                .unwrap_or(ProtocolVersion::LATEST);
328
329        // Log version negotiation details for debugging
330        if requested_version_str == negotiated_version.as_str() {
331            tracing::debug!(
332                version = %negotiated_version,
333                "Protocol version negotiated successfully"
334            );
335        } else {
336            tracing::info!(
337                requested = %requested_version_str,
338                negotiated = %negotiated_version,
339                supported = ?ProtocolVersion::ALL.iter().map(ProtocolVersion::as_str).collect::<Vec<_>>(),
340                "Protocol version negotiation: client requested different version"
341            );
342        }
343
344        // Store the negotiated version (type-safe enum)
345        self.state.set_protocol_version(negotiated_version);
346
347        // Extract client info and capabilities
348        if let Some(caps) = params.get("capabilities") {
349            if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
350                self.state.set_client_caps(client_caps);
351            }
352        }
353
354        // Build response with negotiated version (serialized to string by serde)
355        let result = serde_json::json!({
356            "protocolVersion": negotiated_version.as_str(),
357            "serverInfo": {
358                "name": "mcp-server",
359                "version": "1.0.0"
360            },
361            "capabilities": self.state.server_caps
362        });
363
364        self.state.set_initialized();
365
366        Ok(result)
367    }
368
369    /// Route a request to the appropriate handler.
370    async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
371        let method = request.method.as_ref();
372        let params = request.params.as_ref();
373
374        // Extract progress token from params._meta.progressToken if present
375        let progress_token = extract_progress_token(params);
376
377        // Create context for the handler
378        let peer = TransportPeer::new(self.transport.clone());
379        let client_caps = self.state.client_caps();
380        let protocol_version = self
381            .state
382            .protocol_version()
383            .unwrap_or(ProtocolVersion::LATEST);
384        let ctx = Context::new(
385            &request.id,
386            progress_token.as_ref(),
387            &client_caps,
388            &self.state.server_caps,
389            protocol_version,
390            &peer,
391        );
392
393        // Delegate to the router
394        self.server.route(method, params, &ctx).await
395    }
396
397    /// Handle a notification.
398    async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
399        let method = notification.method.as_ref();
400
401        tracing::debug!(method = %method, "Handling notification");
402
403        match method {
404            "notifications/initialized" => {
405                tracing::info!("Client sent initialized notification");
406                Ok(())
407            }
408            "notifications/cancelled" => {
409                if let Some(params) = &notification.params {
410                    if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
411                        self.state.cancel_request(request_id);
412                    }
413                }
414                Ok(())
415            }
416            _ => {
417                tracing::debug!(method = %method, "Ignoring unknown notification");
418                Ok(())
419            }
420        }
421    }
422}
423
424// Constructor implementations for ServerRuntime with different server types
425impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
426where
427    H: ServerHandler + Send + Sync,
428    T: Send + Sync,
429    R: Send + Sync,
430    P: Send + Sync,
431    K: Send + Sync,
432    Tr: Transport + 'static,
433    Tr::Error: Into<McpError>,
434{
435    /// Create a new server runtime.
436    pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
437        let caps = server.capabilities().clone();
438        Self {
439            server,
440            transport: Arc::new(transport),
441            state: Arc::new(ServerState::new(caps)),
442            config: RuntimeConfig::default(),
443        }
444    }
445
446    /// Create a new server runtime with custom configuration.
447    pub fn with_config(
448        server: Server<H, T, R, P, K>,
449        transport: Tr,
450        config: RuntimeConfig,
451    ) -> Self {
452        let caps = server.capabilities().clone();
453        Self {
454            server,
455            transport: Arc::new(transport),
456            state: Arc::new(ServerState::new(caps)),
457            config,
458        }
459    }
460}
461
462/// Trait for routing requests to handlers.
463///
464/// This trait is implemented by Server with different bounds depending on
465/// which handlers are registered.
466#[allow(async_fn_in_trait)]
467pub trait RequestRouter: Send + Sync {
468    /// Route a request and return the result.
469    async fn route(
470        &self,
471        method: &str,
472        params: Option<&serde_json::Value>,
473        ctx: &Context<'_>,
474    ) -> Result<serde_json::Value, McpError>;
475}
476
477/// Extension methods for Server to run with a transport.
478impl<H, T, R, P, K> Server<H, T, R, P, K>
479where
480    H: ServerHandler + Send + Sync + 'static,
481    T: Send + Sync + 'static,
482    R: Send + Sync + 'static,
483    P: Send + Sync + 'static,
484    K: Send + Sync + 'static,
485    Self: RequestRouter,
486{
487    /// Run this server over the given transport.
488    pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
489    where
490        Tr: Transport + 'static,
491        Tr::Error: Into<McpError>,
492    {
493        let runtime = ServerRuntime::new(self, transport);
494        runtime.run().await
495    }
496}
497
498// ============================================================================
499// RequestRouter implementations via macro
500// ============================================================================
501
502// Internal routing functions to reduce code duplication.
503// Each function handles a specific handler type's methods.
504
505async fn route_tools<TH: ToolHandler + Send + Sync>(
506    handler: &TH,
507    method: &str,
508    params: Option<&serde_json::Value>,
509    ctx: &Context<'_>,
510) -> Option<Result<serde_json::Value, McpError>> {
511    match method {
512        "tools/list" => {
513            tracing::debug!("Listing available tools");
514            let result = handler.list_tools(ctx).await;
515            match &result {
516                Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
517                Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
518            }
519            Some(result.map(|tools| serde_json::json!({ "tools": tools })))
520        }
521        "tools/call" => {
522            let result = async {
523                let params = params.ok_or_else(|| {
524                    McpError::invalid_params("tools/call", "missing params")
525                })?;
526                let name = params.get("name")
527                    .and_then(|v| v.as_str())
528                    .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
529                let args = params.get("arguments")
530                    .cloned()
531                    .unwrap_or_else(|| serde_json::json!({}));
532
533                tracing::info!(tool = %name, "Calling tool");
534                let start = std::time::Instant::now();
535                let output = handler.call_tool(name, args, ctx).await;
536                let duration = start.elapsed();
537
538                match &output {
539                    Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
540                    Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
541                }
542
543                let output = output?;
544                let result: CallToolResult = output.into();
545                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
546            }.await;
547            Some(result)
548        }
549        _ => None,
550    }
551}
552
553async fn route_resources<RH: ResourceHandler + Send + Sync>(
554    handler: &RH,
555    method: &str,
556    params: Option<&serde_json::Value>,
557    ctx: &Context<'_>,
558) -> Option<Result<serde_json::Value, McpError>> {
559    match method {
560        "resources/list" => {
561            tracing::debug!("Listing available resources");
562            let result = handler.list_resources(ctx).await;
563            match &result {
564                Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
565                Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
566            }
567            Some(result.map(|resources| serde_json::json!({ "resources": resources })))
568        }
569        "resources/read" => {
570            let result = async {
571                let params = params.ok_or_else(|| {
572                    McpError::invalid_params("resources/read", "missing params")
573                })?;
574                let uri = params.get("uri")
575                    .and_then(|v| v.as_str())
576                    .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
577
578                tracing::info!(uri = %uri, "Reading resource");
579                let start = std::time::Instant::now();
580                let contents = handler.read_resource(uri, ctx).await;
581                let duration = start.elapsed();
582
583                match &contents {
584                    Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
585                    Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
586                }
587
588                let contents = contents?;
589                Ok(serde_json::json!({ "contents": contents }))
590            }.await;
591            Some(result)
592        }
593        _ => None,
594    }
595}
596
597async fn route_prompts<PH: PromptHandler + Send + Sync>(
598    handler: &PH,
599    method: &str,
600    params: Option<&serde_json::Value>,
601    ctx: &Context<'_>,
602) -> Option<Result<serde_json::Value, McpError>> {
603    match method {
604        "prompts/list" => {
605            tracing::debug!("Listing available prompts");
606            let result = handler.list_prompts(ctx).await;
607            match &result {
608                Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
609                Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
610            }
611            Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
612        }
613        "prompts/get" => {
614            let result = async {
615                let params = params.ok_or_else(|| {
616                    McpError::invalid_params("prompts/get", "missing params")
617                })?;
618                let name = params.get("name")
619                    .and_then(|v| v.as_str())
620                    .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
621                let args = params.get("arguments")
622                    .and_then(|v| v.as_object())
623                    .cloned();
624
625                tracing::info!(prompt = %name, "Getting prompt");
626                let start = std::time::Instant::now();
627                let prompt_result = handler.get_prompt(name, args, ctx).await;
628                let duration = start.elapsed();
629
630                match &prompt_result {
631                    Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
632                    Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
633                }
634
635                let result = prompt_result?;
636                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
637            }.await;
638            Some(result)
639        }
640        _ => None,
641    }
642}
643
644/// Macro to generate `RequestRouter` implementations for all handler combinations.
645///
646/// This macro reduces code duplication by generating all 2^3 = 8 combinations
647/// of tool/resource/prompt handler registration states.
648macro_rules! impl_request_router {
649    // Base case: no handlers
650    (base; $($bounds:tt)*) => {
651        impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
652        where
653            H: ServerHandler + Send + Sync,
654        {
655            async fn route(
656                &self,
657                method: &str,
658                _params: Option<&serde_json::Value>,
659                _ctx: &Context<'_>,
660            ) -> Result<serde_json::Value, McpError> {
661                match method {
662                    "ping" => Ok(serde_json::json!({})),
663                    _ => Err(McpError::method_not_found(method)),
664                }
665            }
666        }
667    };
668
669    // Tools only
670    (tools; $($bounds:tt)*) => {
671        impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
672        where
673            H: ServerHandler + Send + Sync,
674            TH: ToolHandler + Send + Sync,
675        {
676            async fn route(
677                &self,
678                method: &str,
679                params: Option<&serde_json::Value>,
680                ctx: &Context<'_>,
681            ) -> Result<serde_json::Value, McpError> {
682                if method == "ping" {
683                    return Ok(serde_json::json!({}));
684                }
685                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
686                    return result;
687                }
688                Err(McpError::method_not_found(method))
689            }
690        }
691    };
692
693    // Resources only
694    (resources; $($bounds:tt)*) => {
695        impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
696        where
697            H: ServerHandler + Send + Sync,
698            RH: ResourceHandler + Send + Sync,
699        {
700            async fn route(
701                &self,
702                method: &str,
703                params: Option<&serde_json::Value>,
704                ctx: &Context<'_>,
705            ) -> Result<serde_json::Value, McpError> {
706                if method == "ping" {
707                    return Ok(serde_json::json!({}));
708                }
709                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
710                    return result;
711                }
712                Err(McpError::method_not_found(method))
713            }
714        }
715    };
716
717    // Prompts only
718    (prompts; $($bounds:tt)*) => {
719        impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
720        where
721            H: ServerHandler + Send + Sync,
722            PH: PromptHandler + Send + Sync,
723        {
724            async fn route(
725                &self,
726                method: &str,
727                params: Option<&serde_json::Value>,
728                ctx: &Context<'_>,
729            ) -> Result<serde_json::Value, McpError> {
730                if method == "ping" {
731                    return Ok(serde_json::json!({}));
732                }
733                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
734                    return result;
735                }
736                Err(McpError::method_not_found(method))
737            }
738        }
739    };
740
741    // Tools + Resources
742    (tools_resources; $($bounds:tt)*) => {
743        impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
744        where
745            H: ServerHandler + Send + Sync,
746            TH: ToolHandler + Send + Sync,
747            RH: ResourceHandler + Send + Sync,
748        {
749            async fn route(
750                &self,
751                method: &str,
752                params: Option<&serde_json::Value>,
753                ctx: &Context<'_>,
754            ) -> Result<serde_json::Value, McpError> {
755                if method == "ping" {
756                    return Ok(serde_json::json!({}));
757                }
758                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
759                    return result;
760                }
761                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
762                    return result;
763                }
764                Err(McpError::method_not_found(method))
765            }
766        }
767    };
768
769    // Tools + Prompts
770    (tools_prompts; $($bounds:tt)*) => {
771        impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
772        where
773            H: ServerHandler + Send + Sync,
774            TH: ToolHandler + Send + Sync,
775            PH: PromptHandler + Send + Sync,
776        {
777            async fn route(
778                &self,
779                method: &str,
780                params: Option<&serde_json::Value>,
781                ctx: &Context<'_>,
782            ) -> Result<serde_json::Value, McpError> {
783                if method == "ping" {
784                    return Ok(serde_json::json!({}));
785                }
786                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
787                    return result;
788                }
789                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
790                    return result;
791                }
792                Err(McpError::method_not_found(method))
793            }
794        }
795    };
796
797    // Resources + Prompts
798    (resources_prompts; $($bounds:tt)*) => {
799        impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
800        where
801            H: ServerHandler + Send + Sync,
802            RH: ResourceHandler + Send + Sync,
803            PH: PromptHandler + Send + Sync,
804        {
805            async fn route(
806                &self,
807                method: &str,
808                params: Option<&serde_json::Value>,
809                ctx: &Context<'_>,
810            ) -> Result<serde_json::Value, McpError> {
811                if method == "ping" {
812                    return Ok(serde_json::json!({}));
813                }
814                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
815                    return result;
816                }
817                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
818                    return result;
819                }
820                Err(McpError::method_not_found(method))
821            }
822        }
823    };
824
825    // Tools + Resources + Prompts
826    (tools_resources_prompts; $($bounds:tt)*) => {
827        impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
828        where
829            H: ServerHandler + Send + Sync,
830            TH: ToolHandler + Send + Sync,
831            RH: ResourceHandler + Send + Sync,
832            PH: PromptHandler + Send + Sync,
833        {
834            async fn route(
835                &self,
836                method: &str,
837                params: Option<&serde_json::Value>,
838                ctx: &Context<'_>,
839            ) -> Result<serde_json::Value, McpError> {
840                if method == "ping" {
841                    return Ok(serde_json::json!({}));
842                }
843                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
844                    return result;
845                }
846                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
847                    return result;
848                }
849                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
850                    return result;
851                }
852                Err(McpError::method_not_found(method))
853            }
854        }
855    };
856}
857
858// Generate all RequestRouter implementations
859impl_request_router!(base;);
860impl_request_router!(tools;);
861impl_request_router!(resources;);
862impl_request_router!(prompts;);
863impl_request_router!(tools_resources;);
864impl_request_router!(tools_prompts;);
865impl_request_router!(resources_prompts;);
866impl_request_router!(tools_resources_prompts;);
867
868// ============================================================================
869// Helper functions
870// ============================================================================
871
872/// Extract a progress token from request parameters.
873///
874/// Per the MCP specification, progress tokens are sent in the `_meta.progressToken`
875/// field of request parameters. This function attempts to extract and parse that
876/// field into a `ProgressToken`.
877///
878/// # Example JSON structure
879/// ```json
880/// {
881///   "_meta": {
882///     "progressToken": "token-123"
883///   },
884///   "name": "my-tool",
885///   "arguments": {}
886/// }
887/// ```
888fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
889    params?
890        .get("_meta")?
891        .get("progressToken")
892        .and_then(|v| serde_json::from_value(v.clone()).ok())
893}
894
895#[cfg(test)]
896mod tests {
897    use super::*;
898
899    #[test]
900    fn test_server_state_initialization() {
901        let state = ServerState::new(ServerCapabilities::default());
902        assert!(!state.is_initialized());
903
904        state.set_initialized();
905        assert!(state.is_initialized());
906    }
907
908    #[test]
909    fn test_cancellation_management() {
910        let state = ServerState::new(ServerCapabilities::default());
911        let token = CancellationToken::new();
912
913        state.register_cancellation("req-1", token.clone());
914        assert!(!token.is_cancelled());
915
916        state.cancel_request("req-1");
917        assert!(token.is_cancelled());
918
919        state.remove_cancellation("req-1");
920    }
921
922    #[test]
923    fn test_runtime_config_default() {
924        let config = RuntimeConfig::default();
925        assert!(config.auto_initialized);
926        assert_eq!(config.max_concurrent_requests, 100);
927    }
928
929    #[test]
930    fn test_extract_progress_token_string() {
931        let params = serde_json::json!({
932            "_meta": {
933                "progressToken": "my-token-123"
934            },
935            "name": "test-tool"
936        });
937        let token = extract_progress_token(Some(&params));
938        assert!(token.is_some());
939        assert_eq!(
940            token.unwrap(),
941            ProgressToken::String("my-token-123".to_string())
942        );
943    }
944
945    #[test]
946    fn test_extract_progress_token_number() {
947        let params = serde_json::json!({
948            "_meta": {
949                "progressToken": 42
950            },
951            "arguments": {}
952        });
953        let token = extract_progress_token(Some(&params));
954        assert!(token.is_some());
955        assert_eq!(token.unwrap(), ProgressToken::Number(42));
956    }
957
958    #[test]
959    fn test_extract_progress_token_missing_meta() {
960        let params = serde_json::json!({
961            "name": "test-tool",
962            "arguments": {}
963        });
964        let token = extract_progress_token(Some(&params));
965        assert!(token.is_none());
966    }
967
968    #[test]
969    fn test_extract_progress_token_missing_token() {
970        let params = serde_json::json!({
971            "_meta": {},
972            "name": "test-tool"
973        });
974        let token = extract_progress_token(Some(&params));
975        assert!(token.is_none());
976    }
977
978    #[test]
979    fn test_extract_progress_token_none_params() {
980        let token = extract_progress_token(None);
981        assert!(token.is_none());
982    }
983}