use std::path::Path;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Failed to read WASM file: {0}")]
IoError(#[from] std::io::Error),
#[error("Invalid WASM module: {0}")]
InvalidModule(String),
#[error("Missing required export: {0}")]
MissingExport(String),
#[error("Invalid export signature for '{name}': expected {expected}, got {actual}")]
InvalidSignature {
name: String,
expected: String,
actual: String,
},
#[error("Module uses disallowed import: {module}::{name}")]
DisallowedImport { module: String, name: String },
#[error("Module exceeds size limit: {size} bytes (max: {max} bytes)")]
TooLarge { size: u64, max: u64 },
#[error("Validation failed: {0}")]
Other(String),
}
#[derive(Debug)]
pub struct ValidationResult {
pub is_valid: bool,
pub errors: Vec<ValidationError>,
pub warnings: Vec<String>,
pub exports: Vec<ExportInfo>,
pub imports: Vec<ImportInfo>,
pub size_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct ExportInfo {
pub name: String,
pub kind: ExportKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExportKind {
Function,
Memory,
Table,
Global,
}
#[derive(Debug, Clone)]
pub struct ImportInfo {
pub module: String,
pub name: String,
pub kind: ImportKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImportKind {
Function,
Memory,
Table,
Global,
}
pub struct WasmValidator {
max_size: u64,
required_exports: Vec<String>,
allowed_import_modules: Vec<String>,
}
impl Default for WasmValidator {
fn default() -> Self {
Self {
max_size: 10 * 1024 * 1024, required_exports: vec!["run".to_string()],
allowed_import_modules: vec![
"env".to_string(),
"wasi_snapshot_preview1".to_string(),
"wasi".to_string(),
],
}
}
}
impl WasmValidator {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_size(mut self, max_bytes: u64) -> Self {
self.max_size = max_bytes;
self
}
pub fn with_required_export(mut self, name: impl Into<String>) -> Self {
self.required_exports.push(name.into());
self
}
pub fn with_allowed_import(mut self, module: impl Into<String>) -> Self {
self.allowed_import_modules.push(module.into());
self
}
pub async fn validate_file(&self, path: &Path) -> Result<ValidationResult, ValidationError> {
let bytes = tokio::fs::read(path).await?;
self.validate_bytes(&bytes)
}
pub fn validate_bytes(&self, bytes: &[u8]) -> Result<ValidationResult, ValidationError> {
let mut errors = Vec::new();
let mut warnings = Vec::new();
let mut exports = Vec::new();
let mut imports = Vec::new();
let size_bytes = bytes.len() as u64;
if size_bytes > self.max_size {
errors.push(ValidationError::TooLarge {
size: size_bytes,
max: self.max_size,
});
}
let parser = wasmparser::Parser::new(0);
for payload in parser.parse_all(bytes) {
match payload {
Ok(wasmparser::Payload::ExportSection(reader)) => {
for export in reader {
match export {
Ok(exp) => {
let kind = match exp.kind {
wasmparser::ExternalKind::Func => ExportKind::Function,
wasmparser::ExternalKind::Memory => ExportKind::Memory,
wasmparser::ExternalKind::Table => ExportKind::Table,
wasmparser::ExternalKind::Global => ExportKind::Global,
wasmparser::ExternalKind::Tag => continue,
};
exports.push(ExportInfo {
name: exp.name.to_string(),
kind,
});
}
Err(e) => {
errors.push(ValidationError::InvalidModule(format!(
"Failed to parse export: {}",
e
)));
}
}
}
}
Ok(wasmparser::Payload::ImportSection(reader)) => {
for import in reader {
match import {
Ok(imp) => {
let kind = match imp.ty {
wasmparser::TypeRef::Func(_) => ImportKind::Function,
wasmparser::TypeRef::Memory(_) => ImportKind::Memory,
wasmparser::TypeRef::Table(_) => ImportKind::Table,
wasmparser::TypeRef::Global(_) => ImportKind::Global,
wasmparser::TypeRef::Tag(_) => continue,
};
imports.push(ImportInfo {
module: imp.module.to_string(),
name: imp.name.to_string(),
kind,
});
if !self
.allowed_import_modules
.contains(&imp.module.to_string())
{
errors.push(ValidationError::DisallowedImport {
module: imp.module.to_string(),
name: imp.name.to_string(),
});
}
}
Err(e) => {
errors.push(ValidationError::InvalidModule(format!(
"Failed to parse import: {}",
e
)));
}
}
}
}
Ok(_) => {
}
Err(e) => {
errors.push(ValidationError::InvalidModule(format!(
"Failed to parse WASM: {}",
e
)));
break;
}
}
}
for required in &self.required_exports {
if !exports.iter().any(|e| &e.name == required) {
errors.push(ValidationError::MissingExport(required.clone()));
}
}
if !exports
.iter()
.any(|e| e.name == "memory" && e.kind == ExportKind::Memory)
{
warnings
.push("Module does not export memory - host cannot read/write data".to_string());
}
for import in &imports {
if import.module == "wasi_snapshot_preview1" {
match import.name.as_str() {
"fd_write" | "fd_read" | "path_open" | "path_create_directory" => {
warnings.push(format!(
"Module uses WASI filesystem function '{}' - ensure this is intended",
import.name
));
}
"sock_send" | "sock_recv" | "sock_accept" => {
warnings.push(format!(
"Module uses WASI socket function '{}' - ensure this is intended",
import.name
));
}
_ => {}
}
}
}
Ok(ValidationResult {
is_valid: errors.is_empty(),
errors,
warnings,
exports,
imports,
size_bytes,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validator_default() {
let validator = WasmValidator::new();
assert_eq!(validator.max_size, 10 * 1024 * 1024);
assert!(validator.required_exports.contains(&"run".to_string()));
}
#[test]
fn test_validator_builder() {
let validator = WasmValidator::new()
.with_max_size(1024)
.with_required_export("custom_export")
.with_allowed_import("custom_module");
assert_eq!(validator.max_size, 1024);
assert!(
validator
.required_exports
.contains(&"custom_export".to_string())
);
assert!(
validator
.allowed_import_modules
.contains(&"custom_module".to_string())
);
}
#[test]
fn test_validate_bytes_invalid_bytes() {
let validator = WasmValidator::new();
let garbage = b"this is not a wasm module at all";
let result = validator.validate_bytes(garbage).unwrap();
assert!(!result.is_valid);
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidModule(_)))
);
}
#[test]
fn test_validate_bytes_empty() {
let validator = WasmValidator::new();
let result = validator.validate_bytes(b"").unwrap();
assert!(!result.is_valid);
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidModule(_)))
);
}
#[test]
fn test_validate_bytes_minimal_wasm_missing_run_export() {
let validator = WasmValidator::new();
let minimal_wasm = b"\x00asm\x01\x00\x00\x00";
let result = validator.validate_bytes(minimal_wasm).unwrap();
assert!(!result.is_valid);
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::MissingExport(name) if name == "run"))
);
assert_eq!(result.size_bytes, 8);
}
#[test]
fn test_validation_result_is_valid_when_no_errors() {
let result = ValidationResult {
is_valid: true,
errors: vec![],
warnings: vec!["some warning".to_string()],
exports: vec![],
imports: vec![],
size_bytes: 0,
};
assert!(result.is_valid);
assert!(result.errors.is_empty());
}
#[test]
fn test_validation_result_is_invalid_when_errors_present() {
let result = ValidationResult {
is_valid: false,
errors: vec![ValidationError::MissingExport("run".to_string())],
warnings: vec![],
exports: vec![],
imports: vec![],
size_bytes: 0,
};
assert!(!result.is_valid);
assert_eq!(result.errors.len(), 1);
}
#[test]
fn test_validation_error_display() {
let io_err =
ValidationError::IoError(std::io::Error::new(std::io::ErrorKind::NotFound, "gone"));
assert!(io_err.to_string().contains("Failed to read WASM file"));
let invalid = ValidationError::InvalidModule("bad magic".to_string());
assert!(invalid.to_string().contains("Invalid WASM module"));
assert!(invalid.to_string().contains("bad magic"));
let missing = ValidationError::MissingExport("run".to_string());
assert!(missing.to_string().contains("Missing required export"));
assert!(missing.to_string().contains("run"));
let sig = ValidationError::InvalidSignature {
name: "run".to_string(),
expected: "() -> i32".to_string(),
actual: "() -> ()".to_string(),
};
assert!(sig.to_string().contains("Invalid export signature"));
assert!(sig.to_string().contains("run"));
let disallowed = ValidationError::DisallowedImport {
module: "evil".to_string(),
name: "hack".to_string(),
};
assert!(disallowed.to_string().contains("disallowed import"));
assert!(disallowed.to_string().contains("evil::hack"));
let too_large = ValidationError::TooLarge {
size: 200,
max: 100,
};
assert!(too_large.to_string().contains("200"));
assert!(too_large.to_string().contains("100"));
let other = ValidationError::Other("something broke".to_string());
assert!(other.to_string().contains("something broke"));
}
#[test]
fn test_export_kind_equality() {
assert_eq!(ExportKind::Function, ExportKind::Function);
assert_eq!(ExportKind::Memory, ExportKind::Memory);
assert_eq!(ExportKind::Table, ExportKind::Table);
assert_eq!(ExportKind::Global, ExportKind::Global);
assert_ne!(ExportKind::Function, ExportKind::Memory);
assert_ne!(ExportKind::Table, ExportKind::Global);
}
#[test]
fn test_import_kind_equality() {
assert_eq!(ImportKind::Function, ImportKind::Function);
assert_eq!(ImportKind::Memory, ImportKind::Memory);
assert_eq!(ImportKind::Table, ImportKind::Table);
assert_eq!(ImportKind::Global, ImportKind::Global);
assert_ne!(ImportKind::Function, ImportKind::Global);
assert_ne!(ImportKind::Memory, ImportKind::Table);
}
#[test]
fn test_validate_bytes_exceeds_max_size() {
let validator = WasmValidator::new().with_max_size(4);
let minimal_wasm = b"\x00asm\x01\x00\x00\x00";
let result = validator.validate_bytes(minimal_wasm).unwrap();
assert!(!result.is_valid);
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::TooLarge { size: 8, max: 4 }))
);
}
#[test]
fn test_with_max_size_then_validate_over_limit() {
let validator = WasmValidator::new().with_max_size(16);
let oversized = vec![0u8; 32];
let result = validator.validate_bytes(&oversized).unwrap();
assert!(!result.is_valid);
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::TooLarge { size: 32, max: 16 }))
);
}
}