1use crate::{
4 config::CorruptionType,
5 fault::FaultInjector,
6 latency::LatencyInjector,
7 latency_metrics::LatencyMetricsTracker,
8 rate_limit::RateLimiter,
9 resilience::{Bulkhead, CircuitBreaker},
10 traffic_shaping::TrafficShaper,
11 ChaosConfig,
12};
13use axum::{
14 body::Body,
15 extract::{Request, State},
16 http::StatusCode,
17 middleware::Next,
18 response::{IntoResponse, Response},
19};
20use http_body_util::BodyExt;
21use rand::Rng;
22use std::{net::SocketAddr, sync::Arc};
23use tokio::sync::RwLock;
24use tracing::{debug, warn};
25
26#[derive(Clone)]
31pub struct ChaosMiddleware {
32 config: Arc<RwLock<ChaosConfig>>,
34 latency_tracker: Arc<LatencyMetricsTracker>,
36 latency_injector: Arc<RwLock<LatencyInjector>>,
39 fault_injector: Arc<RwLock<FaultInjector>>,
40 rate_limiter: Arc<RwLock<RateLimiter>>,
41 traffic_shaper: Arc<RwLock<TrafficShaper>>,
42 circuit_breaker: Arc<RwLock<CircuitBreaker>>,
43 bulkhead: Arc<RwLock<Bulkhead>>,
44}
45
46impl ChaosMiddleware {
47 pub fn new(
56 config: Arc<RwLock<ChaosConfig>>,
57 latency_tracker: Arc<LatencyMetricsTracker>,
58 ) -> Self {
59 let latency_injector = Arc::new(RwLock::new(LatencyInjector::new(Default::default())));
61
62 let fault_injector = Arc::new(RwLock::new(FaultInjector::new(Default::default())));
66
67 let rate_limiter = Arc::new(RwLock::new(RateLimiter::new(Default::default())));
68
69 let traffic_shaper = Arc::new(RwLock::new(TrafficShaper::new(Default::default())));
70
71 let circuit_breaker = Arc::new(RwLock::new(CircuitBreaker::new(Default::default())));
72
73 let bulkhead = Arc::new(RwLock::new(Bulkhead::new(Default::default())));
74
75 Self {
76 config,
77 latency_tracker,
78 latency_injector,
79 fault_injector,
80 rate_limiter,
81 traffic_shaper,
82 circuit_breaker,
83 bulkhead,
84 }
85 }
86
87 pub async fn init_from_config(&self) {
92 self.update_from_config().await;
93 }
94
95 pub async fn update_from_config(&self) {
101 let config = self.config.read().await;
102
103 {
105 let mut injector = self.latency_injector.write().await;
106 *injector = LatencyInjector::new(config.latency.clone().unwrap_or_default());
107 }
108
109 {
114 let mut limiter = self.rate_limiter.write().await;
115 *limiter = RateLimiter::new(config.rate_limit.clone().unwrap_or_default());
116 }
117
118 {
120 let mut shaper = self.traffic_shaper.write().await;
121 *shaper = TrafficShaper::new(config.traffic_shaping.clone().unwrap_or_default());
122 }
123
124 {
126 let mut breaker = self.circuit_breaker.write().await;
127 *breaker = CircuitBreaker::new(config.circuit_breaker.clone().unwrap_or_default());
128 }
129
130 {
132 let mut bh = self.bulkhead.write().await;
133 *bh = Bulkhead::new(config.bulkhead.clone().unwrap_or_default());
134 }
135 }
136
137 pub fn latency_injector(&self) -> Arc<RwLock<LatencyInjector>> {
139 self.latency_injector.clone()
140 }
141
142 pub fn fault_injector(&self) -> Arc<RwLock<FaultInjector>> {
145 self.fault_injector.clone()
146 }
147
148 pub fn rate_limiter(&self) -> Arc<RwLock<RateLimiter>> {
150 self.rate_limiter.clone()
151 }
152
153 pub fn traffic_shaper(&self) -> Arc<RwLock<TrafficShaper>> {
155 self.traffic_shaper.clone()
156 }
157
158 pub fn circuit_breaker(&self) -> Arc<RwLock<CircuitBreaker>> {
160 self.circuit_breaker.clone()
161 }
162
163 pub fn bulkhead(&self) -> Arc<RwLock<Bulkhead>> {
165 self.bulkhead.clone()
166 }
167
168 pub fn config(&self) -> Arc<RwLock<ChaosConfig>> {
170 self.config.clone()
171 }
172
173 pub fn latency_tracker(&self) -> &Arc<LatencyMetricsTracker> {
175 &self.latency_tracker
176 }
177}
178
179pub async fn chaos_middleware_with_state(
181 chaos: Arc<ChaosMiddleware>,
182 req: Request<Body>,
183 next: Next,
184) -> Response {
185 let (mut parts, body) = req.into_parts();
188 parts.extensions.insert(chaos.clone());
189 let req = Request::from_parts(parts, body);
190
191 chaos_middleware_core(chaos, req, next).await
194}
195
196pub async fn chaos_middleware(
198 State(chaos): State<Arc<ChaosMiddleware>>,
199 req: Request<Body>,
200 next: Next,
201) -> Response {
202 chaos_middleware_core(chaos, req, next).await
203}
204
205async fn chaos_middleware_core(
207 chaos: Arc<ChaosMiddleware>,
208 req: Request<Body>,
209 next: Next,
210) -> Response {
211 let config = chaos.config.read().await;
213
214 if !config.enabled {
216 drop(config);
217 return next.run(req).await;
218 }
219
220 let path = req.uri().path().to_string();
221
222 let ip = req
224 .extensions()
225 .get::<SocketAddr>()
226 .map(|addr| addr.ip().to_string())
227 .or_else(|| {
228 req.headers()
229 .get("x-forwarded-for")
230 .or_else(|| req.headers().get("x-real-ip"))
231 .and_then(|h| h.to_str().ok())
232 .map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
233 })
234 .unwrap_or_else(|| "127.0.0.1".to_string());
235
236 debug!("Chaos middleware processing: {} {}", req.method(), path);
237
238 drop(config);
240
241 {
243 let circuit_breaker = chaos.circuit_breaker.read().await;
244 if !circuit_breaker.allow_request().await {
245 warn!("Circuit breaker open, rejecting request: {}", path);
246 return (
247 StatusCode::SERVICE_UNAVAILABLE,
248 "Service temporarily unavailable (circuit breaker open)",
249 )
250 .into_response();
251 }
252 }
253
254 let _bulkhead_guard = {
256 let bulkhead = chaos.bulkhead.read().await;
257 match bulkhead.try_acquire().await {
258 Ok(guard) => guard,
259 Err(e) => {
260 warn!("Bulkhead rejected request: {} - {:?}", path, e);
261 return (StatusCode::SERVICE_UNAVAILABLE, format!("Service overloaded: {}", e))
262 .into_response();
263 }
264 }
265 };
266
267 let rate_limiter = chaos.rate_limiter.read().await;
269 if let Err(_e) = rate_limiter.check(Some(&ip), Some(&path)) {
270 drop(rate_limiter);
271 warn!("Rate limit exceeded: {} - {}", ip, path);
272 return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
273 }
274 drop(rate_limiter);
275
276 let traffic_shaper = chaos.traffic_shaper.read().await;
278 if !traffic_shaper.check_connection_limit() {
279 drop(traffic_shaper);
280 warn!("Connection limit exceeded");
281 return (StatusCode::SERVICE_UNAVAILABLE, "Connection limit exceeded").into_response();
282 }
283
284 let _connection_guard = crate::traffic_shaping::ConnectionGuard::new(traffic_shaper.clone());
286
287 if traffic_shaper.should_drop_packet() {
289 drop(traffic_shaper);
290 warn!("Simulating packet loss for: {}", path);
291 return (StatusCode::REQUEST_TIMEOUT, "Connection dropped").into_response();
292 }
293 drop(traffic_shaper);
294
295 let latency_injector = chaos.latency_injector.read().await;
297 let delay_ms = latency_injector.inject().await;
298 drop(latency_injector);
299 if delay_ms > 0 {
300 chaos.latency_tracker.record_latency(delay_ms);
301 }
302
303 let config = chaos.config.read().await;
305 let fault_config = config.fault_injection.as_ref();
306 let should_inject_fault = fault_config.map(|f| f.enabled).unwrap_or(false);
307 let http_error_status = if should_inject_fault {
308 fault_config.and_then(|f| {
310 let mut rng = rand::rng();
311 if rng.random::<f64>() <= f.http_error_probability && !f.http_errors.is_empty() {
312 Some(f.http_errors[rng.random_range(0..f.http_errors.len())])
313 } else {
314 None
315 }
316 })
317 } else {
318 None
319 };
320 drop(config);
321
322 if let Some(status_code) = http_error_status {
323 warn!("Injecting HTTP error: {}", status_code);
324 return (
325 StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
326 format!("Injected error: {}", status_code),
327 )
328 .into_response();
329 }
330
331 let (parts, body) = req.into_parts();
333 let body_bytes = match body.collect().await {
334 Ok(collected) => collected.to_bytes(),
335 Err(e) => {
336 warn!("Failed to read request body: {}", e);
337 return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
338 }
339 };
340
341 let request_size = body_bytes.len();
342
343 {
345 let traffic_shaper = chaos.traffic_shaper.read().await;
346 traffic_shaper.throttle_bandwidth(request_size).await;
347 }
348
349 let req = Request::from_parts(parts, Body::from(body_bytes));
351
352 let response = next.run(req).await;
354
355 let status = response.status();
357 {
358 let circuit_breaker = chaos.circuit_breaker.read().await;
359 if status.is_server_error() || status == StatusCode::SERVICE_UNAVAILABLE {
360 circuit_breaker.record_failure().await;
361 } else if status.is_success() {
362 circuit_breaker.record_success().await;
363 }
364 }
365
366 let (parts, body) = response.into_parts();
368 let response_body_bytes = match body.collect().await {
369 Ok(collected) => collected.to_bytes(),
370 Err(e) => {
371 warn!("Failed to read response body: {}", e);
372 return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response body")
373 .into_response();
374 }
375 };
376
377 let response_size = response_body_bytes.len();
378
379 let config = chaos.config.read().await;
382 let should_truncate = config
383 .fault_injection
384 .as_ref()
385 .map(|f| f.enabled && f.timeout_errors)
386 .unwrap_or(false);
387 let should_corrupt = config.fault_injection.as_ref().map(|f| f.enabled).unwrap_or(false);
388 let corruption_type = config
389 .fault_injection
390 .as_ref()
391 .map(|f| f.corruption_type)
392 .unwrap_or(CorruptionType::None);
393 drop(config);
394
395 let mut final_body_bytes = if should_truncate {
396 warn!("Injecting partial response");
397 let truncate_at = response_size / 2;
398 response_body_bytes.slice(0..truncate_at).to_vec()
399 } else {
400 response_body_bytes.to_vec()
401 };
402
403 if should_corrupt && corruption_type != CorruptionType::None {
405 warn!("Injecting payload corruption: {:?}", corruption_type);
406 final_body_bytes = corrupt_payload(&final_body_bytes, corruption_type);
407 }
408
409 let final_body = Body::from(final_body_bytes);
410
411 {
413 let traffic_shaper = chaos.traffic_shaper.read().await;
414 traffic_shaper.throttle_bandwidth(response_size).await;
415 }
416
417 Response::from_parts(parts, final_body)
418}
419
420fn corrupt_payload(data: &[u8], corruption_type: CorruptionType) -> Vec<u8> {
422 if data.is_empty() {
423 return data.to_vec();
424 }
425
426 let mut rng = rand::rng();
427 let mut corrupted = data.to_vec();
428
429 match corruption_type {
430 CorruptionType::None => corrupted,
431 CorruptionType::RandomBytes => {
432 let num_bytes_to_corrupt = (data.len() as f64 * 0.1).max(1.0) as usize;
434 for _ in 0..num_bytes_to_corrupt {
435 let index = rng.random_range(0..data.len());
436 corrupted[index] = rng.random::<u8>();
437 }
438 corrupted
439 }
440 CorruptionType::Truncate => {
441 let min_truncate = data.len() / 2;
443 let max_truncate = (data.len() as f64 * 0.9) as usize;
444 let truncate_at = if max_truncate > min_truncate {
445 rng.random_range(min_truncate..=max_truncate)
446 } else {
447 min_truncate
448 };
449 corrupted.truncate(truncate_at);
450 corrupted
451 }
452 CorruptionType::BitFlip => {
453 let num_bytes_to_flip = (data.len() as f64 * 0.1).max(1.0) as usize;
455 for _ in 0..num_bytes_to_flip {
456 let index = rng.random_range(0..data.len());
457 let bit_to_flip = rng.random_range(0..8);
458 corrupted[index] ^= 1 << bit_to_flip;
459 }
460 corrupted
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::config::{LatencyConfig, RateLimitConfig};
469 use crate::latency_metrics::LatencyMetricsTracker;
470
471 #[tokio::test]
472 async fn test_middleware_creation() {
473 let config = ChaosConfig {
474 enabled: true,
475 latency: Some(LatencyConfig {
476 enabled: true,
477 fixed_delay_ms: Some(10),
478 ..Default::default()
479 }),
480 ..Default::default()
481 };
482
483 let latency_tracker = Arc::new(LatencyMetricsTracker::new());
484 let config_arc = Arc::new(RwLock::new(config));
485 let middleware = ChaosMiddleware::new(config_arc, latency_tracker);
486 assert!(middleware.latency_injector.read().await.is_enabled());
487 }
488
489 #[tokio::test]
490 async fn test_rate_limiting() {
491 let config = Arc::new(RwLock::new(ChaosConfig {
492 enabled: true,
493 rate_limit: Some(RateLimitConfig {
494 enabled: true,
495 requests_per_second: 1,
496 burst_size: 2, ..Default::default()
498 }),
499 ..Default::default()
500 }));
501
502 let latency_tracker = Arc::new(LatencyMetricsTracker::new());
503 let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker));
504 middleware.init_from_config().await;
505
506 {
508 let rate_limiter = middleware.rate_limiter.read().await;
509 assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
510 assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
511 }
512
513 {
515 let rate_limiter = middleware.rate_limiter.read().await;
516 assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_err());
517 }
518 }
519
520 #[tokio::test]
521 async fn test_latency_recording() {
522 let config = Arc::new(RwLock::new(ChaosConfig {
523 enabled: true,
524 latency: Some(LatencyConfig {
525 enabled: true,
526 fixed_delay_ms: Some(50),
527 probability: 1.0,
528 ..Default::default()
529 }),
530 ..Default::default()
531 }));
532
533 let latency_tracker = Arc::new(LatencyMetricsTracker::new());
534 let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker.clone()));
535 middleware.init_from_config().await;
536
537 let tracker_from_middleware = middleware.latency_tracker();
539 assert_eq!(Arc::as_ptr(tracker_from_middleware), Arc::as_ptr(&latency_tracker));
540
541 let delay_ms = {
543 let injector = middleware.latency_injector.read().await;
544 injector.inject().await
545 };
546 if delay_ms > 0 {
547 latency_tracker.record_latency(delay_ms);
548 }
549
550 let samples = latency_tracker.get_samples();
552 assert!(!samples.is_empty(), "Should have recorded at least one latency sample");
553 assert_eq!(samples[0].latency_ms, 50, "Recorded latency should match injected delay");
554 }
555}