1use std::time::Instant;
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::WasmError;
12use crate::instance::{PluginInstance, RequestContext};
13use crate::trap::{TrapContext, TrapResult};
14
15pub type MetricsCallback<'a> = Option<&'a dyn Fn(&str, &str, f64, bool)>;
18
19#[derive(Debug)]
21pub enum OnRequestResult {
22 Continue(Vec<u8>),
24 ShortCircuit(Vec<u8>),
26}
27
28#[derive(Debug, Serialize, Deserialize)]
32struct MiddlewareOutput {
33 action: i32,
35 data: serde_json::Value,
37}
38
39#[derive(Debug, Clone)]
41pub struct MiddlewareConfig {
42 pub name: String,
44 pub config: serde_json::Value,
46}
47
48impl MiddlewareConfig {
49 pub fn new(name: impl Into<String>, config: serde_json::Value) -> Self {
51 Self {
52 name: name.into(),
53 config,
54 }
55 }
56}
57
58pub struct MiddlewareChain {
60 configs: Vec<MiddlewareConfig>,
62}
63
64impl MiddlewareChain {
65 pub fn new() -> Self {
67 Self {
68 configs: Vec::new(),
69 }
70 }
71
72 pub fn from_configs(configs: Vec<MiddlewareConfig>) -> Self {
74 Self { configs }
75 }
76
77 pub fn push(&mut self, config: MiddlewareConfig) {
79 self.configs.push(config);
80 }
81
82 pub fn len(&self) -> usize {
84 self.configs.len()
85 }
86
87 pub fn is_empty(&self) -> bool {
89 self.configs.is_empty()
90 }
91
92 pub fn configs(&self) -> &[MiddlewareConfig] {
94 &self.configs
95 }
96}
97
98impl Default for MiddlewareChain {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[derive(Debug)]
106pub enum ChainResult {
107 Continue {
109 request: Vec<u8>,
111 context: RequestContext,
113 },
114 ShortCircuit {
116 response: Vec<u8>,
118 middleware_index: usize,
120 context: RequestContext,
122 },
123 Error {
125 error: WasmError,
127 trap_result: TrapResult,
129 },
130}
131
132pub fn execute_on_request(
138 instances: &mut [PluginInstance],
139 initial_request: &[u8],
140 context: RequestContext,
141) -> ChainResult {
142 execute_on_request_with_metrics(instances, initial_request, context, None)
143}
144
145pub fn execute_on_request_with_metrics(
147 instances: &mut [PluginInstance],
148 initial_request: &[u8],
149 context: RequestContext,
150 metrics_callback: MetricsCallback<'_>,
151) -> ChainResult {
152 let mut current_request = initial_request.to_vec();
153 let mut current_context = context;
154
155 for (index, instance) in instances.iter_mut().enumerate() {
156 instance.set_context(current_context.clone());
158
159 let start = Instant::now();
161 let middleware_name = instance.name().to_string();
162
163 match instance.on_request(¤t_request) {
165 Ok(result_code) => {
166 let output = instance.take_output();
167
168 match parse_middleware_output(&output, result_code) {
170 Ok(OnRequestResult::Continue(new_request)) => {
171 if let Some(callback) = metrics_callback {
173 callback(
174 &middleware_name,
175 "request",
176 start.elapsed().as_secs_f64(),
177 false,
178 );
179 }
180 current_request = new_request;
181 current_context = instance.get_context();
183 }
184 Ok(OnRequestResult::ShortCircuit(response)) => {
185 if let Some(callback) = metrics_callback {
187 callback(
188 &middleware_name,
189 "request",
190 start.elapsed().as_secs_f64(),
191 true,
192 );
193 }
194 let final_context = instance.get_context();
196 return ChainResult::ShortCircuit {
197 response,
198 middleware_index: index,
199 context: final_context,
200 };
201 }
202 Err(e) => {
203 if let Some(callback) = metrics_callback {
205 callback(
206 &middleware_name,
207 "request",
208 start.elapsed().as_secs_f64(),
209 false,
210 );
211 }
212 return ChainResult::Error {
213 trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
214 error: e,
215 };
216 }
217 }
218 }
219 Err(e) => {
220 if let Some(callback) = metrics_callback {
222 callback(
223 &middleware_name,
224 "request",
225 start.elapsed().as_secs_f64(),
226 false,
227 );
228 }
229 return ChainResult::Error {
230 trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
231 error: e,
232 };
233 }
234 }
235 }
236
237 ChainResult::Continue {
238 request: current_request,
239 context: current_context,
240 }
241}
242
243pub fn execute_on_response(
248 instances: &mut [PluginInstance],
249 initial_response: &[u8],
250 context: RequestContext,
251) -> Vec<u8> {
252 execute_on_response_with_metrics(instances, initial_response, context, None)
253}
254
255pub fn execute_on_response_with_metrics(
257 instances: &mut [PluginInstance],
258 initial_response: &[u8],
259 context: RequestContext,
260 metrics_callback: MetricsCallback<'_>,
261) -> Vec<u8> {
262 let mut current_response = initial_response.to_vec();
263
264 for instance in instances.iter_mut().rev() {
266 instance.set_context(context.clone());
267
268 let start = Instant::now();
270 let middleware_name = instance.name().to_string();
271
272 match instance.on_response(¤t_response) {
273 Ok(_result_code) => {
274 if let Some(callback) = metrics_callback {
276 callback(
277 &middleware_name,
278 "response",
279 start.elapsed().as_secs_f64(),
280 false,
281 );
282 }
283 let output = instance.take_output();
284 if !output.is_empty() {
285 current_response = output;
286 }
287 }
288 Err(e) => {
289 if let Some(callback) = metrics_callback {
291 callback(
292 &middleware_name,
293 "response",
294 start.elapsed().as_secs_f64(),
295 false,
296 );
297 }
298 let trap_result = TrapResult::from_error(&e, TrapContext::OnResponse);
300 tracing::warn!(
301 error = %trap_result.message(),
302 "Middleware on_response failed, continuing with original response"
303 );
304 }
305 }
306 }
307
308 current_response
309}
310
311pub fn execute_on_response_partial(
315 instances: &mut [PluginInstance],
316 response: &[u8],
317 short_circuit_index: usize,
318 context: RequestContext,
319) -> Vec<u8> {
320 if short_circuit_index == 0 {
321 return response.to_vec();
322 }
323
324 let partial_instances = &mut instances[..short_circuit_index];
325 execute_on_response(partial_instances, response, context)
326}
327
328pub fn parse_middleware_output(
330 output: &[u8],
331 result_code: i32,
332) -> Result<OnRequestResult, WasmError> {
333 if output.is_empty() {
335 return if result_code == 0 {
336 Ok(OnRequestResult::Continue(Vec::new()))
337 } else {
338 Err(WasmError::InitFailed(
339 "middleware returned short-circuit without output".into(),
340 ))
341 };
342 }
343
344 match serde_json::from_slice::<MiddlewareOutput>(output) {
346 Ok(parsed) => {
347 let data = serde_json::to_vec(&parsed.data)
348 .map_err(|e| WasmError::InitFailed(format!("failed to serialize output: {}", e)))?;
349
350 if parsed.action == 0 || result_code == 0 {
351 Ok(OnRequestResult::Continue(data))
352 } else {
353 Ok(OnRequestResult::ShortCircuit(data))
354 }
355 }
356 Err(_) => {
357 if result_code == 0 {
359 Ok(OnRequestResult::Continue(output.to_vec()))
360 } else {
361 Ok(OnRequestResult::ShortCircuit(output.to_vec()))
362 }
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use serde_json::json;
371
372 #[test]
373 fn middleware_config_new() {
374 let config = MiddlewareConfig::new("rate-limit", json!({"quota": 100}));
375 assert_eq!(config.name, "rate-limit");
376 assert_eq!(config.config["quota"], 100);
377 }
378
379 #[test]
380 fn chain_new_is_empty() {
381 let chain = MiddlewareChain::new();
382 assert!(chain.is_empty());
383 assert_eq!(chain.len(), 0);
384 }
385
386 #[test]
387 fn chain_push() {
388 let mut chain = MiddlewareChain::new();
389 chain.push(MiddlewareConfig::new("auth", json!({})));
390 chain.push(MiddlewareConfig::new("rate-limit", json!({})));
391
392 assert_eq!(chain.len(), 2);
393 assert_eq!(chain.configs()[0].name, "auth");
394 assert_eq!(chain.configs()[1].name, "rate-limit");
395 }
396
397 #[test]
398 fn chain_from_configs() {
399 let configs = vec![
400 MiddlewareConfig::new("auth", json!({})),
401 MiddlewareConfig::new("cors", json!({})),
402 ];
403 let chain = MiddlewareChain::from_configs(configs);
404
405 assert_eq!(chain.len(), 2);
406 }
407
408 #[test]
409 fn parse_continue_output() {
410 let output = serde_json::to_vec(&json!({
411 "action": 0,
412 "data": {"method": "GET", "path": "/api"}
413 }))
414 .unwrap();
415
416 let result = parse_middleware_output(&output, 0).unwrap();
417 assert!(matches!(result, OnRequestResult::Continue(_)));
418 }
419
420 #[test]
421 fn parse_short_circuit_output() {
422 let output = serde_json::to_vec(&json!({
423 "action": 1,
424 "data": {"status": 401, "body": "Unauthorized"}
425 }))
426 .unwrap();
427
428 let result = parse_middleware_output(&output, 1).unwrap();
429 assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
430 }
431
432 #[test]
433 fn parse_raw_output_continue() {
434 let output = b"raw request data";
435 let result = parse_middleware_output(output, 0).unwrap();
436 assert!(matches!(result, OnRequestResult::Continue(_)));
437 }
438
439 #[test]
440 fn parse_raw_output_short_circuit() {
441 let output = b"error response";
442 let result = parse_middleware_output(output, 1).unwrap();
443 assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
444 }
445
446 #[test]
447 fn parse_empty_continue() {
448 let result = parse_middleware_output(&[], 0).unwrap();
449 assert!(matches!(result, OnRequestResult::Continue(data) if data.is_empty()));
450 }
451
452 #[test]
453 fn parse_empty_short_circuit_fails() {
454 let result = parse_middleware_output(&[], 1);
455 assert!(result.is_err());
456 }
457
458 #[test]
461 fn parse_continue_with_request_metadata() {
462 use barbacane_plugin_sdk::types::Request;
463 use std::collections::BTreeMap;
464
465 let req = Request {
466 method: "POST".into(),
467 path: "/upload".into(),
468 query: None,
469 headers: {
470 let mut h = BTreeMap::new();
471 h.insert("content-type".into(), "application/octet-stream".into());
472 h
473 },
474 body: None, client_ip: "127.0.0.1".into(),
476 path_params: BTreeMap::new(),
477 };
478
479 let output = serde_json::to_vec(&json!({
481 "action": 0,
482 "data": req
483 }))
484 .unwrap();
485
486 let result = parse_middleware_output(&output, 0).unwrap();
487 match result {
488 OnRequestResult::Continue(data) => {
489 let parsed: Request = serde_json::from_slice(&data).unwrap();
490 assert_eq!(parsed.method, "POST");
491 assert_eq!(parsed.path, "/upload");
492 assert_eq!(parsed.body, None); }
494 OnRequestResult::ShortCircuit(_) => panic!("expected Continue"),
495 }
496 }
497
498 #[test]
500 fn parse_short_circuit_with_response_metadata() {
501 use barbacane_plugin_sdk::types::Response;
502 use std::collections::BTreeMap;
503
504 let resp = Response {
505 status: 403,
506 headers: {
507 let mut h = BTreeMap::new();
508 h.insert("content-type".into(), "application/json".into());
509 h
510 },
511 body: None, };
513
514 let output = serde_json::to_vec(&json!({
515 "action": 1,
516 "data": resp
517 }))
518 .unwrap();
519
520 let result = parse_middleware_output(&output, 1).unwrap();
521 match result {
522 OnRequestResult::ShortCircuit(data) => {
523 let parsed: Response = serde_json::from_slice(&data).unwrap();
524 assert_eq!(parsed.status, 403);
525 assert_eq!(parsed.body, None); }
527 OnRequestResult::Continue(_) => panic!("expected ShortCircuit"),
528 }
529 }
530
531 #[test]
532 fn metrics_callback_type_accepts_closure() {
533 use std::cell::RefCell;
534 use std::rc::Rc;
535
536 let invocations = Rc::new(RefCell::new(Vec::new()));
538 let invocations_clone = invocations.clone();
539
540 let callback = move |name: &str, phase: &str, duration: f64, short_circuit: bool| {
541 invocations_clone.borrow_mut().push((
542 name.to_string(),
543 phase.to_string(),
544 duration,
545 short_circuit,
546 ));
547 };
548
549 let metrics_callback: MetricsCallback<'_> = Some(&callback);
551 assert!(metrics_callback.is_some());
552
553 if let Some(cb) = metrics_callback {
555 cb("test-middleware", "request", 0.001, false);
556 cb("test-middleware", "response", 0.002, true);
557 }
558
559 let recorded = invocations.borrow();
561 assert_eq!(recorded.len(), 2);
562 assert_eq!(recorded[0].0, "test-middleware");
563 assert_eq!(recorded[0].1, "request");
564 assert!(!recorded[0].3); assert_eq!(recorded[1].1, "response");
566 assert!(recorded[1].3); }
568
569 #[test]
570 fn execute_on_request_empty_instances_returns_continue() {
571 let mut instances: Vec<PluginInstance> = vec![];
572 let request = b"test request";
573 let context = RequestContext::default();
574
575 let result = execute_on_request(&mut instances, request, context);
576 assert!(matches!(result, ChainResult::Continue { .. }));
577
578 if let ChainResult::Continue {
579 request: req,
580 context: _,
581 } = result
582 {
583 assert_eq!(req, request.to_vec());
584 }
585 }
586
587 #[test]
588 fn execute_on_response_empty_instances_returns_input() {
589 let mut instances: Vec<PluginInstance> = vec![];
590 let response = b"test response";
591 let context = RequestContext::default();
592
593 let result = execute_on_response(&mut instances, response, context);
594 assert_eq!(result, response.to_vec());
595 }
596
597 #[test]
598 fn execute_with_metrics_none_callback_works() {
599 let mut instances: Vec<PluginInstance> = vec![];
600 let request = b"test";
601 let context = RequestContext::default();
602
603 let result =
605 execute_on_request_with_metrics(&mut instances, request, context.clone(), None);
606 assert!(matches!(result, ChainResult::Continue { .. }));
607
608 let response = execute_on_response_with_metrics(&mut instances, request, context, None);
609 assert_eq!(response, request.to_vec());
610 }
611}