1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13 PromptInfo, ResourceInfo, ToolInfo,
14 error::{McpError, McpResult},
15 prompt::{Prompt, PromptHandler},
16 resource::{Resource, ResourceHandler},
17 tool::{Tool, ToolHandler},
18};
19use crate::protocol::{error_codes::*, messages::*, methods, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22#[derive(Debug, Clone)]
24pub struct ServerConfig {
25 pub max_concurrent_requests: usize,
27 pub request_timeout_ms: u64,
29 pub validate_requests: bool,
31 pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36 fn default() -> Self {
37 Self {
38 max_concurrent_requests: 100,
39 request_timeout_ms: 30000,
40 validate_requests: true,
41 enable_logging: true,
42 }
43 }
44}
45
46pub struct McpServer {
48 info: ServerInfo,
50 capabilities: ServerCapabilities,
52 config: ServerConfig,
54 resources: Arc<RwLock<HashMap<String, Resource>>>,
56 tools: Arc<RwLock<HashMap<String, Tool>>>,
58 prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60 transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62 state: Arc<RwLock<ServerState>>,
64 #[allow(dead_code)]
66 request_counter: Arc<Mutex<u64>>,
67}
68
69#[derive(Debug, Clone, PartialEq)]
71pub enum ServerState {
72 Uninitialized,
74 Initializing,
76 Running,
78 Stopping,
80 Stopped,
82}
83
84impl McpServer {
85 pub fn new(name: String, version: String) -> Self {
87 Self {
88 info: ServerInfo::new(name, version),
89 capabilities: ServerCapabilities {
90 prompts: Some(PromptsCapability {
91 list_changed: Some(true),
92 }),
93 resources: Some(ResourcesCapability {
94 subscribe: Some(true),
95 list_changed: Some(true),
96 }),
97 tools: Some(ToolsCapability {
98 list_changed: Some(true),
99 }),
100 sampling: None,
101 logging: None,
102 experimental: None,
103 completions: None,
104 },
105 config: ServerConfig::default(),
106 resources: Arc::new(RwLock::new(HashMap::new())),
107 tools: Arc::new(RwLock::new(HashMap::new())),
108 prompts: Arc::new(RwLock::new(HashMap::new())),
109 transport: Arc::new(Mutex::new(None)),
110 state: Arc::new(RwLock::new(ServerState::Uninitialized)),
111 request_counter: Arc::new(Mutex::new(0)),
112 }
113 }
114
115 pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
117 let mut server = Self::new(name, version);
118 server.config = config;
119 server
120 }
121
122 pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
124 self.capabilities = capabilities;
125 }
126
127 pub fn info(&self) -> &ServerInfo {
129 &self.info
130 }
131
132 pub fn name(&self) -> &str {
134 &self.info.name
135 }
136
137 pub fn version(&self) -> &str {
139 &self.info.version
140 }
141
142 pub fn capabilities(&self) -> &ServerCapabilities {
144 &self.capabilities
145 }
146
147 pub fn config(&self) -> &ServerConfig {
149 &self.config
150 }
151
152 pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
158 where
159 H: ResourceHandler + 'static,
160 {
161 let resource_info = ResourceInfo {
162 uri: uri.clone(),
163 name: name.clone(),
164 description: None,
165 mime_type: None,
166 annotations: None,
167 size: None,
168 title: None,
169 meta: None,
170 };
171
172 validate_resource_info(&resource_info)?;
173
174 let resource = Resource::new(resource_info, handler);
175
176 {
177 let mut resources = self.resources.write().await;
178 resources.insert(uri, resource);
179 }
180
181 self.emit_resources_list_changed().await?;
183
184 Ok(())
185 }
186
187 pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
189 where
190 H: ResourceHandler + 'static,
191 {
192 validate_resource_info(&info)?;
193
194 let uri = info.uri.clone();
195 let resource = Resource::new(info, handler);
196
197 {
198 let mut resources = self.resources.write().await;
199 resources.insert(uri, resource);
200 }
201
202 self.emit_resources_list_changed().await?;
203
204 Ok(())
205 }
206
207 pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
209 let removed = {
210 let mut resources = self.resources.write().await;
211 resources.remove(uri).is_some()
212 };
213
214 if removed {
215 self.emit_resources_list_changed().await?;
216 }
217
218 Ok(removed)
219 }
220
221 pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
223 let resources = self.resources.read().await;
224 Ok(resources.values().map(|r| r.info.clone()).collect())
225 }
226
227 pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContents>> {
229 let resources = self.resources.read().await;
230
231 match resources.get(uri) {
232 Some(resource) => {
233 let params = HashMap::new(); resource.handler.read(uri, ¶ms).await
235 }
236 None => Err(McpError::ResourceNotFound(uri.to_string())),
237 }
238 }
239
240 pub async fn add_tool<H>(
246 &self,
247 name: String,
248 description: Option<String>,
249 schema: Value,
250 handler: H,
251 ) -> McpResult<()>
252 where
253 H: ToolHandler + 'static,
254 {
255 let tool_schema = ToolInputSchema {
256 schema_type: "object".to_string(),
257 properties: schema
258 .get("properties")
259 .and_then(|p| p.as_object())
260 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
261 required: schema.get("required").and_then(|r| {
262 r.as_array().map(|arr| {
263 arr.iter()
264 .filter_map(|v| v.as_str().map(|s| s.to_string()))
265 .collect()
266 })
267 }),
268 additional_properties: schema
269 .as_object()
270 .unwrap_or(&serde_json::Map::new())
271 .iter()
272 .map(|(k, v)| (k.clone(), v.clone()))
273 .collect(),
274 };
275
276 let tool_info = ToolInfo {
277 name: name.clone(),
278 description,
279 input_schema: tool_schema,
280 annotations: None,
281 title: None,
282 meta: None,
283 };
284
285 validate_tool_info(&tool_info)?;
286
287 let tool = Tool::new(
288 name.clone(),
289 tool_info.description.clone(),
290 serde_json::to_value(&tool_info.input_schema)?,
291 handler,
292 );
293
294 {
295 let mut tools = self.tools.write().await;
296 tools.insert(name, tool);
297 }
298
299 self.emit_tools_list_changed().await?;
300
301 Ok(())
302 }
303
304 pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
306 where
307 H: ToolHandler + 'static,
308 {
309 validate_tool_info(&info)?;
310
311 let name = info.name.clone();
312 let tool = Tool::new(
313 name.clone(),
314 info.description.clone(),
315 serde_json::to_value(&info.input_schema)?,
316 handler,
317 );
318
319 {
320 let mut tools = self.tools.write().await;
321 tools.insert(name, tool);
322 }
323
324 self.emit_tools_list_changed().await?;
325
326 Ok(())
327 }
328
329 pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
331 let removed = {
332 let mut tools = self.tools.write().await;
333 tools.remove(name).is_some()
334 };
335
336 if removed {
337 self.emit_tools_list_changed().await?;
338 }
339
340 Ok(removed)
341 }
342
343 pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
345 let tools = self.tools.read().await;
346 Ok(tools.values().map(|t| t.info.clone()).collect())
347 }
348
349 pub async fn call_tool(
351 &self,
352 name: &str,
353 arguments: Option<HashMap<String, Value>>,
354 ) -> McpResult<ToolResult> {
355 let tools = self.tools.read().await;
356
357 match tools.get(name) {
358 Some(tool) => {
359 if !tool.enabled {
360 return Err(McpError::ToolNotFound(format!("Tool '{name}' is disabled")));
361 }
362
363 let args = arguments.unwrap_or_default();
364 tool.handler.call(args).await
365 }
366 None => Err(McpError::ToolNotFound(name.to_string())),
367 }
368 }
369
370 pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
376 where
377 H: PromptHandler + 'static,
378 {
379 validate_prompt_info(&info)?;
380
381 let name = info.name.clone();
382 let prompt = Prompt::new(info, handler);
383
384 {
385 let mut prompts = self.prompts.write().await;
386 prompts.insert(name, prompt);
387 }
388
389 self.emit_prompts_list_changed().await?;
390
391 Ok(())
392 }
393
394 pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
396 let removed = {
397 let mut prompts = self.prompts.write().await;
398 prompts.remove(name).is_some()
399 };
400
401 if removed {
402 self.emit_prompts_list_changed().await?;
403 }
404
405 Ok(removed)
406 }
407
408 pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
410 let prompts = self.prompts.read().await;
411 Ok(prompts.values().map(|p| p.info.clone()).collect())
412 }
413
414 pub async fn get_prompt(
416 &self,
417 name: &str,
418 arguments: Option<HashMap<String, Value>>,
419 ) -> McpResult<PromptResult> {
420 let prompts = self.prompts.read().await;
421
422 match prompts.get(name) {
423 Some(prompt) => {
424 let args = arguments.unwrap_or_default();
425 prompt.handler.get(args).await
426 }
427 None => Err(McpError::PromptNotFound(name.to_string())),
428 }
429 }
430
431 pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
437 where
438 T: ServerTransport + 'static,
439 {
440 let mut state = self.state.write().await;
441
442 match *state {
443 ServerState::Uninitialized => {
444 *state = ServerState::Initializing;
445 }
446 _ => return Err(McpError::Protocol("Server is already started".to_string())),
447 }
448
449 drop(state);
450
451 {
453 let mut transport_guard = self.transport.lock().await;
454 *transport_guard = Some(Box::new(transport));
455 }
456
457 {
459 let mut transport_guard = self.transport.lock().await;
460 if let Some(transport) = transport_guard.as_mut() {
461 transport.start().await?;
462 }
463 }
464
465 {
467 let mut state = self.state.write().await;
468 *state = ServerState::Running;
469 }
470
471 Ok(())
472 }
473
474 pub async fn stop(&self) -> McpResult<()> {
476 let mut state = self.state.write().await;
477
478 match *state {
479 ServerState::Running => {
480 *state = ServerState::Stopping;
481 }
482 ServerState::Stopped => return Ok(()),
483 _ => return Err(McpError::Protocol("Server is not running".to_string())),
484 }
485
486 drop(state);
487
488 {
490 let mut transport_guard = self.transport.lock().await;
491 if let Some(transport) = transport_guard.as_mut() {
492 transport.stop().await?;
493 }
494 }
495
496 {
498 let mut state = self.state.write().await;
499 *state = ServerState::Stopped;
500 }
501
502 Ok(())
503 }
504
505 pub async fn is_running(&self) -> bool {
507 let state = self.state.read().await;
508 matches!(*state, ServerState::Running)
509 }
510
511 pub async fn state(&self) -> ServerState {
513 let state = self.state.read().await;
514 state.clone()
515 }
516
517 pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
523 if self.config.validate_requests {
525 validate_jsonrpc_request(&request)?;
526 validate_mcp_request(&request.method, request.params.as_ref())?;
527 }
528
529 let result = match request.method.as_str() {
531 methods::INITIALIZE => self.handle_initialize(request.params).await,
532 methods::PING => self.handle_ping().await,
533 methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
534 methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
535 methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
536 methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
537 methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
538 methods::RESOURCES_UNSUBSCRIBE => {
539 self.handle_resources_unsubscribe(request.params).await
540 }
541 methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
542 methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
543 methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
544 _ => {
545 let method = &request.method;
546 Err(McpError::Protocol(format!("Unknown method: {method}")))
547 }
548 };
549
550 match result {
552 Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
553 Err(error) => {
554 let (code, message) = match error {
555 McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
556 McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
557 McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
558 McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
559 _ => (INTERNAL_ERROR, error.to_string()),
560 };
561 Ok(JsonRpcResponse::success(
564 request.id,
565 serde_json::json!({
566 "error": {
567 "code": code,
568 "message": message,
569 }
570 }),
571 )?)
572 }
573 }
574 }
575
576 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
581 let params: InitializeParams = match params {
582 Some(p) => serde_json::from_value(p)?,
583 None => {
584 return Err(McpError::Validation(
585 "Missing initialize parameters".to_string(),
586 ));
587 }
588 };
589
590 validate_initialize_params(¶ms)?;
591
592 let result = InitializeResult::new(
593 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
594 self.capabilities.clone(),
595 self.info.clone(),
596 );
597
598 Ok(serde_json::to_value(result)?)
599 }
600
601 async fn handle_ping(&self) -> McpResult<Value> {
602 Ok(serde_json::to_value(PingResult { meta: None })?)
603 }
604
605 async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
606 let _params: ListToolsParams = match params {
607 Some(p) => serde_json::from_value(p)?,
608 None => ListToolsParams::default(),
609 };
610
611 let tools = self.list_tools().await?;
612 let result = ListToolsResult {
613 tools,
614 next_cursor: None, meta: None,
616 };
617
618 Ok(serde_json::to_value(result)?)
619 }
620
621 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
622 let params: CallToolParams = match params {
623 Some(p) => serde_json::from_value(p)?,
624 None => {
625 return Err(McpError::Validation(
626 "Missing tool call parameters".to_string(),
627 ));
628 }
629 };
630
631 validate_call_tool_params(¶ms)?;
632
633 let result = self.call_tool(¶ms.name, params.arguments).await?;
634 Ok(serde_json::to_value(result)?)
635 }
636
637 async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
638 let _params: ListResourcesParams = match params {
639 Some(p) => serde_json::from_value(p)?,
640 None => ListResourcesParams::default(),
641 };
642
643 let resources = self.list_resources().await?;
644 let result = ListResourcesResult {
645 resources,
646 next_cursor: None, meta: None,
648 };
649
650 Ok(serde_json::to_value(result)?)
651 }
652
653 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
654 let params: ReadResourceParams = match params {
655 Some(p) => serde_json::from_value(p)?,
656 None => {
657 return Err(McpError::Validation(
658 "Missing resource read parameters".to_string(),
659 ));
660 }
661 };
662
663 validate_read_resource_params(¶ms)?;
664
665 let contents = self.read_resource(¶ms.uri).await?;
666 let result = ReadResourceResult {
667 contents,
668 meta: None,
669 };
670
671 Ok(serde_json::to_value(result)?)
672 }
673
674 async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
675 let params: SubscribeResourceParams = match params {
676 Some(p) => serde_json::from_value(p)?,
677 None => {
678 return Err(McpError::Validation(
679 "Missing resource subscribe parameters".to_string(),
680 ));
681 }
682 };
683
684 let _uri = params.uri;
686 let result = SubscribeResourceResult { meta: None };
687
688 Ok(serde_json::to_value(result)?)
689 }
690
691 async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
692 let params: UnsubscribeResourceParams = match params {
693 Some(p) => serde_json::from_value(p)?,
694 None => {
695 return Err(McpError::Validation(
696 "Missing resource unsubscribe parameters".to_string(),
697 ));
698 }
699 };
700
701 let _uri = params.uri;
703 let result = UnsubscribeResourceResult { meta: None };
704
705 Ok(serde_json::to_value(result)?)
706 }
707
708 async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
709 let _params: ListPromptsParams = match params {
710 Some(p) => serde_json::from_value(p)?,
711 None => ListPromptsParams::default(),
712 };
713
714 let prompts = self.list_prompts().await?;
715 let result = ListPromptsResult {
716 prompts,
717 next_cursor: None, meta: None,
719 };
720
721 Ok(serde_json::to_value(result)?)
722 }
723
724 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
725 let params: GetPromptParams = match params {
726 Some(p) => serde_json::from_value(p)?,
727 None => {
728 return Err(McpError::Validation(
729 "Missing prompt get parameters".to_string(),
730 ));
731 }
732 };
733
734 validate_get_prompt_params(¶ms)?;
735
736 let arguments = params.arguments.map(|args| {
737 args.into_iter()
738 .map(|(k, v)| (k, serde_json::Value::String(v)))
739 .collect()
740 });
741 let result = self.get_prompt(¶ms.name, arguments).await?;
742 Ok(serde_json::to_value(result)?)
743 }
744
745 async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
746 let _params: SetLoggingLevelParams = match params {
747 Some(p) => serde_json::from_value(p)?,
748 None => {
749 return Err(McpError::Validation(
750 "Missing logging level parameters".to_string(),
751 ));
752 }
753 };
754
755 let result = SetLoggingLevelResult { meta: None };
757 Ok(serde_json::to_value(result)?)
758 }
759
760 async fn emit_resources_list_changed(&self) -> McpResult<()> {
765 let notification = JsonRpcNotification::new(
766 methods::RESOURCES_LIST_CHANGED.to_string(),
767 Some(ResourceListChangedParams { meta: None }),
768 )?;
769
770 self.send_notification(notification).await
771 }
772
773 async fn emit_tools_list_changed(&self) -> McpResult<()> {
774 let notification = JsonRpcNotification::new(
775 methods::TOOLS_LIST_CHANGED.to_string(),
776 Some(ToolListChangedParams { meta: None }),
777 )?;
778
779 self.send_notification(notification).await
780 }
781
782 async fn emit_prompts_list_changed(&self) -> McpResult<()> {
783 let notification = JsonRpcNotification::new(
784 methods::PROMPTS_LIST_CHANGED.to_string(),
785 Some(PromptListChangedParams { meta: None }),
786 )?;
787
788 self.send_notification(notification).await
789 }
790
791 async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
793 let mut transport_guard = self.transport.lock().await;
794 if let Some(transport) = transport_guard.as_mut() {
795 transport.send_notification(notification).await?;
796 }
797 Ok(())
798 }
799
800 #[allow(dead_code)]
805 async fn next_request_id(&self) -> u64 {
806 let mut counter = self.request_counter.lock().await;
807 *counter += 1;
808 *counter
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815 use serde_json::json;
816
817 #[tokio::test]
818 async fn test_server_creation() {
819 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
820 assert_eq!(server.info().name, "test-server");
821 assert_eq!(server.info().version, "1.0.0");
822 assert!(!server.is_running().await);
823 }
824
825 #[tokio::test]
826 async fn test_tool_management() {
827 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
828
829 let schema = json!({
831 "type": "object",
832 "properties": {
833 "name": {"type": "string"}
834 }
835 });
836
837 struct TestToolHandler;
838
839 #[async_trait::async_trait]
840 impl ToolHandler for TestToolHandler {
841 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
842 Ok(ToolResult {
843 content: vec![Content::text("Hello from tool")],
844 is_error: None,
845 structured_content: None,
846 meta: None,
847 })
848 }
849 }
850
851 server
852 .add_tool(
853 "test_tool".to_string(),
854 Some("A test tool".to_string()),
855 schema,
856 TestToolHandler,
857 )
858 .await
859 .unwrap();
860
861 let tools = server.list_tools().await.unwrap();
863 assert_eq!(tools.len(), 1);
864 assert_eq!(tools[0].name, "test_tool");
865
866 let result = server.call_tool("test_tool", None).await.unwrap();
868 assert_eq!(result.content.len(), 1);
869 }
870
871 #[tokio::test]
872 async fn test_initialize_request() {
873 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
874
875 let init_params = InitializeParams::new(
876 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
877 ClientCapabilities::default(),
878 ClientInfo {
879 name: "test-client".to_string(),
880 version: "1.0.0".to_string(),
881 title: Some("Test Client".to_string()),
882 },
883 );
884
885 let request =
886 JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
887 .unwrap();
888
889 let response = server.handle_request(request).await.unwrap();
890 assert!(response.result.is_some());
891 }
892}