use chrono::Duration;
use mcp_execution_core::{ServerConfig, ServerId, ToolName};
use mcp_execution_introspector::{ServerCapabilities, ServerInfo, ToolInfo};
use mcp_execution_server::{CategorizedTool, GeneratorService, PendingGeneration, StateManager};
use rmcp::handler::server::ServerHandler;
use std::sync::Arc;
#[test]
fn test_service_info_has_correct_capabilities() {
let service = GeneratorService::new();
let info = service.get_info();
assert_eq!(
info.protocol_version,
rmcp::model::ProtocolVersion::V_2025_06_18
);
assert!(info.capabilities.tools.is_some());
let tools = info.capabilities.tools.as_ref().unwrap();
assert!(!tools.list_changed.unwrap_or(false));
assert!(info.instructions.is_some());
let instructions = info.instructions.unwrap();
assert!(instructions.contains("progressive loading"));
assert!(instructions.contains("introspect_server"));
assert!(instructions.contains("save_categorized_tools"));
}
#[test]
fn test_service_new_creates_fresh_instance() {
let service1 = GeneratorService::new();
let service2 = GeneratorService::new();
let info1 = service1.get_info();
let info2 = service2.get_info();
assert_eq!(info1.protocol_version, info2.protocol_version);
assert!(info1.capabilities.tools.is_some());
assert!(info2.capabilities.tools.is_some());
}
#[tokio::test]
async fn test_state_manager_workflow() {
let state = StateManager::new();
let server_id = ServerId::new("test-server");
let server_info = create_test_server_info(server_id.clone());
let config = ServerConfig::builder().command("echo".to_string()).build();
let output_dir = std::env::temp_dir().join("mcp-server-test");
let pending = PendingGeneration::new(server_id, server_info, config, output_dir);
let session_id = state.store(pending.clone()).await;
assert_eq!(state.pending_count().await, 1);
let retrieved = state.take(session_id).await;
assert!(retrieved.is_some(), "Should retrieve stored session");
assert_eq!(retrieved.unwrap().server_id, pending.server_id);
assert_eq!(state.pending_count().await, 0);
let second = state.take(session_id).await;
assert!(second.is_none(), "Session should be consumed");
}
#[tokio::test]
async fn test_multiple_concurrent_sessions() {
let state = StateManager::new();
let mut sessions = Vec::new();
for i in 0..5 {
let server_id = ServerId::new(&format!("server-{i}"));
let server_info = ServerInfo {
id: server_id.clone(),
name: format!("Server {i}"),
version: "1.0.0".to_string(),
capabilities: ServerCapabilities {
supports_tools: true,
supports_resources: false,
supports_prompts: false,
},
tools: vec![],
};
let config = ServerConfig::builder().command("echo".to_string()).build();
let output_dir = std::env::temp_dir().join(format!("mcp-test-{i}"));
let pending = PendingGeneration::new(server_id, server_info, config, output_dir);
let session_id = state.store(pending).await;
sessions.push(session_id);
}
assert_eq!(state.pending_count().await, 5);
for session_id in sessions {
let retrieved = state.take(session_id).await;
assert!(retrieved.is_some(), "Session should be retrievable");
}
assert_eq!(state.pending_count().await, 0);
}
#[tokio::test]
async fn test_state_manager_handles_expiration() {
let state = StateManager::new();
let server_id = ServerId::new("test");
let server_info = create_test_server_info(server_id.clone());
let config = ServerConfig::builder().command("echo".to_string()).build();
let output_dir = std::env::temp_dir().join("mcp-expire-test");
let mut pending = PendingGeneration::new(server_id, server_info, config, output_dir);
pending.expires_at = chrono::Utc::now() - Duration::hours(1);
let session_id = state.store(pending).await;
let retrieved = state.take(session_id).await;
assert!(
retrieved.is_none(),
"Expired session should not be retrievable"
);
assert_eq!(state.pending_count().await, 0);
}
#[tokio::test]
async fn test_state_manager_lazy_cleanup() {
let state = StateManager::new();
let valid_pending = create_test_pending("valid-server");
state.store(valid_pending).await;
let mut expired_pending = create_test_pending("expired-server");
expired_pending.expires_at = chrono::Utc::now() - Duration::hours(1);
state.store(expired_pending).await;
assert_eq!(state.pending_count().await, 1);
let removed = state.cleanup_expired().await;
assert_eq!(removed, 1, "Should remove 1 expired session");
}
#[tokio::test]
async fn test_state_manager_get_without_consuming() {
let state = StateManager::new();
let pending = create_test_pending("test");
let session_id = state.store(pending).await;
let first = state.get(session_id).await;
assert!(first.is_some());
let second = state.get(session_id).await;
assert!(second.is_some());
let taken = state.take(session_id).await;
assert!(taken.is_some());
let gone = state.get(session_id).await;
assert!(gone.is_none());
}
#[test]
fn test_pending_generation_not_expired_initially() {
let pending = create_test_pending("test");
assert!(!pending.is_expired());
}
#[test]
fn test_pending_generation_expires_correctly() {
let mut pending = create_test_pending("test");
pending.expires_at = chrono::Utc::now() - Duration::minutes(1);
assert!(pending.is_expired());
}
#[test]
fn test_pending_generation_has_correct_timeout() {
let pending = create_test_pending("test");
let duration = pending.expires_at - pending.created_at;
let minutes = duration.num_minutes();
assert_eq!(
minutes,
PendingGeneration::DEFAULT_TIMEOUT_MINUTES,
"Should use default timeout"
);
}
#[test]
fn test_categorized_tool_serialization_roundtrip() {
let tool = CategorizedTool {
name: "test_tool".to_string(),
category: "testing".to_string(),
keywords: "test,tool,demo".to_string(),
short_description: "A test tool".to_string(),
};
let json = serde_json::to_string(&tool).unwrap();
let deserialized: CategorizedTool = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, tool.name);
assert_eq!(deserialized.category, tool.category);
assert_eq!(deserialized.keywords, tool.keywords);
assert_eq!(deserialized.short_description, tool.short_description);
}
#[tokio::test]
async fn test_state_manager_concurrent_access() {
let state = Arc::new(StateManager::new());
let mut handles = vec![];
for i in 0..10 {
let state_clone = Arc::clone(&state);
handles.push(tokio::spawn(async move {
let pending = create_test_pending(&format!("server-{i}"));
state_clone.store(pending).await
}));
}
let mut session_ids = Vec::new();
for handle in handles {
let session_id = handle.await.unwrap();
session_ids.push(session_id);
}
assert_eq!(state.pending_count().await, 10);
let unique_count = session_ids
.iter()
.collect::<std::collections::HashSet<_>>()
.len();
assert_eq!(unique_count, 10, "All session IDs should be unique");
}
#[tokio::test]
async fn test_state_manager_concurrent_read_write() {
let state = Arc::new(StateManager::new());
let pending = create_test_pending("test");
let session_id = state.store(pending).await;
let state_clone1 = Arc::clone(&state);
let state_clone2 = Arc::clone(&state);
let handle1 = tokio::spawn(async move { state_clone1.get(session_id).await });
let handle2 = tokio::spawn(async move { state_clone2.get(session_id).await });
let result1 = handle1.await.unwrap();
let result2 = handle2.await.unwrap();
assert!(result1.is_some());
assert!(result2.is_some());
}
fn create_test_server_info(server_id: ServerId) -> ServerInfo {
ServerInfo {
id: server_id,
name: "Test Server".to_string(),
version: "1.0.0".to_string(),
capabilities: ServerCapabilities {
supports_tools: true,
supports_resources: false,
supports_prompts: false,
},
tools: vec![ToolInfo {
name: ToolName::new("test_tool"),
description: "A test tool".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"message": {"type": "string"}
},
"required": ["message"]
}),
output_schema: None,
}],
}
}
fn create_test_pending(server_id_str: &str) -> PendingGeneration {
let server_id = ServerId::new(server_id_str);
let server_info = create_test_server_info(server_id.clone());
let config = ServerConfig::builder().command("echo".to_string()).build();
let output_dir = std::env::temp_dir().join(format!("mcp-test-{server_id_str}"));
PendingGeneration::new(server_id, server_info, config, output_dir)
}