ipfrs_storage/
rate_limit.rs1use parking_lot::Mutex;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29use tokio::time::sleep;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub enum RateLimitAlgorithm {
34 TokenBucket,
36 LeakyBucket,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RateLimitConfig {
43 pub capacity: u64,
45 pub refill_rate: u64,
47 pub refill_interval: Duration,
49 pub algorithm: RateLimitAlgorithm,
51 pub block_on_limit: bool,
53}
54
55impl RateLimitConfig {
56 pub fn new(capacity: u64, refill_interval: Duration) -> Self {
62 Self {
63 capacity,
64 refill_rate: capacity,
65 refill_interval,
66 algorithm: RateLimitAlgorithm::TokenBucket,
67 block_on_limit: true,
68 }
69 }
70
71 pub fn per_second(requests: u64) -> Self {
73 Self::new(requests, Duration::from_secs(1))
74 }
75
76 pub fn per_minute(requests: u64) -> Self {
78 Self::new(requests, Duration::from_secs(60))
79 }
80
81 pub fn with_refill_rate(mut self, rate: u64) -> Self {
83 self.refill_rate = rate;
84 self
85 }
86
87 pub fn with_algorithm(mut self, algorithm: RateLimitAlgorithm) -> Self {
89 self.algorithm = algorithm;
90 self
91 }
92
93 pub fn with_blocking(mut self, block: bool) -> Self {
95 self.block_on_limit = block;
96 self
97 }
98}
99
100#[derive(Debug)]
102struct RateLimiterState {
103 tokens: f64,
105 last_refill: Instant,
107 total_requests: u64,
109 requests_allowed: u64,
111 requests_denied: u64,
113}
114
115#[derive(Debug, Clone, Default, Serialize, Deserialize)]
117pub struct RateLimitStats {
118 pub total_requests: u64,
120 pub requests_allowed: u64,
122 pub requests_denied: u64,
124 pub available_tokens: u64,
126 pub utilization_percent: f64,
128}
129
130pub struct RateLimiter {
132 config: RateLimitConfig,
133 state: Arc<Mutex<RateLimiterState>>,
134}
135
136impl RateLimiter {
137 pub fn new(config: RateLimitConfig) -> Self {
139 Self {
140 state: Arc::new(Mutex::new(RateLimiterState {
141 tokens: config.capacity as f64,
142 last_refill: Instant::now(),
143 total_requests: 0,
144 requests_allowed: 0,
145 requests_denied: 0,
146 })),
147 config,
148 }
149 }
150
151 pub async fn acquire(&self, tokens: u64) -> bool {
159 loop {
160 let wait_duration = {
162 let mut state = self.state.lock();
163 self.refill_tokens(&mut state);
164
165 state.total_requests += 1;
166
167 if state.tokens >= tokens as f64 {
168 state.tokens -= tokens as f64;
170 state.requests_allowed += 1;
171 return true;
172 } else {
173 state.requests_denied += 1;
174
175 if !self.config.block_on_limit {
176 return false;
177 }
178
179 let tokens_needed = tokens as f64 - state.tokens;
181 let tokens_per_ms = self.config.refill_rate as f64
182 / self.config.refill_interval.as_millis() as f64;
183 let wait_ms = (tokens_needed / tokens_per_ms).ceil() as u64;
184 Duration::from_millis(wait_ms.max(1))
185 }
186 };
187
188 sleep(wait_duration).await;
190 }
191 }
192
193 pub fn try_acquire(&self, tokens: u64) -> bool {
195 let mut state = self.state.lock();
196 self.refill_tokens(&mut state);
197
198 state.total_requests += 1;
199
200 if state.tokens >= tokens as f64 {
201 state.tokens -= tokens as f64;
202 state.requests_allowed += 1;
203 true
204 } else {
205 state.requests_denied += 1;
206 false
207 }
208 }
209
210 pub fn stats(&self) -> RateLimitStats {
212 let mut state = self.state.lock();
213 self.refill_tokens(&mut state);
214
215 RateLimitStats {
216 total_requests: state.total_requests,
217 requests_allowed: state.requests_allowed,
218 requests_denied: state.requests_denied,
219 available_tokens: state.tokens as u64,
220 utilization_percent: if state.total_requests > 0 {
221 (state.requests_allowed as f64 / state.total_requests as f64) * 100.0
222 } else {
223 0.0
224 },
225 }
226 }
227
228 pub fn reset(&self) {
230 let mut state = self.state.lock();
231 state.tokens = self.config.capacity as f64;
232 state.last_refill = Instant::now();
233 state.total_requests = 0;
234 state.requests_allowed = 0;
235 state.requests_denied = 0;
236 }
237
238 fn refill_tokens(&self, state: &mut RateLimiterState) {
240 let now = Instant::now();
241 let elapsed = now.duration_since(state.last_refill);
242
243 if elapsed >= self.config.refill_interval {
244 let intervals = elapsed.as_secs_f64() / self.config.refill_interval.as_secs_f64();
245 let tokens_to_add = intervals * self.config.refill_rate as f64;
246
247 state.tokens = (state.tokens + tokens_to_add).min(self.config.capacity as f64);
248 state.last_refill = now;
249 }
250 }
251}
252
253impl Clone for RateLimiter {
254 fn clone(&self) -> Self {
255 Self {
256 config: self.config.clone(),
257 state: Arc::clone(&self.state),
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use tokio::time::{sleep, Duration};
266
267 #[tokio::test]
268 async fn test_rate_limiter_basic() {
269 let config = RateLimitConfig::new(10, Duration::from_secs(1));
270 let limiter = RateLimiter::new(config);
271
272 for _ in 0..10 {
274 assert!(limiter.try_acquire(1));
275 }
276
277 assert!(!limiter.try_acquire(1));
279
280 let stats = limiter.stats();
281 assert_eq!(stats.requests_allowed, 10);
282 assert_eq!(stats.requests_denied, 1);
283 }
284
285 #[tokio::test]
286 async fn test_rate_limiter_refill() {
287 let config = RateLimitConfig::new(5, Duration::from_millis(100));
288 let limiter = RateLimiter::new(config);
289
290 for _ in 0..5 {
292 assert!(limiter.try_acquire(1));
293 }
294 assert!(!limiter.try_acquire(1));
295
296 sleep(Duration::from_millis(150)).await;
298
299 assert!(limiter.try_acquire(1));
301 }
302
303 #[tokio::test]
304 async fn test_rate_limiter_blocking() {
305 let config = RateLimitConfig::new(2, Duration::from_millis(100)).with_blocking(true);
306 let limiter = RateLimiter::new(config);
307
308 limiter.acquire(2).await;
310
311 let start = Instant::now();
313 limiter.acquire(1).await;
314 let elapsed = start.elapsed();
315
316 assert!(elapsed >= Duration::from_millis(50));
318 }
319
320 #[tokio::test]
321 async fn test_rate_limiter_stats() {
322 let config = RateLimitConfig::new(10, Duration::from_secs(1));
323 let limiter = RateLimiter::new(config);
324
325 for _ in 0..5 {
327 limiter.try_acquire(1);
328 }
329
330 let stats = limiter.stats();
331 assert_eq!(stats.total_requests, 5);
332 assert_eq!(stats.requests_allowed, 5);
333 assert_eq!(stats.requests_denied, 0);
334 assert_eq!(stats.available_tokens, 5);
335 assert_eq!(stats.utilization_percent, 100.0);
336 }
337
338 #[tokio::test]
339 async fn test_rate_limiter_reset() {
340 let config = RateLimitConfig::new(5, Duration::from_secs(1));
341 let limiter = RateLimiter::new(config);
342
343 for _ in 0..5 {
345 limiter.try_acquire(1);
346 }
347
348 limiter.reset();
350
351 assert!(limiter.try_acquire(1));
353
354 let stats = limiter.stats();
355 assert_eq!(stats.total_requests, 1);
356 }
357
358 #[tokio::test]
359 async fn test_rate_limiter_per_second() {
360 let config = RateLimitConfig::per_second(100);
361 let limiter = RateLimiter::new(config);
362
363 assert_eq!(limiter.config.capacity, 100);
364 assert_eq!(limiter.config.refill_interval, Duration::from_secs(1));
365 }
366
367 #[tokio::test]
368 async fn test_rate_limiter_per_minute() {
369 let config = RateLimitConfig::per_minute(1000);
370 let limiter = RateLimiter::new(config);
371
372 assert_eq!(limiter.config.capacity, 1000);
373 assert_eq!(limiter.config.refill_interval, Duration::from_secs(60));
374 }
375}