codeprism_mcp/
server.rs

1//! MCP Server implementation
2//!
3//! This module implements the main MCP server that handles the protocol lifecycle,
4//! request routing, and integration with repository components.
5
6use anyhow::Result;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tokio::time::Duration;
12use tracing::{debug, info, warn};
13
14use crate::{
15    prompts::{GetPromptParams, ListPromptsParams, PromptManager},
16    protocol::{
17        CancellationParams, CancellationToken, ClientInfo, ClientOptimizations, ClientType,
18        InitializeParams, InitializeResult, JsonRpcError, JsonRpcNotification, JsonRpcRequest,
19        JsonRpcResponse, ServerInfo, VersionNegotiation, DEFAULT_PROTOCOL_VERSION,
20    },
21    resources::{ListResourcesParams, ReadResourceParams, ResourceManager},
22    tools::{CallToolParams, ListToolsParams, ToolRegistry},
23    transport::{StdioTransport, Transport, TransportMessage},
24    CodePrismMcpServer,
25};
26
27/// MCP Server state
28#[derive(Debug, Clone, PartialEq)]
29pub enum ServerState {
30    /// Server is not initialized
31    Uninitialized,
32    /// Server is initialized and ready
33    Ready,
34    /// Server is shutting down
35    Shutdown,
36}
37
38/// Registry for tracking active requests and their cancellation tokens
39#[derive(Debug, Default)]
40pub struct RequestRegistry {
41    /// Active requests mapped by their ID
42    active_requests: HashMap<String, CancellationToken>,
43}
44
45impl RequestRegistry {
46    /// Register a new request with its cancellation token
47    pub fn register(&mut self, request_id: String, token: CancellationToken) {
48        debug!("Registering request {}", request_id);
49        self.active_requests.insert(request_id, token);
50    }
51
52    /// Unregister a completed request
53    pub fn unregister(&mut self, request_id: &str) {
54        debug!("Unregistering request {}", request_id);
55        self.active_requests.remove(request_id);
56    }
57
58    /// Cancel a specific request
59    pub fn cancel_request(&mut self, request_id: &str) -> bool {
60        if let Some(token) = self.active_requests.get(request_id) {
61            info!("Cancelling request {}", request_id);
62            token.cancel();
63            true
64        } else {
65            warn!("Request {} not found for cancellation", request_id);
66            false
67        }
68    }
69
70    /// Get the cancellation token for a request
71    pub fn get_token(&self, request_id: &str) -> Option<&CancellationToken> {
72        self.active_requests.get(request_id)
73    }
74
75    /// Cancel all active requests (for shutdown)
76    pub fn cancel_all(&mut self) {
77        info!(
78            "Cancelling all {} active requests",
79            self.active_requests.len()
80        );
81        for (request_id, token) in &self.active_requests {
82            debug!("Cancelling request {}", request_id);
83            token.cancel();
84        }
85        self.active_requests.clear();
86    }
87}
88
89/// Main MCP Server implementation
90pub struct McpServer {
91    /// Current server state
92    state: ServerState,
93    /// MCP protocol version
94    #[allow(dead_code)] // TODO: Will be used for protocol compatibility checks
95    protocol_version: String,
96    /// Server information
97    server_info: ServerInfo,
98    /// Client information (set during initialization)
99    client_info: Option<ClientInfo>,
100    /// Client type and optimizations
101    client_type: Option<ClientType>,
102    client_optimizations: ClientOptimizations,
103    /// Version negotiation result
104    version_negotiation: Option<VersionNegotiation>,
105    /// Core CodePrism server instance
106    codeprism_server: Arc<RwLock<CodePrismMcpServer>>,
107    /// Resource manager
108    resource_manager: ResourceManager,
109    /// Tool registry
110    tool_registry: ToolRegistry,
111    /// Prompt manager
112    prompt_manager: PromptManager,
113    /// Request registry for cancellation tracking
114    request_registry: Arc<RwLock<RequestRegistry>>,
115    /// Default timeout for operations
116    default_timeout: Duration,
117}
118
119impl McpServer {
120    /// Create a new MCP server
121    pub fn new() -> Result<Self> {
122        let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new()?));
123
124        let resource_manager = ResourceManager::new(codeprism_server.clone());
125        let tool_registry = ToolRegistry::new(codeprism_server.clone());
126        let prompt_manager = PromptManager::new(codeprism_server.clone());
127
128        Ok(Self {
129            state: ServerState::Uninitialized,
130            protocol_version: DEFAULT_PROTOCOL_VERSION.to_string(),
131            server_info: ServerInfo {
132                name: "codeprism-mcp".to_string(),
133                version: "0.1.0".to_string(),
134            },
135            client_info: None,
136            client_type: None,
137            client_optimizations: ClientOptimizations::default(),
138            version_negotiation: None,
139            codeprism_server,
140            resource_manager,
141            tool_registry,
142            prompt_manager,
143            request_registry: Arc::new(RwLock::new(RequestRegistry::default())),
144            default_timeout: Duration::from_secs(300), // 5 minutes default timeout
145        })
146    }
147
148    /// Create a new MCP server with custom configuration
149    pub fn new_with_config(
150        memory_limit_mb: usize,
151        batch_size: usize,
152        max_file_size_mb: usize,
153        disable_memory_limit: bool,
154        exclude_dirs: Vec<String>,
155        include_extensions: Option<Vec<String>>,
156        dependency_mode: Option<String>,
157    ) -> Result<Self> {
158        let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new_with_config(
159            memory_limit_mb,
160            batch_size,
161            max_file_size_mb,
162            disable_memory_limit,
163            exclude_dirs,
164            include_extensions,
165            dependency_mode,
166        )?));
167
168        let resource_manager = ResourceManager::new(codeprism_server.clone());
169        let tool_registry = ToolRegistry::new(codeprism_server.clone());
170        let prompt_manager = PromptManager::new(codeprism_server.clone());
171
172        Ok(Self {
173            state: ServerState::Uninitialized,
174            protocol_version: DEFAULT_PROTOCOL_VERSION.to_string(),
175            server_info: ServerInfo {
176                name: "codeprism-mcp".to_string(),
177                version: "0.1.0".to_string(),
178            },
179            client_info: None,
180            client_type: None,
181            client_optimizations: ClientOptimizations::default(),
182            version_negotiation: None,
183            codeprism_server,
184            resource_manager,
185            tool_registry,
186            prompt_manager,
187            request_registry: Arc::new(RwLock::new(RequestRegistry::default())),
188            default_timeout: Duration::from_secs(300), // 5 minutes default timeout
189        })
190    }
191
192    /// Initialize with repository path
193    pub async fn initialize_with_repository<P: AsRef<std::path::Path>>(
194        &self,
195        path: P,
196    ) -> Result<()> {
197        let mut server = self.codeprism_server.write().await;
198        server.initialize_with_repository(path).await
199    }
200
201    /// Run the MCP server with stdio transport
202    pub async fn run_stdio(self) -> Result<()> {
203        info!("Starting CodePrism MCP server with stdio transport");
204
205        let mut transport = StdioTransport::new();
206        transport.start().await?;
207
208        self.run_with_transport(transport).await
209    }
210
211    /// Run the MCP server with a custom transport
212    pub async fn run_with_transport<T: Transport>(mut self, mut transport: T) -> Result<()> {
213        info!("Starting CodePrism MCP server");
214
215        loop {
216            match transport.receive().await? {
217                Some(message) => {
218                    if let Some(response) = self.handle_message(message).await? {
219                        transport.send(response).await?;
220                    }
221                }
222                None => {
223                    debug!("Transport closed, shutting down server");
224                    break;
225                }
226            }
227        }
228
229        // Cancel all active requests before shutdown
230        let mut registry = self.request_registry.write().await;
231        registry.cancel_all();
232        drop(registry);
233
234        transport.close().await?;
235        info!("Prism MCP server stopped");
236        Ok(())
237    }
238
239    /// Handle an incoming message
240    async fn handle_message(
241        &mut self,
242        message: TransportMessage,
243    ) -> Result<Option<TransportMessage>> {
244        match message {
245            TransportMessage::Request(request) => {
246                let response = self.handle_request(request).await;
247                Ok(Some(TransportMessage::Response(response)))
248            }
249            TransportMessage::Notification(notification) => {
250                self.handle_notification(notification).await?;
251                Ok(None) // Notifications don't get responses
252            }
253            TransportMessage::Response(_) => {
254                warn!("Received unexpected response message");
255                Ok(None)
256            }
257        }
258    }
259
260    /// Handle a JSON-RPC request with cancellation support
261    async fn handle_request(&mut self, request: JsonRpcRequest) -> JsonRpcResponse {
262        debug!(
263            "Handling request: method={}, id={:?}",
264            request.method, request.id
265        );
266
267        // Create cancellation token for this request
268        let token = CancellationToken::new(request.id.clone());
269        let request_id_str = request.id.to_string();
270
271        // Register the request
272        {
273            let mut registry = self.request_registry.write().await;
274            registry.register(request_id_str.clone(), token.clone());
275        }
276
277        // Handle the request with cancellation support
278        let result = self
279            .handle_request_with_cancellation(request.clone(), token.clone())
280            .await;
281
282        // Unregister the request
283        {
284            let mut registry = self.request_registry.write().await;
285            registry.unregister(&request_id_str);
286        }
287
288        match result {
289            Ok(result) => JsonRpcResponse::success(request.id, result),
290            Err(error) => JsonRpcResponse::error(request.id, error),
291        }
292    }
293
294    /// Handle a request with cancellation token
295    async fn handle_request_with_cancellation(
296        &mut self,
297        request: JsonRpcRequest,
298        token: CancellationToken,
299    ) -> Result<Value, JsonRpcError> {
300        let timeout = self.default_timeout;
301
302        let operation = async {
303            match request.method.as_str() {
304                "initialize" => self.handle_initialize(request.params).await,
305                "resources/list" => self.handle_resources_list(request.params).await,
306                "resources/read" => self.handle_resources_read(request.params).await,
307                "tools/list" => self.handle_tools_list(request.params).await,
308                "tools/call" => self.handle_tools_call(request.params, token.clone()).await,
309                "prompts/list" => self.handle_prompts_list(request.params).await,
310                "prompts/get" => self.handle_prompts_get(request.params).await,
311                _ => Err(JsonRpcError::method_not_found(&request.method)),
312            }
313        };
314
315        match token.with_timeout(timeout, operation).await {
316            Ok(result) => result,
317            Err(crate::protocol::CancellationError::Cancelled) => Err(JsonRpcError::new(
318                -32800,
319                "Request was cancelled".to_string(),
320                None,
321            )),
322            Err(crate::protocol::CancellationError::Timeout) => Err(JsonRpcError::new(
323                -32801,
324                "Request timed out".to_string(),
325                None,
326            )),
327        }
328    }
329
330    /// Handle a JSON-RPC notification
331    async fn handle_notification(&mut self, notification: JsonRpcNotification) -> Result<()> {
332        debug!("Handling notification: method={}", notification.method);
333
334        match notification.method.as_str() {
335            "initialized" => {
336                info!("Client reported initialization complete");
337                self.state = ServerState::Ready;
338            }
339            "notifications/cancelled" => {
340                if let Some(params) = notification.params {
341                    match serde_json::from_value::<CancellationParams>(params) {
342                        Ok(cancel_params) => {
343                            let request_id = cancel_params.id.to_string();
344                            let mut registry = self.request_registry.write().await;
345                            if registry.cancel_request(&request_id) {
346                                info!("Successfully cancelled request {}", request_id);
347                                if let Some(reason) = cancel_params.reason {
348                                    debug!("Cancellation reason: {}", reason);
349                                }
350                            }
351                        }
352                        Err(e) => {
353                            warn!("Invalid cancellation notification params: {}", e);
354                        }
355                    }
356                } else {
357                    warn!("Cancellation notification missing parameters");
358                }
359            }
360            _ => {
361                warn!("Unknown notification method: {}", notification.method);
362            }
363        }
364
365        Ok(())
366    }
367
368    /// Handle initialize request
369    async fn handle_initialize(&mut self, params: Option<Value>) -> Result<Value, JsonRpcError> {
370        let params: InitializeParams = params
371            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
372            .try_into_type()
373            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
374
375        info!(
376            "Initializing MCP server with client: {} v{}",
377            params.client_info.name, params.client_info.version
378        );
379
380        // Perform version negotiation
381        let negotiation = VersionNegotiation::negotiate(&params.protocol_version);
382
383        // Check if the negotiation is acceptable
384        if !negotiation.is_acceptable() {
385            return Err(JsonRpcError::new(
386                -32600,
387                format!(
388                    "Incompatible protocol version. Client: {}, Server supports: {:?}",
389                    params.protocol_version, negotiation.server_versions
390                ),
391                None,
392            ));
393        }
394
395        // Log any warnings from version negotiation
396        for warning in &negotiation.warnings {
397            warn!("Version negotiation: {}", warning);
398        }
399
400        // Detect client type and set optimizations
401        let client_type = ClientType::from_client_info(&params.client_info);
402        let client_optimizations = client_type.get_optimizations();
403
404        info!(
405            "Client detected: {:?}, applying optimizations: max_response_size={}, timeout={}s",
406            client_type,
407            client_optimizations.max_response_size,
408            client_optimizations.preferred_timeout.as_secs()
409        );
410
411        // Store client info, type, and negotiation results
412        self.client_info = Some(params.client_info);
413        self.client_type = Some(client_type);
414        self.client_optimizations = client_optimizations.clone();
415        self.version_negotiation = Some(negotiation.clone());
416
417        // Update timeout based on client optimizations
418        self.default_timeout = client_optimizations.preferred_timeout;
419
420        // Create initialize result
421        let server = self.codeprism_server.read().await;
422        let result = InitializeResult {
423            protocol_version: negotiation.agreed_version,
424            capabilities: server.capabilities().clone(),
425            server_info: self.server_info.clone(),
426        };
427
428        serde_json::to_value(result)
429            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
430    }
431
432    /// Handle resources/list request
433    async fn handle_resources_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
434        let params = params
435            .map(serde_json::from_value)
436            .transpose()
437            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
438            .unwrap_or(ListResourcesParams { cursor: None });
439
440        let result = self
441            .resource_manager
442            .list_resources(params)
443            .await
444            .map_err(|e| JsonRpcError::internal_error(format!("Resource list error: {}", e)))?;
445
446        serde_json::to_value(result)
447            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
448    }
449
450    /// Handle resources/read request
451    async fn handle_resources_read(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
452        let params: ReadResourceParams = params
453            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
454            .try_into_type()
455            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
456
457        let result = self
458            .resource_manager
459            .read_resource(params)
460            .await
461            .map_err(|e| JsonRpcError::internal_error(format!("Resource read error: {}", e)))?;
462
463        serde_json::to_value(result)
464            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
465    }
466
467    /// Handle tools/list request
468    async fn handle_tools_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
469        let params = params
470            .map(serde_json::from_value)
471            .transpose()
472            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
473            .unwrap_or(ListToolsParams { cursor: None });
474
475        let result = self
476            .tool_registry
477            .list_tools(params)
478            .await
479            .map_err(|e| JsonRpcError::internal_error(format!("Tool list error: {}", e)))?;
480
481        serde_json::to_value(result)
482            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
483    }
484
485    /// Handle tools/call request with cancellation support
486    async fn handle_tools_call(
487        &self,
488        params: Option<Value>,
489        _token: CancellationToken,
490    ) -> Result<Value, JsonRpcError> {
491        let params: CallToolParams = params
492            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
493            .try_into_type()
494            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
495
496        // TODO: Pass cancellation token to tool registry in the future
497        let result = self
498            .tool_registry
499            .call_tool(params)
500            .await
501            .map_err(|e| JsonRpcError::internal_error(format!("Tool call error: {}", e)))?;
502
503        serde_json::to_value(result)
504            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
505    }
506
507    /// Handle prompts/list request
508    async fn handle_prompts_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
509        let params = params
510            .map(serde_json::from_value)
511            .transpose()
512            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
513            .unwrap_or(ListPromptsParams { cursor: None });
514
515        let result = self
516            .prompt_manager
517            .list_prompts(params)
518            .await
519            .map_err(|e| JsonRpcError::internal_error(format!("Prompt list error: {}", e)))?;
520
521        serde_json::to_value(result)
522            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
523    }
524
525    /// Handle prompts/get request
526    async fn handle_prompts_get(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
527        let params: GetPromptParams = params
528            .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
529            .try_into_type()
530            .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
531
532        let result = self
533            .prompt_manager
534            .get_prompt(params)
535            .await
536            .map_err(|e| JsonRpcError::internal_error(format!("Prompt get error: {}", e)))?;
537
538        serde_json::to_value(result)
539            .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
540    }
541
542    /// Get current server state
543    pub fn state(&self) -> ServerState {
544        self.state.clone()
545    }
546
547    /// Get server info
548    pub fn server_info(&self) -> &ServerInfo {
549        &self.server_info
550    }
551
552    /// Get client info (if initialized)
553    pub fn client_info(&self) -> Option<&ClientInfo> {
554        self.client_info.as_ref()
555    }
556}
557
558impl Default for McpServer {
559    fn default() -> Self {
560        Self::new().expect("Failed to create default MCP server")
561    }
562}
563
564// Helper trait for converting JSON values to types
565trait TryIntoType<T> {
566    fn try_into_type(self) -> Result<T, serde_json::Error>;
567}
568
569impl<T> TryIntoType<T> for Value
570where
571    T: serde::de::DeserializeOwned,
572{
573    fn try_into_type(self) -> Result<T, serde_json::Error> {
574        serde_json::from_value(self)
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::protocol::ClientCapabilities;
582
583    #[tokio::test]
584    async fn test_mcp_server_creation() {
585        let server = McpServer::new().expect("Failed to create MCP server");
586        assert_eq!(server.state(), ServerState::Uninitialized);
587        assert_eq!(server.server_info().name, "codeprism-mcp");
588        assert_eq!(server.server_info().version, "0.1.0");
589    }
590
591    #[tokio::test]
592    async fn test_initialize_request() {
593        let mut server = McpServer::new().expect("Failed to create MCP server");
594
595        let params = InitializeParams {
596            protocol_version: "2024-11-05".to_string(),
597            capabilities: ClientCapabilities::default(),
598            client_info: ClientInfo {
599                name: "test-client".to_string(),
600                version: "1.0.0".to_string(),
601            },
602        };
603
604        let params_value = serde_json::to_value(params).unwrap();
605        let result = server.handle_initialize(Some(params_value)).await;
606
607        assert!(result.is_ok());
608        assert!(server.client_info().is_some());
609        assert_eq!(server.client_info().unwrap().name, "test-client");
610    }
611
612    #[test]
613    fn test_server_states() {
614        assert_eq!(ServerState::Uninitialized, ServerState::Uninitialized);
615        assert_ne!(ServerState::Uninitialized, ServerState::Ready);
616        assert_ne!(ServerState::Ready, ServerState::Shutdown);
617    }
618
619    async fn create_test_server_with_repository() -> McpServer {
620        use std::fs;
621        use tempfile::TempDir;
622
623        let temp_dir = TempDir::new().expect("Failed to create temp dir");
624        let repo_path = temp_dir.path();
625
626        // Create test files for server testing
627        fs::write(
628            repo_path.join("app.py"),
629            r#"
630"""Main application module."""
631
632import logging
633from typing import List, Optional, Dict, Any
634from dataclasses import dataclass
635
636@dataclass
637class Config:
638    """Application configuration."""
639    database_url: str
640    api_key: str
641    debug: bool = False
642
643class ApplicationService:
644    """Main application service."""
645    
646    def __init__(self, config: Config):
647        self.config = config
648        self.logger = logging.getLogger(__name__)
649        self._users: Dict[str, 'User'] = {}
650    
651    def create_user(self, username: str, email: str) -> 'User':
652        """Create a new user."""
653        if username in self._users:
654            raise ValueError(f"User {username} already exists")
655        
656        user = User(username=username, email=email)
657        self._users[username] = user
658        self.logger.info(f"Created user: {username}")
659        return user
660    
661    def get_user(self, username: str) -> Optional['User']:
662        """Get a user by username."""
663        return self._users.get(username)
664    
665    def list_users(self) -> List['User']:
666        """List all users."""
667        return list(self._users.values())
668    
669    def delete_user(self, username: str) -> bool:
670        """Delete a user."""
671        if username in self._users:
672            del self._users[username]
673            self.logger.info(f"Deleted user: {username}")
674            return True
675        return False
676
677class User:
678    """User model."""
679    
680    def __init__(self, username: str, email: str):
681        self.username = username
682        self.email = email
683        self.created_at = None  # Would be datetime in real app
684        self.is_active = True
685    
686    def deactivate(self) -> None:
687        """Deactivate the user."""
688        self.is_active = False
689    
690    def activate(self) -> None:
691        """Activate the user."""
692        self.is_active = True
693    
694    def to_dict(self) -> Dict[str, Any]:
695        """Convert user to dictionary."""
696        return {
697            'username': self.username,
698            'email': self.email,
699            'is_active': self.is_active
700        }
701
702def main():
703    """Main application entry point."""
704    config = Config(
705        database_url="postgresql://localhost/myapp",
706        api_key="secret-key"
707    )
708    
709    app = ApplicationService(config)
710    
711    # Create some sample users
712    app.create_user("alice", "alice@example.com")
713    app.create_user("bob", "bob@example.com")
714    
715    # List users
716    users = app.list_users()
717    print(f"Created {len(users)} users")
718
719if __name__ == "__main__":
720    main()
721"#,
722        )
723        .unwrap();
724
725        fs::write(
726            repo_path.join("utils.py"),
727            r#"
728"""Utility functions for the application."""
729
730import re
731import hashlib
732from typing import Optional, Union, List
733from datetime import datetime, timedelta
734
735# Constants
736EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
737PASSWORD_MIN_LENGTH = 8
738
739def validate_email(email: str) -> bool:
740    """Validate email address format."""
741    return bool(EMAIL_REGEX.match(email))
742
743def validate_password(password: str) -> bool:
744    """Validate password strength."""
745    if len(password) < PASSWORD_MIN_LENGTH:
746        return False
747    
748    has_upper = any(c.isupper() for c in password)
749    has_lower = any(c.islower() for c in password)
750    has_digit = any(c.isdigit() for c in password)
751    
752    return has_upper and has_lower and has_digit
753
754def hash_password(password: str, salt: Optional[str] = None) -> str:
755    """Hash a password with salt."""
756    if salt is None:
757        salt = "default_salt"  # In real app, use random salt
758    
759    combined = f"{password}{salt}"
760    return hashlib.sha256(combined.encode()).hexdigest()
761
762def generate_token(length: int = 32) -> str:
763    """Generate a random token."""
764    import secrets
765    return secrets.token_hex(length)
766
767class DateUtils:
768    """Utility class for date operations."""
769    
770    @staticmethod
771    def now() -> datetime:
772        """Get current datetime."""
773        return datetime.now()
774    
775    @staticmethod
776    def add_days(date: datetime, days: int) -> datetime:
777        """Add days to a date."""
778        return date + timedelta(days=days)
779    
780    @staticmethod
781    def format_date(date: datetime, format_str: str = "%Y-%m-%d") -> str:
782        """Format a date as string."""
783        return date.strftime(format_str)
784
785def cleanup_string(text: str) -> str:
786    """Clean up a string by removing extra whitespace."""
787    return re.sub(r'\s+', ' ', text.strip())
788
789def parse_config_value(value: str) -> Union[str, int, bool]:
790    """Parse a configuration value to appropriate type."""
791    # Try boolean
792    if value.lower() in ('true', 'false'):
793        return value.lower() == 'true'
794    
795    # Try integer
796    try:
797        return int(value)
798    except ValueError:
799        pass
800    
801    # Return as string
802    return value
803"#,
804        )
805        .unwrap();
806
807        let server = McpServer::new_with_config(
808            2048,  // memory_limit_mb
809            20,    // batch_size
810            5,     // max_file_size_mb
811            false, // disable_memory_limit
812            vec!["__pycache__".to_string(), ".pytest_cache".to_string()],
813            Some(vec!["py".to_string()]),
814            Some("exclude".to_string()),
815        )
816        .expect("Failed to create MCP server");
817
818        server
819            .initialize_with_repository(repo_path)
820            .await
821            .expect("Failed to initialize repository");
822
823        // Keep temp_dir alive
824        std::mem::forget(temp_dir);
825
826        server
827    }
828
829    #[tokio::test]
830    async fn test_server_with_repository_initialization() {
831        let server = create_test_server_with_repository().await;
832
833        // Server should be properly configured
834        assert_eq!(server.state(), ServerState::Uninitialized);
835
836        // Should have repository configured
837        let codeprism_server = server.codeprism_server.read().await;
838        assert!(codeprism_server.repository_path().is_some());
839    }
840
841    #[tokio::test]
842    async fn test_full_mcp_workflow() {
843        let mut server = create_test_server_with_repository().await;
844
845        // 1. Initialize the MCP server
846        let init_params = InitializeParams {
847            protocol_version: "2024-11-05".to_string(),
848            capabilities: ClientCapabilities::default(),
849            client_info: ClientInfo {
850                name: "test-client".to_string(),
851                version: "1.0.0".to_string(),
852            },
853        };
854
855        let init_result = server
856            .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
857            .await;
858        assert!(init_result.is_ok());
859
860        // Server should have client info
861        assert!(server.client_info().is_some());
862
863        // 2. Test resources/list
864        let resources_result = server.handle_resources_list(None).await;
865        assert!(resources_result.is_ok());
866
867        let resources_value = resources_result.unwrap();
868        let resources: crate::resources::ListResourcesResult =
869            serde_json::from_value(resources_value).unwrap();
870        assert!(!resources.resources.is_empty());
871
872        // Should have various resource types
873        let uris: Vec<String> = resources.resources.iter().map(|r| r.uri.clone()).collect();
874        assert!(uris.iter().any(|uri| uri == "codeprism://repository/stats"));
875        assert!(uris.iter().any(|uri| uri == "codeprism://graph/repository"));
876        assert!(uris.iter().any(|uri| uri.contains("app.py")));
877
878        // 3. Test resources/read
879        let read_params = crate::resources::ReadResourceParams {
880            uri: "codeprism://repository/stats".to_string(),
881        };
882        let read_result = server
883            .handle_resources_read(Some(serde_json::to_value(read_params).unwrap()))
884            .await;
885        assert!(read_result.is_ok());
886
887        // 4. Test tools/list
888        let tools_result = server.handle_tools_list(None).await;
889        assert!(tools_result.is_ok());
890
891        let tools_value = tools_result.unwrap();
892        let tools: crate::tools::ListToolsResult = serde_json::from_value(tools_value).unwrap();
893        assert_eq!(tools.tools.len(), 26); // All 26 tools should be available including the new JavaScript analysis tools (Phase 2.1)
894
895        // 5. Test tools/call with repository_stats
896        let tool_params = crate::tools::CallToolParams {
897            name: "repository_stats".to_string(),
898            arguments: Some(serde_json::json!({})),
899        };
900        let dummy_token = CancellationToken::new(serde_json::Value::String("test".to_string()));
901        let tool_result = server
902            .handle_tools_call(
903                Some(serde_json::to_value(tool_params).unwrap()),
904                dummy_token,
905            )
906            .await;
907        assert!(tool_result.is_ok());
908
909        // 6. Test prompts/list
910        let prompts_result = server.handle_prompts_list(None).await;
911        assert!(prompts_result.is_ok());
912
913        let prompts_value = prompts_result.unwrap();
914        let prompts: crate::prompts::ListPromptsResult =
915            serde_json::from_value(prompts_value).unwrap();
916        assert_eq!(prompts.prompts.len(), 16); // All 16 prompts should be available (original 8 + 8 new for large codebase understanding)
917
918        // 7. Test prompts/get
919        let prompt_params = crate::prompts::GetPromptParams {
920            name: "repository_overview".to_string(),
921            arguments: Some(serde_json::Map::from_iter([(
922                "focus_area".to_string(),
923                serde_json::Value::String("architecture".to_string()),
924            )])),
925        };
926        let prompt_result = server
927            .handle_prompts_get(Some(serde_json::to_value(prompt_params).unwrap()))
928            .await;
929        assert!(prompt_result.is_ok());
930    }
931
932    #[tokio::test]
933    async fn test_request_handling_errors() {
934        let mut server = McpServer::new().expect("Failed to create MCP server");
935
936        // Test invalid method
937        let invalid_request = JsonRpcRequest {
938            jsonrpc: "2.0".to_string(),
939            id: serde_json::Value::Number(1.into()),
940            method: "invalid_method".to_string(),
941            params: None,
942        };
943
944        let response = server.handle_request(invalid_request).await;
945        assert!(response.error.is_some());
946        assert_eq!(response.error.unwrap().code, -32601); // Method not found
947
948        // Test missing required parameters
949        let missing_params_request = JsonRpcRequest {
950            jsonrpc: "2.0".to_string(),
951            id: serde_json::Value::Number(2.into()),
952            method: "resources/read".to_string(),
953            params: None, // Missing required uri parameter
954        };
955
956        let response = server.handle_request(missing_params_request).await;
957        assert!(response.error.is_some());
958        assert_eq!(response.error.unwrap().code, -32602); // Invalid params
959    }
960
961    #[tokio::test]
962    async fn test_notification_handling() {
963        let mut server = McpServer::new().expect("Failed to create MCP server");
964
965        // Test initialized notification
966        let initialized_notification = JsonRpcNotification {
967            jsonrpc: "2.0".to_string(),
968            method: "initialized".to_string(),
969            params: None,
970        };
971
972        assert_eq!(server.state(), ServerState::Uninitialized);
973
974        let result = server.handle_notification(initialized_notification).await;
975        assert!(result.is_ok());
976        assert_eq!(server.state(), ServerState::Ready);
977
978        // Test unknown notification
979        let unknown_notification = JsonRpcNotification {
980            jsonrpc: "2.0".to_string(),
981            method: "unknown_notification".to_string(),
982            params: None,
983        };
984
985        let result = server.handle_notification(unknown_notification).await;
986        assert!(result.is_ok()); // Should not fail, just log warning
987    }
988
989    #[tokio::test]
990    async fn test_message_handling() {
991        let mut server = McpServer::new().expect("Failed to create MCP server");
992
993        // Test request message handling
994        let request_message = crate::transport::TransportMessage::Request(JsonRpcRequest {
995            jsonrpc: "2.0".to_string(),
996            id: serde_json::Value::Number(1.into()),
997            method: "initialize".to_string(),
998            params: Some(serde_json::json!({
999                "protocolVersion": "2024-11-05",
1000                "capabilities": {},
1001                "clientInfo": {
1002                    "name": "test-client",
1003                    "version": "1.0.0"
1004                }
1005            })),
1006        });
1007
1008        let response = server.handle_message(request_message).await;
1009        assert!(response.is_ok());
1010        assert!(response.unwrap().is_some()); // Should return a response
1011
1012        // Test notification message handling
1013        let notification_message =
1014            crate::transport::TransportMessage::Notification(JsonRpcNotification {
1015                jsonrpc: "2.0".to_string(),
1016                method: "initialized".to_string(),
1017                params: None,
1018            });
1019
1020        let response = server.handle_message(notification_message).await;
1021        assert!(response.is_ok());
1022        assert!(response.unwrap().is_none()); // Notifications don't return responses
1023    }
1024
1025    #[tokio::test]
1026    async fn test_server_capabilities_validation() {
1027        let server = create_test_server_with_repository().await;
1028        let codeprism_server = server.codeprism_server.read().await;
1029        let capabilities = codeprism_server.capabilities();
1030
1031        // Verify all required capabilities are present
1032        assert!(capabilities.resources.is_some());
1033        assert!(capabilities.tools.is_some());
1034        assert!(capabilities.prompts.is_some());
1035
1036        // Verify resource capabilities
1037        let resource_caps = capabilities.resources.as_ref().unwrap();
1038        assert_eq!(resource_caps.subscribe, Some(true));
1039        assert_eq!(resource_caps.list_changed, Some(true));
1040
1041        // Verify tool capabilities
1042        let tool_caps = capabilities.tools.as_ref().unwrap();
1043        assert_eq!(tool_caps.list_changed, Some(true));
1044
1045        // Verify prompt capabilities
1046        let prompt_caps = capabilities.prompts.as_ref().unwrap();
1047        assert_eq!(prompt_caps.list_changed, Some(false));
1048    }
1049
1050    #[tokio::test]
1051    async fn test_concurrent_requests() {
1052        use std::sync::Arc;
1053        use tokio::sync::RwLock;
1054
1055        let server = Arc::new(RwLock::new(create_test_server_with_repository().await));
1056
1057        // Initialize the server first
1058        {
1059            let mut server_lock = server.write().await;
1060            let init_params = InitializeParams {
1061                protocol_version: "2024-11-05".to_string(),
1062                capabilities: ClientCapabilities::default(),
1063                client_info: ClientInfo {
1064                    name: "test-client".to_string(),
1065                    version: "1.0.0".to_string(),
1066                },
1067            };
1068
1069            server_lock
1070                .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
1071                .await
1072                .unwrap();
1073        }
1074
1075        // Run multiple concurrent requests
1076        let mut handles = Vec::new();
1077
1078        for i in 0..5 {
1079            let server_clone = server.clone();
1080            let handle = tokio::spawn(async move {
1081                let server_lock = server_clone.write().await;
1082
1083                // Test resources/list
1084                let resources_result = server_lock.handle_resources_list(None).await;
1085                assert!(resources_result.is_ok());
1086
1087                // Test tools/list
1088                let tools_result = server_lock.handle_tools_list(None).await;
1089                assert!(tools_result.is_ok());
1090
1091                i // Return the task number
1092            });
1093
1094            handles.push(handle);
1095        }
1096
1097        // Wait for all tasks to complete
1098        for handle in handles {
1099            let result = handle.await;
1100            assert!(result.is_ok());
1101        }
1102    }
1103
1104    #[test]
1105    fn test_server_info_serialization() {
1106        let server_info = ServerInfo {
1107            name: "test-server".to_string(),
1108            version: "1.0.0".to_string(),
1109        };
1110
1111        let json = serde_json::to_string(&server_info).unwrap();
1112        let deserialized: ServerInfo = serde_json::from_str(&json).unwrap();
1113
1114        assert_eq!(server_info.name, deserialized.name);
1115        assert_eq!(server_info.version, deserialized.version);
1116    }
1117
1118    #[test]
1119    fn test_client_info_serialization() {
1120        let client_info = ClientInfo {
1121            name: "test-client".to_string(),
1122            version: "2.0.0".to_string(),
1123        };
1124
1125        let json = serde_json::to_string(&client_info).unwrap();
1126        let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
1127
1128        assert_eq!(client_info.name, deserialized.name);
1129        assert_eq!(client_info.version, deserialized.version);
1130    }
1131}