codeprism_mcp/
server.rs

1//! MCP Server implementation
2//!
3//! This module implements the main MCP server that handles the protocol lifecycle,
4//! request routing, and integration with repository components.
5
6use anyhow::Result;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, info, warn};
11
12use crate::{
13    prompts::{GetPromptParams, ListPromptsParams, PromptManager},
14    protocol::{
15        ClientInfo, InitializeParams, InitializeResult, JsonRpcError, JsonRpcNotification,
16        JsonRpcRequest, JsonRpcResponse, ServerInfo,
17    },
18    resources::{ListResourcesParams, ReadResourceParams, ResourceManager},
19    tools::{CallToolParams, ListToolsParams, ToolRegistry},
20    transport::{StdioTransport, Transport, TransportMessage},
21    CodePrismMcpServer,
22};
23
24/// MCP Server state
25#[derive(Debug, Clone, PartialEq)]
26pub enum ServerState {
27    /// Server is not initialized
28    Uninitialized,
29    /// Server is initialized and ready
30    Ready,
31    /// Server is shutting down
32    Shutdown,
33}
34
35/// Main MCP Server implementation
36pub struct McpServer {
37    /// Current server state
38    state: ServerState,
39    /// MCP protocol version
40    protocol_version: String,
41    /// Server information
42    server_info: ServerInfo,
43    /// Client information (set during initialization)
44    client_info: Option<ClientInfo>,
45    /// Core CodePrism server instance
46    codeprism_server: Arc<RwLock<CodePrismMcpServer>>,
47    /// Resource manager
48    resource_manager: ResourceManager,
49    /// Tool registry
50    tool_registry: ToolRegistry,
51    /// Prompt manager
52    prompt_manager: PromptManager,
53}
54
55impl McpServer {
56    /// Create a new MCP server
57    pub fn new() -> Result<Self> {
58        let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new()?));
59
60        let resource_manager = ResourceManager::new(codeprism_server.clone());
61        let tool_registry = ToolRegistry::new(codeprism_server.clone());
62        let prompt_manager = PromptManager::new(codeprism_server.clone());
63
64        Ok(Self {
65            state: ServerState::Uninitialized,
66            protocol_version: "2024-11-05".to_string(),
67            server_info: ServerInfo {
68                name: "codeprism-mcp".to_string(),
69                version: "0.1.0".to_string(),
70            },
71            client_info: None,
72            codeprism_server,
73            resource_manager,
74            tool_registry,
75            prompt_manager,
76        })
77    }
78
79    /// Create a new MCP server with custom configuration
80    pub fn new_with_config(
81        memory_limit_mb: usize,
82        batch_size: usize,
83        max_file_size_mb: usize,
84        disable_memory_limit: bool,
85        exclude_dirs: Vec<String>,
86        include_extensions: Option<Vec<String>>,
87        dependency_mode: Option<String>,
88    ) -> Result<Self> {
89        let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new_with_config(
90            memory_limit_mb,
91            batch_size,
92            max_file_size_mb,
93            disable_memory_limit,
94            exclude_dirs,
95            include_extensions,
96            dependency_mode,
97        )?));
98
99        let resource_manager = ResourceManager::new(codeprism_server.clone());
100        let tool_registry = ToolRegistry::new(codeprism_server.clone());
101        let prompt_manager = PromptManager::new(codeprism_server.clone());
102
103        Ok(Self {
104            state: ServerState::Uninitialized,
105            protocol_version: "2024-11-05".to_string(),
106            server_info: ServerInfo {
107                name: "codeprism-mcp".to_string(),
108                version: "0.1.0".to_string(),
109            },
110            client_info: None,
111            codeprism_server,
112            resource_manager,
113            tool_registry,
114            prompt_manager,
115        })
116    }
117
118    /// Initialize with repository path
119    pub async fn initialize_with_repository<P: AsRef<std::path::Path>>(
120        &self,
121        path: P,
122    ) -> Result<()> {
123        let mut server = self.codeprism_server.write().await;
124        server.initialize_with_repository(path).await
125    }
126
127    /// Run the MCP server with stdio transport
128    pub async fn run_stdio(self) -> Result<()> {
129        info!("Starting CodePrism MCP server with stdio transport");
130
131        let mut transport = StdioTransport::new();
132        transport.start().await?;
133
134        self.run_with_transport(transport).await
135    }
136
137    /// Run the MCP server with a custom transport
138    pub async fn run_with_transport<T: Transport>(mut self, mut transport: T) -> Result<()> {
139        info!("Starting CodePrism MCP server");
140
141        loop {
142            match transport.receive().await? {
143                Some(message) => {
144                    if let Some(response) = self.handle_message(message).await? {
145                        transport.send(response).await?;
146                    }
147                }
148                None => {
149                    debug!("Transport closed, shutting down server");
150                    break;
151                }
152            }
153        }
154
155        transport.close().await?;
156        info!("Prism MCP server stopped");
157        Ok(())
158    }
159
160    /// Handle an incoming message
161    async fn handle_message(
162        &mut self,
163        message: TransportMessage,
164    ) -> Result<Option<TransportMessage>> {
165        match message {
166            TransportMessage::Request(request) => {
167                let response = self.handle_request(request).await;
168                Ok(Some(TransportMessage::Response(response)))
169            }
170            TransportMessage::Notification(notification) => {
171                self.handle_notification(notification).await?;
172                Ok(None) // Notifications don't get responses
173            }
174            TransportMessage::Response(_) => {
175                warn!("Received unexpected response message");
176                Ok(None)
177            }
178        }
179    }
180
181    /// Handle a JSON-RPC request
182    async fn handle_request(&mut self, request: JsonRpcRequest) -> JsonRpcResponse {
183        debug!(
184            "Handling request: method={}, id={:?}",
185            request.method, request.id
186        );
187
188        let result = match request.method.as_str() {
189            "initialize" => self.handle_initialize(request.params).await,
190            "resources/list" => self.handle_resources_list(request.params).await,
191            "resources/read" => self.handle_resources_read(request.params).await,
192            "tools/list" => self.handle_tools_list(request.params).await,
193            "tools/call" => self.handle_tools_call(request.params).await,
194            "prompts/list" => self.handle_prompts_list(request.params).await,
195            "prompts/get" => self.handle_prompts_get(request.params).await,
196            _ => Err(JsonRpcError::method_not_found(&request.method)),
197        };
198
199        match result {
200            Ok(result) => JsonRpcResponse::success(request.id, result),
201            Err(error) => JsonRpcResponse::error(request.id, error),
202        }
203    }
204
205    /// Handle a JSON-RPC notification
206    async fn handle_notification(&mut self, notification: JsonRpcNotification) -> Result<()> {
207        debug!("Handling notification: method={}", notification.method);
208
209        match notification.method.as_str() {
210            "initialized" => {
211                info!("Client reported initialization complete");
212                self.state = ServerState::Ready;
213            }
214            "notifications/cancelled" => {
215                debug!("Received cancellation notification");
216                // TODO: Handle request cancellation
217            }
218            _ => {
219                warn!("Unknown notification method: {}", notification.method);
220            }
221        }
222
223        Ok(())
224    }
225
226    /// Handle initialize request
227    async fn handle_initialize(&mut self, params: Option<Value>) -> Result<Value, JsonRpcError> {
228        let params: InitializeParams = params
229            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
230            .try_into_type()
231            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
232
233        info!(
234            "Initializing MCP server with client: {} v{}",
235            params.client_info.name, params.client_info.version
236        );
237
238        // Store client info
239        self.client_info = Some(params.client_info);
240
241        // Check protocol version compatibility
242        if params.protocol_version != self.protocol_version {
243            warn!(
244                "Protocol version mismatch: client={}, server={}",
245                params.protocol_version, self.protocol_version
246            );
247        }
248
249        // Create initialize result
250        let server = self.codeprism_server.read().await;
251        let result = InitializeResult {
252            protocol_version: self.protocol_version.clone(),
253            capabilities: server.capabilities().clone(),
254            server_info: self.server_info.clone(),
255        };
256
257        serde_json::to_value(result)
258            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
259    }
260
261    /// Handle resources/list request
262    async fn handle_resources_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
263        let params = params
264            .map(serde_json::from_value)
265            .transpose()
266            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
267            .unwrap_or(ListResourcesParams { cursor: None });
268
269        let result = self
270            .resource_manager
271            .list_resources(params)
272            .await
273            .map_err(|e| JsonRpcError::internal_error(format!("Resource list error: {}", e)))?;
274
275        serde_json::to_value(result)
276            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
277    }
278
279    /// Handle resources/read request
280    async fn handle_resources_read(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
281        let params: ReadResourceParams = params
282            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
283            .try_into_type()
284            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
285
286        let result = self
287            .resource_manager
288            .read_resource(params)
289            .await
290            .map_err(|e| JsonRpcError::internal_error(format!("Resource read error: {}", e)))?;
291
292        serde_json::to_value(result)
293            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
294    }
295
296    /// Handle tools/list request
297    async fn handle_tools_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
298        let params = params
299            .map(serde_json::from_value)
300            .transpose()
301            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
302            .unwrap_or(ListToolsParams { cursor: None });
303
304        let result = self
305            .tool_registry
306            .list_tools(params)
307            .await
308            .map_err(|e| JsonRpcError::internal_error(format!("Tool list error: {}", e)))?;
309
310        serde_json::to_value(result)
311            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
312    }
313
314    /// Handle tools/call request
315    async fn handle_tools_call(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
316        let params: CallToolParams = params
317            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
318            .try_into_type()
319            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
320
321        let result = self
322            .tool_registry
323            .call_tool(params)
324            .await
325            .map_err(|e| JsonRpcError::internal_error(format!("Tool call error: {}", e)))?;
326
327        serde_json::to_value(result)
328            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
329    }
330
331    /// Handle prompts/list request
332    async fn handle_prompts_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
333        let params = params
334            .map(serde_json::from_value)
335            .transpose()
336            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
337            .unwrap_or(ListPromptsParams { cursor: None });
338
339        let result = self
340            .prompt_manager
341            .list_prompts(params)
342            .await
343            .map_err(|e| JsonRpcError::internal_error(format!("Prompt list error: {}", e)))?;
344
345        serde_json::to_value(result)
346            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
347    }
348
349    /// Handle prompts/get request
350    async fn handle_prompts_get(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
351        let params: GetPromptParams = params
352            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
353            .try_into_type()
354            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
355
356        let result = self
357            .prompt_manager
358            .get_prompt(params)
359            .await
360            .map_err(|e| JsonRpcError::internal_error(format!("Prompt get error: {}", e)))?;
361
362        serde_json::to_value(result)
363            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
364    }
365
366    /// Get current server state
367    pub fn state(&self) -> ServerState {
368        self.state.clone()
369    }
370
371    /// Get server info
372    pub fn server_info(&self) -> &ServerInfo {
373        &self.server_info
374    }
375
376    /// Get client info (if initialized)
377    pub fn client_info(&self) -> Option<&ClientInfo> {
378        self.client_info.as_ref()
379    }
380}
381
382impl Default for McpServer {
383    fn default() -> Self {
384        Self::new().expect("Failed to create default MCP server")
385    }
386}
387
388// Helper trait for converting JSON values to types
389trait TryIntoType<T> {
390    fn try_into_type(self) -> Result<T, serde_json::Error>;
391}
392
393impl<T> TryIntoType<T> for Value
394where
395    T: serde::de::DeserializeOwned,
396{
397    fn try_into_type(self) -> Result<T, serde_json::Error> {
398        serde_json::from_value(self)
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::protocol::ClientCapabilities;
406
407    #[tokio::test]
408    async fn test_mcp_server_creation() {
409        let server = McpServer::new().expect("Failed to create MCP server");
410        assert_eq!(server.state(), ServerState::Uninitialized);
411        assert_eq!(server.server_info().name, "codeprism-mcp");
412        assert_eq!(server.server_info().version, "0.1.0");
413    }
414
415    #[tokio::test]
416    async fn test_initialize_request() {
417        let mut server = McpServer::new().expect("Failed to create MCP server");
418
419        let params = InitializeParams {
420            protocol_version: "2024-11-05".to_string(),
421            capabilities: ClientCapabilities::default(),
422            client_info: ClientInfo {
423                name: "test-client".to_string(),
424                version: "1.0.0".to_string(),
425            },
426        };
427
428        let params_value = serde_json::to_value(params).unwrap();
429        let result = server.handle_initialize(Some(params_value)).await;
430
431        assert!(result.is_ok());
432        assert!(server.client_info().is_some());
433        assert_eq!(server.client_info().unwrap().name, "test-client");
434    }
435
436    #[test]
437    fn test_server_states() {
438        assert_eq!(ServerState::Uninitialized, ServerState::Uninitialized);
439        assert_ne!(ServerState::Uninitialized, ServerState::Ready);
440        assert_ne!(ServerState::Ready, ServerState::Shutdown);
441    }
442
443    async fn create_test_server_with_repository() -> McpServer {
444        use std::fs;
445        use tempfile::TempDir;
446
447        let temp_dir = TempDir::new().expect("Failed to create temp dir");
448        let repo_path = temp_dir.path();
449
450        // Create test files for server testing
451        fs::write(
452            repo_path.join("app.py"),
453            r#"
454"""Main application module."""
455
456import logging
457from typing import List, Optional, Dict, Any
458from dataclasses import dataclass
459
460@dataclass
461class Config:
462    """Application configuration."""
463    database_url: str
464    api_key: str
465    debug: bool = False
466
467class ApplicationService:
468    """Main application service."""
469    
470    def __init__(self, config: Config):
471        self.config = config
472        self.logger = logging.getLogger(__name__)
473        self._users: Dict[str, 'User'] = {}
474    
475    def create_user(self, username: str, email: str) -> 'User':
476        """Create a new user."""
477        if username in self._users:
478            raise ValueError(f"User {username} already exists")
479        
480        user = User(username=username, email=email)
481        self._users[username] = user
482        self.logger.info(f"Created user: {username}")
483        return user
484    
485    def get_user(self, username: str) -> Optional['User']:
486        """Get a user by username."""
487        return self._users.get(username)
488    
489    def list_users(self) -> List['User']:
490        """List all users."""
491        return list(self._users.values())
492    
493    def delete_user(self, username: str) -> bool:
494        """Delete a user."""
495        if username in self._users:
496            del self._users[username]
497            self.logger.info(f"Deleted user: {username}")
498            return True
499        return False
500
501class User:
502    """User model."""
503    
504    def __init__(self, username: str, email: str):
505        self.username = username
506        self.email = email
507        self.created_at = None  # Would be datetime in real app
508        self.is_active = True
509    
510    def deactivate(self) -> None:
511        """Deactivate the user."""
512        self.is_active = False
513    
514    def activate(self) -> None:
515        """Activate the user."""
516        self.is_active = True
517    
518    def to_dict(self) -> Dict[str, Any]:
519        """Convert user to dictionary."""
520        return {
521            'username': self.username,
522            'email': self.email,
523            'is_active': self.is_active
524        }
525
526def main():
527    """Main application entry point."""
528    config = Config(
529        database_url="postgresql://localhost/myapp",
530        api_key="secret-key"
531    )
532    
533    app = ApplicationService(config)
534    
535    # Create some sample users
536    app.create_user("alice", "alice@example.com")
537    app.create_user("bob", "bob@example.com")
538    
539    # List users
540    users = app.list_users()
541    print(f"Created {len(users)} users")
542
543if __name__ == "__main__":
544    main()
545"#,
546        )
547        .unwrap();
548
549        fs::write(
550            repo_path.join("utils.py"),
551            r#"
552"""Utility functions for the application."""
553
554import re
555import hashlib
556from typing import Optional, Union, List
557from datetime import datetime, timedelta
558
559# Constants
560EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
561PASSWORD_MIN_LENGTH = 8
562
563def validate_email(email: str) -> bool:
564    """Validate email address format."""
565    return bool(EMAIL_REGEX.match(email))
566
567def validate_password(password: str) -> bool:
568    """Validate password strength."""
569    if len(password) < PASSWORD_MIN_LENGTH:
570        return False
571    
572    has_upper = any(c.isupper() for c in password)
573    has_lower = any(c.islower() for c in password)
574    has_digit = any(c.isdigit() for c in password)
575    
576    return has_upper and has_lower and has_digit
577
578def hash_password(password: str, salt: Optional[str] = None) -> str:
579    """Hash a password with salt."""
580    if salt is None:
581        salt = "default_salt"  # In real app, use random salt
582    
583    combined = f"{password}{salt}"
584    return hashlib.sha256(combined.encode()).hexdigest()
585
586def generate_token(length: int = 32) -> str:
587    """Generate a random token."""
588    import secrets
589    return secrets.token_hex(length)
590
591class DateUtils:
592    """Utility class for date operations."""
593    
594    @staticmethod
595    def now() -> datetime:
596        """Get current datetime."""
597        return datetime.now()
598    
599    @staticmethod
600    def add_days(date: datetime, days: int) -> datetime:
601        """Add days to a date."""
602        return date + timedelta(days=days)
603    
604    @staticmethod
605    def format_date(date: datetime, format_str: str = "%Y-%m-%d") -> str:
606        """Format a date as string."""
607        return date.strftime(format_str)
608
609def cleanup_string(text: str) -> str:
610    """Clean up a string by removing extra whitespace."""
611    return re.sub(r'\s+', ' ', text.strip())
612
613def parse_config_value(value: str) -> Union[str, int, bool]:
614    """Parse a configuration value to appropriate type."""
615    # Try boolean
616    if value.lower() in ('true', 'false'):
617        return value.lower() == 'true'
618    
619    # Try integer
620    try:
621        return int(value)
622    except ValueError:
623        pass
624    
625    # Return as string
626    return value
627"#,
628        )
629        .unwrap();
630
631        let server = McpServer::new_with_config(
632            2048,  // memory_limit_mb
633            20,    // batch_size
634            5,     // max_file_size_mb
635            false, // disable_memory_limit
636            vec!["__pycache__".to_string(), ".pytest_cache".to_string()],
637            Some(vec!["py".to_string()]),
638            Some("exclude".to_string()),
639        )
640        .expect("Failed to create MCP server");
641
642        server
643            .initialize_with_repository(repo_path)
644            .await
645            .expect("Failed to initialize repository");
646
647        // Keep temp_dir alive
648        std::mem::forget(temp_dir);
649
650        server
651    }
652
653    #[tokio::test]
654    async fn test_server_with_repository_initialization() {
655        let server = create_test_server_with_repository().await;
656
657        // Server should be properly configured
658        assert_eq!(server.state(), ServerState::Uninitialized);
659
660        // Should have repository configured
661        let codeprism_server = server.codeprism_server.read().await;
662        assert!(codeprism_server.repository_path().is_some());
663    }
664
665    #[tokio::test]
666    async fn test_full_mcp_workflow() {
667        let mut server = create_test_server_with_repository().await;
668
669        // 1. Initialize the MCP server
670        let init_params = InitializeParams {
671            protocol_version: "2024-11-05".to_string(),
672            capabilities: ClientCapabilities::default(),
673            client_info: ClientInfo {
674                name: "test-client".to_string(),
675                version: "1.0.0".to_string(),
676            },
677        };
678
679        let init_result = server
680            .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
681            .await;
682        assert!(init_result.is_ok());
683
684        // Server should have client info
685        assert!(server.client_info().is_some());
686
687        // 2. Test resources/list
688        let resources_result = server.handle_resources_list(None).await;
689        assert!(resources_result.is_ok());
690
691        let resources_value = resources_result.unwrap();
692        let resources: crate::resources::ListResourcesResult =
693            serde_json::from_value(resources_value).unwrap();
694        assert!(!resources.resources.is_empty());
695
696        // Should have various resource types
697        let uris: Vec<String> = resources.resources.iter().map(|r| r.uri.clone()).collect();
698        assert!(uris.iter().any(|uri| uri == "codeprism://repository/stats"));
699        assert!(uris.iter().any(|uri| uri == "codeprism://graph/repository"));
700        assert!(uris.iter().any(|uri| uri.contains("app.py")));
701
702        // 3. Test resources/read
703        let read_params = crate::resources::ReadResourceParams {
704            uri: "codeprism://repository/stats".to_string(),
705        };
706        let read_result = server
707            .handle_resources_read(Some(serde_json::to_value(read_params).unwrap()))
708            .await;
709        assert!(read_result.is_ok());
710
711        // 4. Test tools/list
712        let tools_result = server.handle_tools_list(None).await;
713        assert!(tools_result.is_ok());
714
715        let tools_value = tools_result.unwrap();
716        let tools: crate::tools::ListToolsResult = serde_json::from_value(tools_value).unwrap();
717        assert_eq!(tools.tools.len(), 23); // All 23 tools should be available including all implemented tools
718
719        // 5. Test tools/call with repository_stats
720        let tool_params = crate::tools::CallToolParams {
721            name: "repository_stats".to_string(),
722            arguments: Some(serde_json::json!({})),
723        };
724        let tool_result = server
725            .handle_tools_call(Some(serde_json::to_value(tool_params).unwrap()))
726            .await;
727        assert!(tool_result.is_ok());
728
729        // 6. Test prompts/list
730        let prompts_result = server.handle_prompts_list(None).await;
731        assert!(prompts_result.is_ok());
732
733        let prompts_value = prompts_result.unwrap();
734        let prompts: crate::prompts::ListPromptsResult =
735            serde_json::from_value(prompts_value).unwrap();
736        assert_eq!(prompts.prompts.len(), 16); // All 16 prompts should be available (original 8 + 8 new for large codebase understanding)
737
738        // 7. Test prompts/get
739        let prompt_params = crate::prompts::GetPromptParams {
740            name: "repository_overview".to_string(),
741            arguments: Some(serde_json::Map::from_iter([(
742                "focus_area".to_string(),
743                serde_json::Value::String("architecture".to_string()),
744            )])),
745        };
746        let prompt_result = server
747            .handle_prompts_get(Some(serde_json::to_value(prompt_params).unwrap()))
748            .await;
749        assert!(prompt_result.is_ok());
750    }
751
752    #[tokio::test]
753    async fn test_request_handling_errors() {
754        let mut server = McpServer::new().expect("Failed to create MCP server");
755
756        // Test invalid method
757        let invalid_request = JsonRpcRequest {
758            jsonrpc: "2.0".to_string(),
759            id: serde_json::Value::Number(1.into()),
760            method: "invalid_method".to_string(),
761            params: None,
762        };
763
764        let response = server.handle_request(invalid_request).await;
765        assert!(response.error.is_some());
766        assert_eq!(response.error.unwrap().code, -32601); // Method not found
767
768        // Test missing required parameters
769        let missing_params_request = JsonRpcRequest {
770            jsonrpc: "2.0".to_string(),
771            id: serde_json::Value::Number(2.into()),
772            method: "resources/read".to_string(),
773            params: None, // Missing required uri parameter
774        };
775
776        let response = server.handle_request(missing_params_request).await;
777        assert!(response.error.is_some());
778        assert_eq!(response.error.unwrap().code, -32602); // Invalid params
779    }
780
781    #[tokio::test]
782    async fn test_notification_handling() {
783        let mut server = McpServer::new().expect("Failed to create MCP server");
784
785        // Test initialized notification
786        let initialized_notification = JsonRpcNotification {
787            jsonrpc: "2.0".to_string(),
788            method: "initialized".to_string(),
789            params: None,
790        };
791
792        assert_eq!(server.state(), ServerState::Uninitialized);
793
794        let result = server.handle_notification(initialized_notification).await;
795        assert!(result.is_ok());
796        assert_eq!(server.state(), ServerState::Ready);
797
798        // Test unknown notification
799        let unknown_notification = JsonRpcNotification {
800            jsonrpc: "2.0".to_string(),
801            method: "unknown_notification".to_string(),
802            params: None,
803        };
804
805        let result = server.handle_notification(unknown_notification).await;
806        assert!(result.is_ok()); // Should not fail, just log warning
807    }
808
809    #[tokio::test]
810    async fn test_message_handling() {
811        let mut server = McpServer::new().expect("Failed to create MCP server");
812
813        // Test request message handling
814        let request_message = crate::transport::TransportMessage::Request(JsonRpcRequest {
815            jsonrpc: "2.0".to_string(),
816            id: serde_json::Value::Number(1.into()),
817            method: "initialize".to_string(),
818            params: Some(serde_json::json!({
819                "protocolVersion": "2024-11-05",
820                "capabilities": {},
821                "clientInfo": {
822                    "name": "test-client",
823                    "version": "1.0.0"
824                }
825            })),
826        });
827
828        let response = server.handle_message(request_message).await;
829        assert!(response.is_ok());
830        assert!(response.unwrap().is_some()); // Should return a response
831
832        // Test notification message handling
833        let notification_message =
834            crate::transport::TransportMessage::Notification(JsonRpcNotification {
835                jsonrpc: "2.0".to_string(),
836                method: "initialized".to_string(),
837                params: None,
838            });
839
840        let response = server.handle_message(notification_message).await;
841        assert!(response.is_ok());
842        assert!(response.unwrap().is_none()); // Notifications don't return responses
843    }
844
845    #[tokio::test]
846    async fn test_server_capabilities_validation() {
847        let server = create_test_server_with_repository().await;
848        let codeprism_server = server.codeprism_server.read().await;
849        let capabilities = codeprism_server.capabilities();
850
851        // Verify all required capabilities are present
852        assert!(capabilities.resources.is_some());
853        assert!(capabilities.tools.is_some());
854        assert!(capabilities.prompts.is_some());
855
856        // Verify resource capabilities
857        let resource_caps = capabilities.resources.as_ref().unwrap();
858        assert_eq!(resource_caps.subscribe, Some(true));
859        assert_eq!(resource_caps.list_changed, Some(true));
860
861        // Verify tool capabilities
862        let tool_caps = capabilities.tools.as_ref().unwrap();
863        assert_eq!(tool_caps.list_changed, Some(true));
864
865        // Verify prompt capabilities
866        let prompt_caps = capabilities.prompts.as_ref().unwrap();
867        assert_eq!(prompt_caps.list_changed, Some(false));
868    }
869
870    #[tokio::test]
871    async fn test_concurrent_requests() {
872        use std::sync::Arc;
873        use tokio::sync::RwLock;
874
875        let server = Arc::new(RwLock::new(create_test_server_with_repository().await));
876
877        // Initialize the server first
878        {
879            let mut server_lock = server.write().await;
880            let init_params = InitializeParams {
881                protocol_version: "2024-11-05".to_string(),
882                capabilities: ClientCapabilities::default(),
883                client_info: ClientInfo {
884                    name: "test-client".to_string(),
885                    version: "1.0.0".to_string(),
886                },
887            };
888
889            server_lock
890                .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
891                .await
892                .unwrap();
893        }
894
895        // Run multiple concurrent requests
896        let mut handles = Vec::new();
897
898        for i in 0..5 {
899            let server_clone = server.clone();
900            let handle = tokio::spawn(async move {
901                let server_lock = server_clone.write().await;
902
903                // Test resources/list
904                let resources_result = server_lock.handle_resources_list(None).await;
905                assert!(resources_result.is_ok());
906
907                // Test tools/list
908                let tools_result = server_lock.handle_tools_list(None).await;
909                assert!(tools_result.is_ok());
910
911                i // Return the task number
912            });
913
914            handles.push(handle);
915        }
916
917        // Wait for all tasks to complete
918        for handle in handles {
919            let result = handle.await;
920            assert!(result.is_ok());
921        }
922    }
923
924    #[test]
925    fn test_server_info_serialization() {
926        let server_info = ServerInfo {
927            name: "test-server".to_string(),
928            version: "1.0.0".to_string(),
929        };
930
931        let json = serde_json::to_string(&server_info).unwrap();
932        let deserialized: ServerInfo = serde_json::from_str(&json).unwrap();
933
934        assert_eq!(server_info.name, deserialized.name);
935        assert_eq!(server_info.version, deserialized.version);
936    }
937
938    #[test]
939    fn test_client_info_serialization() {
940        let client_info = ClientInfo {
941            name: "test-client".to_string(),
942            version: "2.0.0".to_string(),
943        };
944
945        let json = serde_json::to_string(&client_info).unwrap();
946        let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
947
948        assert_eq!(client_info.name, deserialized.name);
949        assert_eq!(client_info.version, deserialized.version);
950    }
951}