use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::{Context, Result};
use wasmtime::{Config, Engine, Linker, Module, Store};
use crate::types::{PluginInput, PluginOutput};
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
pub max_memory: usize,
pub max_time_secs: u64,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
max_memory: 256 * 1024 * 1024, max_time_secs: 30,
}
}
}
pub fn validate_plugin_module(bytes: &[u8]) -> Result<()> {
let engine = Engine::default();
let module = Module::new(&engine, bytes)?;
if let Some(import) = module.imports().next() {
anyhow::bail!(
"plugin has forbidden import: {}::{}",
import.module(),
import.name()
);
}
let exports: Vec<_> = module.exports().map(|e| e.name()).collect();
if !exports.contains(&"memory") {
anyhow::bail!("plugin must export 'memory'");
}
if !exports.contains(&"alloc") {
anyhow::bail!("plugin must export 'alloc' function");
}
if !exports.contains(&"process") {
anyhow::bail!("plugin must export 'process' function");
}
Ok(())
}
pub struct Plugin {
name: String,
module: Module,
engine: Arc<Engine>,
}
impl Plugin {
pub fn load(path: &Path, _config: &RuntimeConfig) -> Result<Self> {
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let mut engine_config = Config::new();
engine_config.consume_fuel(true);
let engine = Arc::new(Engine::new(&engine_config)?);
let wasm_bytes =
std::fs::read(path).with_context(|| format!("failed to read {}", path.display()))?;
let module = Module::new(&engine, &wasm_bytes)
.map_err(anyhow::Error::from)
.with_context(|| format!("failed to compile {}", path.display()))?;
Ok(Self {
name,
module,
engine,
})
}
pub fn load_bytes(
name: impl Into<String>,
bytes: &[u8],
_config: &RuntimeConfig,
) -> Result<Self> {
let name = name.into();
let mut engine_config = Config::new();
engine_config.consume_fuel(true);
let engine = Arc::new(Engine::new(&engine_config)?);
let module = Module::new(&engine, bytes)?;
Ok(Self {
name,
module,
engine,
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn execute(&self, input: &PluginInput, config: &RuntimeConfig) -> Result<PluginOutput> {
let mut store = Store::new(&self.engine, ());
let fuel = config.max_time_secs * 1_000_000;
store.set_fuel(fuel)?;
let linker = Linker::new(&self.engine);
let instance = linker.instantiate(&mut store, &self.module)?;
let input_bytes = rmp_serde::to_vec(input)?;
let memory = instance
.get_memory(&mut store, "memory")
.ok_or_else(|| anyhow::anyhow!("plugin must export 'memory'"))?;
let alloc = instance
.get_typed_func::<u32, u32>(&mut store, "alloc")
.map_err(anyhow::Error::from)
.context("plugin must export 'alloc' function")?;
let input_ptr = alloc.call(&mut store, input_bytes.len() as u32)?;
memory.write(&mut store, input_ptr as usize, &input_bytes)?;
let process = instance
.get_typed_func::<(u32, u32), u64>(&mut store, "process")
.map_err(anyhow::Error::from)
.context("plugin must export 'process' function")?;
let result = process.call(&mut store, (input_ptr, input_bytes.len() as u32))?;
let output_ptr = (result >> 32) as u32;
let output_len = (result & 0xFFFF_FFFF) as u32;
let mut output_bytes = vec![0u8; output_len as usize];
memory.read(&store, output_ptr as usize, &mut output_bytes)?;
let output: PluginOutput = rmp_serde::from_slice(&output_bytes)?;
Ok(output)
}
}
pub struct PluginManager {
config: RuntimeConfig,
plugins: Vec<Plugin>,
}
impl PluginManager {
pub fn new() -> Self {
Self::with_config(RuntimeConfig::default())
}
pub const fn with_config(config: RuntimeConfig) -> Self {
Self {
config,
plugins: Vec::new(),
}
}
pub fn load(&mut self, path: &Path) -> Result<usize> {
let plugin = Plugin::load(path, &self.config)?;
let index = self.plugins.len();
self.plugins.push(plugin);
Ok(index)
}
pub fn load_bytes(&mut self, name: impl Into<String>, bytes: &[u8]) -> Result<usize> {
let plugin = Plugin::load_bytes(name, bytes, &self.config)?;
let index = self.plugins.len();
self.plugins.push(plugin);
Ok(index)
}
pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
let plugin = self
.plugins
.get(index)
.context("plugin index out of bounds")?;
plugin.execute(input, &self.config)
}
pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
let mut all_errors = Vec::new();
for plugin in &self.plugins {
let output = plugin.execute(&input, &self.config)?;
all_errors.extend(output.errors);
input.directives = output.directives;
}
Ok(PluginOutput {
directives: input.directives,
errors: all_errors,
})
}
pub const fn len(&self) -> usize {
self.plugins.len()
}
pub const fn is_empty(&self) -> bool {
self.plugins.is_empty()
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}
struct TrackedPlugin {
plugin: Plugin,
path: PathBuf,
modified: SystemTime,
}
pub struct WatchingPluginManager {
config: RuntimeConfig,
plugins: Vec<TrackedPlugin>,
name_index: HashMap<String, usize>,
on_reload: Option<Box<dyn Fn(&str) + Send + Sync>>,
}
impl WatchingPluginManager {
pub fn new() -> Self {
Self::with_config(RuntimeConfig::default())
}
pub fn with_config(config: RuntimeConfig) -> Self {
Self {
config,
plugins: Vec::new(),
name_index: HashMap::new(),
on_reload: None,
}
}
pub fn on_reload<F>(&mut self, callback: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.on_reload = Some(Box::new(callback));
}
pub fn load(&mut self, path: impl AsRef<Path>) -> Result<usize> {
let path = path.as_ref();
let abs_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
let metadata = std::fs::metadata(&abs_path)
.with_context(|| format!("failed to stat {}", abs_path.display()))?;
let modified = metadata.modified()?;
let plugin = Plugin::load(&abs_path, &self.config)?;
let name = plugin.name().to_string();
let index = self.plugins.len();
self.plugins.push(TrackedPlugin {
plugin,
path: abs_path,
modified,
});
self.name_index.insert(name, index);
Ok(index)
}
pub fn check_and_reload(&mut self) -> Result<bool> {
let mut reloaded = false;
for tracked in &mut self.plugins {
let metadata = match std::fs::metadata(&tracked.path) {
Ok(m) => m,
Err(_) => continue, };
let current_modified = match metadata.modified() {
Ok(m) => m,
Err(_) => continue,
};
if current_modified > tracked.modified {
match Plugin::load(&tracked.path, &self.config) {
Ok(new_plugin) => {
let name = tracked.plugin.name().to_string();
tracked.plugin = new_plugin;
tracked.modified = current_modified;
reloaded = true;
if let Some(ref callback) = self.on_reload {
callback(&name);
}
}
Err(e) => {
eprintln!(
"warning: failed to reload plugin {}: {}",
tracked.path.display(),
e
);
}
}
}
}
Ok(reloaded)
}
pub fn reload_all(&mut self) -> Result<()> {
for tracked in &mut self.plugins {
let new_plugin = Plugin::load(&tracked.path, &self.config)?;
let metadata = std::fs::metadata(&tracked.path)?;
tracked.plugin = new_plugin;
tracked.modified = metadata.modified()?;
}
Ok(())
}
pub fn get(&self, name: &str) -> Option<&Plugin> {
self.name_index.get(name).map(|&i| &self.plugins[i].plugin)
}
pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
let tracked = self
.plugins
.get(index)
.context("plugin index out of bounds")?;
tracked.plugin.execute(input, &self.config)
}
pub fn execute_by_name(&self, name: &str, input: &PluginInput) -> Result<PluginOutput> {
let index = self
.name_index
.get(name)
.with_context(|| format!("plugin '{name}' not found"))?;
self.execute(*index, input)
}
pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
let mut all_errors = Vec::new();
for tracked in &self.plugins {
let output = tracked.plugin.execute(&input, &self.config)?;
all_errors.extend(output.errors);
input.directives = output.directives;
}
Ok(PluginOutput {
directives: input.directives,
errors: all_errors,
})
}
pub const fn len(&self) -> usize {
self.plugins.len()
}
pub const fn is_empty(&self) -> bool {
self.plugins.is_empty()
}
pub fn plugin_info(&self) -> Vec<(&Path, SystemTime)> {
self.plugins
.iter()
.map(|t| (t.path.as_path(), t.modified))
.collect()
}
}
impl Default for WatchingPluginManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_plugin_validation() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(
result.is_ok(),
"valid plugin should pass validation: {:?}",
result.err()
);
}
#[test]
fn test_wasi_import_rejected() {
let wasm = wat::parse_str(
r#"
(module
(import "wasi_snapshot_preview1" "fd_write"
(func $fd_write (param i32 i32 i32 i32) (result i32))
)
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(
result.is_err(),
"module with WASI import should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("forbidden import"),
"error should mention forbidden import: {err}"
);
assert!(
err.contains("wasi_snapshot_preview1"),
"error should mention WASI: {err}"
);
}
#[test]
fn test_env_import_rejected() {
let wasm = wat::parse_str(
r#"
(module
(import "env" "some_func" (func $some_func))
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(result.is_err(), "module with env import should be rejected");
}
#[test]
fn test_missing_exports_rejected() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(result.is_err(), "module missing alloc should be rejected");
assert!(result.unwrap_err().to_string().contains("alloc"));
}
#[test]
fn test_runtime_config_defaults() {
let config = RuntimeConfig::default();
assert_eq!(config.max_memory, 256 * 1024 * 1024); assert_eq!(config.max_time_secs, 30);
}
#[test]
fn test_missing_memory_rejected() {
let wasm = wat::parse_str(
r#"
(module
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(result.is_err(), "module missing memory should be rejected");
assert!(result.unwrap_err().to_string().contains("memory"));
}
#[test]
fn test_missing_process_rejected() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
)
"#,
)
.expect("valid wat");
let result = validate_plugin_module(&wasm);
assert!(result.is_err(), "module missing process should be rejected");
assert!(result.unwrap_err().to_string().contains("process"));
}
#[test]
fn test_invalid_wasm_rejected() {
let invalid = b"not valid wasm bytes";
let result = validate_plugin_module(invalid);
assert!(result.is_err(), "invalid WASM should be rejected");
}
#[test]
fn test_runtime_config_custom() {
let config = RuntimeConfig {
max_memory: 512 * 1024 * 1024, max_time_secs: 60,
};
assert_eq!(config.max_memory, 512 * 1024 * 1024);
assert_eq!(config.max_time_secs, 60);
}
#[test]
fn test_plugin_manager_new() {
let manager = PluginManager::new();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
}
#[test]
fn test_plugin_manager_with_config() {
let config = RuntimeConfig {
max_memory: 128 * 1024 * 1024,
max_time_secs: 10,
};
let manager = PluginManager::with_config(config);
assert!(manager.is_empty());
}
#[test]
fn test_plugin_manager_default() {
let manager = PluginManager::default();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
}
#[test]
fn test_watching_plugin_manager_new() {
let manager = WatchingPluginManager::new();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
assert!(manager.plugin_info().is_empty());
}
#[test]
fn test_watching_plugin_manager_with_config() {
let config = RuntimeConfig {
max_memory: 64 * 1024 * 1024,
max_time_secs: 5,
};
let manager = WatchingPluginManager::with_config(config);
assert!(manager.is_empty());
}
#[test]
fn test_watching_plugin_manager_default() {
let manager = WatchingPluginManager::default();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
}
#[test]
fn test_watching_plugin_manager_get_unknown() {
let manager = WatchingPluginManager::new();
assert!(manager.get("nonexistent").is_none());
}
#[test]
fn test_plugin_manager_execute_out_of_bounds() {
let manager = PluginManager::new();
let input = crate::types::PluginInput {
directives: vec![],
options: crate::types::PluginOptions::default(),
config: None,
};
let result = manager.execute(0, &input);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("out of bounds"));
}
#[test]
fn test_watching_plugin_manager_execute_out_of_bounds() {
let manager = WatchingPluginManager::new();
let input = crate::types::PluginInput {
directives: vec![],
options: crate::types::PluginOptions::default(),
config: None,
};
let result = manager.execute(0, &input);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("out of bounds"));
}
#[test]
fn test_watching_plugin_manager_execute_by_name_unknown() {
let manager = WatchingPluginManager::new();
let input = crate::types::PluginInput {
directives: vec![],
options: crate::types::PluginOptions::default(),
config: None,
};
let result = manager.execute_by_name("unknown", &input);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_plugin_manager_execute_all_empty() {
let manager = PluginManager::new();
let input = crate::types::PluginInput {
directives: vec![],
options: crate::types::PluginOptions::default(),
config: None,
};
let result = manager.execute_all(input);
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.directives.is_empty());
assert!(output.errors.is_empty());
}
#[test]
fn test_watching_plugin_manager_execute_all_empty() {
let manager = WatchingPluginManager::new();
let input = crate::types::PluginInput {
directives: vec![],
options: crate::types::PluginOptions::default(),
config: None,
};
let result = manager.execute_all(input);
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.directives.is_empty());
assert!(output.errors.is_empty());
}
#[test]
fn test_watching_plugin_manager_check_reload_empty() {
let mut manager = WatchingPluginManager::new();
let result = manager.check_and_reload();
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn test_watching_plugin_manager_reload_all_empty() {
let mut manager = WatchingPluginManager::new();
let result = manager.reload_all();
assert!(result.is_ok()); }
#[test]
fn test_plugin_load_bytes() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let config = RuntimeConfig::default();
let result = Plugin::load_bytes("test_plugin", &wasm, &config);
assert!(result.is_ok());
let plugin = result.unwrap();
assert_eq!(plugin.name(), "test_plugin");
}
#[test]
fn test_plugin_manager_load_bytes() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let mut manager = PluginManager::new();
let result = manager.load_bytes("my_plugin", &wasm);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0); assert_eq!(manager.len(), 1);
assert!(!manager.is_empty());
}
#[test]
fn test_plugin_manager_multiple_plugins() {
let wasm = wat::parse_str(
r#"
(module
(memory (export "memory") 1)
(func (export "alloc") (param i32) (result i32)
i32.const 0
)
(func (export "process") (param i32 i32) (result i64)
i64.const 0
)
)
"#,
)
.expect("valid wat");
let mut manager = PluginManager::new();
manager.load_bytes("plugin1", &wasm).unwrap();
manager.load_bytes("plugin2", &wasm).unwrap();
manager.load_bytes("plugin3", &wasm).unwrap();
assert_eq!(manager.len(), 3);
}
#[test]
fn test_validate_truncated_wasm() {
let truncated = &[0x00, 0x61, 0x73, 0x6d]; let result = validate_plugin_module(truncated);
assert!(result.is_err());
}
#[test]
fn test_validate_wrong_magic() {
let wrong_magic = &[0xFF, 0xFF, 0xFF, 0xFF];
let result = validate_plugin_module(wrong_magic);
assert!(result.is_err());
}
#[test]
fn test_validate_empty_wasm() {
let empty: &[u8] = &[];
let result = validate_plugin_module(empty);
assert!(result.is_err());
}
}