mcp_protocol_sdk/server/
mcp_server.rs

1//! MCP server implementation
2//!
3//! This module provides the main MCP server implementation that handles client connections,
4//! manages resources, tools, and prompts, and processes JSON-RPC requests according to
5//! the Model Context Protocol specification.
6
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::{
13    error::{McpError, McpResult},
14    prompt::{Prompt, PromptHandler},
15    resource::{Resource, ResourceHandler},
16    tool::{Tool, ToolHandler},
17    PromptInfo, ResourceInfo, ToolInfo,
18};
19use crate::protocol::{messages::*, types::*, validation::*};
20use crate::transport::traits::ServerTransport;
21
22/// Configuration for the MCP server
23#[derive(Debug, Clone)]
24pub struct ServerConfig {
25    /// Maximum number of concurrent requests
26    pub max_concurrent_requests: usize,
27    /// Request timeout in milliseconds
28    pub request_timeout_ms: u64,
29    /// Whether to validate all incoming requests
30    pub validate_requests: bool,
31    /// Whether to enable detailed logging
32    pub enable_logging: bool,
33}
34
35impl Default for ServerConfig {
36    fn default() -> Self {
37        Self {
38            max_concurrent_requests: 100,
39            request_timeout_ms: 30000,
40            validate_requests: true,
41            enable_logging: true,
42        }
43    }
44}
45
46/// Main MCP server implementation
47pub struct McpServer {
48    /// Server information
49    info: ServerInfo,
50    /// Server capabilities
51    capabilities: ServerCapabilities,
52    /// Server configuration
53    config: ServerConfig,
54    /// Registered resources
55    resources: Arc<RwLock<HashMap<String, Resource>>>,
56    /// Registered tools
57    tools: Arc<RwLock<HashMap<String, Tool>>>,
58    /// Registered prompts
59    prompts: Arc<RwLock<HashMap<String, Prompt>>>,
60    /// Active transport
61    transport: Arc<Mutex<Option<Box<dyn ServerTransport>>>>,
62    /// Server state
63    state: Arc<RwLock<ServerState>>,
64    /// Request ID counter
65    request_counter: Arc<Mutex<u64>>,
66}
67
68/// Internal server state
69#[derive(Debug, Clone, PartialEq)]
70enum ServerState {
71    /// Server is not yet initialized
72    Uninitialized,
73    /// Server is initializing
74    Initializing,
75    /// Server is running and ready to accept requests
76    Running,
77    /// Server is shutting down
78    Stopping,
79    /// Server has stopped
80    Stopped,
81}
82
83impl McpServer {
84    /// Create a new MCP server with the given name and version
85    pub fn new(name: String, version: String) -> Self {
86        Self {
87            info: ServerInfo { name, version },
88            capabilities: ServerCapabilities {
89                prompts: Some(PromptsCapability {
90                    list_changed: Some(true),
91                }),
92                resources: Some(ResourcesCapability {
93                    subscribe: Some(true),
94                    list_changed: Some(true),
95                }),
96                tools: Some(ToolsCapability {
97                    list_changed: Some(true),
98                }),
99                sampling: None,
100            },
101            config: ServerConfig::default(),
102            resources: Arc::new(RwLock::new(HashMap::new())),
103            tools: Arc::new(RwLock::new(HashMap::new())),
104            prompts: Arc::new(RwLock::new(HashMap::new())),
105            transport: Arc::new(Mutex::new(None)),
106            state: Arc::new(RwLock::new(ServerState::Uninitialized)),
107            request_counter: Arc::new(Mutex::new(0)),
108        }
109    }
110
111    /// Create a new MCP server with custom configuration
112    pub fn with_config(name: String, version: String, config: ServerConfig) -> Self {
113        let mut server = Self::new(name, version);
114        server.config = config;
115        server
116    }
117
118    /// Set server capabilities
119    pub fn set_capabilities(&mut self, capabilities: ServerCapabilities) {
120        self.capabilities = capabilities;
121    }
122
123    /// Get server information
124    pub fn info(&self) -> &ServerInfo {
125        &self.info
126    }
127
128    /// Get server capabilities
129    pub fn capabilities(&self) -> &ServerCapabilities {
130        &self.capabilities
131    }
132
133    /// Get server configuration
134    pub fn config(&self) -> &ServerConfig {
135        &self.config
136    }
137
138    // ========================================================================
139    // Resource Management
140    // ========================================================================
141
142    /// Add a resource to the server
143    pub async fn add_resource<H>(&self, name: String, uri: String, handler: H) -> McpResult<()>
144    where
145        H: ResourceHandler + 'static,
146    {
147        let resource_info = ResourceInfo {
148            uri: uri.clone(),
149            name: name.clone(),
150            description: None,
151            mime_type: None,
152        };
153
154        validate_resource_info(&resource_info)?;
155
156        let resource = Resource::new(resource_info, handler);
157
158        {
159            let mut resources = self.resources.write().await;
160            resources.insert(uri, resource);
161        }
162
163        // Emit list changed notification if we have an active transport
164        self.emit_resources_list_changed().await?;
165
166        Ok(())
167    }
168
169    /// Add a resource with detailed information
170    pub async fn add_resource_detailed<H>(&self, info: ResourceInfo, handler: H) -> McpResult<()>
171    where
172        H: ResourceHandler + 'static,
173    {
174        validate_resource_info(&info)?;
175
176        let uri = info.uri.clone();
177        let resource = Resource::new(info, handler);
178
179        {
180            let mut resources = self.resources.write().await;
181            resources.insert(uri, resource);
182        }
183
184        self.emit_resources_list_changed().await?;
185
186        Ok(())
187    }
188
189    /// Remove a resource from the server
190    pub async fn remove_resource(&self, uri: &str) -> McpResult<bool> {
191        let removed = {
192            let mut resources = self.resources.write().await;
193            resources.remove(uri).is_some()
194        };
195
196        if removed {
197            self.emit_resources_list_changed().await?;
198        }
199
200        Ok(removed)
201    }
202
203    /// List all registered resources
204    pub async fn list_resources(&self) -> McpResult<Vec<ResourceInfo>> {
205        let resources = self.resources.read().await;
206        Ok(resources.values().map(|r| r.info.clone()).collect())
207    }
208
209    /// Read a resource
210    pub async fn read_resource(&self, uri: &str) -> McpResult<Vec<ResourceContent>> {
211        let resources = self.resources.read().await;
212
213        match resources.get(uri) {
214            Some(resource) => {
215                let params = HashMap::new(); // TODO: Extract params from URI
216                resource.handler.read(uri, &params).await
217            }
218            None => Err(McpError::ResourceNotFound(uri.to_string())),
219        }
220    }
221
222    // ========================================================================
223    // Tool Management
224    // ========================================================================
225
226    /// Add a tool to the server
227    pub async fn add_tool<H>(
228        &self,
229        name: String,
230        description: Option<String>,
231        schema: Value,
232        handler: H,
233    ) -> McpResult<()>
234    where
235        H: ToolHandler + 'static,
236    {
237        let tool_info = ToolInfo {
238            name: name.clone(),
239            description,
240            input_schema: schema,
241        };
242
243        validate_tool_info(&tool_info)?;
244
245        let tool = Tool::new(
246            name.clone(),
247            tool_info.description.clone(),
248            tool_info.input_schema.clone(),
249            handler,
250        );
251
252        {
253            let mut tools = self.tools.write().await;
254            tools.insert(name, tool);
255        }
256
257        self.emit_tools_list_changed().await?;
258
259        Ok(())
260    }
261
262    /// Add a tool with detailed information
263    pub async fn add_tool_detailed<H>(&self, info: ToolInfo, handler: H) -> McpResult<()>
264    where
265        H: ToolHandler + 'static,
266    {
267        validate_tool_info(&info)?;
268
269        let name = info.name.clone();
270        let tool = Tool::new(
271            name.clone(),
272            info.description.clone(),
273            info.input_schema.clone(),
274            handler,
275        );
276
277        {
278            let mut tools = self.tools.write().await;
279            tools.insert(name, tool);
280        }
281
282        self.emit_tools_list_changed().await?;
283
284        Ok(())
285    }
286
287    /// Remove a tool from the server
288    pub async fn remove_tool(&self, name: &str) -> McpResult<bool> {
289        let removed = {
290            let mut tools = self.tools.write().await;
291            tools.remove(name).is_some()
292        };
293
294        if removed {
295            self.emit_tools_list_changed().await?;
296        }
297
298        Ok(removed)
299    }
300
301    /// List all registered tools
302    pub async fn list_tools(&self) -> McpResult<Vec<ToolInfo>> {
303        let tools = self.tools.read().await;
304        Ok(tools.values().map(|t| t.info.clone()).collect())
305    }
306
307    /// Call a tool
308    pub async fn call_tool(
309        &self,
310        name: &str,
311        arguments: Option<HashMap<String, Value>>,
312    ) -> McpResult<ToolResult> {
313        let tools = self.tools.read().await;
314
315        match tools.get(name) {
316            Some(tool) => {
317                if !tool.enabled {
318                    return Err(McpError::ToolNotFound(format!(
319                        "Tool '{}' is disabled",
320                        name
321                    )));
322                }
323
324                let args = arguments.unwrap_or_default();
325                tool.handler.call(args).await
326            }
327            None => Err(McpError::ToolNotFound(name.to_string())),
328        }
329    }
330
331    // ========================================================================
332    // Prompt Management
333    // ========================================================================
334
335    /// Add a prompt to the server
336    pub async fn add_prompt<H>(&self, info: PromptInfo, handler: H) -> McpResult<()>
337    where
338        H: PromptHandler + 'static,
339    {
340        validate_prompt_info(&info)?;
341
342        let name = info.name.clone();
343        let prompt = Prompt::new(info, handler);
344
345        {
346            let mut prompts = self.prompts.write().await;
347            prompts.insert(name, prompt);
348        }
349
350        self.emit_prompts_list_changed().await?;
351
352        Ok(())
353    }
354
355    /// Remove a prompt from the server
356    pub async fn remove_prompt(&self, name: &str) -> McpResult<bool> {
357        let removed = {
358            let mut prompts = self.prompts.write().await;
359            prompts.remove(name).is_some()
360        };
361
362        if removed {
363            self.emit_prompts_list_changed().await?;
364        }
365
366        Ok(removed)
367    }
368
369    /// List all registered prompts
370    pub async fn list_prompts(&self) -> McpResult<Vec<PromptInfo>> {
371        let prompts = self.prompts.read().await;
372        Ok(prompts.values().map(|p| p.info.clone()).collect())
373    }
374
375    /// Get a prompt
376    pub async fn get_prompt(
377        &self,
378        name: &str,
379        arguments: Option<HashMap<String, Value>>,
380    ) -> McpResult<PromptResult> {
381        let prompts = self.prompts.read().await;
382
383        match prompts.get(name) {
384            Some(prompt) => {
385                let args = arguments.unwrap_or_default();
386                prompt.handler.get(args).await
387            }
388            None => Err(McpError::PromptNotFound(name.to_string())),
389        }
390    }
391
392    // ========================================================================
393    // Server Lifecycle
394    // ========================================================================
395
396    /// Start the server with the given transport
397    pub async fn start<T>(&mut self, transport: T) -> McpResult<()>
398    where
399        T: ServerTransport + 'static,
400    {
401        let mut state = self.state.write().await;
402
403        match *state {
404            ServerState::Uninitialized => {
405                *state = ServerState::Initializing;
406            }
407            _ => return Err(McpError::Protocol("Server is already started".to_string())),
408        }
409
410        drop(state);
411
412        // Set up the transport
413        {
414            let mut transport_guard = self.transport.lock().await;
415            *transport_guard = Some(Box::new(transport));
416        }
417
418        // Start the transport
419        {
420            let mut transport_guard = self.transport.lock().await;
421            if let Some(transport) = transport_guard.as_mut() {
422                transport.start().await?;
423            }
424        }
425
426        // Update state to running
427        {
428            let mut state = self.state.write().await;
429            *state = ServerState::Running;
430        }
431
432        Ok(())
433    }
434
435    /// Stop the server
436    pub async fn stop(&self) -> McpResult<()> {
437        let mut state = self.state.write().await;
438
439        match *state {
440            ServerState::Running => {
441                *state = ServerState::Stopping;
442            }
443            ServerState::Stopped => return Ok(()),
444            _ => return Err(McpError::Protocol("Server is not running".to_string())),
445        }
446
447        drop(state);
448
449        // Stop the transport
450        {
451            let mut transport_guard = self.transport.lock().await;
452            if let Some(transport) = transport_guard.as_mut() {
453                transport.stop().await?;
454            }
455        }
456
457        // Update state to stopped
458        {
459            let mut state = self.state.write().await;
460            *state = ServerState::Stopped;
461        }
462
463        Ok(())
464    }
465
466    /// Check if the server is running
467    pub async fn is_running(&self) -> bool {
468        let state = self.state.read().await;
469        matches!(*state, ServerState::Running)
470    }
471
472    /// Get the current server state
473    pub async fn state(&self) -> ServerState {
474        let state = self.state.read().await;
475        state.clone()
476    }
477
478    // ========================================================================
479    // Request Handling
480    // ========================================================================
481
482    /// Handle an incoming JSON-RPC request
483    pub async fn handle_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
484        // Validate the request if configured to do so
485        if self.config.validate_requests {
486            validate_jsonrpc_request(&request)?;
487            validate_mcp_request(&request.method, request.params.as_ref())?;
488        }
489
490        // Route the request to the appropriate handler
491        let result = match request.method.as_str() {
492            methods::INITIALIZE => self.handle_initialize(request.params).await,
493            methods::PING => self.handle_ping().await,
494            methods::TOOLS_LIST => self.handle_tools_list(request.params).await,
495            methods::TOOLS_CALL => self.handle_tools_call(request.params).await,
496            methods::RESOURCES_LIST => self.handle_resources_list(request.params).await,
497            methods::RESOURCES_READ => self.handle_resources_read(request.params).await,
498            methods::RESOURCES_SUBSCRIBE => self.handle_resources_subscribe(request.params).await,
499            methods::RESOURCES_UNSUBSCRIBE => {
500                self.handle_resources_unsubscribe(request.params).await
501            }
502            methods::PROMPTS_LIST => self.handle_prompts_list(request.params).await,
503            methods::PROMPTS_GET => self.handle_prompts_get(request.params).await,
504            methods::LOGGING_SET_LEVEL => self.handle_logging_set_level(request.params).await,
505            _ => Err(McpError::Protocol(format!(
506                "Unknown method: {}",
507                request.method
508            ))),
509        };
510
511        // Convert the result to a JSON-RPC response
512        match result {
513            Ok(result_value) => Ok(JsonRpcResponse::success(request.id, result_value)?),
514            Err(error) => {
515                let (code, message) = match error {
516                    McpError::ToolNotFound(_) => (TOOL_NOT_FOUND, error.to_string()),
517                    McpError::ResourceNotFound(_) => (RESOURCE_NOT_FOUND, error.to_string()),
518                    McpError::PromptNotFound(_) => (PROMPT_NOT_FOUND, error.to_string()),
519                    McpError::Validation(_) => (INVALID_PARAMS, error.to_string()),
520                    _ => (INTERNAL_ERROR, error.to_string()),
521                };
522                Ok(JsonRpcResponse::error(request.id, code, message, None))
523            }
524        }
525    }
526
527    // ========================================================================
528    // Individual Request Handlers
529    // ========================================================================
530
531    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
532        let params: InitializeParams = match params {
533            Some(p) => serde_json::from_value(p)?,
534            None => {
535                return Err(McpError::Validation(
536                    "Missing initialize parameters".to_string(),
537                ))
538            }
539        };
540
541        validate_initialize_params(&params)?;
542
543        let result = InitializeResult::new(
544            self.info.clone(),
545            self.capabilities.clone(),
546            MCP_PROTOCOL_VERSION.to_string(),
547        );
548
549        Ok(serde_json::to_value(result)?)
550    }
551
552    async fn handle_ping(&self) -> McpResult<Value> {
553        Ok(serde_json::to_value(PingResult {})?)
554    }
555
556    async fn handle_tools_list(&self, params: Option<Value>) -> McpResult<Value> {
557        let _params: ListToolsParams = match params {
558            Some(p) => serde_json::from_value(p)?,
559            None => ListToolsParams::default(),
560        };
561
562        let tools = self.list_tools().await?;
563        let result = ListToolsResult {
564            tools,
565            next_cursor: None, // TODO: Implement pagination
566        };
567
568        Ok(serde_json::to_value(result)?)
569    }
570
571    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
572        let params: CallToolParams = match params {
573            Some(p) => serde_json::from_value(p)?,
574            None => {
575                return Err(McpError::Validation(
576                    "Missing tool call parameters".to_string(),
577                ))
578            }
579        };
580
581        validate_call_tool_params(&params)?;
582
583        let result = self.call_tool(&params.name, params.arguments).await?;
584        Ok(serde_json::to_value(result)?)
585    }
586
587    async fn handle_resources_list(&self, params: Option<Value>) -> McpResult<Value> {
588        let _params: ListResourcesParams = match params {
589            Some(p) => serde_json::from_value(p)?,
590            None => ListResourcesParams::default(),
591        };
592
593        let resources = self.list_resources().await?;
594        let result = ListResourcesResult {
595            resources,
596            next_cursor: None, // TODO: Implement pagination
597        };
598
599        Ok(serde_json::to_value(result)?)
600    }
601
602    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
603        let params: ReadResourceParams = match params {
604            Some(p) => serde_json::from_value(p)?,
605            None => {
606                return Err(McpError::Validation(
607                    "Missing resource read parameters".to_string(),
608                ))
609            }
610        };
611
612        validate_read_resource_params(&params)?;
613
614        let contents = self.read_resource(&params.uri).await?;
615        let result = ReadResourceResult { contents };
616
617        Ok(serde_json::to_value(result)?)
618    }
619
620    async fn handle_resources_subscribe(&self, params: Option<Value>) -> McpResult<Value> {
621        let params: SubscribeResourceParams = match params {
622            Some(p) => serde_json::from_value(p)?,
623            None => {
624                return Err(McpError::Validation(
625                    "Missing resource subscribe parameters".to_string(),
626                ))
627            }
628        };
629
630        // TODO: Implement resource subscriptions
631        let _uri = params.uri;
632        let result = SubscribeResourceResult {};
633
634        Ok(serde_json::to_value(result)?)
635    }
636
637    async fn handle_resources_unsubscribe(&self, params: Option<Value>) -> McpResult<Value> {
638        let params: UnsubscribeResourceParams = match params {
639            Some(p) => serde_json::from_value(p)?,
640            None => {
641                return Err(McpError::Validation(
642                    "Missing resource unsubscribe parameters".to_string(),
643                ))
644            }
645        };
646
647        // TODO: Implement resource subscriptions
648        let _uri = params.uri;
649        let result = UnsubscribeResourceResult {};
650
651        Ok(serde_json::to_value(result)?)
652    }
653
654    async fn handle_prompts_list(&self, params: Option<Value>) -> McpResult<Value> {
655        let _params: ListPromptsParams = match params {
656            Some(p) => serde_json::from_value(p)?,
657            None => ListPromptsParams::default(),
658        };
659
660        let prompts = self.list_prompts().await?;
661        let result = ListPromptsResult {
662            prompts,
663            next_cursor: None, // TODO: Implement pagination
664        };
665
666        Ok(serde_json::to_value(result)?)
667    }
668
669    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
670        let params: GetPromptParams = match params {
671            Some(p) => serde_json::from_value(p)?,
672            None => {
673                return Err(McpError::Validation(
674                    "Missing prompt get parameters".to_string(),
675                ))
676            }
677        };
678
679        validate_get_prompt_params(&params)?;
680
681        let result = self.get_prompt(&params.name, params.arguments).await?;
682        Ok(serde_json::to_value(result)?)
683    }
684
685    async fn handle_logging_set_level(&self, params: Option<Value>) -> McpResult<Value> {
686        let _params: SetLoggingLevelParams = match params {
687            Some(p) => serde_json::from_value(p)?,
688            None => {
689                return Err(McpError::Validation(
690                    "Missing logging level parameters".to_string(),
691                ))
692            }
693        };
694
695        // TODO: Implement logging level management
696        let result = SetLoggingLevelResult {};
697        Ok(serde_json::to_value(result)?)
698    }
699
700    // ========================================================================
701    // Notification Helpers
702    // ========================================================================
703
704    async fn emit_resources_list_changed(&self) -> McpResult<()> {
705        let notification = JsonRpcNotification::new(
706            methods::RESOURCES_LIST_CHANGED.to_string(),
707            Some(ResourceListChangedParams {}),
708        )?;
709
710        self.send_notification(notification).await
711    }
712
713    async fn emit_tools_list_changed(&self) -> McpResult<()> {
714        let notification = JsonRpcNotification::new(
715            methods::TOOLS_LIST_CHANGED.to_string(),
716            Some(ToolListChangedParams {}),
717        )?;
718
719        self.send_notification(notification).await
720    }
721
722    async fn emit_prompts_list_changed(&self) -> McpResult<()> {
723        let notification = JsonRpcNotification::new(
724            methods::PROMPTS_LIST_CHANGED.to_string(),
725            Some(PromptListChangedParams {}),
726        )?;
727
728        self.send_notification(notification).await
729    }
730
731    /// Send a notification through the transport
732    async fn send_notification(&self, notification: JsonRpcNotification) -> McpResult<()> {
733        let mut transport_guard = self.transport.lock().await;
734        if let Some(transport) = transport_guard.as_mut() {
735            transport.send_notification(notification).await?;
736        }
737        Ok(())
738    }
739
740    // ========================================================================
741    // Utility Methods
742    // ========================================================================
743
744    async fn next_request_id(&self) -> u64 {
745        let mut counter = self.request_counter.lock().await;
746        *counter += 1;
747        *counter
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754    use serde_json::json;
755
756    #[tokio::test]
757    async fn test_server_creation() {
758        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
759        assert_eq!(server.info().name, "test-server");
760        assert_eq!(server.info().version, "1.0.0");
761        assert!(!server.is_running().await);
762    }
763
764    #[tokio::test]
765    async fn test_tool_management() {
766        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
767
768        // Add a tool
769        let schema = json!({
770            "type": "object",
771            "properties": {
772                "name": {"type": "string"}
773            }
774        });
775
776        struct TestToolHandler;
777
778        #[async_trait::async_trait]
779        impl ToolHandler for TestToolHandler {
780            async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
781                Ok(ToolResult {
782                    content: vec![Content::text("Hello from tool")],
783                    is_error: None,
784                })
785            }
786        }
787
788        server
789            .add_tool(
790                "test_tool".to_string(),
791                Some("A test tool".to_string()),
792                schema,
793                TestToolHandler,
794            )
795            .await
796            .unwrap();
797
798        // List tools
799        let tools = server.list_tools().await.unwrap();
800        assert_eq!(tools.len(), 1);
801        assert_eq!(tools[0].name, "test_tool");
802
803        // Call tool
804        let result = server.call_tool("test_tool", None).await.unwrap();
805        assert_eq!(result.content.len(), 1);
806    }
807
808    #[tokio::test]
809    async fn test_initialize_request() {
810        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
811
812        let init_params = InitializeParams::new(
813            ClientInfo {
814                name: "test-client".to_string(),
815                version: "1.0.0".to_string(),
816            },
817            ClientCapabilities::default(),
818            MCP_PROTOCOL_VERSION.to_string(),
819        );
820
821        let request =
822            JsonRpcRequest::new(json!(1), methods::INITIALIZE.to_string(), Some(init_params))
823                .unwrap();
824
825        let response = server.handle_request(request).await.unwrap();
826        assert!(response.result.is_some());
827        assert!(response.error.is_none());
828    }
829}