use thiserror::Error;
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Module exceeds maximum size: {actual} > {max}")]
ModuleTooLarge { actual: usize, max: usize },
#[error("Too many functions: {actual} > {max}")]
TooManyFunctions { actual: usize, max: usize },
#[error("Too many imports: {actual} > {max}")]
TooManyImports { actual: usize, max: usize },
#[error("Too many exports: {actual} > {max}")]
TooManyExports { actual: usize, max: usize },
#[error("Function complexity too high: {actual} > {max}")]
ComplexityTooHigh { actual: usize, max: usize },
#[error("Suspicious pattern detected: {0}")]
SuspiciousPattern(String),
#[error("Invalid module structure: {0}")]
InvalidStructure(String),
#[error("Wasmtime validation failed: {0}")]
WasmtimeError(String),
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub max_module_size: usize,
pub max_functions: usize,
pub max_imports: usize,
pub max_exports: usize,
pub max_function_complexity: usize,
pub max_memory_pages: u32,
pub max_table_size: u32,
pub check_suspicious_patterns: bool,
pub deep_validation: bool,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
max_module_size: 10 * 1024 * 1024, max_functions: 10_000,
max_imports: 1_000,
max_exports: 1_000,
max_function_complexity: 100_000,
max_memory_pages: 256, max_table_size: 10_000,
check_suspicious_patterns: true,
deep_validation: true,
}
}
}
impl ValidationConfig {
pub fn embedded() -> Self {
Self {
max_module_size: 512 * 1024, max_functions: 100,
max_imports: 50,
max_exports: 50,
max_function_complexity: 10_000,
max_memory_pages: 16, max_table_size: 100,
check_suspicious_patterns: true,
deep_validation: true,
}
}
pub fn standard() -> Self {
Self::default()
}
pub fn compute() -> Self {
Self {
max_module_size: 100 * 1024 * 1024, max_functions: 100_000,
max_imports: 10_000,
max_exports: 10_000,
max_function_complexity: 1_000_000,
max_memory_pages: 4096, max_table_size: 100_000,
check_suspicious_patterns: true,
deep_validation: false, }
}
pub fn fuzzing() -> Self {
Self {
max_module_size: usize::MAX,
max_functions: usize::MAX,
max_imports: usize::MAX,
max_exports: usize::MAX,
max_function_complexity: usize::MAX,
max_memory_pages: u32::MAX,
max_table_size: u32::MAX,
check_suspicious_patterns: false,
deep_validation: false,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ModuleStats {
pub module_size: usize,
pub function_count: usize,
pub import_count: usize,
pub export_count: usize,
pub max_complexity: usize,
pub memory_count: usize,
pub table_count: usize,
pub global_count: usize,
pub data_segment_count: usize,
pub element_segment_count: usize,
}
impl ModuleStats {
pub fn is_complex(&self) -> bool {
self.function_count > 1000 || self.max_complexity > 50_000
}
pub fn complexity_score(&self) -> u8 {
let func_score = (self.function_count as f64 / 10_000.0 * 50.0).min(50.0);
let inst_score = (self.max_complexity as f64 / 100_000.0 * 50.0).min(50.0);
(func_score + inst_score) as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SuspiciousPattern {
ExcessiveMemoryGrowth,
TightInfiniteLoop,
UnboundedRecursion,
SuspiciousArithmetic,
ExcessiveBranching,
}
impl SuspiciousPattern {
pub fn description(&self) -> &'static str {
match self {
Self::ExcessiveMemoryGrowth => "Excessive memory.grow operations",
Self::TightInfiniteLoop => "Tight loop without termination",
Self::UnboundedRecursion => "Recursive call without apparent base case",
Self::SuspiciousArithmetic => "Suspicious arithmetic operations",
Self::ExcessiveBranching => "Excessive conditional branching",
}
}
}
pub struct ModuleValidator {
config: ValidationConfig,
suspicious_patterns: Vec<SuspiciousPattern>,
stats: ModuleStats,
}
impl ModuleValidator {
pub fn new(config: ValidationConfig) -> Self {
Self {
config,
suspicious_patterns: Vec::new(),
stats: ModuleStats::default(),
}
}
pub fn with_default() -> Self {
Self::new(ValidationConfig::default())
}
pub fn validate(&mut self, wasm_bytes: &[u8]) -> Result<ModuleStats, ValidationError> {
self.suspicious_patterns.clear();
self.stats = ModuleStats::default();
self.stats.module_size = wasm_bytes.len();
if self.stats.module_size > self.config.max_module_size {
return Err(ValidationError::ModuleTooLarge {
actual: self.stats.module_size,
max: self.config.max_module_size,
});
}
if wasm_bytes.len() < 4 || &wasm_bytes[0..4] != b"\0asm" {
return Err(ValidationError::InvalidStructure(
"Invalid WASM magic number".to_string(),
));
}
self.parse_module_structure(wasm_bytes)?;
self.check_limits()?;
if self.config.deep_validation {
self.deep_validate(wasm_bytes)?;
}
if self.config.check_suspicious_patterns {
self.detect_suspicious_patterns(wasm_bytes)?;
}
Ok(self.stats.clone())
}
fn parse_module_structure(&mut self, wasm_bytes: &[u8]) -> Result<(), ValidationError> {
let mut offset = 8;
while offset < wasm_bytes.len() {
if offset + 1 > wasm_bytes.len() {
break;
}
let section_id = wasm_bytes[offset];
offset += 1;
let (size, size_len) = self.read_leb128_u32(&wasm_bytes[offset..])?;
offset += size_len;
match section_id {
1 => self.stats.import_count += 1, 2 => self.stats.import_count += 1, 3 => self.stats.function_count += 1, 4 => self.stats.table_count += 1, 5 => self.stats.memory_count += 1, 6 => self.stats.global_count += 1, 7 => self.stats.export_count += 1, 9 => self.stats.element_segment_count += 1, 10 => self.stats.function_count += 1, 11 => self.stats.data_segment_count += 1, _ => {}
}
offset += size as usize;
}
Ok(())
}
fn read_leb128_u32(&self, bytes: &[u8]) -> Result<(u32, usize), ValidationError> {
let mut result = 0u32;
let mut shift = 0;
let mut count = 0;
for &byte in bytes.iter().take(5) {
count += 1;
result |= ((byte & 0x7F) as u32) << shift;
if byte & 0x80 == 0 {
return Ok((result, count));
}
shift += 7;
}
Err(ValidationError::InvalidStructure(
"Invalid LEB128 encoding".to_string(),
))
}
fn check_limits(&self) -> Result<(), ValidationError> {
if self.stats.function_count > self.config.max_functions {
return Err(ValidationError::TooManyFunctions {
actual: self.stats.function_count,
max: self.config.max_functions,
});
}
if self.stats.import_count > self.config.max_imports {
return Err(ValidationError::TooManyImports {
actual: self.stats.import_count,
max: self.config.max_imports,
});
}
if self.stats.export_count > self.config.max_exports {
return Err(ValidationError::TooManyExports {
actual: self.stats.export_count,
max: self.config.max_exports,
});
}
Ok(())
}
fn deep_validate(&mut self, _wasm_bytes: &[u8]) -> Result<(), ValidationError> {
Ok(())
}
fn detect_suspicious_patterns(&mut self, _wasm_bytes: &[u8]) -> Result<(), ValidationError> {
Ok(())
}
pub fn suspicious_patterns(&self) -> &[SuspiciousPattern] {
&self.suspicious_patterns
}
pub fn stats(&self) -> &ModuleStats {
&self.stats
}
}
pub struct FuzzInputGenerator {
seed: u64,
mutation_rate: f64,
}
impl FuzzInputGenerator {
pub fn new(seed: u64) -> Self {
Self {
seed,
mutation_rate: 0.1,
}
}
pub fn with_mutation_rate(mut self, rate: f64) -> Self {
self.mutation_rate = rate.clamp(0.0, 1.0);
self
}
pub fn generate_minimal_module(&self) -> Vec<u8> {
vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ]
}
pub fn mutate(&mut self, input: &[u8]) -> Vec<u8> {
let mut output = input.to_vec();
let mutation_count = (output.len() as f64 * self.mutation_rate) as usize;
for _ in 0..mutation_count {
let idx = self.next_random() as usize % output.len();
output[idx] = self.next_random() as u8;
}
output
}
fn next_random(&mut self) -> u64 {
self.seed ^= self.seed << 13;
self.seed ^= self.seed >> 7;
self.seed ^= self.seed << 17;
self.seed
}
pub fn generate_test_cases(&mut self) -> Vec<Vec<u8>> {
vec![
self.generate_minimal_module(),
self.generate_with_memory(),
self.generate_with_functions(10),
self.generate_with_loops(),
]
}
fn generate_with_memory(&self) -> Vec<u8> {
vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x05, 0x03, 0x01, 0x00, 0x01, ]
}
fn generate_with_functions(&self, _count: usize) -> Vec<u8> {
vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x0a, 0x06, 0x01, 0x04, 0x00, 0x41, 0x2a, 0x0b, ]
}
fn generate_with_loops(&self) -> Vec<u8> {
vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x0a, 0x09, 0x01, 0x07, 0x00, 0x03, 0x40, 0x41, 0x00, 0x0c, 0x00, 0x0b,
0x0b, ]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_config_presets() {
let embedded = ValidationConfig::embedded();
assert!(embedded.max_module_size < ValidationConfig::default().max_module_size);
let compute = ValidationConfig::compute();
assert!(compute.max_module_size > ValidationConfig::default().max_module_size);
let fuzzing = ValidationConfig::fuzzing();
assert_eq!(fuzzing.max_module_size, usize::MAX);
}
#[test]
fn test_module_stats() {
let stats = ModuleStats {
function_count: 5000,
max_complexity: 75_000,
..Default::default()
};
assert!(stats.is_complex());
assert!(stats.complexity_score() > 50);
}
#[test]
fn test_validator_basic() {
let mut validator = ModuleValidator::with_default();
let wasm = wat::parse_str("(module)").unwrap();
let stats = validator.validate(&wasm).unwrap();
assert_eq!(stats.module_size, wasm.len());
}
#[test]
fn test_validator_size_limit() {
let config = ValidationConfig {
max_module_size: 4, ..Default::default()
};
let mut validator = ModuleValidator::new(config);
let wasm = vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ];
let result = validator.validate(&wasm);
assert!(result.is_err());
if let Err(ValidationError::ModuleTooLarge { actual, max }) = result {
assert_eq!(max, 4);
assert_eq!(actual, 8);
}
}
#[test]
fn test_validator_invalid_magic() {
let mut validator = ModuleValidator::with_default();
let invalid_wasm = b"invalid";
let result = validator.validate(invalid_wasm);
assert!(result.is_err());
}
#[test]
fn test_fuzz_generator() {
let mut gen = FuzzInputGenerator::new(12345);
let minimal = gen.generate_minimal_module();
assert!(!minimal.is_empty());
let test_cases = gen.generate_test_cases();
assert_eq!(test_cases.len(), 4);
}
#[test]
fn test_fuzz_mutation() {
let mut gen = FuzzInputGenerator::new(12345).with_mutation_rate(0.1);
let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mutated = gen.mutate(&input);
assert_eq!(mutated.len(), input.len());
assert_ne!(mutated, input);
}
#[test]
fn test_suspicious_pattern_descriptions() {
let patterns = [
SuspiciousPattern::ExcessiveMemoryGrowth,
SuspiciousPattern::TightInfiniteLoop,
SuspiciousPattern::UnboundedRecursion,
];
for pattern in &patterns {
assert!(!pattern.description().is_empty());
}
}
}