1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13 error::{McpError, McpResult},
14 prompt::{Prompt, PromptHandler},
15 resource::{Resource, ResourceHandler},
16 tool::{Tool, ToolHandler},
17 PromptInfo, ResourceInfo, ToolInfo,
18};
19use crate::protocol::{messages::*, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22#[derive(Debug, Clone)]
24pub struct ServerConfig {
25 pub max_concurrent_requests: usize,
27 pub request_timeout_ms: u64,
29 pub validate_requests: bool,
31 pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36 fn default() -> Self {
37 Self {
38 max_concurrent_requests: 100,
39 request_timeout_ms: 30000,
40 validate_requests: true,
41 enable_logging: true,
42 }
43 }
44}
45
46pub struct McpServer {
48 info: ServerInfo,
50 capabilities: ServerCapabilities,
52 config: ServerConfig,
54 resources: Arc<RwLock<HashMap<String, Resource>>>,
56 tools: Arc<RwLock<HashMap<String, Tool>>>,
58 prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60 transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62 state: Arc<RwLock<ServerState>>,
64 #[allow(dead_code)]
66 request_counter: Arc<Mutex<u64>>,
67}
68
69#[derive(Debug, Clone, PartialEq)]
71pub enum ServerState {
72 Uninitialized,
74 Initializing,
76 Running,
78 Stopping,
80 Stopped,
82}
83
84impl McpServer {
85 pub fn new(name: String, version: String) -> Self {
87 Self {
88 info: ServerInfo { name, version },
89 capabilities: ServerCapabilities {
90 prompts: Some(PromptsCapability {
91 list_changed: Some(true),
92 }),
93 resources: Some(ResourcesCapability {
94 subscribe: Some(true),
95 list_changed: Some(true),
96 }),
97 tools: Some(ToolsCapability {
98 list_changed: Some(true),
99 }),
100 sampling: None,
101 },
102 config: ServerConfig::default(),
103 resources: Arc::new(RwLock::new(HashMap::new())),
104 tools: Arc::new(RwLock::new(HashMap::new())),
105 prompts: Arc::new(RwLock::new(HashMap::new())),
106 transport: Arc::new(Mutex::new(None)),
107 state: Arc::new(RwLock::new(ServerState::Uninitialized)),
108 request_counter: Arc::new(Mutex::new(0)),
109 }
110 }
111
112 pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
114 let mut server = Self::new(name, version);
115 server.config = config;
116 server
117 }
118
119 pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
121 self.capabilities = capabilities;
122 }
123
124 pub fn info(&self) -> &ServerInfo {
126 &self.info
127 }
128
129 pub fn capabilities(&self) -> &ServerCapabilities {
131 &self.capabilities
132 }
133
134 pub fn config(&self) -> &ServerConfig {
136 &self.config
137 }
138
139 pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
145 where
146 H: ResourceHandler + 'static,
147 {
148 let resource_info = ResourceInfo {
149 uri: uri.clone(),
150 name: name.clone(),
151 description: None,
152 mime_type: None,
153 };
154
155 validate_resource_info(&resource_info)?;
156
157 let resource = Resource::new(resource_info, handler);
158
159 {
160 let mut resources = self.resources.write().await;
161 resources.insert(uri, resource);
162 }
163
164 self.emit_resources_list_changed().await?;
166
167 Ok(())
168 }
169
170 pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
172 where
173 H: ResourceHandler + 'static,
174 {
175 validate_resource_info(&info)?;
176
177 let uri = info.uri.clone();
178 let resource = Resource::new(info, handler);
179
180 {
181 let mut resources = self.resources.write().await;
182 resources.insert(uri, resource);
183 }
184
185 self.emit_resources_list_changed().await?;
186
187 Ok(())
188 }
189
190 pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
192 let removed = {
193 let mut resources = self.resources.write().await;
194 resources.remove(uri).is_some()
195 };
196
197 if removed {
198 self.emit_resources_list_changed().await?;
199 }
200
201 Ok(removed)
202 }
203
204 pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
206 let resources = self.resources.read().await;
207 Ok(resources.values().map(|r| r.info.clone()).collect())
208 }
209
210 pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContent>> {
212 let resources = self.resources.read().await;
213
214 match resources.get(uri) {
215 Some(resource) => {
216 let params = HashMap::new(); resource.handler.read(uri, ¶ms).await
218 }
219 None => Err(McpError::ResourceNotFound(uri.to_string())),
220 }
221 }
222
223 pub async fn add_tool<H>(
229 &self,
230 name: String,
231 description: Option<String>,
232 schema: Value,
233 handler: H,
234 ) -> McpResult<()>
235 where
236 H: ToolHandler + 'static,
237 {
238 let tool_info = ToolInfo {
239 name: name.clone(),
240 description,
241 input_schema: schema,
242 };
243
244 validate_tool_info(&tool_info)?;
245
246 let tool = Tool::new(
247 name.clone(),
248 tool_info.description.clone(),
249 tool_info.input_schema.clone(),
250 handler,
251 );
252
253 {
254 let mut tools = self.tools.write().await;
255 tools.insert(name, tool);
256 }
257
258 self.emit_tools_list_changed().await?;
259
260 Ok(())
261 }
262
263 pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
265 where
266 H: ToolHandler + 'static,
267 {
268 validate_tool_info(&info)?;
269
270 let name = info.name.clone();
271 let tool = Tool::new(
272 name.clone(),
273 info.description.clone(),
274 info.input_schema.clone(),
275 handler,
276 );
277
278 {
279 let mut tools = self.tools.write().await;
280 tools.insert(name, tool);
281 }
282
283 self.emit_tools_list_changed().await?;
284
285 Ok(())
286 }
287
288 pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
290 let removed = {
291 let mut tools = self.tools.write().await;
292 tools.remove(name).is_some()
293 };
294
295 if removed {
296 self.emit_tools_list_changed().await?;
297 }
298
299 Ok(removed)
300 }
301
302 pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
304 let tools = self.tools.read().await;
305 Ok(tools.values().map(|t| t.info.clone()).collect())
306 }
307
308 pub async fn call_tool(
310 &self,
311 name: &str,
312 arguments: Option<HashMap<String, Value>>,
313 ) -> McpResult<ToolResult> {
314 let tools = self.tools.read().await;
315
316 match tools.get(name) {
317 Some(tool) => {
318 if !tool.enabled {
319 return Err(McpError::ToolNotFound(format!(
320 "Tool '{}' is disabled",
321 name
322 )));
323 }
324
325 let args = arguments.unwrap_or_default();
326 tool.handler.call(args).await
327 }
328 None => Err(McpError::ToolNotFound(name.to_string())),
329 }
330 }
331
332 pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
338 where
339 H: PromptHandler + 'static,
340 {
341 validate_prompt_info(&info)?;
342
343 let name = info.name.clone();
344 let prompt = Prompt::new(info, handler);
345
346 {
347 let mut prompts = self.prompts.write().await;
348 prompts.insert(name, prompt);
349 }
350
351 self.emit_prompts_list_changed().await?;
352
353 Ok(())
354 }
355
356 pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
358 let removed = {
359 let mut prompts = self.prompts.write().await;
360 prompts.remove(name).is_some()
361 };
362
363 if removed {
364 self.emit_prompts_list_changed().await?;
365 }
366
367 Ok(removed)
368 }
369
370 pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
372 let prompts = self.prompts.read().await;
373 Ok(prompts.values().map(|p| p.info.clone()).collect())
374 }
375
376 pub async fn get_prompt(
378 &self,
379 name: &str,
380 arguments: Option<HashMap<String, Value>>,
381 ) -> McpResult<PromptResult> {
382 let prompts = self.prompts.read().await;
383
384 match prompts.get(name) {
385 Some(prompt) => {
386 let args = arguments.unwrap_or_default();
387 prompt.handler.get(args).await
388 }
389 None => Err(McpError::PromptNotFound(name.to_string())),
390 }
391 }
392
393 pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
399 where
400 T: ServerTransport + 'static,
401 {
402 let mut state = self.state.write().await;
403
404 match *state {
405 ServerState::Uninitialized => {
406 *state = ServerState::Initializing;
407 }
408 _ => return Err(McpError::Protocol("Server is already started".to_string())),
409 }
410
411 drop(state);
412
413 {
415 let mut transport_guard = self.transport.lock().await;
416 *transport_guard = Some(Box::new(transport));
417 }
418
419 {
421 let mut transport_guard = self.transport.lock().await;
422 if let Some(transport) = transport_guard.as_mut() {
423 transport.start().await?;
424 }
425 }
426
427 {
429 let mut state = self.state.write().await;
430 *state = ServerState::Running;
431 }
432
433 Ok(())
434 }
435
436 pub async fn stop(&self) -> McpResult<()> {
438 let mut state = self.state.write().await;
439
440 match *state {
441 ServerState::Running => {
442 *state = ServerState::Stopping;
443 }
444 ServerState::Stopped => return Ok(()),
445 _ => return Err(McpError::Protocol("Server is not running".to_string())),
446 }
447
448 drop(state);
449
450 {
452 let mut transport_guard = self.transport.lock().await;
453 if let Some(transport) = transport_guard.as_mut() {
454 transport.stop().await?;
455 }
456 }
457
458 {
460 let mut state = self.state.write().await;
461 *state = ServerState::Stopped;
462 }
463
464 Ok(())
465 }
466
467 pub async fn is_running(&self) -> bool {
469 let state = self.state.read().await;
470 matches!(*state, ServerState::Running)
471 }
472
473 pub async fn state(&self) -> ServerState {
475 let state = self.state.read().await;
476 state.clone()
477 }
478
479 pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
485 if self.config.validate_requests {
487 validate_jsonrpc_request(&request)?;
488 validate_mcp_request(&request.method, request.params.as_ref())?;
489 }
490
491 let result = match request.method.as_str() {
493 methods::INITIALIZE => self.handle_initialize(request.params).await,
494 methods::PING => self.handle_ping().await,
495 methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
496 methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
497 methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
498 methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
499 methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
500 methods::RESOURCES_UNSUBSCRIBE => {
501 self.handle_resources_unsubscribe(request.params).await
502 }
503 methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
504 methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
505 methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
506 _ => Err(McpError::Protocol(format!(
507 "Unknown method: {}",
508 request.method
509 ))),
510 };
511
512 match result {
514 Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
515 Err(error) => {
516 let (code, message) = match error {
517 McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
518 McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
519 McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
520 McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
521 _ => (INTERNAL_ERROR, error.to_string()),
522 };
523 Ok(JsonRpcResponse::error(request.id, code, message, None))
524 }
525 }
526 }
527
528 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
533 let params: InitializeParams = match params {
534 Some(p) => serde_json::from_value(p)?,
535 None => {
536 return Err(McpError::Validation(
537 "Missing initialize parameters".to_string(),
538 ))
539 }
540 };
541
542 validate_initialize_params(¶ms)?;
543
544 let result = InitializeResult::new(
545 self.info.clone(),
546 self.capabilities.clone(),
547 MCP_PROTOCOL_VERSION.to_string(),
548 );
549
550 Ok(serde_json::to_value(result)?)
551 }
552
553 async fn handle_ping(&self) -> McpResult<Value> {
554 Ok(serde_json::to_value(PingResult {})?)
555 }
556
557 async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
558 let _params: ListToolsParams = match params {
559 Some(p) => serde_json::from_value(p)?,
560 None => ListToolsParams::default(),
561 };
562
563 let tools = self.list_tools().await?;
564 let result = ListToolsResult {
565 tools,
566 next_cursor: None, };
568
569 Ok(serde_json::to_value(result)?)
570 }
571
572 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
573 let params: CallToolParams = match params {
574 Some(p) => serde_json::from_value(p)?,
575 None => {
576 return Err(McpError::Validation(
577 "Missing tool call parameters".to_string(),
578 ))
579 }
580 };
581
582 validate_call_tool_params(¶ms)?;
583
584 let result = self.call_tool(¶ms.name, params.arguments).await?;
585 Ok(serde_json::to_value(result)?)
586 }
587
588 async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
589 let _params: ListResourcesParams = match params {
590 Some(p) => serde_json::from_value(p)?,
591 None => ListResourcesParams::default(),
592 };
593
594 let resources = self.list_resources().await?;
595 let result = ListResourcesResult {
596 resources,
597 next_cursor: None, };
599
600 Ok(serde_json::to_value(result)?)
601 }
602
603 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
604 let params: ReadResourceParams = match params {
605 Some(p) => serde_json::from_value(p)?,
606 None => {
607 return Err(McpError::Validation(
608 "Missing resource read parameters".to_string(),
609 ))
610 }
611 };
612
613 validate_read_resource_params(¶ms)?;
614
615 let contents = self.read_resource(¶ms.uri).await?;
616 let result = ReadResourceResult { contents };
617
618 Ok(serde_json::to_value(result)?)
619 }
620
621 async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
622 let params: SubscribeResourceParams = match params {
623 Some(p) => serde_json::from_value(p)?,
624 None => {
625 return Err(McpError::Validation(
626 "Missing resource subscribe parameters".to_string(),
627 ))
628 }
629 };
630
631 let _uri = params.uri;
633 let result = SubscribeResourceResult {};
634
635 Ok(serde_json::to_value(result)?)
636 }
637
638 async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
639 let params: UnsubscribeResourceParams = match params {
640 Some(p) => serde_json::from_value(p)?,
641 None => {
642 return Err(McpError::Validation(
643 "Missing resource unsubscribe parameters".to_string(),
644 ))
645 }
646 };
647
648 let _uri = params.uri;
650 let result = UnsubscribeResourceResult {};
651
652 Ok(serde_json::to_value(result)?)
653 }
654
655 async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
656 let _params: ListPromptsParams = match params {
657 Some(p) => serde_json::from_value(p)?,
658 None => ListPromptsParams::default(),
659 };
660
661 let prompts = self.list_prompts().await?;
662 let result = ListPromptsResult {
663 prompts,
664 next_cursor: None, };
666
667 Ok(serde_json::to_value(result)?)
668 }
669
670 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
671 let params: GetPromptParams = match params {
672 Some(p) => serde_json::from_value(p)?,
673 None => {
674 return Err(McpError::Validation(
675 "Missing prompt get parameters".to_string(),
676 ))
677 }
678 };
679
680 validate_get_prompt_params(¶ms)?;
681
682 let result = self.get_prompt(¶ms.name, params.arguments).await?;
683 Ok(serde_json::to_value(result)?)
684 }
685
686 async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
687 let _params: SetLoggingLevelParams = match params {
688 Some(p) => serde_json::from_value(p)?,
689 None => {
690 return Err(McpError::Validation(
691 "Missing logging level parameters".to_string(),
692 ))
693 }
694 };
695
696 let result = SetLoggingLevelResult {};
698 Ok(serde_json::to_value(result)?)
699 }
700
701 async fn emit_resources_list_changed(&self) -> McpResult<()> {
706 let notification = JsonRpcNotification::new(
707 methods::RESOURCES_LIST_CHANGED.to_string(),
708 Some(ResourceListChangedParams {}),
709 )?;
710
711 self.send_notification(notification).await
712 }
713
714 async fn emit_tools_list_changed(&self) -> McpResult<()> {
715 let notification = JsonRpcNotification::new(
716 methods::TOOLS_LIST_CHANGED.to_string(),
717 Some(ToolListChangedParams {}),
718 )?;
719
720 self.send_notification(notification).await
721 }
722
723 async fn emit_prompts_list_changed(&self) -> McpResult<()> {
724 let notification = JsonRpcNotification::new(
725 methods::PROMPTS_LIST_CHANGED.to_string(),
726 Some(PromptListChangedParams {}),
727 )?;
728
729 self.send_notification(notification).await
730 }
731
732 async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
734 let mut transport_guard = self.transport.lock().await;
735 if let Some(transport) = transport_guard.as_mut() {
736 transport.send_notification(notification).await?;
737 }
738 Ok(())
739 }
740
741 #[allow(dead_code)]
746 async fn next_request_id(&self) -> u64 {
747 let mut counter = self.request_counter.lock().await;
748 *counter += 1;
749 *counter
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use serde_json::json;
757
758 #[tokio::test]
759 async fn test_server_creation() {
760 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
761 assert_eq!(server.info().name, "test-server");
762 assert_eq!(server.info().version, "1.0.0");
763 assert!(!server.is_running().await);
764 }
765
766 #[tokio::test]
767 async fn test_tool_management() {
768 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
769
770 let schema = json!({
772 "type": "object",
773 "properties": {
774 "name": {"type": "string"}
775 }
776 });
777
778 struct TestToolHandler;
779
780 #[async_trait::async_trait]
781 impl ToolHandler for TestToolHandler {
782 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
783 Ok(ToolResult {
784 content: vec![Content::text("Hello from tool")],
785 is_error: None,
786 })
787 }
788 }
789
790 server
791 .add_tool(
792 "test_tool".to_string(),
793 Some("A test tool".to_string()),
794 schema,
795 TestToolHandler,
796 )
797 .await
798 .unwrap();
799
800 let tools = server.list_tools().await.unwrap();
802 assert_eq!(tools.len(), 1);
803 assert_eq!(tools[0].name, "test_tool");
804
805 let result = server.call_tool("test_tool", None).await.unwrap();
807 assert_eq!(result.content.len(), 1);
808 }
809
810 #[tokio::test]
811 async fn test_initialize_request() {
812 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
813
814 let init_params = InitializeParams::new(
815 ClientInfo {
816 name: "test-client".to_string(),
817 version: "1.0.0".to_string(),
818 },
819 ClientCapabilities::default(),
820 MCP_PROTOCOL_VERSION.to_string(),
821 );
822
823 let request =
824 JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
825 .unwrap();
826
827 let response = server.handle_request(request).await.unwrap();
828 assert!(response.result.is_some());
829 assert!(response.error.is_none());
830 }
831}