use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use wasmtime::*;
use sentinel_agent_protocol::{
AgentHandler, AgentResponse, AuditMetadata, ConfigureEvent, HeaderOp, RequestHeadersEvent,
ResponseHeadersEvent,
};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WasmResult {
pub decision: String,
pub status: Option<u16>,
pub body: Option<String>,
pub add_request_headers: Option<HashMap<String, String>>,
pub remove_request_headers: Option<Vec<String>>,
pub add_response_headers: Option<HashMap<String, String>>,
pub remove_response_headers: Option<Vec<String>>,
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmRequest {
pub method: String,
pub uri: String,
pub client_ip: String,
pub correlation_id: String,
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmResponse {
pub status: u16,
pub correlation_id: String,
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct WasmConfigJson {
pub pool_size: usize,
pub fail_open: bool,
}
struct WasmInstance {
store: Store<()>,
memory: Memory,
alloc: TypedFunc<i32, i32>,
dealloc: TypedFunc<(i32, i32), ()>,
on_request_headers: Option<TypedFunc<(i32, i32), i64>>,
on_response_headers: Option<TypedFunc<(i32, i32), i64>>,
}
pub struct WasmAgent {
engine: Engine,
module: Module,
instance_pool: Arc<Mutex<Vec<WasmInstance>>>,
pool_size: usize,
fail_open: bool,
}
unsafe impl Send for WasmAgent {}
unsafe impl Sync for WasmAgent {}
impl WasmAgent {
pub fn new<P: AsRef<Path>>(module_path: P, pool_size: usize, fail_open: bool) -> Result<Self> {
let module_bytes = std::fs::read(module_path.as_ref())
.with_context(|| format!("Failed to read Wasm module: {:?}", module_path.as_ref()))?;
Self::from_bytes(&module_bytes, pool_size, fail_open)
}
pub fn from_bytes(module_bytes: &[u8], pool_size: usize, fail_open: bool) -> Result<Self> {
let mut config = Config::new();
config.wasm_multi_memory(true);
config.wasm_bulk_memory(true);
let engine = Engine::new(&config).context("Failed to create Wasm engine")?;
let module = Module::new(&engine, module_bytes).context("Failed to compile Wasm module")?;
info!("Wasm module compiled successfully");
let agent = Self {
engine,
module,
instance_pool: Arc::new(Mutex::new(Vec::with_capacity(pool_size))),
pool_size,
fail_open,
};
Ok(agent)
}
fn create_instance(&self) -> Result<WasmInstance> {
let mut store = Store::new(&self.engine, ());
let instance = Instance::new(&mut store, &self.module, &[])
.context("Failed to instantiate Wasm module")?;
let memory = instance
.get_memory(&mut store, "memory")
.context("Wasm module must export 'memory'")?;
let alloc = instance
.get_typed_func::<i32, i32>(&mut store, "alloc")
.context("Wasm module must export 'alloc(i32) -> i32'")?;
let dealloc = instance
.get_typed_func::<(i32, i32), ()>(&mut store, "dealloc")
.context("Wasm module must export 'dealloc(i32, i32)'")?;
let on_request_headers = instance
.get_typed_func::<(i32, i32), i64>(&mut store, "on_request_headers")
.ok();
let on_response_headers = instance
.get_typed_func::<(i32, i32), i64>(&mut store, "on_response_headers")
.ok();
if on_request_headers.is_none() && on_response_headers.is_none() {
anyhow::bail!(
"Wasm module must export at least one of: on_request_headers, on_response_headers"
);
}
debug!("Created new Wasm instance");
Ok(WasmInstance {
store,
memory,
alloc,
dealloc,
on_request_headers,
on_response_headers,
})
}
async fn acquire_instance(&self) -> Result<WasmInstance> {
let mut pool = self.instance_pool.lock().await;
if let Some(instance) = pool.pop() {
Ok(instance)
} else {
drop(pool); self.create_instance()
}
}
async fn release_instance(&self, instance: WasmInstance) {
let mut pool = self.instance_pool.lock().await;
if pool.len() < self.pool_size {
pool.push(instance);
}
}
fn has_request_handler(instance: &WasmInstance) -> bool {
instance.on_request_headers.is_some()
}
fn has_response_handler(instance: &WasmInstance) -> bool {
instance.on_response_headers.is_some()
}
fn call_request_handler(instance: &mut WasmInstance, input_json: &str) -> Result<String> {
let handler = instance
.on_request_headers
.clone()
.expect("on_request_headers should exist");
Self::call_wasm_handler_impl(instance, handler, input_json)
}
fn call_response_handler(instance: &mut WasmInstance, input_json: &str) -> Result<String> {
let handler = instance
.on_response_headers
.clone()
.expect("on_response_headers should exist");
Self::call_wasm_handler_impl(instance, handler, input_json)
}
fn call_wasm_handler_impl(
instance: &mut WasmInstance,
handler: TypedFunc<(i32, i32), i64>,
input_json: &str,
) -> Result<String> {
let input_bytes = input_json.as_bytes();
let input_len = input_bytes.len() as i32;
let input_ptr = instance
.alloc
.call(&mut instance.store, input_len)
.context("Failed to allocate input memory")?;
instance
.memory
.write(&mut instance.store, input_ptr as usize, input_bytes)
.context("Failed to write input to Wasm memory")?;
let result = handler
.call(&mut instance.store, (input_ptr, input_len))
.context("Wasm handler call failed")?;
instance
.dealloc
.call(&mut instance.store, (input_ptr, input_len))
.ok();
let result_ptr = (result >> 32) as i32;
let result_len = (result & 0xFFFFFFFF) as i32;
if result_ptr == 0 || result_len == 0 {
return Ok(r#"{"decision":"allow"}"#.to_string());
}
let mut result_bytes = vec![0u8; result_len as usize];
instance
.memory
.read(&instance.store, result_ptr as usize, &mut result_bytes)
.context("Failed to read result from Wasm memory")?;
instance
.dealloc
.call(&mut instance.store, (result_ptr, result_len))
.ok();
String::from_utf8(result_bytes).context("Wasm result is not valid UTF-8")
}
pub fn build_response(result: WasmResult) -> AgentResponse {
let decision = result.decision.to_lowercase();
let mut response = match decision.as_str() {
"block" | "deny" => {
let status = result.status.unwrap_or(403);
AgentResponse::block(status, result.body)
}
"redirect" => {
let status = result.status.unwrap_or(302);
let mut resp = AgentResponse::block(status, None);
if let Some(url) = result.body {
resp = resp.add_response_header(HeaderOp::Set {
name: "Location".to_string(),
value: url,
});
}
resp
}
_ => AgentResponse::default_allow(),
};
if let Some(headers) = result.add_request_headers {
for (name, value) in headers {
response = response.add_request_header(HeaderOp::Set { name, value });
}
}
if let Some(headers) = result.remove_request_headers {
for name in headers {
response = response.add_request_header(HeaderOp::Remove { name });
}
}
if let Some(headers) = result.add_response_headers {
for (name, value) in headers {
response = response.add_response_header(HeaderOp::Set { name, value });
}
}
if let Some(headers) = result.remove_response_headers {
for name in headers {
response = response.add_response_header(HeaderOp::Remove { name });
}
}
if let Some(tags) = result.tags {
response = response.with_audit(AuditMetadata {
tags,
..Default::default()
});
}
response
}
fn handle_error(&self, error: anyhow::Error, correlation_id: &str) -> AgentResponse {
error!(
correlation_id = correlation_id,
error = %error,
"Wasm execution failed"
);
if self.fail_open {
AgentResponse::default_allow().with_audit(AuditMetadata {
tags: vec!["wasm-error".to_string(), "fail-open".to_string()],
reason_codes: vec![error.to_string()],
..Default::default()
})
} else {
AgentResponse::block(500, Some("Wasm Error".to_string())).with_audit(AuditMetadata {
tags: vec!["wasm-error".to_string()],
reason_codes: vec![error.to_string()],
..Default::default()
})
}
}
}
#[async_trait]
impl AgentHandler for WasmAgent {
async fn on_configure(&self, event: ConfigureEvent) -> AgentResponse {
info!(
agent_id = %event.agent_id,
"Received configuration event"
);
let config: WasmConfigJson = match serde_json::from_value(event.config) {
Ok(c) => c,
Err(e) => {
error!(error = %e, "Failed to parse Wasm agent configuration");
return AgentResponse::block(
500,
Some(format!("Invalid Wasm agent configuration: {}", e)),
);
}
};
info!(
pool_size = config.pool_size,
fail_open = config.fail_open,
"Wasm agent configuration received (note: module cannot be changed dynamically)"
);
AgentResponse::default_allow()
}
async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
let correlation_id = event.metadata.correlation_id.clone();
let mut instance = match self.acquire_instance().await {
Ok(inst) => inst,
Err(e) => return self.handle_error(e, &correlation_id),
};
let mut headers: HashMap<String, String> = HashMap::new();
for (name, values) in &event.headers {
headers.insert(name.clone(), values.join(", "));
}
let request = WasmRequest {
method: event.method.clone(),
uri: event.uri.clone(),
client_ip: event.metadata.client_ip.clone(),
correlation_id: correlation_id.clone(),
headers,
};
let input_json = match serde_json::to_string(&request) {
Ok(j) => j,
Err(e) => {
self.release_instance(instance).await;
return self.handle_error(e.into(), &correlation_id);
}
};
if !Self::has_request_handler(&instance) {
self.release_instance(instance).await;
return AgentResponse::default_allow();
}
let result = Self::call_request_handler(&mut instance, &input_json);
self.release_instance(instance).await;
match result {
Ok(output_json) => {
debug!(
correlation_id = correlation_id,
output = %output_json,
"Wasm handler returned"
);
match serde_json::from_str::<WasmResult>(&output_json) {
Ok(wasm_result) => Self::build_response(wasm_result),
Err(e) => {
warn!(
correlation_id = correlation_id,
error = %e,
output = %output_json,
"Failed to parse Wasm result"
);
self.handle_error(e.into(), &correlation_id)
}
}
}
Err(e) => self.handle_error(e, &correlation_id),
}
}
async fn on_response_headers(&self, event: ResponseHeadersEvent) -> AgentResponse {
let correlation_id = event.correlation_id.clone();
let mut instance = match self.acquire_instance().await {
Ok(inst) => inst,
Err(e) => return self.handle_error(e, &correlation_id),
};
let mut headers: HashMap<String, String> = HashMap::new();
for (name, values) in &event.headers {
headers.insert(name.clone(), values.join(", "));
}
let response = WasmResponse {
status: event.status,
correlation_id: correlation_id.clone(),
headers,
};
let input_json = match serde_json::to_string(&response) {
Ok(j) => j,
Err(e) => {
self.release_instance(instance).await;
return self.handle_error(e.into(), &correlation_id);
}
};
if !Self::has_response_handler(&instance) {
self.release_instance(instance).await;
return AgentResponse::default_allow();
}
let result = Self::call_response_handler(&mut instance, &input_json);
self.release_instance(instance).await;
match result {
Ok(output_json) => {
debug!(
correlation_id = correlation_id,
output = %output_json,
"Wasm handler returned"
);
match serde_json::from_str::<WasmResult>(&output_json) {
Ok(wasm_result) => Self::build_response(wasm_result),
Err(e) => {
warn!(
correlation_id = correlation_id,
error = %e,
output = %output_json,
"Failed to parse Wasm result"
);
self.handle_error(e.into(), &correlation_id)
}
}
}
Err(e) => self.handle_error(e, &correlation_id),
}
}
}