1use ::wasmer::Memory;
2use bytes::Bytes;
3use parking_lot::RwLock;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7#[derive(Clone, Debug)]
8pub struct ResponseState {
9 headers: HashMap<String, String>,
10 status: u16,
11 status_overridden: bool,
12}
13
14impl ResponseState {
15 pub fn with_default_status(status: u16) -> Self {
16 Self::from_parts(HashMap::new(), status, false)
17 }
18
19 pub fn from_parts(
20 headers: HashMap<String, String>,
21 status: u16,
22 status_overridden: bool,
23 ) -> Self {
24 Self {
25 headers,
26 status,
27 status_overridden,
28 }
29 }
30
31 pub fn headers(&self) -> &HashMap<String, String> {
32 &self.headers
33 }
34
35 pub fn headers_mut(&mut self) -> &mut HashMap<String, String> {
36 &mut self.headers
37 }
38
39 pub fn status(&self) -> u16 {
40 self.status
41 }
42
43 pub fn set_status(&mut self, status: u16) {
44 self.status = status;
45 self.status_overridden = true;
46 }
47
48 pub fn status_override(&self) -> Option<u16> {
49 self.status_overridden.then_some(self.status)
50 }
51}
52
53impl Default for ResponseState {
54 fn default() -> Self {
55 Self::with_default_status(0)
56 }
57}
58
59mod host;
60pub mod instance;
61pub mod plugin;
62pub mod runner;
63pub mod utils;
64
65pub mod wasmer {
66 pub use wasmer::*;
67}
68
69#[derive(Clone, Debug)]
70struct RequestData {
71 headers: Arc<HashMap<String, String>>,
72 query: Arc<HashMap<String, Vec<String>>>,
73 body: Option<Bytes>,
74 persistent_vars: Arc<RwLock<HashMap<String, String>>>,
75}
76
77impl RequestData {
78 fn new(
79 headers: HashMap<String, String>,
80 query: HashMap<String, Vec<String>>,
81 body: Option<Bytes>,
82 persistent_vars: Arc<RwLock<HashMap<String, String>>>,
83 ) -> Self {
84 Self {
85 headers: Arc::new(normalize_headers(headers)),
86 query: Arc::new(normalize_query(query)),
87 body,
88 persistent_vars,
89 }
90 }
91}
92
93impl Default for RequestData {
94 fn default() -> Self {
95 Self {
96 headers: Arc::new(HashMap::new()),
97 query: Arc::new(HashMap::new()),
98 body: None,
99 persistent_vars: Arc::new(RwLock::new(HashMap::new())),
100 }
101 }
102}
103
104#[derive(Clone, Default, Debug)]
105pub struct ExecutionContext {
106 memory: Option<Memory>,
107 request: RequestData,
108 response: ResponseState,
109}
110
111impl ExecutionContext {
112 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn with_response(response: ResponseState) -> Self {
117 Self {
118 response,
119 ..Self::default()
120 }
121 }
122
123 pub fn from_parts(
124 req_headers: HashMap<String, String>,
125 query: HashMap<String, Vec<String>>,
126 body: Option<Bytes>,
127 response: ResponseState,
128 persistent_vars: Arc<RwLock<HashMap<String, String>>>,
129 ) -> Self {
130 Self {
131 memory: None,
132 request: RequestData::new(req_headers, query, body, persistent_vars),
133 response,
134 }
135 }
136
137 pub fn replace_memory(&mut self, memory: Memory) {
138 self.memory.replace(memory);
139 }
140
141 pub fn memory(&self) -> &Option<Memory> {
142 &self.memory
143 }
144
145 pub fn memory_mut(&mut self) -> &mut Option<Memory> {
146 &mut self.memory
147 }
148
149 pub fn req_headers(&self) -> &HashMap<String, String> {
150 &self.request.headers
151 }
152
153 pub fn req_headers_mut(&mut self) -> &mut HashMap<String, String> {
154 Arc::make_mut(&mut self.request.headers)
155 }
156
157 pub fn query(&self) -> &HashMap<String, Vec<String>> {
158 &self.request.query
159 }
160
161 pub fn query_mut(&mut self) -> &mut HashMap<String, Vec<String>> {
162 Arc::make_mut(&mut self.request.query)
163 }
164
165 pub fn body(&self) -> &Option<Bytes> {
166 &self.request.body
167 }
168
169 pub fn set_body(&mut self, body: Option<Bytes>) {
170 self.request.body = body;
171 }
172
173 pub fn response(&self) -> &ResponseState {
174 &self.response
175 }
176
177 pub fn response_mut(&mut self) -> &mut ResponseState {
178 &mut self.response
179 }
180
181 pub fn persistent_vars(&self) -> &Arc<RwLock<HashMap<String, String>>> {
182 &self.request.persistent_vars
183 }
184}
185
186pub type SharedExecutionContext = Arc<RwLock<ExecutionContext>>;
187
188fn normalize_headers(headers: HashMap<String, String>) -> HashMap<String, String> {
189 headers
190 .into_iter()
191 .map(|(k, v)| (k.to_ascii_lowercase(), v))
192 .collect()
193}
194
195fn normalize_query(query: HashMap<String, Vec<String>>) -> HashMap<String, Vec<String>> {
196 query
197 .into_iter()
198 .map(|(k, v)| (k.to_ascii_lowercase(), v))
199 .collect()
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::plugin::WasmPlugin;
206 use crate::runner::WasmRunner;
207 use bytes::Bytes;
208 use serde_json::Value;
209 use std::collections::HashMap;
210 use std::fs;
211 use std::path::{Path, PathBuf};
212
213 const CASE_ROOT: &str = "../../../tests/wasm-plugins";
214
215 #[test]
216 fn wasm_plugin_allow_sets_headers() {
217 run_wasm_case("allow", ScenarioKind::Response);
218 }
219
220 #[test]
221 fn wasm_plugin_blocks_flagged_requests() {
222 run_wasm_case("block", ScenarioKind::Response);
223 }
224
225 #[test]
226 fn wasm_plugin_requires_tenant() {
227 run_wasm_case("require-tenant", ScenarioKind::Response);
228 }
229
230 #[test]
231 fn wasm_inbound_plugin_allows_when_header_present() {
232 run_wasm_case("inbound-allow", ScenarioKind::Request);
233 }
234
235 #[test]
236 fn wasm_inbound_plugin_blocks_without_header() {
237 run_wasm_case("inbound-block", ScenarioKind::Request);
238 }
239
240 fn run_wasm_case(name: &str, expected_type: ScenarioKind) {
241 let case_dir = case_path(name);
242 let wasm_path = case_dir.join("plugin.wasm");
243 let incoming_path = case_dir.join("incoming_request.json");
244 let expected_path = case_dir.join("expected_response.json");
245
246 let wasm_plugin = WasmPlugin::from_path(&wasm_path)
247 .unwrap_or_else(|e| panic!("failed to load plugin {:?}: {}", wasm_path, e));
248
249 let incoming = load_json(&incoming_path);
250 let expected = load_json(&expected_path);
251
252 let expected = expected_response_from_value(&expected, name);
253 if expected.execution_type != expected_type {
254 panic!(
255 "fixture {} declares execution_type {:?} but test expected {:?}",
256 name, expected.execution_type, expected_type
257 );
258 }
259
260 let exec_ctx = execution_context_from_value(&incoming, expected.execution_type, name);
261
262 let runner = WasmRunner::new(&wasm_plugin, None);
263 let shared_ctx = Arc::new(RwLock::new(exec_ctx));
264 let result = runner
265 .run(shared_ctx)
266 .unwrap_or_else(|e| panic!("plugin execution failed for {:?}: {}", wasm_path, e));
267
268 assert_eq!(
269 result.should_continue, expected.should_continue,
270 "decision mismatch for {}",
271 name
272 );
273 let context = result.execution_context;
274 match expected.execution_type {
275 ScenarioKind::Response => {
276 let response = context.response();
277 let expected_status = expected.status.unwrap_or_else(|| {
278 panic!(
279 "outbound fixture {} must define a status field in expected_response.json",
280 name
281 )
282 });
283 assert_eq!(
284 i32::from(response.status()),
285 expected_status,
286 "status mismatch for {}",
287 name
288 );
289
290 let actual_headers = lowercase_string_map(response.headers().clone());
291 for (key, value) in expected.resp_headers.iter() {
292 let actual = actual_headers
293 .get(key)
294 .unwrap_or_else(|| panic!("missing header `{}` for {}", key, name));
295 assert_eq!(actual, value, "header `{}` mismatch for {}", key, name);
296 }
297 }
298 ScenarioKind::Request => {
299 assert!(
300 expected.status.is_none(),
301 "inbound fixture {} should not define a status field",
302 name
303 );
304
305 assert!(
306 expected.resp_headers.is_empty(),
307 "inbound fixture {} should not define resp_headers",
308 name
309 );
310
311 }
313 }
314 }
315
316 fn case_path(name: &str) -> PathBuf {
317 Path::new(env!("CARGO_MANIFEST_DIR"))
318 .join(CASE_ROOT)
319 .join(name)
320 }
321
322 fn load_json(path: &Path) -> Value {
323 let data =
324 fs::read_to_string(path).unwrap_or_else(|e| panic!("failed to read {:?}: {}", path, e));
325 serde_json::from_str(&data).unwrap_or_else(|e| panic!("failed to parse {:?}: {}", path, e))
326 }
327
328 fn execution_context_from_value(
329 value: &Value,
330 scenario_kind: ScenarioKind,
331 _scenario: &str,
332 ) -> ExecutionContext {
333 let req_headers = lowercase_string_map(json_string_map(value.get("req_headers")));
334 let query = lowercase_string_vec_map(json_string_vec_map(value.get("query")));
335 let body = value.get("body").and_then(body_from_value);
336
337 let response_state = match scenario_kind {
338 ScenarioKind::Request => response_state_from_value(value, 403),
339 ScenarioKind::Response => response_state_from_value(value, 0),
340 };
341
342 ExecutionContext::from_parts(
343 req_headers,
344 query,
345 body,
346 response_state,
347 Arc::new(RwLock::new(HashMap::new())),
348 )
349 }
350
351 fn expected_response_from_value(value: &Value, scenario: &str) -> ExpectedResponse {
352 let should_continue = value
353 .get("should_continue")
354 .and_then(Value::as_bool)
355 .unwrap_or(false);
356 let execution_type = execution_type_from_value(value.get("execution_type"), scenario);
357 let status = value
358 .get("status")
359 .and_then(Value::as_i64)
360 .map(|s| s as i32);
361 let resp_headers = lowercase_string_map(json_string_map(value.get("resp_headers")));
362
363 ExpectedResponse {
364 should_continue,
365 status,
366 resp_headers,
367 execution_type,
368 }
369 }
370
371 fn json_string_map(value: Option<&Value>) -> HashMap<String, String> {
372 match value {
373 Some(Value::Object(map)) => map
374 .iter()
375 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
376 .collect(),
377 _ => HashMap::new(),
378 }
379 }
380
381 fn json_string_vec_map(value: Option<&Value>) -> HashMap<String, Vec<String>> {
382 match value {
383 Some(Value::Object(map)) => map
384 .iter()
385 .map(|(k, v)| (k.clone(), value_to_string_vec(v)))
386 .collect(),
387 _ => HashMap::new(),
388 }
389 }
390
391 fn value_to_string_vec(value: &Value) -> Vec<String> {
392 match value {
393 Value::String(s) => vec![s.to_string()],
394 Value::Array(arr) => arr
395 .iter()
396 .filter_map(|v| v.as_str().map(|s| s.to_string()))
397 .collect(),
398 _ => Vec::new(),
399 }
400 }
401
402 fn lowercase_string_map(map: HashMap<String, String>) -> HashMap<String, String> {
403 map.into_iter()
404 .map(|(k, v)| (k.to_ascii_lowercase(), v))
405 .collect()
406 }
407
408 fn response_state_from_value(value: &Value, default_status: u16) -> ResponseState {
409 let headers = lowercase_string_map(json_string_map(value.get("resp_headers")));
410 let override_status = value.get("status").and_then(Value::as_i64).and_then(|raw| {
411 if raw > 0 {
412 u16::try_from(raw).ok()
413 } else {
414 None
415 }
416 });
417
418 match override_status {
419 Some(status) => ResponseState::from_parts(headers, status, true),
420 None => ResponseState::from_parts(headers, default_status, false),
421 }
422 }
423
424 fn lowercase_string_vec_map(map: HashMap<String, Vec<String>>) -> HashMap<String, Vec<String>> {
425 map.into_iter()
426 .map(|(k, v)| (k.to_ascii_lowercase(), v))
427 .collect()
428 }
429
430 fn body_from_value(value: &Value) -> Option<Bytes> {
431 match value {
432 Value::Null => None,
433 Value::String(s) => Some(Bytes::from(s.clone())),
434 _ => None,
435 }
436 }
437
438 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
439 enum ScenarioKind {
440 Request,
441 Response,
442 }
443
444 struct ExpectedResponse {
445 should_continue: bool,
446 status: Option<i32>,
447 resp_headers: HashMap<String, String>,
448 execution_type: ScenarioKind,
449 }
450
451 fn execution_type_from_value(value: Option<&Value>, scenario: &str) -> ScenarioKind {
452 let raw = value
453 .and_then(Value::as_str)
454 .unwrap_or_else(|| panic!("fixture {} must define execution_type", scenario));
455
456 match raw.to_ascii_lowercase().as_str() {
457 "inbound" => ScenarioKind::Request,
458 "outbound" => ScenarioKind::Response,
459 other => panic!(
460 "fixture {} has invalid execution_type '{}'; expected 'inbound' or 'outbound'",
461 scenario, other
462 ),
463 }
464 }
465}