mockforge_chaos/
rate_limit.rs1use crate::{config::RateLimitConfig, ChaosError, Result};
4use governor::{
5 clock::DefaultClock,
6 state::{InMemoryState, NotKeyed},
7 Quota, RateLimiter as GovernorRateLimiter,
8};
9use nonzero_ext::nonzero;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::num::NonZeroU32;
13use std::sync::Arc;
14use tracing::debug;
15
16#[derive(Clone)]
18pub struct RateLimiter {
19 config: RateLimitConfig,
20 global_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
21 #[allow(clippy::type_complexity)]
22 ip_limiters: Arc<
23 RwLock<HashMap<String, Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
24 >,
25 #[allow(clippy::type_complexity)]
26 endpoint_limiters: Arc<
27 RwLock<HashMap<String, Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
28 >,
29}
30
31impl RateLimiter {
32 pub fn new(config: RateLimitConfig) -> Self {
34 let quota = Quota::per_second(
35 NonZeroU32::new(config.requests_per_second).unwrap_or(nonzero!(100u32)),
36 )
37 .allow_burst(NonZeroU32::new(config.burst_size).unwrap_or(nonzero!(10u32)));
38
39 let global_limiter = Arc::new(GovernorRateLimiter::direct(quota));
40
41 Self {
42 config,
43 global_limiter,
44 ip_limiters: Arc::new(RwLock::new(HashMap::new())),
45 endpoint_limiters: Arc::new(RwLock::new(HashMap::new())),
46 }
47 }
48
49 pub fn is_enabled(&self) -> bool {
51 self.config.enabled
52 }
53
54 pub fn check_global(&self) -> Result<()> {
56 if !self.config.enabled {
57 return Ok(());
58 }
59
60 if self.global_limiter.check().is_err() {
61 debug!("Global rate limit exceeded");
62 return Err(ChaosError::RateLimitExceeded);
63 }
64
65 Ok(())
66 }
67
68 pub fn check_ip(&self, ip: &str) -> Result<()> {
70 if !self.config.enabled || !self.config.per_ip {
71 return Ok(());
72 }
73
74 let limiter = {
75 let mut limiters = self.ip_limiters.write();
76 limiters
77 .entry(ip.to_string())
78 .or_insert_with(|| {
79 let quota = Quota::per_second(
80 NonZeroU32::new(self.config.requests_per_second)
81 .unwrap_or(nonzero!(100u32)),
82 )
83 .allow_burst(
84 NonZeroU32::new(self.config.burst_size).unwrap_or(nonzero!(10u32)),
85 );
86 Arc::new(GovernorRateLimiter::direct(quota))
87 })
88 .clone()
89 };
90
91 if limiter.check().is_err() {
92 debug!("Per-IP rate limit exceeded for {}", ip);
93 return Err(ChaosError::RateLimitExceeded);
94 }
95
96 Ok(())
97 }
98
99 pub fn check_endpoint(&self, endpoint: &str) -> Result<()> {
101 if !self.config.enabled || !self.config.per_endpoint {
102 return Ok(());
103 }
104
105 let limiter = {
106 let mut limiters = self.endpoint_limiters.write();
107 limiters
108 .entry(endpoint.to_string())
109 .or_insert_with(|| {
110 let quota = Quota::per_second(
111 NonZeroU32::new(self.config.requests_per_second)
112 .unwrap_or(nonzero!(100u32)),
113 )
114 .allow_burst(
115 NonZeroU32::new(self.config.burst_size).unwrap_or(nonzero!(10u32)),
116 );
117 Arc::new(GovernorRateLimiter::direct(quota))
118 })
119 .clone()
120 };
121
122 if limiter.check().is_err() {
123 debug!("Per-endpoint rate limit exceeded for {}", endpoint);
124 return Err(ChaosError::RateLimitExceeded);
125 }
126
127 Ok(())
128 }
129
130 pub fn check(&self, ip: Option<&str>, endpoint: Option<&str>) -> Result<()> {
132 self.check_global()?;
133
134 if let Some(ip_addr) = ip {
135 self.check_ip(ip_addr)?;
136 }
137
138 if let Some(endpoint_path) = endpoint {
139 self.check_endpoint(endpoint_path)?;
140 }
141
142 Ok(())
143 }
144
145 pub fn config(&self) -> &RateLimitConfig {
147 &self.config
148 }
149
150 pub fn update_config(&mut self, config: RateLimitConfig) {
152 self.config = config;
153 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_global_rate_limit() {
163 let config = RateLimitConfig {
164 enabled: true,
165 requests_per_second: 1,
166 burst_size: 2, per_ip: false,
168 per_endpoint: false,
169 };
170
171 let limiter = RateLimiter::new(config);
172
173 assert!(limiter.check_global().is_ok());
175
176 assert!(limiter.check_global().is_ok());
178
179 assert!(matches!(limiter.check_global(), Err(ChaosError::RateLimitExceeded)));
181 }
182
183 #[test]
184 fn test_disabled_rate_limit() {
185 let config = RateLimitConfig {
186 enabled: false,
187 ..Default::default()
188 };
189
190 let limiter = RateLimiter::new(config);
191
192 for _ in 0..1000 {
194 assert!(limiter.check_global().is_ok());
195 }
196 }
197
198 #[test]
199 fn test_per_ip_rate_limit() {
200 let config = RateLimitConfig {
201 enabled: true,
202 requests_per_second: 1,
203 burst_size: 2, per_ip: true,
205 per_endpoint: false,
206 };
207
208 let limiter = RateLimiter::new(config);
209
210 assert!(limiter.check_ip("192.168.1.1").is_ok());
212 assert!(limiter.check_ip("192.168.1.2").is_ok());
213
214 assert!(limiter.check_ip("192.168.1.1").is_ok());
216 assert!(limiter.check_ip("192.168.1.2").is_ok());
217
218 assert!(matches!(limiter.check_ip("192.168.1.1"), Err(ChaosError::RateLimitExceeded)));
220 assert!(matches!(limiter.check_ip("192.168.1.2"), Err(ChaosError::RateLimitExceeded)));
221 }
222}