1use anyhow::Result;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tokio::time::Duration;
12use tracing::{debug, info, warn};
13
14use crate::{
15 prompts::{GetPromptParams, ListPromptsParams, PromptManager},
16 protocol::{
17 CancellationParams, CancellationToken, ClientInfo, ClientOptimizations, ClientType,
18 InitializeParams, InitializeResult, JsonRpcError, JsonRpcNotification, JsonRpcRequest,
19 JsonRpcResponse, ServerInfo, VersionNegotiation, DEFAULT_PROTOCOL_VERSION,
20 },
21 resources::{ListResourcesParams, ReadResourceParams, ResourceManager},
22 tools::{CallToolParams, ListToolsParams, ToolRegistry},
23 transport::{StdioTransport, Transport, TransportMessage},
24 CodePrismMcpServer,
25};
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum ServerState {
30 Uninitialized,
32 Ready,
34 Shutdown,
36}
37
38#[derive(Debug, Default)]
40pub struct RequestRegistry {
41 active_requests: HashMap<String, CancellationToken>,
43}
44
45impl RequestRegistry {
46 pub fn register(&mut self, request_id: String, token: CancellationToken) {
48 debug!("Registering request {}", request_id);
49 self.active_requests.insert(request_id, token);
50 }
51
52 pub fn unregister(&mut self, request_id: &str) {
54 debug!("Unregistering request {}", request_id);
55 self.active_requests.remove(request_id);
56 }
57
58 pub fn cancel_request(&mut self, request_id: &str) -> bool {
60 if let Some(token) = self.active_requests.get(request_id) {
61 info!("Cancelling request {}", request_id);
62 token.cancel();
63 true
64 } else {
65 warn!("Request {} not found for cancellation", request_id);
66 false
67 }
68 }
69
70 pub fn get_token(&self, request_id: &str) -> Option<&CancellationToken> {
72 self.active_requests.get(request_id)
73 }
74
75 pub fn cancel_all(&mut self) {
77 info!(
78 "Cancelling all {} active requests",
79 self.active_requests.len()
80 );
81 for (request_id, token) in &self.active_requests {
82 debug!("Cancelling request {}", request_id);
83 token.cancel();
84 }
85 self.active_requests.clear();
86 }
87}
88
89pub struct McpServer {
91 state: ServerState,
93 #[allow(dead_code)] protocol_version: String,
96 server_info: ServerInfo,
98 client_info: Option<ClientInfo>,
100 client_type: Option<ClientType>,
102 client_optimizations: ClientOptimizations,
103 version_negotiation: Option<VersionNegotiation>,
105 codeprism_server: Arc<RwLock<CodePrismMcpServer>>,
107 resource_manager: ResourceManager,
109 tool_registry: ToolRegistry,
111 prompt_manager: PromptManager,
113 request_registry: Arc<RwLock<RequestRegistry>>,
115 default_timeout: Duration,
117}
118
119impl McpServer {
120 pub fn new() -> Result<Self> {
122 let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new()?));
123
124 let resource_manager = ResourceManager::new(codeprism_server.clone());
125 let tool_registry = ToolRegistry::new(codeprism_server.clone());
126 let prompt_manager = PromptManager::new(codeprism_server.clone());
127
128 Ok(Self {
129 state: ServerState::Uninitialized,
130 protocol_version: DEFAULT_PROTOCOL_VERSION.to_string(),
131 server_info: ServerInfo {
132 name: "codeprism-mcp".to_string(),
133 version: "0.1.0".to_string(),
134 },
135 client_info: None,
136 client_type: None,
137 client_optimizations: ClientOptimizations::default(),
138 version_negotiation: None,
139 codeprism_server,
140 resource_manager,
141 tool_registry,
142 prompt_manager,
143 request_registry: Arc::new(RwLock::new(RequestRegistry::default())),
144 default_timeout: Duration::from_secs(300), })
146 }
147
148 pub fn new_with_config(
150 memory_limit_mb: usize,
151 batch_size: usize,
152 max_file_size_mb: usize,
153 disable_memory_limit: bool,
154 exclude_dirs: Vec<String>,
155 include_extensions: Option<Vec<String>>,
156 dependency_mode: Option<String>,
157 ) -> Result<Self> {
158 let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new_with_config(
159 memory_limit_mb,
160 batch_size,
161 max_file_size_mb,
162 disable_memory_limit,
163 exclude_dirs,
164 include_extensions,
165 dependency_mode,
166 )?));
167
168 let resource_manager = ResourceManager::new(codeprism_server.clone());
169 let tool_registry = ToolRegistry::new(codeprism_server.clone());
170 let prompt_manager = PromptManager::new(codeprism_server.clone());
171
172 Ok(Self {
173 state: ServerState::Uninitialized,
174 protocol_version: DEFAULT_PROTOCOL_VERSION.to_string(),
175 server_info: ServerInfo {
176 name: "codeprism-mcp".to_string(),
177 version: "0.1.0".to_string(),
178 },
179 client_info: None,
180 client_type: None,
181 client_optimizations: ClientOptimizations::default(),
182 version_negotiation: None,
183 codeprism_server,
184 resource_manager,
185 tool_registry,
186 prompt_manager,
187 request_registry: Arc::new(RwLock::new(RequestRegistry::default())),
188 default_timeout: Duration::from_secs(300), })
190 }
191
192 pub async fn initialize_with_repository<P: AsRef<std::path::Path>>(
194 &self,
195 path: P,
196 ) -> Result<()> {
197 let mut server = self.codeprism_server.write().await;
198 server.initialize_with_repository(path).await
199 }
200
201 pub async fn run_stdio(self) -> Result<()> {
203 info!("Starting CodePrism MCP server with stdio transport");
204
205 let mut transport = StdioTransport::new();
206 transport.start().await?;
207
208 self.run_with_transport(transport).await
209 }
210
211 pub async fn run_with_transport<T: Transport>(mut self, mut transport: T) -> Result<()> {
213 info!("Starting CodePrism MCP server");
214
215 loop {
216 match transport.receive().await? {
217 Some(message) => {
218 if let Some(response) = self.handle_message(message).await? {
219 transport.send(response).await?;
220 }
221 }
222 None => {
223 debug!("Transport closed, shutting down server");
224 break;
225 }
226 }
227 }
228
229 let mut registry = self.request_registry.write().await;
231 registry.cancel_all();
232 drop(registry);
233
234 transport.close().await?;
235 info!("Prism MCP server stopped");
236 Ok(())
237 }
238
239 async fn handle_message(
241 &mut self,
242 message: TransportMessage,
243 ) -> Result<Option<TransportMessage>> {
244 match message {
245 TransportMessage::Request(request) => {
246 let response = self.handle_request(request).await;
247 Ok(Some(TransportMessage::Response(response)))
248 }
249 TransportMessage::Notification(notification) => {
250 self.handle_notification(notification).await?;
251 Ok(None) }
253 TransportMessage::Response(_) => {
254 warn!("Received unexpected response message");
255 Ok(None)
256 }
257 }
258 }
259
260 async fn handle_request(&mut self, request: JsonRpcRequest) -> JsonRpcResponse {
262 debug!(
263 "Handling request: method={}, id={:?}",
264 request.method, request.id
265 );
266
267 let token = CancellationToken::new(request.id.clone());
269 let request_id_str = request.id.to_string();
270
271 {
273 let mut registry = self.request_registry.write().await;
274 registry.register(request_id_str.clone(), token.clone());
275 }
276
277 let result = self
279 .handle_request_with_cancellation(request.clone(), token.clone())
280 .await;
281
282 {
284 let mut registry = self.request_registry.write().await;
285 registry.unregister(&request_id_str);
286 }
287
288 match result {
289 Ok(result) => JsonRpcResponse::success(request.id, result),
290 Err(error) => JsonRpcResponse::error(request.id, error),
291 }
292 }
293
294 async fn handle_request_with_cancellation(
296 &mut self,
297 request: JsonRpcRequest,
298 token: CancellationToken,
299 ) -> Result<Value, JsonRpcError> {
300 let timeout = self.default_timeout;
301
302 let operation = async {
303 match request.method.as_str() {
304 "initialize" => self.handle_initialize(request.params).await,
305 "resources/list" => self.handle_resources_list(request.params).await,
306 "resources/read" => self.handle_resources_read(request.params).await,
307 "tools/list" => self.handle_tools_list(request.params).await,
308 "tools/call" => self.handle_tools_call(request.params, token.clone()).await,
309 "prompts/list" => self.handle_prompts_list(request.params).await,
310 "prompts/get" => self.handle_prompts_get(request.params).await,
311 _ => Err(JsonRpcError::method_not_found(&request.method)),
312 }
313 };
314
315 match token.with_timeout(timeout, operation).await {
316 Ok(result) => result,
317 Err(crate::protocol::CancellationError::Cancelled) => Err(JsonRpcError::new(
318 -32800,
319 "Request was cancelled".to_string(),
320 None,
321 )),
322 Err(crate::protocol::CancellationError::Timeout) => Err(JsonRpcError::new(
323 -32801,
324 "Request timed out".to_string(),
325 None,
326 )),
327 }
328 }
329
330 async fn handle_notification(&mut self, notification: JsonRpcNotification) -> Result<()> {
332 debug!("Handling notification: method={}", notification.method);
333
334 match notification.method.as_str() {
335 "initialized" => {
336 info!("Client reported initialization complete");
337 self.state = ServerState::Ready;
338 }
339 "notifications/cancelled" => {
340 if let Some(params) = notification.params {
341 match serde_json::from_value::<CancellationParams>(params) {
342 Ok(cancel_params) => {
343 let request_id = cancel_params.id.to_string();
344 let mut registry = self.request_registry.write().await;
345 if registry.cancel_request(&request_id) {
346 info!("Successfully cancelled request {}", request_id);
347 if let Some(reason) = cancel_params.reason {
348 debug!("Cancellation reason: {}", reason);
349 }
350 }
351 }
352 Err(e) => {
353 warn!("Invalid cancellation notification params: {}", e);
354 }
355 }
356 } else {
357 warn!("Cancellation notification missing parameters");
358 }
359 }
360 _ => {
361 warn!("Unknown notification method: {}", notification.method);
362 }
363 }
364
365 Ok(())
366 }
367
368 async fn handle_initialize(&mut self, params: Option<Value>) -> Result<Value, JsonRpcError> {
370 let params: InitializeParams = params
371 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
372 .try_into_type()
373 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
374
375 info!(
376 "Initializing MCP server with client: {} v{}",
377 params.client_info.name, params.client_info.version
378 );
379
380 let negotiation = VersionNegotiation::negotiate(¶ms.protocol_version);
382
383 if !negotiation.is_acceptable() {
385 return Err(JsonRpcError::new(
386 -32600,
387 format!(
388 "Incompatible protocol version. Client: {}, Server supports: {:?}",
389 params.protocol_version, negotiation.server_versions
390 ),
391 None,
392 ));
393 }
394
395 for warning in &negotiation.warnings {
397 warn!("Version negotiation: {}", warning);
398 }
399
400 let client_type = ClientType::from_client_info(¶ms.client_info);
402 let client_optimizations = client_type.get_optimizations();
403
404 info!(
405 "Client detected: {:?}, applying optimizations: max_response_size={}, timeout={}s",
406 client_type,
407 client_optimizations.max_response_size,
408 client_optimizations.preferred_timeout.as_secs()
409 );
410
411 self.client_info = Some(params.client_info);
413 self.client_type = Some(client_type);
414 self.client_optimizations = client_optimizations.clone();
415 self.version_negotiation = Some(negotiation.clone());
416
417 self.default_timeout = client_optimizations.preferred_timeout;
419
420 let server = self.codeprism_server.read().await;
422 let result = InitializeResult {
423 protocol_version: negotiation.agreed_version,
424 capabilities: server.capabilities().clone(),
425 server_info: self.server_info.clone(),
426 };
427
428 serde_json::to_value(result)
429 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
430 }
431
432 async fn handle_resources_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
434 let params = params
435 .map(serde_json::from_value)
436 .transpose()
437 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
438 .unwrap_or(ListResourcesParams { cursor: None });
439
440 let result = self
441 .resource_manager
442 .list_resources(params)
443 .await
444 .map_err(|e| JsonRpcError::internal_error(format!("Resource list error: {}", e)))?;
445
446 serde_json::to_value(result)
447 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
448 }
449
450 async fn handle_resources_read(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
452 let params: ReadResourceParams = params
453 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
454 .try_into_type()
455 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
456
457 let result = self
458 .resource_manager
459 .read_resource(params)
460 .await
461 .map_err(|e| JsonRpcError::internal_error(format!("Resource read error: {}", e)))?;
462
463 serde_json::to_value(result)
464 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
465 }
466
467 async fn handle_tools_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
469 let params = params
470 .map(serde_json::from_value)
471 .transpose()
472 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
473 .unwrap_or(ListToolsParams { cursor: None });
474
475 let result = self
476 .tool_registry
477 .list_tools(params)
478 .await
479 .map_err(|e| JsonRpcError::internal_error(format!("Tool list error: {}", e)))?;
480
481 serde_json::to_value(result)
482 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
483 }
484
485 async fn handle_tools_call(
487 &self,
488 params: Option<Value>,
489 _token: CancellationToken,
490 ) -> Result<Value, JsonRpcError> {
491 let params: CallToolParams = params
492 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
493 .try_into_type()
494 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
495
496 let result = self
498 .tool_registry
499 .call_tool(params)
500 .await
501 .map_err(|e| JsonRpcError::internal_error(format!("Tool call error: {}", e)))?;
502
503 serde_json::to_value(result)
504 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
505 }
506
507 async fn handle_prompts_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
509 let params = params
510 .map(serde_json::from_value)
511 .transpose()
512 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
513 .unwrap_or(ListPromptsParams { cursor: None });
514
515 let result = self
516 .prompt_manager
517 .list_prompts(params)
518 .await
519 .map_err(|e| JsonRpcError::internal_error(format!("Prompt list error: {}", e)))?;
520
521 serde_json::to_value(result)
522 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
523 }
524
525 async fn handle_prompts_get(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
527 let params: GetPromptParams = params
528 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
529 .try_into_type()
530 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
531
532 let result = self
533 .prompt_manager
534 .get_prompt(params)
535 .await
536 .map_err(|e| JsonRpcError::internal_error(format!("Prompt get error: {}", e)))?;
537
538 serde_json::to_value(result)
539 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
540 }
541
542 pub fn state(&self) -> ServerState {
544 self.state.clone()
545 }
546
547 pub fn server_info(&self) -> &ServerInfo {
549 &self.server_info
550 }
551
552 pub fn client_info(&self) -> Option<&ClientInfo> {
554 self.client_info.as_ref()
555 }
556}
557
558impl Default for McpServer {
559 fn default() -> Self {
560 Self::new().expect("Failed to create default MCP server")
561 }
562}
563
564trait TryIntoType<T> {
566 fn try_into_type(self) -> Result<T, serde_json::Error>;
567}
568
569impl<T> TryIntoType<T> for Value
570where
571 T: serde::de::DeserializeOwned,
572{
573 fn try_into_type(self) -> Result<T, serde_json::Error> {
574 serde_json::from_value(self)
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::protocol::ClientCapabilities;
582
583 #[tokio::test]
584 async fn test_mcp_server_creation() {
585 let server = McpServer::new().expect("Failed to create MCP server");
586 assert_eq!(server.state(), ServerState::Uninitialized);
587 assert_eq!(server.server_info().name, "codeprism-mcp");
588 assert_eq!(server.server_info().version, "0.1.0");
589 }
590
591 #[tokio::test]
592 async fn test_initialize_request() {
593 let mut server = McpServer::new().expect("Failed to create MCP server");
594
595 let params = InitializeParams {
596 protocol_version: "2024-11-05".to_string(),
597 capabilities: ClientCapabilities::default(),
598 client_info: ClientInfo {
599 name: "test-client".to_string(),
600 version: "1.0.0".to_string(),
601 },
602 };
603
604 let params_value = serde_json::to_value(params).unwrap();
605 let result = server.handle_initialize(Some(params_value)).await;
606
607 assert!(result.is_ok());
608 assert!(server.client_info().is_some());
609 assert_eq!(server.client_info().unwrap().name, "test-client");
610 }
611
612 #[test]
613 fn test_server_states() {
614 assert_eq!(ServerState::Uninitialized, ServerState::Uninitialized);
615 assert_ne!(ServerState::Uninitialized, ServerState::Ready);
616 assert_ne!(ServerState::Ready, ServerState::Shutdown);
617 }
618
619 async fn create_test_server_with_repository() -> McpServer {
620 use std::fs;
621 use tempfile::TempDir;
622
623 let temp_dir = TempDir::new().expect("Failed to create temp dir");
624 let repo_path = temp_dir.path();
625
626 fs::write(
628 repo_path.join("app.py"),
629 r#"
630"""Main application module."""
631
632import logging
633from typing import List, Optional, Dict, Any
634from dataclasses import dataclass
635
636@dataclass
637class Config:
638 """Application configuration."""
639 database_url: str
640 api_key: str
641 debug: bool = False
642
643class ApplicationService:
644 """Main application service."""
645
646 def __init__(self, config: Config):
647 self.config = config
648 self.logger = logging.getLogger(__name__)
649 self._users: Dict[str, 'User'] = {}
650
651 def create_user(self, username: str, email: str) -> 'User':
652 """Create a new user."""
653 if username in self._users:
654 raise ValueError(f"User {username} already exists")
655
656 user = User(username=username, email=email)
657 self._users[username] = user
658 self.logger.info(f"Created user: {username}")
659 return user
660
661 def get_user(self, username: str) -> Optional['User']:
662 """Get a user by username."""
663 return self._users.get(username)
664
665 def list_users(self) -> List['User']:
666 """List all users."""
667 return list(self._users.values())
668
669 def delete_user(self, username: str) -> bool:
670 """Delete a user."""
671 if username in self._users:
672 del self._users[username]
673 self.logger.info(f"Deleted user: {username}")
674 return True
675 return False
676
677class User:
678 """User model."""
679
680 def __init__(self, username: str, email: str):
681 self.username = username
682 self.email = email
683 self.created_at = None # Would be datetime in real app
684 self.is_active = True
685
686 def deactivate(self) -> None:
687 """Deactivate the user."""
688 self.is_active = False
689
690 def activate(self) -> None:
691 """Activate the user."""
692 self.is_active = True
693
694 def to_dict(self) -> Dict[str, Any]:
695 """Convert user to dictionary."""
696 return {
697 'username': self.username,
698 'email': self.email,
699 'is_active': self.is_active
700 }
701
702def main():
703 """Main application entry point."""
704 config = Config(
705 database_url="postgresql://localhost/myapp",
706 api_key="secret-key"
707 )
708
709 app = ApplicationService(config)
710
711 # Create some sample users
712 app.create_user("alice", "alice@example.com")
713 app.create_user("bob", "bob@example.com")
714
715 # List users
716 users = app.list_users()
717 print(f"Created {len(users)} users")
718
719if __name__ == "__main__":
720 main()
721"#,
722 )
723 .unwrap();
724
725 fs::write(
726 repo_path.join("utils.py"),
727 r#"
728"""Utility functions for the application."""
729
730import re
731import hashlib
732from typing import Optional, Union, List
733from datetime import datetime, timedelta
734
735# Constants
736EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
737PASSWORD_MIN_LENGTH = 8
738
739def validate_email(email: str) -> bool:
740 """Validate email address format."""
741 return bool(EMAIL_REGEX.match(email))
742
743def validate_password(password: str) -> bool:
744 """Validate password strength."""
745 if len(password) < PASSWORD_MIN_LENGTH:
746 return False
747
748 has_upper = any(c.isupper() for c in password)
749 has_lower = any(c.islower() for c in password)
750 has_digit = any(c.isdigit() for c in password)
751
752 return has_upper and has_lower and has_digit
753
754def hash_password(password: str, salt: Optional[str] = None) -> str:
755 """Hash a password with salt."""
756 if salt is None:
757 salt = "default_salt" # In real app, use random salt
758
759 combined = f"{password}{salt}"
760 return hashlib.sha256(combined.encode()).hexdigest()
761
762def generate_token(length: int = 32) -> str:
763 """Generate a random token."""
764 import secrets
765 return secrets.token_hex(length)
766
767class DateUtils:
768 """Utility class for date operations."""
769
770 @staticmethod
771 def now() -> datetime:
772 """Get current datetime."""
773 return datetime.now()
774
775 @staticmethod
776 def add_days(date: datetime, days: int) -> datetime:
777 """Add days to a date."""
778 return date + timedelta(days=days)
779
780 @staticmethod
781 def format_date(date: datetime, format_str: str = "%Y-%m-%d") -> str:
782 """Format a date as string."""
783 return date.strftime(format_str)
784
785def cleanup_string(text: str) -> str:
786 """Clean up a string by removing extra whitespace."""
787 return re.sub(r'\s+', ' ', text.strip())
788
789def parse_config_value(value: str) -> Union[str, int, bool]:
790 """Parse a configuration value to appropriate type."""
791 # Try boolean
792 if value.lower() in ('true', 'false'):
793 return value.lower() == 'true'
794
795 # Try integer
796 try:
797 return int(value)
798 except ValueError:
799 pass
800
801 # Return as string
802 return value
803"#,
804 )
805 .unwrap();
806
807 let server = McpServer::new_with_config(
808 2048, 20, 5, false, vec!["__pycache__".to_string(), ".pytest_cache".to_string()],
813 Some(vec!["py".to_string()]),
814 Some("exclude".to_string()),
815 )
816 .expect("Failed to create MCP server");
817
818 server
819 .initialize_with_repository(repo_path)
820 .await
821 .expect("Failed to initialize repository");
822
823 std::mem::forget(temp_dir);
825
826 server
827 }
828
829 #[tokio::test]
830 async fn test_server_with_repository_initialization() {
831 let server = create_test_server_with_repository().await;
832
833 assert_eq!(server.state(), ServerState::Uninitialized);
835
836 let codeprism_server = server.codeprism_server.read().await;
838 assert!(codeprism_server.repository_path().is_some());
839 }
840
841 #[tokio::test]
842 async fn test_full_mcp_workflow() {
843 let mut server = create_test_server_with_repository().await;
844
845 let init_params = InitializeParams {
847 protocol_version: "2024-11-05".to_string(),
848 capabilities: ClientCapabilities::default(),
849 client_info: ClientInfo {
850 name: "test-client".to_string(),
851 version: "1.0.0".to_string(),
852 },
853 };
854
855 let init_result = server
856 .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
857 .await;
858 assert!(init_result.is_ok());
859
860 assert!(server.client_info().is_some());
862
863 let resources_result = server.handle_resources_list(None).await;
865 assert!(resources_result.is_ok());
866
867 let resources_value = resources_result.unwrap();
868 let resources: crate::resources::ListResourcesResult =
869 serde_json::from_value(resources_value).unwrap();
870 assert!(!resources.resources.is_empty());
871
872 let uris: Vec<String> = resources.resources.iter().map(|r| r.uri.clone()).collect();
874 assert!(uris.iter().any(|uri| uri == "codeprism://repository/stats"));
875 assert!(uris.iter().any(|uri| uri == "codeprism://graph/repository"));
876 assert!(uris.iter().any(|uri| uri.contains("app.py")));
877
878 let read_params = crate::resources::ReadResourceParams {
880 uri: "codeprism://repository/stats".to_string(),
881 };
882 let read_result = server
883 .handle_resources_read(Some(serde_json::to_value(read_params).unwrap()))
884 .await;
885 assert!(read_result.is_ok());
886
887 let tools_result = server.handle_tools_list(None).await;
889 assert!(tools_result.is_ok());
890
891 let tools_value = tools_result.unwrap();
892 let tools: crate::tools::ListToolsResult = serde_json::from_value(tools_value).unwrap();
893 assert_eq!(tools.tools.len(), 26); let tool_params = crate::tools::CallToolParams {
897 name: "repository_stats".to_string(),
898 arguments: Some(serde_json::json!({})),
899 };
900 let dummy_token = CancellationToken::new(serde_json::Value::String("test".to_string()));
901 let tool_result = server
902 .handle_tools_call(
903 Some(serde_json::to_value(tool_params).unwrap()),
904 dummy_token,
905 )
906 .await;
907 assert!(tool_result.is_ok());
908
909 let prompts_result = server.handle_prompts_list(None).await;
911 assert!(prompts_result.is_ok());
912
913 let prompts_value = prompts_result.unwrap();
914 let prompts: crate::prompts::ListPromptsResult =
915 serde_json::from_value(prompts_value).unwrap();
916 assert_eq!(prompts.prompts.len(), 16); let prompt_params = crate::prompts::GetPromptParams {
920 name: "repository_overview".to_string(),
921 arguments: Some(serde_json::Map::from_iter([(
922 "focus_area".to_string(),
923 serde_json::Value::String("architecture".to_string()),
924 )])),
925 };
926 let prompt_result = server
927 .handle_prompts_get(Some(serde_json::to_value(prompt_params).unwrap()))
928 .await;
929 assert!(prompt_result.is_ok());
930 }
931
932 #[tokio::test]
933 async fn test_request_handling_errors() {
934 let mut server = McpServer::new().expect("Failed to create MCP server");
935
936 let invalid_request = JsonRpcRequest {
938 jsonrpc: "2.0".to_string(),
939 id: serde_json::Value::Number(1.into()),
940 method: "invalid_method".to_string(),
941 params: None,
942 };
943
944 let response = server.handle_request(invalid_request).await;
945 assert!(response.error.is_some());
946 assert_eq!(response.error.unwrap().code, -32601); let missing_params_request = JsonRpcRequest {
950 jsonrpc: "2.0".to_string(),
951 id: serde_json::Value::Number(2.into()),
952 method: "resources/read".to_string(),
953 params: None, };
955
956 let response = server.handle_request(missing_params_request).await;
957 assert!(response.error.is_some());
958 assert_eq!(response.error.unwrap().code, -32602); }
960
961 #[tokio::test]
962 async fn test_notification_handling() {
963 let mut server = McpServer::new().expect("Failed to create MCP server");
964
965 let initialized_notification = JsonRpcNotification {
967 jsonrpc: "2.0".to_string(),
968 method: "initialized".to_string(),
969 params: None,
970 };
971
972 assert_eq!(server.state(), ServerState::Uninitialized);
973
974 let result = server.handle_notification(initialized_notification).await;
975 assert!(result.is_ok());
976 assert_eq!(server.state(), ServerState::Ready);
977
978 let unknown_notification = JsonRpcNotification {
980 jsonrpc: "2.0".to_string(),
981 method: "unknown_notification".to_string(),
982 params: None,
983 };
984
985 let result = server.handle_notification(unknown_notification).await;
986 assert!(result.is_ok()); }
988
989 #[tokio::test]
990 async fn test_message_handling() {
991 let mut server = McpServer::new().expect("Failed to create MCP server");
992
993 let request_message = crate::transport::TransportMessage::Request(JsonRpcRequest {
995 jsonrpc: "2.0".to_string(),
996 id: serde_json::Value::Number(1.into()),
997 method: "initialize".to_string(),
998 params: Some(serde_json::json!({
999 "protocolVersion": "2024-11-05",
1000 "capabilities": {},
1001 "clientInfo": {
1002 "name": "test-client",
1003 "version": "1.0.0"
1004 }
1005 })),
1006 });
1007
1008 let response = server.handle_message(request_message).await;
1009 assert!(response.is_ok());
1010 assert!(response.unwrap().is_some()); let notification_message =
1014 crate::transport::TransportMessage::Notification(JsonRpcNotification {
1015 jsonrpc: "2.0".to_string(),
1016 method: "initialized".to_string(),
1017 params: None,
1018 });
1019
1020 let response = server.handle_message(notification_message).await;
1021 assert!(response.is_ok());
1022 assert!(response.unwrap().is_none()); }
1024
1025 #[tokio::test]
1026 async fn test_server_capabilities_validation() {
1027 let server = create_test_server_with_repository().await;
1028 let codeprism_server = server.codeprism_server.read().await;
1029 let capabilities = codeprism_server.capabilities();
1030
1031 assert!(capabilities.resources.is_some());
1033 assert!(capabilities.tools.is_some());
1034 assert!(capabilities.prompts.is_some());
1035
1036 let resource_caps = capabilities.resources.as_ref().unwrap();
1038 assert_eq!(resource_caps.subscribe, Some(true));
1039 assert_eq!(resource_caps.list_changed, Some(true));
1040
1041 let tool_caps = capabilities.tools.as_ref().unwrap();
1043 assert_eq!(tool_caps.list_changed, Some(true));
1044
1045 let prompt_caps = capabilities.prompts.as_ref().unwrap();
1047 assert_eq!(prompt_caps.list_changed, Some(false));
1048 }
1049
1050 #[tokio::test]
1051 async fn test_concurrent_requests() {
1052 use std::sync::Arc;
1053 use tokio::sync::RwLock;
1054
1055 let server = Arc::new(RwLock::new(create_test_server_with_repository().await));
1056
1057 {
1059 let mut server_lock = server.write().await;
1060 let init_params = InitializeParams {
1061 protocol_version: "2024-11-05".to_string(),
1062 capabilities: ClientCapabilities::default(),
1063 client_info: ClientInfo {
1064 name: "test-client".to_string(),
1065 version: "1.0.0".to_string(),
1066 },
1067 };
1068
1069 server_lock
1070 .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
1071 .await
1072 .unwrap();
1073 }
1074
1075 let mut handles = Vec::new();
1077
1078 for i in 0..5 {
1079 let server_clone = server.clone();
1080 let handle = tokio::spawn(async move {
1081 let server_lock = server_clone.write().await;
1082
1083 let resources_result = server_lock.handle_resources_list(None).await;
1085 assert!(resources_result.is_ok());
1086
1087 let tools_result = server_lock.handle_tools_list(None).await;
1089 assert!(tools_result.is_ok());
1090
1091 i });
1093
1094 handles.push(handle);
1095 }
1096
1097 for handle in handles {
1099 let result = handle.await;
1100 assert!(result.is_ok());
1101 }
1102 }
1103
1104 #[test]
1105 fn test_server_info_serialization() {
1106 let server_info = ServerInfo {
1107 name: "test-server".to_string(),
1108 version: "1.0.0".to_string(),
1109 };
1110
1111 let json = serde_json::to_string(&server_info).unwrap();
1112 let deserialized: ServerInfo = serde_json::from_str(&json).unwrap();
1113
1114 assert_eq!(server_info.name, deserialized.name);
1115 assert_eq!(server_info.version, deserialized.version);
1116 }
1117
1118 #[test]
1119 fn test_client_info_serialization() {
1120 let client_info = ClientInfo {
1121 name: "test-client".to_string(),
1122 version: "2.0.0".to_string(),
1123 };
1124
1125 let json = serde_json::to_string(&client_info).unwrap();
1126 let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
1127
1128 assert_eq!(client_info.name, deserialized.name);
1129 assert_eq!(client_info.version, deserialized.version);
1130 }
1131}