mcpkit_server/
server.rs

1//! Server runtime for MCP servers.
2//!
3//! This module provides the runtime that executes an MCP server over
4//! a transport, handling message routing, request correlation, and
5//! the connection lifecycle.
6//!
7//! # Overview
8//!
9//! The server runtime:
10//! 1. Accepts a transport for communication
11//! 2. Handles the initialize/initialized handshake
12//! 3. Routes incoming requests to the appropriate handlers
13//! 4. Manages the connection lifecycle
14//!
15//! # Example
16//!
17//! ```rust
18//! use mcpkit_server::{ServerBuilder, ServerHandler, ServerState};
19//! use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
20//!
21//! struct MyHandler;
22//! impl ServerHandler for MyHandler {
23//!     fn server_info(&self) -> ServerInfo {
24//!         ServerInfo::new("my-server", "1.0.0")
25//!     }
26//! }
27//!
28//! // Build a server and create server state
29//! let server = ServerBuilder::new(MyHandler).build();
30//! let state = ServerState::new(server.capabilities().clone());
31//!
32//! assert!(!state.is_initialized());
33//! ```
34
35use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
39use mcpkit_core::error::McpError;
40use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
41use mcpkit_core::protocol_version::ProtocolVersion;
42use mcpkit_core::types::CallToolResult;
43use mcpkit_transport::Transport;
44use std::collections::HashMap;
45use std::sync::Arc;
46use std::sync::RwLock;
47use std::sync::atomic::{AtomicBool, Ordering};
48
49/// State for a running server.
50pub struct ServerState {
51    /// Client capabilities negotiated during initialization.
52    pub client_caps: RwLock<ClientCapabilities>,
53    /// Server capabilities advertised during initialization.
54    pub server_caps: ServerCapabilities,
55    /// Whether the server has been initialized.
56    pub initialized: AtomicBool,
57    /// Active cancellation tokens by request ID.
58    pub cancellations: RwLock<HashMap<String, CancellationToken>>,
59    /// The protocol version negotiated during initialization.
60    ///
61    /// This is stored as a `ProtocolVersion` enum for type-safe feature detection.
62    /// Use methods like `protocol_version().supports_tasks()` to check capabilities.
63    pub negotiated_version: RwLock<Option<ProtocolVersion>>,
64}
65
66impl ServerState {
67    /// Create a new server state.
68    #[must_use]
69    pub fn new(server_caps: ServerCapabilities) -> Self {
70        Self {
71            client_caps: RwLock::new(ClientCapabilities::default()),
72            server_caps,
73            initialized: AtomicBool::new(false),
74            cancellations: RwLock::new(HashMap::new()),
75            negotiated_version: RwLock::new(None),
76        }
77    }
78
79    /// Get the negotiated protocol version.
80    ///
81    /// Returns `None` if not yet initialized.
82    ///
83    /// # Example
84    ///
85    /// ```rust,ignore
86    /// if let Some(version) = state.protocol_version() {
87    ///     if version.supports_tasks() {
88    ///         // Tasks are available in this session
89    ///     }
90    /// }
91    /// ```
92    pub fn protocol_version(&self) -> Option<ProtocolVersion> {
93        self.negotiated_version.read().ok().and_then(|guard| *guard)
94    }
95
96    /// Set the negotiated protocol version.
97    ///
98    /// Silently fails if the lock is poisoned.
99    pub fn set_protocol_version(&self, version: ProtocolVersion) {
100        if let Ok(mut guard) = self.negotiated_version.write() {
101            *guard = Some(version);
102        }
103    }
104
105    /// Get a snapshot of client capabilities.
106    ///
107    /// Returns default capabilities if the lock is poisoned.
108    pub fn client_caps(&self) -> ClientCapabilities {
109        self.client_caps
110            .read()
111            .map(|guard| guard.clone())
112            .unwrap_or_default()
113    }
114
115    /// Update client capabilities.
116    ///
117    /// Silently fails if the lock is poisoned.
118    pub fn set_client_caps(&self, caps: ClientCapabilities) {
119        if let Ok(mut guard) = self.client_caps.write() {
120            *guard = caps;
121        }
122    }
123
124    /// Check if the server is initialized.
125    pub fn is_initialized(&self) -> bool {
126        self.initialized.load(Ordering::Acquire)
127    }
128
129    /// Mark the server as initialized.
130    pub fn set_initialized(&self) {
131        self.initialized.store(true, Ordering::Release);
132    }
133
134    /// Register a cancellation token for a request.
135    pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
136        if let Ok(mut cancellations) = self.cancellations.write() {
137            cancellations.insert(request_id.to_string(), token);
138        }
139    }
140
141    /// Cancel a request by ID.
142    pub fn cancel_request(&self, request_id: &str) {
143        if let Ok(cancellations) = self.cancellations.read() {
144            if let Some(token) = cancellations.get(request_id) {
145                token.cancel();
146            }
147        }
148    }
149
150    /// Remove a cancellation token after request completion.
151    pub fn remove_cancellation(&self, request_id: &str) {
152        if let Ok(mut cancellations) = self.cancellations.write() {
153            cancellations.remove(request_id);
154        }
155    }
156}
157
158/// A peer implementation that sends notifications over a transport.
159pub struct TransportPeer<T: Transport> {
160    transport: Arc<T>,
161}
162
163impl<T: Transport> TransportPeer<T> {
164    /// Create a new transport peer.
165    pub const fn new(transport: Arc<T>) -> Self {
166        Self { transport }
167    }
168}
169
170impl<T: Transport + 'static> Peer for TransportPeer<T>
171where
172    T::Error: Into<McpError>,
173{
174    fn notify(
175        &self,
176        notification: Notification,
177    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
178    {
179        let transport = self.transport.clone();
180        Box::pin(async move {
181            transport
182                .send(Message::Notification(notification))
183                .await
184                .map_err(std::convert::Into::into)
185        })
186    }
187}
188
189/// Server runtime configuration.
190#[derive(Debug, Clone)]
191pub struct RuntimeConfig {
192    /// Whether to automatically send initialized notification.
193    pub auto_initialized: bool,
194    /// Maximum concurrent requests to process.
195    pub max_concurrent_requests: usize,
196}
197
198impl Default for RuntimeConfig {
199    fn default() -> Self {
200        Self {
201            auto_initialized: true,
202            max_concurrent_requests: 100,
203        }
204    }
205}
206
207/// Server runtime that handles the message loop.
208///
209/// This runtime manages the connection lifecycle, routes requests to
210/// handlers, and coordinates response delivery.
211pub struct ServerRuntime<S, Tr>
212where
213    Tr: Transport,
214{
215    server: S,
216    transport: Arc<Tr>,
217    state: Arc<ServerState>,
218    /// Runtime configuration (request timeouts, etc.) - will be used by advanced features.
219    #[allow(dead_code)]
220    config: RuntimeConfig,
221}
222
223impl<S, Tr> ServerRuntime<S, Tr>
224where
225    S: RequestRouter + Send + Sync,
226    Tr: Transport + 'static,
227    Tr::Error: Into<McpError>,
228{
229    /// Get the server state.
230    pub const fn state(&self) -> &Arc<ServerState> {
231        &self.state
232    }
233
234    /// Run the server message loop.
235    ///
236    /// This method runs until the connection is closed or an error occurs.
237    pub async fn run(&self) -> Result<(), McpError> {
238        loop {
239            match self.transport.recv().await {
240                Ok(Some(message)) => {
241                    if let Err(e) = self.handle_message(message).await {
242                        tracing::error!(error = %e, "Error handling message");
243                    }
244                }
245                Ok(None) => {
246                    // Connection closed cleanly
247                    tracing::info!("Connection closed");
248                    break;
249                }
250                Err(e) => {
251                    let err: McpError = e.into();
252                    tracing::error!(error = %err, "Transport error");
253                    return Err(err);
254                }
255            }
256        }
257
258        Ok(())
259    }
260
261    /// Handle a single message.
262    async fn handle_message(&self, message: Message) -> Result<(), McpError> {
263        match message {
264            Message::Request(request) => self.handle_request(request).await,
265            Message::Notification(notification) => self.handle_notification(notification).await,
266            Message::Response(_) => {
267                // Servers don't typically receive responses
268                tracing::warn!("Received unexpected response message");
269                Ok(())
270            }
271        }
272    }
273
274    /// Handle a request.
275    async fn handle_request(&self, request: Request) -> Result<(), McpError> {
276        let method = request.method.to_string();
277        let id = request.id.clone();
278
279        tracing::debug!(method = %method, id = %id, "Handling request");
280
281        let response = match method.as_str() {
282            "initialize" => self.handle_initialize(&request).await,
283            _ if !self.state.is_initialized() => {
284                Err(McpError::invalid_request("Server not initialized"))
285            }
286            _ => self.route_request(&request).await,
287        };
288
289        // Send response
290        let response_msg = match response {
291            Ok(result) => Response::success(id, result),
292            Err(e) => Response::error(id, e.into()),
293        };
294
295        self.transport
296            .send(Message::Response(response_msg))
297            .await
298            .map_err(std::convert::Into::into)
299    }
300
301    /// Handle the initialize request.
302    ///
303    /// This performs protocol version negotiation according to the MCP specification:
304    /// 1. Client sends its preferred protocol version
305    /// 2. Server responds with the same version if supported, or its preferred version
306    /// 3. Client must support the returned version or disconnect
307    async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
308        if self.state.is_initialized() {
309            return Err(McpError::invalid_request("Already initialized"));
310        }
311
312        // Parse initialize params
313        let params = request
314            .params
315            .as_ref()
316            .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
317
318        // Extract and negotiate protocol version using type-safe enum
319        let requested_version_str = params
320            .get("protocolVersion")
321            .and_then(|v| v.as_str())
322            .unwrap_or("");
323
324        // Negotiate using the ProtocolVersion enum for type safety
325        let negotiated_version =
326            ProtocolVersion::negotiate(requested_version_str, ProtocolVersion::ALL)
327                .unwrap_or(ProtocolVersion::LATEST);
328
329        // Log version negotiation details for debugging
330        if requested_version_str == negotiated_version.as_str() {
331            tracing::debug!(
332                version = %negotiated_version,
333                "Protocol version negotiated successfully"
334            );
335        } else {
336            tracing::info!(
337                requested = %requested_version_str,
338                negotiated = %negotiated_version,
339                supported = ?ProtocolVersion::ALL.iter().map(ProtocolVersion::as_str).collect::<Vec<_>>(),
340                "Protocol version negotiation: client requested different version"
341            );
342        }
343
344        // Store the negotiated version (type-safe enum)
345        self.state.set_protocol_version(negotiated_version);
346
347        // Extract client info and capabilities
348        if let Some(caps) = params.get("capabilities") {
349            if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
350                self.state.set_client_caps(client_caps);
351            }
352        }
353
354        // Build response with negotiated version (serialized to string by serde)
355        let result = serde_json::json!({
356            "protocolVersion": negotiated_version.as_str(),
357            "serverInfo": self.server.server_info(),
358            "capabilities": self.state.server_caps
359        });
360
361        self.state.set_initialized();
362
363        Ok(result)
364    }
365
366    /// Route a request to the appropriate handler.
367    async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
368        let method = request.method.as_ref();
369        let params = request.params.as_ref();
370
371        // Extract progress token from params._meta.progressToken if present
372        let progress_token = extract_progress_token(params);
373
374        // Create context for the handler
375        let peer = TransportPeer::new(self.transport.clone());
376        let client_caps = self.state.client_caps();
377        let protocol_version = self
378            .state
379            .protocol_version()
380            .unwrap_or(ProtocolVersion::LATEST);
381        let ctx = Context::new(
382            &request.id,
383            progress_token.as_ref(),
384            &client_caps,
385            &self.state.server_caps,
386            protocol_version,
387            &peer,
388        );
389
390        // Delegate to the router
391        self.server.route(method, params, &ctx).await
392    }
393
394    /// Handle a notification.
395    async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
396        let method = notification.method.as_ref();
397
398        tracing::debug!(method = %method, "Handling notification");
399
400        match method {
401            "notifications/initialized" => {
402                tracing::info!("Client sent initialized notification");
403                Ok(())
404            }
405            "notifications/cancelled" => {
406                if let Some(params) = &notification.params {
407                    if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
408                        self.state.cancel_request(request_id);
409                    }
410                }
411                Ok(())
412            }
413            _ => {
414                tracing::debug!(method = %method, "Ignoring unknown notification");
415                Ok(())
416            }
417        }
418    }
419}
420
421// Constructor implementations for ServerRuntime with different server types
422impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
423where
424    H: ServerHandler + Send + Sync,
425    T: Send + Sync,
426    R: Send + Sync,
427    P: Send + Sync,
428    K: Send + Sync,
429    Tr: Transport + 'static,
430    Tr::Error: Into<McpError>,
431{
432    /// Create a new server runtime.
433    pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
434        let caps = server.capabilities().clone();
435        Self {
436            server,
437            transport: Arc::new(transport),
438            state: Arc::new(ServerState::new(caps)),
439            config: RuntimeConfig::default(),
440        }
441    }
442
443    /// Create a new server runtime with custom configuration.
444    pub fn with_config(
445        server: Server<H, T, R, P, K>,
446        transport: Tr,
447        config: RuntimeConfig,
448    ) -> Self {
449        let caps = server.capabilities().clone();
450        Self {
451            server,
452            transport: Arc::new(transport),
453            state: Arc::new(ServerState::new(caps)),
454            config,
455        }
456    }
457}
458
459/// Trait for routing requests to handlers.
460///
461/// This trait is implemented by Server with different bounds depending on
462/// which handlers are registered.
463#[allow(async_fn_in_trait)]
464pub trait RequestRouter: Send + Sync {
465    /// Get the server info.
466    fn server_info(&self) -> mcpkit_core::capability::ServerInfo;
467
468    /// Route a request and return the result.
469    async fn route(
470        &self,
471        method: &str,
472        params: Option<&serde_json::Value>,
473        ctx: &Context<'_>,
474    ) -> Result<serde_json::Value, McpError>;
475}
476
477/// Extension methods for Server to run with a transport.
478impl<H, T, R, P, K> Server<H, T, R, P, K>
479where
480    H: ServerHandler + Send + Sync + 'static,
481    T: Send + Sync + 'static,
482    R: Send + Sync + 'static,
483    P: Send + Sync + 'static,
484    K: Send + Sync + 'static,
485    Self: RequestRouter,
486{
487    /// Run this server over the given transport.
488    pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
489    where
490        Tr: Transport + 'static,
491        Tr::Error: Into<McpError>,
492    {
493        let runtime = ServerRuntime::new(self, transport);
494        runtime.run().await
495    }
496}
497
498// ============================================================================
499// RequestRouter implementations via macro
500// ============================================================================
501
502// Internal routing functions to reduce code duplication.
503// Each function handles a specific handler type's methods.
504
505async fn route_tools<TH: ToolHandler + Send + Sync>(
506    handler: &TH,
507    method: &str,
508    params: Option<&serde_json::Value>,
509    ctx: &Context<'_>,
510) -> Option<Result<serde_json::Value, McpError>> {
511    match method {
512        "tools/list" => {
513            tracing::debug!("Listing available tools");
514            let result = handler.list_tools(ctx).await;
515            match &result {
516                Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
517                Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
518            }
519            Some(result.map(|tools| serde_json::json!({ "tools": tools })))
520        }
521        "tools/call" => {
522            let result = async {
523                let params = params.ok_or_else(|| {
524                    McpError::invalid_params("tools/call", "missing params")
525                })?;
526                let name = params.get("name")
527                    .and_then(|v| v.as_str())
528                    .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
529                let args = params.get("arguments")
530                    .cloned()
531                    .unwrap_or_else(|| serde_json::json!({}));
532
533                tracing::info!(tool = %name, "Calling tool");
534                let start = std::time::Instant::now();
535                let output = handler.call_tool(name, args, ctx).await;
536                let duration = start.elapsed();
537
538                match &output {
539                    Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
540                    Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
541                }
542
543                let output = output?;
544                let result: CallToolResult = output.into();
545                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
546            }.await;
547            Some(result)
548        }
549        _ => None,
550    }
551}
552
553async fn route_resources<RH: ResourceHandler + Send + Sync>(
554    handler: &RH,
555    method: &str,
556    params: Option<&serde_json::Value>,
557    ctx: &Context<'_>,
558) -> Option<Result<serde_json::Value, McpError>> {
559    match method {
560        "resources/list" => {
561            tracing::debug!("Listing available resources");
562            let result = handler.list_resources(ctx).await;
563            match &result {
564                Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
565                Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
566            }
567            Some(result.map(|resources| serde_json::json!({ "resources": resources })))
568        }
569        "resources/templates/list" => {
570            tracing::debug!("Listing available resource templates");
571            let result = handler.list_resource_templates(ctx).await;
572            match &result {
573                Ok(templates) => {
574                    tracing::debug!(count = templates.len(), "Listed resource templates");
575                }
576                Err(e) => tracing::warn!(error = %e, "Failed to list resource templates"),
577            }
578            Some(result.map(|templates| serde_json::json!({ "resourceTemplates": templates })))
579        }
580        "resources/read" => {
581            let result = async {
582                let params = params.ok_or_else(|| {
583                    McpError::invalid_params("resources/read", "missing params")
584                })?;
585                let uri = params.get("uri")
586                    .and_then(|v| v.as_str())
587                    .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
588
589                tracing::info!(uri = %uri, "Reading resource");
590                let start = std::time::Instant::now();
591                let contents = handler.read_resource(uri, ctx).await;
592                let duration = start.elapsed();
593
594                match &contents {
595                    Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
596                    Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
597                }
598
599                let contents = contents?;
600                Ok(serde_json::json!({ "contents": contents }))
601            }.await;
602            Some(result)
603        }
604        _ => None,
605    }
606}
607
608async fn route_prompts<PH: PromptHandler + Send + Sync>(
609    handler: &PH,
610    method: &str,
611    params: Option<&serde_json::Value>,
612    ctx: &Context<'_>,
613) -> Option<Result<serde_json::Value, McpError>> {
614    match method {
615        "prompts/list" => {
616            tracing::debug!("Listing available prompts");
617            let result = handler.list_prompts(ctx).await;
618            match &result {
619                Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
620                Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
621            }
622            Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
623        }
624        "prompts/get" => {
625            let result = async {
626                let params = params.ok_or_else(|| {
627                    McpError::invalid_params("prompts/get", "missing params")
628                })?;
629                let name = params.get("name")
630                    .and_then(|v| v.as_str())
631                    .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
632                let args = params.get("arguments")
633                    .and_then(|v| v.as_object())
634                    .cloned();
635
636                tracing::info!(prompt = %name, "Getting prompt");
637                let start = std::time::Instant::now();
638                let prompt_result = handler.get_prompt(name, args, ctx).await;
639                let duration = start.elapsed();
640
641                match &prompt_result {
642                    Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
643                    Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
644                }
645
646                let result = prompt_result?;
647                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
648            }.await;
649            Some(result)
650        }
651        _ => None,
652    }
653}
654
655/// Macro to generate `RequestRouter` implementations for all handler combinations.
656///
657/// This macro reduces code duplication by generating all 2^3 = 8 combinations
658/// of tool/resource/prompt handler registration states.
659macro_rules! impl_request_router {
660    // Base case: no handlers
661    (base; $($bounds:tt)*) => {
662        impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
663        where
664            H: ServerHandler + Send + Sync,
665        {
666            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
667                self.handler().server_info()
668            }
669
670            async fn route(
671                &self,
672                method: &str,
673                _params: Option<&serde_json::Value>,
674                _ctx: &Context<'_>,
675            ) -> Result<serde_json::Value, McpError> {
676                match method {
677                    "ping" => Ok(serde_json::json!({})),
678                    _ => Err(McpError::method_not_found(method)),
679                }
680            }
681        }
682    };
683
684    // Tools only
685    (tools; $($bounds:tt)*) => {
686        impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
687        where
688            H: ServerHandler + Send + Sync,
689            TH: ToolHandler + Send + Sync,
690        {
691            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
692                self.handler().server_info()
693            }
694
695            async fn route(
696                &self,
697                method: &str,
698                params: Option<&serde_json::Value>,
699                ctx: &Context<'_>,
700            ) -> Result<serde_json::Value, McpError> {
701                if method == "ping" {
702                    return Ok(serde_json::json!({}));
703                }
704                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
705                    return result;
706                }
707                Err(McpError::method_not_found(method))
708            }
709        }
710    };
711
712    // Resources only
713    (resources; $($bounds:tt)*) => {
714        impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
715        where
716            H: ServerHandler + Send + Sync,
717            RH: ResourceHandler + Send + Sync,
718        {
719            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
720                self.handler().server_info()
721            }
722
723            async fn route(
724                &self,
725                method: &str,
726                params: Option<&serde_json::Value>,
727                ctx: &Context<'_>,
728            ) -> Result<serde_json::Value, McpError> {
729                if method == "ping" {
730                    return Ok(serde_json::json!({}));
731                }
732                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
733                    return result;
734                }
735                Err(McpError::method_not_found(method))
736            }
737        }
738    };
739
740    // Prompts only
741    (prompts; $($bounds:tt)*) => {
742        impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
743        where
744            H: ServerHandler + Send + Sync,
745            PH: PromptHandler + Send + Sync,
746        {
747            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
748                self.handler().server_info()
749            }
750
751            async fn route(
752                &self,
753                method: &str,
754                params: Option<&serde_json::Value>,
755                ctx: &Context<'_>,
756            ) -> Result<serde_json::Value, McpError> {
757                if method == "ping" {
758                    return Ok(serde_json::json!({}));
759                }
760                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
761                    return result;
762                }
763                Err(McpError::method_not_found(method))
764            }
765        }
766    };
767
768    // Tools + Resources
769    (tools_resources; $($bounds:tt)*) => {
770        impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
771        where
772            H: ServerHandler + Send + Sync,
773            TH: ToolHandler + Send + Sync,
774            RH: ResourceHandler + Send + Sync,
775        {
776            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
777                self.handler().server_info()
778            }
779
780            async fn route(
781                &self,
782                method: &str,
783                params: Option<&serde_json::Value>,
784                ctx: &Context<'_>,
785            ) -> Result<serde_json::Value, McpError> {
786                if method == "ping" {
787                    return Ok(serde_json::json!({}));
788                }
789                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
790                    return result;
791                }
792                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
793                    return result;
794                }
795                Err(McpError::method_not_found(method))
796            }
797        }
798    };
799
800    // Tools + Prompts
801    (tools_prompts; $($bounds:tt)*) => {
802        impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
803        where
804            H: ServerHandler + Send + Sync,
805            TH: ToolHandler + Send + Sync,
806            PH: PromptHandler + Send + Sync,
807        {
808            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
809                self.handler().server_info()
810            }
811
812            async fn route(
813                &self,
814                method: &str,
815                params: Option<&serde_json::Value>,
816                ctx: &Context<'_>,
817            ) -> Result<serde_json::Value, McpError> {
818                if method == "ping" {
819                    return Ok(serde_json::json!({}));
820                }
821                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
822                    return result;
823                }
824                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
825                    return result;
826                }
827                Err(McpError::method_not_found(method))
828            }
829        }
830    };
831
832    // Resources + Prompts
833    (resources_prompts; $($bounds:tt)*) => {
834        impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
835        where
836            H: ServerHandler + Send + Sync,
837            RH: ResourceHandler + Send + Sync,
838            PH: PromptHandler + Send + Sync,
839        {
840            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
841                self.handler().server_info()
842            }
843
844            async fn route(
845                &self,
846                method: &str,
847                params: Option<&serde_json::Value>,
848                ctx: &Context<'_>,
849            ) -> Result<serde_json::Value, McpError> {
850                if method == "ping" {
851                    return Ok(serde_json::json!({}));
852                }
853                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
854                    return result;
855                }
856                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
857                    return result;
858                }
859                Err(McpError::method_not_found(method))
860            }
861        }
862    };
863
864    // Tools + Resources + Prompts
865    (tools_resources_prompts; $($bounds:tt)*) => {
866        impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
867        where
868            H: ServerHandler + Send + Sync,
869            TH: ToolHandler + Send + Sync,
870            RH: ResourceHandler + Send + Sync,
871            PH: PromptHandler + Send + Sync,
872        {
873            fn server_info(&self) -> mcpkit_core::capability::ServerInfo {
874                self.handler().server_info()
875            }
876
877            async fn route(
878                &self,
879                method: &str,
880                params: Option<&serde_json::Value>,
881                ctx: &Context<'_>,
882            ) -> Result<serde_json::Value, McpError> {
883                if method == "ping" {
884                    return Ok(serde_json::json!({}));
885                }
886                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
887                    return result;
888                }
889                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
890                    return result;
891                }
892                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
893                    return result;
894                }
895                Err(McpError::method_not_found(method))
896            }
897        }
898    };
899}
900
901// Generate all RequestRouter implementations
902impl_request_router!(base;);
903impl_request_router!(tools;);
904impl_request_router!(resources;);
905impl_request_router!(prompts;);
906impl_request_router!(tools_resources;);
907impl_request_router!(tools_prompts;);
908impl_request_router!(resources_prompts;);
909impl_request_router!(tools_resources_prompts;);
910
911// ============================================================================
912// Helper functions
913// ============================================================================
914
915/// Extract a progress token from request parameters.
916///
917/// Per the MCP specification, progress tokens are sent in the `_meta.progressToken`
918/// field of request parameters. This function attempts to extract and parse that
919/// field into a `ProgressToken`.
920///
921/// # Example JSON structure
922/// ```json
923/// {
924///   "_meta": {
925///     "progressToken": "token-123"
926///   },
927///   "name": "my-tool",
928///   "arguments": {}
929/// }
930/// ```
931fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
932    params?
933        .get("_meta")?
934        .get("progressToken")
935        .and_then(|v| serde_json::from_value(v.clone()).ok())
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941
942    #[test]
943    fn test_server_state_initialization() {
944        let state = ServerState::new(ServerCapabilities::default());
945        assert!(!state.is_initialized());
946
947        state.set_initialized();
948        assert!(state.is_initialized());
949    }
950
951    #[test]
952    fn test_cancellation_management() {
953        let state = ServerState::new(ServerCapabilities::default());
954        let token = CancellationToken::new();
955
956        state.register_cancellation("req-1", token.clone());
957        assert!(!token.is_cancelled());
958
959        state.cancel_request("req-1");
960        assert!(token.is_cancelled());
961
962        state.remove_cancellation("req-1");
963    }
964
965    #[test]
966    fn test_runtime_config_default() {
967        let config = RuntimeConfig::default();
968        assert!(config.auto_initialized);
969        assert_eq!(config.max_concurrent_requests, 100);
970    }
971
972    #[test]
973    fn test_extract_progress_token_string() -> Result<(), Box<dyn std::error::Error>> {
974        let params = serde_json::json!({
975            "_meta": {
976                "progressToken": "my-token-123"
977            },
978            "name": "test-tool"
979        });
980        let token = extract_progress_token(Some(&params));
981        assert!(token.is_some());
982        assert_eq!(
983            token.ok_or("Token not found")?,
984            ProgressToken::String("my-token-123".to_string())
985        );
986
987        Ok(())
988    }
989
990    #[test]
991    fn test_extract_progress_token_number() -> Result<(), Box<dyn std::error::Error>> {
992        let params = serde_json::json!({
993            "_meta": {
994                "progressToken": 42
995            },
996            "arguments": {}
997        });
998        let token = extract_progress_token(Some(&params));
999        assert!(token.is_some());
1000        assert_eq!(token.ok_or("Token not found")?, ProgressToken::Number(42));
1001
1002        Ok(())
1003    }
1004
1005    #[test]
1006    fn test_extract_progress_token_missing_meta() {
1007        let params = serde_json::json!({
1008            "name": "test-tool",
1009            "arguments": {}
1010        });
1011        let token = extract_progress_token(Some(&params));
1012        assert!(token.is_none());
1013    }
1014
1015    #[test]
1016    fn test_extract_progress_token_missing_token() {
1017        let params = serde_json::json!({
1018            "_meta": {},
1019            "name": "test-tool"
1020        });
1021        let token = extract_progress_token(Some(&params));
1022        assert!(token.is_none());
1023    }
1024
1025    #[test]
1026    fn test_extract_progress_token_none_params() {
1027        let token = extract_progress_token(None);
1028        assert!(token.is_none());
1029    }
1030}