use oxify_model::{ExecutionContext, ExecutionResult, Node};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
use thiserror::Error;
#[cfg(feature = "wasm")]
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, Store, TypedFunction};
#[derive(Error, Debug)]
pub enum WasmError {
#[error("Failed to load WASM module: {0}")]
LoadError(String),
#[error("Failed to instantiate WASM module: {0}")]
InstantiationError(String),
#[error("Function not found: {0}")]
FunctionNotFound(String),
#[error("Execution error: {0}")]
ExecutionError(String),
#[error("Memory limit exceeded")]
MemoryLimitExceeded,
#[error("Timeout exceeded")]
TimeoutExceeded,
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("WASM feature not enabled")]
FeatureNotEnabled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmPluginConfig {
pub max_memory_pages: u32,
pub timeout: Duration,
pub enable_fuel_metering: bool,
pub fuel_limit: u64,
}
impl Default for WasmPluginConfig {
fn default() -> Self {
Self {
max_memory_pages: 256, timeout: Duration::from_secs(30),
enable_fuel_metering: true,
fuel_limit: 1_000_000,
}
}
}
impl WasmPluginConfig {
pub fn strict() -> Self {
Self {
max_memory_pages: 64, timeout: Duration::from_secs(5),
enable_fuel_metering: true,
fuel_limit: 100_000,
}
}
pub fn permissive() -> Self {
Self {
max_memory_pages: 1024, timeout: Duration::from_secs(300),
enable_fuel_metering: true,
fuel_limit: 10_000_000,
}
}
}
#[cfg(feature = "wasm")]
pub struct WasmPluginLoader {
config: WasmPluginConfig,
store: Store,
}
#[cfg(feature = "wasm")]
impl WasmPluginLoader {
pub fn new(config: WasmPluginConfig) -> Self {
let store = Store::default();
Self { config, store }
}
pub fn load_from_file(&mut self, path: &Path) -> Result<WasmPlugin, WasmError> {
let wasm_bytes = std::fs::read(path).map_err(|e| WasmError::LoadError(e.to_string()))?;
self.load_from_bytes(&wasm_bytes)
}
pub fn load_from_bytes(&mut self, wasm_bytes: &[u8]) -> Result<WasmPlugin, WasmError> {
let module = Module::new(&self.store, wasm_bytes)
.map_err(|e| WasmError::LoadError(e.to_string()))?;
let memory_type = MemoryType::new(1, Some(self.config.max_memory_pages), false);
let memory = Memory::new(&mut self.store, memory_type)
.map_err(|e| WasmError::InstantiationError(e.to_string()))?;
let imports = imports! {
"env" => {
"memory" => memory.clone(),
"log" => Function::new_typed(&mut self.store, wasm_host_log),
},
};
let instance = Instance::new(&mut self.store, &module, &imports)
.map_err(|e| WasmError::InstantiationError(e.to_string()))?;
Ok(WasmPlugin {
instance,
memory,
config: self.config.clone(),
})
}
}
#[cfg(feature = "wasm")]
fn wasm_host_log(ptr: i32, len: i32) {
tracing::debug!("WASM log: ptr={}, len={}", ptr, len);
}
#[cfg(feature = "wasm")]
pub struct WasmPlugin {
instance: Instance,
memory: Memory,
config: WasmPluginConfig,
}
#[cfg(feature = "wasm")]
impl WasmPlugin {
pub fn execute(
&mut self,
store: &mut Store,
node: &Node,
context: &ExecutionContext,
) -> Result<ExecutionResult, WasmError> {
let input = serde_json::json!({
"node": node,
"context": context,
});
let input_str = serde_json::to_string(&input)
.map_err(|e| WasmError::InvalidParameter(e.to_string()))?;
let execute_fn: TypedFunction<(i32, i32), i32> = self
.instance
.exports
.get_typed_function(store, "execute")
.map_err(|_| WasmError::FunctionNotFound("execute".to_string()))?;
let input_ptr = self.write_to_memory(store, input_str.as_bytes())?;
let timeout = self.config.timeout;
let result = tokio::task::block_in_place(|| {
let start = std::time::Instant::now();
let result_ptr = execute_fn
.call(store, input_ptr as i32, input_str.len() as i32)
.map_err(|e| WasmError::ExecutionError(e.to_string()))?;
if start.elapsed() > timeout {
return Err(WasmError::TimeoutExceeded);
}
Ok(result_ptr)
})?;
let result_str = self.read_from_memory(store, result)?;
let execution_result: ExecutionResult = serde_json::from_str(&result_str)
.map_err(|e| WasmError::ExecutionError(e.to_string()))?;
Ok(execution_result)
}
fn write_to_memory(&mut self, store: &mut Store, data: &[u8]) -> Result<usize, WasmError> {
let ptr = 1024usize;
let memory_view = self.memory.view(store);
for (i, byte) in data.iter().enumerate() {
memory_view
.write_u8((ptr + i) as u64, *byte)
.map_err(|_| WasmError::ExecutionError("Failed to write to memory".to_string()))?;
}
Ok(ptr)
}
fn read_from_memory(&self, store: &Store, ptr: i32) -> Result<String, WasmError> {
let memory_view = self.memory.view(store);
let ptr = ptr as u64;
let mut len_bytes = [0u8; 4];
for (i, byte) in len_bytes.iter_mut().enumerate() {
*byte = memory_view
.read_u8(ptr + (i as u64))
.map_err(|_| WasmError::ExecutionError("Failed to read length".to_string()))?;
}
let len = u32::from_le_bytes(len_bytes) as usize;
let mut data = vec![0u8; len];
for (i, byte) in data.iter_mut().enumerate() {
*byte = memory_view
.read_u8(ptr + 4 + (i as u64))
.map_err(|_| WasmError::ExecutionError("Failed to read data".to_string()))?;
}
String::from_utf8(data).map_err(|e| WasmError::ExecutionError(e.to_string()))
}
pub fn get_exports(&self, _store: &Store) -> Vec<String> {
self.instance
.exports
.iter()
.filter_map(|(name, _)| {
if !name.starts_with("__") {
Some(name.to_string())
} else {
None
}
})
.collect()
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct WasmPluginStats {
pub executions: u64,
pub total_execution_time: Duration,
pub avg_execution_time: Duration,
pub memory_usage: u64,
pub fuel_consumed: u64,
}
#[cfg(not(feature = "wasm"))]
#[allow(dead_code)]
pub struct WasmPluginLoader {
config: WasmPluginConfig,
}
#[cfg(not(feature = "wasm"))]
impl WasmPluginLoader {
#[allow(dead_code)]
pub fn new(config: WasmPluginConfig) -> Self {
Self { config }
}
#[allow(dead_code)]
pub fn load_from_file(&mut self, _path: &Path) -> Result<WasmPlugin, WasmError> {
Err(WasmError::FeatureNotEnabled)
}
#[allow(dead_code)]
pub fn load_from_bytes(&mut self, _wasm_bytes: &[u8]) -> Result<WasmPlugin, WasmError> {
Err(WasmError::FeatureNotEnabled)
}
}
#[cfg(not(feature = "wasm"))]
#[allow(dead_code)]
pub struct WasmPlugin;
#[cfg(not(feature = "wasm"))]
impl WasmPlugin {
#[allow(dead_code)]
pub fn execute(
&mut self,
_node: &Node,
_context: &ExecutionContext,
) -> Result<ExecutionResult, WasmError> {
Err(WasmError::FeatureNotEnabled)
}
#[allow(dead_code)]
pub fn get_exports(&self) -> Vec<String> {
vec![]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_plugin_config_default() {
let config = WasmPluginConfig::default();
assert_eq!(config.max_memory_pages, 256);
assert_eq!(config.timeout, Duration::from_secs(30));
assert!(config.enable_fuel_metering);
}
#[test]
fn test_wasm_plugin_config_strict() {
let config = WasmPluginConfig::strict();
assert_eq!(config.max_memory_pages, 64);
assert_eq!(config.timeout, Duration::from_secs(5));
assert_eq!(config.fuel_limit, 100_000);
}
#[test]
fn test_wasm_plugin_config_permissive() {
let config = WasmPluginConfig::permissive();
assert_eq!(config.max_memory_pages, 1024);
assert_eq!(config.timeout, Duration::from_secs(300));
assert_eq!(config.fuel_limit, 10_000_000);
}
#[test]
fn test_wasm_plugin_loader_creation() {
let config = WasmPluginConfig::default();
let _loader = WasmPluginLoader::new(config);
}
#[test]
fn test_wasm_plugin_stats_default() {
let stats = WasmPluginStats::default();
assert_eq!(stats.executions, 0);
assert_eq!(stats.total_execution_time, Duration::from_secs(0));
}
#[cfg(not(feature = "wasm"))]
#[test]
fn test_wasm_plugin_loader_without_feature() {
let config = WasmPluginConfig::default();
let mut loader = WasmPluginLoader::new(config);
let result = loader.load_from_bytes(&[]);
assert!(matches!(result, Err(WasmError::FeatureNotEnabled)));
}
}