use std::time::Instant;
use serde::{Deserialize, Serialize};
use crate::error::WasmError;
use crate::instance::{PluginInstance, RequestContext};
use crate::trap::{TrapContext, TrapResult};
pub type MetricsCallback<'a> = Option<&'a dyn Fn(&str, &str, f64, bool)>;
#[derive(Debug)]
pub enum OnRequestResult {
Continue(Vec<u8>),
ShortCircuit(Vec<u8>),
}
#[derive(Debug, Serialize, Deserialize)]
struct MiddlewareOutput {
action: i32,
data: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct MiddlewareConfig {
pub name: String,
pub config: serde_json::Value,
}
impl MiddlewareConfig {
pub fn new(name: impl Into<String>, config: serde_json::Value) -> Self {
Self {
name: name.into(),
config,
}
}
}
pub struct MiddlewareChain {
configs: Vec<MiddlewareConfig>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
configs: Vec::new(),
}
}
pub fn from_configs(configs: Vec<MiddlewareConfig>) -> Self {
Self { configs }
}
pub fn push(&mut self, config: MiddlewareConfig) {
self.configs.push(config);
}
pub fn len(&self) -> usize {
self.configs.len()
}
pub fn is_empty(&self) -> bool {
self.configs.is_empty()
}
pub fn configs(&self) -> &[MiddlewareConfig] {
&self.configs
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ChainResult {
Continue {
request: Vec<u8>,
context: RequestContext,
},
ShortCircuit {
response: Vec<u8>,
middleware_index: usize,
context: RequestContext,
},
Error {
error: WasmError,
trap_result: TrapResult,
},
}
pub fn execute_on_request(
instances: &mut [PluginInstance],
initial_request: &[u8],
context: RequestContext,
) -> ChainResult {
execute_on_request_with_metrics(instances, initial_request, context, None)
}
pub fn execute_on_request_with_metrics(
instances: &mut [PluginInstance],
initial_request: &[u8],
context: RequestContext,
metrics_callback: MetricsCallback<'_>,
) -> ChainResult {
let mut current_request = initial_request.to_vec();
let mut current_context = context;
for (index, instance) in instances.iter_mut().enumerate() {
instance.set_context(current_context.clone());
let start = Instant::now();
let middleware_name = instance.name().to_string();
match instance.on_request(¤t_request) {
Ok(result_code) => {
let output = instance.take_output();
match parse_middleware_output(&output, result_code) {
Ok(OnRequestResult::Continue(new_request)) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"request",
start.elapsed().as_secs_f64(),
false,
);
}
current_request = new_request;
current_context = instance.get_context();
}
Ok(OnRequestResult::ShortCircuit(response)) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"request",
start.elapsed().as_secs_f64(),
true,
);
}
let final_context = instance.get_context();
return ChainResult::ShortCircuit {
response,
middleware_index: index,
context: final_context,
};
}
Err(e) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"request",
start.elapsed().as_secs_f64(),
false,
);
}
return ChainResult::Error {
trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
error: e,
};
}
}
}
Err(e) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"request",
start.elapsed().as_secs_f64(),
false,
);
}
return ChainResult::Error {
trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
error: e,
};
}
}
}
ChainResult::Continue {
request: current_request,
context: current_context,
}
}
pub fn execute_on_response(
instances: &mut [PluginInstance],
initial_response: &[u8],
context: RequestContext,
) -> Vec<u8> {
execute_on_response_with_metrics(instances, initial_response, context, None)
}
pub fn execute_on_response_with_metrics(
instances: &mut [PluginInstance],
initial_response: &[u8],
context: RequestContext,
metrics_callback: MetricsCallback<'_>,
) -> Vec<u8> {
let mut current_response = initial_response.to_vec();
for instance in instances.iter_mut().rev() {
instance.set_context(context.clone());
let start = Instant::now();
let middleware_name = instance.name().to_string();
match instance.on_response(¤t_response) {
Ok(_result_code) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"response",
start.elapsed().as_secs_f64(),
false,
);
}
let output = instance.take_output();
if !output.is_empty() {
current_response = output;
}
}
Err(e) => {
if let Some(callback) = metrics_callback {
callback(
&middleware_name,
"response",
start.elapsed().as_secs_f64(),
false,
);
}
let trap_result = TrapResult::from_error(&e, TrapContext::OnResponse);
tracing::warn!(
error = %trap_result.message(),
"Middleware on_response failed, continuing with original response"
);
}
}
}
current_response
}
pub fn execute_on_response_partial(
instances: &mut [PluginInstance],
response: &[u8],
short_circuit_index: usize,
context: RequestContext,
) -> Vec<u8> {
if short_circuit_index == 0 {
return response.to_vec();
}
let partial_instances = &mut instances[..short_circuit_index];
execute_on_response(partial_instances, response, context)
}
pub fn parse_middleware_output(
output: &[u8],
result_code: i32,
) -> Result<OnRequestResult, WasmError> {
if output.is_empty() {
return if result_code == 0 {
Ok(OnRequestResult::Continue(Vec::new()))
} else {
Err(WasmError::InitFailed(
"middleware returned short-circuit without output".into(),
))
};
}
match serde_json::from_slice::<MiddlewareOutput>(output) {
Ok(parsed) => {
let data = serde_json::to_vec(&parsed.data)
.map_err(|e| WasmError::InitFailed(format!("failed to serialize output: {}", e)))?;
if parsed.action == 0 || result_code == 0 {
Ok(OnRequestResult::Continue(data))
} else {
Ok(OnRequestResult::ShortCircuit(data))
}
}
Err(_) => {
if result_code == 0 {
Ok(OnRequestResult::Continue(output.to_vec()))
} else {
Ok(OnRequestResult::ShortCircuit(output.to_vec()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn middleware_config_new() {
let config = MiddlewareConfig::new("rate-limit", json!({"quota": 100}));
assert_eq!(config.name, "rate-limit");
assert_eq!(config.config["quota"], 100);
}
#[test]
fn chain_new_is_empty() {
let chain = MiddlewareChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
}
#[test]
fn chain_push() {
let mut chain = MiddlewareChain::new();
chain.push(MiddlewareConfig::new("auth", json!({})));
chain.push(MiddlewareConfig::new("rate-limit", json!({})));
assert_eq!(chain.len(), 2);
assert_eq!(chain.configs()[0].name, "auth");
assert_eq!(chain.configs()[1].name, "rate-limit");
}
#[test]
fn chain_from_configs() {
let configs = vec![
MiddlewareConfig::new("auth", json!({})),
MiddlewareConfig::new("cors", json!({})),
];
let chain = MiddlewareChain::from_configs(configs);
assert_eq!(chain.len(), 2);
}
#[test]
fn parse_continue_output() {
let output = serde_json::to_vec(&json!({
"action": 0,
"data": {"method": "GET", "path": "/api"}
}))
.unwrap();
let result = parse_middleware_output(&output, 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(_)));
}
#[test]
fn parse_short_circuit_output() {
let output = serde_json::to_vec(&json!({
"action": 1,
"data": {"status": 401, "body": "Unauthorized"}
}))
.unwrap();
let result = parse_middleware_output(&output, 1).unwrap();
assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
}
#[test]
fn parse_raw_output_continue() {
let output = b"raw request data";
let result = parse_middleware_output(output, 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(_)));
}
#[test]
fn parse_raw_output_short_circuit() {
let output = b"error response";
let result = parse_middleware_output(output, 1).unwrap();
assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
}
#[test]
fn parse_empty_continue() {
let result = parse_middleware_output(&[], 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(data) if data.is_empty()));
}
#[test]
fn parse_empty_short_circuit_fails() {
let result = parse_middleware_output(&[], 1);
assert!(result.is_err());
}
#[test]
fn parse_continue_with_request_metadata() {
use barbacane_plugin_sdk::types::Request;
use std::collections::BTreeMap;
let req = Request {
method: "POST".into(),
path: "/upload".into(),
query: None,
headers: {
let mut h = BTreeMap::new();
h.insert("content-type".into(), "application/octet-stream".into());
h
},
body: None, client_ip: "127.0.0.1".into(),
path_params: BTreeMap::new(),
};
let output = serde_json::to_vec(&json!({
"action": 0,
"data": req
}))
.unwrap();
let result = parse_middleware_output(&output, 0).unwrap();
match result {
OnRequestResult::Continue(data) => {
let parsed: Request = serde_json::from_slice(&data).unwrap();
assert_eq!(parsed.method, "POST");
assert_eq!(parsed.path, "/upload");
assert_eq!(parsed.body, None); }
OnRequestResult::ShortCircuit(_) => panic!("expected Continue"),
}
}
#[test]
fn parse_short_circuit_with_response_metadata() {
use barbacane_plugin_sdk::types::Response;
use std::collections::BTreeMap;
let resp = Response {
status: 403,
headers: {
let mut h = BTreeMap::new();
h.insert("content-type".into(), "application/json".into());
h
},
body: None, };
let output = serde_json::to_vec(&json!({
"action": 1,
"data": resp
}))
.unwrap();
let result = parse_middleware_output(&output, 1).unwrap();
match result {
OnRequestResult::ShortCircuit(data) => {
let parsed: Response = serde_json::from_slice(&data).unwrap();
assert_eq!(parsed.status, 403);
assert_eq!(parsed.body, None); }
OnRequestResult::Continue(_) => panic!("expected ShortCircuit"),
}
}
#[test]
fn metrics_callback_type_accepts_closure() {
use std::cell::RefCell;
use std::rc::Rc;
let invocations = Rc::new(RefCell::new(Vec::new()));
let invocations_clone = invocations.clone();
let callback = move |name: &str, phase: &str, duration: f64, short_circuit: bool| {
invocations_clone.borrow_mut().push((
name.to_string(),
phase.to_string(),
duration,
short_circuit,
));
};
let metrics_callback: MetricsCallback<'_> = Some(&callback);
assert!(metrics_callback.is_some());
if let Some(cb) = metrics_callback {
cb("test-middleware", "request", 0.001, false);
cb("test-middleware", "response", 0.002, true);
}
let recorded = invocations.borrow();
assert_eq!(recorded.len(), 2);
assert_eq!(recorded[0].0, "test-middleware");
assert_eq!(recorded[0].1, "request");
assert!(!recorded[0].3); assert_eq!(recorded[1].1, "response");
assert!(recorded[1].3); }
#[test]
fn execute_on_request_empty_instances_returns_continue() {
let mut instances: Vec<PluginInstance> = vec![];
let request = b"test request";
let context = RequestContext::default();
let result = execute_on_request(&mut instances, request, context);
assert!(matches!(result, ChainResult::Continue { .. }));
if let ChainResult::Continue {
request: req,
context: _,
} = result
{
assert_eq!(req, request.to_vec());
}
}
#[test]
fn execute_on_response_empty_instances_returns_input() {
let mut instances: Vec<PluginInstance> = vec![];
let response = b"test response";
let context = RequestContext::default();
let result = execute_on_response(&mut instances, response, context);
assert_eq!(result, response.to_vec());
}
#[test]
fn execute_with_metrics_none_callback_works() {
let mut instances: Vec<PluginInstance> = vec![];
let request = b"test";
let context = RequestContext::default();
let result =
execute_on_request_with_metrics(&mut instances, request, context.clone(), None);
assert!(matches!(result, ChainResult::Continue { .. }));
let response = execute_on_response_with_metrics(&mut instances, request, context, None);
assert_eq!(response, request.to_vec());
}
}