use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use crate::plugin_manifest::{ManifestCapabilities, PluginManifest};
use crate::wasm_host_abi::{HostCallResult, HostFunctionContext};
use crate::wasm_runtime::WasmPluginCapabilities;
#[derive(Debug, Clone)]
pub struct SandboxConfig {
pub max_memory_bytes: usize,
pub fuel_limit: u64,
pub epoch_interval: Duration,
pub max_sandboxes: usize,
pub enable_tracing: bool,
pub pool_size: usize,
pub host_timeout: Duration,
}
impl Default for SandboxConfig {
fn default() -> Self {
Self {
max_memory_bytes: 16 * 1024 * 1024, fuel_limit: 1_000_000_000,
epoch_interval: Duration::from_millis(10),
max_sandboxes: 64,
enable_tracing: false,
pool_size: 4,
host_timeout: Duration::from_secs(30),
}
}
}
pub struct WasmSandboxRuntime {
config: SandboxConfig,
sandboxes: RwLock<HashMap<String, Arc<PluginSandbox>>>,
module_cache: RwLock<HashMap<String, CompiledModule>>,
stats: RwLock<SandboxRuntimeStats>,
host_context: Arc<dyn HostContextProvider + Send + Sync>,
shutdown: Mutex<bool>,
}
pub trait HostContextProvider: Send + Sync {
fn create_context(
&self,
plugin_id: &str,
capabilities: &ManifestCapabilities,
) -> HostFunctionContext;
fn read(&self, ctx: &HostFunctionContext, table_id: u32, row_id: u64) -> HostCallResult;
fn write(
&self,
ctx: &HostFunctionContext,
table_id: u32,
row_id: u64,
data: &[u8],
) -> HostCallResult;
fn vector_search(
&self,
ctx: &HostFunctionContext,
index: &str,
vector: &[f32],
top_k: u32,
) -> HostCallResult;
fn log(&self, ctx: &HostFunctionContext, level: u8, message: &str);
}
#[derive(Clone)]
#[allow(dead_code)]
struct CompiledModule {
wasm_bytes: Vec<u8>,
compiled_at: Instant,
instantiation_count: u64,
source_hash: u64,
}
#[allow(dead_code)]
pub struct PluginSandbox {
plugin_id: String,
manifest: PluginManifest,
capabilities: ManifestCapabilities,
memory_used: Mutex<usize>,
fuel_remaining: Mutex<u64>,
call_count: Mutex<u64>,
created_at: Instant,
last_activity: Mutex<Instant>,
state: Mutex<SandboxState>,
host_context: HostFunctionContext,
stats: Mutex<SandboxStats>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SandboxState {
Ready,
Executing,
Suspended,
Terminated,
Failed,
}
#[derive(Debug, Clone, Default)]
pub struct SandboxStats {
pub total_calls: u64,
pub successful_calls: u64,
pub failed_calls: u64,
pub fuel_consumed: u64,
pub total_execution_time: Duration,
pub peak_memory_bytes: usize,
pub host_calls: u64,
pub host_call_errors: u64,
}
#[derive(Debug, Clone, Default)]
pub struct SandboxRuntimeStats {
pub sandboxes_created: u64,
pub active_sandboxes: u64,
pub total_invocations: u64,
pub total_fuel_consumed: u64,
pub module_cache_hits: u64,
pub module_cache_misses: u64,
pub memory_violations: u64,
pub fuel_exhaustions: u64,
pub capability_denials: u64,
}
impl WasmSandboxRuntime {
pub fn new(
config: SandboxConfig,
host_context: Arc<dyn HostContextProvider + Send + Sync>,
) -> Self {
Self {
config,
sandboxes: RwLock::new(HashMap::new()),
module_cache: RwLock::new(HashMap::new()),
stats: RwLock::new(SandboxRuntimeStats::default()),
host_context,
shutdown: Mutex::new(false),
}
}
pub fn load_plugin(
&self,
plugin_id: &str,
wasm_bytes: &[u8],
manifest: PluginManifest,
) -> Result<(), SandboxError> {
if *self.shutdown.lock().unwrap() {
return Err(SandboxError::RuntimeShutdown);
}
let active = self.sandboxes.read().unwrap().len();
if active >= self.config.max_sandboxes {
return Err(SandboxError::TooManySandboxes {
current: active,
max: self.config.max_sandboxes,
});
}
self.validate_wasm(wasm_bytes)?;
let source_hash = self.compute_hash(wasm_bytes);
let compiled = CompiledModule {
wasm_bytes: wasm_bytes.to_vec(),
compiled_at: Instant::now(),
instantiation_count: 0,
source_hash,
};
self.module_cache
.write()
.unwrap()
.insert(plugin_id.to_string(), compiled);
let capabilities = manifest.capabilities.clone();
let host_context = self.host_context.create_context(plugin_id, &capabilities);
let sandbox = PluginSandbox {
plugin_id: plugin_id.to_string(),
manifest,
capabilities,
memory_used: Mutex::new(0),
fuel_remaining: Mutex::new(self.config.fuel_limit),
call_count: Mutex::new(0),
created_at: Instant::now(),
last_activity: Mutex::new(Instant::now()),
state: Mutex::new(SandboxState::Ready),
host_context,
stats: Mutex::new(SandboxStats::default()),
};
self.sandboxes
.write()
.unwrap()
.insert(plugin_id.to_string(), Arc::new(sandbox));
let mut stats = self.stats.write().unwrap();
stats.sandboxes_created += 1;
stats.active_sandboxes += 1;
Ok(())
}
pub fn invoke(
&self,
plugin_id: &str,
function: &str,
args: &[SandboxValue],
) -> Result<Vec<SandboxValue>, SandboxError> {
let sandbox = self.get_sandbox(plugin_id)?;
{
let state = sandbox.state.lock().unwrap();
match *state {
SandboxState::Terminated => {
return Err(SandboxError::SandboxTerminated(plugin_id.to_string()));
}
SandboxState::Failed => {
return Err(SandboxError::SandboxFailed(plugin_id.to_string()));
}
SandboxState::Executing => {
return Err(SandboxError::AlreadyExecuting(plugin_id.to_string()));
}
_ => {}
}
}
*sandbox.state.lock().unwrap() = SandboxState::Executing;
*sandbox.last_activity.lock().unwrap() = Instant::now();
let start = Instant::now();
let result = self.execute_with_limits(&sandbox, function, args);
let elapsed = start.elapsed();
{
let mut stats = sandbox.stats.lock().unwrap();
stats.total_calls += 1;
stats.total_execution_time += elapsed;
if result.is_ok() {
stats.successful_calls += 1;
} else {
stats.failed_calls += 1;
}
}
*sandbox.call_count.lock().unwrap() += 1;
{
let mut global_stats = self.stats.write().unwrap();
global_stats.total_invocations += 1;
}
*sandbox.state.lock().unwrap() = SandboxState::Ready;
result
}
fn execute_with_limits(
&self,
sandbox: &PluginSandbox,
function: &str,
args: &[SandboxValue],
) -> Result<Vec<SandboxValue>, SandboxError> {
let fuel_available = *sandbox.fuel_remaining.lock().unwrap();
if fuel_available == 0 {
self.stats.write().unwrap().fuel_exhaustions += 1;
return Err(SandboxError::FuelExhausted {
plugin_id: sandbox.plugin_id.clone(),
consumed: 0,
});
}
let memory_used = *sandbox.memory_used.lock().unwrap();
if memory_used > self.config.max_memory_bytes {
self.stats.write().unwrap().memory_violations += 1;
return Err(SandboxError::MemoryLimitExceeded {
plugin_id: sandbox.plugin_id.clone(),
used: memory_used,
limit: self.config.max_memory_bytes,
});
}
let fuel_consumed = (function.len() as u64 * 1000) + (args.len() as u64 * 100);
*sandbox.fuel_remaining.lock().unwrap() -= fuel_consumed.min(fuel_available);
{
let mut stats = sandbox.stats.lock().unwrap();
stats.fuel_consumed += fuel_consumed.min(fuel_available);
}
Ok(vec![SandboxValue::I32(0)])
}
fn get_sandbox(&self, plugin_id: &str) -> Result<Arc<PluginSandbox>, SandboxError> {
self.sandboxes
.read()
.unwrap()
.get(plugin_id)
.cloned()
.ok_or_else(|| SandboxError::PluginNotFound(plugin_id.to_string()))
}
fn validate_wasm(&self, wasm_bytes: &[u8]) -> Result<(), SandboxError> {
if wasm_bytes.len() < 8 {
return Err(SandboxError::InvalidWasm("too short".to_string()));
}
if &wasm_bytes[0..4] != b"\0asm" {
return Err(SandboxError::InvalidWasm(
"invalid magic number".to_string(),
));
}
let version =
u32::from_le_bytes([wasm_bytes[4], wasm_bytes[5], wasm_bytes[6], wasm_bytes[7]]);
if version != 1 {
return Err(SandboxError::InvalidWasm(format!(
"unsupported version: {}",
version
)));
}
Ok(())
}
fn compute_hash(&self, bytes: &[u8]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
pub fn unload_plugin(&self, plugin_id: &str) -> Result<(), SandboxError> {
let sandbox = self
.sandboxes
.write()
.unwrap()
.remove(plugin_id)
.ok_or_else(|| SandboxError::PluginNotFound(plugin_id.to_string()))?;
*sandbox.state.lock().unwrap() = SandboxState::Terminated;
self.stats.write().unwrap().active_sandboxes -= 1;
self.module_cache.write().unwrap().remove(plugin_id);
Ok(())
}
pub fn hot_reload(
&self,
plugin_id: &str,
new_wasm_bytes: &[u8],
new_manifest: PluginManifest,
) -> Result<(), SandboxError> {
self.validate_wasm(new_wasm_bytes)?;
let old_sandbox = self.get_sandbox(plugin_id)?;
loop {
let state = *old_sandbox.state.lock().unwrap();
if state != SandboxState::Executing {
break;
}
std::thread::sleep(Duration::from_millis(10));
}
self.unload_plugin(plugin_id)?;
self.load_plugin(plugin_id, new_wasm_bytes, new_manifest)?;
Ok(())
}
pub fn get_plugin_stats(&self, plugin_id: &str) -> Result<SandboxStats, SandboxError> {
let sandbox = self.get_sandbox(plugin_id)?;
Ok(sandbox.stats.lock().unwrap().clone())
}
pub fn get_runtime_stats(&self) -> SandboxRuntimeStats {
self.stats.read().unwrap().clone()
}
pub fn list_plugins(&self) -> Vec<PluginInfo> {
self.sandboxes
.read()
.unwrap()
.values()
.map(|s| PluginInfo {
id: s.plugin_id.clone(),
name: s.manifest.plugin.name.clone(),
version: s.manifest.plugin.version.clone(),
state: *s.state.lock().unwrap(),
memory_used: *s.memory_used.lock().unwrap(),
call_count: *s.call_count.lock().unwrap(),
uptime: s.created_at.elapsed(),
})
.collect()
}
pub fn reset_fuel(&self, plugin_id: &str) -> Result<(), SandboxError> {
let sandbox = self.get_sandbox(plugin_id)?;
*sandbox.fuel_remaining.lock().unwrap() = self.config.fuel_limit;
Ok(())
}
pub fn shutdown(&self) {
*self.shutdown.lock().unwrap() = true;
let sandboxes: Vec<_> = self.sandboxes.read().unwrap().values().cloned().collect();
for sandbox in sandboxes {
*sandbox.state.lock().unwrap() = SandboxState::Terminated;
}
self.sandboxes.write().unwrap().clear();
self.module_cache.write().unwrap().clear();
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SandboxValue {
I32(i32),
I64(i64),
F32(f32),
F64(f64),
Bytes(Vec<u8>),
String(String),
}
impl SandboxValue {
pub fn as_i32(&self) -> Option<i32> {
match self {
SandboxValue::I32(v) => Some(*v),
_ => None,
}
}
pub fn as_i64(&self) -> Option<i64> {
match self {
SandboxValue::I64(v) => Some(*v),
_ => None,
}
}
pub fn as_bytes(&self) -> Option<&[u8]> {
match self {
SandboxValue::Bytes(v) => Some(v),
SandboxValue::String(s) => Some(s.as_bytes()),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct PluginInfo {
pub id: String,
pub name: String,
pub version: String,
pub state: SandboxState,
pub memory_used: usize,
pub call_count: u64,
pub uptime: Duration,
}
#[derive(Debug, Clone)]
pub enum SandboxError {
PluginNotFound(String),
InvalidWasm(String),
MemoryLimitExceeded {
plugin_id: String,
used: usize,
limit: usize,
},
FuelExhausted { plugin_id: String, consumed: u64 },
CapabilityDenied {
plugin_id: String,
capability: String,
},
TooManySandboxes { current: usize, max: usize },
SandboxTerminated(String),
SandboxFailed(String),
AlreadyExecuting(String),
RuntimeShutdown,
HostError(String),
Timeout {
plugin_id: String,
elapsed: Duration,
},
Trap { plugin_id: String, message: String },
}
impl std::fmt::Display for SandboxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SandboxError::PluginNotFound(id) => write!(f, "plugin not found: {}", id),
SandboxError::InvalidWasm(msg) => write!(f, "invalid WASM: {}", msg),
SandboxError::MemoryLimitExceeded {
plugin_id,
used,
limit,
} => {
write!(
f,
"plugin {} exceeded memory limit: {} > {}",
plugin_id, used, limit
)
}
SandboxError::FuelExhausted {
plugin_id,
consumed,
} => {
write!(f, "plugin {} exhausted fuel after {}", plugin_id, consumed)
}
SandboxError::CapabilityDenied {
plugin_id,
capability,
} => {
write!(f, "plugin {} denied capability: {}", plugin_id, capability)
}
SandboxError::TooManySandboxes { current, max } => {
write!(f, "too many sandboxes: {} >= {}", current, max)
}
SandboxError::SandboxTerminated(id) => write!(f, "sandbox terminated: {}", id),
SandboxError::SandboxFailed(id) => write!(f, "sandbox failed: {}", id),
SandboxError::AlreadyExecuting(id) => write!(f, "sandbox already executing: {}", id),
SandboxError::RuntimeShutdown => write!(f, "runtime is shutdown"),
SandboxError::HostError(msg) => write!(f, "host error: {}", msg),
SandboxError::Timeout { plugin_id, elapsed } => {
write!(f, "plugin {} timed out after {:?}", plugin_id, elapsed)
}
SandboxError::Trap { plugin_id, message } => {
write!(f, "plugin {} trapped: {}", plugin_id, message)
}
}
}
}
impl std::error::Error for SandboxError {}
pub struct DefaultHostContextProvider;
impl HostContextProvider for DefaultHostContextProvider {
fn create_context(
&self,
plugin_id: &str,
capabilities: &ManifestCapabilities,
) -> HostFunctionContext {
let wasm_caps = WasmPluginCapabilities {
can_read_table: capabilities.can_read_table.clone(),
can_write_table: capabilities.can_write_table.clone(),
can_vector_search: capabilities.can_vector_search,
can_index_search: capabilities.can_index_search,
can_call_plugin: capabilities.can_call_plugin.clone(),
memory_limit_bytes: 16 * 1024 * 1024, fuel_limit: 1_000_000,
timeout_ms: 100,
};
HostFunctionContext::new(plugin_id, wasm_caps)
}
fn read(&self, _ctx: &HostFunctionContext, _table_id: u32, _row_id: u64) -> HostCallResult {
HostCallResult::Success(Vec::new())
}
fn write(
&self,
_ctx: &HostFunctionContext,
_table_id: u32,
_row_id: u64,
_data: &[u8],
) -> HostCallResult {
HostCallResult::Ok
}
fn vector_search(
&self,
_ctx: &HostFunctionContext,
_index: &str,
_vector: &[f32],
_top_k: u32,
) -> HostCallResult {
HostCallResult::Success(Vec::new())
}
fn log(&self, _ctx: &HostFunctionContext, _level: u8, _message: &str) {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_runtime() -> WasmSandboxRuntime {
WasmSandboxRuntime::new(
SandboxConfig::default(),
Arc::new(DefaultHostContextProvider),
)
}
fn create_test_manifest() -> PluginManifest {
PluginManifest {
plugin: crate::plugin_manifest::PluginMetadata {
name: "test-plugin".to_string(),
version: "1.0.0".to_string(),
description: "Test plugin".to_string(),
author: "Test Author".to_string(),
license: Some("MIT".to_string()),
homepage: None,
repository: None,
min_kernel_version: None,
},
capabilities: crate::plugin_manifest::ManifestCapabilities::default(),
resources: crate::plugin_manifest::ResourceLimits::default(),
exports: crate::plugin_manifest::ExportedFunctions::default(),
hooks: crate::plugin_manifest::TableHooks::default(),
config_schema: None,
}
}
fn create_valid_wasm() -> Vec<u8> {
vec![
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ]
}
#[test]
fn test_load_plugin() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
let result = runtime.load_plugin("test", &wasm, manifest);
assert!(result.is_ok());
let plugins = runtime.list_plugins();
assert_eq!(plugins.len(), 1);
assert_eq!(plugins[0].id, "test");
}
#[test]
fn test_load_invalid_wasm() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let result = runtime.load_plugin("test", b"not wasm", manifest);
assert!(matches!(result, Err(SandboxError::InvalidWasm(_))));
}
#[test]
fn test_unload_plugin() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
assert_eq!(runtime.list_plugins().len(), 1);
runtime.unload_plugin("test").unwrap();
assert_eq!(runtime.list_plugins().len(), 0);
}
#[test]
fn test_invoke_plugin() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
let result = runtime.invoke("test", "test_fn", &[SandboxValue::I32(42)]);
assert!(result.is_ok());
}
#[test]
fn test_invoke_nonexistent() {
let runtime = create_test_runtime();
let result = runtime.invoke("nonexistent", "fn", &[]);
assert!(matches!(result, Err(SandboxError::PluginNotFound(_))));
}
#[test]
fn test_sandbox_limit() {
let config = SandboxConfig {
max_sandboxes: 2,
..Default::default()
};
let runtime = WasmSandboxRuntime::new(config, Arc::new(DefaultHostContextProvider));
let wasm = create_valid_wasm();
runtime
.load_plugin("p1", &wasm, create_test_manifest())
.unwrap();
runtime
.load_plugin("p2", &wasm, create_test_manifest())
.unwrap();
let result = runtime.load_plugin("p3", &wasm, create_test_manifest());
assert!(matches!(result, Err(SandboxError::TooManySandboxes { .. })));
}
#[test]
fn test_runtime_stats() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
runtime.invoke("test", "fn1", &[]).unwrap();
runtime.invoke("test", "fn2", &[]).unwrap();
let stats = runtime.get_runtime_stats();
assert_eq!(stats.sandboxes_created, 1);
assert_eq!(stats.active_sandboxes, 1);
assert_eq!(stats.total_invocations, 2);
}
#[test]
fn test_plugin_stats() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
runtime.invoke("test", "fn", &[]).unwrap();
let stats = runtime.get_plugin_stats("test").unwrap();
assert_eq!(stats.total_calls, 1);
assert_eq!(stats.successful_calls, 1);
}
#[test]
fn test_hot_reload() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime
.load_plugin("test", &wasm, manifest.clone())
.unwrap();
let result = runtime.hot_reload("test", &wasm, manifest);
assert!(result.is_ok());
let stats = runtime.get_plugin_stats("test").unwrap();
assert_eq!(stats.total_calls, 0);
}
#[test]
fn test_reset_fuel() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
runtime.invoke("test", "some_function", &[]).unwrap();
runtime.reset_fuel("test").unwrap();
let result = runtime.invoke("test", "fn", &[]);
assert!(result.is_ok());
}
#[test]
fn test_shutdown() {
let runtime = create_test_runtime();
let manifest = create_test_manifest();
let wasm = create_valid_wasm();
runtime.load_plugin("test", &wasm, manifest).unwrap();
runtime.shutdown();
assert_eq!(runtime.list_plugins().len(), 0);
let result = runtime.load_plugin("new", &wasm, create_test_manifest());
assert!(matches!(result, Err(SandboxError::RuntimeShutdown)));
}
#[test]
fn test_sandbox_value() {
let v1 = SandboxValue::I32(42);
assert_eq!(v1.as_i32(), Some(42));
assert_eq!(v1.as_i64(), None);
let v2 = SandboxValue::String("hello".to_string());
assert_eq!(v2.as_bytes(), Some(b"hello".as_slice()));
}
}