1use super::*;
7use crate::invocation_metrics::{InvocationMetricsBus, InvocationTimer};
8use crate::memory_tracking::MemoryTracker;
9use mockforge_plugin_core::{
10 PluginCapabilities, PluginContext, PluginHealth, PluginId, PluginMetrics, PluginResult,
11 PluginState,
12};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use wasmtime::{Engine, Linker, Module, ResourceLimiter, Store};
17use wasmtime_wasi::{WasiCtx, WasiCtxBuilder};
18
19pub struct SandboxStoreData {
28 pub wasi: WasiCtx,
30 pub tracker: MemoryTracker,
33}
34
35fn make_store(engine: &Engine, max_memory_bytes: usize) -> Store<SandboxStoreData> {
39 let wasi = WasiCtxBuilder::new().inherit_stderr().inherit_stdout().build();
40 let tracker = MemoryTracker::with_byte_limit(max_memory_bytes);
41 let mut store = Store::new(engine, SandboxStoreData { wasi, tracker });
42 store.limiter(|d| &mut d.tracker as &mut dyn ResourceLimiter);
43 store
44}
45
46pub struct PluginSandbox {
48 engine: Option<Arc<Engine>>,
50 _config: PluginLoaderConfig,
52 active_sandboxes: RwLock<HashMap<PluginId, SandboxInstance>>,
54 metrics_bus: Arc<InvocationMetricsBus>,
58}
59
60impl PluginSandbox {
61 pub fn new(config: PluginLoaderConfig) -> Self {
63 let engine = Some(Arc::new(Engine::default()));
65
66 Self {
67 engine,
68 _config: config,
69 active_sandboxes: RwLock::new(HashMap::new()),
70 metrics_bus: Arc::new(InvocationMetricsBus::new()),
71 }
72 }
73
74 pub fn metrics_bus(&self) -> Arc<InvocationMetricsBus> {
80 self.metrics_bus.clone()
81 }
82
83 pub async fn create_plugin_instance(
85 &self,
86 context: &PluginLoadContext,
87 ) -> LoaderResult<PluginInstance> {
88 let plugin_id = &context.plugin_id;
89
90 {
92 let sandboxes = self.active_sandboxes.read().await;
93 if sandboxes.contains_key(plugin_id) {
94 return Err(PluginLoaderError::already_loaded(plugin_id.clone()));
95 }
96 }
97
98 let sandbox = if let Some(ref engine) = self.engine {
100 SandboxInstance::new(engine, context, self.metrics_bus.clone()).await?
101 } else {
102 SandboxInstance::stub_new(context, self.metrics_bus.clone()).await?
104 };
105
106 let mut sandboxes = self.active_sandboxes.write().await;
108 sandboxes.insert(plugin_id.clone(), sandbox);
109
110 let mut core_instance = PluginInstance::new(plugin_id.clone(), context.manifest.clone());
112 core_instance.set_state(PluginState::Ready);
113
114 Ok(core_instance)
115 }
116
117 pub async fn execute_plugin_function(
119 &self,
120 plugin_id: &PluginId,
121 function_name: &str,
122 context: &PluginContext,
123 input: &[u8],
124 ) -> LoaderResult<PluginResult<serde_json::Value>> {
125 let mut sandboxes = self.active_sandboxes.write().await;
126 let sandbox = sandboxes
127 .get_mut(plugin_id)
128 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
129
130 sandbox.execute_function(function_name, context, input).await
131 }
132
133 pub async fn get_plugin_health(&self, plugin_id: &PluginId) -> LoaderResult<PluginHealth> {
135 let sandboxes = self.active_sandboxes.read().await;
136 let sandbox = sandboxes
137 .get(plugin_id)
138 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
139
140 Ok(sandbox.get_health().await)
141 }
142
143 pub async fn destroy_sandbox(&self, plugin_id: &PluginId) -> LoaderResult<()> {
145 let mut sandboxes = self.active_sandboxes.write().await;
146 if let Some(mut sandbox) = sandboxes.remove(plugin_id) {
147 sandbox.destroy().await?;
148 }
149 Ok(())
150 }
151
152 pub async fn list_active_sandboxes(&self) -> Vec<PluginId> {
154 let sandboxes = self.active_sandboxes.read().await;
155 sandboxes.keys().cloned().collect()
156 }
157
158 pub async fn get_sandbox_resources(
160 &self,
161 plugin_id: &PluginId,
162 ) -> LoaderResult<SandboxResources> {
163 let sandboxes = self.active_sandboxes.read().await;
164 let sandbox = sandboxes
165 .get(plugin_id)
166 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
167
168 Ok(sandbox.get_resources().await)
169 }
170
171 pub async fn check_sandbox_health(&self, plugin_id: &PluginId) -> LoaderResult<SandboxHealth> {
173 let sandboxes = self.active_sandboxes.read().await;
174 let sandbox = sandboxes
175 .get(plugin_id)
176 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
177
178 Ok(sandbox.check_health().await)
179 }
180}
181
182pub struct SandboxInstance {
184 plugin_id: PluginId,
186 _module: Module,
188 store: Store<SandboxStoreData>,
190 linker: Linker<SandboxStoreData>,
192 resources: SandboxResources,
194 health: SandboxHealth,
196 limits: ExecutionLimits,
198 metrics_bus: Arc<InvocationMetricsBus>,
201}
202
203impl SandboxInstance {
204 async fn new(
206 engine: &Engine,
207 context: &PluginLoadContext,
208 metrics_bus: Arc<InvocationMetricsBus>,
209 ) -> LoaderResult<Self> {
210 let plugin_id = &context.plugin_id;
211
212 let module = Module::from_file(engine, &context.plugin_path)
214 .map_err(|e| PluginLoaderError::wasm(format!("Failed to load WASM module: {}", e)))?;
215
216 let plugin_capabilities = PluginCapabilities::default();
219 let limits = ExecutionLimits::from_capabilities(&plugin_capabilities);
220
221 let mut store = make_store(engine, limits.max_memory_bytes);
222
223 let linker = Linker::new(engine);
225
226 linker
232 .instantiate(&mut store, &module)
233 .map_err(|e| PluginLoaderError::wasm(format!("Failed to instantiate module: {}", e)))?;
234
235 Ok(Self {
236 plugin_id: plugin_id.clone(),
237 _module: module,
238 store,
239 linker,
240 resources: SandboxResources::default(),
241 health: SandboxHealth::healthy(),
242 limits,
243 metrics_bus,
244 })
245 }
246
247 async fn stub_new(
249 context: &PluginLoadContext,
250 metrics_bus: Arc<InvocationMetricsBus>,
251 ) -> LoaderResult<Self> {
252 let plugin_id = &context.plugin_id;
253
254 let engine = Engine::default();
256 let module = Module::new(&engine, [])
257 .map_err(|e| PluginLoaderError::wasm(format!("Failed to create stub module: {}", e)))?;
258
259 let plugin_capabilities = PluginCapabilities::default();
260 let limits = ExecutionLimits::from_capabilities(&plugin_capabilities);
261
262 let store = make_store(&engine, limits.max_memory_bytes);
263 let linker = Linker::new(&engine);
264
265 Ok(Self {
266 plugin_id: plugin_id.clone(),
267 _module: module,
268 store,
269 linker,
270 resources: SandboxResources::default(),
271 health: SandboxHealth::healthy(),
272 limits,
273 metrics_bus,
274 })
275 }
276
277 async fn execute_function(
279 &mut self,
280 function_name: &str,
281 context: &PluginContext,
282 input: &[u8],
283 ) -> LoaderResult<PluginResult<serde_json::Value>> {
284 self.resources.execution_count += 1;
286 self.resources.last_execution = chrono::Utc::now();
287
288 if self.resources.execution_count > self.limits.max_executions {
290 return Err(PluginLoaderError::resource_limit(format!(
291 "Maximum executions exceeded: {} allowed, {} used",
292 self.limits.max_executions, self.resources.execution_count
293 )));
294 }
295
296 let time_since_last = chrono::Utc::now().signed_duration_since(self.resources.created_at);
298 let time_since_last_std =
299 std::time::Duration::from_secs(time_since_last.num_seconds() as u64);
300 if time_since_last_std > self.limits.max_lifetime {
301 return Err(PluginLoaderError::resource_limit(format!(
302 "Maximum lifetime exceeded: {}s allowed, {}s used",
303 self.limits.max_lifetime.as_secs(),
304 time_since_last_std.as_secs()
305 )));
306 }
307
308 let timer = InvocationTimer::start(
313 self.metrics_bus.clone(),
314 self.plugin_id.clone(),
315 function_name.to_string(),
316 );
317
318 let start_time = std::time::Instant::now();
320
321 let func_lookup = self.linker.get(&mut self.store, "", function_name);
323 if func_lookup.is_none() {
324 self.resources.error_count += 1;
327 let err_msg = format!("Function '{}' not found", function_name);
328 timer.finish_failure(err_msg.clone(), self.resources.peak_memory_usage as u64);
329 return Err(PluginLoaderError::execution(err_msg));
330 }
331
332 let result = self.call_wasm_function(function_name, context, input).await;
334
335 let execution_time = start_time.elapsed();
337 self.resources.total_execution_time += execution_time;
338 self.resources.last_execution_time = execution_time;
339
340 if execution_time > self.resources.max_execution_time {
341 self.resources.max_execution_time = execution_time;
342 }
343
344 let peak_memory_bytes = self.store.data().tracker.peak_memory() as u64;
352 if (peak_memory_bytes as usize) > self.resources.peak_memory_usage {
353 self.resources.peak_memory_usage = peak_memory_bytes as usize;
354 }
355 self.resources.memory_usage = self.store.data().tracker.current_memory();
356
357 match result {
358 Ok(data) => {
359 self.resources.success_count += 1;
360 timer.finish_success(peak_memory_bytes);
361 Ok(PluginResult::success(data, execution_time.as_millis() as u64))
362 }
363 Err(e) => {
364 self.resources.error_count += 1;
365 timer.finish_failure(e.clone(), peak_memory_bytes);
366 Ok(PluginResult::failure(e, execution_time.as_millis() as u64))
367 }
368 }
369 }
370
371 async fn call_wasm_function(
373 &mut self,
374 function_name: &str,
375 context: &PluginContext,
376 input: &[u8],
377 ) -> Result<serde_json::Value, String> {
378 let context_json = serde_json::to_string(context)
380 .map_err(|e| format!("Failed to serialize context: {}", e))?;
381 let combined_input = format!("{}\n{}", context_json, String::from_utf8_lossy(input));
382
383 let func_extern = self
385 .linker
386 .get(&mut self.store, "", function_name)
387 .ok_or_else(|| format!("Function '{}' not found in WASM module", function_name))?;
388 let func = func_extern
389 .into_func()
390 .ok_or_else(|| format!("Export '{}' is not a function", function_name))?;
391
392 let input_bytes = combined_input.as_bytes();
394 let input_len = input_bytes.len() as i32;
395
396 let alloc_extern = self.linker.get(&mut self.store, "", "alloc").ok_or_else(|| {
398 "WASM module must export an 'alloc' function for memory allocation".to_string()
399 })?;
400 let alloc_func = alloc_extern
401 .into_func()
402 .ok_or_else(|| "Export 'alloc' is not a function".to_string())?;
403
404 let mut alloc_result = [wasmtime::Val::I32(0)];
405 alloc_func
406 .call(&mut self.store, &[wasmtime::Val::I32(input_len)], &mut alloc_result)
407 .map_err(|e| format!("Failed to allocate memory for input: {}", e))?;
408
409 let input_ptr = match alloc_result[0] {
410 wasmtime::Val::I32(ptr) => ptr,
411 _ => return Err("alloc function did not return a valid pointer".to_string()),
412 };
413
414 let memory_extern = self
416 .linker
417 .get(&mut self.store, "", "memory")
418 .ok_or_else(|| "WASM module must export a 'memory'".to_string())?;
419 let memory = memory_extern
420 .into_memory()
421 .ok_or_else(|| "Export 'memory' is not a memory".to_string())?;
422
423 memory
424 .write(&mut self.store, input_ptr as usize, input_bytes)
425 .map_err(|e| format!("Failed to write input to WASM memory: {}", e))?;
426
427 let mut func_result = [wasmtime::Val::I32(0), wasmtime::Val::I32(0)];
429 func.call(
430 &mut self.store,
431 &[wasmtime::Val::I32(input_ptr), wasmtime::Val::I32(input_len)],
432 &mut func_result,
433 )
434 .map_err(|e| format!("Failed to call WASM function '{}': {}", function_name, e))?;
435
436 let output_ptr = match func_result[0] {
438 wasmtime::Val::I32(ptr) => ptr,
439 _ => {
440 return Err(format!(
441 "Function '{}' did not return a valid output pointer",
442 function_name
443 ))
444 }
445 };
446
447 let output_len = match func_result[1] {
448 wasmtime::Val::I32(len) => len,
449 _ => {
450 return Err(format!(
451 "Function '{}' did not return a valid output length",
452 function_name
453 ))
454 }
455 };
456
457 let mut output_bytes = vec![0u8; output_len as usize];
459 memory
460 .read(&mut self.store, output_ptr as usize, &mut output_bytes)
461 .map_err(|e| format!("Failed to read output from WASM memory: {}", e))?;
462
463 if let Some(dealloc_extern) = self.linker.get(&mut self.store, "", "dealloc") {
465 if let Some(dealloc_func) = dealloc_extern.into_func() {
466 let _ = dealloc_func.call(
467 &mut self.store,
468 &[wasmtime::Val::I32(input_ptr), wasmtime::Val::I32(input_len)],
469 &mut [],
470 );
471 let _ = dealloc_func.call(
472 &mut self.store,
473 &[
474 wasmtime::Val::I32(output_ptr),
475 wasmtime::Val::I32(output_len),
476 ],
477 &mut [],
478 );
479 }
480 }
481
482 let output_str = String::from_utf8(output_bytes)
484 .map_err(|e| format!("Failed to convert output to string: {}", e))?;
485
486 serde_json::from_str(&output_str)
487 .map_err(|e| format!("Failed to parse WASM output as JSON: {}", e))
488 }
489
490 async fn get_health(&self) -> PluginHealth {
492 if self.health.is_healthy {
493 PluginHealth::healthy(
494 "Sandbox is healthy".to_string(),
495 PluginMetrics {
496 total_executions: self.resources.execution_count,
497 successful_executions: self.resources.success_count,
498 failed_executions: self.resources.error_count,
499 avg_execution_time_ms: self.resources.avg_execution_time_ms(),
500 max_execution_time_ms: self.resources.max_execution_time.as_millis() as u64,
501 memory_usage_bytes: self.resources.memory_usage,
502 peak_memory_usage_bytes: self.resources.peak_memory_usage,
503 },
504 )
505 } else {
506 PluginHealth::unhealthy(
507 PluginState::Error,
508 self.health.last_error.clone(),
509 PluginMetrics::default(),
510 )
511 }
512 }
513
514 async fn get_resources(&self) -> SandboxResources {
516 self.resources.clone()
517 }
518
519 async fn check_health(&self) -> SandboxHealth {
521 self.health.clone()
522 }
523
524 async fn destroy(&mut self) -> LoaderResult<()> {
526 self.health.is_healthy = false;
528 self.health.last_error = "Sandbox destroyed".to_string();
529 Ok(())
530 }
531}
532
533#[derive(Debug, Clone, Default)]
535pub struct SandboxResources {
536 pub execution_count: u64,
538 pub success_count: u64,
540 pub error_count: u64,
542 pub total_execution_time: std::time::Duration,
544 pub last_execution_time: std::time::Duration,
546 pub max_execution_time: std::time::Duration,
548 pub memory_usage: usize,
550 pub peak_memory_usage: usize,
552 pub created_at: chrono::DateTime<chrono::Utc>,
554 pub last_execution: chrono::DateTime<chrono::Utc>,
556}
557
558impl SandboxResources {
559 pub fn avg_execution_time_ms(&self) -> f64 {
561 if self.execution_count == 0 {
562 0.0
563 } else {
564 self.total_execution_time.as_millis() as f64 / self.execution_count as f64
565 }
566 }
567
568 pub fn success_rate(&self) -> f64 {
570 if self.execution_count == 0 {
571 0.0
572 } else {
573 (self.success_count as f64 / self.execution_count as f64) * 100.0
574 }
575 }
576
577 pub fn check_limits(&self, limits: &ExecutionLimits) -> bool {
579 self.execution_count <= limits.max_executions
580 && self.memory_usage <= limits.max_memory_bytes
581 && self.total_execution_time <= limits.max_total_time
582 }
583}
584
585#[derive(Debug, Clone)]
587pub struct SandboxHealth {
588 pub is_healthy: bool,
590 pub last_check: chrono::DateTime<chrono::Utc>,
592 pub last_error: String,
594 pub checks: Vec<HealthCheck>,
596}
597
598impl SandboxHealth {
599 pub fn healthy() -> Self {
601 Self {
602 is_healthy: true,
603 last_check: chrono::Utc::now(),
604 last_error: String::new(),
605 checks: Vec::new(),
606 }
607 }
608
609 pub fn unhealthy<S: Into<String>>(error: S) -> Self {
611 Self {
612 is_healthy: false,
613 last_check: chrono::Utc::now(),
614 last_error: error.into(),
615 checks: Vec::new(),
616 }
617 }
618
619 pub fn add_check(&mut self, check: HealthCheck) {
621 let failed = !check.passed;
622 let error_message = if failed {
623 Some(check.message.clone())
624 } else {
625 None
626 };
627
628 self.checks.push(check);
629 self.last_check = chrono::Utc::now();
630
631 if failed {
633 self.is_healthy = false;
634 if let Some(msg) = error_message {
635 self.last_error = msg;
636 }
637 }
638 }
639
640 pub async fn run_checks(&mut self, resources: &SandboxResources, limits: &ExecutionLimits) {
642 self.checks.clear();
643
644 let memory_check = if resources.memory_usage <= limits.max_memory_bytes {
646 HealthCheck::pass("Memory usage within limits")
647 } else {
648 HealthCheck::fail(format!(
649 "Memory usage {} exceeds limit {}",
650 resources.memory_usage, limits.max_memory_bytes
651 ))
652 };
653 self.add_check(memory_check);
654
655 let execution_check = if resources.execution_count <= limits.max_executions {
657 HealthCheck::pass("Execution count within limits")
658 } else {
659 HealthCheck::fail(format!(
660 "Execution count {} exceeds limit {}",
661 resources.execution_count, limits.max_executions
662 ))
663 };
664 self.add_check(execution_check);
665
666 let success_rate = resources.success_rate();
668 let success_check = if success_rate >= 90.0 {
669 HealthCheck::pass(format!("Success rate: {:.1}%", success_rate))
670 } else {
671 HealthCheck::fail(format!("Low success rate: {:.1}%", success_rate))
672 };
673 self.add_check(success_check);
674 }
675}
676
677#[derive(Debug, Clone)]
679pub struct HealthCheck {
680 pub name: String,
682 pub passed: bool,
684 pub message: String,
686 pub timestamp: chrono::DateTime<chrono::Utc>,
688}
689
690impl HealthCheck {
691 pub fn pass<S: Into<String>>(message: S) -> Self {
693 Self {
694 name: "health_check".to_string(),
695 passed: true,
696 message: message.into(),
697 timestamp: chrono::Utc::now(),
698 }
699 }
700
701 pub fn fail<S: Into<String>>(message: S) -> Self {
703 Self {
704 name: "health_check".to_string(),
705 passed: false,
706 message: message.into(),
707 timestamp: chrono::Utc::now(),
708 }
709 }
710}
711
712#[derive(Debug, Clone)]
714pub struct ExecutionLimits {
715 pub max_executions: u64,
717 pub max_total_time: std::time::Duration,
719 pub max_lifetime: std::time::Duration,
721 pub max_memory_bytes: usize,
723 pub max_cpu_time_per_execution: std::time::Duration,
725}
726
727impl Default for ExecutionLimits {
728 fn default() -> Self {
729 Self {
730 max_executions: 1000,
731 max_total_time: std::time::Duration::from_secs(300), max_lifetime: std::time::Duration::from_secs(3600), max_memory_bytes: 10 * 1024 * 1024, max_cpu_time_per_execution: std::time::Duration::from_secs(5),
735 }
736 }
737}
738
739impl ExecutionLimits {
740 pub fn from_capabilities(capabilities: &PluginCapabilities) -> Self {
742 Self {
743 max_executions: 10000, max_total_time: std::time::Duration::from_secs(600), max_lifetime: std::time::Duration::from_secs(86400), max_memory_bytes: capabilities.resources.max_memory_bytes,
747 max_cpu_time_per_execution: std::time::Duration::from_millis(
748 (capabilities.resources.max_cpu_percent * 1000.0) as u64,
749 ),
750 }
751 }
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[tokio::test]
759 async fn test_sandbox_resources() {
760 let resources = SandboxResources {
761 execution_count: 10,
762 success_count: 8,
763 error_count: 2,
764 total_execution_time: std::time::Duration::from_millis(1000),
765 ..Default::default()
766 };
767
768 assert_eq!(resources.avg_execution_time_ms(), 100.0);
769 assert_eq!(resources.success_rate(), 80.0);
770 }
771
772 #[tokio::test]
773 async fn test_execution_limits() {
774 let limits = ExecutionLimits::default();
775 assert_eq!(limits.max_executions, 1000);
776 assert_eq!(limits.max_memory_bytes, 10 * 1024 * 1024);
777 }
778
779 #[tokio::test]
780 async fn test_health_checks() {
781 let mut health = SandboxHealth::healthy();
782 assert!(health.is_healthy);
783
784 health.add_check(HealthCheck::fail("Test failure"));
785 assert!(!health.is_healthy);
786 assert_eq!(health.last_error, "Test failure");
787 }
788
789 #[tokio::test]
790 async fn test_plugin_sandbox_creation() {
791 let config = PluginLoaderConfig::default();
792 let sandbox = PluginSandbox::new(config);
793
794 let active = sandbox.list_active_sandboxes().await;
795 assert!(active.is_empty());
796 }
797}