use super::*;
use crate::invocation_metrics::{InvocationMetricsBus, InvocationTimer};
use crate::memory_tracking::MemoryTracker;
use mockforge_plugin_core::{
PluginCapabilities, PluginContext, PluginHealth, PluginId, PluginMetrics, PluginResult,
PluginState,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use wasmtime::{Engine, Linker, Module, ResourceLimiter, Store};
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder};
pub struct SandboxStoreData {
pub wasi: WasiCtx,
pub tracker: MemoryTracker,
}
fn make_store(engine: &Engine, max_memory_bytes: usize) -> Store<SandboxStoreData> {
let wasi = WasiCtxBuilder::new().inherit_stderr().inherit_stdout().build();
let tracker = MemoryTracker::with_byte_limit(max_memory_bytes);
let mut store = Store::new(engine, SandboxStoreData { wasi, tracker });
store.limiter(|d| &mut d.tracker as &mut dyn ResourceLimiter);
store
}
pub struct PluginSandbox {
engine: Option<Arc<Engine>>,
_config: PluginLoaderConfig,
active_sandboxes: RwLock<HashMap<PluginId, SandboxInstance>>,
metrics_bus: Arc<InvocationMetricsBus>,
}
impl PluginSandbox {
pub fn new(config: PluginLoaderConfig) -> Self {
let engine = Some(Arc::new(Engine::default()));
Self {
engine,
_config: config,
active_sandboxes: RwLock::new(HashMap::new()),
metrics_bus: Arc::new(InvocationMetricsBus::new()),
}
}
pub fn metrics_bus(&self) -> Arc<InvocationMetricsBus> {
self.metrics_bus.clone()
}
pub async fn create_plugin_instance(
&self,
context: &PluginLoadContext,
) -> LoaderResult<PluginInstance> {
let plugin_id = &context.plugin_id;
{
let sandboxes = self.active_sandboxes.read().await;
if sandboxes.contains_key(plugin_id) {
return Err(PluginLoaderError::already_loaded(plugin_id.clone()));
}
}
let sandbox = if let Some(ref engine) = self.engine {
SandboxInstance::new(engine, context, self.metrics_bus.clone()).await?
} else {
SandboxInstance::stub_new(context, self.metrics_bus.clone()).await?
};
let mut sandboxes = self.active_sandboxes.write().await;
sandboxes.insert(plugin_id.clone(), sandbox);
let mut core_instance = PluginInstance::new(plugin_id.clone(), context.manifest.clone());
core_instance.set_state(PluginState::Ready);
Ok(core_instance)
}
pub async fn execute_plugin_function(
&self,
plugin_id: &PluginId,
function_name: &str,
context: &PluginContext,
input: &[u8],
) -> LoaderResult<PluginResult<serde_json::Value>> {
let mut sandboxes = self.active_sandboxes.write().await;
let sandbox = sandboxes
.get_mut(plugin_id)
.ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
sandbox.execute_function(function_name, context, input).await
}
pub async fn get_plugin_health(&self, plugin_id: &PluginId) -> LoaderResult<PluginHealth> {
let sandboxes = self.active_sandboxes.read().await;
let sandbox = sandboxes
.get(plugin_id)
.ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
Ok(sandbox.get_health().await)
}
pub async fn destroy_sandbox(&self, plugin_id: &PluginId) -> LoaderResult<()> {
let mut sandboxes = self.active_sandboxes.write().await;
if let Some(mut sandbox) = sandboxes.remove(plugin_id) {
sandbox.destroy().await?;
}
Ok(())
}
pub async fn list_active_sandboxes(&self) -> Vec<PluginId> {
let sandboxes = self.active_sandboxes.read().await;
sandboxes.keys().cloned().collect()
}
pub async fn get_sandbox_resources(
&self,
plugin_id: &PluginId,
) -> LoaderResult<SandboxResources> {
let sandboxes = self.active_sandboxes.read().await;
let sandbox = sandboxes
.get(plugin_id)
.ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
Ok(sandbox.get_resources().await)
}
pub async fn check_sandbox_health(&self, plugin_id: &PluginId) -> LoaderResult<SandboxHealth> {
let sandboxes = self.active_sandboxes.read().await;
let sandbox = sandboxes
.get(plugin_id)
.ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
Ok(sandbox.check_health().await)
}
}
pub struct SandboxInstance {
plugin_id: PluginId,
_module: Module,
store: Store<SandboxStoreData>,
linker: Linker<SandboxStoreData>,
resources: SandboxResources,
health: SandboxHealth,
limits: ExecutionLimits,
metrics_bus: Arc<InvocationMetricsBus>,
}
impl SandboxInstance {
async fn new(
engine: &Engine,
context: &PluginLoadContext,
metrics_bus: Arc<InvocationMetricsBus>,
) -> LoaderResult<Self> {
let plugin_id = &context.plugin_id;
let module = Module::from_file(engine, &context.plugin_path)
.map_err(|e| PluginLoaderError::wasm(format!("Failed to load WASM module: {}", e)))?;
let plugin_capabilities = PluginCapabilities::default();
let limits = ExecutionLimits::from_capabilities(&plugin_capabilities);
let mut store = make_store(engine, limits.max_memory_bytes);
let linker = Linker::new(engine);
linker
.instantiate(&mut store, &module)
.map_err(|e| PluginLoaderError::wasm(format!("Failed to instantiate module: {}", e)))?;
Ok(Self {
plugin_id: plugin_id.clone(),
_module: module,
store,
linker,
resources: SandboxResources::default(),
health: SandboxHealth::healthy(),
limits,
metrics_bus,
})
}
async fn stub_new(
context: &PluginLoadContext,
metrics_bus: Arc<InvocationMetricsBus>,
) -> LoaderResult<Self> {
let plugin_id = &context.plugin_id;
let engine = Engine::default();
let module = Module::new(&engine, [])
.map_err(|e| PluginLoaderError::wasm(format!("Failed to create stub module: {}", e)))?;
let plugin_capabilities = PluginCapabilities::default();
let limits = ExecutionLimits::from_capabilities(&plugin_capabilities);
let store = make_store(&engine, limits.max_memory_bytes);
let linker = Linker::new(&engine);
Ok(Self {
plugin_id: plugin_id.clone(),
_module: module,
store,
linker,
resources: SandboxResources::default(),
health: SandboxHealth::healthy(),
limits,
metrics_bus,
})
}
async fn execute_function(
&mut self,
function_name: &str,
context: &PluginContext,
input: &[u8],
) -> LoaderResult<PluginResult<serde_json::Value>> {
self.resources.execution_count += 1;
self.resources.last_execution = chrono::Utc::now();
if self.resources.execution_count > self.limits.max_executions {
return Err(PluginLoaderError::resource_limit(format!(
"Maximum executions exceeded: {} allowed, {} used",
self.limits.max_executions, self.resources.execution_count
)));
}
let time_since_last = chrono::Utc::now().signed_duration_since(self.resources.created_at);
let time_since_last_std =
std::time::Duration::from_secs(time_since_last.num_seconds() as u64);
if time_since_last_std > self.limits.max_lifetime {
return Err(PluginLoaderError::resource_limit(format!(
"Maximum lifetime exceeded: {}s allowed, {}s used",
self.limits.max_lifetime.as_secs(),
time_since_last_std.as_secs()
)));
}
let timer = InvocationTimer::start(
self.metrics_bus.clone(),
self.plugin_id.clone(),
function_name.to_string(),
);
let start_time = std::time::Instant::now();
let func_lookup = self.linker.get(&mut self.store, "", function_name);
if func_lookup.is_none() {
self.resources.error_count += 1;
let err_msg = format!("Function '{}' not found", function_name);
timer.finish_failure(err_msg.clone(), self.resources.peak_memory_usage as u64);
return Err(PluginLoaderError::execution(err_msg));
}
let result = self.call_wasm_function(function_name, context, input).await;
let execution_time = start_time.elapsed();
self.resources.total_execution_time += execution_time;
self.resources.last_execution_time = execution_time;
if execution_time > self.resources.max_execution_time {
self.resources.max_execution_time = execution_time;
}
let peak_memory_bytes = self.store.data().tracker.peak_memory() as u64;
if (peak_memory_bytes as usize) > self.resources.peak_memory_usage {
self.resources.peak_memory_usage = peak_memory_bytes as usize;
}
self.resources.memory_usage = self.store.data().tracker.current_memory();
match result {
Ok(data) => {
self.resources.success_count += 1;
timer.finish_success(peak_memory_bytes);
Ok(PluginResult::success(data, execution_time.as_millis() as u64))
}
Err(e) => {
self.resources.error_count += 1;
timer.finish_failure(e.clone(), peak_memory_bytes);
Ok(PluginResult::failure(e, execution_time.as_millis() as u64))
}
}
}
async fn call_wasm_function(
&mut self,
function_name: &str,
context: &PluginContext,
input: &[u8],
) -> Result<serde_json::Value, String> {
let context_json = serde_json::to_string(context)
.map_err(|e| format!("Failed to serialize context: {}", e))?;
let combined_input = format!("{}\n{}", context_json, String::from_utf8_lossy(input));
let func_extern = self
.linker
.get(&mut self.store, "", function_name)
.ok_or_else(|| format!("Function '{}' not found in WASM module", function_name))?;
let func = func_extern
.into_func()
.ok_or_else(|| format!("Export '{}' is not a function", function_name))?;
let input_bytes = combined_input.as_bytes();
let input_len = input_bytes.len() as i32;
let alloc_extern = self.linker.get(&mut self.store, "", "alloc").ok_or_else(|| {
"WASM module must export an 'alloc' function for memory allocation".to_string()
})?;
let alloc_func = alloc_extern
.into_func()
.ok_or_else(|| "Export 'alloc' is not a function".to_string())?;
let mut alloc_result = [wasmtime::Val::I32(0)];
alloc_func
.call(&mut self.store, &[wasmtime::Val::I32(input_len)], &mut alloc_result)
.map_err(|e| format!("Failed to allocate memory for input: {}", e))?;
let input_ptr = match alloc_result[0] {
wasmtime::Val::I32(ptr) => ptr,
_ => return Err("alloc function did not return a valid pointer".to_string()),
};
let memory_extern = self
.linker
.get(&mut self.store, "", "memory")
.ok_or_else(|| "WASM module must export a 'memory'".to_string())?;
let memory = memory_extern
.into_memory()
.ok_or_else(|| "Export 'memory' is not a memory".to_string())?;
memory
.write(&mut self.store, input_ptr as usize, input_bytes)
.map_err(|e| format!("Failed to write input to WASM memory: {}", e))?;
let mut func_result = [wasmtime::Val::I32(0), wasmtime::Val::I32(0)];
func.call(
&mut self.store,
&[wasmtime::Val::I32(input_ptr), wasmtime::Val::I32(input_len)],
&mut func_result,
)
.map_err(|e| format!("Failed to call WASM function '{}': {}", function_name, e))?;
let output_ptr = match func_result[0] {
wasmtime::Val::I32(ptr) => ptr,
_ => {
return Err(format!(
"Function '{}' did not return a valid output pointer",
function_name
))
}
};
let output_len = match func_result[1] {
wasmtime::Val::I32(len) => len,
_ => {
return Err(format!(
"Function '{}' did not return a valid output length",
function_name
))
}
};
let mut output_bytes = vec![0u8; output_len as usize];
memory
.read(&mut self.store, output_ptr as usize, &mut output_bytes)
.map_err(|e| format!("Failed to read output from WASM memory: {}", e))?;
if let Some(dealloc_extern) = self.linker.get(&mut self.store, "", "dealloc") {
if let Some(dealloc_func) = dealloc_extern.into_func() {
let _ = dealloc_func.call(
&mut self.store,
&[wasmtime::Val::I32(input_ptr), wasmtime::Val::I32(input_len)],
&mut [],
);
let _ = dealloc_func.call(
&mut self.store,
&[
wasmtime::Val::I32(output_ptr),
wasmtime::Val::I32(output_len),
],
&mut [],
);
}
}
let output_str = String::from_utf8(output_bytes)
.map_err(|e| format!("Failed to convert output to string: {}", e))?;
serde_json::from_str(&output_str)
.map_err(|e| format!("Failed to parse WASM output as JSON: {}", e))
}
async fn get_health(&self) -> PluginHealth {
if self.health.is_healthy {
PluginHealth::healthy(
"Sandbox is healthy".to_string(),
PluginMetrics {
total_executions: self.resources.execution_count,
successful_executions: self.resources.success_count,
failed_executions: self.resources.error_count,
avg_execution_time_ms: self.resources.avg_execution_time_ms(),
max_execution_time_ms: self.resources.max_execution_time.as_millis() as u64,
memory_usage_bytes: self.resources.memory_usage,
peak_memory_usage_bytes: self.resources.peak_memory_usage,
},
)
} else {
PluginHealth::unhealthy(
PluginState::Error,
self.health.last_error.clone(),
PluginMetrics::default(),
)
}
}
async fn get_resources(&self) -> SandboxResources {
self.resources.clone()
}
async fn check_health(&self) -> SandboxHealth {
self.health.clone()
}
async fn destroy(&mut self) -> LoaderResult<()> {
self.health.is_healthy = false;
self.health.last_error = "Sandbox destroyed".to_string();
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct SandboxResources {
pub execution_count: u64,
pub success_count: u64,
pub error_count: u64,
pub total_execution_time: std::time::Duration,
pub last_execution_time: std::time::Duration,
pub max_execution_time: std::time::Duration,
pub memory_usage: usize,
pub peak_memory_usage: usize,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_execution: chrono::DateTime<chrono::Utc>,
}
impl SandboxResources {
pub fn avg_execution_time_ms(&self) -> f64 {
if self.execution_count == 0 {
0.0
} else {
self.total_execution_time.as_millis() as f64 / self.execution_count as f64
}
}
pub fn success_rate(&self) -> f64 {
if self.execution_count == 0 {
0.0
} else {
(self.success_count as f64 / self.execution_count as f64) * 100.0
}
}
pub fn check_limits(&self, limits: &ExecutionLimits) -> bool {
self.execution_count <= limits.max_executions
&& self.memory_usage <= limits.max_memory_bytes
&& self.total_execution_time <= limits.max_total_time
}
}
#[derive(Debug, Clone)]
pub struct SandboxHealth {
pub is_healthy: bool,
pub last_check: chrono::DateTime<chrono::Utc>,
pub last_error: String,
pub checks: Vec<HealthCheck>,
}
impl SandboxHealth {
pub fn healthy() -> Self {
Self {
is_healthy: true,
last_check: chrono::Utc::now(),
last_error: String::new(),
checks: Vec::new(),
}
}
pub fn unhealthy<S: Into<String>>(error: S) -> Self {
Self {
is_healthy: false,
last_check: chrono::Utc::now(),
last_error: error.into(),
checks: Vec::new(),
}
}
pub fn add_check(&mut self, check: HealthCheck) {
let failed = !check.passed;
let error_message = if failed {
Some(check.message.clone())
} else {
None
};
self.checks.push(check);
self.last_check = chrono::Utc::now();
if failed {
self.is_healthy = false;
if let Some(msg) = error_message {
self.last_error = msg;
}
}
}
pub async fn run_checks(&mut self, resources: &SandboxResources, limits: &ExecutionLimits) {
self.checks.clear();
let memory_check = if resources.memory_usage <= limits.max_memory_bytes {
HealthCheck::pass("Memory usage within limits")
} else {
HealthCheck::fail(format!(
"Memory usage {} exceeds limit {}",
resources.memory_usage, limits.max_memory_bytes
))
};
self.add_check(memory_check);
let execution_check = if resources.execution_count <= limits.max_executions {
HealthCheck::pass("Execution count within limits")
} else {
HealthCheck::fail(format!(
"Execution count {} exceeds limit {}",
resources.execution_count, limits.max_executions
))
};
self.add_check(execution_check);
let success_rate = resources.success_rate();
let success_check = if success_rate >= 90.0 {
HealthCheck::pass(format!("Success rate: {:.1}%", success_rate))
} else {
HealthCheck::fail(format!("Low success rate: {:.1}%", success_rate))
};
self.add_check(success_check);
}
}
#[derive(Debug, Clone)]
pub struct HealthCheck {
pub name: String,
pub passed: bool,
pub message: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl HealthCheck {
pub fn pass<S: Into<String>>(message: S) -> Self {
Self {
name: "health_check".to_string(),
passed: true,
message: message.into(),
timestamp: chrono::Utc::now(),
}
}
pub fn fail<S: Into<String>>(message: S) -> Self {
Self {
name: "health_check".to_string(),
passed: false,
message: message.into(),
timestamp: chrono::Utc::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExecutionLimits {
pub max_executions: u64,
pub max_total_time: std::time::Duration,
pub max_lifetime: std::time::Duration,
pub max_memory_bytes: usize,
pub max_cpu_time_per_execution: std::time::Duration,
}
impl Default for ExecutionLimits {
fn default() -> Self {
Self {
max_executions: 1000,
max_total_time: std::time::Duration::from_secs(300), max_lifetime: std::time::Duration::from_secs(3600), max_memory_bytes: 10 * 1024 * 1024, max_cpu_time_per_execution: std::time::Duration::from_secs(5),
}
}
}
impl ExecutionLimits {
pub fn from_capabilities(capabilities: &PluginCapabilities) -> Self {
Self {
max_executions: 10000, max_total_time: std::time::Duration::from_secs(600), max_lifetime: std::time::Duration::from_secs(86400), max_memory_bytes: capabilities.resources.max_memory_bytes,
max_cpu_time_per_execution: std::time::Duration::from_millis(
(capabilities.resources.max_cpu_percent * 1000.0) as u64,
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sandbox_resources() {
let resources = SandboxResources {
execution_count: 10,
success_count: 8,
error_count: 2,
total_execution_time: std::time::Duration::from_millis(1000),
..Default::default()
};
assert_eq!(resources.avg_execution_time_ms(), 100.0);
assert_eq!(resources.success_rate(), 80.0);
}
#[tokio::test]
async fn test_execution_limits() {
let limits = ExecutionLimits::default();
assert_eq!(limits.max_executions, 1000);
assert_eq!(limits.max_memory_bytes, 10 * 1024 * 1024);
}
#[tokio::test]
async fn test_health_checks() {
let mut health = SandboxHealth::healthy();
assert!(health.is_healthy);
health.add_check(HealthCheck::fail("Test failure"));
assert!(!health.is_healthy);
assert_eq!(health.last_error, "Test failure");
}
#[tokio::test]
async fn test_plugin_sandbox_creation() {
let config = PluginLoaderConfig::default();
let sandbox = PluginSandbox::new(config);
let active = sandbox.list_active_sandboxes().await;
assert!(active.is_empty());
}
}