1use bytes::Bytes;
2use derive_builder::Builder;
3use enum_as_inner::EnumAsInner;
4use std::collections::HashMap;
5use wasmer::Memory;
6
7mod host;
8pub mod instance;
9pub mod plugin;
10pub mod runner;
11mod utils;
12
13#[derive(Clone, Builder)]
14pub struct ExecutionRequest {
15 pub memory: Option<Memory>,
16 pub req_headers: HashMap<String, String>,
17 pub query: HashMap<String, Vec<String>>,
18 pub body: Option<Bytes>,
19}
20
21#[derive(Clone, Builder)]
22pub struct ExecutionResponse {
23 pub memory: Option<Memory>,
24 pub req_headers: HashMap<String, String>,
25 pub query: HashMap<String, Vec<String>>,
26 pub body: Option<Bytes>,
27
28 pub resp_headers: HashMap<String, String>,
30 pub status: i32,
31}
32
33#[derive(Clone, EnumAsInner)]
36pub enum ExecutionContext {
37 Inbound(ExecutionRequest),
38 Outbound(ExecutionResponse),
39}
40
41impl ExecutionContext {
42 pub fn replace_memory(&mut self, memory: Memory) {
43 match self {
44 ExecutionContext::Inbound(ctx) => {
45 ctx.memory.replace(memory);
46 }
47 ExecutionContext::Outbound(ctx) => {
48 ctx.memory.replace(memory);
49 }
50 }
51 }
52
53 pub fn body(&self) -> &Option<Bytes> {
54 match self {
55 ExecutionContext::Inbound(inbound) => &inbound.body,
56 ExecutionContext::Outbound(outbound) => &outbound.body,
57 }
58 }
59
60 pub fn req_headers(&self) -> &HashMap<String, String> {
61 match self {
62 ExecutionContext::Inbound(inbound) => &inbound.req_headers,
63 ExecutionContext::Outbound(outbound) => &outbound.req_headers,
64 }
65 }
66
67 pub fn query(&self) -> &HashMap<String, Vec<String>> {
68 match self {
69 ExecutionContext::Inbound(inbound) => &inbound.query,
70 ExecutionContext::Outbound(outbound) => &outbound.query,
71 }
72 }
73
74 pub fn memory(&self) -> &Option<Memory> {
75 match self {
76 ExecutionContext::Inbound(inbound) => &inbound.memory,
77 ExecutionContext::Outbound(outbound) => &outbound.memory,
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::plugin::WasmPlugin;
86 use crate::runner::{ExecutionType, WasmRunner};
87 use bytes::Bytes;
88 use serde_json::Value;
89 use std::collections::HashMap;
90 use std::fs;
91 use std::path::{Path, PathBuf};
92
93 const CASE_ROOT: &str = "../../../tests/wasm-plugins";
94
95 #[test]
96 fn wasm_plugin_allow_sets_headers() {
97 run_wasm_case("allow", ExecutionType::Outbound);
98 }
99
100 #[test]
101 fn wasm_plugin_blocks_flagged_requests() {
102 run_wasm_case("block", ExecutionType::Outbound);
103 }
104
105 #[test]
106 fn wasm_plugin_requires_tenant() {
107 run_wasm_case("require-tenant", ExecutionType::Outbound);
108 }
109
110 #[test]
111 fn wasm_inbound_plugin_allows_when_header_present() {
112 run_wasm_case("inbound-allow", ExecutionType::Inbound);
113 }
114
115 #[test]
116 fn wasm_inbound_plugin_blocks_without_header() {
117 run_wasm_case("inbound-block", ExecutionType::Inbound);
118 }
119
120 fn run_wasm_case(name: &str, expected_type: ExecutionType) {
121 let case_dir = case_path(name);
122 let wasm_path = case_dir.join("plugin.wasm");
123 let incoming_path = case_dir.join("incoming_request.json");
124 let expected_path = case_dir.join("expected_response.json");
125
126 let wasm_plugin = WasmPlugin::from_path(&wasm_path)
127 .unwrap_or_else(|e| panic!("failed to load plugin {:?}: {}", wasm_path, e));
128
129 let incoming = load_json(&incoming_path);
130 let expected = load_json(&expected_path);
131
132 let expected = expected_response_from_value(&expected, name);
133 if !matches!(
134 (expected.execution_type, expected_type),
135 (ExecutionType::Inbound, ExecutionType::Inbound)
136 | (ExecutionType::Outbound, ExecutionType::Outbound)
137 ) {
138 panic!(
139 "fixture {} declares execution_type {:?} but test expected {:?}",
140 name, expected.execution_type, expected_type
141 );
142 }
143
144 let exec_ctx = execution_context_from_value(&incoming, expected.execution_type, name);
145
146 let runner = WasmRunner::new(&wasm_plugin, expected.execution_type);
147 let result = runner
148 .run(exec_ctx)
149 .unwrap_or_else(|e| panic!("plugin execution failed for {:?}: {}", wasm_path, e));
150
151 assert_eq!(
152 result.should_continue, expected.should_continue,
153 "decision mismatch for {}",
154 name
155 );
156 match expected.execution_type {
157 ExecutionType::Outbound => {
158 let outbound = result
159 .execution_context
160 .into_outbound()
161 .unwrap_or_else(|_| panic!("expected outbound context for {}", name));
162
163 let expected_status = expected.status.unwrap_or_else(|| {
164 panic!(
165 "outbound fixture {} must define a status field in expected_response.json",
166 name
167 )
168 });
169 assert_eq!(
170 outbound.status, expected_status,
171 "status mismatch for {}",
172 name
173 );
174
175 let actual_headers = lowercase_string_map(outbound.resp_headers.clone());
176 for (key, value) in expected.resp_headers.iter() {
177 let actual = actual_headers
178 .get(key)
179 .unwrap_or_else(|| panic!("missing header `{}` for {}", key, name));
180 assert_eq!(actual, value, "header `{}` mismatch for {}", key, name);
181 }
182 }
183 ExecutionType::Inbound => {
184 let inbound = result
185 .execution_context
186 .into_inbound()
187 .unwrap_or_else(|_| panic!("expected inbound context for {}", name));
188
189 assert!(
190 expected.status.is_none(),
191 "inbound fixture {} should not define a status field",
192 name
193 );
194
195 assert!(
196 expected.resp_headers.is_empty(),
197 "inbound fixture {} should not define resp_headers",
198 name
199 );
200
201 let _inbound = inbound;
203 }
204 }
205 }
206
207 fn case_path(name: &str) -> PathBuf {
208 Path::new(env!("CARGO_MANIFEST_DIR"))
209 .join(CASE_ROOT)
210 .join(name)
211 }
212
213 fn load_json(path: &Path) -> Value {
214 let data =
215 fs::read_to_string(path).unwrap_or_else(|e| panic!("failed to read {:?}: {}", path, e));
216 serde_json::from_str(&data).unwrap_or_else(|e| panic!("failed to parse {:?}: {}", path, e))
217 }
218
219 fn execution_context_from_value(
220 value: &Value,
221 exec_type: ExecutionType,
222 scenario: &str,
223 ) -> ExecutionContext {
224 let req_headers = lowercase_string_map(json_string_map(value.get("req_headers")));
225 let query = lowercase_string_vec_map(json_string_vec_map(value.get("query")));
226 let body = value.get("body").and_then(body_from_value);
227
228 match exec_type {
229 ExecutionType::Inbound => ExecutionContext::Inbound(ExecutionRequest {
230 memory: None,
231 req_headers,
232 query,
233 body,
234 }),
235 ExecutionType::Outbound => {
236 let resp_headers = lowercase_string_map(json_string_map(value.get("resp_headers")));
237 let status = value
238 .get("status")
239 .and_then(Value::as_i64)
240 .unwrap_or_else(|| {
241 panic!(
242 "outbound fixture {} must define a numeric status field",
243 scenario
244 )
245 }) as i32;
246
247 ExecutionContext::Outbound(ExecutionResponse {
248 memory: None,
249 req_headers,
250 query,
251 body,
252 resp_headers,
253 status,
254 })
255 }
256 }
257 }
258
259 fn expected_response_from_value(value: &Value, scenario: &str) -> ExpectedResponse {
260 let should_continue = value
261 .get("should_continue")
262 .and_then(Value::as_bool)
263 .unwrap_or(false);
264 let execution_type = execution_type_from_value(value.get("execution_type"), scenario);
265 let status = value
266 .get("status")
267 .and_then(Value::as_i64)
268 .map(|s| s as i32);
269 let resp_headers = lowercase_string_map(json_string_map(value.get("resp_headers")));
270
271 ExpectedResponse {
272 should_continue,
273 status,
274 resp_headers,
275 execution_type,
276 }
277 }
278
279 fn json_string_map(value: Option<&Value>) -> HashMap<String, String> {
280 match value {
281 Some(Value::Object(map)) => map
282 .iter()
283 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
284 .collect(),
285 _ => HashMap::new(),
286 }
287 }
288
289 fn json_string_vec_map(value: Option<&Value>) -> HashMap<String, Vec<String>> {
290 match value {
291 Some(Value::Object(map)) => map
292 .iter()
293 .map(|(k, v)| (k.clone(), value_to_string_vec(v)))
294 .collect(),
295 _ => HashMap::new(),
296 }
297 }
298
299 fn value_to_string_vec(value: &Value) -> Vec<String> {
300 match value {
301 Value::String(s) => vec![s.to_string()],
302 Value::Array(arr) => arr
303 .iter()
304 .filter_map(|v| v.as_str().map(|s| s.to_string()))
305 .collect(),
306 _ => Vec::new(),
307 }
308 }
309
310 fn lowercase_string_map(map: HashMap<String, String>) -> HashMap<String, String> {
311 map.into_iter()
312 .map(|(k, v)| (k.to_ascii_lowercase(), v))
313 .collect()
314 }
315
316 fn lowercase_string_vec_map(map: HashMap<String, Vec<String>>) -> HashMap<String, Vec<String>> {
317 map.into_iter()
318 .map(|(k, v)| (k.to_ascii_lowercase(), v))
319 .collect()
320 }
321
322 fn body_from_value(value: &Value) -> Option<Bytes> {
323 match value {
324 Value::Null => None,
325 Value::String(s) => Some(Bytes::from(s.clone())),
326 _ => None,
327 }
328 }
329
330 struct ExpectedResponse {
331 should_continue: bool,
332 status: Option<i32>,
333 resp_headers: HashMap<String, String>,
334 execution_type: ExecutionType,
335 }
336
337 fn execution_type_from_value(value: Option<&Value>, scenario: &str) -> ExecutionType {
338 let raw = value
339 .and_then(Value::as_str)
340 .unwrap_or_else(|| panic!("fixture {} must define execution_type", scenario));
341
342 match raw.to_ascii_lowercase().as_str() {
343 "inbound" => ExecutionType::Inbound,
344 "outbound" => ExecutionType::Outbound,
345 other => panic!(
346 "fixture {} has invalid execution_type '{}'; expected 'inbound' or 'outbound'",
347 scenario, other
348 ),
349 }
350 }
351}