mcpkit_server/
server.rs

1//! Server runtime for MCP servers.
2//!
3//! This module provides the runtime that executes an MCP server over
4//! a transport, handling message routing, request correlation, and
5//! the connection lifecycle.
6//!
7//! # Overview
8//!
9//! The server runtime:
10//! 1. Accepts a transport for communication
11//! 2. Handles the initialize/initialized handshake
12//! 3. Routes incoming requests to the appropriate handlers
13//! 4. Manages the connection lifecycle
14//!
15//! # Example
16//!
17//! ```rust
18//! use mcpkit_server::{ServerBuilder, ServerHandler, ServerState};
19//! use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
20//!
21//! struct MyHandler;
22//! impl ServerHandler for MyHandler {
23//!     fn server_info(&self) -> ServerInfo {
24//!         ServerInfo::new("my-server", "1.0.0")
25//!     }
26//! }
27//!
28//! // Build a server and create server state
29//! let server = ServerBuilder::new(MyHandler).build();
30//! let state = ServerState::new(server.capabilities().clone());
31//!
32//! assert!(!state.is_initialized());
33//! ```
34
35use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{
39    negotiate_version, ClientCapabilities, ServerCapabilities, SUPPORTED_PROTOCOL_VERSIONS,
40};
41use mcpkit_core::error::McpError;
42use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
43use mcpkit_core::types::CallToolResult;
44use mcpkit_transport::Transport;
45use std::collections::HashMap;
46use std::sync::atomic::{AtomicBool, Ordering};
47use std::sync::Arc;
48use std::sync::RwLock;
49
50/// State for a running server.
51pub struct ServerState {
52    /// Client capabilities negotiated during initialization.
53    pub client_caps: RwLock<ClientCapabilities>,
54    /// Server capabilities advertised during initialization.
55    pub server_caps: ServerCapabilities,
56    /// Whether the server has been initialized.
57    pub initialized: AtomicBool,
58    /// Active cancellation tokens by request ID.
59    pub cancellations: RwLock<HashMap<String, CancellationToken>>,
60    /// The protocol version negotiated during initialization.
61    pub negotiated_version: RwLock<Option<String>>,
62}
63
64impl ServerState {
65    /// Create a new server state.
66    #[must_use]
67    pub fn new(server_caps: ServerCapabilities) -> Self {
68        Self {
69            client_caps: RwLock::new(ClientCapabilities::default()),
70            server_caps,
71            initialized: AtomicBool::new(false),
72            cancellations: RwLock::new(HashMap::new()),
73            negotiated_version: RwLock::new(None),
74        }
75    }
76
77    /// Get the negotiated protocol version.
78    ///
79    /// Returns `None` if not yet initialized.
80    pub fn protocol_version(&self) -> Option<String> {
81        self.negotiated_version
82            .read()
83            .ok()
84            .and_then(|guard| guard.clone())
85    }
86
87    /// Set the negotiated protocol version.
88    ///
89    /// Silently fails if the lock is poisoned.
90    pub fn set_protocol_version(&self, version: String) {
91        if let Ok(mut guard) = self.negotiated_version.write() {
92            *guard = Some(version);
93        }
94    }
95
96    /// Get a snapshot of client capabilities.
97    ///
98    /// Returns default capabilities if the lock is poisoned.
99    pub fn client_caps(&self) -> ClientCapabilities {
100        self.client_caps
101            .read()
102            .map(|guard| guard.clone())
103            .unwrap_or_default()
104    }
105
106    /// Update client capabilities.
107    ///
108    /// Silently fails if the lock is poisoned.
109    pub fn set_client_caps(&self, caps: ClientCapabilities) {
110        if let Ok(mut guard) = self.client_caps.write() {
111            *guard = caps;
112        }
113    }
114
115    /// Check if the server is initialized.
116    pub fn is_initialized(&self) -> bool {
117        self.initialized.load(Ordering::Acquire)
118    }
119
120    /// Mark the server as initialized.
121    pub fn set_initialized(&self) {
122        self.initialized.store(true, Ordering::Release);
123    }
124
125    /// Register a cancellation token for a request.
126    pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
127        if let Ok(mut cancellations) = self.cancellations.write() {
128            cancellations.insert(request_id.to_string(), token);
129        }
130    }
131
132    /// Cancel a request by ID.
133    pub fn cancel_request(&self, request_id: &str) {
134        if let Ok(cancellations) = self.cancellations.read() {
135            if let Some(token) = cancellations.get(request_id) {
136                token.cancel();
137            }
138        }
139    }
140
141    /// Remove a cancellation token after request completion.
142    pub fn remove_cancellation(&self, request_id: &str) {
143        if let Ok(mut cancellations) = self.cancellations.write() {
144            cancellations.remove(request_id);
145        }
146    }
147}
148
149/// A peer implementation that sends notifications over a transport.
150pub struct TransportPeer<T: Transport> {
151    transport: Arc<T>,
152}
153
154impl<T: Transport> TransportPeer<T> {
155    /// Create a new transport peer.
156    pub const fn new(transport: Arc<T>) -> Self {
157        Self { transport }
158    }
159}
160
161impl<T: Transport + 'static> Peer for TransportPeer<T>
162where
163    T::Error: Into<McpError>,
164{
165    fn notify(
166        &self,
167        notification: Notification,
168    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
169    {
170        let transport = self.transport.clone();
171        Box::pin(async move {
172            transport
173                .send(Message::Notification(notification))
174                .await
175                .map_err(std::convert::Into::into)
176        })
177    }
178}
179
180/// Server runtime configuration.
181#[derive(Debug, Clone)]
182pub struct RuntimeConfig {
183    /// Whether to automatically send initialized notification.
184    pub auto_initialized: bool,
185    /// Maximum concurrent requests to process.
186    pub max_concurrent_requests: usize,
187}
188
189impl Default for RuntimeConfig {
190    fn default() -> Self {
191        Self {
192            auto_initialized: true,
193            max_concurrent_requests: 100,
194        }
195    }
196}
197
198/// Server runtime that handles the message loop.
199///
200/// This runtime manages the connection lifecycle, routes requests to
201/// handlers, and coordinates response delivery.
202pub struct ServerRuntime<S, Tr>
203where
204    Tr: Transport,
205{
206    server: S,
207    transport: Arc<Tr>,
208    state: Arc<ServerState>,
209    /// Runtime configuration (request timeouts, etc.) - will be used by advanced features.
210    #[allow(dead_code)]
211    config: RuntimeConfig,
212}
213
214impl<S, Tr> ServerRuntime<S, Tr>
215where
216    S: RequestRouter + Send + Sync,
217    Tr: Transport + 'static,
218    Tr::Error: Into<McpError>,
219{
220    /// Get the server state.
221    pub const fn state(&self) -> &Arc<ServerState> {
222        &self.state
223    }
224
225    /// Run the server message loop.
226    ///
227    /// This method runs until the connection is closed or an error occurs.
228    pub async fn run(&self) -> Result<(), McpError> {
229        loop {
230            match self.transport.recv().await {
231                Ok(Some(message)) => {
232                    if let Err(e) = self.handle_message(message).await {
233                        tracing::error!(error = %e, "Error handling message");
234                    }
235                }
236                Ok(None) => {
237                    // Connection closed cleanly
238                    tracing::info!("Connection closed");
239                    break;
240                }
241                Err(e) => {
242                    let err: McpError = e.into();
243                    tracing::error!(error = %err, "Transport error");
244                    return Err(err);
245                }
246            }
247        }
248
249        Ok(())
250    }
251
252    /// Handle a single message.
253    async fn handle_message(&self, message: Message) -> Result<(), McpError> {
254        match message {
255            Message::Request(request) => self.handle_request(request).await,
256            Message::Notification(notification) => self.handle_notification(notification).await,
257            Message::Response(_) => {
258                // Servers don't typically receive responses
259                tracing::warn!("Received unexpected response message");
260                Ok(())
261            }
262        }
263    }
264
265    /// Handle a request.
266    async fn handle_request(&self, request: Request) -> Result<(), McpError> {
267        let method = request.method.to_string();
268        let id = request.id.clone();
269
270        tracing::debug!(method = %method, id = %id, "Handling request");
271
272        let response = match method.as_str() {
273            "initialize" => self.handle_initialize(&request).await,
274            _ if !self.state.is_initialized() => {
275                Err(McpError::invalid_request("Server not initialized"))
276            }
277            _ => self.route_request(&request).await,
278        };
279
280        // Send response
281        let response_msg = match response {
282            Ok(result) => Response::success(id, result),
283            Err(e) => Response::error(id, e.into()),
284        };
285
286        self.transport
287            .send(Message::Response(response_msg))
288            .await
289            .map_err(std::convert::Into::into)
290    }
291
292    /// Handle the initialize request.
293    ///
294    /// This performs protocol version negotiation according to the MCP specification:
295    /// 1. Client sends its preferred protocol version
296    /// 2. Server responds with the same version if supported, or its preferred version
297    /// 3. Client must support the returned version or disconnect
298    async fn handle_initialize(&self, request: &Request) -> Result<serde_json::Value, McpError> {
299        if self.state.is_initialized() {
300            return Err(McpError::invalid_request("Already initialized"));
301        }
302
303        // Parse initialize params
304        let params = request
305            .params
306            .as_ref()
307            .ok_or_else(|| McpError::invalid_params("initialize", "missing params"))?;
308
309        // Extract and negotiate protocol version
310        let requested_version = params
311            .get("protocolVersion")
312            .and_then(|v| v.as_str())
313            .unwrap_or("");
314
315        let negotiated_version = negotiate_version(requested_version);
316
317        // Log version negotiation details for debugging
318        if requested_version == negotiated_version {
319            tracing::debug!(
320                version = %negotiated_version,
321                "Protocol version negotiated successfully"
322            );
323        } else {
324            tracing::info!(
325                requested = %requested_version,
326                negotiated = %negotiated_version,
327                supported = ?SUPPORTED_PROTOCOL_VERSIONS,
328                "Protocol version negotiation: client requested unsupported version"
329            );
330        }
331
332        // Store the negotiated version
333        self.state
334            .set_protocol_version(negotiated_version.to_string());
335
336        // Extract client info and capabilities
337        if let Some(caps) = params.get("capabilities") {
338            if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
339                self.state.set_client_caps(client_caps);
340            }
341        }
342
343        // Build response with negotiated version
344        let result = serde_json::json!({
345            "protocolVersion": negotiated_version,
346            "serverInfo": {
347                "name": "mcp-server",
348                "version": "1.0.0"
349            },
350            "capabilities": self.state.server_caps
351        });
352
353        self.state.set_initialized();
354
355        Ok(result)
356    }
357
358    /// Route a request to the appropriate handler.
359    async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
360        let method = request.method.as_ref();
361        let params = request.params.as_ref();
362
363        // Extract progress token from params._meta.progressToken if present
364        let progress_token = extract_progress_token(params);
365
366        // Create context for the handler
367        let peer = TransportPeer::new(self.transport.clone());
368        let client_caps = self.state.client_caps();
369        let ctx = Context::new(
370            &request.id,
371            progress_token.as_ref(),
372            &client_caps,
373            &self.state.server_caps,
374            &peer,
375        );
376
377        // Delegate to the router
378        self.server.route(method, params, &ctx).await
379    }
380
381    /// Handle a notification.
382    async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
383        let method = notification.method.as_ref();
384
385        tracing::debug!(method = %method, "Handling notification");
386
387        match method {
388            "notifications/initialized" => {
389                tracing::info!("Client sent initialized notification");
390                Ok(())
391            }
392            "notifications/cancelled" => {
393                if let Some(params) = &notification.params {
394                    if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
395                        self.state.cancel_request(request_id);
396                    }
397                }
398                Ok(())
399            }
400            _ => {
401                tracing::debug!(method = %method, "Ignoring unknown notification");
402                Ok(())
403            }
404        }
405    }
406}
407
408// Constructor implementations for ServerRuntime with different server types
409impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
410where
411    H: ServerHandler + Send + Sync,
412    T: Send + Sync,
413    R: Send + Sync,
414    P: Send + Sync,
415    K: Send + Sync,
416    Tr: Transport + 'static,
417    Tr::Error: Into<McpError>,
418{
419    /// Create a new server runtime.
420    pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
421        let caps = server.capabilities().clone();
422        Self {
423            server,
424            transport: Arc::new(transport),
425            state: Arc::new(ServerState::new(caps)),
426            config: RuntimeConfig::default(),
427        }
428    }
429
430    /// Create a new server runtime with custom configuration.
431    pub fn with_config(
432        server: Server<H, T, R, P, K>,
433        transport: Tr,
434        config: RuntimeConfig,
435    ) -> Self {
436        let caps = server.capabilities().clone();
437        Self {
438            server,
439            transport: Arc::new(transport),
440            state: Arc::new(ServerState::new(caps)),
441            config,
442        }
443    }
444}
445
446/// Trait for routing requests to handlers.
447///
448/// This trait is implemented by Server with different bounds depending on
449/// which handlers are registered.
450#[allow(async_fn_in_trait)]
451pub trait RequestRouter: Send + Sync {
452    /// Route a request and return the result.
453    async fn route(
454        &self,
455        method: &str,
456        params: Option<&serde_json::Value>,
457        ctx: &Context<'_>,
458    ) -> Result<serde_json::Value, McpError>;
459}
460
461/// Extension methods for Server to run with a transport.
462impl<H, T, R, P, K> Server<H, T, R, P, K>
463where
464    H: ServerHandler + Send + Sync + 'static,
465    T: Send + Sync + 'static,
466    R: Send + Sync + 'static,
467    P: Send + Sync + 'static,
468    K: Send + Sync + 'static,
469    Self: RequestRouter,
470{
471    /// Run this server over the given transport.
472    pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
473    where
474        Tr: Transport + 'static,
475        Tr::Error: Into<McpError>,
476    {
477        let runtime = ServerRuntime::new(self, transport);
478        runtime.run().await
479    }
480}
481
482// ============================================================================
483// RequestRouter implementations via macro
484// ============================================================================
485
486// Internal routing functions to reduce code duplication.
487// Each function handles a specific handler type's methods.
488
489async fn route_tools<TH: ToolHandler + Send + Sync>(
490    handler: &TH,
491    method: &str,
492    params: Option<&serde_json::Value>,
493    ctx: &Context<'_>,
494) -> Option<Result<serde_json::Value, McpError>> {
495    match method {
496        "tools/list" => {
497            tracing::debug!("Listing available tools");
498            let result = handler.list_tools(ctx).await;
499            match &result {
500                Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
501                Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
502            }
503            Some(result.map(|tools| serde_json::json!({ "tools": tools })))
504        }
505        "tools/call" => {
506            let result = async {
507                let params = params.ok_or_else(|| {
508                    McpError::invalid_params("tools/call", "missing params")
509                })?;
510                let name = params.get("name")
511                    .and_then(|v| v.as_str())
512                    .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
513                let args = params.get("arguments")
514                    .cloned()
515                    .unwrap_or_else(|| serde_json::json!({}));
516
517                tracing::info!(tool = %name, "Calling tool");
518                let start = std::time::Instant::now();
519                let output = handler.call_tool(name, args, ctx).await;
520                let duration = start.elapsed();
521
522                match &output {
523                    Ok(_) => tracing::info!(tool = %name, duration_ms = duration.as_millis(), "Tool call completed"),
524                    Err(e) => tracing::warn!(tool = %name, duration_ms = duration.as_millis(), error = %e, "Tool call failed"),
525                }
526
527                let output = output?;
528                let result: CallToolResult = output.into();
529                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
530            }.await;
531            Some(result)
532        }
533        _ => None,
534    }
535}
536
537async fn route_resources<RH: ResourceHandler + Send + Sync>(
538    handler: &RH,
539    method: &str,
540    params: Option<&serde_json::Value>,
541    ctx: &Context<'_>,
542) -> Option<Result<serde_json::Value, McpError>> {
543    match method {
544        "resources/list" => {
545            tracing::debug!("Listing available resources");
546            let result = handler.list_resources(ctx).await;
547            match &result {
548                Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
549                Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
550            }
551            Some(result.map(|resources| serde_json::json!({ "resources": resources })))
552        }
553        "resources/read" => {
554            let result = async {
555                let params = params.ok_or_else(|| {
556                    McpError::invalid_params("resources/read", "missing params")
557                })?;
558                let uri = params.get("uri")
559                    .and_then(|v| v.as_str())
560                    .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
561
562                tracing::info!(uri = %uri, "Reading resource");
563                let start = std::time::Instant::now();
564                let contents = handler.read_resource(uri, ctx).await;
565                let duration = start.elapsed();
566
567                match &contents {
568                    Ok(_) => tracing::info!(uri = %uri, duration_ms = duration.as_millis(), "Resource read completed"),
569                    Err(e) => tracing::warn!(uri = %uri, duration_ms = duration.as_millis(), error = %e, "Resource read failed"),
570                }
571
572                let contents = contents?;
573                Ok(serde_json::json!({ "contents": contents }))
574            }.await;
575            Some(result)
576        }
577        _ => None,
578    }
579}
580
581async fn route_prompts<PH: PromptHandler + Send + Sync>(
582    handler: &PH,
583    method: &str,
584    params: Option<&serde_json::Value>,
585    ctx: &Context<'_>,
586) -> Option<Result<serde_json::Value, McpError>> {
587    match method {
588        "prompts/list" => {
589            tracing::debug!("Listing available prompts");
590            let result = handler.list_prompts(ctx).await;
591            match &result {
592                Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
593                Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
594            }
595            Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
596        }
597        "prompts/get" => {
598            let result = async {
599                let params = params.ok_or_else(|| {
600                    McpError::invalid_params("prompts/get", "missing params")
601                })?;
602                let name = params.get("name")
603                    .and_then(|v| v.as_str())
604                    .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
605                let args = params.get("arguments")
606                    .and_then(|v| v.as_object())
607                    .cloned();
608
609                tracing::info!(prompt = %name, "Getting prompt");
610                let start = std::time::Instant::now();
611                let prompt_result = handler.get_prompt(name, args, ctx).await;
612                let duration = start.elapsed();
613
614                match &prompt_result {
615                    Ok(_) => tracing::info!(prompt = %name, duration_ms = duration.as_millis(), "Prompt retrieval completed"),
616                    Err(e) => tracing::warn!(prompt = %name, duration_ms = duration.as_millis(), error = %e, "Prompt retrieval failed"),
617                }
618
619                let result = prompt_result?;
620                Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
621            }.await;
622            Some(result)
623        }
624        _ => None,
625    }
626}
627
628/// Macro to generate `RequestRouter` implementations for all handler combinations.
629///
630/// This macro reduces code duplication by generating all 2^3 = 8 combinations
631/// of tool/resource/prompt handler registration states.
632macro_rules! impl_request_router {
633    // Base case: no handlers
634    (base; $($bounds:tt)*) => {
635        impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
636        where
637            H: ServerHandler + Send + Sync,
638        {
639            async fn route(
640                &self,
641                method: &str,
642                _params: Option<&serde_json::Value>,
643                _ctx: &Context<'_>,
644            ) -> Result<serde_json::Value, McpError> {
645                match method {
646                    "ping" => Ok(serde_json::json!({})),
647                    _ => Err(McpError::method_not_found(method)),
648                }
649            }
650        }
651    };
652
653    // Tools only
654    (tools; $($bounds:tt)*) => {
655        impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
656        where
657            H: ServerHandler + Send + Sync,
658            TH: ToolHandler + Send + Sync,
659        {
660            async fn route(
661                &self,
662                method: &str,
663                params: Option<&serde_json::Value>,
664                ctx: &Context<'_>,
665            ) -> Result<serde_json::Value, McpError> {
666                if method == "ping" {
667                    return Ok(serde_json::json!({}));
668                }
669                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
670                    return result;
671                }
672                Err(McpError::method_not_found(method))
673            }
674        }
675    };
676
677    // Resources only
678    (resources; $($bounds:tt)*) => {
679        impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
680        where
681            H: ServerHandler + Send + Sync,
682            RH: ResourceHandler + Send + Sync,
683        {
684            async fn route(
685                &self,
686                method: &str,
687                params: Option<&serde_json::Value>,
688                ctx: &Context<'_>,
689            ) -> Result<serde_json::Value, McpError> {
690                if method == "ping" {
691                    return Ok(serde_json::json!({}));
692                }
693                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
694                    return result;
695                }
696                Err(McpError::method_not_found(method))
697            }
698        }
699    };
700
701    // Prompts only
702    (prompts; $($bounds:tt)*) => {
703        impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
704        where
705            H: ServerHandler + Send + Sync,
706            PH: PromptHandler + Send + Sync,
707        {
708            async fn route(
709                &self,
710                method: &str,
711                params: Option<&serde_json::Value>,
712                ctx: &Context<'_>,
713            ) -> Result<serde_json::Value, McpError> {
714                if method == "ping" {
715                    return Ok(serde_json::json!({}));
716                }
717                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
718                    return result;
719                }
720                Err(McpError::method_not_found(method))
721            }
722        }
723    };
724
725    // Tools + Resources
726    (tools_resources; $($bounds:tt)*) => {
727        impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
728        where
729            H: ServerHandler + Send + Sync,
730            TH: ToolHandler + Send + Sync,
731            RH: ResourceHandler + Send + Sync,
732        {
733            async fn route(
734                &self,
735                method: &str,
736                params: Option<&serde_json::Value>,
737                ctx: &Context<'_>,
738            ) -> Result<serde_json::Value, McpError> {
739                if method == "ping" {
740                    return Ok(serde_json::json!({}));
741                }
742                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
743                    return result;
744                }
745                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
746                    return result;
747                }
748                Err(McpError::method_not_found(method))
749            }
750        }
751    };
752
753    // Tools + Prompts
754    (tools_prompts; $($bounds:tt)*) => {
755        impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
756        where
757            H: ServerHandler + Send + Sync,
758            TH: ToolHandler + Send + Sync,
759            PH: PromptHandler + Send + Sync,
760        {
761            async fn route(
762                &self,
763                method: &str,
764                params: Option<&serde_json::Value>,
765                ctx: &Context<'_>,
766            ) -> Result<serde_json::Value, McpError> {
767                if method == "ping" {
768                    return Ok(serde_json::json!({}));
769                }
770                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
771                    return result;
772                }
773                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
774                    return result;
775                }
776                Err(McpError::method_not_found(method))
777            }
778        }
779    };
780
781    // Resources + Prompts
782    (resources_prompts; $($bounds:tt)*) => {
783        impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
784        where
785            H: ServerHandler + Send + Sync,
786            RH: ResourceHandler + Send + Sync,
787            PH: PromptHandler + Send + Sync,
788        {
789            async fn route(
790                &self,
791                method: &str,
792                params: Option<&serde_json::Value>,
793                ctx: &Context<'_>,
794            ) -> Result<serde_json::Value, McpError> {
795                if method == "ping" {
796                    return Ok(serde_json::json!({}));
797                }
798                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
799                    return result;
800                }
801                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
802                    return result;
803                }
804                Err(McpError::method_not_found(method))
805            }
806        }
807    };
808
809    // Tools + Resources + Prompts
810    (tools_resources_prompts; $($bounds:tt)*) => {
811        impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
812        where
813            H: ServerHandler + Send + Sync,
814            TH: ToolHandler + Send + Sync,
815            RH: ResourceHandler + Send + Sync,
816            PH: PromptHandler + Send + Sync,
817        {
818            async fn route(
819                &self,
820                method: &str,
821                params: Option<&serde_json::Value>,
822                ctx: &Context<'_>,
823            ) -> Result<serde_json::Value, McpError> {
824                if method == "ping" {
825                    return Ok(serde_json::json!({}));
826                }
827                if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
828                    return result;
829                }
830                if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
831                    return result;
832                }
833                if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
834                    return result;
835                }
836                Err(McpError::method_not_found(method))
837            }
838        }
839    };
840}
841
842// Generate all RequestRouter implementations
843impl_request_router!(base;);
844impl_request_router!(tools;);
845impl_request_router!(resources;);
846impl_request_router!(prompts;);
847impl_request_router!(tools_resources;);
848impl_request_router!(tools_prompts;);
849impl_request_router!(resources_prompts;);
850impl_request_router!(tools_resources_prompts;);
851
852// ============================================================================
853// Helper functions
854// ============================================================================
855
856/// Extract a progress token from request parameters.
857///
858/// Per the MCP specification, progress tokens are sent in the `_meta.progressToken`
859/// field of request parameters. This function attempts to extract and parse that
860/// field into a `ProgressToken`.
861///
862/// # Example JSON structure
863/// ```json
864/// {
865///   "_meta": {
866///     "progressToken": "token-123"
867///   },
868///   "name": "my-tool",
869///   "arguments": {}
870/// }
871/// ```
872fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
873    params?
874        .get("_meta")?
875        .get("progressToken")
876        .and_then(|v| serde_json::from_value(v.clone()).ok())
877}
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    #[test]
884    fn test_server_state_initialization() {
885        let state = ServerState::new(ServerCapabilities::default());
886        assert!(!state.is_initialized());
887
888        state.set_initialized();
889        assert!(state.is_initialized());
890    }
891
892    #[test]
893    fn test_cancellation_management() {
894        let state = ServerState::new(ServerCapabilities::default());
895        let token = CancellationToken::new();
896
897        state.register_cancellation("req-1", token.clone());
898        assert!(!token.is_cancelled());
899
900        state.cancel_request("req-1");
901        assert!(token.is_cancelled());
902
903        state.remove_cancellation("req-1");
904    }
905
906    #[test]
907    fn test_runtime_config_default() {
908        let config = RuntimeConfig::default();
909        assert!(config.auto_initialized);
910        assert_eq!(config.max_concurrent_requests, 100);
911    }
912
913    #[test]
914    fn test_extract_progress_token_string() {
915        let params = serde_json::json!({
916            "_meta": {
917                "progressToken": "my-token-123"
918            },
919            "name": "test-tool"
920        });
921        let token = extract_progress_token(Some(&params));
922        assert!(token.is_some());
923        assert_eq!(
924            token.unwrap(),
925            ProgressToken::String("my-token-123".to_string())
926        );
927    }
928
929    #[test]
930    fn test_extract_progress_token_number() {
931        let params = serde_json::json!({
932            "_meta": {
933                "progressToken": 42
934            },
935            "arguments": {}
936        });
937        let token = extract_progress_token(Some(&params));
938        assert!(token.is_some());
939        assert_eq!(token.unwrap(), ProgressToken::Number(42));
940    }
941
942    #[test]
943    fn test_extract_progress_token_missing_meta() {
944        let params = serde_json::json!({
945            "name": "test-tool",
946            "arguments": {}
947        });
948        let token = extract_progress_token(Some(&params));
949        assert!(token.is_none());
950    }
951
952    #[test]
953    fn test_extract_progress_token_missing_token() {
954        let params = serde_json::json!({
955            "_meta": {},
956            "name": "test-tool"
957        });
958        let token = extract_progress_token(Some(&params));
959        assert!(token.is_none());
960    }
961
962    #[test]
963    fn test_extract_progress_token_none_params() {
964        let token = extract_progress_token(None);
965        assert!(token.is_none());
966    }
967}