1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13 error::{McpError, McpResult},
14 prompt::{Prompt, PromptHandler},
15 resource::{Resource, ResourceHandler},
16 tool::{Tool, ToolHandler},
17 PromptInfo, ResourceInfo, ToolInfo,
18};
19use crate::protocol::{messages::*, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22#[derive(Debug, Clone)]
24pub struct ServerConfig {
25 pub max_concurrent_requests: usize,
27 pub request_timeout_ms: u64,
29 pub validate_requests: bool,
31 pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36 fn default() -> Self {
37 Self {
38 max_concurrent_requests: 100,
39 request_timeout_ms: 30000,
40 validate_requests: true,
41 enable_logging: true,
42 }
43 }
44}
45
46pub struct McpServer {
48 info: ServerInfo,
50 capabilities: ServerCapabilities,
52 config: ServerConfig,
54 resources: Arc<RwLock<HashMap<String, Resource>>>,
56 tools: Arc<RwLock<HashMap<String, Tool>>>,
58 prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60 transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62 state: Arc<RwLock<ServerState>>,
64 request_counter: Arc<Mutex<u64>>,
66}
67
68#[derive(Debug, Clone, PartialEq)]
70enum ServerState {
71 Uninitialized,
73 Initializing,
75 Running,
77 Stopping,
79 Stopped,
81}
82
83impl McpServer {
84 pub fn new(name: String, version: String) -> Self {
86 Self {
87 info: ServerInfo { name, version },
88 capabilities: ServerCapabilities {
89 prompts: Some(PromptsCapability {
90 list_changed: Some(true),
91 }),
92 resources: Some(ResourcesCapability {
93 subscribe: Some(true),
94 list_changed: Some(true),
95 }),
96 tools: Some(ToolsCapability {
97 list_changed: Some(true),
98 }),
99 sampling: None,
100 },
101 config: ServerConfig::default(),
102 resources: Arc::new(RwLock::new(HashMap::new())),
103 tools: Arc::new(RwLock::new(HashMap::new())),
104 prompts: Arc::new(RwLock::new(HashMap::new())),
105 transport: Arc::new(Mutex::new(None)),
106 state: Arc::new(RwLock::new(ServerState::Uninitialized)),
107 request_counter: Arc::new(Mutex::new(0)),
108 }
109 }
110
111 pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
113 let mut server = Self::new(name, version);
114 server.config = config;
115 server
116 }
117
118 pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
120 self.capabilities = capabilities;
121 }
122
123 pub fn info(&self) -> &ServerInfo {
125 &self.info
126 }
127
128 pub fn capabilities(&self) -> &ServerCapabilities {
130 &self.capabilities
131 }
132
133 pub fn config(&self) -> &ServerConfig {
135 &self.config
136 }
137
138 pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
144 where
145 H: ResourceHandler + 'static,
146 {
147 let resource_info = ResourceInfo {
148 uri: uri.clone(),
149 name: name.clone(),
150 description: None,
151 mime_type: None,
152 };
153
154 validate_resource_info(&resource_info)?;
155
156 let resource = Resource::new(resource_info, handler);
157
158 {
159 let mut resources = self.resources.write().await;
160 resources.insert(uri, resource);
161 }
162
163 self.emit_resources_list_changed().await?;
165
166 Ok(())
167 }
168
169 pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
171 where
172 H: ResourceHandler + 'static,
173 {
174 validate_resource_info(&info)?;
175
176 let uri = info.uri.clone();
177 let resource = Resource::new(info, handler);
178
179 {
180 let mut resources = self.resources.write().await;
181 resources.insert(uri, resource);
182 }
183
184 self.emit_resources_list_changed().await?;
185
186 Ok(())
187 }
188
189 pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
191 let removed = {
192 let mut resources = self.resources.write().await;
193 resources.remove(uri).is_some()
194 };
195
196 if removed {
197 self.emit_resources_list_changed().await?;
198 }
199
200 Ok(removed)
201 }
202
203 pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
205 let resources = self.resources.read().await;
206 Ok(resources.values().map(|r| r.info.clone()).collect())
207 }
208
209 pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContent>> {
211 let resources = self.resources.read().await;
212
213 match resources.get(uri) {
214 Some(resource) => {
215 let params = HashMap::new(); resource.handler.read(uri, ¶ms).await
217 }
218 None => Err(McpError::ResourceNotFound(uri.to_string())),
219 }
220 }
221
222 pub async fn add_tool<H>(
228 &self,
229 name: String,
230 description: Option<String>,
231 schema: Value,
232 handler: H,
233 ) -> McpResult<()>
234 where
235 H: ToolHandler + 'static,
236 {
237 let tool_info = ToolInfo {
238 name: name.clone(),
239 description,
240 input_schema: schema,
241 };
242
243 validate_tool_info(&tool_info)?;
244
245 let tool = Tool::new(
246 name.clone(),
247 tool_info.description.clone(),
248 tool_info.input_schema.clone(),
249 handler,
250 );
251
252 {
253 let mut tools = self.tools.write().await;
254 tools.insert(name, tool);
255 }
256
257 self.emit_tools_list_changed().await?;
258
259 Ok(())
260 }
261
262 pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
264 where
265 H: ToolHandler + 'static,
266 {
267 validate_tool_info(&info)?;
268
269 let name = info.name.clone();
270 let tool = Tool::new(
271 name.clone(),
272 info.description.clone(),
273 info.input_schema.clone(),
274 handler,
275 );
276
277 {
278 let mut tools = self.tools.write().await;
279 tools.insert(name, tool);
280 }
281
282 self.emit_tools_list_changed().await?;
283
284 Ok(())
285 }
286
287 pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
289 let removed = {
290 let mut tools = self.tools.write().await;
291 tools.remove(name).is_some()
292 };
293
294 if removed {
295 self.emit_tools_list_changed().await?;
296 }
297
298 Ok(removed)
299 }
300
301 pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
303 let tools = self.tools.read().await;
304 Ok(tools.values().map(|t| t.info.clone()).collect())
305 }
306
307 pub async fn call_tool(
309 &self,
310 name: &str,
311 arguments: Option<HashMap<String, Value>>,
312 ) -> McpResult<ToolResult> {
313 let tools = self.tools.read().await;
314
315 match tools.get(name) {
316 Some(tool) => {
317 if !tool.enabled {
318 return Err(McpError::ToolNotFound(format!(
319 "Tool '{}' is disabled",
320 name
321 )));
322 }
323
324 let args = arguments.unwrap_or_default();
325 tool.handler.call(args).await
326 }
327 None => Err(McpError::ToolNotFound(name.to_string())),
328 }
329 }
330
331 pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
337 where
338 H: PromptHandler + 'static,
339 {
340 validate_prompt_info(&info)?;
341
342 let name = info.name.clone();
343 let prompt = Prompt::new(info, handler);
344
345 {
346 let mut prompts = self.prompts.write().await;
347 prompts.insert(name, prompt);
348 }
349
350 self.emit_prompts_list_changed().await?;
351
352 Ok(())
353 }
354
355 pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
357 let removed = {
358 let mut prompts = self.prompts.write().await;
359 prompts.remove(name).is_some()
360 };
361
362 if removed {
363 self.emit_prompts_list_changed().await?;
364 }
365
366 Ok(removed)
367 }
368
369 pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
371 let prompts = self.prompts.read().await;
372 Ok(prompts.values().map(|p| p.info.clone()).collect())
373 }
374
375 pub async fn get_prompt(
377 &self,
378 name: &str,
379 arguments: Option<HashMap<String, Value>>,
380 ) -> McpResult<PromptResult> {
381 let prompts = self.prompts.read().await;
382
383 match prompts.get(name) {
384 Some(prompt) => {
385 let args = arguments.unwrap_or_default();
386 prompt.handler.get(args).await
387 }
388 None => Err(McpError::PromptNotFound(name.to_string())),
389 }
390 }
391
392 pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
398 where
399 T: ServerTransport + 'static,
400 {
401 let mut state = self.state.write().await;
402
403 match *state {
404 ServerState::Uninitialized => {
405 *state = ServerState::Initializing;
406 }
407 _ => return Err(McpError::Protocol("Server is already started".to_string())),
408 }
409
410 drop(state);
411
412 {
414 let mut transport_guard = self.transport.lock().await;
415 *transport_guard = Some(Box::new(transport));
416 }
417
418 {
420 let mut transport_guard = self.transport.lock().await;
421 if let Some(transport) = transport_guard.as_mut() {
422 transport.start().await?;
423 }
424 }
425
426 {
428 let mut state = self.state.write().await;
429 *state = ServerState::Running;
430 }
431
432 Ok(())
433 }
434
435 pub async fn stop(&self) -> McpResult<()> {
437 let mut state = self.state.write().await;
438
439 match *state {
440 ServerState::Running => {
441 *state = ServerState::Stopping;
442 }
443 ServerState::Stopped => return Ok(()),
444 _ => return Err(McpError::Protocol("Server is not running".to_string())),
445 }
446
447 drop(state);
448
449 {
451 let mut transport_guard = self.transport.lock().await;
452 if let Some(transport) = transport_guard.as_mut() {
453 transport.stop().await?;
454 }
455 }
456
457 {
459 let mut state = self.state.write().await;
460 *state = ServerState::Stopped;
461 }
462
463 Ok(())
464 }
465
466 pub async fn is_running(&self) -> bool {
468 let state = self.state.read().await;
469 matches!(*state, ServerState::Running)
470 }
471
472 pub async fn state(&self) -> ServerState {
474 let state = self.state.read().await;
475 state.clone()
476 }
477
478 pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
484 if self.config.validate_requests {
486 validate_jsonrpc_request(&request)?;
487 validate_mcp_request(&request.method, request.params.as_ref())?;
488 }
489
490 let result = match request.method.as_str() {
492 methods::INITIALIZE => self.handle_initialize(request.params).await,
493 methods::PING => self.handle_ping().await,
494 methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
495 methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
496 methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
497 methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
498 methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
499 methods::RESOURCES_UNSUBSCRIBE => {
500 self.handle_resources_unsubscribe(request.params).await
501 }
502 methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
503 methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
504 methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
505 _ => Err(McpError::Protocol(format!(
506 "Unknown method: {}",
507 request.method
508 ))),
509 };
510
511 match result {
513 Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
514 Err(error) => {
515 let (code, message) = match error {
516 McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
517 McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
518 McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
519 McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
520 _ => (INTERNAL_ERROR, error.to_string()),
521 };
522 Ok(JsonRpcResponse::error(request.id, code, message, None))
523 }
524 }
525 }
526
527 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
532 let params: InitializeParams = match params {
533 Some(p) => serde_json::from_value(p)?,
534 None => {
535 return Err(McpError::Validation(
536 "Missing initialize parameters".to_string(),
537 ))
538 }
539 };
540
541 validate_initialize_params(¶ms)?;
542
543 let result = InitializeResult::new(
544 self.info.clone(),
545 self.capabilities.clone(),
546 MCP_PROTOCOL_VERSION.to_string(),
547 );
548
549 Ok(serde_json::to_value(result)?)
550 }
551
552 async fn handle_ping(&self) -> McpResult<Value> {
553 Ok(serde_json::to_value(PingResult {})?)
554 }
555
556 async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
557 let _params: ListToolsParams = match params {
558 Some(p) => serde_json::from_value(p)?,
559 None => ListToolsParams::default(),
560 };
561
562 let tools = self.list_tools().await?;
563 let result = ListToolsResult {
564 tools,
565 next_cursor: None, };
567
568 Ok(serde_json::to_value(result)?)
569 }
570
571 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
572 let params: CallToolParams = match params {
573 Some(p) => serde_json::from_value(p)?,
574 None => {
575 return Err(McpError::Validation(
576 "Missing tool call parameters".to_string(),
577 ))
578 }
579 };
580
581 validate_call_tool_params(¶ms)?;
582
583 let result = self.call_tool(¶ms.name, params.arguments).await?;
584 Ok(serde_json::to_value(result)?)
585 }
586
587 async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
588 let _params: ListResourcesParams = match params {
589 Some(p) => serde_json::from_value(p)?,
590 None => ListResourcesParams::default(),
591 };
592
593 let resources = self.list_resources().await?;
594 let result = ListResourcesResult {
595 resources,
596 next_cursor: None, };
598
599 Ok(serde_json::to_value(result)?)
600 }
601
602 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
603 let params: ReadResourceParams = match params {
604 Some(p) => serde_json::from_value(p)?,
605 None => {
606 return Err(McpError::Validation(
607 "Missing resource read parameters".to_string(),
608 ))
609 }
610 };
611
612 validate_read_resource_params(¶ms)?;
613
614 let contents = self.read_resource(¶ms.uri).await?;
615 let result = ReadResourceResult { contents };
616
617 Ok(serde_json::to_value(result)?)
618 }
619
620 async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
621 let params: SubscribeResourceParams = match params {
622 Some(p) => serde_json::from_value(p)?,
623 None => {
624 return Err(McpError::Validation(
625 "Missing resource subscribe parameters".to_string(),
626 ))
627 }
628 };
629
630 let _uri = params.uri;
632 let result = SubscribeResourceResult {};
633
634 Ok(serde_json::to_value(result)?)
635 }
636
637 async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
638 let params: UnsubscribeResourceParams = match params {
639 Some(p) => serde_json::from_value(p)?,
640 None => {
641 return Err(McpError::Validation(
642 "Missing resource unsubscribe parameters".to_string(),
643 ))
644 }
645 };
646
647 let _uri = params.uri;
649 let result = UnsubscribeResourceResult {};
650
651 Ok(serde_json::to_value(result)?)
652 }
653
654 async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
655 let _params: ListPromptsParams = match params {
656 Some(p) => serde_json::from_value(p)?,
657 None => ListPromptsParams::default(),
658 };
659
660 let prompts = self.list_prompts().await?;
661 let result = ListPromptsResult {
662 prompts,
663 next_cursor: None, };
665
666 Ok(serde_json::to_value(result)?)
667 }
668
669 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
670 let params: GetPromptParams = match params {
671 Some(p) => serde_json::from_value(p)?,
672 None => {
673 return Err(McpError::Validation(
674 "Missing prompt get parameters".to_string(),
675 ))
676 }
677 };
678
679 validate_get_prompt_params(¶ms)?;
680
681 let result = self.get_prompt(¶ms.name, params.arguments).await?;
682 Ok(serde_json::to_value(result)?)
683 }
684
685 async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
686 let _params: SetLoggingLevelParams = match params {
687 Some(p) => serde_json::from_value(p)?,
688 None => {
689 return Err(McpError::Validation(
690 "Missing logging level parameters".to_string(),
691 ))
692 }
693 };
694
695 let result = SetLoggingLevelResult {};
697 Ok(serde_json::to_value(result)?)
698 }
699
700 async fn emit_resources_list_changed(&self) -> McpResult<()> {
705 let notification = JsonRpcNotification::new(
706 methods::RESOURCES_LIST_CHANGED.to_string(),
707 Some(ResourceListChangedParams {}),
708 )?;
709
710 self.send_notification(notification).await
711 }
712
713 async fn emit_tools_list_changed(&self) -> McpResult<()> {
714 let notification = JsonRpcNotification::new(
715 methods::TOOLS_LIST_CHANGED.to_string(),
716 Some(ToolListChangedParams {}),
717 )?;
718
719 self.send_notification(notification).await
720 }
721
722 async fn emit_prompts_list_changed(&self) -> McpResult<()> {
723 let notification = JsonRpcNotification::new(
724 methods::PROMPTS_LIST_CHANGED.to_string(),
725 Some(PromptListChangedParams {}),
726 )?;
727
728 self.send_notification(notification).await
729 }
730
731 async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
733 let mut transport_guard = self.transport.lock().await;
734 if let Some(transport) = transport_guard.as_mut() {
735 transport.send_notification(notification).await?;
736 }
737 Ok(())
738 }
739
740 async fn next_request_id(&self) -> u64 {
745 let mut counter = self.request_counter.lock().await;
746 *counter += 1;
747 *counter
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754 use serde_json::json;
755
756 #[tokio::test]
757 async fn test_server_creation() {
758 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
759 assert_eq!(server.info().name, "test-server");
760 assert_eq!(server.info().version, "1.0.0");
761 assert!(!server.is_running().await);
762 }
763
764 #[tokio::test]
765 async fn test_tool_management() {
766 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
767
768 let schema = json!({
770 "type": "object",
771 "properties": {
772 "name": {"type": "string"}
773 }
774 });
775
776 struct TestToolHandler;
777
778 #[async_trait::async_trait]
779 impl ToolHandler for TestToolHandler {
780 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
781 Ok(ToolResult {
782 content: vec![Content::text("Hello from tool")],
783 is_error: None,
784 })
785 }
786 }
787
788 server
789 .add_tool(
790 "test_tool".to_string(),
791 Some("A test tool".to_string()),
792 schema,
793 TestToolHandler,
794 )
795 .await
796 .unwrap();
797
798 let tools = server.list_tools().await.unwrap();
800 assert_eq!(tools.len(), 1);
801 assert_eq!(tools[0].name, "test_tool");
802
803 let result = server.call_tool("test_tool", None).await.unwrap();
805 assert_eq!(result.content.len(), 1);
806 }
807
808 #[tokio::test]
809 async fn test_initialize_request() {
810 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
811
812 let init_params = InitializeParams::new(
813 ClientInfo {
814 name: "test-client".to_string(),
815 version: "1.0.0".to_string(),
816 },
817 ClientCapabilities::default(),
818 MCP_PROTOCOL_VERSION.to_string(),
819 );
820
821 let request =
822 JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
823 .unwrap();
824
825 let response = server.handle_request(request).await.unwrap();
826 assert!(response.result.is_some());
827 assert!(response.error.is_none());
828 }
829}