use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::agent_lifecycle_manager::WasmModuleValidator;
#[cfg(test)]
use crate::domain::ValidationRuleType;
use crate::domain::{
AgentVersion, CustomValidationRule, ValidationFailure, ValidationResult, ValidationWarning,
VersionNumber, WasmModule, WasmSecurityPolicy, WasmValidationError, millis_to_f64_for_stats,
u64_to_f64_for_stats,
};
use crate::domain_types::AgentName;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationMode {
Enabled,
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StrictnessLevel {
Strict,
Relaxed,
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub security_policy: WasmSecurityPolicy,
pub structural_validation: ValidationMode,
pub security_validation: ValidationMode,
pub performance_analysis: ValidationMode,
pub strictness: StrictnessLevel,
pub max_validation_time_ms: u64,
}
impl ValidationConfig {
pub fn strict() -> Self {
Self {
security_policy: WasmSecurityPolicy::strict(),
structural_validation: ValidationMode::Enabled,
security_validation: ValidationMode::Enabled,
performance_analysis: ValidationMode::Enabled,
strictness: StrictnessLevel::Strict,
max_validation_time_ms: 30_000, }
}
pub fn permissive() -> Self {
Self {
security_policy: WasmSecurityPolicy::permissive(),
structural_validation: ValidationMode::Enabled,
security_validation: ValidationMode::Disabled,
performance_analysis: ValidationMode::Disabled,
strictness: StrictnessLevel::Relaxed,
max_validation_time_ms: 10_000, }
}
pub fn testing() -> Self {
Self {
security_policy: WasmSecurityPolicy::testing(),
structural_validation: ValidationMode::Enabled,
security_validation: ValidationMode::Enabled,
performance_analysis: ValidationMode::Enabled, strictness: StrictnessLevel::Relaxed,
max_validation_time_ms: 5_000, }
}
}
impl Default for ValidationConfig {
fn default() -> Self {
Self::strict()
}
}
#[derive(Debug, Clone)]
struct StructuralAnalysis {
pub function_count: usize,
pub import_count: usize,
pub export_count: usize,
pub memory_pages: u32,
pub table_elements: u32,
pub complexity_score: f64,
}
#[derive(Debug, Clone)]
struct SecurityAnalysis {
pub unauthorized_imports: Vec<String>,
pub missing_exports: Vec<String>,
pub policy_violations: Vec<String>,
pub security_score: f64,
}
#[derive(Debug, Clone)]
struct PerformanceAnalysis {
pub estimated_memory_usage: usize,
pub estimated_execution_cost: u64,
pub potential_bottlenecks: Vec<String>,
pub optimization_suggestions: Vec<String>,
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct ValidationStatistics {
pub modules_validated: u64,
pub modules_passed: u64,
pub modules_failed: u64,
pub average_validation_time_ms: f64,
pub common_failures: HashMap<String, u32>,
}
impl Default for ValidationStatistics {
fn default() -> Self {
Self::new()
}
}
impl ValidationStatistics {
pub fn new() -> Self {
Self {
modules_validated: 0,
modules_passed: 0,
modules_failed: 0,
average_validation_time_ms: 0.0,
common_failures: HashMap::new(),
}
}
pub fn record_validation(
&mut self,
passed: bool,
validation_time_ms: f64,
failure_reason: Option<&str>,
) {
self.modules_validated += 1;
if passed {
self.modules_passed += 1;
} else {
self.modules_failed += 1;
if let Some(reason) = failure_reason {
*self.common_failures.entry(reason.to_string()).or_insert(0) += 1;
}
}
let validated_f64 = u64_to_f64_for_stats(self.modules_validated);
self.average_validation_time_ms =
((self.average_validation_time_ms * (validated_f64 - 1.0)) + validation_time_ms)
/ validated_f64;
}
pub fn success_rate(&self) -> f64 {
if self.modules_validated == 0 {
return 0.0;
}
let passed_f64 = u64_to_f64_for_stats(self.modules_passed);
let validated_f64 = u64_to_f64_for_stats(self.modules_validated);
(passed_f64 / validated_f64) * 100.0
}
}
pub struct CaxtonWasmModuleValidator {
config: Arc<RwLock<ValidationConfig>>,
statistics: Arc<RwLock<ValidationStatistics>>,
custom_rules: Arc<RwLock<Vec<CustomValidationRule>>>,
}
impl CaxtonWasmModuleValidator {
pub fn new(config: ValidationConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
statistics: Arc::new(RwLock::new(ValidationStatistics::new())),
custom_rules: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn strict() -> Self {
Self::new(ValidationConfig::strict())
}
pub fn permissive() -> Self {
Self::new(ValidationConfig::permissive())
}
pub fn testing() -> Self {
Self::new(ValidationConfig::testing())
}
pub async fn update_config(&self, config: ValidationConfig) {
let mut current_config = self.config.write().await;
*current_config = config;
}
pub async fn add_custom_rule(&self, rule: CustomValidationRule) {
let mut rules = self.custom_rules.write().await;
rules.push(rule);
}
pub async fn get_statistics(&self) -> ValidationStatistics {
self.statistics.read().await.clone()
}
fn validate_wasm_format(wasm_bytes: &[u8]) -> Result<(), ValidationFailure> {
if wasm_bytes.is_empty() {
return Err(ValidationFailure::InvalidWasmFormat);
}
if wasm_bytes.len() < 4 {
return Err(ValidationFailure::InvalidWasmFormat);
}
let magic = &wasm_bytes[0..4];
if magic != [0x00, 0x61, 0x73, 0x6D] {
return Err(ValidationFailure::InvalidWasmFormat);
}
if wasm_bytes.len() < 8 {
return Err(ValidationFailure::InvalidWasmFormat);
}
let version = &wasm_bytes[4..8];
if version != [0x01, 0x00, 0x00, 0x00] {
return Err(ValidationFailure::UnsupportedWasmVersion);
}
Ok(())
}
fn analyze_structure(_wasm_bytes: &[u8]) -> StructuralAnalysis {
StructuralAnalysis {
function_count: 10,
import_count: 2,
export_count: 3,
memory_pages: 16,
table_elements: 0,
complexity_score: 0.3, }
}
fn analyze_security(
_wasm_bytes: &[u8],
policy: &WasmSecurityPolicy,
_analysis: &StructuralAnalysis,
) -> SecurityAnalysis {
let mut security_analysis = SecurityAnalysis {
unauthorized_imports: Vec::new(),
missing_exports: Vec::new(),
policy_violations: Vec::new(),
security_score: 95.0, };
if policy.name == "strict" {
security_analysis.security_score = 98.0;
} else if policy.name == "permissive" {
security_analysis.security_score = 75.0;
}
security_analysis
}
fn analyze_performance(
_wasm_bytes: &[u8],
analysis: &StructuralAnalysis,
) -> PerformanceAnalysis {
let estimated_memory = (analysis.memory_pages as usize) * 65536; let estimated_cost = (analysis.function_count as u64) * 1000;
let mut bottlenecks = Vec::new();
let mut suggestions = Vec::new();
if analysis.function_count > 100 {
bottlenecks.push("High function count may impact load time".to_string());
suggestions.push("Consider splitting module into smaller components".to_string());
}
if analysis.memory_pages > 64 {
bottlenecks.push("High memory usage may impact performance".to_string());
suggestions.push("Optimize memory layout and reduce allocations".to_string());
}
if analysis.complexity_score > 0.8 {
bottlenecks.push("High complexity may impact execution performance".to_string());
suggestions.push("Refactor complex functions for better performance".to_string());
}
PerformanceAnalysis {
estimated_memory_usage: estimated_memory,
estimated_execution_cost: estimated_cost,
potential_bottlenecks: bottlenecks,
optimization_suggestions: suggestions,
}
}
fn create_validation_result(
structural: &StructuralAnalysis,
security: &SecurityAnalysis,
performance: &PerformanceAnalysis,
custom_warnings: Vec<ValidationWarning>,
config: &ValidationConfig,
) -> ValidationResult {
let mut failures = Vec::new();
let mut warnings = custom_warnings;
if structural.function_count > 1000 {
failures.push(ValidationFailure::TooManyFunctions {
count: structural.function_count,
limit: 1000,
});
}
if structural.import_count > 100 {
failures.push(ValidationFailure::TooManyImports {
count: structural.import_count,
limit: 100,
});
}
if structural.export_count > 100 {
failures.push(ValidationFailure::TooManyExports {
count: structural.export_count,
limit: 100,
});
}
for import in &security.unauthorized_imports {
failures.push(ValidationFailure::UnauthorizedImport {
function_name: import.clone(),
});
}
for export in &security.missing_exports {
failures.push(ValidationFailure::MissingRequiredExport {
function_name: export.clone(),
});
}
for violation in &security.policy_violations {
failures.push(ValidationFailure::SecurityViolation {
policy: config.security_policy.name.clone(),
violation: violation.clone(),
});
}
if config.performance_analysis == ValidationMode::Enabled {
for bottleneck in &performance.potential_bottlenecks {
warnings.push(ValidationWarning::PerformanceWarning {
warning: bottleneck.clone(),
});
}
}
if structural.function_count > 500 {
warnings.push(ValidationWarning::LargeFunctionCount {
count: structural.function_count,
});
}
if !failures.is_empty() {
ValidationResult::Invalid { reasons: failures }
} else if !warnings.is_empty() {
ValidationResult::Warning { warnings }
} else {
ValidationResult::Valid
}
}
fn extract_basic_metadata(wasm_bytes: &[u8]) -> HashMap<String, String> {
let mut metadata = HashMap::new();
metadata.insert("size_bytes".to_string(), wasm_bytes.len().to_string());
metadata.insert(
"size_kb".to_string(),
wasm_bytes.len().div_ceil(1024).to_string(),
);
if wasm_bytes.len() >= 8 {
metadata.insert("wasm_version".to_string(), "1".to_string());
}
metadata.insert(
"validation_timestamp".to_string(),
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string(),
);
metadata
}
async fn validate_module_comprehensive(
&self,
wasm_bytes: &[u8],
agent_name: Option<AgentName>,
) -> Result<WasmModule, WasmValidationError> {
let validation_start = SystemTime::now();
self.perform_validation_steps(wasm_bytes, agent_name, validation_start)
.await
}
async fn perform_validation_steps(
&self,
wasm_bytes: &[u8],
agent_name: Option<AgentName>,
validation_start: SystemTime,
) -> Result<WasmModule, WasmValidationError> {
let config = self.config.read().await.clone();
Self::validate_basic_format(wasm_bytes)?;
let (structural, security, performance) = Self::perform_analysis_phase(wasm_bytes, &config);
let custom_warnings =
Self::apply_custom_validation_rules(wasm_bytes, agent_name.as_ref(), &config);
let validation_result = Self::create_validation_result(
&structural,
&security,
&performance,
custom_warnings,
&config,
);
self.finalize_validation(
wasm_bytes,
agent_name,
validation_result,
validation_start,
(structural, security, performance),
)
.await
}
fn validate_basic_format(wasm_bytes: &[u8]) -> Result<(), WasmValidationError> {
info!(
"Starting comprehensive WASM module validation (size: {} bytes)",
wasm_bytes.len()
);
Self::validate_wasm_format(wasm_bytes).map_err(|failure| {
WasmValidationError::InvalidFormat {
reason: failure.to_string(),
}
})
}
fn perform_analysis_phase(
wasm_bytes: &[u8],
config: &ValidationConfig,
) -> (StructuralAnalysis, SecurityAnalysis, PerformanceAnalysis) {
let structural = if config.structural_validation == ValidationMode::Enabled {
Self::analyze_structure(wasm_bytes)
} else {
StructuralAnalysis {
function_count: 0,
import_count: 0,
export_count: 0,
memory_pages: 0,
table_elements: 0,
complexity_score: 0.0,
}
};
debug!(
"Structural analysis completed: {} functions, {} imports, {} exports",
structural.function_count, structural.import_count, structural.export_count
);
let security = if config.security_validation == ValidationMode::Enabled {
Self::analyze_security(wasm_bytes, &config.security_policy, &structural)
} else {
SecurityAnalysis {
unauthorized_imports: Vec::new(),
missing_exports: Vec::new(),
policy_violations: Vec::new(),
security_score: 100.0,
}
};
debug!(
"Security analysis completed: score {}",
security.security_score
);
let performance = if config.performance_analysis == ValidationMode::Enabled {
Self::analyze_performance(wasm_bytes, &structural)
} else {
PerformanceAnalysis {
estimated_memory_usage: 0,
estimated_execution_cost: 0,
potential_bottlenecks: Vec::new(),
optimization_suggestions: Vec::new(),
}
};
debug!(
"Performance analysis completed: {} bytes memory, {} fuel cost",
performance.estimated_memory_usage, performance.estimated_execution_cost
);
(structural, security, performance)
}
fn apply_custom_validation_rules(
_wasm_bytes: &[u8],
_agent_name: Option<&AgentName>,
_config: &ValidationConfig,
) -> Vec<ValidationWarning> {
Vec::new()
}
async fn finalize_validation(
&self,
wasm_bytes: &[u8],
agent_name: Option<AgentName>,
validation_result: ValidationResult,
validation_start: SystemTime,
analysis_results: (StructuralAnalysis, SecurityAnalysis, PerformanceAnalysis),
) -> Result<WasmModule, WasmValidationError> {
let validation_millis = validation_start.elapsed().unwrap_or_default().as_millis();
let validation_duration = millis_to_f64_for_stats(validation_millis);
let config = self.config.read().await;
let mut final_module = WasmModule::from_bytes(
AgentVersion::generate(),
VersionNumber::first(),
None, agent_name,
wasm_bytes,
&config.security_policy,
)?;
final_module.validation_result = validation_result.clone();
let (structural, security, performance) = analysis_results;
final_module.metadata.insert(
"estimated_memory_usage".to_string(),
performance.estimated_memory_usage.to_string(),
);
final_module.metadata.insert(
"estimated_execution_cost".to_string(),
performance.estimated_execution_cost.to_string(),
);
final_module.metadata.insert(
"security_score".to_string(),
security.security_score.to_string(),
);
final_module.metadata.insert(
"complexity_score".to_string(),
structural.complexity_score.to_string(),
);
info!(
"WASM module validation completed in {:.2}ms with result: {:?}",
validation_duration, final_module.validation_result
);
Ok(final_module)
}
}
#[async_trait::async_trait]
impl WasmModuleValidator for CaxtonWasmModuleValidator {
async fn validate_module(
&self,
wasm_bytes: &[u8],
agent_name: Option<AgentName>,
) -> Result<WasmModule, WasmValidationError> {
let validation_start = SystemTime::now();
let early_error = if wasm_bytes.is_empty() {
Some(WasmValidationError::EmptyModule)
} else if wasm_bytes.len() > 100 * 1024 * 1024 {
Some(WasmValidationError::ModuleTooLarge {
size: wasm_bytes.len(),
limit: 100 * 1024 * 1024,
})
} else {
None
};
if let Some(error) = early_error {
let validation_millis = validation_start.elapsed().unwrap_or_default().as_millis();
let validation_time = millis_to_f64_for_stats(validation_millis);
let validation_time = validation_time.max(0.1); let mut stats = self.statistics.write().await;
stats.record_validation(false, validation_time, Some(&error.to_string()));
return Err(error);
}
let result = self
.validate_module_comprehensive(wasm_bytes, agent_name)
.await;
let validation_millis = validation_start.elapsed().unwrap_or_default().as_millis();
let validation_time = millis_to_f64_for_stats(validation_millis);
let validation_time = validation_time.max(0.1); let passed = result.is_ok();
let failure_reason = if let Err(ref error) = result {
Some(error.to_string())
} else {
None
};
{
let mut stats = self.statistics.write().await;
stats.record_validation(passed, validation_time, failure_reason.as_deref());
}
result
}
async fn validate_security(
&self,
module: &WasmModule,
) -> Result<ValidationResult, WasmValidationError> {
info!(
"Performing security validation for module {}",
module
.name
.as_ref()
.map_or_else(|| "unnamed".to_string(), std::string::ToString::to_string)
);
let config = self.config.read().await;
if config.security_validation != ValidationMode::Enabled {
debug!("Security validation disabled, returning valid");
return Ok(ValidationResult::Valid);
}
let validation_result = module.security_policy.validate_module(module);
match &validation_result {
ValidationResult::Valid => {
debug!("Security validation passed");
}
ValidationResult::Invalid { reasons } => {
warn!("Security validation failed: {} violations", reasons.len());
for reason in reasons {
warn!("Security violation: {}", reason);
}
}
ValidationResult::Warning { warnings } => {
info!(
"Security validation passed with {} warnings",
warnings.len()
);
for warning in warnings {
info!("Security warning: {}", warning);
}
}
}
Ok(validation_result)
}
async fn extract_metadata(
&self,
wasm_bytes: &[u8],
) -> Result<HashMap<String, String>, WasmValidationError> {
debug!(
"Extracting metadata from WASM module ({} bytes)",
wasm_bytes.len()
);
if wasm_bytes.is_empty() {
return Err(WasmValidationError::EmptyModule);
}
Self::validate_wasm_format(wasm_bytes).map_err(|e| WasmValidationError::InvalidFormat {
reason: e.to_string(),
})?;
let mut metadata = Self::extract_basic_metadata(wasm_bytes);
let config = self.config.read().await;
if config.structural_validation == ValidationMode::Enabled {
let structural = Self::analyze_structure(wasm_bytes);
metadata.insert(
"function_count".to_string(),
structural.function_count.to_string(),
);
metadata.insert(
"import_count".to_string(),
structural.import_count.to_string(),
);
metadata.insert(
"export_count".to_string(),
structural.export_count.to_string(),
);
metadata.insert(
"memory_pages".to_string(),
structural.memory_pages.to_string(),
);
metadata.insert(
"table_elements".to_string(),
structural.table_elements.to_string(),
);
metadata.insert(
"complexity_score".to_string(),
structural.complexity_score.to_string(),
);
}
if config.performance_analysis == ValidationMode::Enabled {
let structural = Self::analyze_structure(wasm_bytes);
let performance = Self::analyze_performance(wasm_bytes, &structural);
metadata.insert(
"estimated_memory_usage".to_string(),
performance.estimated_memory_usage.to_string(),
);
metadata.insert(
"estimated_execution_cost".to_string(),
performance.estimated_execution_cost.to_string(),
);
metadata.insert(
"bottleneck_count".to_string(),
performance.potential_bottlenecks.len().to_string(),
);
metadata.insert(
"optimization_suggestions_count".to_string(),
performance.optimization_suggestions.len().to_string(),
);
}
debug!("Extracted {} metadata fields", metadata.len());
Ok(metadata)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_valid_wasm_bytes() -> Vec<u8> {
vec![0x00, 0x61, 0x73, 0x6D, 0x01, 0x00, 0x00, 0x00]
}
fn create_invalid_wasm_bytes() -> Vec<u8> {
vec![0xFF, 0xFF, 0xFF, 0xFF] }
#[tokio::test]
async fn test_valid_wasm_module_validation() {
let validator = CaxtonWasmModuleValidator::testing();
let wasm_bytes = create_valid_wasm_bytes();
let result = validator.validate_module(&wasm_bytes, None).await;
assert!(result.is_ok());
let module = result.unwrap();
assert!(module.is_valid() || module.validation_result.has_warnings());
}
#[tokio::test]
async fn test_invalid_wasm_module_validation() {
let validator = CaxtonWasmModuleValidator::strict();
let wasm_bytes = create_invalid_wasm_bytes();
let result = validator.validate_module(&wasm_bytes, None).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
WasmValidationError::InvalidFormat { .. }
));
}
#[tokio::test]
async fn test_empty_wasm_module_validation() {
let validator = CaxtonWasmModuleValidator::testing();
let empty_bytes = vec![];
let result = validator.validate_module(&empty_bytes, None).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
WasmValidationError::EmptyModule
));
}
#[tokio::test]
async fn test_security_validation() {
let validator = CaxtonWasmModuleValidator::testing();
let wasm_bytes = create_valid_wasm_bytes();
let module = validator.validate_module(&wasm_bytes, None).await.unwrap();
let security_result = validator.validate_security(&module).await.unwrap();
assert!(
security_result.is_valid(),
"Security validation should pass for valid WASM module with testing policy"
);
}
#[tokio::test]
async fn test_metadata_extraction() {
let validator = CaxtonWasmModuleValidator::testing();
let wasm_bytes = create_valid_wasm_bytes();
let result = validator.extract_metadata(&wasm_bytes).await;
assert!(result.is_ok());
let metadata = result.unwrap();
assert!(metadata.contains_key("size_bytes"));
assert!(metadata.contains_key("wasm_version"));
}
#[tokio::test]
async fn test_custom_validation_rules() {
let validator = CaxtonWasmModuleValidator::testing();
let custom_rule = CustomValidationRule {
name: "test_rule".to_string(),
description: "Test custom rule".to_string(),
rule_type: ValidationRuleType::FunctionNamePattern,
parameters: {
let mut params = HashMap::new();
params.insert("pattern".to_string(), "test_*".to_string());
params
},
};
validator.add_custom_rule(custom_rule).await;
let wasm_bytes = create_valid_wasm_bytes();
let result = validator.validate_module(&wasm_bytes, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validation_statistics() {
let validator = CaxtonWasmModuleValidator::testing();
let wasm_bytes = create_valid_wasm_bytes();
for _ in 0..3 {
let _ = validator.validate_module(&wasm_bytes, None).await;
}
let stats = validator.get_statistics().await;
assert!(stats.modules_validated >= 3);
assert!(stats.success_rate() > 0.0);
}
#[tokio::test]
async fn test_configuration_update() {
let validator = CaxtonWasmModuleValidator::strict();
let new_config = ValidationConfig::permissive();
validator.update_config(new_config).await;
let wasm_bytes = create_valid_wasm_bytes();
let result = validator.validate_module(&wasm_bytes, None).await;
assert!(result.is_ok());
}
#[test]
fn test_validation_config_creation() {
let strict_config = ValidationConfig::strict();
assert!(strict_config.strictness == StrictnessLevel::Strict);
assert!(strict_config.security_validation == ValidationMode::Enabled);
let permissive_config = ValidationConfig::permissive();
assert_eq!(permissive_config.strictness, StrictnessLevel::Relaxed);
assert_eq!(
permissive_config.security_validation,
ValidationMode::Disabled
);
let testing_config = ValidationConfig::testing();
assert_eq!(testing_config.strictness, StrictnessLevel::Relaxed);
assert_eq!(testing_config.security_validation, ValidationMode::Enabled);
}
#[test]
fn test_validation_statistics_tracking() {
let mut stats = ValidationStatistics::new();
stats.record_validation(true, 100.0, None);
stats.record_validation(false, 150.0, Some("test_error"));
stats.record_validation(true, 120.0, None);
assert_eq!(stats.modules_validated, 3);
assert_eq!(stats.modules_passed, 2);
assert_eq!(stats.modules_failed, 1);
assert!((stats.success_rate() - 66.666_666_666_666_66).abs() < f64::EPSILON);
assert_eq!(stats.common_failures.get("test_error"), Some(&1));
}
}