use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use regex::Regex;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use wasmtime::*;
use super::types::{ModuleAction, RequestInfo, ResponseInfo};
use crate::charter::{Module as ModuleConfig, Phase};
struct LoadedModule {
name: String,
_module: Module,
instance: Instance,
store: Store<ModuleState>,
routes: Vec<Regex>,
phase: Phase,
}
struct ModuleState {
request_json: String,
response_json: String,
action_json: String,
config_json: String,
pending_sends: Vec<(String, String)>,
pending_broadcasts: Vec<(String, String)>,
pending_closes: Vec<(String, i32, String)>,
limiter: StoreLimits,
start_time: Instant,
}
impl Default for ModuleState {
fn default() -> Self {
Self {
request_json: String::new(),
response_json: String::new(),
action_json: String::new(),
config_json: String::from("{}"),
pending_sends: Vec::new(),
pending_broadcasts: Vec::new(),
pending_closes: Vec::new(),
limiter: StoreLimitsBuilder::new()
.memory_size(WASM_MAX_MEMORY_PAGES as usize * 65536)
.build(),
start_time: Instant::now(),
}
}
}
fn read_wasm_string(caller: &mut Caller<'_, ModuleState>, ptr: i32, len: i32) -> Option<String> {
if ptr == 0 || len <= 0 {
return Some(String::new());
}
let memory = match caller.get_export("memory") {
Some(Extern::Memory(mem)) => mem,
_ => return None,
};
let data = memory.data(&*caller);
let start = ptr as usize;
let end = start + len as usize;
if end > data.len() {
return None;
}
std::str::from_utf8(&data[start..end])
.ok()
.map(|s| s.to_string())
}
pub struct ModuleRuntime {
engine: Engine,
modules: Arc<RwLock<Vec<LoadedModule>>>,
}
const WASM_MAX_MEMORY_PAGES: u64 = 256; const WASM_FUEL_LIMIT: u64 = 1_000_000;
impl ModuleRuntime {
pub fn new() -> Result<Self> {
let mut config = Config::new();
config.async_support(true);
config.consume_fuel(true);
let engine = Engine::new(&config)?;
Ok(Self {
engine,
modules: Arc::new(RwLock::new(Vec::new())),
})
}
pub async fn load_modules(&self, configs: &[ModuleConfig]) -> anyhow::Result<()> {
let mut modules = self.modules.write().await;
let mut failures = Vec::new();
for config in configs {
match self.load_module(config).await {
Ok(module) => {
info!(name = %config.name, wasm = %config.wasm, "Loaded WASM module");
modules.push(module);
}
Err(e) => {
error!(name = %config.name, error = %e, "Failed to load WASM module");
failures.push(format!("{} ({})", config.name, e));
}
}
}
if !failures.is_empty() {
anyhow::bail!(
"One or more WASM modules failed to load: {}",
failures.join(", ")
);
}
Ok(())
}
async fn load_module(&self, config: &ModuleConfig) -> anyhow::Result<LoadedModule> {
let path = Path::new(&config.wasm);
if !path.exists() {
anyhow::bail!("Module file not found: {}", config.wasm);
}
let routes: Vec<Regex> = config
.routes
.iter()
.filter_map(|r| match Regex::new(r) {
Ok(regex) => Some(regex),
Err(e) => {
warn!(pattern = %r, error = %e, "Invalid route pattern");
None
}
})
.collect();
let module = Module::from_file(&self.engine, path)?;
let mut store = Store::new(&self.engine, ModuleState::default());
store.data_mut().config_json =
serde_json::to_string(&config.config).unwrap_or_else(|_| "{}".to_string());
store.limiter(|state| &mut state.limiter);
let mut linker = Linker::new(&self.engine);
linker.func_wrap(
"env",
"get_request",
|caller: Caller<'_, ModuleState>| -> i32 {
let state = caller.data();
state.request_json.len() as i32
},
)?;
linker.func_wrap(
"env",
"read_request",
|mut caller: Caller<'_, ModuleState>, ptr: i32, max_len: i32| -> i32 {
let json = caller.data().request_json.clone();
let bytes = json.as_bytes();
let len = bytes.len().min(max_len as usize);
if let Some(Extern::Memory(memory)) = caller.get_export("memory")
&& let Some(slice) = memory
.data_mut(&mut caller)
.get_mut(ptr as usize..(ptr as usize + len))
{
slice.copy_from_slice(&bytes[..len]);
return len as i32;
}
-1
},
)?;
linker.func_wrap(
"env",
"get_response",
|caller: Caller<'_, ModuleState>| -> i32 {
let state = caller.data();
state.response_json.len() as i32
},
)?;
linker.func_wrap(
"env",
"read_response",
|mut caller: Caller<'_, ModuleState>, ptr: i32, max_len: i32| -> i32 {
let json = caller.data().response_json.clone();
let bytes = json.as_bytes();
let len = bytes.len().min(max_len as usize);
if let Some(Extern::Memory(memory)) = caller.get_export("memory")
&& let Some(slice) = memory
.data_mut(&mut caller)
.get_mut(ptr as usize..(ptr as usize + len))
{
slice.copy_from_slice(&bytes[..len]);
return len as i32;
}
-1
},
)?;
linker.func_wrap(
"env",
"get_config",
|caller: Caller<'_, ModuleState>| -> i32 {
let state = caller.data();
state.config_json.len() as i32
},
)?;
linker.func_wrap(
"env",
"read_config",
|mut caller: Caller<'_, ModuleState>, ptr: i32, max_len: i32| -> i32 {
let json = caller.data().config_json.clone();
let bytes = json.as_bytes();
let len = bytes.len().min(max_len as usize);
if let Some(Extern::Memory(memory)) = caller.get_export("memory")
&& let Some(slice) = memory
.data_mut(&mut caller)
.get_mut(ptr as usize..(ptr as usize + len))
{
slice.copy_from_slice(&bytes[..len]);
return len as i32;
}
-1
},
)?;
linker.func_wrap(
"env",
"set_action",
|mut caller: Caller<'_, ModuleState>, action: i32| {
let state = caller.data_mut();
state.action_json = match action {
0 => r#"{"Continue":null}"#.to_string(),
1 => r#"{"Block":{"status":403,"body":"Blocked by module"}}"#.to_string(),
_ => r#"{"Continue":null}"#.to_string(),
};
},
)?;
linker.func_wrap(
"env",
"set_block",
|mut caller: Caller<'_, ModuleState>, status: i32, body_ptr: i32, body_len: i32| {
let body = read_wasm_string(&mut caller, body_ptr, body_len).unwrap_or_default();
let state = caller.data_mut();
let escaped_body =
serde_json::to_string(&body).unwrap_or_else(|_| "\"\"".to_string());
let escaped_body = &escaped_body[1..escaped_body.len() - 1];
state.action_json = format!(
r#"{{"Block":{{"status":{},"body":"{}","headers":null}}}}"#,
status, escaped_body
);
},
)?;
linker.func_wrap(
"env",
"set_block_with_headers",
|mut caller: Caller<'_, ModuleState>,
status: i32,
body_ptr: i32,
body_len: i32,
headers_ptr: i32,
headers_len: i32| {
let body = read_wasm_string(&mut caller, body_ptr, body_len).unwrap_or_default();
let headers_json =
read_wasm_string(&mut caller, headers_ptr, headers_len).unwrap_or_default();
let state = caller.data_mut();
let escaped_body =
serde_json::to_string(&body).unwrap_or_else(|_| "\"\"".to_string());
let escaped_body = &escaped_body[1..escaped_body.len() - 1];
state.action_json = format!(
r#"{{"Block":{{"status":{},"body":"{}","headers":{}}}}}"#,
status, escaped_body, headers_json
);
},
)?;
linker.func_wrap(
"env",
"log",
|_caller: Caller<'_, ModuleState>, level: i32, ptr: i32, len: i32| {
debug!(level = level, ptr = ptr, len = len, "Module log");
},
)?;
linker.func_wrap(
"env",
"get_time",
|caller: Caller<'_, ModuleState>| -> f64 {
caller.data().start_time.elapsed().as_secs_f64()
},
)?;
linker.func_wrap(
"mothership",
"host_log",
|mut caller: Caller<'_, ModuleState>, level: i32, msg_ptr: i32, msg_len: i32| -> i32 {
let message = read_wasm_string(&mut caller, msg_ptr, msg_len)
.unwrap_or_else(|| "<invalid utf-8>".to_string());
match level {
0 => debug!(target: "mothership::plugin", "{}", message),
1 => warn!(target: "mothership::plugin", "{}", message),
2 => info!(target: "mothership::plugin", "{}", message),
_ => debug!(target: "mothership::plugin", "{}", message),
}
0 },
)?;
linker.func_wrap(
"mothership",
"host_send",
|mut caller: Caller<'_, ModuleState>,
conn_id_ptr: i32,
conn_id_len: i32,
payload_ptr: i32,
payload_len: i32|
-> i32 {
let conn_id = match read_wasm_string(&mut caller, conn_id_ptr, conn_id_len) {
Some(s) => s,
None => return -2,
};
let payload = match read_wasm_string(&mut caller, payload_ptr, payload_len) {
Some(s) => s,
None => return -3,
};
debug!(target: "mothership::plugin", conn_id = %conn_id, payload_len = payload.len(), "host_send");
caller.data_mut().pending_sends.push((conn_id, payload));
0 },
)?;
linker.func_wrap(
"mothership",
"host_broadcast",
|mut caller: Caller<'_, ModuleState>,
stream_ptr: i32,
stream_len: i32,
payload_ptr: i32,
payload_len: i32|
-> i32 {
let stream = match read_wasm_string(&mut caller, stream_ptr, stream_len) {
Some(s) => s,
None => return -2,
};
let payload = match read_wasm_string(&mut caller, payload_ptr, payload_len) {
Some(s) => s,
None => return -3,
};
debug!(target: "mothership::plugin", stream = %stream, payload_len = payload.len(), "host_broadcast");
caller.data_mut().pending_broadcasts.push((stream, payload));
0 },
)?;
linker.func_wrap(
"mothership",
"host_close",
|mut caller: Caller<'_, ModuleState>,
conn_id_ptr: i32,
conn_id_len: i32,
code: i32,
reason_ptr: i32,
reason_len: i32|
-> i32 {
let conn_id = match read_wasm_string(&mut caller, conn_id_ptr, conn_id_len) {
Some(s) => s,
None => return -2,
};
let reason = read_wasm_string(&mut caller, reason_ptr, reason_len)
.unwrap_or_default();
debug!(target: "mothership::plugin", conn_id = %conn_id, code = code, reason = %reason, "host_close");
caller.data_mut().pending_closes.push((conn_id, code, reason));
0 },
)?;
let instance = linker.instantiate_async(&mut store, &module).await?;
Ok(LoadedModule {
name: config.name.clone(),
_module: module,
instance,
store,
routes,
phase: config.phase,
})
}
pub async fn process_request(&self, path: &str, request: &RequestInfo) -> ModuleAction {
let mut modules = self.modules.write().await;
for module in modules.iter_mut() {
if module.phase != Phase::Request {
continue;
}
if !module.routes.is_empty() && !module.routes.iter().any(|r| r.is_match(path)) {
continue;
}
module.store.data_mut().request_json =
serde_json::to_string(request).unwrap_or_default();
module.store.data_mut().action_json = r#"{"Continue":null}"#.to_string();
if let Some(func) = module.instance.get_func(&mut module.store, "on_request") {
let _ = module.store.set_fuel(WASM_FUEL_LIMIT);
match func.call_async(&mut module.store, &[], &mut []).await {
Ok(_) => {
let json = &module.store.data().action_json;
let action: ModuleAction = match serde_json::from_str(json) {
Ok(action) => action,
Err(e) => {
error!(
module = %module.name,
error = %e,
"Module returned invalid action JSON; failing closed"
);
return ModuleAction::Block {
status: 503,
body: "Module processing failure".to_string(),
headers: None,
};
}
};
match &action {
ModuleAction::Continue => continue,
ModuleAction::Modify { .. } => return action,
ModuleAction::Block { .. } => return action,
}
}
Err(e) => {
error!(
module = %module.name,
error = %e,
"Module on_request failed; failing closed"
);
return ModuleAction::Block {
status: 503,
body: "Module processing failure".to_string(),
headers: None,
};
}
}
}
}
ModuleAction::Continue
}
pub async fn process_response(&self, path: &str, response: &ResponseInfo) -> ModuleAction {
let mut modules = self.modules.write().await;
for module in modules.iter_mut() {
if module.phase != Phase::Response {
continue;
}
if !module.routes.is_empty() && !module.routes.iter().any(|r| r.is_match(path)) {
continue;
}
module.store.data_mut().response_json =
serde_json::to_string(response).unwrap_or_default();
module.store.data_mut().action_json = r#"{"Continue":null}"#.to_string();
let func = module.instance.get_func(&mut module.store, "on_response");
if let Some(func) = func {
let _ = module.store.set_fuel(WASM_FUEL_LIMIT);
match func.call_async(&mut module.store, &[], &mut []).await {
Ok(_) => {
let json = &module.store.data().action_json;
let action: ModuleAction = match serde_json::from_str(json) {
Ok(action) => action,
Err(e) => {
error!(
module = %module.name,
error = %e,
"Module returned invalid action JSON; failing closed"
);
return ModuleAction::Block {
status: 503,
body: "Module processing failure".to_string(),
headers: None,
};
}
};
match &action {
ModuleAction::Continue => continue,
ModuleAction::Modify { .. } => return action,
ModuleAction::Block { .. } => return action,
}
}
Err(e) => {
error!(
module = %module.name,
error = %e,
"Module on_response failed; failing closed"
);
return ModuleAction::Block {
status: 503,
body: "Module processing failure".to_string(),
headers: None,
};
}
}
}
}
ModuleAction::Continue
}
pub async fn module_names(&self) -> Vec<String> {
self.modules
.read()
.await
.iter()
.map(|m| m.name.clone())
.collect()
}
}
impl Default for ModuleRuntime {
fn default() -> Self {
Self::new().expect("Failed to create module runtime")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum WsEventType {
Connect = 0,
Message = 1,
Disconnect = 2,
}
#[derive(Debug, Default)]
pub struct WsEventResult {
pub sends: Vec<(String, String)>,
pub broadcasts: Vec<(String, String)>,
pub closes: Vec<(String, i32, String)>,
}
pub struct WsPluginRuntime {
engine: Engine,
module: Option<Module>,
instance: Option<Instance>,
store: Option<Store<ModuleState>>,
name: String,
}
impl WsPluginRuntime {
pub fn new() -> Result<Self> {
let mut config = Config::new();
config.consume_fuel(true);
let engine = Engine::new(&config)?;
Ok(Self {
engine,
module: None,
instance: None,
store: None,
name: String::new(),
})
}
pub async fn load(&mut self, path: &Path, name: &str) -> anyhow::Result<()> {
if !path.exists() {
anyhow::bail!("Plugin file not found: {}", path.display());
}
let module = Module::from_file(&self.engine, path)?;
let mut store = Store::new(&self.engine, ModuleState::default());
store.limiter(|state| &mut state.limiter);
let mut linker = Linker::new(&self.engine);
linker.func_wrap(
"mothership",
"host_log",
|mut caller: Caller<'_, ModuleState>, level: i32, msg_ptr: i32, msg_len: i32| -> i32 {
let message = read_wasm_string(&mut caller, msg_ptr, msg_len)
.unwrap_or_else(|| "<invalid utf-8>".to_string());
match level {
0 => debug!(target: "mothership::plugin", "{}", message),
1 => warn!(target: "mothership::plugin", "{}", message),
2 => info!(target: "mothership::plugin", "{}", message),
_ => debug!(target: "mothership::plugin", "{}", message),
}
0
},
)?;
linker.func_wrap(
"mothership",
"host_send",
|mut caller: Caller<'_, ModuleState>,
conn_id_ptr: i32,
conn_id_len: i32,
payload_ptr: i32,
payload_len: i32|
-> i32 {
let conn_id = match read_wasm_string(&mut caller, conn_id_ptr, conn_id_len) {
Some(s) => s,
None => return -2,
};
let payload = match read_wasm_string(&mut caller, payload_ptr, payload_len) {
Some(s) => s,
None => return -3,
};
caller.data_mut().pending_sends.push((conn_id, payload));
0
},
)?;
linker.func_wrap(
"mothership",
"host_broadcast",
|mut caller: Caller<'_, ModuleState>,
stream_ptr: i32,
stream_len: i32,
payload_ptr: i32,
payload_len: i32|
-> i32 {
let stream = match read_wasm_string(&mut caller, stream_ptr, stream_len) {
Some(s) => s,
None => return -2,
};
let payload = match read_wasm_string(&mut caller, payload_ptr, payload_len) {
Some(s) => s,
None => return -3,
};
caller.data_mut().pending_broadcasts.push((stream, payload));
0
},
)?;
linker.func_wrap(
"mothership",
"host_close",
|mut caller: Caller<'_, ModuleState>,
conn_id_ptr: i32,
conn_id_len: i32,
code: i32,
reason_ptr: i32,
reason_len: i32|
-> i32 {
let conn_id = match read_wasm_string(&mut caller, conn_id_ptr, conn_id_len) {
Some(s) => s,
None => return -2,
};
let reason =
read_wasm_string(&mut caller, reason_ptr, reason_len).unwrap_or_default();
caller
.data_mut()
.pending_closes
.push((conn_id, code, reason));
0
},
)?;
let instance = linker.instantiate_async(&mut store, &module).await?;
self.module = Some(module);
self.instance = Some(instance);
self.store = Some(store);
self.name = name.to_string();
info!(name = %name, path = %path.display(), "Loaded WebSocket plugin");
Ok(())
}
pub fn on_event(
&mut self,
event_type: WsEventType,
conn_id: &str,
payload: &str,
) -> anyhow::Result<WsEventResult> {
let store = self
.store
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Plugin not loaded"))?;
let instance = self
.instance
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Plugin not loaded"))?;
store.data_mut().pending_sends.clear();
store.data_mut().pending_broadcasts.clear();
store.data_mut().pending_closes.clear();
let memory = instance
.get_memory(&mut *store, "memory")
.ok_or_else(|| anyhow::anyhow!("No memory export"))?;
let on_event = instance
.get_typed_func::<(i32, i32, i32, i32, i32), i32>(&mut *store, "wasm_on_event")
.map_err(|e| anyhow::anyhow!("wasm_on_event not found: {}", e))?;
let conn_id_bytes = conn_id.as_bytes();
let payload_bytes = payload.as_bytes();
let conn_id_ptr = 1024i32;
let payload_ptr = conn_id_ptr + conn_id_bytes.len() as i32 + 16;
let required_size = (payload_ptr as usize) + payload_bytes.len() + 1024;
let current_size = memory.data_size(&*store);
if required_size > current_size {
let pages_needed = ((required_size - current_size) / 65536) + 1;
memory.grow(&mut *store, pages_needed as u64)?;
}
memory.write(&mut *store, conn_id_ptr as usize, conn_id_bytes)?;
memory.write(&mut *store, payload_ptr as usize, payload_bytes)?;
let _ = store.set_fuel(WASM_FUEL_LIMIT);
let result = on_event.call(
&mut *store,
(
event_type as i32,
conn_id_ptr,
conn_id_bytes.len() as i32,
payload_ptr,
payload_bytes.len() as i32,
),
)?;
if result != 0 {
warn!(
plugin = %self.name,
event = ?event_type,
conn_id = %conn_id,
result = result,
"Plugin returned error"
);
}
let state = store.data();
Ok(WsEventResult {
sends: state.pending_sends.clone(),
broadcasts: state.pending_broadcasts.clone(),
closes: state.pending_closes.clone(),
})
}
pub fn name(&self) -> &str {
&self.name
}
}
impl Default for WsPluginRuntime {
fn default() -> Self {
Self::new().expect("Failed to create WebSocket plugin runtime")
}
}