#[cfg(feature = "wasm-sandbox")]
use std::path::Path;
#[cfg(feature = "wasm-sandbox")]
use std::collections::HashMap;
#[cfg(feature = "wasm-sandbox")]
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ResourceKind {
Memory,
Instructions,
ModuleSize,
}
#[derive(Debug, Clone, Error)]
pub enum WasmError {
#[error("WASM module '{0}' not found")]
ModuleNotFound(String),
#[error("Function '{1}' not found in module '{0}'")]
FunctionNotFound(String, String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Out of resources: {kind:?} limit {limit} exceeded")]
OutOfResources {
kind: ResourceKind,
limit: u64,
},
#[error("Module instantiation failed: {0}")]
InstantiationFailed(String),
#[error("Module too large: {size} bytes exceeds limit of {limit} bytes")]
ModuleTooLarge { size: u64, limit: u64 },
#[error("WASM sandbox feature is disabled")]
FeatureDisabled,
}
#[derive(Debug, Clone)]
pub struct WasmConfig {
pub max_memory_bytes: u64,
pub max_instructions: u64,
pub max_module_size_bytes: u64,
}
impl Default for WasmConfig {
fn default() -> Self {
Self {
max_memory_bytes: 50 * 1024 * 1024,
max_instructions: 10_000_000,
max_module_size_bytes: 10 * 1024 * 1024,
}
}
}
#[cfg(feature = "wasm-sandbox")]
pub struct WasmSandbox {
engine: wasmtime::Engine,
linker: wasmtime::Linker<wasmtime::WasiCtx>,
config: WasmConfig,
modules: RwLock<HashMap<String, wasmtime::Module>>,
}
#[cfg(not(feature = "wasm-sandbox"))]
pub struct WasmSandbox;
#[cfg(feature = "wasm-sandbox")]
impl WasmSandbox {
pub fn new(config: WasmConfig) -> Result<Self, WasmError> {
let mut engine_config = wasmtime::Config::new();
engine_config.consume_fuel(true);
let engine = wasmtime::Engine::new(&engine_config)
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
let mut linker = wasmtime::Linker::new(&engine);
let wasi_ctx = wasmtime_wasi::WasiCtxBuilder::new().build();
linker
.define_wasi(wasi_ctx)
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
Ok(Self {
engine,
linker,
config,
modules: RwLock::new(HashMap::new()),
})
}
pub fn load_module(&self, name: &str, wasm_bytes: &[u8]) -> Result<(), WasmError> {
let module_size = wasm_bytes.len() as u64;
if module_size > self.config.max_module_size_bytes {
return Err(WasmError::ModuleTooLarge {
size: module_size,
limit: self.config.max_module_size_bytes,
});
}
let module = wasmtime::Module::from_binary(&self.engine, wasm_bytes)
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
let mut modules = self.modules.write();
modules.insert(name.to_string(), module);
Ok(())
}
pub fn load_module_from_file(&self, name: &str, path: &Path) -> Result<(), WasmError> {
let wasm_bytes = std::fs::read(path)
.map_err(|e| WasmError::InstantiationFailed(format!("Failed to read file: {}", e)))?;
self.load_module(name, &wasm_bytes)
}
pub async fn execute_tool(
&self,
module_name: &str,
func_name: &str,
input_json: serde_json::Value,
) -> Result<serde_json::Value, WasmError> {
let module = {
let modules = self.modules.read();
modules
.get(module_name)
.cloned()
.ok_or_else(|| WasmError::ModuleNotFound(module_name.to_string()))?
};
let mut store =
wasmtime::Store::new(&self.engine, wasmtime_wasi::WasiCtxBuilder::new().build());
store
.set_fuel(self.config.max_instructions)
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
let instance = self
.linker
.instantiate(&mut store, &module)
.await
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
let func = instance
.get_typed_func::<(i32, i32), (i32, i32)>(&mut store, func_name)
.map_err(|_| {
WasmError::FunctionNotFound(module_name.to_string(), func_name.to_string())
})?;
let input_bytes = serde_json::to_vec(&input_json)
.map_err(|e| WasmError::ExecutionFailed(format!("Failed to serialize input: {}", e)))?;
let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| {
WasmError::ExecutionFailed("Module does not export 'memory'".to_string())
})?;
let input_ptr = 0i32;
memory
.write(&mut store, input_ptr as usize, &input_bytes)
.map_err(|e| {
WasmError::ExecutionFailed(format!("Failed to write input to memory: {}", e))
})?;
let result = func
.call(&mut store, (input_ptr, input_bytes.len() as i32))
.map_err(|e| {
let fuel_err = store
.fuel_remaining()
.map(|remaining| remaining == 0)
.unwrap_or(false);
if fuel_err {
WasmError::OutOfResources {
kind: ResourceKind::Instructions,
limit: self.config.max_instructions,
}
} else {
WasmError::ExecutionFailed(e.to_string())
}
})?;
let output_len = result.1 as usize;
let output_bytes = memory
.read(&store, result.0 as usize, output_len)
.map_err(|e| {
WasmError::ExecutionFailed(format!("Failed to read output from memory: {}", e))
})?;
let output: serde_json::Value = serde_json::from_slice(&output_bytes).map_err(|e| {
WasmError::ExecutionFailed(format!("Failed to deserialize output: {}", e))
})?;
Ok(output)
}
pub fn list_modules(&self) -> Vec<String> {
let modules = self.modules.read();
modules.keys().cloned().collect()
}
pub fn unload_module(&self, name: &str) -> bool {
let mut modules = self.modules.write();
modules.remove(name).is_some()
}
}
#[cfg(not(feature = "wasm-sandbox"))]
impl WasmSandbox {
pub fn new(_config: WasmConfig) -> Result<Self, WasmError> {
Ok(Self)
}
pub fn load_module(&self, _name: &str, _wasm: &[u8]) -> Result<(), WasmError> {
Err(WasmError::FeatureDisabled)
}
pub fn load_module_from_file(&self, _name: &str, _path: &Path) -> Result<(), WasmError> {
Err(WasmError::FeatureDisabled)
}
pub async fn execute_tool(
&self,
_module_name: &str,
_func_name: &str,
_input_json: serde_json::Value,
) -> Result<serde_json::Value, WasmError> {
Err(WasmError::FeatureDisabled)
}
pub fn list_modules(&self) -> Vec<String> {
vec![]
}
pub fn unload_module(&self, _name: &str) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_config_default() {
let config = WasmConfig::default();
assert_eq!(config.max_memory_bytes, 50 * 1024 * 1024);
assert_eq!(config.max_instructions, 10_000_000);
assert_eq!(config.max_module_size_bytes, 10 * 1024 * 1024);
}
#[test]
fn test_wasm_error_display() {
let err = WasmError::ModuleNotFound("test".to_string());
assert_eq!(format!("{}", err), "WASM module 'test' not found");
let err = WasmError::FunctionNotFound("mod".to_string(), "func".to_string());
assert_eq!(
format!("{}", err),
"Function 'func' not found in module 'mod'"
);
let err = WasmError::FeatureDisabled;
assert_eq!(format!("{}", err), "WASM sandbox feature is disabled");
}
#[test]
fn test_resource_kind_serde() {
let memory = serde_json::to_string(&ResourceKind::Memory).unwrap();
let instructions = serde_json::to_string(&ResourceKind::Instructions).unwrap();
let module_size = serde_json::to_string(&ResourceKind::ModuleSize).unwrap();
assert_eq!(memory, "\"Memory\"");
assert_eq!(instructions, "\"Instructions\"");
assert_eq!(module_size, "\"ModuleSize\"");
}
#[cfg(not(feature = "wasm-sandbox"))]
mod stub_tests {
use super::*;
#[test]
fn test_stub_new() {
let config = WasmConfig::default();
let sandbox = WasmSandbox::new(config);
assert!(sandbox.is_ok());
}
#[test]
fn test_stub_load_module() {
let config = WasmConfig::default();
let sandbox = WasmSandbox::new(config).unwrap();
let result = sandbox.load_module("test", &[0, 1, 2]);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WasmError::FeatureDisabled));
}
#[test]
fn test_stub_list_modules() {
let config = WasmConfig::default();
let sandbox = WasmSandbox::new(config).unwrap();
let modules = sandbox.list_modules();
assert!(modules.is_empty());
}
#[test]
fn test_stub_unload_module() {
let config = WasmConfig::default();
let sandbox = WasmSandbox::new(config).unwrap();
let result = sandbox.unload_module("test");
assert!(!result);
}
#[tokio::test]
async fn test_stub_execute_tool() {
let config = WasmConfig::default();
let sandbox = WasmSandbox::new(config).unwrap();
let result = sandbox
.execute_tool("test", "func", serde_json::json!({}))
.await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WasmError::FeatureDisabled));
}
}
}