use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::time::timeout;
use tracing::{debug, info};
use mcp_probe_core::{
error::McpError,
messages::{
core::{JsonRpcId, JsonRpcRequest},
initialization::{InitializeRequest, InitializeResponse},
prompts::{ListPromptsRequest, Prompt},
resources::{ListResourcesRequest, Resource},
tools::{ListToolsRequest, Tool},
Capabilities, Implementation, ProtocolVersion,
},
transport::{Transport, TransportConfig, TransportFactory},
validation::ParameterValidator,
};
pub struct ValidationEngine {
transport_config: TransportConfig,
config: ValidationConfig,
results: Vec<ValidationResult>,
start_time: Option<Instant>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
pub test_timeout: Duration,
pub total_timeout: Duration,
pub strict_schema_validation: bool,
pub test_error_conditions: bool,
pub validate_tool_schemas: bool,
pub test_capability_discovery: bool,
pub max_tools_to_test: usize,
pub custom_rules: Vec<String>,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
test_timeout: Duration::from_secs(30),
total_timeout: Duration::from_secs(300),
strict_schema_validation: true,
test_error_conditions: true,
validate_tool_schemas: true,
test_capability_discovery: true,
max_tools_to_test: 10,
custom_rules: vec![],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
pub test_id: String,
pub test_name: String,
pub category: ValidationCategory,
pub status: ValidationStatus,
pub message: String,
pub details: Option<Value>,
pub duration: Duration,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ValidationCategory {
Protocol,
Initialization,
Tools,
Resources,
Prompts,
ErrorHandling,
Performance,
Security,
Schema,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ValidationStatus {
Pass,
Info,
Warning,
Error,
Critical,
Skipped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationReport {
pub metadata: ReportMetadata,
pub summary: ValidationSummary,
pub results: Vec<ValidationResult>,
pub server_info: Option<ServerInfo>,
pub performance: PerformanceMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReportMetadata {
pub generated_at: DateTime<Utc>,
pub validator_version: String,
pub transport_type: String,
pub total_duration: Duration,
pub config: ValidationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationSummary {
pub total_tests: usize,
pub passed: usize,
pub info: usize,
pub warnings: usize,
pub errors: usize,
pub critical: usize,
pub skipped: usize,
pub compliance_percentage: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
pub version: String,
pub protocol_version: String,
pub capabilities: ServerCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCapabilities {
pub tools: Option<ToolsCapability>,
pub resources: Option<ResourcesCapability>,
pub prompts: Option<PromptsCapability>,
pub logging: Option<LoggingCapability>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsCapability {
pub list_changed: Option<bool>,
pub available_tools: Vec<Tool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourcesCapability {
pub subscribe: Option<bool>,
pub list_changed: Option<bool>,
pub available_resources: Vec<Resource>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptsCapability {
pub list_changed: Option<bool>,
pub available_prompts: Vec<Prompt>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingCapability {
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub initialization_time: Duration,
pub average_request_time: Duration,
pub total_requests: usize,
pub failed_requests: usize,
pub timeouts: usize,
}
impl ValidationEngine {
pub fn new(transport_config: TransportConfig) -> Self {
Self {
transport_config,
config: ValidationConfig::default(),
results: Vec::new(),
start_time: None,
}
}
pub fn with_config(mut self, config: ValidationConfig) -> Self {
self.config = config;
self
}
pub async fn validate(&mut self) -> Result<ValidationReport> {
info!("Starting comprehensive MCP server validation");
self.start_time = Some(Instant::now());
let validation_result =
timeout(self.config.total_timeout, self.run_validation_suite()).await;
match validation_result {
Ok(result) => result,
Err(_) => {
self.add_result(ValidationResult {
test_id: "timeout".to_string(),
test_name: "Overall Validation Timeout".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Critical,
message: format!("Validation timed out after {:?}", self.config.total_timeout),
details: None,
duration: self.config.total_timeout,
timestamp: Utc::now(),
});
self.generate_report()
}
}
}
async fn run_validation_suite(&mut self) -> Result<ValidationReport> {
let mut transport = self.create_transport().await?;
let server_info = self.test_initialization(&mut transport).await?;
self.test_protocol_compliance(&mut transport).await?;
if self.config.test_capability_discovery {
self.test_capability_discovery(&mut transport).await?;
}
self.test_transport_features(&mut transport).await?;
if let Some(tools_cap) = server_info
.as_ref()
.and_then(|si| si.capabilities.tools.as_ref())
{
self.test_tools(&mut transport, &tools_cap.available_tools)
.await?;
}
if let Some(resources_cap) = server_info
.as_ref()
.and_then(|si| si.capabilities.resources.as_ref())
{
self.test_resources(&mut transport, &resources_cap.available_resources)
.await?;
}
if let Some(prompts_cap) = server_info
.as_ref()
.and_then(|si| si.capabilities.prompts.as_ref())
{
self.test_prompts(&mut transport, &prompts_cap.available_prompts)
.await?;
}
if self.config.test_error_conditions {
self.test_error_handling(&mut transport).await?;
}
if self.config.strict_schema_validation {
self.test_schema_validation().await?;
}
info!("Validation suite completed successfully");
self.generate_report()
}
async fn create_transport(&mut self) -> Result<Box<dyn Transport>> {
let test_start = Instant::now();
let result = async {
let mut transport = TransportFactory::create(self.transport_config.clone()).await?;
transport.connect().await?;
Ok::<_, McpError>(transport)
}
.await;
match result {
Ok(transport) => {
self.add_result(ValidationResult {
test_id: "transport_connection".to_string(),
test_name: "Transport Connection".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Pass,
message: format!(
"Successfully connected via {}",
self.transport_config.transport_type()
),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(transport)
}
Err(e) => {
self.add_result(ValidationResult {
test_id: "transport_connection".to_string(),
test_name: "Transport Connection".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Critical,
message: format!("Failed to connect: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Err(e.into())
}
}
}
async fn test_initialization(
&mut self,
transport: &mut Box<dyn Transport>,
) -> Result<Option<ServerInfo>> {
let test_start = Instant::now();
info!("Testing MCP initialization sequence");
let client_info = Implementation {
name: "mcp-probe-validator".to_string(),
version: "1.0.0".to_string(),
metadata: HashMap::new(),
};
let init_request = InitializeRequest {
protocol_version: ProtocolVersion::V2024_11_05,
capabilities: Capabilities::default(),
client_info,
};
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("init_1".to_string()),
method: "initialize".to_string(),
params: Some(serde_json::to_value(init_request)?),
};
let result = timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await;
match result {
Ok(Ok(response)) => {
if let Some(result_value) = response.result {
match serde_json::from_value::<InitializeResponse>(result_value) {
Ok(init_response) => {
let server_info = ServerInfo {
name: init_response.server_info.name.clone(),
version: init_response.server_info.version.clone(),
protocol_version: init_response.protocol_version.to_string(),
capabilities: ServerCapabilities {
tools: init_response.capabilities.standard.tools.map(|t| {
ToolsCapability {
list_changed: t.list_changed,
available_tools: vec![], }
}),
resources: init_response.capabilities.standard.resources.map(
|r| ResourcesCapability {
subscribe: r.subscribe,
list_changed: r.list_changed,
available_resources: vec![], },
),
prompts: init_response.capabilities.standard.prompts.map(|p| {
PromptsCapability {
list_changed: p.list_changed,
available_prompts: vec![], }
}),
logging: init_response
.capabilities
.standard
.logging
.map(|_| LoggingCapability { enabled: true }),
},
};
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Pass,
message: format!(
"Successfully initialized with {} v{}",
server_info.name, server_info.version
),
details: Some(serde_json::to_value(&server_info)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(Some(server_info))
}
Err(e) => {
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Error,
message: format!("Invalid initialization response: {}", e),
details: Some(json!({"parse_error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(None)
}
}
} else if let Some(error) = response.error {
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Error,
message: format!(
"Server returned error: {} - {}",
error.code, error.message
),
details: Some(serde_json::to_value(error)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(None)
} else {
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Error,
message: "Response missing both result and error".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(None)
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Critical,
message: format!("Transport error during initialization: {}", e),
details: Some(json!({"transport_error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Err(e.into())
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "initialization".to_string(),
test_name: "MCP Initialization".to_string(),
category: ValidationCategory::Initialization,
status: ValidationStatus::Critical,
message: format!(
"Initialization timed out after {:?}",
self.config.test_timeout
),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Err(anyhow::anyhow!("Initialization timeout"))
}
}
}
async fn test_protocol_compliance(
&mut self,
_transport: &mut Box<dyn Transport>,
) -> Result<()> {
self.add_result(ValidationResult {
test_id: "json_rpc_compliance".to_string(),
test_name: "JSON-RPC 2.0 Compliance".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Pass,
message: "All messages follow JSON-RPC 2.0 specification".to_string(),
details: None,
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
Ok(())
}
async fn test_capability_discovery(
&mut self,
transport: &mut Box<dyn Transport>,
) -> Result<()> {
info!("Testing capability discovery");
self.test_tools_listing(transport).await?;
self.test_resources_listing(transport).await?;
self.test_prompts_listing(transport).await?;
Ok(())
}
async fn test_transport_features(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
info!("Testing transport-specific features");
let transport_info = transport.get_info();
let transport_type = &transport_info.transport_type;
self.add_result(ValidationResult {
test_id: "transport_info".to_string(),
test_name: "Transport Information".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Pass,
message: format!("Using {} transport", transport_type),
details: Some(serde_json::to_value(&transport_info)?),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
if transport_type == "streamable-http" {
self.test_streamable_http_features(transport).await?;
}
if transport.is_connected() {
self.add_result(ValidationResult {
test_id: "connection_stability".to_string(),
test_name: "Connection Stability".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Pass,
message: "Transport connection is stable".to_string(),
details: None,
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "connection_stability".to_string(),
test_name: "Connection Stability".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Error,
message: "Transport connection is not stable".to_string(),
details: None,
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
Ok(())
}
async fn test_streamable_http_features(
&mut self,
transport: &mut Box<dyn Transport>,
) -> Result<()> {
let transport_info = transport.get_info();
if let Some(session_id) = transport_info.metadata.get("session_id") {
if !session_id.is_null() {
self.add_result(ValidationResult {
test_id: "session_management".to_string(),
test_name: "Session Management".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Pass,
message: "Session ID properly managed".to_string(),
details: Some(json!({"session_id_present": true})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "session_management".to_string(),
test_name: "Session Management".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Warning,
message: "No session ID found - server may not support sessions".to_string(),
details: Some(json!({"session_id_present": false})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
}
if let Some(can_resume) = transport_info.metadata.get("can_resume") {
if can_resume.as_bool().unwrap_or(false) {
self.add_result(ValidationResult {
test_id: "resumability_support".to_string(),
test_name: "Resumability Support".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Pass,
message: "Transport supports connection resumability".to_string(),
details: Some(
json!({"last_event_id": transport_info.metadata.get("last_event_id")}),
),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "resumability_support".to_string(),
test_name: "Resumability Support".to_string(),
category: ValidationCategory::Protocol,
status: ValidationStatus::Info,
message:
"Transport does not currently support resumability (normal for simple HTTP)"
.to_string(),
details: None,
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
}
if let Some(security_enabled) = transport_info.metadata.get("security_enabled") {
if security_enabled.as_bool().unwrap_or(false) {
self.add_result(ValidationResult {
test_id: "security_features".to_string(),
test_name: "Security Features".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Pass,
message: "Security validation is enabled".to_string(),
details: Some(json!({
"enforce_https": transport_info.metadata.get("enforce_https"),
"localhost_only": transport_info.metadata.get("localhost_only")
})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
}
if let Some(base_url) = transport_info.metadata.get("base_url") {
if let Some(url_str) = base_url.as_str() {
if url_str.starts_with("https://") {
self.add_result(ValidationResult {
test_id: "https_usage".to_string(),
test_name: "HTTPS Usage".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Pass,
message: "Using secure HTTPS connection".to_string(),
details: Some(json!({"url": url_str})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
} else if url_str.starts_with("http://localhost")
|| url_str.starts_with("http://127.0.0.1")
{
self.add_result(ValidationResult {
test_id: "https_usage".to_string(),
test_name: "HTTPS Usage".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Info,
message: "Using HTTP for localhost (acceptable for development)"
.to_string(),
details: Some(json!({"url": url_str})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "https_usage".to_string(),
test_name: "HTTPS Usage".to_string(),
category: ValidationCategory::Security,
status: ValidationStatus::Warning,
message: "Using insecure HTTP for non-localhost connection".to_string(),
details: Some(json!({"url": url_str})),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
}
}
Ok(())
}
async fn test_tools_listing(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
let test_start = Instant::now();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("list_tools_1".to_string()),
method: "tools/list".to_string(),
params: Some(serde_json::to_value(ListToolsRequest { cursor: None })?),
};
match timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await
{
Ok(Ok(response)) => {
if response.result.is_some() {
self.add_result(ValidationResult {
test_id: "tools_listing".to_string(),
test_name: "Tools Listing".to_string(),
category: ValidationCategory::Tools,
status: ValidationStatus::Pass,
message: "Successfully retrieved tools list".to_string(),
details: response.result,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "tools_listing".to_string(),
test_name: "Tools Listing".to_string(),
category: ValidationCategory::Tools,
status: ValidationStatus::Warning,
message: "Tools listing returned no results".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "tools_listing".to_string(),
test_name: "Tools Listing".to_string(),
category: ValidationCategory::Tools,
status: ValidationStatus::Warning,
message: format!("Tools listing not supported: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "tools_listing".to_string(),
test_name: "Tools Listing".to_string(),
category: ValidationCategory::Tools,
status: ValidationStatus::Error,
message: "Tools listing timed out".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(())
}
async fn test_resources_listing(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
let test_start = Instant::now();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("list_resources_1".to_string()),
method: "resources/list".to_string(),
params: Some(serde_json::to_value(ListResourcesRequest { cursor: None })?),
};
match timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await
{
Ok(Ok(response)) => {
if response.result.is_some() {
self.add_result(ValidationResult {
test_id: "resources_listing".to_string(),
test_name: "Resources Listing".to_string(),
category: ValidationCategory::Resources,
status: ValidationStatus::Pass,
message: "Successfully retrieved resources list".to_string(),
details: response.result,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "resources_listing".to_string(),
test_name: "Resources Listing".to_string(),
category: ValidationCategory::Resources,
status: ValidationStatus::Warning,
message: "Resources listing returned no results".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "resources_listing".to_string(),
test_name: "Resources Listing".to_string(),
category: ValidationCategory::Resources,
status: ValidationStatus::Warning,
message: format!("Resources listing not supported: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "resources_listing".to_string(),
test_name: "Resources Listing".to_string(),
category: ValidationCategory::Resources,
status: ValidationStatus::Error,
message: "Resources listing timed out".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(())
}
async fn test_prompts_listing(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
let test_start = Instant::now();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("list_prompts_1".to_string()),
method: "prompts/list".to_string(),
params: Some(serde_json::to_value(ListPromptsRequest { cursor: None })?),
};
match timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await
{
Ok(Ok(response)) => {
if response.result.is_some() {
self.add_result(ValidationResult {
test_id: "prompts_listing".to_string(),
test_name: "Prompts Listing".to_string(),
category: ValidationCategory::Prompts,
status: ValidationStatus::Pass,
message: "Successfully retrieved prompts list".to_string(),
details: response.result,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "prompts_listing".to_string(),
test_name: "Prompts Listing".to_string(),
category: ValidationCategory::Prompts,
status: ValidationStatus::Warning,
message: "Prompts listing returned no results".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "prompts_listing".to_string(),
test_name: "Prompts Listing".to_string(),
category: ValidationCategory::Prompts,
status: ValidationStatus::Warning,
message: format!("Prompts listing not supported: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "prompts_listing".to_string(),
test_name: "Prompts Listing".to_string(),
category: ValidationCategory::Prompts,
status: ValidationStatus::Error,
message: "Prompts listing timed out".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(())
}
async fn test_tools(
&mut self,
_transport: &mut Box<dyn Transport>,
tools: &[Tool],
) -> Result<()> {
info!("Testing {} tools", tools.len());
let tools_to_test = tools.iter().take(self.config.max_tools_to_test);
for tool in tools_to_test {
if let Some(ref schema) = tool.input_schema {
self.validate_tool_schema(&tool.name, schema).await?;
}
self.add_result(ValidationResult {
test_id: format!("tool_{}", tool.name),
test_name: format!("Tool: {}", tool.name),
category: ValidationCategory::Tools,
status: ValidationStatus::Pass,
message: format!("Tool '{}' is properly defined", tool.name),
details: Some(serde_json::to_value(tool)?),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
Ok(())
}
async fn validate_tool_schema(&mut self, tool_name: &str, schema: &Value) -> Result<()> {
let test_start = Instant::now();
let validator = ParameterValidator::new();
let dummy_params = json!({}); let result = validator.validate(schema, &dummy_params);
if result.errors.is_empty() {
self.add_result(ValidationResult {
test_id: format!("tool_schema_{}", tool_name),
test_name: format!("Tool Schema: {}", tool_name),
category: ValidationCategory::Schema,
status: ValidationStatus::Pass,
message: format!("Tool '{}' has valid JSON Schema structure", tool_name),
details: Some(schema.clone()),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
let error_messages: Vec<String> = result.errors.iter().map(|e| e.to_string()).collect();
self.add_result(ValidationResult {
test_id: format!("tool_schema_{}", tool_name),
test_name: format!("Tool Schema: {}", tool_name),
category: ValidationCategory::Schema,
status: ValidationStatus::Error,
message: format!(
"Tool '{}' has schema validation issues: {}",
tool_name,
error_messages.join(", ")
),
details: Some(json!({"schema": schema, "errors": error_messages})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Ok(())
}
async fn test_resources(
&mut self,
_transport: &mut Box<dyn Transport>,
resources: &[Resource],
) -> Result<()> {
info!("Testing {} resources", resources.len());
for resource in resources {
self.add_result(ValidationResult {
test_id: format!("resource_{}", resource.uri),
test_name: format!("Resource: {}", resource.name),
category: ValidationCategory::Resources,
status: ValidationStatus::Pass,
message: format!("Resource '{}' is properly defined", resource.name),
details: Some(serde_json::to_value(resource)?),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
Ok(())
}
async fn test_prompts(
&mut self,
_transport: &mut Box<dyn Transport>,
prompts: &[Prompt],
) -> Result<()> {
info!("Testing {} prompts", prompts.len());
for prompt in prompts {
self.add_result(ValidationResult {
test_id: format!("prompt_{}", prompt.name),
test_name: format!("Prompt: {}", prompt.name),
category: ValidationCategory::Prompts,
status: ValidationStatus::Pass,
message: format!("Prompt '{}' is properly defined", prompt.name),
details: Some(serde_json::to_value(prompt)?),
duration: Duration::from_millis(1),
timestamp: Utc::now(),
});
}
Ok(())
}
async fn test_error_handling(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
info!("Testing error handling");
self.test_invalid_method(transport).await?;
self.test_invalid_parameters(transport).await?;
Ok(())
}
async fn test_invalid_method(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
let test_start = Instant::now();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("invalid_method_1".to_string()),
method: "invalid/nonexistent/method".to_string(),
params: None,
};
match timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await
{
Ok(Ok(response)) => {
if let Some(error) = response.error {
if error.code == -32601 {
self.add_result(ValidationResult {
test_id: "invalid_method_handling".to_string(),
test_name: "Invalid Method Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Pass,
message: "Server correctly rejects invalid methods".to_string(),
details: Some(serde_json::to_value(error)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "invalid_method_handling".to_string(),
test_name: "Invalid Method Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Warning,
message: format!(
"Server returned unexpected error code: {}",
error.code
),
details: Some(serde_json::to_value(error)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
} else {
self.add_result(ValidationResult {
test_id: "invalid_method_handling".to_string(),
test_name: "Invalid Method Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Error,
message: "Server should return error for invalid methods".to_string(),
details: Some(serde_json::to_value(response)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "invalid_method_handling".to_string(),
test_name: "Invalid Method Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Warning,
message: format!("Transport error during invalid method test: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "invalid_method_handling".to_string(),
test_name: "Invalid Method Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Error,
message: "Invalid method test timed out".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(())
}
async fn test_invalid_parameters(&mut self, transport: &mut Box<dyn Transport>) -> Result<()> {
let test_start = Instant::now();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::String("invalid_params_1".to_string()),
method: "initialize".to_string(),
params: Some(json!({"invalid": "parameters"})), };
match timeout(
self.config.test_timeout,
transport.send_request(request, Some(self.config.test_timeout)),
)
.await
{
Ok(Ok(response)) => {
if let Some(error) = response.error {
if error.code == -32602 {
self.add_result(ValidationResult {
test_id: "invalid_params_handling".to_string(),
test_name: "Invalid Parameters Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Pass,
message: "Server correctly rejects invalid parameters".to_string(),
details: Some(serde_json::to_value(error)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
} else {
self.add_result(ValidationResult {
test_id: "invalid_params_handling".to_string(),
test_name: "Invalid Parameters Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Warning,
message: format!(
"Server returned unexpected error code: {}",
error.code
),
details: Some(serde_json::to_value(error)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
} else {
self.add_result(ValidationResult {
test_id: "invalid_params_handling".to_string(),
test_name: "Invalid Parameters Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Error,
message: "Server should return error for invalid parameters".to_string(),
details: Some(serde_json::to_value(response)?),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(Err(e)) => {
self.add_result(ValidationResult {
test_id: "invalid_params_handling".to_string(),
test_name: "Invalid Parameters Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Warning,
message: format!("Transport error during invalid parameters test: {}", e),
details: Some(json!({"error": e.to_string()})),
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
Err(_) => {
self.add_result(ValidationResult {
test_id: "invalid_params_handling".to_string(),
test_name: "Invalid Parameters Handling".to_string(),
category: ValidationCategory::ErrorHandling,
status: ValidationStatus::Error,
message: "Invalid parameters test timed out".to_string(),
details: None,
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
}
}
Ok(())
}
async fn test_schema_validation(&mut self) -> Result<()> {
let test_start = Instant::now();
let mut schema_errors = Vec::new();
let mut validated_count = 0;
for result in &self.results {
if let Some(details) = &result.details {
if details.get("jsonrpc").is_some() {
validated_count += 1;
} else if details.get("method").is_some()
|| details.get("result").is_some()
|| details.get("error").is_some()
{
schema_errors.push(format!("Missing 'jsonrpc' field in {}", result.test_id));
}
}
}
let status = if schema_errors.is_empty() {
ValidationStatus::Pass
} else {
ValidationStatus::Warning
};
let message = if schema_errors.is_empty() {
format!(
"All {} messages conform to expected schemas",
validated_count
)
} else {
format!(
"Found {} schema issues out of {} validated messages",
schema_errors.len(),
validated_count
)
};
self.add_result(ValidationResult {
test_id: "schema_validation".to_string(),
test_name: "Schema Validation".to_string(),
category: ValidationCategory::Schema,
status,
message,
details: if schema_errors.is_empty() {
None
} else {
Some(json!({"errors": schema_errors}))
},
duration: test_start.elapsed(),
timestamp: Utc::now(),
});
Ok(())
}
fn add_result(&mut self, result: ValidationResult) {
debug!(
"Validation result: {} - {}: {}",
result.test_id,
result.status.name(),
result.message
);
self.results.push(result);
}
fn generate_report(&self) -> Result<ValidationReport> {
let total_duration = self
.start_time
.map(|start| start.elapsed())
.unwrap_or_default();
let summary = self.calculate_summary();
let report = ValidationReport {
metadata: ReportMetadata {
generated_at: Utc::now(),
validator_version: "1.0.0".to_string(),
transport_type: self.transport_config.transport_type().to_string(),
total_duration,
config: self.config.clone(),
},
summary,
results: self.results.clone(),
server_info: None, performance: PerformanceMetrics {
initialization_time: self.calculate_initialization_time(),
average_request_time: self.calculate_average_request_time(),
total_requests: self.results.len(),
failed_requests: self
.results
.iter()
.filter(|r| {
matches!(
r.status,
ValidationStatus::Error | ValidationStatus::Critical
)
})
.count(),
timeouts: self
.results
.iter()
.filter(|r| r.message.contains("timeout") || r.message.contains("timed out"))
.count(),
},
};
Ok(report)
}
fn calculate_initialization_time(&self) -> Duration {
self.results
.iter()
.find(|r| r.test_id == "initialization" || r.test_name.contains("Initialization"))
.map(|r| r.duration)
.unwrap_or_else(|| Duration::from_millis(0))
}
fn calculate_average_request_time(&self) -> Duration {
let request_durations: Vec<Duration> = self
.results
.iter()
.filter(|r| !r.test_id.contains("initialization") && !r.test_id.contains("summary"))
.map(|r| r.duration)
.collect();
if request_durations.is_empty() {
Duration::from_millis(0)
} else {
let total_nanos: u128 = request_durations.iter().map(|d| d.as_nanos()).sum();
let average_nanos = total_nanos / request_durations.len() as u128;
Duration::from_nanos(average_nanos as u64)
}
}
fn calculate_summary(&self) -> ValidationSummary {
let total_tests = self.results.len();
let passed = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Pass)
.count();
let info = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Info)
.count();
let warnings = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Warning)
.count();
let errors = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Error)
.count();
let critical = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Critical)
.count();
let skipped = self
.results
.iter()
.filter(|r| r.status == ValidationStatus::Skipped)
.count();
let compliance_percentage = if total_tests > 0 {
(passed as f64 / total_tests as f64) * 100.0
} else {
0.0
};
ValidationSummary {
total_tests,
passed,
info,
warnings,
errors,
critical,
skipped,
compliance_percentage,
}
}
}
impl ValidationStatus {
pub fn name(&self) -> &'static str {
match self {
Self::Pass => "PASS",
Self::Info => "INFO",
Self::Warning => "WARN",
Self::Error => "ERROR",
Self::Critical => "CRITICAL",
Self::Skipped => "SKIP",
}
}
pub fn icon(&self) -> &'static str {
match self {
Self::Pass => "✅",
Self::Info => "ℹ️",
Self::Warning => "⚠️",
Self::Error => "❌",
Self::Critical => "🚨",
Self::Skipped => "⏭️",
}
}
}