1use anyhow::Result;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, info, warn};
11
12use crate::{
13 prompts::{GetPromptParams, ListPromptsParams, PromptManager},
14 protocol::{
15 ClientInfo, InitializeParams, InitializeResult, JsonRpcError, JsonRpcNotification,
16 JsonRpcRequest, JsonRpcResponse, ServerInfo,
17 },
18 resources::{ListResourcesParams, ReadResourceParams, ResourceManager},
19 tools::{CallToolParams, ListToolsParams, ToolRegistry},
20 transport::{StdioTransport, Transport, TransportMessage},
21 CodePrismMcpServer,
22};
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum ServerState {
27 Uninitialized,
29 Ready,
31 Shutdown,
33}
34
35pub struct McpServer {
37 state: ServerState,
39 protocol_version: String,
41 server_info: ServerInfo,
43 client_info: Option<ClientInfo>,
45 codeprism_server: Arc<RwLock<CodePrismMcpServer>>,
47 resource_manager: ResourceManager,
49 tool_registry: ToolRegistry,
51 prompt_manager: PromptManager,
53}
54
55impl McpServer {
56 pub fn new() -> Result<Self> {
58 let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new()?));
59
60 let resource_manager = ResourceManager::new(codeprism_server.clone());
61 let tool_registry = ToolRegistry::new(codeprism_server.clone());
62 let prompt_manager = PromptManager::new(codeprism_server.clone());
63
64 Ok(Self {
65 state: ServerState::Uninitialized,
66 protocol_version: "2024-11-05".to_string(),
67 server_info: ServerInfo {
68 name: "codeprism-mcp".to_string(),
69 version: "0.1.0".to_string(),
70 },
71 client_info: None,
72 codeprism_server,
73 resource_manager,
74 tool_registry,
75 prompt_manager,
76 })
77 }
78
79 pub fn new_with_config(
81 memory_limit_mb: usize,
82 batch_size: usize,
83 max_file_size_mb: usize,
84 disable_memory_limit: bool,
85 exclude_dirs: Vec<String>,
86 include_extensions: Option<Vec<String>>,
87 dependency_mode: Option<String>,
88 ) -> Result<Self> {
89 let codeprism_server = Arc::new(RwLock::new(CodePrismMcpServer::new_with_config(
90 memory_limit_mb,
91 batch_size,
92 max_file_size_mb,
93 disable_memory_limit,
94 exclude_dirs,
95 include_extensions,
96 dependency_mode,
97 )?));
98
99 let resource_manager = ResourceManager::new(codeprism_server.clone());
100 let tool_registry = ToolRegistry::new(codeprism_server.clone());
101 let prompt_manager = PromptManager::new(codeprism_server.clone());
102
103 Ok(Self {
104 state: ServerState::Uninitialized,
105 protocol_version: "2024-11-05".to_string(),
106 server_info: ServerInfo {
107 name: "codeprism-mcp".to_string(),
108 version: "0.1.0".to_string(),
109 },
110 client_info: None,
111 codeprism_server,
112 resource_manager,
113 tool_registry,
114 prompt_manager,
115 })
116 }
117
118 pub async fn initialize_with_repository<P: AsRef<std::path::Path>>(
120 &self,
121 path: P,
122 ) -> Result<()> {
123 let mut server = self.codeprism_server.write().await;
124 server.initialize_with_repository(path).await
125 }
126
127 pub async fn run_stdio(self) -> Result<()> {
129 info!("Starting CodePrism MCP server with stdio transport");
130
131 let mut transport = StdioTransport::new();
132 transport.start().await?;
133
134 self.run_with_transport(transport).await
135 }
136
137 pub async fn run_with_transport<T: Transport>(mut self, mut transport: T) -> Result<()> {
139 info!("Starting CodePrism MCP server");
140
141 loop {
142 match transport.receive().await? {
143 Some(message) => {
144 if let Some(response) = self.handle_message(message).await? {
145 transport.send(response).await?;
146 }
147 }
148 None => {
149 debug!("Transport closed, shutting down server");
150 break;
151 }
152 }
153 }
154
155 transport.close().await?;
156 info!("Prism MCP server stopped");
157 Ok(())
158 }
159
160 async fn handle_message(
162 &mut self,
163 message: TransportMessage,
164 ) -> Result<Option<TransportMessage>> {
165 match message {
166 TransportMessage::Request(request) => {
167 let response = self.handle_request(request).await;
168 Ok(Some(TransportMessage::Response(response)))
169 }
170 TransportMessage::Notification(notification) => {
171 self.handle_notification(notification).await?;
172 Ok(None) }
174 TransportMessage::Response(_) => {
175 warn!("Received unexpected response message");
176 Ok(None)
177 }
178 }
179 }
180
181 async fn handle_request(&mut self, request: JsonRpcRequest) -> JsonRpcResponse {
183 debug!(
184 "Handling request: method={}, id={:?}",
185 request.method, request.id
186 );
187
188 let result = match request.method.as_str() {
189 "initialize" => self.handle_initialize(request.params).await,
190 "resources/list" => self.handle_resources_list(request.params).await,
191 "resources/read" => self.handle_resources_read(request.params).await,
192 "tools/list" => self.handle_tools_list(request.params).await,
193 "tools/call" => self.handle_tools_call(request.params).await,
194 "prompts/list" => self.handle_prompts_list(request.params).await,
195 "prompts/get" => self.handle_prompts_get(request.params).await,
196 _ => Err(JsonRpcError::method_not_found(&request.method)),
197 };
198
199 match result {
200 Ok(result) => JsonRpcResponse::success(request.id, result),
201 Err(error) => JsonRpcResponse::error(request.id, error),
202 }
203 }
204
205 async fn handle_notification(&mut self, notification: JsonRpcNotification) -> Result<()> {
207 debug!("Handling notification: method={}", notification.method);
208
209 match notification.method.as_str() {
210 "initialized" => {
211 info!("Client reported initialization complete");
212 self.state = ServerState::Ready;
213 }
214 "notifications/cancelled" => {
215 debug!("Received cancellation notification");
216 }
218 _ => {
219 warn!("Unknown notification method: {}", notification.method);
220 }
221 }
222
223 Ok(())
224 }
225
226 async fn handle_initialize(&mut self, params: Option<Value>) -> Result<Value, JsonRpcError> {
228 let params: InitializeParams = params
229 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
230 .try_into_type()
231 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
232
233 info!(
234 "Initializing MCP server with client: {} v{}",
235 params.client_info.name, params.client_info.version
236 );
237
238 self.client_info = Some(params.client_info);
240
241 if params.protocol_version != self.protocol_version {
243 warn!(
244 "Protocol version mismatch: client={}, server={}",
245 params.protocol_version, self.protocol_version
246 );
247 }
248
249 let server = self.codeprism_server.read().await;
251 let result = InitializeResult {
252 protocol_version: self.protocol_version.clone(),
253 capabilities: server.capabilities().clone(),
254 server_info: self.server_info.clone(),
255 };
256
257 serde_json::to_value(result)
258 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
259 }
260
261 async fn handle_resources_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
263 let params = params
264 .map(serde_json::from_value)
265 .transpose()
266 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
267 .unwrap_or(ListResourcesParams { cursor: None });
268
269 let result = self
270 .resource_manager
271 .list_resources(params)
272 .await
273 .map_err(|e| JsonRpcError::internal_error(format!("Resource list error: {}", e)))?;
274
275 serde_json::to_value(result)
276 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
277 }
278
279 async fn handle_resources_read(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
281 let params: ReadResourceParams = params
282 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
283 .try_into_type()
284 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
285
286 let result = self
287 .resource_manager
288 .read_resource(params)
289 .await
290 .map_err(|e| JsonRpcError::internal_error(format!("Resource read error: {}", e)))?;
291
292 serde_json::to_value(result)
293 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
294 }
295
296 async fn handle_tools_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
298 let params = params
299 .map(serde_json::from_value)
300 .transpose()
301 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
302 .unwrap_or(ListToolsParams { cursor: None });
303
304 let result = self
305 .tool_registry
306 .list_tools(params)
307 .await
308 .map_err(|e| JsonRpcError::internal_error(format!("Tool list error: {}", e)))?;
309
310 serde_json::to_value(result)
311 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
312 }
313
314 async fn handle_tools_call(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
316 let params: CallToolParams = params
317 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
318 .try_into_type()
319 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
320
321 let result = self
322 .tool_registry
323 .call_tool(params)
324 .await
325 .map_err(|e| JsonRpcError::internal_error(format!("Tool call error: {}", e)))?;
326
327 serde_json::to_value(result)
328 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
329 }
330
331 async fn handle_prompts_list(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
333 let params = params
334 .map(serde_json::from_value)
335 .transpose()
336 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?
337 .unwrap_or(ListPromptsParams { cursor: None });
338
339 let result = self
340 .prompt_manager
341 .list_prompts(params)
342 .await
343 .map_err(|e| JsonRpcError::internal_error(format!("Prompt list error: {}", e)))?;
344
345 serde_json::to_value(result)
346 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
347 }
348
349 async fn handle_prompts_get(&self, params: Option<Value>) -> Result<Value, JsonRpcError> {
351 let params: GetPromptParams = params
352 .ok_or_else(|| JsonRpcError::invalid_params("Missing parameters".to_string()))?
353 .try_into_type()
354 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
355
356 let result = self
357 .prompt_manager
358 .get_prompt(params)
359 .await
360 .map_err(|e| JsonRpcError::internal_error(format!("Prompt get error: {}", e)))?;
361
362 serde_json::to_value(result)
363 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
364 }
365
366 pub fn state(&self) -> ServerState {
368 self.state.clone()
369 }
370
371 pub fn server_info(&self) -> &ServerInfo {
373 &self.server_info
374 }
375
376 pub fn client_info(&self) -> Option<&ClientInfo> {
378 self.client_info.as_ref()
379 }
380}
381
382impl Default for McpServer {
383 fn default() -> Self {
384 Self::new().expect("Failed to create default MCP server")
385 }
386}
387
388trait TryIntoType<T> {
390 fn try_into_type(self) -> Result<T, serde_json::Error>;
391}
392
393impl<T> TryIntoType<T> for Value
394where
395 T: serde::de::DeserializeOwned,
396{
397 fn try_into_type(self) -> Result<T, serde_json::Error> {
398 serde_json::from_value(self)
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::protocol::ClientCapabilities;
406
407 #[tokio::test]
408 async fn test_mcp_server_creation() {
409 let server = McpServer::new().expect("Failed to create MCP server");
410 assert_eq!(server.state(), ServerState::Uninitialized);
411 assert_eq!(server.server_info().name, "codeprism-mcp");
412 assert_eq!(server.server_info().version, "0.1.0");
413 }
414
415 #[tokio::test]
416 async fn test_initialize_request() {
417 let mut server = McpServer::new().expect("Failed to create MCP server");
418
419 let params = InitializeParams {
420 protocol_version: "2024-11-05".to_string(),
421 capabilities: ClientCapabilities::default(),
422 client_info: ClientInfo {
423 name: "test-client".to_string(),
424 version: "1.0.0".to_string(),
425 },
426 };
427
428 let params_value = serde_json::to_value(params).unwrap();
429 let result = server.handle_initialize(Some(params_value)).await;
430
431 assert!(result.is_ok());
432 assert!(server.client_info().is_some());
433 assert_eq!(server.client_info().unwrap().name, "test-client");
434 }
435
436 #[test]
437 fn test_server_states() {
438 assert_eq!(ServerState::Uninitialized, ServerState::Uninitialized);
439 assert_ne!(ServerState::Uninitialized, ServerState::Ready);
440 assert_ne!(ServerState::Ready, ServerState::Shutdown);
441 }
442
443 async fn create_test_server_with_repository() -> McpServer {
444 use std::fs;
445 use tempfile::TempDir;
446
447 let temp_dir = TempDir::new().expect("Failed to create temp dir");
448 let repo_path = temp_dir.path();
449
450 fs::write(
452 repo_path.join("app.py"),
453 r#"
454"""Main application module."""
455
456import logging
457from typing import List, Optional, Dict, Any
458from dataclasses import dataclass
459
460@dataclass
461class Config:
462 """Application configuration."""
463 database_url: str
464 api_key: str
465 debug: bool = False
466
467class ApplicationService:
468 """Main application service."""
469
470 def __init__(self, config: Config):
471 self.config = config
472 self.logger = logging.getLogger(__name__)
473 self._users: Dict[str, 'User'] = {}
474
475 def create_user(self, username: str, email: str) -> 'User':
476 """Create a new user."""
477 if username in self._users:
478 raise ValueError(f"User {username} already exists")
479
480 user = User(username=username, email=email)
481 self._users[username] = user
482 self.logger.info(f"Created user: {username}")
483 return user
484
485 def get_user(self, username: str) -> Optional['User']:
486 """Get a user by username."""
487 return self._users.get(username)
488
489 def list_users(self) -> List['User']:
490 """List all users."""
491 return list(self._users.values())
492
493 def delete_user(self, username: str) -> bool:
494 """Delete a user."""
495 if username in self._users:
496 del self._users[username]
497 self.logger.info(f"Deleted user: {username}")
498 return True
499 return False
500
501class User:
502 """User model."""
503
504 def __init__(self, username: str, email: str):
505 self.username = username
506 self.email = email
507 self.created_at = None # Would be datetime in real app
508 self.is_active = True
509
510 def deactivate(self) -> None:
511 """Deactivate the user."""
512 self.is_active = False
513
514 def activate(self) -> None:
515 """Activate the user."""
516 self.is_active = True
517
518 def to_dict(self) -> Dict[str, Any]:
519 """Convert user to dictionary."""
520 return {
521 'username': self.username,
522 'email': self.email,
523 'is_active': self.is_active
524 }
525
526def main():
527 """Main application entry point."""
528 config = Config(
529 database_url="postgresql://localhost/myapp",
530 api_key="secret-key"
531 )
532
533 app = ApplicationService(config)
534
535 # Create some sample users
536 app.create_user("alice", "alice@example.com")
537 app.create_user("bob", "bob@example.com")
538
539 # List users
540 users = app.list_users()
541 print(f"Created {len(users)} users")
542
543if __name__ == "__main__":
544 main()
545"#,
546 )
547 .unwrap();
548
549 fs::write(
550 repo_path.join("utils.py"),
551 r#"
552"""Utility functions for the application."""
553
554import re
555import hashlib
556from typing import Optional, Union, List
557from datetime import datetime, timedelta
558
559# Constants
560EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
561PASSWORD_MIN_LENGTH = 8
562
563def validate_email(email: str) -> bool:
564 """Validate email address format."""
565 return bool(EMAIL_REGEX.match(email))
566
567def validate_password(password: str) -> bool:
568 """Validate password strength."""
569 if len(password) < PASSWORD_MIN_LENGTH:
570 return False
571
572 has_upper = any(c.isupper() for c in password)
573 has_lower = any(c.islower() for c in password)
574 has_digit = any(c.isdigit() for c in password)
575
576 return has_upper and has_lower and has_digit
577
578def hash_password(password: str, salt: Optional[str] = None) -> str:
579 """Hash a password with salt."""
580 if salt is None:
581 salt = "default_salt" # In real app, use random salt
582
583 combined = f"{password}{salt}"
584 return hashlib.sha256(combined.encode()).hexdigest()
585
586def generate_token(length: int = 32) -> str:
587 """Generate a random token."""
588 import secrets
589 return secrets.token_hex(length)
590
591class DateUtils:
592 """Utility class for date operations."""
593
594 @staticmethod
595 def now() -> datetime:
596 """Get current datetime."""
597 return datetime.now()
598
599 @staticmethod
600 def add_days(date: datetime, days: int) -> datetime:
601 """Add days to a date."""
602 return date + timedelta(days=days)
603
604 @staticmethod
605 def format_date(date: datetime, format_str: str = "%Y-%m-%d") -> str:
606 """Format a date as string."""
607 return date.strftime(format_str)
608
609def cleanup_string(text: str) -> str:
610 """Clean up a string by removing extra whitespace."""
611 return re.sub(r'\s+', ' ', text.strip())
612
613def parse_config_value(value: str) -> Union[str, int, bool]:
614 """Parse a configuration value to appropriate type."""
615 # Try boolean
616 if value.lower() in ('true', 'false'):
617 return value.lower() == 'true'
618
619 # Try integer
620 try:
621 return int(value)
622 except ValueError:
623 pass
624
625 # Return as string
626 return value
627"#,
628 )
629 .unwrap();
630
631 let server = McpServer::new_with_config(
632 2048, 20, 5, false, vec!["__pycache__".to_string(), ".pytest_cache".to_string()],
637 Some(vec!["py".to_string()]),
638 Some("exclude".to_string()),
639 )
640 .expect("Failed to create MCP server");
641
642 server
643 .initialize_with_repository(repo_path)
644 .await
645 .expect("Failed to initialize repository");
646
647 std::mem::forget(temp_dir);
649
650 server
651 }
652
653 #[tokio::test]
654 async fn test_server_with_repository_initialization() {
655 let server = create_test_server_with_repository().await;
656
657 assert_eq!(server.state(), ServerState::Uninitialized);
659
660 let codeprism_server = server.codeprism_server.read().await;
662 assert!(codeprism_server.repository_path().is_some());
663 }
664
665 #[tokio::test]
666 async fn test_full_mcp_workflow() {
667 let mut server = create_test_server_with_repository().await;
668
669 let init_params = InitializeParams {
671 protocol_version: "2024-11-05".to_string(),
672 capabilities: ClientCapabilities::default(),
673 client_info: ClientInfo {
674 name: "test-client".to_string(),
675 version: "1.0.0".to_string(),
676 },
677 };
678
679 let init_result = server
680 .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
681 .await;
682 assert!(init_result.is_ok());
683
684 assert!(server.client_info().is_some());
686
687 let resources_result = server.handle_resources_list(None).await;
689 assert!(resources_result.is_ok());
690
691 let resources_value = resources_result.unwrap();
692 let resources: crate::resources::ListResourcesResult =
693 serde_json::from_value(resources_value).unwrap();
694 assert!(!resources.resources.is_empty());
695
696 let uris: Vec<String> = resources.resources.iter().map(|r| r.uri.clone()).collect();
698 assert!(uris.iter().any(|uri| uri == "codeprism://repository/stats"));
699 assert!(uris.iter().any(|uri| uri == "codeprism://graph/repository"));
700 assert!(uris.iter().any(|uri| uri.contains("app.py")));
701
702 let read_params = crate::resources::ReadResourceParams {
704 uri: "codeprism://repository/stats".to_string(),
705 };
706 let read_result = server
707 .handle_resources_read(Some(serde_json::to_value(read_params).unwrap()))
708 .await;
709 assert!(read_result.is_ok());
710
711 let tools_result = server.handle_tools_list(None).await;
713 assert!(tools_result.is_ok());
714
715 let tools_value = tools_result.unwrap();
716 let tools: crate::tools::ListToolsResult = serde_json::from_value(tools_value).unwrap();
717 assert_eq!(tools.tools.len(), 23); let tool_params = crate::tools::CallToolParams {
721 name: "repository_stats".to_string(),
722 arguments: Some(serde_json::json!({})),
723 };
724 let tool_result = server
725 .handle_tools_call(Some(serde_json::to_value(tool_params).unwrap()))
726 .await;
727 assert!(tool_result.is_ok());
728
729 let prompts_result = server.handle_prompts_list(None).await;
731 assert!(prompts_result.is_ok());
732
733 let prompts_value = prompts_result.unwrap();
734 let prompts: crate::prompts::ListPromptsResult =
735 serde_json::from_value(prompts_value).unwrap();
736 assert_eq!(prompts.prompts.len(), 16); let prompt_params = crate::prompts::GetPromptParams {
740 name: "repository_overview".to_string(),
741 arguments: Some(serde_json::Map::from_iter([(
742 "focus_area".to_string(),
743 serde_json::Value::String("architecture".to_string()),
744 )])),
745 };
746 let prompt_result = server
747 .handle_prompts_get(Some(serde_json::to_value(prompt_params).unwrap()))
748 .await;
749 assert!(prompt_result.is_ok());
750 }
751
752 #[tokio::test]
753 async fn test_request_handling_errors() {
754 let mut server = McpServer::new().expect("Failed to create MCP server");
755
756 let invalid_request = JsonRpcRequest {
758 jsonrpc: "2.0".to_string(),
759 id: serde_json::Value::Number(1.into()),
760 method: "invalid_method".to_string(),
761 params: None,
762 };
763
764 let response = server.handle_request(invalid_request).await;
765 assert!(response.error.is_some());
766 assert_eq!(response.error.unwrap().code, -32601); let missing_params_request = JsonRpcRequest {
770 jsonrpc: "2.0".to_string(),
771 id: serde_json::Value::Number(2.into()),
772 method: "resources/read".to_string(),
773 params: None, };
775
776 let response = server.handle_request(missing_params_request).await;
777 assert!(response.error.is_some());
778 assert_eq!(response.error.unwrap().code, -32602); }
780
781 #[tokio::test]
782 async fn test_notification_handling() {
783 let mut server = McpServer::new().expect("Failed to create MCP server");
784
785 let initialized_notification = JsonRpcNotification {
787 jsonrpc: "2.0".to_string(),
788 method: "initialized".to_string(),
789 params: None,
790 };
791
792 assert_eq!(server.state(), ServerState::Uninitialized);
793
794 let result = server.handle_notification(initialized_notification).await;
795 assert!(result.is_ok());
796 assert_eq!(server.state(), ServerState::Ready);
797
798 let unknown_notification = JsonRpcNotification {
800 jsonrpc: "2.0".to_string(),
801 method: "unknown_notification".to_string(),
802 params: None,
803 };
804
805 let result = server.handle_notification(unknown_notification).await;
806 assert!(result.is_ok()); }
808
809 #[tokio::test]
810 async fn test_message_handling() {
811 let mut server = McpServer::new().expect("Failed to create MCP server");
812
813 let request_message = crate::transport::TransportMessage::Request(JsonRpcRequest {
815 jsonrpc: "2.0".to_string(),
816 id: serde_json::Value::Number(1.into()),
817 method: "initialize".to_string(),
818 params: Some(serde_json::json!({
819 "protocolVersion": "2024-11-05",
820 "capabilities": {},
821 "clientInfo": {
822 "name": "test-client",
823 "version": "1.0.0"
824 }
825 })),
826 });
827
828 let response = server.handle_message(request_message).await;
829 assert!(response.is_ok());
830 assert!(response.unwrap().is_some()); let notification_message =
834 crate::transport::TransportMessage::Notification(JsonRpcNotification {
835 jsonrpc: "2.0".to_string(),
836 method: "initialized".to_string(),
837 params: None,
838 });
839
840 let response = server.handle_message(notification_message).await;
841 assert!(response.is_ok());
842 assert!(response.unwrap().is_none()); }
844
845 #[tokio::test]
846 async fn test_server_capabilities_validation() {
847 let server = create_test_server_with_repository().await;
848 let codeprism_server = server.codeprism_server.read().await;
849 let capabilities = codeprism_server.capabilities();
850
851 assert!(capabilities.resources.is_some());
853 assert!(capabilities.tools.is_some());
854 assert!(capabilities.prompts.is_some());
855
856 let resource_caps = capabilities.resources.as_ref().unwrap();
858 assert_eq!(resource_caps.subscribe, Some(true));
859 assert_eq!(resource_caps.list_changed, Some(true));
860
861 let tool_caps = capabilities.tools.as_ref().unwrap();
863 assert_eq!(tool_caps.list_changed, Some(true));
864
865 let prompt_caps = capabilities.prompts.as_ref().unwrap();
867 assert_eq!(prompt_caps.list_changed, Some(false));
868 }
869
870 #[tokio::test]
871 async fn test_concurrent_requests() {
872 use std::sync::Arc;
873 use tokio::sync::RwLock;
874
875 let server = Arc::new(RwLock::new(create_test_server_with_repository().await));
876
877 {
879 let mut server_lock = server.write().await;
880 let init_params = InitializeParams {
881 protocol_version: "2024-11-05".to_string(),
882 capabilities: ClientCapabilities::default(),
883 client_info: ClientInfo {
884 name: "test-client".to_string(),
885 version: "1.0.0".to_string(),
886 },
887 };
888
889 server_lock
890 .handle_initialize(Some(serde_json::to_value(init_params).unwrap()))
891 .await
892 .unwrap();
893 }
894
895 let mut handles = Vec::new();
897
898 for i in 0..5 {
899 let server_clone = server.clone();
900 let handle = tokio::spawn(async move {
901 let server_lock = server_clone.write().await;
902
903 let resources_result = server_lock.handle_resources_list(None).await;
905 assert!(resources_result.is_ok());
906
907 let tools_result = server_lock.handle_tools_list(None).await;
909 assert!(tools_result.is_ok());
910
911 i });
913
914 handles.push(handle);
915 }
916
917 for handle in handles {
919 let result = handle.await;
920 assert!(result.is_ok());
921 }
922 }
923
924 #[test]
925 fn test_server_info_serialization() {
926 let server_info = ServerInfo {
927 name: "test-server".to_string(),
928 version: "1.0.0".to_string(),
929 };
930
931 let json = serde_json::to_string(&server_info).unwrap();
932 let deserialized: ServerInfo = serde_json::from_str(&json).unwrap();
933
934 assert_eq!(server_info.name, deserialized.name);
935 assert_eq!(server_info.version, deserialized.version);
936 }
937
938 #[test]
939 fn test_client_info_serialization() {
940 let client_info = ClientInfo {
941 name: "test-client".to_string(),
942 version: "2.0.0".to_string(),
943 };
944
945 let json = serde_json::to_string(&client_info).unwrap();
946 let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
947
948 assert_eq!(client_info.name, deserialized.name);
949 assert_eq!(client_info.version, deserialized.version);
950 }
951}