mockforge_chaos/
rate_limit.rs

1//! Rate limiting for controlling request throughput
2
3use crate::{config::RateLimitConfig, ChaosError, Result};
4use governor::{
5    clock::DefaultClock,
6    state::{InMemoryState, NotKeyed},
7    Quota, RateLimiter as GovernorRateLimiter,
8};
9use nonzero_ext::nonzero;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::num::NonZeroU32;
13use std::sync::Arc;
14use tracing::debug;
15
16/// Rate limiter for controlling request throughput
17#[derive(Clone)]
18pub struct RateLimiter {
19    config: RateLimitConfig,
20    global_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
21    #[allow(clippy::type_complexity)]
22    ip_limiters: Arc<
23        RwLock<HashMap<String, Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
24    >,
25    #[allow(clippy::type_complexity)]
26    endpoint_limiters: Arc<
27        RwLock<HashMap<String, Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
28    >,
29}
30
31impl RateLimiter {
32    /// Create a new rate limiter
33    pub fn new(config: RateLimitConfig) -> Self {
34        let quota = Quota::per_second(
35            NonZeroU32::new(config.requests_per_second).unwrap_or(nonzero!(100u32)),
36        )
37        .allow_burst(NonZeroU32::new(config.burst_size).unwrap_or(nonzero!(10u32)));
38
39        let global_limiter = Arc::new(GovernorRateLimiter::direct(quota));
40
41        Self {
42            config,
43            global_limiter,
44            ip_limiters: Arc::new(RwLock::new(HashMap::new())),
45            endpoint_limiters: Arc::new(RwLock::new(HashMap::new())),
46        }
47    }
48
49    /// Check if rate limiting is enabled
50    pub fn is_enabled(&self) -> bool {
51        self.config.enabled
52    }
53
54    /// Check global rate limit
55    pub fn check_global(&self) -> Result<()> {
56        if !self.config.enabled {
57            return Ok(());
58        }
59
60        if self.global_limiter.check().is_err() {
61            debug!("Global rate limit exceeded");
62            return Err(ChaosError::RateLimitExceeded);
63        }
64
65        Ok(())
66    }
67
68    /// Check per-IP rate limit
69    pub fn check_ip(&self, ip: &str) -> Result<()> {
70        if !self.config.enabled || !self.config.per_ip {
71            return Ok(());
72        }
73
74        let limiter = {
75            let mut limiters = self.ip_limiters.write();
76            limiters
77                .entry(ip.to_string())
78                .or_insert_with(|| {
79                    let quota = Quota::per_second(
80                        NonZeroU32::new(self.config.requests_per_second)
81                            .unwrap_or(nonzero!(100u32)),
82                    )
83                    .allow_burst(
84                        NonZeroU32::new(self.config.burst_size).unwrap_or(nonzero!(10u32)),
85                    );
86                    Arc::new(GovernorRateLimiter::direct(quota))
87                })
88                .clone()
89        };
90
91        if limiter.check().is_err() {
92            debug!("Per-IP rate limit exceeded for {}", ip);
93            return Err(ChaosError::RateLimitExceeded);
94        }
95
96        Ok(())
97    }
98
99    /// Check per-endpoint rate limit
100    pub fn check_endpoint(&self, endpoint: &str) -> Result<()> {
101        if !self.config.enabled || !self.config.per_endpoint {
102            return Ok(());
103        }
104
105        let limiter = {
106            let mut limiters = self.endpoint_limiters.write();
107            limiters
108                .entry(endpoint.to_string())
109                .or_insert_with(|| {
110                    let quota = Quota::per_second(
111                        NonZeroU32::new(self.config.requests_per_second)
112                            .unwrap_or(nonzero!(100u32)),
113                    )
114                    .allow_burst(
115                        NonZeroU32::new(self.config.burst_size).unwrap_or(nonzero!(10u32)),
116                    );
117                    Arc::new(GovernorRateLimiter::direct(quota))
118                })
119                .clone()
120        };
121
122        if limiter.check().is_err() {
123            debug!("Per-endpoint rate limit exceeded for {}", endpoint);
124            return Err(ChaosError::RateLimitExceeded);
125        }
126
127        Ok(())
128    }
129
130    /// Check all applicable rate limits
131    pub fn check(&self, ip: Option<&str>, endpoint: Option<&str>) -> Result<()> {
132        self.check_global()?;
133
134        if let Some(ip_addr) = ip {
135            self.check_ip(ip_addr)?;
136        }
137
138        if let Some(endpoint_path) = endpoint {
139            self.check_endpoint(endpoint_path)?;
140        }
141
142        Ok(())
143    }
144
145    /// Get configuration
146    pub fn config(&self) -> &RateLimitConfig {
147        &self.config
148    }
149
150    /// Update configuration
151    pub fn update_config(&mut self, config: RateLimitConfig) {
152        self.config = config;
153        // Note: Updating limiters would require recreating them
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_global_rate_limit() {
163        let config = RateLimitConfig {
164            enabled: true,
165            requests_per_second: 1,
166            burst_size: 2, // burst_size is the total capacity, not additional requests
167            per_ip: false,
168            per_endpoint: false,
169        };
170
171        let limiter = RateLimiter::new(config);
172
173        // First request should succeed
174        assert!(limiter.check_global().is_ok());
175
176        // Burst request should succeed
177        assert!(limiter.check_global().is_ok());
178
179        // Next request should fail (exceeded rate + burst)
180        assert!(matches!(limiter.check_global(), Err(ChaosError::RateLimitExceeded)));
181    }
182
183    #[test]
184    fn test_disabled_rate_limit() {
185        let config = RateLimitConfig {
186            enabled: false,
187            ..Default::default()
188        };
189
190        let limiter = RateLimiter::new(config);
191
192        // Should always succeed when disabled
193        for _ in 0..1000 {
194            assert!(limiter.check_global().is_ok());
195        }
196    }
197
198    #[test]
199    fn test_per_ip_rate_limit() {
200        let config = RateLimitConfig {
201            enabled: true,
202            requests_per_second: 1,
203            burst_size: 2, // burst_size is the total capacity, not additional requests
204            per_ip: true,
205            per_endpoint: false,
206        };
207
208        let limiter = RateLimiter::new(config);
209
210        // Requests from different IPs should be independent
211        assert!(limiter.check_ip("192.168.1.1").is_ok());
212        assert!(limiter.check_ip("192.168.1.2").is_ok());
213
214        // Burst
215        assert!(limiter.check_ip("192.168.1.1").is_ok());
216        assert!(limiter.check_ip("192.168.1.2").is_ok());
217
218        // Should fail for each IP independently
219        assert!(matches!(limiter.check_ip("192.168.1.1"), Err(ChaosError::RateLimitExceeded)));
220        assert!(matches!(limiter.check_ip("192.168.1.2"), Err(ChaosError::RateLimitExceeded)));
221    }
222}