use super::*;
use std::collections::HashSet;
use std::path::Path;
use mockforge_plugin_core::{
FilesystemPermissions, NetworkPermissions, PluginCapabilities, PluginId, PluginManifest,
PluginVersion, ResourceLimits,
};
use wasmparser::{Parser, Payload};
use ring::signature;
use shellexpand;
use base64::{engine::general_purpose, Engine as _};
use hex;
use serde_json;
#[derive(Debug, Clone)]
struct PluginSignature {
algorithm: String,
signature: Vec<u8>,
key_id: String,
}
pub struct PluginValidator {
config: PluginLoaderConfig,
security_policies: SecurityPolicies,
}
impl PluginValidator {
pub fn new(config: PluginLoaderConfig) -> Self {
Self {
config,
security_policies: SecurityPolicies::default(),
}
}
pub async fn validate_manifest(&self, manifest: &PluginManifest) -> LoaderResult<()> {
let mut errors = Vec::new();
if let Err(validation_error) = manifest.validate() {
errors.push(PluginLoaderError::manifest(validation_error));
}
if let Err(e) = self.security_policies.validate_manifest(manifest) {
errors.push(e);
}
let mut visited = HashSet::new();
visited.insert(manifest.info.id.clone());
if let Err(e) = self
.validate_dependencies(&manifest.info.id, &manifest.dependencies, &mut visited)
.await
{
errors.push(e);
}
if errors.is_empty() {
Ok(())
} else {
Err(PluginLoaderError::validation(format!(
"Manifest validation failed with {} errors: {}",
errors.len(),
errors.into_iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", ")
)))
}
}
pub fn validate_capabilities(&self, capability_names: &[String]) -> LoaderResult<()> {
let capabilities = PluginCapabilities {
network: NetworkPermissions::default(),
filesystem: FilesystemPermissions::default(),
resources: ResourceLimits::default(),
custom: capability_names
.iter()
.map(|name| (name.clone(), serde_json::Value::Bool(true)))
.collect(),
};
self.security_policies.validate_capabilities(&capabilities)
}
pub async fn validate_wasm_file(&self, wasm_path: &Path) -> LoaderResult<()> {
if !wasm_path.exists() {
return Err(PluginLoaderError::fs("WASM file does not exist".to_string()));
}
let metadata = tokio::fs::metadata(wasm_path)
.await
.map_err(|e| PluginLoaderError::fs(format!("Cannot read WASM file metadata: {}", e)))?;
if !metadata.is_file() {
return Err(PluginLoaderError::fs("WASM path is not a file".to_string()));
}
let file_size = metadata.len();
if file_size > self.security_policies.max_wasm_file_size {
return Err(PluginLoaderError::security(format!(
"WASM file too large: {} bytes (max: {} bytes)",
file_size, self.security_policies.max_wasm_file_size
)));
}
self.validate_wasm_module(wasm_path).await?;
Ok(())
}
pub async fn validate_plugin_file(&self, plugin_path: &Path) -> LoaderResult<PluginManifest> {
if !plugin_path.exists() {
return Err(PluginLoaderError::fs("Plugin path does not exist".to_string()));
}
if !plugin_path.is_dir() {
return Err(PluginLoaderError::fs("Plugin path must be a directory".to_string()));
}
let manifest_path = plugin_path.join("plugin.yaml");
if !manifest_path.exists() {
return Err(PluginLoaderError::manifest("plugin.yaml not found".to_string()));
}
let manifest = PluginManifest::from_file(&manifest_path)
.map_err(|e| PluginLoaderError::manifest(format!("Failed to load manifest: {}", e)))?;
self.validate_manifest(&manifest).await?;
let wasm_files: Vec<_> = std::fs::read_dir(plugin_path)
.map_err(|e| PluginLoaderError::fs(format!("Cannot read plugin directory: {}", e)))?
.filter_map(|entry| entry.ok())
.map(|entry| entry.path())
.filter(|path| path.extension().is_some_and(|ext| ext == "wasm"))
.collect();
if wasm_files.is_empty() {
return Err(PluginLoaderError::load(
"No WebAssembly file found in plugin directory".to_string(),
));
}
if wasm_files.len() > 1 {
return Err(PluginLoaderError::load(
"Multiple WebAssembly files found in plugin directory".to_string(),
));
}
if !self.config.skip_wasm_validation {
self.validate_wasm_file(&wasm_files[0]).await?;
}
Ok(manifest)
}
async fn validate_dependencies(
&self,
current_plugin_id: &PluginId,
dependencies: &std::collections::HashMap<mockforge_plugin_core::PluginId, PluginVersion>,
visited: &mut HashSet<PluginId>,
) -> LoaderResult<()> {
for (plugin_id, version) in dependencies {
if self.would_create_circular_dependency(current_plugin_id, plugin_id, visited) {
return Err(PluginLoaderError::ValidationError {
message: format!(
"Circular dependency detected: '{}' -> '{}'",
current_plugin_id.0, plugin_id.0
),
});
}
if plugin_id.0.is_empty() {
return Err(PluginLoaderError::ValidationError {
message: "Dependency plugin ID cannot be empty".to_string(),
});
}
if plugin_id.0.len() > 100 {
return Err(PluginLoaderError::ValidationError {
message: format!(
"Dependency plugin ID '{}' is too long (max 100 characters)",
plugin_id.0
),
});
}
if version.major == 0 && version.minor == 0 && version.patch == 0 {
tracing::warn!("Dependency '{}' specifies version 0.0.0 which may indicate development/testing", plugin_id.0);
}
if plugin_id.0.contains("..") || plugin_id.0.contains("/") || plugin_id.0.contains("\\")
{
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Dependency plugin ID '{}' contains potentially unsafe characters",
plugin_id.0
),
});
}
}
Ok(())
}
fn would_create_circular_dependency(
&self,
current_plugin_id: &PluginId,
dependency_id: &PluginId,
visited: &mut HashSet<PluginId>,
) -> bool {
if dependency_id == current_plugin_id {
return true;
}
visited.contains(dependency_id)
}
async fn validate_wasm_module(&self, wasm_path: &Path) -> LoaderResult<()> {
let module = wasmtime::Module::from_file(&wasmtime::Engine::default(), wasm_path)
.map_err(|e| PluginLoaderError::wasm(format!("Invalid WASM module: {}", e)))?;
self.security_policies.validate_wasm_module(&module)?;
Ok(())
}
pub async fn validate_plugin_signature(
&self,
plugin_path: &Path,
manifest: &PluginManifest,
) -> LoaderResult<()> {
if self.config.allow_unsigned {
return Ok(());
}
let sig_path = plugin_path.with_extension("sig");
if !sig_path.exists() {
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Plugin '{}' requires a signature but none was found",
manifest.info.id.0
),
});
}
let signature_data =
std::fs::read(&sig_path).map_err(|e| PluginLoaderError::ValidationError {
message: format!("Failed to read signature file: {}", e),
})?;
let signature = self.parse_signature(&signature_data)?;
let plugin_data =
std::fs::read(plugin_path).map_err(|e| PluginLoaderError::ValidationError {
message: format!("Failed to read plugin file: {}", e),
})?;
self.verify_signature(&plugin_data, &signature, manifest).await?;
Ok(())
}
fn parse_signature(&self, data: &[u8]) -> Result<PluginSignature, PluginLoaderError> {
let sig_json: serde_json::Value =
serde_json::from_slice(data).map_err(|e| PluginLoaderError::ValidationError {
message: format!("Invalid signature JSON format: {}", e),
})?;
let algorithm = sig_json
.get("algorithm")
.and_then(|v| v.as_str())
.ok_or_else(|| PluginLoaderError::ValidationError {
message: "Missing or invalid 'algorithm' field".to_string(),
})?
.to_string();
let signature_hex =
sig_json.get("signature").and_then(|v| v.as_str()).ok_or_else(|| {
PluginLoaderError::ValidationError {
message: "Missing or invalid 'signature' field".to_string(),
}
})?;
let key_id = sig_json
.get("key_id")
.and_then(|v| v.as_str())
.ok_or_else(|| PluginLoaderError::ValidationError {
message: "Missing or invalid 'key_id' field".to_string(),
})?
.to_string();
if !["rsa", "ecdsa", "ed25519"].contains(&algorithm.as_str()) {
return Err(PluginLoaderError::ValidationError {
message: format!("Unsupported signature algorithm: {}", algorithm),
});
}
let signature =
hex::decode(signature_hex).map_err(|e| PluginLoaderError::ValidationError {
message: format!("Invalid signature hex: {}", e),
})?;
Ok(PluginSignature {
algorithm,
signature,
key_id,
})
}
async fn verify_signature(
&self,
data: &[u8],
signature: &PluginSignature,
manifest: &PluginManifest,
) -> LoaderResult<()> {
let public_key = self.get_trusted_key(&signature.key_id)?;
match signature.algorithm.as_str() {
"rsa" => {
self.verify_rsa_signature(data, &signature.signature, &public_key)?;
}
"ecdsa" => {
self.verify_ecdsa_signature(data, &signature.signature, &public_key)?;
}
"ed25519" => {
self.verify_ed25519_signature(data, &signature.signature, &public_key)?;
}
_ => {
return Err(PluginLoaderError::ValidationError {
message: format!("Unsupported algorithm: {}", signature.algorithm),
});
}
}
self.validate_key_authorization(&signature.key_id, manifest)?;
Ok(())
}
fn get_trusted_key(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
if !self.config.trusted_keys.contains(&key_id.to_string()) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Key '{}' is not in the trusted keys list", key_id),
});
}
self.load_key_from_store(key_id)
}
fn load_key_from_store(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
if let Ok(key_data) = self.load_key_from_env(key_id) {
tracing::info!("Loaded key '{}' from environment variable", key_id);
return Ok(key_data);
}
if let Some(key_data) = self.config.key_data.get(key_id) {
tracing::info!("Loaded key '{}' from configuration", key_id);
return Ok(key_data.clone());
}
if let Ok(key_data) = self.load_key_from_filesystem(key_id) {
tracing::info!("Loaded key '{}' from filesystem", key_id);
return Ok(key_data);
}
if let Ok(key_data) = self.load_key_from_database(key_id) {
tracing::info!("Loaded key '{}' from database provider", key_id);
return Ok(key_data);
}
if let Ok(key_data) = self.load_key_from_kms(key_id) {
tracing::info!("Loaded key '{}' from key management service", key_id);
return Ok(key_data);
}
Err(PluginLoaderError::SecurityViolation {
violation: format!("Could not find key data for trusted key: {}", key_id),
})
}
fn load_key_from_env(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
self.load_key_material_from_prefixes(key_id, &["MOCKFORGE_KEY"], "environment")
}
fn load_key_material_from_prefixes(
&self,
key_id: &str,
prefixes: &[&str],
source_name: &str,
) -> Result<Vec<u8>, PluginLoaderError> {
let normalized = key_id.to_uppercase().replace("-", "_");
for prefix in prefixes {
let b64_env_key = format!("{}_{}_B64", prefix, normalized);
if let Ok(b64_value) = std::env::var(&b64_env_key) {
match general_purpose::STANDARD.decode(&b64_value) {
Ok(key_data) => return Ok(key_data),
Err(e) => {
tracing::warn!(
"Failed to decode base64 key from {} ({}): {}",
b64_env_key,
source_name,
e
);
}
}
}
let hex_env_key = format!("{}_{}_HEX", prefix, normalized);
if let Ok(hex_value) = std::env::var(&hex_env_key) {
match hex::decode(&hex_value) {
Ok(key_data) => return Ok(key_data),
Err(e) => {
tracing::warn!(
"Failed to decode hex key from {} ({}): {}",
hex_env_key,
source_name,
e
);
}
}
}
let raw_env_key = format!("{}_{}", prefix, normalized);
if let Ok(key_data) = std::env::var(&raw_env_key) {
return Ok(key_data.into_bytes());
}
}
Err(PluginLoaderError::SecurityViolation {
violation: format!("Key not found in {}: {}", source_name, key_id),
})
}
fn load_key_from_directory(
&self,
key_id: &str,
dir: &std::path::Path,
) -> Result<Vec<u8>, PluginLoaderError> {
let candidates = [
dir.join(format!("{}.der", key_id)),
dir.join(format!("{}.pem", key_id)),
dir.join(format!("{}.key", key_id)),
dir.join(format!("{}.bin", key_id)),
];
for path in candidates {
if path.exists() {
return std::fs::read(&path).map_err(|e| PluginLoaderError::SecurityViolation {
violation: format!("Failed to read key file {}: {}", path.display(), e),
});
}
}
Err(PluginLoaderError::SecurityViolation {
violation: format!("Key '{}' not found in directory {}", key_id, dir.display()),
})
}
fn load_key_from_filesystem(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
let key_paths = vec![
format!("~/.mockforge/keys/{}.der", key_id),
format!("~/.mockforge/keys/{}.pem", key_id),
format!("/etc/mockforge/keys/{}.der", key_id),
format!("/etc/mockforge/keys/{}.pem", key_id),
];
for key_path in key_paths {
let expanded_path = shellexpand::tilde(&key_path);
let path = std::path::Path::new(expanded_path.as_ref());
if path.exists() {
match std::fs::read(path) {
Ok(key_data) => return Ok(key_data),
Err(e) => {
tracing::warn!("Failed to read key file {}: {}", path.display(), e);
continue;
}
}
}
}
Err(PluginLoaderError::SecurityViolation {
violation: format!("Key not found in filesystem: {}", key_id),
})
}
#[allow(unused)]
fn load_key_from_database(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
let db_type = std::env::var("MOCKFORGE_DB_TYPE").map_err(|_| {
PluginLoaderError::SecurityViolation {
violation: "Database key loading requires MOCKFORGE_DB_TYPE environment variable"
.to_string(),
}
})?;
let connection_string = std::env::var("MOCKFORGE_DB_CONNECTION").map_err(|_| {
PluginLoaderError::SecurityViolation {
violation:
"Database key loading requires MOCKFORGE_DB_CONNECTION environment variable"
.to_string(),
}
})?;
let table_name =
std::env::var("MOCKFORGE_DB_KEY_TABLE").unwrap_or_else(|_| "plugin_keys".to_string());
tracing::info!("Database key loading configured: type={}, table={}", db_type, table_name);
tracing::debug!("Looking up key '{}' in database-backed key source", key_id);
if let Ok(key_data) =
self.load_key_material_from_prefixes(key_id, &["MOCKFORGE_DB_KEY"], "database env")
{
return Ok(key_data);
}
if let Ok(key_dir) = std::env::var("MOCKFORGE_DB_KEY_DIR") {
let expanded = shellexpand::tilde(&key_dir);
let path = std::path::Path::new(expanded.as_ref());
if path.exists() {
return self.load_key_from_directory(key_id, path);
}
}
Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Database key '{}' not found in configured environment or key directory (connection: {})",
key_id, connection_string
),
})
}
#[allow(unused)]
fn load_key_from_kms(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
let kms_provider = std::env::var("MOCKFORGE_KMS_PROVIDER").map_err(|_| {
PluginLoaderError::SecurityViolation {
violation: "KMS key loading requires MOCKFORGE_KMS_PROVIDER environment variable"
.to_string(),
}
})?;
match kms_provider.to_lowercase().as_str() {
"aws" => self.load_key_from_aws_kms(key_id),
"vault" => self.load_key_from_vault(key_id),
"azure" => self.load_key_from_azure_kv(key_id),
"gcp" => self.load_key_from_gcp_kms(key_id),
_ => Err(PluginLoaderError::SecurityViolation {
violation: format!("Unsupported KMS provider: {}", kms_provider),
}),
}
}
fn load_key_from_aws_kms(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
let region =
std::env::var("MOCKFORGE_KMS_REGION").unwrap_or_else(|_| "us-east-1".to_string());
tracing::info!("AWS KMS key loading configured: region={}", region);
tracing::debug!("Looking up key '{}' in AWS KMS", key_id);
self.load_key_material_from_prefixes(
key_id,
&["MOCKFORGE_AWS_KMS_KEY", "MOCKFORGE_KMS_KEY"],
"AWS KMS",
)
}
fn load_key_from_vault(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
let vault_addr = std::env::var("MOCKFORGE_VAULT_ADDR").map_err(|_| {
PluginLoaderError::SecurityViolation {
violation: "Vault key loading requires MOCKFORGE_VAULT_ADDR environment variable"
.to_string(),
}
})?;
let _vault_token = std::env::var("MOCKFORGE_VAULT_TOKEN").map_err(|_| {
PluginLoaderError::SecurityViolation {
violation: "Vault key loading requires MOCKFORGE_VAULT_TOKEN environment variable"
.to_string(),
}
})?;
tracing::info!("HashCorp Vault key loading configured: addr={}", vault_addr);
tracing::debug!("Looking up key '{}' in Vault", key_id);
self.load_key_material_from_prefixes(
key_id,
&["MOCKFORGE_VAULT_KEY", "MOCKFORGE_KMS_KEY"],
"Vault",
)
}
fn load_key_from_azure_kv(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
tracing::info!("Azure Key Vault key loading requested");
tracing::debug!("Looking up key '{}' in Azure Key Vault", key_id);
self.load_key_material_from_prefixes(
key_id,
&["MOCKFORGE_AZURE_KV_KEY", "MOCKFORGE_KMS_KEY"],
"Azure Key Vault",
)
}
fn load_key_from_gcp_kms(&self, key_id: &str) -> Result<Vec<u8>, PluginLoaderError> {
tracing::info!("Google Cloud KMS key loading requested");
tracing::debug!("Looking up key '{}' in GCP KMS", key_id);
self.load_key_material_from_prefixes(
key_id,
&["MOCKFORGE_GCP_KMS_KEY", "MOCKFORGE_KMS_KEY"],
"Google Cloud KMS",
)
}
fn verify_rsa_signature(
&self,
data: &[u8],
signature: &[u8],
public_key: &[u8],
) -> LoaderResult<()> {
let public_key =
signature::UnparsedPublicKey::new(&signature::RSA_PKCS1_2048_8192_SHA256, public_key);
public_key
.verify(data, signature)
.map_err(|e| PluginLoaderError::SecurityViolation {
violation: format!("RSA signature verification failed: {}", e),
})?;
Ok(())
}
fn verify_ecdsa_signature(
&self,
data: &[u8],
signature: &[u8],
public_key: &[u8],
) -> LoaderResult<()> {
let public_key =
signature::UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_ASN1, public_key);
public_key
.verify(data, signature)
.map_err(|e| PluginLoaderError::SecurityViolation {
violation: format!("ECDSA signature verification failed: {}", e),
})?;
Ok(())
}
fn verify_ed25519_signature(
&self,
data: &[u8],
signature: &[u8],
public_key: &[u8],
) -> LoaderResult<()> {
let public_key = signature::UnparsedPublicKey::new(&signature::ED25519, public_key);
public_key
.verify(data, signature)
.map_err(|e| PluginLoaderError::SecurityViolation {
violation: format!("Ed25519 signature verification failed: {}", e),
})?;
Ok(())
}
fn validate_key_authorization(
&self,
key_id: &str,
manifest: &PluginManifest,
) -> LoaderResult<()> {
if self.config.trusted_keys.contains(&key_id.to_string()) {
return Ok(());
}
Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Key '{}' is not authorized to sign plugins from '{}'",
key_id, manifest.info.author.name
),
})
}
pub async fn get_validation_summary(&self, plugin_path: &Path) -> ValidationSummary {
let mut summary = ValidationSummary::default();
if !plugin_path.exists() {
summary.errors.push("Plugin path does not exist".to_string());
return summary;
}
let manifest_result = self.validate_plugin_file(plugin_path).await;
match manifest_result {
Ok(manifest) => {
summary.manifest_valid = true;
summary.manifest = Some(manifest);
}
Err(e) => {
summary.errors.push(format!("Manifest validation failed: {}", e));
}
}
if let Ok(wasm_path) = self.find_wasm_file(plugin_path) {
let wasm_result = self.validate_wasm_file(&wasm_path).await;
summary.wasm_valid = wasm_result.is_ok();
if let Err(e) = wasm_result {
summary.errors.push(format!("WASM validation failed: {}", e));
}
} else {
summary.errors.push("No WebAssembly file found".to_string());
}
summary.is_valid =
summary.manifest_valid && summary.wasm_valid && summary.errors.is_empty();
summary
}
fn find_wasm_file(&self, plugin_path: &Path) -> LoaderResult<PathBuf> {
let entries = std::fs::read_dir(plugin_path)
.map_err(|e| PluginLoaderError::fs(format!("Cannot read directory: {}", e)))?;
for entry in entries {
let entry =
entry.map_err(|e| PluginLoaderError::fs(format!("Cannot read entry: {}", e)))?;
let path = entry.path();
if let Some(extension) = path.extension() {
if extension == "wasm" {
return Ok(path);
}
}
}
Err(PluginLoaderError::load("No WebAssembly file found".to_string()))
}
}
#[derive(Debug, Clone)]
pub struct SecurityPolicies {
pub max_wasm_file_size: u64,
pub allowed_imports: HashSet<String>,
pub forbidden_imports: HashSet<String>,
pub max_memory_pages: u32,
pub max_functions: u32,
pub allow_floats: bool,
pub allow_simd: bool,
pub allow_network_access: bool,
pub allow_filesystem_read: bool,
pub allow_filesystem_write: bool,
}
impl Default for SecurityPolicies {
fn default() -> Self {
let mut allowed_imports = HashSet::new();
allowed_imports.insert("env".to_string());
allowed_imports.insert("wasi_snapshot_preview1".to_string());
let mut forbidden_imports = HashSet::new();
forbidden_imports.insert("abort".to_string());
forbidden_imports.insert("exit".to_string());
Self {
max_wasm_file_size: 10 * 1024 * 1024, allowed_imports,
forbidden_imports,
max_memory_pages: 256, max_functions: 1000,
allow_floats: true,
allow_simd: false,
allow_network_access: false,
allow_filesystem_read: false,
allow_filesystem_write: false,
}
}
}
impl SecurityPolicies {
pub fn validate_manifest(&self, manifest: &PluginManifest) -> LoaderResult<()> {
let _caps = PluginCapabilities::from_strings(&manifest.capabilities);
Ok(())
}
pub fn validate_capabilities(&self, capabilities: &PluginCapabilities) -> LoaderResult<()> {
if capabilities.resources.max_memory_bytes > self.max_memory_bytes() {
return Err(PluginLoaderError::security(format!(
"Memory limit {} exceeds maximum allowed {}",
capabilities.resources.max_memory_bytes,
self.max_memory_bytes()
)));
}
if capabilities.resources.max_cpu_percent > self.max_cpu_percent() {
return Err(PluginLoaderError::security(format!(
"CPU limit {:.2}% exceeds maximum allowed {:.2}%",
capabilities.resources.max_cpu_percent,
self.max_cpu_percent()
)));
}
Ok(())
}
pub fn validate_wasm_module(&self, module: &wasmtime::Module) -> LoaderResult<()> {
self.validate_imports(module)?;
self.validate_exports(module)?;
self.validate_memory_usage(module)?;
self.check_dangerous_operations(module)?;
self.validate_function_limits(module)?;
self.validate_data_segments(module)?;
Ok(())
}
fn validate_imports(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let _module_info = module.resources_required();
for import in module.imports() {
let module_name = import.module();
let field_name = import.name();
let allowed_modules = [
"wasi_snapshot_preview1",
"wasi:io/streams",
"wasi:filesystem/types",
"mockforge:plugin/host",
];
if !allowed_modules.contains(&module_name) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Disallowed import module: {}", module_name),
});
}
match module_name {
"wasi_snapshot_preview1" => {
self.validate_wasi_import(field_name)?;
}
"mockforge:plugin/host" => {
self.validate_host_import(field_name)?;
}
_ => {
}
}
}
Ok(())
}
fn validate_wasi_import(&self, field_name: &str) -> LoaderResult<()> {
let allowed_functions = [
"fd_read",
"fd_write",
"fd_close",
"fd_fdstat_get",
"path_open",
"path_readlink",
"path_filestat_get",
"clock_time_get",
"proc_exit",
"random_get",
];
if !allowed_functions.contains(&field_name) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Disallowed WASI function: {}", field_name),
});
}
Ok(())
}
fn validate_host_import(&self, field_name: &str) -> LoaderResult<()> {
let allowed_functions = [
"log_message",
"get_config_value",
"store_data",
"retrieve_data",
];
if !allowed_functions.contains(&field_name) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Disallowed host function: {}", field_name),
});
}
Ok(())
}
fn validate_exports(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let _module_info = module.resources_required();
let mut has_memory_export = false;
let mut function_exports = 0;
for export in module.exports() {
match export.ty() {
wasmtime::ExternType::Memory(_) => {
has_memory_export = true;
}
wasmtime::ExternType::Func(_) => {
function_exports += 1;
}
_ => {
}
}
}
if !has_memory_export {
return Err(PluginLoaderError::ValidationError {
message: "WASM module must export memory".to_string(),
});
}
if function_exports > 1000 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Too many function exports: {} (max: 1000)", function_exports),
});
}
Ok(())
}
fn validate_memory_usage(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let _module_info = module.resources_required();
for import in module.imports() {
if let wasmtime::ExternType::Memory(memory_type) = import.ty() {
if let Some(max) = memory_type.maximum() {
if max > 100 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Memory limit too high: {} pages (max: 100)", max),
});
}
}
if memory_type.maximum().is_none() && memory_type.is_shared() {
return Err(PluginLoaderError::SecurityViolation {
violation: "Shared memory without maximum limit not allowed".to_string(),
});
}
}
}
for export in module.exports() {
if let wasmtime::ExternType::Memory(memory_type) = export.ty() {
if let Some(max) = memory_type.maximum() {
if max > 100 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Exported memory limit too high: {} pages", max),
});
}
}
}
}
Ok(())
}
fn check_dangerous_operations(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let _module_info = module.resources_required();
self.validate_function_sizes(module)?;
let suspicious_imports = ["env", "wasi_unstable", "wasi_experimental"];
for import in module.imports() {
if suspicious_imports.contains(&import.module()) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Suspicious import module: {}", import.module()),
});
}
}
Ok(())
}
fn validate_function_limits(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let _module_info = module.resources_required();
let mut function_count = 0;
for export in module.exports() {
if let wasmtime::ExternType::Func(_) = export.ty() {
function_count += 1;
}
}
for import in module.imports() {
if let wasmtime::ExternType::Func(_) = import.ty() {
function_count += 1;
}
}
if function_count > 10000 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!("Too many functions: {} (max: 10000)", function_count),
});
}
Ok(())
}
fn validate_function_sizes(&self, module: &wasmtime::Module) -> LoaderResult<()> {
for export in module.exports() {
if let wasmtime::ExternType::Func(func_type) = export.ty() {
let param_count = func_type.params().len();
let result_count = func_type.results().len();
if param_count > 20 || result_count > 10 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Function '{}' has suspiciously complex signature: {} params, {} results",
export.name(), param_count, result_count
),
});
}
let mut complex_types = 0;
for param in func_type.params() {
match param {
wasmtime::ValType::V128 | wasmtime::ValType::Ref(_) => {
complex_types += 1;
}
_ => {}
}
}
if complex_types > param_count / 2 && param_count > 5 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Function '{}' has unusually complex parameter types: {} complex types out of {} params",
export.name(), complex_types, param_count
),
});
}
}
}
let mut total_functions = 0;
for export in module.exports() {
if let wasmtime::ExternType::Func(_) = export.ty() {
total_functions += 1;
}
}
for import in module.imports() {
if let wasmtime::ExternType::Func(_) = import.ty() {
total_functions += 1;
}
}
if total_functions > 5000 {
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Too many functions: {} (reasonable limit: 5000)",
total_functions
),
});
}
Ok(())
}
fn validate_data_segments(&self, module: &wasmtime::Module) -> LoaderResult<()> {
let wasm_bytes = module.serialize().map_err(|e| PluginLoaderError::ValidationError {
message: format!("Failed to serialize WASM module: {}", e),
})?;
let parser = Parser::new(0);
let payloads =
parser.parse_all(&wasm_bytes).collect::<Result<Vec<_>, _>>().map_err(|e| {
PluginLoaderError::ValidationError {
message: format!("Failed to parse WASM module: {}", e),
}
})?;
let suspicious_patterns = [
"http://",
"https://",
"/bin/",
"/usr/bin/",
"eval(",
"exec(",
"system(",
"shell",
"cmd.exe",
"powershell",
"wget",
"curl",
"nc ",
"netcat",
"python -c",
"ruby -e",
"node -e",
"bash -c",
"sh -c",
];
for payload in payloads {
if let Payload::DataSection(data_section) = payload {
for data_segment_result in data_section {
let data_segment =
data_segment_result.map_err(|e| PluginLoaderError::ValidationError {
message: format!("Failed to read data segment: {}", e),
})?;
let data = data_segment.data;
if let Ok(data_str) = std::str::from_utf8(data) {
for pattern in &suspicious_patterns {
if data_str.contains(pattern) {
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Data segment contains suspicious content: '{}'",
pattern
),
});
}
}
} else {
for pattern in &suspicious_patterns {
if data
.windows(pattern.len())
.any(|window| window == pattern.as_bytes())
{
return Err(PluginLoaderError::SecurityViolation {
violation: format!(
"Data segment contains suspicious content: '{}'",
pattern
),
});
}
}
}
}
}
}
Ok(())
}
pub fn allow_network_access(&self) -> bool {
self.allow_network_access
}
pub fn allow_filesystem_read(&self) -> bool {
self.allow_filesystem_read
}
pub fn allow_filesystem_write(&self) -> bool {
self.allow_filesystem_write
}
pub fn max_memory_bytes(&self) -> usize {
10 * 1024 * 1024 }
pub fn max_cpu_percent(&self) -> f64 {
0.5 }
}
#[derive(Debug, Clone)]
pub struct ValidationSummary {
pub is_valid: bool,
pub manifest_valid: bool,
pub wasm_valid: bool,
pub manifest: Option<PluginManifest>,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
impl Default for ValidationSummary {
fn default() -> Self {
Self {
is_valid: true,
manifest_valid: false,
wasm_valid: false,
manifest: None,
errors: Vec::new(),
warnings: Vec::new(),
}
}
}
impl ValidationSummary {
pub fn add_error<S: Into<String>>(&mut self, error: S) {
self.errors.push(error.into());
self.is_valid = false;
}
pub fn add_warning<S: Into<String>>(&mut self, warning: S) {
self.warnings.push(warning.into());
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
pub fn has_warnings(&self) -> bool {
!self.warnings.is_empty()
}
pub fn error_count(&self) -> usize {
self.errors.len()
}
pub fn warning_count(&self) -> usize {
self.warnings.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_security_policies_creation() {
let policies = SecurityPolicies::default();
assert!(!policies.allow_network_access());
assert!(!policies.allow_filesystem_read());
assert!(!policies.allow_filesystem_write());
assert_eq!(policies.max_memory_bytes(), 10 * 1024 * 1024);
assert_eq!(policies.max_cpu_percent(), 0.5);
}
#[tokio::test]
async fn test_validation_summary() {
let mut summary = ValidationSummary::default();
assert!(summary.is_valid);
assert!(!summary.has_errors());
assert!(!summary.has_warnings());
summary.add_error("Test error");
assert!(!summary.is_valid);
assert!(summary.has_errors());
assert_eq!(summary.error_count(), 1);
summary.add_warning("Test warning");
assert!(summary.has_warnings());
assert_eq!(summary.warning_count(), 1);
}
#[tokio::test]
async fn test_plugin_validator_creation() {
let config = PluginLoaderConfig::default();
let _validator = PluginValidator::new(config);
}
}