use serde::{Deserialize, Serialize};
use crate::error::WasmError;
use crate::instance::{PluginInstance, RequestContext};
use crate::trap::{TrapContext, TrapResult};
#[derive(Debug)]
pub enum OnRequestResult {
Continue(Vec<u8>),
ShortCircuit(Vec<u8>),
}
#[derive(Debug, Serialize, Deserialize)]
struct MiddlewareOutput {
action: i32,
data: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct MiddlewareConfig {
pub name: String,
pub config: serde_json::Value,
}
impl MiddlewareConfig {
pub fn new(name: impl Into<String>, config: serde_json::Value) -> Self {
Self {
name: name.into(),
config,
}
}
}
pub struct MiddlewareChain {
configs: Vec<MiddlewareConfig>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
configs: Vec::new(),
}
}
pub fn from_configs(configs: Vec<MiddlewareConfig>) -> Self {
Self { configs }
}
pub fn push(&mut self, config: MiddlewareConfig) {
self.configs.push(config);
}
pub fn len(&self) -> usize {
self.configs.len()
}
pub fn is_empty(&self) -> bool {
self.configs.is_empty()
}
pub fn configs(&self) -> &[MiddlewareConfig] {
&self.configs
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ChainResult {
Continue {
request: Vec<u8>,
context: RequestContext,
},
ShortCircuit {
response: Vec<u8>,
middleware_index: usize,
context: RequestContext,
},
Error {
error: WasmError,
trap_result: TrapResult,
},
}
pub fn execute_on_request(
instances: &mut [PluginInstance],
initial_request: &[u8],
context: RequestContext,
) -> ChainResult {
let mut current_request = initial_request.to_vec();
let current_context = context;
for (index, instance) in instances.iter_mut().enumerate() {
instance.set_context(current_context.clone());
match instance.on_request(¤t_request) {
Ok(result_code) => {
let output = instance.take_output();
match parse_middleware_output(&output, result_code) {
Ok(OnRequestResult::Continue(new_request)) => {
current_request = new_request;
}
Ok(OnRequestResult::ShortCircuit(response)) => {
return ChainResult::ShortCircuit {
response,
middleware_index: index,
context: current_context,
};
}
Err(e) => {
return ChainResult::Error {
trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
error: e,
};
}
}
}
Err(e) => {
return ChainResult::Error {
trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
error: e,
};
}
}
}
ChainResult::Continue {
request: current_request,
context: current_context,
}
}
pub fn execute_on_response(
instances: &mut [PluginInstance],
initial_response: &[u8],
context: RequestContext,
) -> Vec<u8> {
let mut current_response = initial_response.to_vec();
for instance in instances.iter_mut().rev() {
instance.set_context(context.clone());
match instance.on_response(¤t_response) {
Ok(_result_code) => {
let output = instance.take_output();
if !output.is_empty() {
current_response = output;
}
}
Err(e) => {
let trap_result = TrapResult::from_error(&e, TrapContext::OnResponse);
tracing::warn!(
error = %trap_result.message(),
"Middleware on_response failed, continuing with original response"
);
}
}
}
current_response
}
pub fn execute_on_response_partial(
instances: &mut [PluginInstance],
response: &[u8],
short_circuit_index: usize,
context: RequestContext,
) -> Vec<u8> {
if short_circuit_index == 0 {
return response.to_vec();
}
let partial_instances = &mut instances[..short_circuit_index];
execute_on_response(partial_instances, response, context)
}
fn parse_middleware_output(output: &[u8], result_code: i32) -> Result<OnRequestResult, WasmError> {
if output.is_empty() {
return if result_code == 0 {
Ok(OnRequestResult::Continue(Vec::new()))
} else {
Err(WasmError::InitFailed(
"middleware returned short-circuit without output".into(),
))
};
}
match serde_json::from_slice::<MiddlewareOutput>(output) {
Ok(parsed) => {
let data = serde_json::to_vec(&parsed.data)
.map_err(|e| WasmError::InitFailed(format!("failed to serialize output: {}", e)))?;
if parsed.action == 0 || result_code == 0 {
Ok(OnRequestResult::Continue(data))
} else {
Ok(OnRequestResult::ShortCircuit(data))
}
}
Err(_) => {
if result_code == 0 {
Ok(OnRequestResult::Continue(output.to_vec()))
} else {
Ok(OnRequestResult::ShortCircuit(output.to_vec()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn middleware_config_new() {
let config = MiddlewareConfig::new("rate-limit", json!({"quota": 100}));
assert_eq!(config.name, "rate-limit");
assert_eq!(config.config["quota"], 100);
}
#[test]
fn chain_new_is_empty() {
let chain = MiddlewareChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
}
#[test]
fn chain_push() {
let mut chain = MiddlewareChain::new();
chain.push(MiddlewareConfig::new("auth", json!({})));
chain.push(MiddlewareConfig::new("rate-limit", json!({})));
assert_eq!(chain.len(), 2);
assert_eq!(chain.configs()[0].name, "auth");
assert_eq!(chain.configs()[1].name, "rate-limit");
}
#[test]
fn chain_from_configs() {
let configs = vec![
MiddlewareConfig::new("auth", json!({})),
MiddlewareConfig::new("cors", json!({})),
];
let chain = MiddlewareChain::from_configs(configs);
assert_eq!(chain.len(), 2);
}
#[test]
fn parse_continue_output() {
let output = serde_json::to_vec(&json!({
"action": 0,
"data": {"method": "GET", "path": "/api"}
}))
.unwrap();
let result = parse_middleware_output(&output, 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(_)));
}
#[test]
fn parse_short_circuit_output() {
let output = serde_json::to_vec(&json!({
"action": 1,
"data": {"status": 401, "body": "Unauthorized"}
}))
.unwrap();
let result = parse_middleware_output(&output, 1).unwrap();
assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
}
#[test]
fn parse_raw_output_continue() {
let output = b"raw request data";
let result = parse_middleware_output(output, 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(_)));
}
#[test]
fn parse_raw_output_short_circuit() {
let output = b"error response";
let result = parse_middleware_output(output, 1).unwrap();
assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
}
#[test]
fn parse_empty_continue() {
let result = parse_middleware_output(&[], 0).unwrap();
assert!(matches!(result, OnRequestResult::Continue(data) if data.is_empty()));
}
#[test]
fn parse_empty_short_circuit_fails() {
let result = parse_middleware_output(&[], 1);
assert!(result.is_err());
}
}