1use anyhow::{Context, Result};
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::Path;
28use std::sync::Arc;
29use tokio::sync::Mutex;
30use tracing::{debug, error, info, warn};
31use wasmtime::*;
32
33use sentinel_agent_protocol::{
34 AgentHandler, AgentResponse, AuditMetadata, ConfigureEvent, HeaderOp, RequestHeadersEvent,
35 ResponseHeadersEvent,
36};
37
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct WasmResult {
41 pub decision: String,
43 pub status: Option<u16>,
45 pub body: Option<String>,
47 pub add_request_headers: Option<HashMap<String, String>>,
49 pub remove_request_headers: Option<Vec<String>>,
51 pub add_response_headers: Option<HashMap<String, String>>,
53 pub remove_response_headers: Option<Vec<String>>,
55 pub tags: Option<Vec<String>>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct WasmRequest {
62 pub method: String,
63 pub uri: String,
64 pub client_ip: String,
65 pub correlation_id: String,
66 pub headers: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct WasmResponse {
72 pub status: u16,
73 pub correlation_id: String,
74 pub headers: HashMap<String, String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(rename_all = "kebab-case")]
83pub struct WasmConfigJson {
84 pub pool_size: usize,
86 pub fail_open: bool,
88}
89
90struct WasmInstance {
92 store: Store<()>,
93 memory: Memory,
94 alloc: TypedFunc<i32, i32>,
95 dealloc: TypedFunc<(i32, i32), ()>,
96 on_request_headers: Option<TypedFunc<(i32, i32), i64>>,
97 on_response_headers: Option<TypedFunc<(i32, i32), i64>>,
98}
99
100pub struct WasmAgent {
102 engine: Engine,
103 module: Module,
104 instance_pool: Arc<Mutex<Vec<WasmInstance>>>,
105 pool_size: usize,
106 fail_open: bool,
107}
108
109unsafe impl Send for WasmAgent {}
111unsafe impl Sync for WasmAgent {}
112
113impl WasmAgent {
114 pub fn new<P: AsRef<Path>>(module_path: P, pool_size: usize, fail_open: bool) -> Result<Self> {
116 let module_bytes = std::fs::read(module_path.as_ref())
117 .with_context(|| format!("Failed to read Wasm module: {:?}", module_path.as_ref()))?;
118
119 Self::from_bytes(&module_bytes, pool_size, fail_open)
120 }
121
122 pub fn from_bytes(module_bytes: &[u8], pool_size: usize, fail_open: bool) -> Result<Self> {
124 let mut config = Config::new();
125 config.wasm_multi_memory(true);
126 config.wasm_bulk_memory(true);
127
128 let engine = Engine::new(&config).context("Failed to create Wasm engine")?;
129 let module = Module::new(&engine, module_bytes).context("Failed to compile Wasm module")?;
130
131 info!("Wasm module compiled successfully");
132
133 let agent = Self {
134 engine,
135 module,
136 instance_pool: Arc::new(Mutex::new(Vec::with_capacity(pool_size))),
137 pool_size,
138 fail_open,
139 };
140
141 Ok(agent)
142 }
143
144 fn create_instance(&self) -> Result<WasmInstance> {
146 let mut store = Store::new(&self.engine, ());
147 let instance = Instance::new(&mut store, &self.module, &[])
148 .context("Failed to instantiate Wasm module")?;
149
150 let memory = instance
152 .get_memory(&mut store, "memory")
153 .context("Wasm module must export 'memory'")?;
154
155 let alloc = instance
157 .get_typed_func::<i32, i32>(&mut store, "alloc")
158 .context("Wasm module must export 'alloc(i32) -> i32'")?;
159
160 let dealloc = instance
161 .get_typed_func::<(i32, i32), ()>(&mut store, "dealloc")
162 .context("Wasm module must export 'dealloc(i32, i32)'")?;
163
164 let on_request_headers = instance
166 .get_typed_func::<(i32, i32), i64>(&mut store, "on_request_headers")
167 .ok();
168
169 let on_response_headers = instance
170 .get_typed_func::<(i32, i32), i64>(&mut store, "on_response_headers")
171 .ok();
172
173 if on_request_headers.is_none() && on_response_headers.is_none() {
174 anyhow::bail!(
175 "Wasm module must export at least one of: on_request_headers, on_response_headers"
176 );
177 }
178
179 debug!("Created new Wasm instance");
180
181 Ok(WasmInstance {
182 store,
183 memory,
184 alloc,
185 dealloc,
186 on_request_headers,
187 on_response_headers,
188 })
189 }
190
191 async fn acquire_instance(&self) -> Result<WasmInstance> {
193 let mut pool = self.instance_pool.lock().await;
194 if let Some(instance) = pool.pop() {
195 Ok(instance)
196 } else {
197 drop(pool); self.create_instance()
199 }
200 }
201
202 async fn release_instance(&self, instance: WasmInstance) {
204 let mut pool = self.instance_pool.lock().await;
205 if pool.len() < self.pool_size {
206 pool.push(instance);
207 }
208 }
210
211 fn has_request_handler(instance: &WasmInstance) -> bool {
213 instance.on_request_headers.is_some()
214 }
215
216 fn has_response_handler(instance: &WasmInstance) -> bool {
218 instance.on_response_headers.is_some()
219 }
220
221 fn call_request_handler(instance: &mut WasmInstance, input_json: &str) -> Result<String> {
223 let handler = instance
224 .on_request_headers
225 .clone()
226 .expect("on_request_headers should exist");
227 Self::call_wasm_handler_impl(instance, handler, input_json)
228 }
229
230 fn call_response_handler(instance: &mut WasmInstance, input_json: &str) -> Result<String> {
232 let handler = instance
233 .on_response_headers
234 .clone()
235 .expect("on_response_headers should exist");
236 Self::call_wasm_handler_impl(instance, handler, input_json)
237 }
238
239 fn call_wasm_handler_impl(
241 instance: &mut WasmInstance,
242 handler: TypedFunc<(i32, i32), i64>,
243 input_json: &str,
244 ) -> Result<String> {
245 let input_bytes = input_json.as_bytes();
246 let input_len = input_bytes.len() as i32;
247
248 let input_ptr = instance
250 .alloc
251 .call(&mut instance.store, input_len)
252 .context("Failed to allocate input memory")?;
253
254 instance
256 .memory
257 .write(&mut instance.store, input_ptr as usize, input_bytes)
258 .context("Failed to write input to Wasm memory")?;
259
260 let result = handler
262 .call(&mut instance.store, (input_ptr, input_len))
263 .context("Wasm handler call failed")?;
264
265 instance
267 .dealloc
268 .call(&mut instance.store, (input_ptr, input_len))
269 .ok(); let result_ptr = (result >> 32) as i32;
273 let result_len = (result & 0xFFFFFFFF) as i32;
274
275 if result_ptr == 0 || result_len == 0 {
276 return Ok(r#"{"decision":"allow"}"#.to_string());
277 }
278
279 let mut result_bytes = vec![0u8; result_len as usize];
281 instance
282 .memory
283 .read(&instance.store, result_ptr as usize, &mut result_bytes)
284 .context("Failed to read result from Wasm memory")?;
285
286 instance
288 .dealloc
289 .call(&mut instance.store, (result_ptr, result_len))
290 .ok(); String::from_utf8(result_bytes).context("Wasm result is not valid UTF-8")
293 }
294
295 pub fn build_response(result: WasmResult) -> AgentResponse {
297 let decision = result.decision.to_lowercase();
298
299 let mut response = match decision.as_str() {
300 "block" | "deny" => {
301 let status = result.status.unwrap_or(403);
302 AgentResponse::block(status, result.body)
303 }
304 "redirect" => {
305 let status = result.status.unwrap_or(302);
306 let mut resp = AgentResponse::block(status, None);
307 if let Some(url) = result.body {
308 resp = resp.add_response_header(HeaderOp::Set {
309 name: "Location".to_string(),
310 value: url,
311 });
312 }
313 resp
314 }
315 _ => AgentResponse::default_allow(),
316 };
317
318 if let Some(headers) = result.add_request_headers {
320 for (name, value) in headers {
321 response = response.add_request_header(HeaderOp::Set { name, value });
322 }
323 }
324
325 if let Some(headers) = result.remove_request_headers {
327 for name in headers {
328 response = response.add_request_header(HeaderOp::Remove { name });
329 }
330 }
331
332 if let Some(headers) = result.add_response_headers {
334 for (name, value) in headers {
335 response = response.add_response_header(HeaderOp::Set { name, value });
336 }
337 }
338
339 if let Some(headers) = result.remove_response_headers {
341 for name in headers {
342 response = response.add_response_header(HeaderOp::Remove { name });
343 }
344 }
345
346 if let Some(tags) = result.tags {
348 response = response.with_audit(AuditMetadata {
349 tags,
350 ..Default::default()
351 });
352 }
353
354 response
355 }
356
357 fn handle_error(&self, error: anyhow::Error, correlation_id: &str) -> AgentResponse {
359 error!(
360 correlation_id = correlation_id,
361 error = %error,
362 "Wasm execution failed"
363 );
364
365 if self.fail_open {
366 AgentResponse::default_allow().with_audit(AuditMetadata {
367 tags: vec!["wasm-error".to_string(), "fail-open".to_string()],
368 reason_codes: vec![error.to_string()],
369 ..Default::default()
370 })
371 } else {
372 AgentResponse::block(500, Some("Wasm Error".to_string())).with_audit(AuditMetadata {
373 tags: vec!["wasm-error".to_string()],
374 reason_codes: vec![error.to_string()],
375 ..Default::default()
376 })
377 }
378 }
379}
380
381#[async_trait]
382impl AgentHandler for WasmAgent {
383 async fn on_configure(&self, event: ConfigureEvent) -> AgentResponse {
384 info!(
385 agent_id = %event.agent_id,
386 "Received configuration event"
387 );
388
389 let config: WasmConfigJson = match serde_json::from_value(event.config) {
391 Ok(c) => c,
392 Err(e) => {
393 error!(error = %e, "Failed to parse Wasm agent configuration");
394 return AgentResponse::block(
395 500,
396 Some(format!("Invalid Wasm agent configuration: {}", e)),
397 );
398 }
399 };
400
401 info!(
406 pool_size = config.pool_size,
407 fail_open = config.fail_open,
408 "Wasm agent configuration received (note: module cannot be changed dynamically)"
409 );
410
411 AgentResponse::default_allow()
412 }
413
414 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
415 let correlation_id = event.metadata.correlation_id.clone();
416
417 let mut instance = match self.acquire_instance().await {
419 Ok(inst) => inst,
420 Err(e) => return self.handle_error(e, &correlation_id),
421 };
422
423 let mut headers: HashMap<String, String> = HashMap::new();
425 for (name, values) in &event.headers {
426 headers.insert(name.clone(), values.join(", "));
427 }
428
429 let request = WasmRequest {
430 method: event.method.clone(),
431 uri: event.uri.clone(),
432 client_ip: event.metadata.client_ip.clone(),
433 correlation_id: correlation_id.clone(),
434 headers,
435 };
436
437 let input_json = match serde_json::to_string(&request) {
438 Ok(j) => j,
439 Err(e) => {
440 self.release_instance(instance).await;
441 return self.handle_error(e.into(), &correlation_id);
442 }
443 };
444
445 if !Self::has_request_handler(&instance) {
447 self.release_instance(instance).await;
448 return AgentResponse::default_allow();
449 }
450
451 let result = Self::call_request_handler(&mut instance, &input_json);
453
454 self.release_instance(instance).await;
456
457 match result {
458 Ok(output_json) => {
459 debug!(
460 correlation_id = correlation_id,
461 output = %output_json,
462 "Wasm handler returned"
463 );
464
465 match serde_json::from_str::<WasmResult>(&output_json) {
466 Ok(wasm_result) => Self::build_response(wasm_result),
467 Err(e) => {
468 warn!(
469 correlation_id = correlation_id,
470 error = %e,
471 output = %output_json,
472 "Failed to parse Wasm result"
473 );
474 self.handle_error(e.into(), &correlation_id)
475 }
476 }
477 }
478 Err(e) => self.handle_error(e, &correlation_id),
479 }
480 }
481
482 async fn on_response_headers(&self, event: ResponseHeadersEvent) -> AgentResponse {
483 let correlation_id = event.correlation_id.clone();
484
485 let mut instance = match self.acquire_instance().await {
487 Ok(inst) => inst,
488 Err(e) => return self.handle_error(e, &correlation_id),
489 };
490
491 let mut headers: HashMap<String, String> = HashMap::new();
493 for (name, values) in &event.headers {
494 headers.insert(name.clone(), values.join(", "));
495 }
496
497 let response = WasmResponse {
498 status: event.status,
499 correlation_id: correlation_id.clone(),
500 headers,
501 };
502
503 let input_json = match serde_json::to_string(&response) {
504 Ok(j) => j,
505 Err(e) => {
506 self.release_instance(instance).await;
507 return self.handle_error(e.into(), &correlation_id);
508 }
509 };
510
511 if !Self::has_response_handler(&instance) {
513 self.release_instance(instance).await;
514 return AgentResponse::default_allow();
515 }
516
517 let result = Self::call_response_handler(&mut instance, &input_json);
519
520 self.release_instance(instance).await;
522
523 match result {
524 Ok(output_json) => {
525 debug!(
526 correlation_id = correlation_id,
527 output = %output_json,
528 "Wasm handler returned"
529 );
530
531 match serde_json::from_str::<WasmResult>(&output_json) {
532 Ok(wasm_result) => Self::build_response(wasm_result),
533 Err(e) => {
534 warn!(
535 correlation_id = correlation_id,
536 error = %e,
537 output = %output_json,
538 "Failed to parse Wasm result"
539 );
540 self.handle_error(e.into(), &correlation_id)
541 }
542 }
543 }
544 Err(e) => self.handle_error(e, &correlation_id),
545 }
546 }
547}