use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::Response;
use serde_json::Value;
use crate::latency_profiles::LatencyProfiles;
use mockforge_chaos::core_failure_injection::FailureInjector;
use mockforge_chaos::core_traffic_shaping::TrafficShaper;
use mockforge_core::Overrides;
#[derive(Clone)]
pub struct OperationMeta {
pub id: String,
pub tags: Vec<String>,
pub path: String,
}
#[derive(Clone)]
pub struct Shared {
pub profiles: LatencyProfiles,
pub overrides: Overrides,
pub failure_injector: Option<FailureInjector>,
pub traffic_shaper: Option<TrafficShaper>,
pub overrides_enabled: bool,
pub traffic_shaping_enabled: bool,
}
pub async fn add_shared_extension(
State(shared): State<Shared>,
mut req: Request<Body>,
next: Next,
) -> Response {
req.extensions_mut().insert(shared);
next.run(req).await
}
pub async fn fault_then_next(req: Request<Body>, next: Next) -> Response {
let shared = match req.extensions().get::<Shared>() {
Some(s) => s.clone(),
None => {
tracing::error!("Shared extension not found in request - ensure add_shared_extension middleware is configured");
let mut res =
Response::new(Body::from("Internal server error: middleware misconfiguration"));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return res;
}
};
let op = req.extensions().get::<OperationMeta>().cloned();
if let Some(failure_injector) = &shared.failure_injector {
let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
if let Some((status_code, error_message)) = failure_injector.process_request(tags) {
let mut res = Response::new(Body::from(error_message));
*res.status_mut() =
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
return res;
}
}
if let Some(op) = &op {
if let Some((code, msg)) = shared
.profiles
.maybe_fault(&op.id, &op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>())
.await
{
let mut res = Response::new(Body::from(msg));
*res.status_mut() =
StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
return res;
}
}
if shared.traffic_shaping_enabled {
if let Some(traffic_shaper) = &shared.traffic_shaper {
let request_size = calculate_request_size(&req);
let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
match traffic_shaper.process_transfer(request_size, tags).await {
Ok(Some(_timeout)) => {
let mut res =
Response::new(Body::from("Request timeout due to traffic shaping"));
*res.status_mut() = StatusCode::REQUEST_TIMEOUT;
return res;
}
Ok(None) => {
}
Err(e) => {
let mut res =
Response::new(Body::from(format!("Traffic shaping error: {}", e)));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return res;
}
}
}
}
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts, body);
let response = next.run(req).await;
if shared.traffic_shaping_enabled {
if let Some(traffic_shaper) = &shared.traffic_shaper {
let response_size = calculate_response_size(&response);
let tags = op.as_ref().map(|o| o.tags.as_slice()).unwrap_or(&[]);
match traffic_shaper.process_transfer(response_size, tags).await {
Ok(Some(_timeout)) => {
let mut res =
Response::new(Body::from("Response timeout due to traffic shaping"));
*res.status_mut() = StatusCode::GATEWAY_TIMEOUT;
return res;
}
Ok(None) => {
}
Err(e) => {
let mut res =
Response::new(Body::from(format!("Traffic shaping error: {}", e)));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return res;
}
}
}
}
response
}
pub fn apply_overrides(shared: &Shared, op: Option<&OperationMeta>, body: &mut Value) {
if shared.overrides_enabled {
if let Some(op) = op {
shared.overrides.apply(
&op.id,
&op.tags.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
&op.path,
body,
);
}
}
}
fn calculate_request_size<B>(req: &Request<B>) -> u64 {
let mut size = 0u64;
for (name, value) in req.headers() {
size += name.as_str().len() as u64;
size += value.as_bytes().len() as u64;
}
size += req.uri().to_string().len() as u64;
if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
if let Ok(len_str) = content_length.to_str() {
if let Ok(len) = len_str.parse::<u64>() {
size += len;
return size;
}
}
}
let method = req.method();
if method == http::Method::POST || method == http::Method::PUT || method == http::Method::PATCH
{
size += 256; }
size
}
fn calculate_response_size(res: &Response) -> u64 {
let mut size = 0u64;
for (name, value) in res.headers() {
size += name.as_str().len() as u64;
size += value.as_bytes().len() as u64;
}
size += 15;
if let Some(content_length) = res.headers().get(http::header::CONTENT_LENGTH) {
if let Ok(len_str) = content_length.to_str() {
if let Ok(len) = len_str.parse::<u64>() {
size += len;
return size;
}
}
}
match res.status().as_u16() {
204 | 304 => {} _ => size += 256, }
size
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{Request, Response, StatusCode};
use serde_json::json;
#[test]
fn test_operation_meta_creation() {
let meta = OperationMeta {
id: "getUserById".to_string(),
tags: vec!["users".to_string(), "public".to_string()],
path: "/users/{id}".to_string(),
};
assert_eq!(meta.id, "getUserById");
assert_eq!(meta.tags.len(), 2);
assert_eq!(meta.path, "/users/{id}");
}
#[test]
fn test_shared_creation() {
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: None,
traffic_shaper: None,
overrides_enabled: false,
traffic_shaping_enabled: false,
};
assert!(!shared.overrides_enabled);
assert!(!shared.traffic_shaping_enabled);
assert!(shared.failure_injector.is_none());
assert!(shared.traffic_shaper.is_none());
}
#[test]
fn test_shared_with_failure_injector() {
let failure_injector = FailureInjector::new(None, true);
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: Some(failure_injector),
traffic_shaper: None,
overrides_enabled: false,
traffic_shaping_enabled: false,
};
assert!(shared.failure_injector.is_some());
}
#[test]
fn test_apply_overrides_disabled() {
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: None,
traffic_shaper: None,
overrides_enabled: false,
traffic_shaping_enabled: false,
};
let op = OperationMeta {
id: "getUser".to_string(),
tags: vec![],
path: "/users".to_string(),
};
let mut body = json!({"name": "John"});
let original = body.clone();
apply_overrides(&shared, Some(&op), &mut body);
assert_eq!(body, original);
}
#[test]
fn test_apply_overrides_enabled_no_rules() {
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: None,
traffic_shaper: None,
overrides_enabled: true,
traffic_shaping_enabled: false,
};
let op = OperationMeta {
id: "getUser".to_string(),
tags: vec![],
path: "/users".to_string(),
};
let mut body = json!({"name": "John"});
let original = body.clone();
apply_overrides(&shared, Some(&op), &mut body);
assert_eq!(body, original);
}
#[test]
fn test_apply_overrides_with_none_operation() {
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: None,
traffic_shaper: None,
overrides_enabled: true,
traffic_shaping_enabled: false,
};
let mut body = json!({"name": "John"});
let original = body.clone();
apply_overrides(&shared, None, &mut body);
assert_eq!(body, original);
}
#[test]
fn test_calculate_request_size_basic() {
let req = Request::builder()
.uri("/test")
.header("content-type", "application/json")
.body(())
.unwrap();
let size = calculate_request_size(&req);
assert!(size > 0);
assert!(size >= "/test".len() as u64 + "content-type".len() as u64);
}
#[test]
fn test_calculate_request_size_with_multiple_headers() {
let req = Request::builder()
.uri("/api/users")
.header("content-type", "application/json")
.header("authorization", "Bearer token123")
.header("user-agent", "test-client")
.body(())
.unwrap();
let size = calculate_request_size(&req);
assert!(size > 50);
}
#[test]
fn test_calculate_response_size_basic() {
let res = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(axum::body::Body::empty())
.unwrap();
let size = calculate_response_size(&res);
assert!(size > 0);
assert!(size >= 50);
}
#[test]
fn test_calculate_response_size_with_multiple_headers() {
let res = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.header("cache-control", "no-cache")
.header("x-request-id", "123-456-789")
.body(axum::body::Body::empty())
.unwrap();
let size = calculate_response_size(&res);
assert!(size > 100);
}
#[test]
fn test_shared_clone() {
let shared = Shared {
profiles: LatencyProfiles::default(),
overrides: Overrides::default(),
failure_injector: None,
traffic_shaper: None,
overrides_enabled: true,
traffic_shaping_enabled: true,
};
let cloned = shared.clone();
assert_eq!(shared.overrides_enabled, cloned.overrides_enabled);
assert_eq!(shared.traffic_shaping_enabled, cloned.traffic_shaping_enabled);
}
#[test]
fn test_operation_meta_clone() {
let meta = OperationMeta {
id: "testOp".to_string(),
tags: vec!["tag1".to_string()],
path: "/test".to_string(),
};
let cloned = meta.clone();
assert_eq!(meta.id, cloned.id);
assert_eq!(meta.tags, cloned.tags);
assert_eq!(meta.path, cloned.path);
}
}