1use crate::builder::{NotRegistered, Registered, Server};
36use crate::context::{CancellationToken, Context, Peer};
37use crate::handler::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
38use mcpkit_core::capability::{
39 negotiate_version, ClientCapabilities, ServerCapabilities, SUPPORTED_PROTOCOL_VERSIONS,
40};
41use mcpkit_core::error::McpError;
42use mcpkit_core::protocol::{Message, Notification, ProgressToken, Request, Response};
43use mcpkit_core::types::CallToolResult;
44use mcpkit_transport::Transport;
45use std::collections::HashMap;
46use std::sync::atomic::{AtomicBool, Ordering};
47use std::sync::Arc;
48use std::sync::RwLock;
49
50pub struct ServerState {
52 pub client_caps: RwLock<ClientCapabilities>,
54 pub server_caps: ServerCapabilities,
56 pub initialized: AtomicBool,
58 pub cancellations: RwLock<HashMap<String, CancellationToken>>,
60 pub negotiated_version: RwLock<Option<String>>,
62}
63
64impl ServerState {
65 pub fn new(server_caps: ServerCapabilities) -> Self {
67 Self {
68 client_caps: RwLock::new(ClientCapabilities::default()),
69 server_caps,
70 initialized: AtomicBool::new(false),
71 cancellations: RwLock::new(HashMap::new()),
72 negotiated_version: RwLock::new(None),
73 }
74 }
75
76 pub fn protocol_version(&self) -> Option<String> {
80 self.negotiated_version
81 .read()
82 .ok()
83 .and_then(|guard| guard.clone())
84 }
85
86 pub fn set_protocol_version(&self, version: String) {
90 if let Ok(mut guard) = self.negotiated_version.write() {
91 *guard = Some(version);
92 }
93 }
94
95 pub fn client_caps(&self) -> ClientCapabilities {
99 self.client_caps
100 .read()
101 .map(|guard| guard.clone())
102 .unwrap_or_default()
103 }
104
105 pub fn set_client_caps(&self, caps: ClientCapabilities) {
109 if let Ok(mut guard) = self.client_caps.write() {
110 *guard = caps;
111 }
112 }
113
114 pub fn is_initialized(&self) -> bool {
116 self.initialized.load(Ordering::Acquire)
117 }
118
119 pub fn set_initialized(&self) {
121 self.initialized.store(true, Ordering::Release);
122 }
123
124 pub fn register_cancellation(&self, request_id: &str, token: CancellationToken) {
126 if let Ok(mut cancellations) = self.cancellations.write() {
127 cancellations.insert(request_id.to_string(), token);
128 }
129 }
130
131 pub fn cancel_request(&self, request_id: &str) {
133 if let Ok(cancellations) = self.cancellations.read() {
134 if let Some(token) = cancellations.get(request_id) {
135 token.cancel();
136 }
137 }
138 }
139
140 pub fn remove_cancellation(&self, request_id: &str) {
142 if let Ok(mut cancellations) = self.cancellations.write() {
143 cancellations.remove(request_id);
144 }
145 }
146}
147
148pub struct TransportPeer<T: Transport> {
150 transport: Arc<T>,
151}
152
153impl<T: Transport> TransportPeer<T> {
154 pub fn new(transport: Arc<T>) -> Self {
156 Self { transport }
157 }
158}
159
160impl<T: Transport + 'static> Peer for TransportPeer<T>
161where
162 T::Error: Into<McpError>,
163{
164 fn notify(
165 &self,
166 notification: Notification,
167 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), McpError>> + Send + '_>>
168 {
169 let transport = self.transport.clone();
170 Box::pin(async move {
171 transport
172 .send(Message::Notification(notification))
173 .await
174 .map_err(|e| e.into())
175 })
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct RuntimeConfig {
182 pub auto_initialized: bool,
184 pub max_concurrent_requests: usize,
186}
187
188impl Default for RuntimeConfig {
189 fn default() -> Self {
190 Self {
191 auto_initialized: true,
192 max_concurrent_requests: 100,
193 }
194 }
195}
196
197pub struct ServerRuntime<S, Tr>
202where
203 Tr: Transport,
204{
205 server: S,
206 transport: Arc<Tr>,
207 state: Arc<ServerState>,
208 #[allow(dead_code)]
210 config: RuntimeConfig,
211}
212
213impl<S, Tr> ServerRuntime<S, Tr>
214where
215 S: RequestRouter + Send + Sync,
216 Tr: Transport + 'static,
217 Tr::Error: Into<McpError>,
218{
219 pub fn state(&self) -> &Arc<ServerState> {
221 &self.state
222 }
223
224 pub async fn run(&self) -> Result<(), McpError> {
228 loop {
229 match self.transport.recv().await {
230 Ok(Some(message)) => {
231 if let Err(e) = self.handle_message(message).await {
232 tracing::error!(error = %e, "Error handling message");
233 }
234 }
235 Ok(None) => {
236 tracing::info!("Connection closed");
238 break;
239 }
240 Err(e) => {
241 let err: McpError = e.into();
242 tracing::error!(error = %err, "Transport error");
243 return Err(err);
244 }
245 }
246 }
247
248 Ok(())
249 }
250
251 async fn handle_message(&self, message: Message) -> Result<(), McpError> {
253 match message {
254 Message::Request(request) => self.handle_request(request).await,
255 Message::Notification(notification) => self.handle_notification(notification).await,
256 Message::Response(_) => {
257 tracing::warn!("Received unexpected response message");
259 Ok(())
260 }
261 }
262 }
263
264 async fn handle_request(&self, request: Request) -> Result<(), McpError> {
266 let method = request.method.to_string();
267 let id = request.id.clone();
268
269 tracing::debug!(method = %method, id = %id, "Handling request");
270
271 let response = match method.as_str() {
272 "initialize" => self.handle_initialize(&request).await,
273 _ if !self.state.is_initialized() => {
274 Err(McpError::invalid_request("Server not initialized"))
275 }
276 _ => self.route_request(&request).await,
277 };
278
279 let response_msg = match response {
281 Ok(result) => Response::success(id, result),
282 Err(e) => Response::error(id, e.into()),
283 };
284
285 self.transport
286 .send(Message::Response(response_msg))
287 .await
288 .map_err(|e| e.into())
289 }
290
291 async fn handle_initialize(
298 &self,
299 request: &Request,
300 ) -> Result<serde_json::Value, McpError> {
301 if self.state.is_initialized() {
302 return Err(McpError::invalid_request("Already initialized"));
303 }
304
305 let params = request.params.as_ref().ok_or_else(|| {
307 McpError::invalid_params("initialize", "missing params")
308 })?;
309
310 let requested_version = params
312 .get("protocolVersion")
313 .and_then(|v| v.as_str())
314 .unwrap_or("");
315
316 let negotiated_version = negotiate_version(requested_version);
317
318 if requested_version != negotiated_version {
320 tracing::info!(
321 requested = %requested_version,
322 negotiated = %negotiated_version,
323 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
324 "Protocol version negotiation: client requested unsupported version"
325 );
326 } else {
327 tracing::debug!(
328 version = %negotiated_version,
329 "Protocol version negotiated successfully"
330 );
331 }
332
333 self.state.set_protocol_version(negotiated_version.to_string());
335
336 if let Some(caps) = params.get("capabilities") {
338 if let Ok(client_caps) = serde_json::from_value::<ClientCapabilities>(caps.clone()) {
339 self.state.set_client_caps(client_caps);
340 }
341 }
342
343 let result = serde_json::json!({
345 "protocolVersion": negotiated_version,
346 "serverInfo": {
347 "name": "mcp-server",
348 "version": "1.0.0"
349 },
350 "capabilities": self.state.server_caps
351 });
352
353 self.state.set_initialized();
354
355 Ok(result)
356 }
357
358 async fn route_request(&self, request: &Request) -> Result<serde_json::Value, McpError> {
360 let method = request.method.as_ref();
361 let params = request.params.as_ref();
362
363 let progress_token = extract_progress_token(params);
365
366 let peer = TransportPeer::new(self.transport.clone());
368 let client_caps = self.state.client_caps();
369 let ctx = Context::new(
370 &request.id,
371 progress_token.as_ref(),
372 &client_caps,
373 &self.state.server_caps,
374 &peer,
375 );
376
377 self.server.route(method, params, &ctx).await
379 }
380
381 async fn handle_notification(&self, notification: Notification) -> Result<(), McpError> {
383 let method = notification.method.as_ref();
384
385 tracing::debug!(method = %method, "Handling notification");
386
387 match method {
388 "notifications/initialized" => {
389 tracing::info!("Client sent initialized notification");
390 Ok(())
391 }
392 "notifications/cancelled" => {
393 if let Some(params) = ¬ification.params {
394 if let Some(request_id) = params.get("requestId").and_then(|v| v.as_str()) {
395 self.state.cancel_request(request_id);
396 }
397 }
398 Ok(())
399 }
400 _ => {
401 tracing::debug!(method = %method, "Ignoring unknown notification");
402 Ok(())
403 }
404 }
405 }
406}
407
408impl<H, T, R, P, K, Tr> ServerRuntime<Server<H, T, R, P, K>, Tr>
410where
411 H: ServerHandler + Send + Sync,
412 T: Send + Sync,
413 R: Send + Sync,
414 P: Send + Sync,
415 K: Send + Sync,
416 Tr: Transport + 'static,
417 Tr::Error: Into<McpError>,
418{
419 pub fn new(server: Server<H, T, R, P, K>, transport: Tr) -> Self {
421 let caps = server.capabilities().clone();
422 Self {
423 server,
424 transport: Arc::new(transport),
425 state: Arc::new(ServerState::new(caps)),
426 config: RuntimeConfig::default(),
427 }
428 }
429
430 pub fn with_config(server: Server<H, T, R, P, K>, transport: Tr, config: RuntimeConfig) -> Self {
432 let caps = server.capabilities().clone();
433 Self {
434 server,
435 transport: Arc::new(transport),
436 state: Arc::new(ServerState::new(caps)),
437 config,
438 }
439 }
440}
441
442#[allow(async_fn_in_trait)]
447pub trait RequestRouter: Send + Sync {
448 async fn route(
450 &self,
451 method: &str,
452 params: Option<&serde_json::Value>,
453 ctx: &Context<'_>,
454 ) -> Result<serde_json::Value, McpError>;
455}
456
457impl<H, T, R, P, K> Server<H, T, R, P, K>
459where
460 H: ServerHandler + Send + Sync + 'static,
461 T: Send + Sync + 'static,
462 R: Send + Sync + 'static,
463 P: Send + Sync + 'static,
464 K: Send + Sync + 'static,
465 Self: RequestRouter,
466{
467 pub async fn serve<Tr>(self, transport: Tr) -> Result<(), McpError>
469 where
470 Tr: Transport + 'static,
471 Tr::Error: Into<McpError>,
472 {
473 let runtime = ServerRuntime::new(self, transport);
474 runtime.run().await
475 }
476}
477
478async fn route_tools<TH: ToolHandler + Send + Sync>(
486 handler: &TH,
487 method: &str,
488 params: Option<&serde_json::Value>,
489 ctx: &Context<'_>,
490) -> Option<Result<serde_json::Value, McpError>> {
491 match method {
492 "tools/list" => {
493 let result = handler.list_tools(ctx).await;
494 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
495 }
496 "tools/call" => {
497 let result = (|| async {
498 let params = params.ok_or_else(|| {
499 McpError::invalid_params("tools/call", "missing params")
500 })?;
501 let name = params.get("name")
502 .and_then(|v| v.as_str())
503 .ok_or_else(|| McpError::invalid_params("tools/call", "missing tool name"))?;
504 let args = params.get("arguments")
505 .cloned()
506 .unwrap_or(serde_json::json!({}));
507 let output = handler.call_tool(name, args, ctx).await?;
508 let result: CallToolResult = output.into();
509 Ok(serde_json::to_value(result).unwrap_or(serde_json::json!({})))
510 })().await;
511 Some(result)
512 }
513 _ => None,
514 }
515}
516
517async fn route_resources<RH: ResourceHandler + Send + Sync>(
518 handler: &RH,
519 method: &str,
520 params: Option<&serde_json::Value>,
521 ctx: &Context<'_>,
522) -> Option<Result<serde_json::Value, McpError>> {
523 match method {
524 "resources/list" => {
525 let result = handler.list_resources(ctx).await;
526 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
527 }
528 "resources/read" => {
529 let result = (|| async {
530 let params = params.ok_or_else(|| {
531 McpError::invalid_params("resources/read", "missing params")
532 })?;
533 let uri = params.get("uri")
534 .and_then(|v| v.as_str())
535 .ok_or_else(|| McpError::invalid_params("resources/read", "missing uri"))?;
536 let contents = handler.read_resource(uri, ctx).await?;
537 Ok(serde_json::json!({ "contents": contents }))
538 })().await;
539 Some(result)
540 }
541 _ => None,
542 }
543}
544
545async fn route_prompts<PH: PromptHandler + Send + Sync>(
546 handler: &PH,
547 method: &str,
548 params: Option<&serde_json::Value>,
549 ctx: &Context<'_>,
550) -> Option<Result<serde_json::Value, McpError>> {
551 match method {
552 "prompts/list" => {
553 let result = handler.list_prompts(ctx).await;
554 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
555 }
556 "prompts/get" => {
557 let result = (|| async {
558 let params = params.ok_or_else(|| {
559 McpError::invalid_params("prompts/get", "missing params")
560 })?;
561 let name = params.get("name")
562 .and_then(|v| v.as_str())
563 .ok_or_else(|| McpError::invalid_params("prompts/get", "missing prompt name"))?;
564 let args = params.get("arguments")
565 .and_then(|v| v.as_object())
566 .cloned();
567 let result = handler.get_prompt(name, args, ctx).await?;
568 Ok(serde_json::to_value(result).unwrap_or(serde_json::json!({})))
569 })().await;
570 Some(result)
571 }
572 _ => None,
573 }
574}
575
576macro_rules! impl_request_router {
581 (base; $($bounds:tt)*) => {
583 impl<H $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, NotRegistered, NotRegistered>
584 where
585 H: ServerHandler + Send + Sync,
586 {
587 async fn route(
588 &self,
589 method: &str,
590 _params: Option<&serde_json::Value>,
591 _ctx: &Context<'_>,
592 ) -> Result<serde_json::Value, McpError> {
593 match method {
594 "ping" => Ok(serde_json::json!({})),
595 _ => Err(McpError::method_not_found(method)),
596 }
597 }
598 }
599 };
600
601 (tools; $($bounds:tt)*) => {
603 impl<H, TH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, NotRegistered, NotRegistered>
604 where
605 H: ServerHandler + Send + Sync,
606 TH: ToolHandler + Send + Sync,
607 {
608 async fn route(
609 &self,
610 method: &str,
611 params: Option<&serde_json::Value>,
612 ctx: &Context<'_>,
613 ) -> Result<serde_json::Value, McpError> {
614 if method == "ping" {
615 return Ok(serde_json::json!({}));
616 }
617 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
618 return result;
619 }
620 Err(McpError::method_not_found(method))
621 }
622 }
623 };
624
625 (resources; $($bounds:tt)*) => {
627 impl<H, RH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, NotRegistered, NotRegistered>
628 where
629 H: ServerHandler + Send + Sync,
630 RH: ResourceHandler + Send + Sync,
631 {
632 async fn route(
633 &self,
634 method: &str,
635 params: Option<&serde_json::Value>,
636 ctx: &Context<'_>,
637 ) -> Result<serde_json::Value, McpError> {
638 if method == "ping" {
639 return Ok(serde_json::json!({}));
640 }
641 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
642 return result;
643 }
644 Err(McpError::method_not_found(method))
645 }
646 }
647 };
648
649 (prompts; $($bounds:tt)*) => {
651 impl<H, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, NotRegistered, Registered<PH>, NotRegistered>
652 where
653 H: ServerHandler + Send + Sync,
654 PH: PromptHandler + Send + Sync,
655 {
656 async fn route(
657 &self,
658 method: &str,
659 params: Option<&serde_json::Value>,
660 ctx: &Context<'_>,
661 ) -> Result<serde_json::Value, McpError> {
662 if method == "ping" {
663 return Ok(serde_json::json!({}));
664 }
665 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
666 return result;
667 }
668 Err(McpError::method_not_found(method))
669 }
670 }
671 };
672
673 (tools_resources; $($bounds:tt)*) => {
675 impl<H, TH, RH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, NotRegistered, NotRegistered>
676 where
677 H: ServerHandler + Send + Sync,
678 TH: ToolHandler + Send + Sync,
679 RH: ResourceHandler + Send + Sync,
680 {
681 async fn route(
682 &self,
683 method: &str,
684 params: Option<&serde_json::Value>,
685 ctx: &Context<'_>,
686 ) -> Result<serde_json::Value, McpError> {
687 if method == "ping" {
688 return Ok(serde_json::json!({}));
689 }
690 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
691 return result;
692 }
693 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
694 return result;
695 }
696 Err(McpError::method_not_found(method))
697 }
698 }
699 };
700
701 (tools_prompts; $($bounds:tt)*) => {
703 impl<H, TH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, NotRegistered, Registered<PH>, NotRegistered>
704 where
705 H: ServerHandler + Send + Sync,
706 TH: ToolHandler + Send + Sync,
707 PH: PromptHandler + Send + Sync,
708 {
709 async fn route(
710 &self,
711 method: &str,
712 params: Option<&serde_json::Value>,
713 ctx: &Context<'_>,
714 ) -> Result<serde_json::Value, McpError> {
715 if method == "ping" {
716 return Ok(serde_json::json!({}));
717 }
718 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
719 return result;
720 }
721 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
722 return result;
723 }
724 Err(McpError::method_not_found(method))
725 }
726 }
727 };
728
729 (resources_prompts; $($bounds:tt)*) => {
731 impl<H, RH, PH $($bounds)*> RequestRouter for Server<H, NotRegistered, Registered<RH>, Registered<PH>, NotRegistered>
732 where
733 H: ServerHandler + Send + Sync,
734 RH: ResourceHandler + Send + Sync,
735 PH: PromptHandler + Send + Sync,
736 {
737 async fn route(
738 &self,
739 method: &str,
740 params: Option<&serde_json::Value>,
741 ctx: &Context<'_>,
742 ) -> Result<serde_json::Value, McpError> {
743 if method == "ping" {
744 return Ok(serde_json::json!({}));
745 }
746 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
747 return result;
748 }
749 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
750 return result;
751 }
752 Err(McpError::method_not_found(method))
753 }
754 }
755 };
756
757 (tools_resources_prompts; $($bounds:tt)*) => {
759 impl<H, TH, RH, PH $($bounds)*> RequestRouter for Server<H, Registered<TH>, Registered<RH>, Registered<PH>, NotRegistered>
760 where
761 H: ServerHandler + Send + Sync,
762 TH: ToolHandler + Send + Sync,
763 RH: ResourceHandler + Send + Sync,
764 PH: PromptHandler + Send + Sync,
765 {
766 async fn route(
767 &self,
768 method: &str,
769 params: Option<&serde_json::Value>,
770 ctx: &Context<'_>,
771 ) -> Result<serde_json::Value, McpError> {
772 if method == "ping" {
773 return Ok(serde_json::json!({}));
774 }
775 if let Some(result) = route_tools(self.tool_handler(), method, params, ctx).await {
776 return result;
777 }
778 if let Some(result) = route_resources(self.resource_handler(), method, params, ctx).await {
779 return result;
780 }
781 if let Some(result) = route_prompts(self.prompt_handler(), method, params, ctx).await {
782 return result;
783 }
784 Err(McpError::method_not_found(method))
785 }
786 }
787 };
788}
789
790impl_request_router!(base;);
792impl_request_router!(tools;);
793impl_request_router!(resources;);
794impl_request_router!(prompts;);
795impl_request_router!(tools_resources;);
796impl_request_router!(tools_prompts;);
797impl_request_router!(resources_prompts;);
798impl_request_router!(tools_resources_prompts;);
799
800fn extract_progress_token(params: Option<&serde_json::Value>) -> Option<ProgressToken> {
821 params?
822 .get("_meta")?
823 .get("progressToken")
824 .and_then(|v| serde_json::from_value(v.clone()).ok())
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830
831 #[test]
832 fn test_server_state_initialization() {
833 let state = ServerState::new(ServerCapabilities::default());
834 assert!(!state.is_initialized());
835
836 state.set_initialized();
837 assert!(state.is_initialized());
838 }
839
840 #[test]
841 fn test_cancellation_management() {
842 let state = ServerState::new(ServerCapabilities::default());
843 let token = CancellationToken::new();
844
845 state.register_cancellation("req-1", token.clone());
846 assert!(!token.is_cancelled());
847
848 state.cancel_request("req-1");
849 assert!(token.is_cancelled());
850
851 state.remove_cancellation("req-1");
852 }
853
854 #[test]
855 fn test_runtime_config_default() {
856 let config = RuntimeConfig::default();
857 assert!(config.auto_initialized);
858 assert_eq!(config.max_concurrent_requests, 100);
859 }
860
861 #[test]
862 fn test_extract_progress_token_string() {
863 let params = serde_json::json!({
864 "_meta": {
865 "progressToken": "my-token-123"
866 },
867 "name": "test-tool"
868 });
869 let token = extract_progress_token(Some(¶ms));
870 assert!(token.is_some());
871 assert_eq!(token.unwrap(), ProgressToken::String("my-token-123".to_string()));
872 }
873
874 #[test]
875 fn test_extract_progress_token_number() {
876 let params = serde_json::json!({
877 "_meta": {
878 "progressToken": 42
879 },
880 "arguments": {}
881 });
882 let token = extract_progress_token(Some(¶ms));
883 assert!(token.is_some());
884 assert_eq!(token.unwrap(), ProgressToken::Number(42));
885 }
886
887 #[test]
888 fn test_extract_progress_token_missing_meta() {
889 let params = serde_json::json!({
890 "name": "test-tool",
891 "arguments": {}
892 });
893 let token = extract_progress_token(Some(¶ms));
894 assert!(token.is_none());
895 }
896
897 #[test]
898 fn test_extract_progress_token_missing_token() {
899 let params = serde_json::json!({
900 "_meta": {},
901 "name": "test-tool"
902 });
903 let token = extract_progress_token(Some(¶ms));
904 assert!(token.is_none());
905 }
906
907 #[test]
908 fn test_extract_progress_token_none_params() {
909 let token = extract_progress_token(None);
910 assert!(token.is_none());
911 }
912}