1use std::collections::HashMap;
34use std::time::{Duration, SystemTime};
35
36#[derive(Debug, Clone)]
38pub struct AdaptiveRateLimitConfig {
39 pub base_rate: u64,
41 pub base_window_secs: u64,
43 pub min_rate: u64,
45 pub max_rate: u64,
47 pub reputation_multiplier: f64,
49 pub burst_multiplier: f64,
51 pub cleanup_interval_secs: u64,
53}
54
55impl Default for AdaptiveRateLimitConfig {
56 fn default() -> Self {
57 Self {
58 base_rate: 100,
59 base_window_secs: 60,
60 min_rate: 10,
61 max_rate: 1000,
62 reputation_multiplier: 2.0,
63 burst_multiplier: 1.5,
64 cleanup_interval_secs: 300,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71struct RequestRecord {
72 timestamp: SystemTime,
73 count: u64,
74}
75
76#[derive(Debug, Clone)]
78struct PeerRateLimit {
79 current_limit: u64,
81 requests: Vec<RequestRecord>,
83 last_reset: SystemTime,
85 total_requests: u64,
87 violations: u64,
89}
90
91pub struct AdaptiveRateLimiter {
93 config: AdaptiveRateLimitConfig,
94 peer_limits: HashMap<String, PeerRateLimit>,
95 last_cleanup: SystemTime,
96}
97
98impl AdaptiveRateLimiter {
99 #[must_use]
101 #[inline]
102 pub fn new(config: AdaptiveRateLimitConfig) -> Self {
103 Self {
104 config,
105 peer_limits: HashMap::new(),
106 last_cleanup: SystemTime::now(),
107 }
108 }
109
110 pub fn check_rate_limit(&mut self, peer_id: &str, reputation_score: f64) -> bool {
119 self.maybe_cleanup();
120
121 let now = SystemTime::now();
122 let calculated_limit = self.calculate_limit(reputation_score);
123
124 let state = self
125 .peer_limits
126 .entry(peer_id.to_string())
127 .or_insert_with(|| PeerRateLimit {
128 current_limit: calculated_limit,
129 requests: Vec::new(),
130 last_reset: now,
131 total_requests: 0,
132 violations: 0,
133 });
134
135 state.current_limit = calculated_limit;
137
138 let window = Duration::from_secs(self.config.base_window_secs);
140 state.requests.retain(|r| {
141 if let Ok(age) = now.duration_since(r.timestamp) {
142 age < window
143 } else {
144 false
145 }
146 });
147
148 let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
150
151 let burst_limit = (state.current_limit as f64 * self.config.burst_multiplier) as u64;
153
154 if current_count < burst_limit {
155 state.requests.push(RequestRecord {
157 timestamp: now,
158 count: 1,
159 });
160 state.total_requests += 1;
161 true
162 } else {
163 state.violations += 1;
165 false
166 }
167 }
168
169 #[inline]
171 fn calculate_limit(&self, reputation_score: f64) -> u64 {
172 let reputation_score = reputation_score.clamp(0.0, 1.0);
173
174 let multiplier = 1.0 + (reputation_score * (self.config.reputation_multiplier - 1.0));
176 let limit = (self.config.base_rate as f64 * multiplier) as u64;
177
178 limit.clamp(self.config.min_rate, self.config.max_rate)
179 }
180
181 #[must_use]
183 #[inline]
184 pub fn get_limit(&mut self, peer_id: &str, reputation_score: f64) -> u64 {
185 let limit = self.calculate_limit(reputation_score);
186
187 if let Some(state) = self.peer_limits.get_mut(peer_id) {
188 state.current_limit = limit;
189 }
190
191 limit
192 }
193
194 #[must_use]
196 #[inline]
197 pub fn get_remaining(&mut self, peer_id: &str, reputation_score: f64) -> u64 {
198 let now = SystemTime::now();
199 let window = Duration::from_secs(self.config.base_window_secs);
200
201 let state = match self.peer_limits.get_mut(peer_id) {
202 Some(s) => s,
203 None => return self.calculate_limit(reputation_score),
204 };
205
206 state.requests.retain(|r| {
208 if let Ok(age) = now.duration_since(r.timestamp) {
209 age < window
210 } else {
211 false
212 }
213 });
214
215 let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
216 let limit = self.calculate_limit(reputation_score);
217
218 limit.saturating_sub(current_count)
219 }
220
221 #[must_use]
223 #[inline]
224 pub fn get_reset_time(&self, peer_id: &str) -> Option<Duration> {
225 let state = self.peer_limits.get(peer_id)?;
226 let now = SystemTime::now();
227
228 let oldest = state.requests.iter().min_by_key(|r| r.timestamp)?;
230
231 let window = Duration::from_secs(self.config.base_window_secs);
232 let age = now.duration_since(oldest.timestamp).ok()?;
233
234 if age < window {
235 Some(window - age)
236 } else {
237 Some(Duration::from_secs(0))
238 }
239 }
240
241 #[inline]
243 pub fn reset_peer(&mut self, peer_id: &str) {
244 if let Some(state) = self.peer_limits.get_mut(peer_id) {
245 state.requests.clear();
246 state.last_reset = SystemTime::now();
247 }
248 }
249
250 #[must_use]
252 #[inline]
253 pub fn get_peer_stats(&self, peer_id: &str) -> Option<PeerRateLimitStats> {
254 let state = self.peer_limits.get(peer_id)?;
255 let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
256
257 Some(PeerRateLimitStats {
258 current_limit: state.current_limit,
259 current_usage: current_count,
260 total_requests: state.total_requests,
261 violations: state.violations,
262 })
263 }
264
265 #[must_use]
267 #[inline]
268 pub fn get_global_stats(&self) -> GlobalRateLimitStats {
269 let total_peers = self.peer_limits.len();
270 let total_requests: u64 = self.peer_limits.values().map(|s| s.total_requests).sum();
271 let total_violations: u64 = self.peer_limits.values().map(|s| s.violations).sum();
272
273 GlobalRateLimitStats {
274 total_peers,
275 total_requests,
276 total_violations,
277 }
278 }
279
280 #[inline]
282 fn maybe_cleanup(&mut self) {
283 let now = SystemTime::now();
284
285 if let Ok(duration) = now.duration_since(self.last_cleanup) {
286 if duration.as_secs() < self.config.cleanup_interval_secs {
287 return;
288 }
289 }
290
291 let cleanup_threshold = Duration::from_secs(self.config.base_window_secs * 5);
292
293 self.peer_limits.retain(|_, state| {
294 if state.requests.is_empty() {
295 if let Ok(age) = now.duration_since(state.last_reset) {
296 age < cleanup_threshold
297 } else {
298 true
299 }
300 } else {
301 true
302 }
303 });
304
305 self.last_cleanup = now;
306 }
307
308 #[inline]
310 pub fn remove_peer(&mut self, peer_id: &str) {
311 self.peer_limits.remove(peer_id);
312 }
313
314 #[must_use]
316 #[inline]
317 pub fn peer_count(&self) -> usize {
318 self.peer_limits.len()
319 }
320
321 #[inline]
323 pub fn clear(&mut self) {
324 self.peer_limits.clear();
325 self.last_cleanup = SystemTime::now();
326 }
327}
328
329#[derive(Debug, Clone)]
331pub struct PeerRateLimitStats {
332 pub current_limit: u64,
334 pub current_usage: u64,
336 pub total_requests: u64,
338 pub violations: u64,
340}
341
342#[derive(Debug, Clone)]
344pub struct GlobalRateLimitStats {
345 pub total_peers: usize,
347 pub total_requests: u64,
349 pub total_violations: u64,
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use std::thread;
357
358 #[test]
359 fn test_basic_rate_limiting() {
360 let config = AdaptiveRateLimitConfig {
361 base_rate: 10,
362 base_window_secs: 1,
363 burst_multiplier: 1.0,
364 reputation_multiplier: 1.0,
365 ..Default::default()
366 };
367
368 let mut limiter = AdaptiveRateLimiter::new(config);
369
370 for _ in 0..10 {
372 assert!(limiter.check_rate_limit("peer1", 0.5));
373 }
374
375 assert!(!limiter.check_rate_limit("peer1", 0.5));
377 }
378
379 #[test]
380 fn test_reputation_based_limits() {
381 let config = AdaptiveRateLimitConfig {
382 base_rate: 100,
383 reputation_multiplier: 3.0,
384 ..Default::default()
385 };
386
387 let mut limiter = AdaptiveRateLimiter::new(config);
388
389 let low_limit = limiter.get_limit("peer1", 0.1);
391
392 let high_limit = limiter.get_limit("peer2", 0.9);
394
395 assert!(high_limit > low_limit);
396 }
397
398 #[test]
399 fn test_window_expiration() {
400 let config = AdaptiveRateLimitConfig {
401 base_rate: 5,
402 base_window_secs: 1,
403 burst_multiplier: 1.0,
404 reputation_multiplier: 1.0,
405 min_rate: 1,
406 max_rate: 1000,
407 ..Default::default()
408 };
409
410 let mut limiter = AdaptiveRateLimiter::new(config);
411
412 for _ in 0..5 {
414 assert!(limiter.check_rate_limit("peer1", 0.5));
415 }
416
417 assert!(!limiter.check_rate_limit("peer1", 0.5));
419
420 thread::sleep(Duration::from_millis(1100));
422
423 assert!(limiter.check_rate_limit("peer1", 0.5));
425 }
426
427 #[test]
428 fn test_burst_allowance() {
429 let config = AdaptiveRateLimitConfig {
430 base_rate: 10,
431 burst_multiplier: 2.0,
432 reputation_multiplier: 1.0,
433 ..Default::default()
434 };
435
436 let mut limiter = AdaptiveRateLimiter::new(config);
437
438 for _ in 0..20 {
440 assert!(limiter.check_rate_limit("peer1", 0.5));
441 }
442
443 assert!(!limiter.check_rate_limit("peer1", 0.5));
445 }
446
447 #[test]
448 fn test_get_remaining() {
449 let config = AdaptiveRateLimitConfig {
450 base_rate: 10,
451 burst_multiplier: 1.0,
452 reputation_multiplier: 1.0,
453 ..Default::default()
454 };
455
456 let mut limiter = AdaptiveRateLimiter::new(config);
457
458 assert_eq!(limiter.get_remaining("peer1", 0.5), 10);
459
460 limiter.check_rate_limit("peer1", 0.5);
461 limiter.check_rate_limit("peer1", 0.5);
462 limiter.check_rate_limit("peer1", 0.5);
463
464 assert_eq!(limiter.get_remaining("peer1", 0.5), 7);
465 }
466
467 #[test]
468 fn test_reset_peer() {
469 let config = AdaptiveRateLimitConfig {
470 base_rate: 5,
471 burst_multiplier: 1.0,
472 reputation_multiplier: 1.0,
473 min_rate: 1,
474 max_rate: 1000,
475 ..Default::default()
476 };
477
478 let mut limiter = AdaptiveRateLimiter::new(config);
479
480 for _ in 0..5 {
482 assert!(limiter.check_rate_limit("peer1", 0.5));
483 }
484
485 assert_eq!(limiter.get_remaining("peer1", 0.5), 0);
486
487 limiter.reset_peer("peer1");
489
490 assert_eq!(limiter.get_remaining("peer1", 0.5), 5);
491 }
492
493 #[test]
494 fn test_peer_stats() {
495 let config = AdaptiveRateLimitConfig {
496 base_rate: 10,
497 burst_multiplier: 1.0,
498 ..Default::default()
499 };
500
501 let mut limiter = AdaptiveRateLimiter::new(config);
502
503 limiter.check_rate_limit("peer1", 0.5);
504 limiter.check_rate_limit("peer1", 0.5);
505 limiter.check_rate_limit("peer1", 0.5);
506
507 let stats = limiter.get_peer_stats("peer1").unwrap();
508 assert_eq!(stats.total_requests, 3);
509 assert_eq!(stats.current_usage, 3);
510 }
511
512 #[test]
513 fn test_violation_tracking() {
514 let config = AdaptiveRateLimitConfig {
515 base_rate: 2,
516 burst_multiplier: 1.0,
517 reputation_multiplier: 1.0,
518 min_rate: 1,
519 max_rate: 1000,
520 ..Default::default()
521 };
522
523 let mut limiter = AdaptiveRateLimiter::new(config);
524
525 assert!(limiter.check_rate_limit("peer1", 0.5));
526 assert!(limiter.check_rate_limit("peer1", 0.5));
527 assert!(!limiter.check_rate_limit("peer1", 0.5)); assert!(!limiter.check_rate_limit("peer1", 0.5)); let stats = limiter.get_peer_stats("peer1").unwrap();
531 assert_eq!(stats.violations, 2);
532 }
533
534 #[test]
535 fn test_global_stats() {
536 let config = AdaptiveRateLimitConfig::default();
537 let mut limiter = AdaptiveRateLimiter::new(config);
538
539 limiter.check_rate_limit("peer1", 0.5);
540 limiter.check_rate_limit("peer2", 0.5);
541 limiter.check_rate_limit("peer3", 0.5);
542
543 let stats = limiter.get_global_stats();
544 assert_eq!(stats.total_peers, 3);
545 assert_eq!(stats.total_requests, 3);
546 }
547
548 #[test]
549 fn test_min_max_limits() {
550 let config = AdaptiveRateLimitConfig {
551 base_rate: 100,
552 min_rate: 50,
553 max_rate: 200,
554 reputation_multiplier: 10.0,
555 ..Default::default()
556 };
557
558 let mut limiter = AdaptiveRateLimiter::new(config);
559
560 let low_limit = limiter.get_limit("peer1", 0.0);
563 assert_eq!(low_limit, 100);
564
565 let high_limit = limiter.get_limit("peer2", 1.0);
568 assert_eq!(high_limit, 200);
569 }
570
571 #[test]
572 fn test_remove_peer() {
573 let config = AdaptiveRateLimitConfig::default();
574 let mut limiter = AdaptiveRateLimiter::new(config);
575
576 limiter.check_rate_limit("peer1", 0.5);
577 assert_eq!(limiter.peer_count(), 1);
578
579 limiter.remove_peer("peer1");
580 assert_eq!(limiter.peer_count(), 0);
581 }
582
583 #[test]
584 fn test_clear() {
585 let config = AdaptiveRateLimitConfig::default();
586 let mut limiter = AdaptiveRateLimiter::new(config);
587
588 limiter.check_rate_limit("peer1", 0.5);
589 limiter.check_rate_limit("peer2", 0.5);
590 limiter.check_rate_limit("peer3", 0.5);
591
592 assert_eq!(limiter.peer_count(), 3);
593
594 limiter.clear();
595 assert_eq!(limiter.peer_count(), 0);
596 }
597}