mcp_protocol_sdk/server/
mcp_server.rs

1//! MCP server implementation
2//!
3//! This module provides the main MCP server implementation that handles client connections,
4//! manages resources, tools, and prompts, and processes JSON-RPC requests according to
5//! the Model Context Protocol specification.
6
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13    error::{McpError, McpResult},
14    prompt::{Prompt, PromptHandler},
15    resource::{Resource, ResourceHandler},
16    tool::{Tool, ToolHandler},
17    PromptInfo, ResourceInfo, ToolInfo,
18};
19use crate::protocol::{messages::*, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22/// Configuration for the MCP server
23#[derive(Debug, Clone)]
24pub struct ServerConfig {
25    /// Maximum number of concurrent requests
26    pub max_concurrent_requests: usize,
27    /// Request timeout in milliseconds
28    pub request_timeout_ms: u64,
29    /// Whether to validate all incoming requests
30    pub validate_requests: bool,
31    /// Whether to enable detailed logging
32    pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36    fn default() -> Self {
37        Self {
38            max_concurrent_requests: 100,
39            request_timeout_ms: 30000,
40            validate_requests: true,
41            enable_logging: true,
42        }
43    }
44}
45
46/// Main MCP server implementation
47pub struct McpServer {
48    /// Server information
49    info: ServerInfo,
50    /// Server capabilities
51    capabilities: ServerCapabilities,
52    /// Server configuration
53    config: ServerConfig,
54    /// Registered resources
55    resources: Arc<RwLock<HashMap<String, Resource>>>,
56    /// Registered tools
57    tools: Arc<RwLock<HashMap<String, Tool>>>,
58    /// Registered prompts
59    prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60    /// Active transport
61    transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62    /// Server state
63    state: Arc<RwLock<ServerState>>,
64    /// Request ID counter
65    #[allow(dead_code)]
66    request_counter: Arc<Mutex<u64>>,
67}
68
69/// Internal server state
70#[derive(Debug, Clone, PartialEq)]
71pub enum ServerState {
72    /// Server is not yet initialized
73    Uninitialized,
74    /// Server is initializing
75    Initializing,
76    /// Server is running and ready to accept requests
77    Running,
78    /// Server is shutting down
79    Stopping,
80    /// Server has stopped
81    Stopped,
82}
83
84impl McpServer {
85    /// Create a new MCP server with the given name and version
86    pub fn new(name: String, version: String) -> Self {
87        Self {
88            info: ServerInfo { name, version },
89            capabilities: ServerCapabilities {
90                prompts: Some(PromptsCapability {
91                    list_changed: Some(true),
92                }),
93                resources: Some(ResourcesCapability {
94                    subscribe: Some(true),
95                    list_changed: Some(true),
96                }),
97                tools: Some(ToolsCapability {
98                    list_changed: Some(true),
99                }),
100                sampling: None,
101            },
102            config: ServerConfig::default(),
103            resources: Arc::new(RwLock::new(HashMap::new())),
104            tools: Arc::new(RwLock::new(HashMap::new())),
105            prompts: Arc::new(RwLock::new(HashMap::new())),
106            transport: Arc::new(Mutex::new(None)),
107            state: Arc::new(RwLock::new(ServerState::Uninitialized)),
108            request_counter: Arc::new(Mutex::new(0)),
109        }
110    }
111
112    /// Create a new MCP server with custom configuration
113    pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
114        let mut server = Self::new(name, version);
115        server.config = config;
116        server
117    }
118
119    /// Set server capabilities
120    pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
121        self.capabilities = capabilities;
122    }
123
124    /// Get server information
125    pub fn info(&self) -> &ServerInfo {
126        &self.info
127    }
128
129    /// Get server capabilities
130    pub fn capabilities(&self) -> &ServerCapabilities {
131        &self.capabilities
132    }
133
134    /// Get server configuration
135    pub fn config(&self) -> &ServerConfig {
136        &self.config
137    }
138
139    // ========================================================================
140    // Resource Management
141    // ========================================================================
142
143    /// Add a resource to the server
144    pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
145    where
146        H: ResourceHandler + 'static,
147    {
148        let resource_info = ResourceInfo {
149            uri: uri.clone(),
150            name: name.clone(),
151            description: None,
152            mime_type: None,
153        };
154
155        validate_resource_info(&resource_info)?;
156
157        let resource = Resource::new(resource_info, handler);
158
159        {
160            let mut resources = self.resources.write().await;
161            resources.insert(uri, resource);
162        }
163
164        // Emit list changed notification if we have an active transport
165        self.emit_resources_list_changed().await?;
166
167        Ok(())
168    }
169
170    /// Add a resource with detailed information
171    pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
172    where
173        H: ResourceHandler + 'static,
174    {
175        validate_resource_info(&info)?;
176
177        let uri = info.uri.clone();
178        let resource = Resource::new(info, handler);
179
180        {
181            let mut resources = self.resources.write().await;
182            resources.insert(uri, resource);
183        }
184
185        self.emit_resources_list_changed().await?;
186
187        Ok(())
188    }
189
190    /// Remove a resource from the server
191    pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
192        let removed = {
193            let mut resources = self.resources.write().await;
194            resources.remove(uri).is_some()
195        };
196
197        if removed {
198            self.emit_resources_list_changed().await?;
199        }
200
201        Ok(removed)
202    }
203
204    /// List all registered resources
205    pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
206        let resources = self.resources.read().await;
207        Ok(resources.values().map(|r| r.info.clone()).collect())
208    }
209
210    /// Read a resource
211    pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContent>> {
212        let resources = self.resources.read().await;
213
214        match resources.get(uri) {
215            Some(resource) => {
216                let params = HashMap::new(); // URL parameter extraction will be implemented in future versions
217                resource.handler.read(uri, &params).await
218            }
219            None => Err(McpError::ResourceNotFound(uri.to_string())),
220        }
221    }
222
223    // ========================================================================
224    // Tool Management
225    // ========================================================================
226
227    /// Add a tool to the server
228    pub async fn add_tool<H>(
229        &self,
230        name: String,
231        description: Option<String>,
232        schema: Value,
233        handler: H,
234    ) -> McpResult<()>
235    where
236        H: ToolHandler + 'static,
237    {
238        let tool_info = ToolInfo {
239            name: name.clone(),
240            description,
241            input_schema: schema,
242        };
243
244        validate_tool_info(&tool_info)?;
245
246        let tool = Tool::new(
247            name.clone(),
248            tool_info.description.clone(),
249            tool_info.input_schema.clone(),
250            handler,
251        );
252
253        {
254            let mut tools = self.tools.write().await;
255            tools.insert(name, tool);
256        }
257
258        self.emit_tools_list_changed().await?;
259
260        Ok(())
261    }
262
263    /// Add a tool with detailed information
264    pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
265    where
266        H: ToolHandler + 'static,
267    {
268        validate_tool_info(&info)?;
269
270        let name = info.name.clone();
271        let tool = Tool::new(
272            name.clone(),
273            info.description.clone(),
274            info.input_schema.clone(),
275            handler,
276        );
277
278        {
279            let mut tools = self.tools.write().await;
280            tools.insert(name, tool);
281        }
282
283        self.emit_tools_list_changed().await?;
284
285        Ok(())
286    }
287
288    /// Remove a tool from the server
289    pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
290        let removed = {
291            let mut tools = self.tools.write().await;
292            tools.remove(name).is_some()
293        };
294
295        if removed {
296            self.emit_tools_list_changed().await?;
297        }
298
299        Ok(removed)
300    }
301
302    /// List all registered tools
303    pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
304        let tools = self.tools.read().await;
305        Ok(tools.values().map(|t| t.info.clone()).collect())
306    }
307
308    /// Call a tool
309    pub async fn call_tool(
310        &self,
311        name: &str,
312        arguments: Option<HashMap<String, Value>>,
313    ) -> McpResult<ToolResult> {
314        let tools = self.tools.read().await;
315
316        match tools.get(name) {
317            Some(tool) => {
318                if !tool.enabled {
319                    return Err(McpError::ToolNotFound(format!(
320                        "Tool '{}' is disabled",
321                        name
322                    )));
323                }
324
325                let args = arguments.unwrap_or_default();
326                tool.handler.call(args).await
327            }
328            None => Err(McpError::ToolNotFound(name.to_string())),
329        }
330    }
331
332    // ========================================================================
333    // Prompt Management
334    // ========================================================================
335
336    /// Add a prompt to the server
337    pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
338    where
339        H: PromptHandler + 'static,
340    {
341        validate_prompt_info(&info)?;
342
343        let name = info.name.clone();
344        let prompt = Prompt::new(info, handler);
345
346        {
347            let mut prompts = self.prompts.write().await;
348            prompts.insert(name, prompt);
349        }
350
351        self.emit_prompts_list_changed().await?;
352
353        Ok(())
354    }
355
356    /// Remove a prompt from the server
357    pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
358        let removed = {
359            let mut prompts = self.prompts.write().await;
360            prompts.remove(name).is_some()
361        };
362
363        if removed {
364            self.emit_prompts_list_changed().await?;
365        }
366
367        Ok(removed)
368    }
369
370    /// List all registered prompts
371    pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
372        let prompts = self.prompts.read().await;
373        Ok(prompts.values().map(|p| p.info.clone()).collect())
374    }
375
376    /// Get a prompt
377    pub async fn get_prompt(
378        &self,
379        name: &str,
380        arguments: Option<HashMap<String, Value>>,
381    ) -> McpResult<PromptResult> {
382        let prompts = self.prompts.read().await;
383
384        match prompts.get(name) {
385            Some(prompt) => {
386                let args = arguments.unwrap_or_default();
387                prompt.handler.get(args).await
388            }
389            None => Err(McpError::PromptNotFound(name.to_string())),
390        }
391    }
392
393    // ========================================================================
394    // Server Lifecycle
395    // ========================================================================
396
397    /// Start the server with the given transport
398    pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
399    where
400        T: ServerTransport + 'static,
401    {
402        let mut state = self.state.write().await;
403
404        match *state {
405            ServerState::Uninitialized => {
406                *state = ServerState::Initializing;
407            }
408            _ => return Err(McpError::Protocol("Server is already started".to_string())),
409        }
410
411        drop(state);
412
413        // Set up the transport
414        {
415            let mut transport_guard = self.transport.lock().await;
416            *transport_guard = Some(Box::new(transport));
417        }
418
419        // Start the transport
420        {
421            let mut transport_guard = self.transport.lock().await;
422            if let Some(transport) = transport_guard.as_mut() {
423                transport.start().await?;
424            }
425        }
426
427        // Update state to running
428        {
429            let mut state = self.state.write().await;
430            *state = ServerState::Running;
431        }
432
433        Ok(())
434    }
435
436    /// Stop the server
437    pub async fn stop(&self) -> McpResult<()> {
438        let mut state = self.state.write().await;
439
440        match *state {
441            ServerState::Running => {
442                *state = ServerState::Stopping;
443            }
444            ServerState::Stopped => return Ok(()),
445            _ => return Err(McpError::Protocol("Server is not running".to_string())),
446        }
447
448        drop(state);
449
450        // Stop the transport
451        {
452            let mut transport_guard = self.transport.lock().await;
453            if let Some(transport) = transport_guard.as_mut() {
454                transport.stop().await?;
455            }
456        }
457
458        // Update state to stopped
459        {
460            let mut state = self.state.write().await;
461            *state = ServerState::Stopped;
462        }
463
464        Ok(())
465    }
466
467    /// Check if the server is running
468    pub async fn is_running(&self) -> bool {
469        let state = self.state.read().await;
470        matches!(*state, ServerState::Running)
471    }
472
473    /// Get the current server state
474    pub async fn state(&self) -> ServerState {
475        let state = self.state.read().await;
476        state.clone()
477    }
478
479    // ========================================================================
480    // Request Handling
481    // ========================================================================
482
483    /// Handle an incoming JSON-RPC request
484    pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
485        // Validate the request if configured to do so
486        if self.config.validate_requests {
487            validate_jsonrpc_request(&request)?;
488            validate_mcp_request(&request.method, request.params.as_ref())?;
489        }
490
491        // Route the request to the appropriate handler
492        let result = match request.method.as_str() {
493            methods::INITIALIZE => self.handle_initialize(request.params).await,
494            methods::PING => self.handle_ping().await,
495            methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
496            methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
497            methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
498            methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
499            methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
500            methods::RESOURCES_UNSUBSCRIBE => {
501                self.handle_resources_unsubscribe(request.params).await
502            }
503            methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
504            methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
505            methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
506            _ => Err(McpError::Protocol(format!(
507                "Unknown method: {}",
508                request.method
509            ))),
510        };
511
512        // Convert the result to a JSON-RPC response
513        match result {
514            Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
515            Err(error) => {
516                let (code, message) = match error {
517                    McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
518                    McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
519                    McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
520                    McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
521                    _ => (INTERNAL_ERROR, error.to_string()),
522                };
523                Ok(JsonRpcResponse::error(request.id, code, message, None))
524            }
525        }
526    }
527
528    // ========================================================================
529    // Individual Request Handlers
530    // ========================================================================
531
532    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
533        let params: InitializeParams = match params {
534            Some(p) => serde_json::from_value(p)?,
535            None => {
536                return Err(McpError::Validation(
537                    "Missing initialize parameters".to_string(),
538                ))
539            }
540        };
541
542        validate_initialize_params(&params)?;
543
544        let result = InitializeResult::new(
545            self.info.clone(),
546            self.capabilities.clone(),
547            MCP_PROTOCOL_VERSION.to_string(),
548        );
549
550        Ok(serde_json::to_value(result)?)
551    }
552
553    async fn handle_ping(&self) -> McpResult<Value> {
554        Ok(serde_json::to_value(PingResult {})?)
555    }
556
557    async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
558        let _params: ListToolsParams = match params {
559            Some(p) => serde_json::from_value(p)?,
560            None => ListToolsParams::default(),
561        };
562
563        let tools = self.list_tools().await?;
564        let result = ListToolsResult {
565            tools,
566            next_cursor: None, // Pagination support will be added in future versions
567        };
568
569        Ok(serde_json::to_value(result)?)
570    }
571
572    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
573        let params: CallToolParams = match params {
574            Some(p) => serde_json::from_value(p)?,
575            None => {
576                return Err(McpError::Validation(
577                    "Missing tool call parameters".to_string(),
578                ))
579            }
580        };
581
582        validate_call_tool_params(&params)?;
583
584        let result = self.call_tool(&params.name, params.arguments).await?;
585        Ok(serde_json::to_value(result)?)
586    }
587
588    async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
589        let _params: ListResourcesParams = match params {
590            Some(p) => serde_json::from_value(p)?,
591            None => ListResourcesParams::default(),
592        };
593
594        let resources = self.list_resources().await?;
595        let result = ListResourcesResult {
596            resources,
597            next_cursor: None, // Pagination support will be added in future versions
598        };
599
600        Ok(serde_json::to_value(result)?)
601    }
602
603    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
604        let params: ReadResourceParams = match params {
605            Some(p) => serde_json::from_value(p)?,
606            None => {
607                return Err(McpError::Validation(
608                    "Missing resource read parameters".to_string(),
609                ))
610            }
611        };
612
613        validate_read_resource_params(&params)?;
614
615        let contents = self.read_resource(&params.uri).await?;
616        let result = ReadResourceResult { contents };
617
618        Ok(serde_json::to_value(result)?)
619    }
620
621    async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
622        let params: SubscribeResourceParams = match params {
623            Some(p) => serde_json::from_value(p)?,
624            None => {
625                return Err(McpError::Validation(
626                    "Missing resource subscribe parameters".to_string(),
627                ))
628            }
629        };
630
631        // Resource subscriptions functionality planned for future implementation
632        let _uri = params.uri;
633        let result = SubscribeResourceResult {};
634
635        Ok(serde_json::to_value(result)?)
636    }
637
638    async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
639        let params: UnsubscribeResourceParams = match params {
640            Some(p) => serde_json::from_value(p)?,
641            None => {
642                return Err(McpError::Validation(
643                    "Missing resource unsubscribe parameters".to_string(),
644                ))
645            }
646        };
647
648        // Resource subscriptions functionality planned for future implementation
649        let _uri = params.uri;
650        let result = UnsubscribeResourceResult {};
651
652        Ok(serde_json::to_value(result)?)
653    }
654
655    async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
656        let _params: ListPromptsParams = match params {
657            Some(p) => serde_json::from_value(p)?,
658            None => ListPromptsParams::default(),
659        };
660
661        let prompts = self.list_prompts().await?;
662        let result = ListPromptsResult {
663            prompts,
664            next_cursor: None, // Pagination support will be added in future versions
665        };
666
667        Ok(serde_json::to_value(result)?)
668    }
669
670    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
671        let params: GetPromptParams = match params {
672            Some(p) => serde_json::from_value(p)?,
673            None => {
674                return Err(McpError::Validation(
675                    "Missing prompt get parameters".to_string(),
676                ))
677            }
678        };
679
680        validate_get_prompt_params(&params)?;
681
682        let result = self.get_prompt(&params.name, params.arguments).await?;
683        Ok(serde_json::to_value(result)?)
684    }
685
686    async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
687        let _params: SetLoggingLevelParams = match params {
688            Some(p) => serde_json::from_value(p)?,
689            None => {
690                return Err(McpError::Validation(
691                    "Missing logging level parameters".to_string(),
692                ))
693            }
694        };
695
696        // Logging level management feature planned for future implementation
697        let result = SetLoggingLevelResult {};
698        Ok(serde_json::to_value(result)?)
699    }
700
701    // ========================================================================
702    // Notification Helpers
703    // ========================================================================
704
705    async fn emit_resources_list_changed(&self) -> McpResult<()> {
706        let notification = JsonRpcNotification::new(
707            methods::RESOURCES_LIST_CHANGED.to_string(),
708            Some(ResourceListChangedParams {}),
709        )?;
710
711        self.send_notification(notification).await
712    }
713
714    async fn emit_tools_list_changed(&self) -> McpResult<()> {
715        let notification = JsonRpcNotification::new(
716            methods::TOOLS_LIST_CHANGED.to_string(),
717            Some(ToolListChangedParams {}),
718        )?;
719
720        self.send_notification(notification).await
721    }
722
723    async fn emit_prompts_list_changed(&self) -> McpResult<()> {
724        let notification = JsonRpcNotification::new(
725            methods::PROMPTS_LIST_CHANGED.to_string(),
726            Some(PromptListChangedParams {}),
727        )?;
728
729        self.send_notification(notification).await
730    }
731
732    /// Send a notification through the transport
733    async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
734        let mut transport_guard = self.transport.lock().await;
735        if let Some(transport) = transport_guard.as_mut() {
736            transport.send_notification(notification).await?;
737        }
738        Ok(())
739    }
740
741    // ========================================================================
742    // Utility Methods
743    // ========================================================================
744
745    #[allow(dead_code)]
746    async fn next_request_id(&self) -> u64 {
747        let mut counter = self.request_counter.lock().await;
748        *counter += 1;
749        *counter
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use serde_json::json;
757
758    #[tokio::test]
759    async fn test_server_creation() {
760        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
761        assert_eq!(server.info().name, "test-server");
762        assert_eq!(server.info().version, "1.0.0");
763        assert!(!server.is_running().await);
764    }
765
766    #[tokio::test]
767    async fn test_tool_management() {
768        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
769
770        // Add a tool
771        let schema = json!({
772            "type": "object",
773            "properties": {
774                "name": {"type": "string"}
775            }
776        });
777
778        struct TestToolHandler;
779
780        #[async_trait::async_trait]
781        impl ToolHandler for TestToolHandler {
782            async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
783                Ok(ToolResult {
784                    content: vec![Content::text("Hello from tool")],
785                    is_error: None,
786                })
787            }
788        }
789
790        server
791            .add_tool(
792                "test_tool".to_string(),
793                Some("A test tool".to_string()),
794                schema,
795                TestToolHandler,
796            )
797            .await
798            .unwrap();
799
800        // List tools
801        let tools = server.list_tools().await.unwrap();
802        assert_eq!(tools.len(), 1);
803        assert_eq!(tools[0].name, "test_tool");
804
805        // Call tool
806        let result = server.call_tool("test_tool", None).await.unwrap();
807        assert_eq!(result.content.len(), 1);
808    }
809
810    #[tokio::test]
811    async fn test_initialize_request() {
812        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
813
814        let init_params = InitializeParams::new(
815            ClientInfo {
816                name: "test-client".to_string(),
817                version: "1.0.0".to_string(),
818            },
819            ClientCapabilities::default(),
820            MCP_PROTOCOL_VERSION.to_string(),
821        );
822
823        let request =
824            JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
825                .unwrap();
826
827        let response = server.handle_request(request).await.unwrap();
828        assert!(response.result.is_some());
829        assert!(response.error.is_none());
830    }
831}