mockforge_chaos/
middleware.rs

1//! Chaos engineering middleware for HTTP
2
3use crate::{
4    config::CorruptionType,
5    fault::FaultInjector,
6    latency::LatencyInjector,
7    latency_metrics::LatencyMetricsTracker,
8    rate_limit::RateLimiter,
9    resilience::{Bulkhead, CircuitBreaker},
10    traffic_shaping::TrafficShaper,
11    ChaosConfig,
12};
13use axum::{
14    body::Body,
15    extract::{Request, State},
16    http::StatusCode,
17    middleware::Next,
18    response::{IntoResponse, Response},
19};
20use http_body_util::BodyExt;
21use rand::Rng;
22use std::{net::SocketAddr, sync::Arc};
23use tokio::sync::RwLock;
24use tracing::{debug, warn};
25
26/// Chaos middleware state
27///
28/// This middleware reads configuration from a shared `Arc<RwLock<ChaosConfig>>`
29/// to support hot-reload of chaos settings at runtime.
30#[derive(Clone)]
31pub struct ChaosMiddleware {
32    /// Shared chaos configuration (read on each request for hot-reload support)
33    config: Arc<RwLock<ChaosConfig>>,
34    /// Latency metrics tracker for recording injected latencies
35    latency_tracker: Arc<LatencyMetricsTracker>,
36    /// Cached injectors (recreated when config changes)
37    /// These are cached for performance but can be updated via update_from_config()
38    latency_injector: Arc<RwLock<LatencyInjector>>,
39    fault_injector: Arc<RwLock<FaultInjector>>,
40    rate_limiter: Arc<RwLock<RateLimiter>>,
41    traffic_shaper: Arc<RwLock<TrafficShaper>>,
42    circuit_breaker: Arc<RwLock<CircuitBreaker>>,
43    bulkhead: Arc<RwLock<Bulkhead>>,
44}
45
46impl ChaosMiddleware {
47    /// Create new chaos middleware from shared config
48    ///
49    /// # Arguments
50    /// * `config` - Shared chaos configuration (Arc<RwLock<ChaosConfig>>)
51    /// * `latency_tracker` - Latency metrics tracker for recording injected latencies
52    ///
53    /// The middleware will read from the shared config on each request,
54    /// allowing hot-reload of chaos settings without restarting the server.
55    pub fn new(
56        config: Arc<RwLock<ChaosConfig>>,
57        latency_tracker: Arc<LatencyMetricsTracker>,
58    ) -> Self {
59        // Initialize injectors with defaults (will be updated via init_from_config)
60        let latency_injector = Arc::new(RwLock::new(LatencyInjector::new(Default::default())));
61
62        // FaultInjector doesn't support hot-reload, but we'll read from config directly
63        // Keep a reference for compatibility but won't use it for fault injection
64        // Note: We wrap it in RwLock for consistency, even though we read from config directly
65        let fault_injector = Arc::new(RwLock::new(FaultInjector::new(Default::default())));
66
67        let rate_limiter = Arc::new(RwLock::new(RateLimiter::new(Default::default())));
68
69        let traffic_shaper = Arc::new(RwLock::new(TrafficShaper::new(Default::default())));
70
71        let circuit_breaker = Arc::new(RwLock::new(CircuitBreaker::new(Default::default())));
72
73        let bulkhead = Arc::new(RwLock::new(Bulkhead::new(Default::default())));
74
75        Self {
76            config,
77            latency_tracker,
78            latency_injector,
79            fault_injector,
80            rate_limiter,
81            traffic_shaper,
82            circuit_breaker,
83            bulkhead,
84        }
85    }
86
87    /// Initialize middleware from config (async version)
88    ///
89    /// This should be called after creation to sync injectors with the actual config.
90    /// This is a convenience method that calls `update_from_config()`.
91    pub async fn init_from_config(&self) {
92        self.update_from_config().await;
93    }
94
95    /// Update injectors from current config
96    ///
97    /// This method should be called when the config is updated to refresh
98    /// the cached injectors. For hot-reload support, this is called automatically
99    /// when processing requests if the config has changed.
100    pub async fn update_from_config(&self) {
101        let config = self.config.read().await;
102
103        // Update latency injector
104        {
105            let mut injector = self.latency_injector.write().await;
106            *injector = LatencyInjector::new(config.latency.clone().unwrap_or_default());
107        }
108
109        // Note: FaultInjector doesn't have an update method, so we'd need to recreate it
110        // For now, we'll read from config directly in the middleware
111
112        // Update rate limiter
113        {
114            let mut limiter = self.rate_limiter.write().await;
115            *limiter = RateLimiter::new(config.rate_limit.clone().unwrap_or_default());
116        }
117
118        // Update traffic shaper
119        {
120            let mut shaper = self.traffic_shaper.write().await;
121            *shaper = TrafficShaper::new(config.traffic_shaping.clone().unwrap_or_default());
122        }
123
124        // Update circuit breaker
125        {
126            let mut breaker = self.circuit_breaker.write().await;
127            *breaker = CircuitBreaker::new(config.circuit_breaker.clone().unwrap_or_default());
128        }
129
130        // Update bulkhead
131        {
132            let mut bh = self.bulkhead.write().await;
133            *bh = Bulkhead::new(config.bulkhead.clone().unwrap_or_default());
134        }
135    }
136
137    /// Get latency injector (read-only access)
138    pub fn latency_injector(&self) -> Arc<RwLock<LatencyInjector>> {
139        self.latency_injector.clone()
140    }
141
142    /// Get fault injector (read-only access)
143    /// Note: FaultInjector doesn't support hot-reload, so we read from config directly
144    pub fn fault_injector(&self) -> Arc<RwLock<FaultInjector>> {
145        self.fault_injector.clone()
146    }
147
148    /// Get rate limiter (read-only access)
149    pub fn rate_limiter(&self) -> Arc<RwLock<RateLimiter>> {
150        self.rate_limiter.clone()
151    }
152
153    /// Get traffic shaper (read-only access)
154    pub fn traffic_shaper(&self) -> Arc<RwLock<TrafficShaper>> {
155        self.traffic_shaper.clone()
156    }
157
158    /// Get circuit breaker (read-only access)
159    pub fn circuit_breaker(&self) -> Arc<RwLock<CircuitBreaker>> {
160        self.circuit_breaker.clone()
161    }
162
163    /// Get bulkhead (read-only access)
164    pub fn bulkhead(&self) -> Arc<RwLock<Bulkhead>> {
165        self.bulkhead.clone()
166    }
167
168    /// Get shared config (for direct access if needed)
169    pub fn config(&self) -> Arc<RwLock<ChaosConfig>> {
170        self.config.clone()
171    }
172
173    /// Get latency tracker
174    pub fn latency_tracker(&self) -> &Arc<LatencyMetricsTracker> {
175        &self.latency_tracker
176    }
177}
178
179/// Chaos middleware handler (takes state directly, for use with from_fn)
180pub async fn chaos_middleware_with_state(
181    chaos: Arc<ChaosMiddleware>,
182    req: Request<Body>,
183    next: Next,
184) -> Response {
185    // Call the main handler by creating a temporary State extractor
186    // We do this by putting the state in request extensions temporarily
187    let (mut parts, body) = req.into_parts();
188    parts.extensions.insert(chaos.clone());
189    let req = Request::from_parts(parts, body);
190
191    // Now we can use the State extractor pattern
192    // But actually, let's just call the core logic directly
193    chaos_middleware_core(chaos, req, next).await
194}
195
196/// Chaos middleware handler (uses State extractor, for use with from_fn_with_state)
197pub async fn chaos_middleware(
198    State(chaos): State<Arc<ChaosMiddleware>>,
199    req: Request<Body>,
200    next: Next,
201) -> Response {
202    chaos_middleware_core(chaos, req, next).await
203}
204
205/// Core chaos middleware logic
206async fn chaos_middleware_core(
207    chaos: Arc<ChaosMiddleware>,
208    req: Request<Body>,
209    next: Next,
210) -> Response {
211    // Read config at start of request (supports hot-reload)
212    let config = chaos.config.read().await;
213
214    // Early return if chaos is disabled
215    if !config.enabled {
216        drop(config);
217        return next.run(req).await;
218    }
219
220    let path = req.uri().path().to_string();
221
222    // Extract client IP from request extensions (set by ConnectInfo if available) or headers
223    let ip = req
224        .extensions()
225        .get::<SocketAddr>()
226        .map(|addr| addr.ip().to_string())
227        .or_else(|| {
228            req.headers()
229                .get("x-forwarded-for")
230                .or_else(|| req.headers().get("x-real-ip"))
231                .and_then(|h| h.to_str().ok())
232                .map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
233        })
234        .unwrap_or_else(|| "127.0.0.1".to_string());
235
236    debug!("Chaos middleware processing: {} {}", req.method(), path);
237
238    // Release config lock early (we'll read specific configs as needed)
239    drop(config);
240
241    // Check circuit breaker
242    {
243        let circuit_breaker = chaos.circuit_breaker.read().await;
244        if !circuit_breaker.allow_request().await {
245            warn!("Circuit breaker open, rejecting request: {}", path);
246            return (
247                StatusCode::SERVICE_UNAVAILABLE,
248                "Service temporarily unavailable (circuit breaker open)",
249            )
250                .into_response();
251        }
252    }
253
254    // Try to acquire bulkhead slot
255    let _bulkhead_guard = {
256        let bulkhead = chaos.bulkhead.read().await;
257        match bulkhead.try_acquire().await {
258            Ok(guard) => guard,
259            Err(e) => {
260                warn!("Bulkhead rejected request: {} - {:?}", path, e);
261                return (StatusCode::SERVICE_UNAVAILABLE, format!("Service overloaded: {}", e))
262                    .into_response();
263            }
264        }
265    };
266
267    // Check rate limits
268    let rate_limiter = chaos.rate_limiter.read().await;
269    if let Err(_e) = rate_limiter.check(Some(&ip), Some(&path)) {
270        drop(rate_limiter);
271        warn!("Rate limit exceeded: {} - {}", ip, path);
272        return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
273    }
274    drop(rate_limiter);
275
276    // Check connection limits
277    let traffic_shaper = chaos.traffic_shaper.read().await;
278    if !traffic_shaper.check_connection_limit() {
279        drop(traffic_shaper);
280        warn!("Connection limit exceeded");
281        return (StatusCode::SERVICE_UNAVAILABLE, "Connection limit exceeded").into_response();
282    }
283
284    // Always release connection on scope exit
285    let _connection_guard = crate::traffic_shaping::ConnectionGuard::new(traffic_shaper.clone());
286
287    // Check for packet loss (simulate dropped connection)
288    if traffic_shaper.should_drop_packet() {
289        drop(traffic_shaper);
290        warn!("Simulating packet loss for: {}", path);
291        return (StatusCode::REQUEST_TIMEOUT, "Connection dropped").into_response();
292    }
293    drop(traffic_shaper);
294
295    // Inject latency and record it for metrics
296    let latency_injector = chaos.latency_injector.read().await;
297    let delay_ms = latency_injector.inject().await;
298    drop(latency_injector);
299    if delay_ms > 0 {
300        chaos.latency_tracker.record_latency(delay_ms);
301    }
302
303    // Check for fault injection (read from config for hot-reload)
304    let config = chaos.config.read().await;
305    let fault_config = config.fault_injection.as_ref();
306    let should_inject_fault = fault_config.map(|f| f.enabled).unwrap_or(false);
307    let http_error_status = if should_inject_fault {
308        // Check probability and get error status
309        fault_config.and_then(|f| {
310            let mut rng = rand::rng();
311            if rng.random::<f64>() <= f.http_error_probability && !f.http_errors.is_empty() {
312                Some(f.http_errors[rng.random_range(0..f.http_errors.len())])
313            } else {
314                None
315            }
316        })
317    } else {
318        None
319    };
320    drop(config);
321
322    if let Some(status_code) = http_error_status {
323        warn!("Injecting HTTP error: {}", status_code);
324        return (
325            StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
326            format!("Injected error: {}", status_code),
327        )
328            .into_response();
329    }
330
331    // Extract body size for bandwidth throttling
332    let (parts, body) = req.into_parts();
333    let body_bytes = match body.collect().await {
334        Ok(collected) => collected.to_bytes(),
335        Err(e) => {
336            warn!("Failed to read request body: {}", e);
337            return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
338        }
339    };
340
341    let request_size = body_bytes.len();
342
343    // Throttle request bandwidth
344    {
345        let traffic_shaper = chaos.traffic_shaper.read().await;
346        traffic_shaper.throttle_bandwidth(request_size).await;
347    }
348
349    // Reconstruct request
350    let req = Request::from_parts(parts, Body::from(body_bytes));
351
352    // Pass to next handler
353    let response = next.run(req).await;
354
355    // Record circuit breaker result based on response status
356    let status = response.status();
357    {
358        let circuit_breaker = chaos.circuit_breaker.read().await;
359        if status.is_server_error() || status == StatusCode::SERVICE_UNAVAILABLE {
360            circuit_breaker.record_failure().await;
361        } else if status.is_success() {
362            circuit_breaker.record_success().await;
363        }
364    }
365
366    // Extract response body size for bandwidth throttling
367    let (parts, body) = response.into_parts();
368    let response_body_bytes = match body.collect().await {
369        Ok(collected) => collected.to_bytes(),
370        Err(e) => {
371            warn!("Failed to read response body: {}", e);
372            return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response body")
373                .into_response();
374        }
375    };
376
377    let response_size = response_body_bytes.len();
378
379    // Check if should truncate response (partial response simulation)
380    // Read from config for hot-reload support
381    let config = chaos.config.read().await;
382    let should_truncate = config
383        .fault_injection
384        .as_ref()
385        .map(|f| f.enabled && f.timeout_errors)
386        .unwrap_or(false);
387    let should_corrupt = config.fault_injection.as_ref().map(|f| f.enabled).unwrap_or(false);
388    let corruption_type = config
389        .fault_injection
390        .as_ref()
391        .map(|f| f.corruption_type)
392        .unwrap_or(CorruptionType::None);
393    drop(config);
394
395    let mut final_body_bytes = if should_truncate {
396        warn!("Injecting partial response");
397        let truncate_at = response_size / 2;
398        response_body_bytes.slice(0..truncate_at).to_vec()
399    } else {
400        response_body_bytes.to_vec()
401    };
402
403    // Apply payload corruption if enabled
404    if should_corrupt && corruption_type != CorruptionType::None {
405        warn!("Injecting payload corruption: {:?}", corruption_type);
406        final_body_bytes = corrupt_payload(&final_body_bytes, corruption_type);
407    }
408
409    let final_body = Body::from(final_body_bytes);
410
411    // Throttle response bandwidth
412    {
413        let traffic_shaper = chaos.traffic_shaper.read().await;
414        traffic_shaper.throttle_bandwidth(response_size).await;
415    }
416
417    Response::from_parts(parts, final_body)
418}
419
420/// Corrupt a payload based on the corruption type
421fn corrupt_payload(data: &[u8], corruption_type: CorruptionType) -> Vec<u8> {
422    if data.is_empty() {
423        return data.to_vec();
424    }
425
426    let mut rng = rand::rng();
427    let mut corrupted = data.to_vec();
428
429    match corruption_type {
430        CorruptionType::None => corrupted,
431        CorruptionType::RandomBytes => {
432            // Replace 10% of bytes with random values
433            let num_bytes_to_corrupt = (data.len() as f64 * 0.1).max(1.0) as usize;
434            for _ in 0..num_bytes_to_corrupt {
435                let index = rng.random_range(0..data.len());
436                corrupted[index] = rng.random::<u8>();
437            }
438            corrupted
439        }
440        CorruptionType::Truncate => {
441            // Truncate at random position (between 50% and 90% of original length)
442            let min_truncate = data.len() / 2;
443            let max_truncate = (data.len() as f64 * 0.9) as usize;
444            let truncate_at = if max_truncate > min_truncate {
445                rng.random_range(min_truncate..=max_truncate)
446            } else {
447                min_truncate
448            };
449            corrupted.truncate(truncate_at);
450            corrupted
451        }
452        CorruptionType::BitFlip => {
453            // Flip random bits in 10% of bytes
454            let num_bytes_to_flip = (data.len() as f64 * 0.1).max(1.0) as usize;
455            for _ in 0..num_bytes_to_flip {
456                let index = rng.random_range(0..data.len());
457                let bit_to_flip = rng.random_range(0..8);
458                corrupted[index] ^= 1 << bit_to_flip;
459            }
460            corrupted
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::config::{LatencyConfig, RateLimitConfig};
469    use crate::latency_metrics::LatencyMetricsTracker;
470
471    #[tokio::test]
472    async fn test_middleware_creation() {
473        let config = ChaosConfig {
474            enabled: true,
475            latency: Some(LatencyConfig {
476                enabled: true,
477                fixed_delay_ms: Some(10),
478                ..Default::default()
479            }),
480            ..Default::default()
481        };
482
483        let latency_tracker = Arc::new(LatencyMetricsTracker::new());
484        let config_arc = Arc::new(RwLock::new(config));
485        let middleware = ChaosMiddleware::new(config_arc, latency_tracker);
486        assert!(middleware.latency_injector.read().await.is_enabled());
487    }
488
489    #[tokio::test]
490    async fn test_rate_limiting() {
491        let config = Arc::new(RwLock::new(ChaosConfig {
492            enabled: true,
493            rate_limit: Some(RateLimitConfig {
494                enabled: true,
495                requests_per_second: 1,
496                burst_size: 2, // burst_size is the total capacity, not additional requests
497                ..Default::default()
498            }),
499            ..Default::default()
500        }));
501
502        let latency_tracker = Arc::new(LatencyMetricsTracker::new());
503        let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker));
504        middleware.init_from_config().await;
505
506        // First two requests should succeed (rate + burst)
507        {
508            let rate_limiter = middleware.rate_limiter.read().await;
509            assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
510            assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_ok());
511        }
512
513        // Third should fail
514        {
515            let rate_limiter = middleware.rate_limiter.read().await;
516            assert!(rate_limiter.check(Some("127.0.0.1"), Some("/test")).is_err());
517        }
518    }
519
520    #[tokio::test]
521    async fn test_latency_recording() {
522        let config = Arc::new(RwLock::new(ChaosConfig {
523            enabled: true,
524            latency: Some(LatencyConfig {
525                enabled: true,
526                fixed_delay_ms: Some(50),
527                probability: 1.0,
528                ..Default::default()
529            }),
530            ..Default::default()
531        }));
532
533        let latency_tracker = Arc::new(LatencyMetricsTracker::new());
534        let middleware = Arc::new(ChaosMiddleware::new(config.clone(), latency_tracker.clone()));
535        middleware.init_from_config().await;
536
537        // Verify tracker is accessible via getter
538        let tracker_from_middleware = middleware.latency_tracker();
539        assert_eq!(Arc::as_ptr(tracker_from_middleware), Arc::as_ptr(&latency_tracker));
540
541        // Manually inject latency and record it (simulating what middleware does)
542        let delay_ms = {
543            let injector = middleware.latency_injector.read().await;
544            injector.inject().await
545        };
546        if delay_ms > 0 {
547            latency_tracker.record_latency(delay_ms);
548        }
549
550        // Verify latency was recorded
551        let samples = latency_tracker.get_samples();
552        assert!(!samples.is_empty(), "Should have recorded at least one latency sample");
553        assert_eq!(samples[0].latency_ms, 50, "Recorded latency should match injected delay");
554    }
555}