1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13 error::{McpError, McpResult},
14 prompt::{Prompt, PromptHandler},
15 resource::{Resource, ResourceHandler},
16 tool::{Tool, ToolHandler},
17 PromptInfo, ResourceInfo, ToolInfo,
18};
19use crate::protocol::{error_codes::*, messages::*, methods, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22#[derive(Debug, Clone)]
24pub struct ServerConfig {
25 pub max_concurrent_requests: usize,
27 pub request_timeout_ms: u64,
29 pub validate_requests: bool,
31 pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36 fn default() -> Self {
37 Self {
38 max_concurrent_requests: 100,
39 request_timeout_ms: 30000,
40 validate_requests: true,
41 enable_logging: true,
42 }
43 }
44}
45
46pub struct McpServer {
48 info: ServerInfo,
50 capabilities: ServerCapabilities,
52 config: ServerConfig,
54 resources: Arc<RwLock<HashMap<String, Resource>>>,
56 tools: Arc<RwLock<HashMap<String, Tool>>>,
58 prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60 transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62 state: Arc<RwLock<ServerState>>,
64 #[allow(dead_code)]
66 request_counter: Arc<Mutex<u64>>,
67}
68
69#[derive(Debug, Clone, PartialEq)]
71pub enum ServerState {
72 Uninitialized,
74 Initializing,
76 Running,
78 Stopping,
80 Stopped,
82}
83
84impl McpServer {
85 pub fn new(name: String, version: String) -> Self {
87 Self {
88 info: ServerInfo { name, version },
89 capabilities: ServerCapabilities {
90 prompts: Some(PromptsCapability {
91 list_changed: Some(true),
92 }),
93 resources: Some(ResourcesCapability {
94 subscribe: Some(true),
95 list_changed: Some(true),
96 }),
97 tools: Some(ToolsCapability {
98 list_changed: Some(true),
99 }),
100 sampling: None,
101 logging: None,
102 experimental: None,
103 completions: None,
104 },
105 config: ServerConfig::default(),
106 resources: Arc::new(RwLock::new(HashMap::new())),
107 tools: Arc::new(RwLock::new(HashMap::new())),
108 prompts: Arc::new(RwLock::new(HashMap::new())),
109 transport: Arc::new(Mutex::new(None)),
110 state: Arc::new(RwLock::new(ServerState::Uninitialized)),
111 request_counter: Arc::new(Mutex::new(0)),
112 }
113 }
114
115 pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
117 let mut server = Self::new(name, version);
118 server.config = config;
119 server
120 }
121
122 pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
124 self.capabilities = capabilities;
125 }
126
127 pub fn info(&self) -> &ServerInfo {
129 &self.info
130 }
131
132 pub fn name(&self) -> &str {
134 &self.info.name
135 }
136
137 pub fn version(&self) -> &str {
139 &self.info.version
140 }
141
142 pub fn capabilities(&self) -> &ServerCapabilities {
144 &self.capabilities
145 }
146
147 pub fn config(&self) -> &ServerConfig {
149 &self.config
150 }
151
152 pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
158 where
159 H: ResourceHandler + 'static,
160 {
161 let resource_info = ResourceInfo {
162 uri: uri.clone(),
163 name: Some(name.clone()),
164 description: None,
165 mime_type: None,
166 annotations: None,
167 size: None,
168 };
169
170 validate_resource_info(&resource_info)?;
171
172 let resource = Resource::new(resource_info, handler);
173
174 {
175 let mut resources = self.resources.write().await;
176 resources.insert(uri, resource);
177 }
178
179 self.emit_resources_list_changed().await?;
181
182 Ok(())
183 }
184
185 pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
187 where
188 H: ResourceHandler + 'static,
189 {
190 validate_resource_info(&info)?;
191
192 let uri = info.uri.clone();
193 let resource = Resource::new(info, handler);
194
195 {
196 let mut resources = self.resources.write().await;
197 resources.insert(uri, resource);
198 }
199
200 self.emit_resources_list_changed().await?;
201
202 Ok(())
203 }
204
205 pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
207 let removed = {
208 let mut resources = self.resources.write().await;
209 resources.remove(uri).is_some()
210 };
211
212 if removed {
213 self.emit_resources_list_changed().await?;
214 }
215
216 Ok(removed)
217 }
218
219 pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
221 let resources = self.resources.read().await;
222 Ok(resources.values().map(|r| r.info.clone()).collect())
223 }
224
225 pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContents>> {
227 let resources = self.resources.read().await;
228
229 match resources.get(uri) {
230 Some(resource) => {
231 let params = HashMap::new(); resource.handler.read(uri, ¶ms).await
233 }
234 None => Err(McpError::ResourceNotFound(uri.to_string())),
235 }
236 }
237
238 pub async fn add_tool<H>(
244 &self,
245 name: String,
246 description: Option<String>,
247 schema: Value,
248 handler: H,
249 ) -> McpResult<()>
250 where
251 H: ToolHandler + 'static,
252 {
253 let tool_schema = ToolInputSchema {
254 schema_type: "object".to_string(),
255 properties: schema
256 .get("properties")
257 .and_then(|p| p.as_object())
258 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
259 required: schema.get("required").and_then(|r| {
260 r.as_array().map(|arr| {
261 arr.iter()
262 .filter_map(|v| v.as_str().map(|s| s.to_string()))
263 .collect()
264 })
265 }),
266 additional_properties: schema
267 .as_object()
268 .unwrap_or(&serde_json::Map::new())
269 .iter()
270 .map(|(k, v)| (k.clone(), v.clone()))
271 .collect(),
272 };
273
274 let tool_info = ToolInfo {
275 name: name.clone(),
276 description,
277 input_schema: tool_schema,
278 annotations: None,
279 };
280
281 validate_tool_info(&tool_info)?;
282
283 let tool = Tool::new(
284 name.clone(),
285 tool_info.description.clone(),
286 serde_json::to_value(&tool_info.input_schema)?,
287 handler,
288 );
289
290 {
291 let mut tools = self.tools.write().await;
292 tools.insert(name, tool);
293 }
294
295 self.emit_tools_list_changed().await?;
296
297 Ok(())
298 }
299
300 pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
302 where
303 H: ToolHandler + 'static,
304 {
305 validate_tool_info(&info)?;
306
307 let name = info.name.clone();
308 let tool = Tool::new(
309 name.clone(),
310 info.description.clone(),
311 serde_json::to_value(&info.input_schema)?,
312 handler,
313 );
314
315 {
316 let mut tools = self.tools.write().await;
317 tools.insert(name, tool);
318 }
319
320 self.emit_tools_list_changed().await?;
321
322 Ok(())
323 }
324
325 pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
327 let removed = {
328 let mut tools = self.tools.write().await;
329 tools.remove(name).is_some()
330 };
331
332 if removed {
333 self.emit_tools_list_changed().await?;
334 }
335
336 Ok(removed)
337 }
338
339 pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
341 let tools = self.tools.read().await;
342 Ok(tools.values().map(|t| t.info.clone()).collect())
343 }
344
345 pub async fn call_tool(
347 &self,
348 name: &str,
349 arguments: Option<HashMap<String, Value>>,
350 ) -> McpResult<ToolResult> {
351 let tools = self.tools.read().await;
352
353 match tools.get(name) {
354 Some(tool) => {
355 if !tool.enabled {
356 return Err(McpError::ToolNotFound(format!(
357 "Tool '{}' is disabled",
358 name
359 )));
360 }
361
362 let args = arguments.unwrap_or_default();
363 tool.handler.call(args).await
364 }
365 None => Err(McpError::ToolNotFound(name.to_string())),
366 }
367 }
368
369 pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
375 where
376 H: PromptHandler + 'static,
377 {
378 validate_prompt_info(&info)?;
379
380 let name = info.name.clone();
381 let prompt = Prompt::new(info, handler);
382
383 {
384 let mut prompts = self.prompts.write().await;
385 prompts.insert(name, prompt);
386 }
387
388 self.emit_prompts_list_changed().await?;
389
390 Ok(())
391 }
392
393 pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
395 let removed = {
396 let mut prompts = self.prompts.write().await;
397 prompts.remove(name).is_some()
398 };
399
400 if removed {
401 self.emit_prompts_list_changed().await?;
402 }
403
404 Ok(removed)
405 }
406
407 pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
409 let prompts = self.prompts.read().await;
410 Ok(prompts.values().map(|p| p.info.clone()).collect())
411 }
412
413 pub async fn get_prompt(
415 &self,
416 name: &str,
417 arguments: Option<HashMap<String, Value>>,
418 ) -> McpResult<PromptResult> {
419 let prompts = self.prompts.read().await;
420
421 match prompts.get(name) {
422 Some(prompt) => {
423 let args = arguments.unwrap_or_default();
424 prompt.handler.get(args).await
425 }
426 None => Err(McpError::PromptNotFound(name.to_string())),
427 }
428 }
429
430 pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
436 where
437 T: ServerTransport + 'static,
438 {
439 let mut state = self.state.write().await;
440
441 match *state {
442 ServerState::Uninitialized => {
443 *state = ServerState::Initializing;
444 }
445 _ => return Err(McpError::Protocol("Server is already started".to_string())),
446 }
447
448 drop(state);
449
450 {
452 let mut transport_guard = self.transport.lock().await;
453 *transport_guard = Some(Box::new(transport));
454 }
455
456 {
458 let mut transport_guard = self.transport.lock().await;
459 if let Some(transport) = transport_guard.as_mut() {
460 transport.start().await?;
461 }
462 }
463
464 {
466 let mut state = self.state.write().await;
467 *state = ServerState::Running;
468 }
469
470 Ok(())
471 }
472
473 pub async fn stop(&self) -> McpResult<()> {
475 let mut state = self.state.write().await;
476
477 match *state {
478 ServerState::Running => {
479 *state = ServerState::Stopping;
480 }
481 ServerState::Stopped => return Ok(()),
482 _ => return Err(McpError::Protocol("Server is not running".to_string())),
483 }
484
485 drop(state);
486
487 {
489 let mut transport_guard = self.transport.lock().await;
490 if let Some(transport) = transport_guard.as_mut() {
491 transport.stop().await?;
492 }
493 }
494
495 {
497 let mut state = self.state.write().await;
498 *state = ServerState::Stopped;
499 }
500
501 Ok(())
502 }
503
504 pub async fn is_running(&self) -> bool {
506 let state = self.state.read().await;
507 matches!(*state, ServerState::Running)
508 }
509
510 pub async fn state(&self) -> ServerState {
512 let state = self.state.read().await;
513 state.clone()
514 }
515
516 pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
522 if self.config.validate_requests {
524 validate_jsonrpc_request(&request)?;
525 validate_mcp_request(&request.method, request.params.as_ref())?;
526 }
527
528 let result = match request.method.as_str() {
530 methods::INITIALIZE => self.handle_initialize(request.params).await,
531 methods::PING => self.handle_ping().await,
532 methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
533 methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
534 methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
535 methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
536 methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
537 methods::RESOURCES_UNSUBSCRIBE => {
538 self.handle_resources_unsubscribe(request.params).await
539 }
540 methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
541 methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
542 methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
543 _ => Err(McpError::Protocol(format!(
544 "Unknown method: {}",
545 request.method
546 ))),
547 };
548
549 match result {
551 Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
552 Err(error) => {
553 let (code, message) = match error {
554 McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
555 McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
556 McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
557 McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
558 _ => (INTERNAL_ERROR, error.to_string()),
559 };
560 Ok(JsonRpcResponse::success(
563 request.id,
564 serde_json::json!({
565 "error": {
566 "code": code,
567 "message": message,
568 }
569 }),
570 )?)
571 }
572 }
573 }
574
575 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
580 let params: InitializeParams = match params {
581 Some(p) => serde_json::from_value(p)?,
582 None => {
583 return Err(McpError::Validation(
584 "Missing initialize parameters".to_string(),
585 ))
586 }
587 };
588
589 validate_initialize_params(¶ms)?;
590
591 let result = InitializeResult::new(
592 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
593 self.capabilities.clone(),
594 self.info.clone(),
595 );
596
597 Ok(serde_json::to_value(result)?)
598 }
599
600 async fn handle_ping(&self) -> McpResult<Value> {
601 Ok(serde_json::to_value(PingResult { meta: None })?)
602 }
603
604 async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
605 let _params: ListToolsParams = match params {
606 Some(p) => serde_json::from_value(p)?,
607 None => ListToolsParams::default(),
608 };
609
610 let tools = self.list_tools().await?;
611 let result = ListToolsResult {
612 tools,
613 next_cursor: None, meta: None,
615 };
616
617 Ok(serde_json::to_value(result)?)
618 }
619
620 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
621 let params: CallToolParams = match params {
622 Some(p) => serde_json::from_value(p)?,
623 None => {
624 return Err(McpError::Validation(
625 "Missing tool call parameters".to_string(),
626 ))
627 }
628 };
629
630 validate_call_tool_params(¶ms)?;
631
632 let result = self.call_tool(¶ms.name, params.arguments).await?;
633 Ok(serde_json::to_value(result)?)
634 }
635
636 async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
637 let _params: ListResourcesParams = match params {
638 Some(p) => serde_json::from_value(p)?,
639 None => ListResourcesParams::default(),
640 };
641
642 let resources = self.list_resources().await?;
643 let result = ListResourcesResult {
644 resources,
645 next_cursor: None, meta: None,
647 };
648
649 Ok(serde_json::to_value(result)?)
650 }
651
652 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
653 let params: ReadResourceParams = match params {
654 Some(p) => serde_json::from_value(p)?,
655 None => {
656 return Err(McpError::Validation(
657 "Missing resource read parameters".to_string(),
658 ))
659 }
660 };
661
662 validate_read_resource_params(¶ms)?;
663
664 let contents = self.read_resource(¶ms.uri).await?;
665 let result = ReadResourceResult {
666 contents,
667 meta: None,
668 };
669
670 Ok(serde_json::to_value(result)?)
671 }
672
673 async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
674 let params: SubscribeResourceParams = match params {
675 Some(p) => serde_json::from_value(p)?,
676 None => {
677 return Err(McpError::Validation(
678 "Missing resource subscribe parameters".to_string(),
679 ))
680 }
681 };
682
683 let _uri = params.uri;
685 let result = SubscribeResourceResult { meta: None };
686
687 Ok(serde_json::to_value(result)?)
688 }
689
690 async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
691 let params: UnsubscribeResourceParams = match params {
692 Some(p) => serde_json::from_value(p)?,
693 None => {
694 return Err(McpError::Validation(
695 "Missing resource unsubscribe parameters".to_string(),
696 ))
697 }
698 };
699
700 let _uri = params.uri;
702 let result = UnsubscribeResourceResult { meta: None };
703
704 Ok(serde_json::to_value(result)?)
705 }
706
707 async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
708 let _params: ListPromptsParams = match params {
709 Some(p) => serde_json::from_value(p)?,
710 None => ListPromptsParams::default(),
711 };
712
713 let prompts = self.list_prompts().await?;
714 let result = ListPromptsResult {
715 prompts,
716 next_cursor: None, meta: None,
718 };
719
720 Ok(serde_json::to_value(result)?)
721 }
722
723 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
724 let params: GetPromptParams = match params {
725 Some(p) => serde_json::from_value(p)?,
726 None => {
727 return Err(McpError::Validation(
728 "Missing prompt get parameters".to_string(),
729 ))
730 }
731 };
732
733 validate_get_prompt_params(¶ms)?;
734
735 let arguments = params.arguments.map(|args| {
736 args.into_iter()
737 .map(|(k, v)| (k, serde_json::Value::String(v)))
738 .collect()
739 });
740 let result = self.get_prompt(¶ms.name, arguments).await?;
741 Ok(serde_json::to_value(result)?)
742 }
743
744 async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
745 let _params: SetLoggingLevelParams = match params {
746 Some(p) => serde_json::from_value(p)?,
747 None => {
748 return Err(McpError::Validation(
749 "Missing logging level parameters".to_string(),
750 ))
751 }
752 };
753
754 let result = SetLoggingLevelResult { meta: None };
756 Ok(serde_json::to_value(result)?)
757 }
758
759 async fn emit_resources_list_changed(&self) -> McpResult<()> {
764 let notification = JsonRpcNotification::new(
765 methods::RESOURCES_LIST_CHANGED.to_string(),
766 Some(ResourceListChangedParams { meta: None }),
767 )?;
768
769 self.send_notification(notification).await
770 }
771
772 async fn emit_tools_list_changed(&self) -> McpResult<()> {
773 let notification = JsonRpcNotification::new(
774 methods::TOOLS_LIST_CHANGED.to_string(),
775 Some(ToolListChangedParams { meta: None }),
776 )?;
777
778 self.send_notification(notification).await
779 }
780
781 async fn emit_prompts_list_changed(&self) -> McpResult<()> {
782 let notification = JsonRpcNotification::new(
783 methods::PROMPTS_LIST_CHANGED.to_string(),
784 Some(PromptListChangedParams { meta: None }),
785 )?;
786
787 self.send_notification(notification).await
788 }
789
790 async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
792 let mut transport_guard = self.transport.lock().await;
793 if let Some(transport) = transport_guard.as_mut() {
794 transport.send_notification(notification).await?;
795 }
796 Ok(())
797 }
798
799 #[allow(dead_code)]
804 async fn next_request_id(&self) -> u64 {
805 let mut counter = self.request_counter.lock().await;
806 *counter += 1;
807 *counter
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use super::*;
814 use serde_json::json;
815
816 #[tokio::test]
817 async fn test_server_creation() {
818 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
819 assert_eq!(server.info().name, "test-server");
820 assert_eq!(server.info().version, "1.0.0");
821 assert!(!server.is_running().await);
822 }
823
824 #[tokio::test]
825 async fn test_tool_management() {
826 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
827
828 let schema = json!({
830 "type": "object",
831 "properties": {
832 "name": {"type": "string"}
833 }
834 });
835
836 struct TestToolHandler;
837
838 #[async_trait::async_trait]
839 impl ToolHandler for TestToolHandler {
840 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
841 Ok(ToolResult {
842 content: vec![Content::text("Hello from tool")],
843 is_error: None,
844 meta: None,
845 })
846 }
847 }
848
849 server
850 .add_tool(
851 "test_tool".to_string(),
852 Some("A test tool".to_string()),
853 schema,
854 TestToolHandler,
855 )
856 .await
857 .unwrap();
858
859 let tools = server.list_tools().await.unwrap();
861 assert_eq!(tools.len(), 1);
862 assert_eq!(tools[0].name, "test_tool");
863
864 let result = server.call_tool("test_tool", None).await.unwrap();
866 assert_eq!(result.content.len(), 1);
867 }
868
869 #[tokio::test]
870 async fn test_initialize_request() {
871 let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
872
873 let init_params = InitializeParams::new(
874 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
875 ClientCapabilities::default(),
876 ClientInfo {
877 name: "test-client".to_string(),
878 version: "1.0.0".to_string(),
879 },
880 );
881
882 let request =
883 JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
884 .unwrap();
885
886 let response = server.handle_request(request).await.unwrap();
887 assert!(response.result.is_some());
888 }
889}