use crate::{
config::CorruptionType,
fault::FaultInjector,
latency::LatencyInjector,
latency_metrics::LatencyMetricsTracker,
rate_limit::RateLimiter,
resilience::{Bulkhead, CircuitBreaker},
traffic_shaping::TrafficShaper,
ChaosConfig,
};
use axum::{
body::Body,
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use http_body_util::BodyExt;
use rand::Rng;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Clone)]
pub struct ChaosMiddleware {
config: Arc<RwLock<ChaosConfig>>,
latency_tracker: Arc<LatencyMetricsTracker>,
latency_injector: Arc<RwLock<LatencyInjector>>,
fault_injector: Arc<RwLock<FaultInjector>>,
rate_limiter: Arc<RwLock<RateLimiter>>,
traffic_shaper: Arc<RwLock<TrafficShaper>>,
circuit_breaker: Arc<RwLock<CircuitBreaker>>,
bulkhead: Arc<RwLock<Bulkhead>>,
}
impl ChaosMiddleware {
pub fn new(
config: Arc<RwLock<ChaosConfig>>,
latency_tracker: Arc<LatencyMetricsTracker>,
) -> Self {
let latency_injector = Arc::new(RwLock::new(LatencyInjector::new(Default::default())));
let fault_injector = Arc::new(RwLock::new(FaultInjector::new(Default::default())));
let rate_limiter = Arc::new(RwLock::new(RateLimiter::new(Default::default())));
let traffic_shaper = Arc::new(RwLock::new(TrafficShaper::new(Default::default())));
let circuit_breaker = Arc::new(RwLock::new(CircuitBreaker::new(Default::default())));
let bulkhead = Arc::new(RwLock::new(Bulkhead::new(Default::default())));
Self {
config,
latency_tracker,
latency_injector,
fault_injector,
rate_limiter,
traffic_shaper,
circuit_breaker,
bulkhead,
}
}
pub async fn init_from_config(&self) {
self.update_from_config().await;
}
pub async fn update_from_config(&self) {
let config = self.config.read().await;
{
let mut injector = self.latency_injector.write().await;
*injector = LatencyInjector::new(config.latency.clone().unwrap_or_default());
}
{
let mut limiter = self.rate_limiter.write().await;
*limiter = RateLimiter::new(config.rate_limit.clone().unwrap_or_default());
}
{
let mut shaper = self.traffic_shaper.write().await;
*shaper = TrafficShaper::new(config.traffic_shaping.clone().unwrap_or_default());
}
{
let mut breaker = self.circuit_breaker.write().await;
*breaker = CircuitBreaker::new(config.circuit_breaker.clone().unwrap_or_default());
}
{
let mut bh = self.bulkhead.write().await;
*bh = Bulkhead::new(config.bulkhead.clone().unwrap_or_default());
}
}
pub fn latency_injector(&self) -> Arc<RwLock<LatencyInjector>> {
self.latency_injector.clone()
}
pub fn fault_injector(&self) -> Arc<RwLock<FaultInjector>> {
self.fault_injector.clone()
}
pub fn rate_limiter(&self) -> Arc<RwLock<RateLimiter>> {
self.rate_limiter.clone()
}
pub fn traffic_shaper(&self) -> Arc<RwLock<TrafficShaper>> {
self.traffic_shaper.clone()
}
pub fn circuit_breaker(&self) -> Arc<RwLock<CircuitBreaker>> {
self.circuit_breaker.clone()
}
pub fn bulkhead(&self) -> Arc<RwLock<Bulkhead>> {
self.bulkhead.clone()
}
pub fn config(&self) -> Arc<RwLock<ChaosConfig>> {
self.config.clone()
}
pub fn latency_tracker(&self) -> &Arc<LatencyMetricsTracker> {
&self.latency_tracker
}
}
pub async fn chaos_middleware_with_state(
chaos: Arc<ChaosMiddleware>,
req: Request<Body>,
next: Next,
) -> Response {
let (mut parts, body) = req.into_parts();
parts.extensions.insert(chaos.clone());
let req = Request::from_parts(parts, body);
chaos_middleware_core(chaos, req, next).await
}
pub async fn chaos_middleware(
State(chaos): State<Arc<ChaosMiddleware>>,
req: Request<Body>,
next: Next,
) -> Response {
chaos_middleware_core(chaos, req, next).await
}
async fn chaos_middleware_core(
chaos: Arc<ChaosMiddleware>,
req: Request<Body>,
next: Next,
) -> Response {
let config = chaos.config.read().await;
if !config.enabled {
drop(config);
return next.run(req).await;
}
let path = req.uri().path().to_string();
let ip = req
.extensions()
.get::<SocketAddr>()
.map(|addr| addr.ip().to_string())
.or_else(|| {
req.headers()
.get("x-forwarded-for")
.or_else(|| req.headers().get("x-real-ip"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
})
.unwrap_or_else(|| "127.0.0.1".to_string());
debug!("Chaos middleware processing: {} {}", req.method(), path);
drop(config);
{
let circuit_breaker = chaos.circuit_breaker.read().await;
if !circuit_breaker.allow_request().await {
warn!("Circuit breaker open, rejecting request: {}", path);
return (
StatusCode::SERVICE_UNAVAILABLE,
"Service temporarily unavailable (circuit breaker open)",
)
.into_response();
}
}
let _bulkhead_guard = {
let bulkhead = chaos.bulkhead.read().await;
match bulkhead.try_acquire().await {
Ok(guard) => guard,
Err(e) => {
warn!("Bulkhead rejected request: {} - {:?}", path, e);
return (StatusCode::SERVICE_UNAVAILABLE, format!("Service overloaded: {}", e))
.into_response();
}
}
};
let rate_limiter = chaos.rate_limiter.read().await;
if let Err(_e) = rate_limiter.check(Some(&ip), Some(&path)) {
drop(rate_limiter);
warn!("Rate limit exceeded: {} - {}", ip, path);
return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
}
drop(rate_limiter);
let traffic_shaper = chaos.traffic_shaper.read().await;
if !traffic_shaper.check_connection_limit() {
drop(traffic_shaper);
warn!("Connection limit exceeded");
return (StatusCode::SERVICE_UNAVAILABLE, "Connection limit exceeded").into_response();
}
let _connection_guard = crate::traffic_shaping::ConnectionGuard::new(traffic_shaper.clone());
if traffic_shaper.should_drop_packet() {
drop(traffic_shaper);
warn!("Simulating packet loss for: {}", path);
return (StatusCode::REQUEST_TIMEOUT, "Connection dropped").into_response();
}
drop(traffic_shaper);
let latency_injector = chaos.latency_injector.read().await;
let delay_ms = latency_injector.inject().await;
drop(latency_injector);
if delay_ms > 0 {
chaos.latency_tracker.record_latency(delay_ms);
}
let config = chaos.config.read().await;
let fault_config = config.fault_injection.as_ref();
let should_inject_fault = fault_config.map(|f| f.enabled).unwrap_or(false);
let http_error_status = if should_inject_fault {
fault_config.and_then(|f| {
let mut rng = rand::rng();
if rng.random::<f64>() <= f.http_error_probability && !f.http_errors.is_empty() {
Some(f.http_errors[rng.random_range(0..f.http_errors.len())])
} else {
None
}
})
} else {
None
};
drop(config);
if let Some(status_code) = http_error_status {
warn!("Injecting HTTP error: {}", status_code);
return (
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
format!("Injected error: {}", status_code),
)
.into_response();
}
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!("Failed to read request body: {}", e);
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
}
};
let request_size = body_bytes.len();
{
let traffic_shaper = chaos.traffic_shaper.read().await;
traffic_shaper.throttle_bandwidth(request_size).await;
}
let req = Request::from_parts(parts, Body::from(body_bytes));
let response = next.run(req).await;
let status = response.status();
{
let circuit_breaker = chaos.circuit_breaker.read().await;
if status.is_server_error() || status == StatusCode::SERVICE_UNAVAILABLE {
circuit_breaker.record_failure().await;
} else if status.is_success() {
circuit_breaker.record_success().await;
}
}
let (parts, body) = response.into_parts();
let response_body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!("Failed to read response body: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response body")
.into_response();
}
};
let response_size = response_body_bytes.len();
let config = chaos.config.read().await;
let should_truncate = config
.fault_injection
.as_ref()
.map(|f| f.enabled && f.timeout_errors)
.unwrap_or(false);
let should_corrupt = config.fault_injection.as_ref().map(|f| f.enabled).unwrap_or(false);
let corruption_type = config
.fault_injection
.as_ref()
.map(|f| f.corruption_type)
.unwrap_or(CorruptionType::None);
drop(config);
let mut final_body_bytes = if should_truncate {
warn!("Injecting partial response");
let truncate_at = response_size / 2;
response_body_bytes.slice(0..truncate_at).to_vec()
} else {
response_body_bytes.to_vec()
};
if should_corrupt && corruption_type != CorruptionType::None {
warn!("Injecting payload corruption: {:?}", corruption_type);
final_body_bytes = corrupt_payload(&final_body_bytes, corruption_type);
}
let final_body = Body::from(final_body_bytes);
{
let traffic_shaper = chaos.traffic_shaper.read().await;
traffic_shaper.throttle_bandwidth(response_size).await;
}
Response::from_parts(parts, final_body)
}
fn corrupt_payload(data: &[u8], corruption_type: CorruptionType) -> Vec<u8> {
if data.is_empty() {
return data.to_vec();
}
let mut rng = rand::rng();
let mut corrupted = data.to_vec();
match corruption_type {
CorruptionType::None => corrupted,
CorruptionType::RandomBytes => {
let num_bytes_to_corrupt = (data.len() as f64 * 0.1).max(1.0) as usize;
for _ in 0..num_bytes_to_corrupt {
let index = rng.random_range(0..data.len());
corrupted[index] = rng.random::<u8>();
}
corrupted
}
CorruptionType::Truncate => {
let min_truncate = data.len() / 2;
let max_truncate = (data.len() as f64 * 0.9) as usize;
let truncate_at = if max_truncate > min_truncate {
rng.random_range(min_truncate..=max_truncate)
} else {
min_truncate
};
corrupted.truncate(truncate_at);
corrupted
}
CorruptionType::BitFlip => {
let num_bytes_to_flip = (data.len() as f64 * 0.1).max(1.0) as usize;
for _ in 0..num_bytes_to_flip {
let index = rng.random_range(0..data.len());
let bit_to_flip = rng.random_range(0..8);
corrupted[index] ^= 1 << bit_to_flip;
}
corrupted
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{LatencyConfig, RateLimitConfig};
use crate::latency_metrics::LatencyMetricsTracker;
#[tokio::test]
async fn test_middleware_creation() {
let config = ChaosConfig {
enabled: true,
latency: Some(LatencyConfig {
enabled: true,
fixed_delay_ms: Some(10),
..Default::default()
}),
..Default::default()
};
let latency_tracker = Arc::new(LatencyMetricsTracker::new());
let config_arc = Arc::new(RwLock::new(config));
let middleware = ChaosMiddleware::new(config_arc, latency_tracker);
middleware.init_from_config().await;
assert!(middleware.latency_injector.read().await.is_enabled());
}
#[tokio::test]
async fn test_rate_limiting() {
let config = Arc::new(RwLock::new(ChaosConfig {
enabled: true,
rate_limit: Some(RateLimitConfig {
enabled: true,
requests_per_second: 1,
burst_size: 2, ..Default::default()
}),
..Default::default()
}));
let latency_tracker = Arc::new(LatencyMetricsTracker::new());
let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker));
middleware.init_from_config().await;
{
let rate_limiter = middleware.rate_limiter.read().await;
assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
}
{
let rate_limiter = middleware.rate_limiter.read().await;
assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_err());
}
}
#[tokio::test]
async fn test_latency_recording() {
let config = Arc::new(RwLock::new(ChaosConfig {
enabled: true,
latency: Some(LatencyConfig {
enabled: true,
fixed_delay_ms: Some(50),
probability: 1.0,
..Default::default()
}),
..Default::default()
}));
let latency_tracker = Arc::new(LatencyMetricsTracker::new());
let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker.clone()));
middleware.init_from_config().await;
let tracker_from_middleware = middleware.latency_tracker();
assert_eq!(Arc::as_ptr(tracker_from_middleware), Arc::as_ptr(&latency_tracker));
let delay_ms = {
let injector = middleware.latency_injector.read().await;
injector.inject().await
};
if delay_ms > 0 {
latency_tracker.record_latency(delay_ms);
}
let samples = latency_tracker.get_samples();
assert!(!samples.is_empty(), "Should have recorded at least one latency sample");
assert_eq!(samples[0].latency_ms, 50, "Recorded latency should match injected delay");
}
}