1#[cfg(feature = "http-server")]
57pub mod http;
58#[cfg(feature = "stdio-server")]
59pub mod stdio;
60
61use std::collections::HashMap;
62use std::sync::atomic::{AtomicU8, Ordering};
63use std::sync::Arc;
64
65use serde_json::Value;
66use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
67
68use crate::protocol::{
69 ClientInbound, JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
70 McpCapabilities, McpServerInfo, ServerOutbound, MCP_PROTOCOL_VERSION,
71};
72use crate::tool::{DynTool, McpTool, ToolCallResult, ToolProvider, ToolRegistry};
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76#[repr(u8)]
77pub enum ServerStatus {
78 Stopped = 0,
80 Running = 1,
82 Faulted = 2,
84}
85
86impl From<u8> for ServerStatus {
87 fn from(value: u8) -> Self {
88 match value {
89 0 => ServerStatus::Stopped,
90 1 => ServerStatus::Running,
91 2 => ServerStatus::Faulted,
92 _ => ServerStatus::Stopped,
93 }
94 }
95}
96
97pub struct McpServerChannels {
102 pub inbound_tx: mpsc::Sender<ClientInbound>,
104 pub outbound_tx: mpsc::Sender<ServerOutbound>,
107 pub outbound_rx: mpsc::Receiver<ServerOutbound>,
109}
110
111pub struct McpServerConfig {
113 pub(crate) name: String,
114 pub(crate) version: String,
115 pub(crate) registry: ToolRegistry,
116 pub(crate) capabilities: McpCapabilities,
117}
118
119impl McpServerConfig {
120 pub fn builder() -> McpServerConfigBuilder {
122 McpServerConfigBuilder::new()
123 }
124
125 pub fn name(&self) -> &str {
127 &self.name
128 }
129
130 pub fn version(&self) -> &str {
132 &self.version
133 }
134
135 pub fn registry(&self) -> &ToolRegistry {
137 &self.registry
138 }
139}
140
141#[derive(Default)]
143pub struct McpServerConfigBuilder {
144 name: String,
145 version: String,
146 registry: ToolRegistry,
147 capabilities: McpCapabilities,
148}
149
150impl McpServerConfigBuilder {
151 pub fn new() -> Self {
153 Self {
154 name: "mcp-server".to_string(),
155 version: "0.1.0".to_string(),
156 registry: ToolRegistry::new(),
157 capabilities: McpCapabilities {
158 tools: Some(serde_json::json!({})),
159 ..Default::default()
160 },
161 }
162 }
163
164 pub fn name(mut self, name: impl Into<String>) -> Self {
166 self.name = name.into();
167 self
168 }
169
170 pub fn version(mut self, version: impl Into<String>) -> Self {
172 self.version = version.into();
173 self
174 }
175
176 pub fn with_tool<T: McpTool + 'static>(mut self, tool: T) -> Self {
178 self.registry.register(Arc::new(tool));
179 self
180 }
181
182 pub fn with_dyn_tool(mut self, tool: DynTool) -> Self {
184 self.registry.register(tool);
185 self
186 }
187
188 pub fn with_tools(mut self, tools: Vec<DynTool>) -> Self {
190 for tool in tools {
191 self.registry.register(tool);
192 }
193 self
194 }
195
196 pub fn with_tools_from<P: ToolProvider>(mut self, provider: P) -> Self {
198 self.registry.register_provider(provider);
199 self
200 }
201
202 pub fn register_tools(mut self) -> Self {
204 for tool in crate::tool::all_tools() {
205 self.registry.register(tool);
206 }
207 self
208 }
209
210 pub fn register_tools_in_group(mut self, group: &str) -> Self {
212 for tool in crate::tool::tools_in_group(group) {
213 self.registry.register(tool);
214 }
215 self
216 }
217
218 pub fn with_capabilities(mut self, capabilities: McpCapabilities) -> Self {
220 self.capabilities = capabilities;
221 self
222 }
223
224 pub fn with_resources(mut self) -> Self {
226 self.capabilities.resources = Some(serde_json::json!({}));
227 self
228 }
229
230 pub fn with_prompts(mut self) -> Self {
232 self.capabilities.prompts = Some(serde_json::json!({}));
233 self
234 }
235
236 pub fn with_elicitation(mut self) -> Self {
238 self.capabilities.elicitation = Some(serde_json::json!({}));
239 self
240 }
241
242 pub fn with_tasks(mut self) -> Self {
244 self.capabilities.tasks = Some(serde_json::json!({}));
245 self
246 }
247
248 pub fn with_logging(mut self) -> Self {
250 self.capabilities.logging = Some(serde_json::json!({}));
251 self
252 }
253
254 pub fn with_completions(mut self) -> Self {
256 self.capabilities.completions = Some(serde_json::json!({}));
257 self
258 }
259
260 pub fn build(self) -> McpServerConfig {
262 McpServerConfig {
263 name: self.name,
264 version: self.version,
265 registry: self.registry,
266 capabilities: self.capabilities,
267 }
268 }
269}
270
271type PendingRequests = Arc<Mutex<HashMap<JsonRpcId, oneshot::Sender<JsonRpcResponse>>>>;
273
274pub struct McpServer {
298 name: String,
299 version: String,
300 registry: Arc<ToolRegistry>,
301 capabilities: McpCapabilities,
302 status: Arc<AtomicU8>,
303 pending_requests: PendingRequests,
304 next_request_id: Arc<AtomicU8>,
305 outbound_tx: mpsc::Sender<ServerOutbound>,
306 fault_reason: Arc<RwLock<Option<String>>>,
307 start_time: std::time::Instant,
308}
309
310impl McpServer {
311 pub fn new(config: McpServerConfig) -> (Arc<Self>, McpServerChannels) {
316 let (inbound_tx, inbound_rx) = mpsc::channel::<ClientInbound>(256);
317 let (outbound_tx, outbound_rx) = mpsc::channel::<ServerOutbound>(256);
318
319 let server = Arc::new(Self {
320 name: config.name,
321 version: config.version,
322 registry: Arc::new(config.registry),
323 capabilities: config.capabilities,
324 status: Arc::new(AtomicU8::new(ServerStatus::Running as u8)),
325 pending_requests: Arc::new(Mutex::new(HashMap::new())),
326 next_request_id: Arc::new(AtomicU8::new(1)),
327 outbound_tx,
328 fault_reason: Arc::new(RwLock::new(None)),
329 start_time: std::time::Instant::now(),
330 });
331
332 let server_clone = Arc::clone(&server);
334 tokio::spawn(async move {
335 server_clone.message_loop(inbound_rx).await;
336 });
337
338 let channels = McpServerChannels {
339 inbound_tx,
340 outbound_tx: server.outbound_tx.clone(),
341 outbound_rx,
342 };
343
344 (server, channels)
345 }
346
347 pub fn status(&self) -> ServerStatus {
349 ServerStatus::from(self.status.load(Ordering::SeqCst))
350 }
351
352 pub async fn fault_reason(&self) -> Option<String> {
354 self.fault_reason.read().await.clone()
355 }
356
357 pub fn stop(&self) {
359 self.status
360 .store(ServerStatus::Stopped as u8, Ordering::SeqCst);
361 }
362
363 pub fn name(&self) -> &str {
365 &self.name
366 }
367
368 pub fn version(&self) -> &str {
370 &self.version
371 }
372
373 pub fn uptime(&self) -> std::time::Duration {
375 self.start_time.elapsed()
376 }
377
378 pub fn server_info(&self) -> McpServerInfo {
380 McpServerInfo {
381 name: self.name.clone(),
382 version: self.version.clone(),
383 title: None,
384 description: None,
385 icons: None,
386 website_url: None,
387 }
388 }
389
390 pub async fn send_request(
395 &self,
396 method: impl Into<String>,
397 params: Option<Value>,
398 timeout: std::time::Duration,
399 ) -> Result<JsonRpcResponse, ServerError> {
400 let id = JsonRpcId::Number(self.next_request_id.fetch_add(1, Ordering::SeqCst) as i64);
401
402 let request = JsonRpcRequest {
403 jsonrpc: "2.0".to_string(),
404 id: id.clone(),
405 method: method.into(),
406 params,
407 };
408
409 let (response_tx, response_rx) = oneshot::channel();
411
412 {
414 let mut pending = self.pending_requests.lock().await;
415 pending.insert(id.clone(), response_tx);
416 }
417
418 self.outbound_tx
420 .send(ServerOutbound::Request(request))
421 .await
422 .map_err(|_| ServerError::ChannelClosed)?;
423
424 match tokio::time::timeout(timeout, response_rx).await {
426 Ok(Ok(response)) => Ok(response),
427 Ok(Err(_)) => {
428 self.pending_requests.lock().await.remove(&id);
430 Err(ServerError::ChannelClosed)
431 }
432 Err(_) => {
433 self.pending_requests.lock().await.remove(&id);
435 Err(ServerError::ResponseTimeout)
436 }
437 }
438 }
439
440 pub async fn send_request_default_timeout(
442 &self,
443 method: impl Into<String>,
444 params: Option<Value>,
445 ) -> Result<JsonRpcResponse, ServerError> {
446 self.send_request(method, params, std::time::Duration::from_secs(30))
447 .await
448 }
449
450 pub async fn send_notification(
452 &self,
453 method: impl Into<String>,
454 params: Option<Value>,
455 ) -> Result<(), ServerError> {
456 let notification = JsonRpcNotification::new(method, params);
457 self.outbound_tx
458 .send(ServerOutbound::Notification(notification))
459 .await
460 .map_err(|_| ServerError::ChannelClosed)
461 }
462
463 pub async fn send_progress(
465 &self,
466 token: impl Into<String>,
467 progress: f64,
468 message: Option<String>,
469 ) -> Result<(), ServerError> {
470 let params = serde_json::json!({
471 "progressToken": token.into(),
472 "progress": progress,
473 "message": message
474 });
475 self.send_notification("notifications/progress", Some(params))
476 .await
477 }
478
479 pub async fn send_log(
481 &self,
482 level: &str,
483 message: impl Into<String>,
484 logger: Option<&str>,
485 data: Option<Value>,
486 ) -> Result<(), ServerError> {
487 let mut params = serde_json::json!({
488 "level": level,
489 "message": message.into()
490 });
491 if let Some(l) = logger {
492 params["logger"] = serde_json::json!(l);
493 }
494 if let Some(d) = data {
495 params["data"] = d;
496 }
497 self.send_notification("notifications/message", Some(params))
498 .await
499 }
500
501 pub async fn call_tool(&self, name: &str, args: Value) -> ToolCallResult {
503 self.registry.call(name, args).await
504 }
505
506 pub fn list_tools(&self) -> Vec<crate::protocol::McpToolDefinition> {
508 self.registry.definitions()
509 }
510
511 pub fn registry(&self) -> &ToolRegistry {
513 &self.registry
514 }
515
516 async fn message_loop(self: Arc<Self>, mut inbound_rx: mpsc::Receiver<ClientInbound>) {
518 while self.status() == ServerStatus::Running {
519 match inbound_rx.recv().await {
520 Some(message) => {
521 if let Err(e) = self.handle_inbound(message).await {
522 self.status
524 .store(ServerStatus::Faulted as u8, Ordering::SeqCst);
525 *self.fault_reason.write().await = Some(e.to_string());
526 break;
527 }
528 }
529 None => {
530 self.status
532 .store(ServerStatus::Stopped as u8, Ordering::SeqCst);
533 break;
534 }
535 }
536 }
537 }
538
539 async fn handle_inbound(&self, message: ClientInbound) -> Result<(), ServerError> {
541 match message {
542 ClientInbound::Request(request) => {
543 let response = self.handle_rpc_request(request).await;
544 self.outbound_tx
545 .send(ServerOutbound::Response(response))
546 .await
547 .map_err(|_| ServerError::ChannelClosed)?;
548 }
549 ClientInbound::Response(response) => {
550 let mut pending = self.pending_requests.lock().await;
552 if let Some(tx) = pending.remove(&response.id) {
553 let _ = tx.send(response);
554 }
555 }
556 ClientInbound::Notification(notification) => {
557 self.handle_notification(notification).await?;
559 }
560 }
561 Ok(())
562 }
563
564 async fn handle_rpc_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
566 match request.method.as_str() {
567 "initialize" => self.handle_initialize(request.id),
568 "tools/list" => self.handle_tools_list(request.id),
569 "tools/call" => self.handle_tools_call(request.id, request.params).await,
570 "ping" => JsonRpcResponse::success(request.id, serde_json::json!({})),
571 "health/check" => self.handle_health_check(request.id),
572 _ => JsonRpcResponse::error(
573 request.id,
574 -32601,
575 format!("Method not found: {}", request.method),
576 None,
577 ),
578 }
579 }
580
581 fn handle_initialize(&self, id: JsonRpcId) -> JsonRpcResponse {
582 JsonRpcResponse::success(
583 id,
584 serde_json::json!({
585 "protocolVersion": MCP_PROTOCOL_VERSION,
586 "serverInfo": self.server_info(),
587 "capabilities": self.capabilities
588 }),
589 )
590 }
591
592 fn handle_tools_list(&self, id: JsonRpcId) -> JsonRpcResponse {
593 let tools = self.registry.definitions();
594 JsonRpcResponse::success(id, serde_json::json!({ "tools": tools }))
595 }
596
597 fn handle_health_check(&self, id: JsonRpcId) -> JsonRpcResponse {
598 let uptime_secs = self.start_time.elapsed().as_secs();
599 let status = self.status();
600 JsonRpcResponse::success(
601 id,
602 serde_json::json!({
603 "status": match status {
604 ServerStatus::Running => "healthy",
605 ServerStatus::Stopped => "stopped",
606 ServerStatus::Faulted => "unhealthy",
607 },
608 "uptime_seconds": uptime_secs,
609 "server_name": self.name,
610 "server_version": self.version,
611 "tool_count": self.registry.definitions().len()
612 }),
613 )
614 }
615
616 async fn handle_tools_call(&self, id: JsonRpcId, params: Option<Value>) -> JsonRpcResponse {
617 let params = match params {
618 Some(p) => p,
619 None => {
620 return JsonRpcResponse::error(id, -32602, "Missing params".to_string(), None);
621 }
622 };
623
624 let name = match params.get("name").and_then(|n| n.as_str()) {
625 Some(n) => n,
626 None => {
627 return JsonRpcResponse::error(id, -32602, "Missing tool name".to_string(), None);
628 }
629 };
630
631 let arguments = params
632 .get("arguments")
633 .cloned()
634 .unwrap_or(serde_json::json!({}));
635
636 let result = self.registry.call(name, arguments).await;
637
638 match result {
639 Ok(content) => JsonRpcResponse::success(
640 id,
641 serde_json::json!({
642 "content": content,
643 "isError": false
644 }),
645 ),
646 Err(e) => JsonRpcResponse::success(
647 id,
648 serde_json::json!({
649 "content": [{ "type": "text", "text": e.to_string() }],
650 "isError": true
651 }),
652 ),
653 }
654 }
655
656 async fn handle_notification(
658 &self,
659 notification: JsonRpcNotification,
660 ) -> Result<(), ServerError> {
661 match notification.method.as_str() {
662 "notifications/cancelled" => {
663 eprintln!("[MCP] Received cancellation notification (not yet implemented)");
666 }
667 "notifications/initialized" => {
668 }
670 method => {
671 eprintln!("[MCP] Unknown notification method: {}", method);
673 }
674 }
675 Ok(())
676 }
677}
678
679#[derive(Debug)]
681pub enum ServerError {
682 Io(std::io::Error),
684 Serialization(serde_json::Error),
686 ChannelClosed,
688 ResponseTimeout,
690 Transport(String),
692}
693
694impl std::fmt::Display for ServerError {
695 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696 match self {
697 ServerError::Io(e) => write!(f, "I/O error: {}", e),
698 ServerError::Serialization(e) => write!(f, "Serialization error: {}", e),
699 ServerError::ChannelClosed => write!(f, "Channel closed"),
700 ServerError::ResponseTimeout => write!(f, "Response timeout"),
701 ServerError::Transport(e) => write!(f, "Transport error: {}", e),
702 }
703 }
704}
705
706impl std::error::Error for ServerError {
707 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
708 match self {
709 ServerError::Io(e) => Some(e),
710 ServerError::Serialization(e) => Some(e),
711 ServerError::ChannelClosed => None,
712 ServerError::ResponseTimeout => None,
713 ServerError::Transport(_) => None,
714 }
715 }
716}
717
718#[macro_export]
720macro_rules! tools {
721 () => {
722 Vec::new()
723 };
724 ($($tool:expr),+ $(,)?) => {
725 vec![
726 $(std::sync::Arc::new($tool) as $crate::DynTool),+
727 ]
728 };
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use crate::protocol::{McpToolDefinition, ToolContent};
735 use crate::tool::BoxFuture;
736
737 struct EchoTool;
738
739 impl McpTool for EchoTool {
740 fn definition(&self) -> McpToolDefinition {
741 McpToolDefinition::new("echo")
742 .with_description("Echo the input")
743 .with_schema(serde_json::json!({
744 "type": "object",
745 "properties": {
746 "message": { "type": "string" }
747 }
748 }))
749 }
750
751 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
752 Box::pin(async move {
753 let message = args
754 .get("message")
755 .and_then(|m| m.as_str())
756 .unwrap_or("no message");
757 Ok(vec![ToolContent::text(message)])
758 })
759 }
760 }
761
762 #[test]
763 fn test_config_builder() {
764 let config = McpServerConfig::builder()
765 .name("test-server")
766 .version("1.0.0")
767 .with_tool(EchoTool)
768 .build();
769
770 assert_eq!(config.name(), "test-server");
771 assert_eq!(config.version(), "1.0.0");
772 assert_eq!(config.registry.len(), 1);
773 }
774
775 #[test]
776 fn test_tools_macro() {
777 let tools = tools![EchoTool];
778 assert_eq!(tools.len(), 1);
779 }
780
781 #[test]
782 fn test_config_with_tools() {
783 let config = McpServerConfig::builder()
784 .name("test-server")
785 .version("1.0.0")
786 .with_tools(tools![EchoTool])
787 .build();
788
789 assert_eq!(config.registry.len(), 1);
790 }
791
792 #[tokio::test]
793 async fn test_server_creation() {
794 let config = McpServerConfig::builder()
795 .name("test-server")
796 .version("1.0.0")
797 .with_tool(EchoTool)
798 .build();
799
800 let (server, _channels) = McpServer::new(config);
801
802 assert_eq!(server.name(), "test-server");
803 assert_eq!(server.version(), "1.0.0");
804 assert_eq!(server.status(), ServerStatus::Running);
805 }
806
807 #[tokio::test]
808 async fn test_server_bidirectional() {
809 let config = McpServerConfig::builder()
810 .name("bidir-test")
811 .version("1.0.0")
812 .with_tool(EchoTool)
813 .build();
814
815 let (server, mut channels) = McpServer::new(config);
816
817 let request = JsonRpcRequest {
819 jsonrpc: "2.0".to_string(),
820 id: JsonRpcId::Number(1),
821 method: "tools/list".to_string(),
822 params: None,
823 };
824
825 channels
826 .inbound_tx
827 .send(ClientInbound::Request(request))
828 .await
829 .unwrap();
830
831 let outbound = channels.outbound_rx.recv().await.unwrap();
833 match outbound {
834 ServerOutbound::Response(response) => {
835 assert!(response.is_success());
836 }
837 _ => panic!("Expected response"),
838 }
839
840 assert_eq!(server.status(), ServerStatus::Running);
842 }
843
844 #[tokio::test]
845 async fn test_server_stop() {
846 let config = McpServerConfig::builder()
847 .name("stop-test")
848 .version("1.0.0")
849 .build();
850
851 let (server, _channels) = McpServer::new(config);
852
853 assert_eq!(server.status(), ServerStatus::Running);
854
855 server.stop();
856
857 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
859
860 assert_eq!(server.status(), ServerStatus::Stopped);
861 }
862
863 #[tokio::test]
864 async fn test_server_tool_call() {
865 let config = McpServerConfig::builder()
866 .name("tool-test")
867 .version("1.0.0")
868 .with_tool(EchoTool)
869 .build();
870
871 let (_server, mut channels) = McpServer::new(config);
872
873 let request = JsonRpcRequest {
874 jsonrpc: "2.0".to_string(),
875 id: JsonRpcId::Number(1),
876 method: "tools/call".to_string(),
877 params: Some(serde_json::json!({
878 "name": "echo",
879 "arguments": { "message": "hello world" }
880 })),
881 };
882
883 channels
884 .inbound_tx
885 .send(ClientInbound::Request(request))
886 .await
887 .unwrap();
888
889 let outbound = channels.outbound_rx.recv().await.unwrap();
890 match outbound {
891 ServerOutbound::Response(response) => {
892 assert!(response.is_success());
893 let result = response.result().unwrap();
894 assert!(result.to_string().contains("hello world"));
895 }
896 _ => panic!("Expected response"),
897 }
898 }
899}