use crate::error::{KernelError, KernelResult};
use crate::plugin::{Extension, ExtensionCapability, ExtensionInfo, ObservabilityExtension};
use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct WasmPluginCapabilities {
pub can_read_table: Vec<String>,
pub can_write_table: Vec<String>,
pub can_index_search: bool,
pub can_vector_search: bool,
pub can_call_plugin: Vec<String>,
pub memory_limit_bytes: u64,
pub fuel_limit: u64,
pub timeout_ms: u64,
}
impl Default for WasmPluginCapabilities {
fn default() -> Self {
Self {
can_read_table: vec![],
can_write_table: vec![],
can_index_search: false,
can_vector_search: false,
can_call_plugin: vec![],
memory_limit_bytes: 16 * 1024 * 1024, fuel_limit: 1_000_000, timeout_ms: 100, }
}
}
impl WasmPluginCapabilities {
pub fn observability_only() -> Self {
Self {
can_read_table: vec![],
can_write_table: vec![],
can_index_search: false,
can_vector_search: false,
can_call_plugin: vec![],
memory_limit_bytes: 4 * 1024 * 1024, fuel_limit: 100_000, timeout_ms: 10, }
}
pub fn read_only(tables: Vec<String>) -> Self {
Self {
can_read_table: tables,
can_write_table: vec![],
can_index_search: true,
can_vector_search: true,
can_call_plugin: vec![],
memory_limit_bytes: 64 * 1024 * 1024, fuel_limit: 10_000_000, timeout_ms: 1000, }
}
pub fn can_read(&self, table_name: &str) -> bool {
self.can_read_table.iter().any(|pattern| {
if pattern == "*" {
true
} else if pattern.ends_with('*') {
table_name.starts_with(&pattern[..pattern.len() - 1])
} else {
pattern == table_name
}
})
}
pub fn can_write(&self, table_name: &str) -> bool {
self.can_write_table.iter().any(|pattern| {
if pattern == "*" {
true
} else if pattern.ends_with('*') {
table_name.starts_with(&pattern[..pattern.len() - 1])
} else {
pattern == table_name
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WasmPluginState {
Loading,
Ready,
Executing,
Trapped,
Unloading,
Unloaded,
}
#[derive(Debug, Clone, Default)]
pub struct WasmPluginStats {
pub total_calls: u64,
pub total_fuel_consumed: u64,
pub total_execution_us: u64,
pub trap_count: u64,
pub peak_memory_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct WasmInstanceConfig {
pub capabilities: WasmPluginCapabilities,
pub debug_mode: bool,
pub enable_fuel: bool,
pub enable_epochs: bool,
pub epoch_interval_ms: u64,
}
impl Default for WasmInstanceConfig {
fn default() -> Self {
Self {
capabilities: WasmPluginCapabilities::default(),
debug_mode: false,
enable_fuel: true,
enable_epochs: true,
epoch_interval_ms: 10, }
}
}
pub struct WasmPluginInstance {
name: String,
state: RwLock<WasmPluginState>,
config: WasmInstanceConfig,
stats: RwLock<WasmPluginStats>,
fuel_remaining: AtomicU64,
info: ExtensionInfo,
module_hash: [u8; 32],
}
impl WasmPluginInstance {
pub fn new(name: &str, _wasm_bytes: &[u8], config: WasmInstanceConfig) -> KernelResult<Self> {
let module_hash = Self::compute_hash(_wasm_bytes);
Ok(Self {
name: name.to_string(),
state: RwLock::new(WasmPluginState::Loading),
config: config.clone(),
stats: RwLock::new(WasmPluginStats::default()),
fuel_remaining: AtomicU64::new(config.capabilities.fuel_limit),
info: ExtensionInfo {
name: name.to_string(),
version: "1.0.0".to_string(),
description: format!("WASM plugin: {}", name),
author: "SochDB".to_string(),
capabilities: vec![ExtensionCapability::Custom("wasm".to_string())],
},
module_hash,
})
}
pub fn init(&self) -> KernelResult<()> {
*self.state.write() = WasmPluginState::Ready;
Ok(())
}
pub fn call(&self, func_name: &str, args: &[WasmValue]) -> KernelResult<Vec<WasmValue>> {
{
let state = self.state.read();
if *state != WasmPluginState::Ready {
return Err(KernelError::Plugin {
message: format!("plugin {} not ready, state: {:?}", self.name, *state),
});
}
}
*self.state.write() = WasmPluginState::Executing;
let start = Instant::now();
if self.config.enable_fuel {
let remaining = self.fuel_remaining.load(Ordering::Acquire);
if remaining == 0 {
*self.state.write() = WasmPluginState::Trapped;
return Err(KernelError::Plugin {
message: format!("plugin {} exhausted fuel limit", self.name),
});
}
}
let result = self.simulate_call(func_name, args);
{
let mut stats = self.stats.write();
stats.total_calls += 1;
stats.total_execution_us += start.elapsed().as_micros() as u64;
let fuel_used = 100 + args.len() as u64 * 10;
stats.total_fuel_consumed += fuel_used;
self.fuel_remaining.fetch_sub(
fuel_used.min(self.fuel_remaining.load(Ordering::Acquire)),
Ordering::AcqRel,
);
}
*self.state.write() = WasmPluginState::Ready;
result
}
pub fn refuel(&self) {
self.fuel_remaining
.store(self.config.capabilities.fuel_limit, Ordering::Release);
}
pub fn stats(&self) -> WasmPluginStats {
self.stats.read().clone()
}
pub fn state(&self) -> WasmPluginState {
*self.state.read()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn capabilities(&self) -> &WasmPluginCapabilities {
&self.config.capabilities
}
pub fn module_hash(&self) -> &[u8; 32] {
&self.module_hash
}
fn simulate_call(&self, func_name: &str, args: &[WasmValue]) -> KernelResult<Vec<WasmValue>> {
match func_name {
"on_insert" | "on_update" | "on_delete" => {
Ok(vec![WasmValue::I32(0)])
}
"get_metrics" => {
Ok(vec![WasmValue::F64(42.0)])
}
"transform" => {
if args.is_empty() {
Ok(vec![WasmValue::I32(0)])
} else {
Ok(vec![args[0].clone()])
}
}
_ => Err(KernelError::Plugin {
message: format!("unknown function: {}", func_name),
}),
}
}
fn compute_hash(bytes: &[u8]) -> [u8; 32] {
let mut hash = [0u8; 32];
let crc = crc32fast::hash(bytes);
for i in 0..8 {
hash[i * 4..(i + 1) * 4].copy_from_slice(&crc.to_le_bytes());
}
hash
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum WasmValue {
I32(i32),
I64(i64),
F32(f32),
F64(f64),
ExternRef(u64),
}
impl WasmValue {
pub fn as_i32(&self) -> Option<i32> {
match self {
WasmValue::I32(v) => Some(*v),
_ => None,
}
}
pub fn as_i64(&self) -> Option<i64> {
match self {
WasmValue::I64(v) => Some(*v),
_ => None,
}
}
pub fn as_f32(&self) -> Option<f32> {
match self {
WasmValue::F32(v) => Some(*v),
_ => None,
}
}
pub fn as_f64(&self) -> Option<f64> {
match self {
WasmValue::F64(v) => Some(*v),
_ => None,
}
}
}
pub struct WasmPluginRegistry {
plugins: RwLock<HashMap<String, Arc<WasmPluginInstance>>>,
load_order: RwLock<Vec<String>>,
total_calls: AtomicU64,
total_traps: AtomicU64,
}
impl Default for WasmPluginRegistry {
fn default() -> Self {
Self::new()
}
}
impl WasmPluginRegistry {
pub fn new() -> Self {
Self {
plugins: RwLock::new(HashMap::new()),
load_order: RwLock::new(Vec::new()),
total_calls: AtomicU64::new(0),
total_traps: AtomicU64::new(0),
}
}
pub fn load(
&self,
name: &str,
wasm_bytes: &[u8],
config: WasmInstanceConfig,
) -> KernelResult<()> {
if self.plugins.read().contains_key(name) {
return Err(KernelError::Plugin {
message: format!("plugin '{}' already registered", name),
});
}
let instance = WasmPluginInstance::new(name, wasm_bytes, config)?;
instance.init()?;
let arc = Arc::new(instance);
self.plugins.write().insert(name.to_string(), arc);
self.load_order.write().push(name.to_string());
Ok(())
}
pub fn load_from_file(
&self,
name: &str,
path: &Path,
config: WasmInstanceConfig,
) -> KernelResult<()> {
let wasm_bytes = std::fs::read(path).map_err(|e| KernelError::Plugin {
message: format!("failed to read WASM file: {}", e),
})?;
self.load(name, &wasm_bytes, config)
}
pub fn unload(&self, name: &str) -> KernelResult<()> {
let mut plugins = self.plugins.write();
if !plugins.contains_key(name) {
return Err(KernelError::Plugin {
message: format!("plugin '{}' not found", name),
});
}
if let Some(plugin) = plugins.get(name) {
*plugin.state.write() = WasmPluginState::Unloading;
}
plugins.remove(name);
self.load_order.write().retain(|n| n != name);
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<WasmPluginInstance>> {
self.plugins.read().get(name).cloned()
}
pub fn call(
&self,
plugin_name: &str,
func_name: &str,
args: &[WasmValue],
) -> KernelResult<Vec<WasmValue>> {
let plugin = self.get(plugin_name).ok_or_else(|| KernelError::Plugin {
message: format!("plugin '{}' not found", plugin_name),
})?;
self.total_calls.fetch_add(1, Ordering::Relaxed);
match plugin.call(func_name, args) {
Ok(result) => Ok(result),
Err(e) => {
self.total_traps.fetch_add(1, Ordering::Relaxed);
Err(e)
}
}
}
pub fn list(&self) -> Vec<String> {
self.load_order.read().clone()
}
pub fn count(&self) -> usize {
self.plugins.read().len()
}
pub fn global_stats(&self) -> (u64, u64) {
(
self.total_calls.load(Ordering::Relaxed),
self.total_traps.load(Ordering::Relaxed),
)
}
pub fn shutdown_all(&self) -> KernelResult<()> {
let order = self.load_order.read().clone();
for name in order.iter().rev() {
if let Err(e) = self.unload(name) {
eprintln!("warning: failed to unload plugin {}: {}", name, e);
}
}
Ok(())
}
}
pub struct WasmObservabilityPlugin {
instance: Arc<WasmPluginInstance>,
}
impl WasmObservabilityPlugin {
pub fn new(instance: Arc<WasmPluginInstance>) -> Self {
Self { instance }
}
}
impl Extension for WasmObservabilityPlugin {
fn info(&self) -> ExtensionInfo {
self.instance.info.clone()
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
impl ObservabilityExtension for WasmObservabilityPlugin {
fn counter_inc(&self, name: &str, value: u64, labels: &[(&str, &str)]) {
let _ = self.instance.call(
"counter_inc",
&[
WasmValue::I64(name.as_ptr() as i64),
WasmValue::I32(name.len() as i32),
WasmValue::I64(value as i64),
WasmValue::I32(labels.len() as i32),
],
);
}
fn gauge_set(&self, name: &str, value: f64, labels: &[(&str, &str)]) {
let _ = self.instance.call(
"gauge_set",
&[
WasmValue::I64(name.as_ptr() as i64),
WasmValue::I32(name.len() as i32),
WasmValue::F64(value),
WasmValue::I32(labels.len() as i32),
],
);
}
fn histogram_observe(&self, name: &str, value: f64, labels: &[(&str, &str)]) {
let _ = self.instance.call(
"histogram_observe",
&[
WasmValue::I64(name.as_ptr() as i64),
WasmValue::I32(name.len() as i32),
WasmValue::F64(value),
WasmValue::I32(labels.len() as i32),
],
);
}
fn span_start(&self, name: &str, parent: Option<u64>) -> u64 {
match self.instance.call(
"span_start",
&[
WasmValue::I64(name.as_ptr() as i64),
WasmValue::I32(name.len() as i32),
WasmValue::I64(parent.unwrap_or(0) as i64),
],
) {
Ok(results) => results.first().and_then(|v| v.as_i64()).unwrap_or(0) as u64,
Err(_) => 0,
}
}
fn span_end(&self, span_id: u64) {
let _ = self
.instance
.call("span_end", &[WasmValue::I64(span_id as i64)]);
}
fn span_event(&self, span_id: u64, name: &str, _attributes: &[(&str, &str)]) {
let _ = self.instance.call(
"span_event",
&[
WasmValue::I64(span_id as i64),
WasmValue::I64(name.as_ptr() as i64),
WasmValue::I32(name.len() as i32),
],
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capabilities_default() {
let caps = WasmPluginCapabilities::default();
assert_eq!(caps.memory_limit_bytes, 16 * 1024 * 1024);
assert_eq!(caps.fuel_limit, 1_000_000);
assert!(!caps.can_read("any_table"));
assert!(!caps.can_write("any_table"));
}
#[test]
fn test_capabilities_read_patterns() {
let caps = WasmPluginCapabilities {
can_read_table: vec!["users".to_string(), "logs_*".to_string()],
..Default::default()
};
assert!(caps.can_read("users"));
assert!(caps.can_read("logs_2024"));
assert!(caps.can_read("logs_"));
assert!(!caps.can_read("orders"));
}
#[test]
fn test_capabilities_wildcard() {
let caps = WasmPluginCapabilities {
can_read_table: vec!["*".to_string()],
..Default::default()
};
assert!(caps.can_read("any_table"));
assert!(caps.can_read("another_table"));
}
#[test]
fn test_wasm_instance_creation() {
let config = WasmInstanceConfig::default();
let instance = WasmPluginInstance::new("test_plugin", b"fake wasm bytes", config).unwrap();
assert_eq!(instance.name(), "test_plugin");
assert_eq!(instance.state(), WasmPluginState::Loading);
instance.init().unwrap();
assert_eq!(instance.state(), WasmPluginState::Ready);
}
#[test]
fn test_wasm_instance_call() {
let config = WasmInstanceConfig::default();
let instance = WasmPluginInstance::new("test_plugin", b"fake wasm bytes", config).unwrap();
instance.init().unwrap();
let result = instance.call("on_insert", &[WasmValue::I32(42)]).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0], WasmValue::I32(0));
let stats = instance.stats();
assert_eq!(stats.total_calls, 1);
assert!(stats.total_fuel_consumed > 0);
}
#[test]
fn test_wasm_registry() {
let registry = WasmPluginRegistry::new();
registry
.load("plugin1", b"fake wasm", WasmInstanceConfig::default())
.unwrap();
assert_eq!(registry.count(), 1);
let result = registry.call("plugin1", "on_insert", &[]).unwrap();
assert_eq!(result[0], WasmValue::I32(0));
let (calls, traps) = registry.global_stats();
assert_eq!(calls, 1);
assert_eq!(traps, 0);
registry.unload("plugin1").unwrap();
assert_eq!(registry.count(), 0);
}
#[test]
fn test_wasm_registry_duplicate() {
let registry = WasmPluginRegistry::new();
registry
.load("plugin1", b"fake wasm", WasmInstanceConfig::default())
.unwrap();
let result = registry.load("plugin1", b"fake wasm", WasmInstanceConfig::default());
assert!(result.is_err());
}
#[test]
fn test_wasm_value_conversions() {
let v = WasmValue::I32(42);
assert_eq!(v.as_i32(), Some(42));
assert_eq!(v.as_i64(), None);
let v = WasmValue::F64(2.5);
assert_eq!(v.as_f64(), Some(2.5));
assert_eq!(v.as_f32(), None);
}
#[test]
fn test_fuel_exhaustion() {
let config = WasmInstanceConfig {
capabilities: WasmPluginCapabilities {
fuel_limit: 100, ..Default::default()
},
enable_fuel: true,
..Default::default()
};
let instance = WasmPluginInstance::new("test", b"fake wasm", config).unwrap();
instance.init().unwrap();
let _ = instance.call("on_insert", &[]);
let result = instance.call("on_insert", &[]);
assert!(result.is_err());
}
#[test]
fn test_refuel() {
let config = WasmInstanceConfig {
capabilities: WasmPluginCapabilities {
fuel_limit: 150,
..Default::default()
},
enable_fuel: true,
..Default::default()
};
let instance = WasmPluginInstance::new("test", b"fake wasm", config).unwrap();
instance.init().unwrap();
let _ = instance.call("on_insert", &[]);
instance.refuel();
let result = instance.call("on_insert", &[]);
assert!(result.is_ok());
}
#[test]
fn test_observability_wrapper() {
let config = WasmInstanceConfig::default();
let instance =
Arc::new(WasmPluginInstance::new("obs_plugin", b"fake wasm", config).unwrap());
instance.init().unwrap();
let wrapper = WasmObservabilityPlugin::new(instance.clone());
wrapper.counter_inc("test_counter", 1, &[]);
wrapper.gauge_set("test_gauge", 42.0, &[]);
wrapper.histogram_observe("test_histogram", 0.5, &[]);
let span = wrapper.span_start("test_span", None);
wrapper.span_event(span, "event1", &[]);
wrapper.span_end(span);
}
}