1use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct AdaptiveRateConfig {
20 pub base_rate_per_minute: u64,
22 pub burst_multiplier: f64,
25 pub anomaly_reduction_factor: f64,
28 pub recovery_period: Duration,
30 pub window_size: Duration,
32}
33
34impl Default for AdaptiveRateConfig {
35 fn default() -> Self {
36 Self {
37 base_rate_per_minute: 100,
38 burst_multiplier: 2.0,
39 anomaly_reduction_factor: 0.5,
40 recovery_period: Duration::from_secs(300),
41 window_size: Duration::from_secs(60),
42 }
43 }
44}
45
46impl AdaptiveRateConfig {
47 pub fn validate(&self) -> Result<(), String> {
52 if self.base_rate_per_minute == 0 {
53 return Err("base_rate_per_minute must be > 0".to_string());
54 }
55 if !self.burst_multiplier.is_finite() || self.burst_multiplier <= 0.0 {
56 return Err(format!(
57 "burst_multiplier must be finite and > 0.0, got {}",
58 self.burst_multiplier
59 ));
60 }
61 if !self.anomaly_reduction_factor.is_finite()
62 || self.anomaly_reduction_factor < 0.0
63 || self.anomaly_reduction_factor > 1.0
64 {
65 return Err(format!(
66 "anomaly_reduction_factor must be in [0.0, 1.0], got {}",
67 self.anomaly_reduction_factor
68 ));
69 }
70 if self.recovery_period.is_zero() {
71 return Err("recovery_period must be > 0".to_string());
72 }
73 if self.window_size.is_zero() {
74 return Err("window_size must be > 0".to_string());
75 }
76 Ok(())
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct RateState {
83 pub current_rate: u64,
85 pub request_count: u64,
87 pub window_start: Instant,
89 pub anomaly_detected: bool,
91 pub anomaly_expires: Option<Instant>,
93 pub total_requests: u64,
95 pub total_denied: u64,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum RateDecision {
102 Allow,
104 Deny,
106 Throttle,
108}
109
110#[derive(Debug, Clone)]
112pub struct RateLimiterStats {
113 pub active_entities: usize,
114 pub total_requests: u64,
115 pub total_denied: u64,
116 pub anomaly_count: usize,
117}
118
119const MAX_TRACKED_ENTITIES: usize = 100_000;
122
123pub struct AdaptiveRateLimiter {
129 config: AdaptiveRateConfig,
130 states: HashMap<String, RateState>,
131}
132
133impl AdaptiveRateLimiter {
134 pub fn new(config: AdaptiveRateConfig) -> Self {
139 if let Err(e) = config.validate() {
140 tracing::warn!(error = %e, "AdaptiveRateConfig validation failed — using defaults");
141 return Self {
142 config: AdaptiveRateConfig::default(),
143 states: HashMap::new(),
144 };
145 }
146 Self {
147 config,
148 states: HashMap::new(),
149 }
150 }
151
152 #[must_use = "rate limit decisions must not be discarded"]
157 pub fn check(&mut self, entity_id: &str) -> RateDecision {
158 let now = Instant::now();
159 let base_rate = self.config.base_rate_per_minute;
160 let window_size = self.config.window_size;
161
162 if !self.states.contains_key(entity_id) && self.states.len() >= MAX_TRACKED_ENTITIES {
166 return RateDecision::Deny;
167 }
168
169 let state = self
170 .states
171 .entry(entity_id.to_string())
172 .or_insert_with(|| RateState {
173 current_rate: base_rate,
174 request_count: 0,
175 window_start: now,
176 anomaly_detected: false,
177 anomaly_expires: None,
178 total_requests: 0,
179 total_denied: 0,
180 });
181
182 if now.duration_since(state.window_start) >= window_size {
184 state.window_start = now;
185 state.request_count = 0;
186 }
187
188 if state.anomaly_detected {
190 if let Some(expires) = state.anomaly_expires {
191 if now >= expires {
192 state.anomaly_detected = false;
193 state.anomaly_expires = None;
194 state.current_rate = base_rate;
195 }
196 }
197 }
198
199 state.request_count = state.request_count.saturating_add(1);
201 state.total_requests = state.total_requests.saturating_add(1);
202
203 let burst_multiplier = clamp_f64(self.config.burst_multiplier, 0.0, 100.0);
205 let burst_ceiling = (state.current_rate as f64 * burst_multiplier) as u64;
206 let burst_ceiling = burst_ceiling.max(state.current_rate);
208
209 if state.request_count > burst_ceiling {
210 state.total_denied = state.total_denied.saturating_add(1);
211 RateDecision::Deny
212 } else if state.request_count > state.current_rate {
213 RateDecision::Throttle
214 } else {
215 RateDecision::Allow
216 }
217 }
218
219 pub fn signal_anomaly(&mut self, entity_id: &str) {
224 let now = Instant::now();
225 let base_rate = self.config.base_rate_per_minute;
226 let reduction = clamp_f64(self.config.anomaly_reduction_factor, 0.0, 1.0);
227 let recovery = self.config.recovery_period;
228
229 let state = self
230 .states
231 .entry(entity_id.to_string())
232 .or_insert_with(|| RateState {
233 current_rate: base_rate,
234 request_count: 0,
235 window_start: now,
236 anomaly_detected: false,
237 anomaly_expires: None,
238 total_requests: 0,
239 total_denied: 0,
240 });
241
242 state.anomaly_detected = true;
243 state.anomaly_expires = Some(now + recovery);
244 let reduced = (base_rate as f64 * reduction) as u64;
246 state.current_rate = reduced.max(1);
247 }
248
249 pub fn clear_anomaly(&mut self, entity_id: &str) {
251 if let Some(state) = self.states.get_mut(entity_id) {
252 state.anomaly_detected = false;
253 state.anomaly_expires = None;
254 state.current_rate = self.config.base_rate_per_minute;
255 }
256 }
257
258 pub fn get_state(&self, entity_id: &str) -> Option<&RateState> {
260 self.states.get(entity_id)
261 }
262
263 pub fn prune_stale(&mut self) {
268 let now = Instant::now();
269 let staleness_threshold = self.config.window_size * 2;
270 self.states.retain(|_, state| {
271 if state.anomaly_detected {
273 return true;
274 }
275 now.duration_since(state.window_start) < staleness_threshold
277 });
278 }
279
280 pub fn stats(&self) -> RateLimiterStats {
282 let mut total_requests: u64 = 0;
283 let mut total_denied: u64 = 0;
284 let mut anomaly_count: usize = 0;
285
286 for state in self.states.values() {
287 total_requests = total_requests.saturating_add(state.total_requests);
288 total_denied = total_denied.saturating_add(state.total_denied);
289 if state.anomaly_detected {
290 anomaly_count = anomaly_count.saturating_add(1);
291 }
292 }
293
294 RateLimiterStats {
295 active_entities: self.states.len(),
296 total_requests,
297 total_denied,
298 anomaly_count,
299 }
300 }
301}
302
303fn clamp_f64(val: f64, min: f64, max: f64) -> f64 {
305 if val.is_nan() || val.is_infinite() {
306 return min;
307 }
308 if val < min {
309 min
310 } else if val > max {
311 max
312 } else {
313 val
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn default_limiter() -> AdaptiveRateLimiter {
322 AdaptiveRateLimiter::new(AdaptiveRateConfig::default())
323 }
324
325 fn fast_limiter(base_rate: u64) -> AdaptiveRateLimiter {
326 AdaptiveRateLimiter::new(AdaptiveRateConfig {
327 base_rate_per_minute: base_rate,
328 burst_multiplier: 2.0,
329 anomaly_reduction_factor: 0.5,
330 recovery_period: Duration::from_millis(50),
331 window_size: Duration::from_millis(100),
332 })
333 }
334
335 #[test]
336 fn test_adaptive_rate_allow_under_limit() {
337 let mut limiter = fast_limiter(10);
338 for _ in 0..10 {
339 assert_eq!(limiter.check("agent-1"), RateDecision::Allow);
340 }
341 }
342
343 #[test]
344 fn test_adaptive_rate_throttle_above_base() {
345 let mut limiter = fast_limiter(5);
346 for _ in 0..5 {
348 assert_eq!(limiter.check("agent-1"), RateDecision::Allow);
349 }
350 assert_eq!(limiter.check("agent-1"), RateDecision::Throttle);
352 }
353
354 #[test]
355 fn test_adaptive_rate_deny_above_burst() {
356 let mut limiter = fast_limiter(5);
357 for _ in 0..10 {
359 let _ = limiter.check("agent-1");
360 }
361 assert_eq!(limiter.check("agent-1"), RateDecision::Deny);
363 }
364
365 #[test]
366 fn test_adaptive_rate_window_reset() {
367 let mut limiter = fast_limiter(5);
368 for _ in 0..5 {
370 let _ = limiter.check("agent-1");
371 }
372 std::thread::sleep(Duration::from_millis(120));
374 assert_eq!(limiter.check("agent-1"), RateDecision::Allow);
376 }
377
378 #[test]
379 fn test_adaptive_rate_anomaly_reduces_rate() {
380 let mut limiter = fast_limiter(10);
381 limiter.signal_anomaly("agent-1");
382 let state = limiter.get_state("agent-1").unwrap();
383 assert!(state.anomaly_detected);
384 assert_eq!(state.current_rate, 5);
386 }
387
388 #[test]
389 fn test_adaptive_rate_anomaly_expires() {
390 let mut limiter = fast_limiter(10);
391 limiter.signal_anomaly("agent-1");
392 std::thread::sleep(Duration::from_millis(60));
394 let _ = limiter.check("agent-1");
396 let state = limiter.get_state("agent-1").unwrap();
397 assert!(!state.anomaly_detected);
398 assert_eq!(state.current_rate, 10);
399 }
400
401 #[test]
402 fn test_adaptive_rate_clear_anomaly() {
403 let mut limiter = fast_limiter(10);
404 limiter.signal_anomaly("agent-1");
405 assert!(limiter.get_state("agent-1").unwrap().anomaly_detected);
406 limiter.clear_anomaly("agent-1");
407 let state = limiter.get_state("agent-1").unwrap();
408 assert!(!state.anomaly_detected);
409 assert_eq!(state.current_rate, 10);
410 }
411
412 #[test]
413 fn test_adaptive_rate_clear_anomaly_nonexistent() {
414 let mut limiter = default_limiter();
415 limiter.clear_anomaly("ghost");
417 assert!(limiter.get_state("ghost").is_none());
418 }
419
420 #[test]
421 fn test_adaptive_rate_get_state_none() {
422 let limiter = default_limiter();
423 assert!(limiter.get_state("unknown").is_none());
424 }
425
426 #[test]
427 fn test_adaptive_rate_multiple_entities() {
428 let mut limiter = fast_limiter(5);
429 assert_eq!(limiter.check("a"), RateDecision::Allow);
430 assert_eq!(limiter.check("b"), RateDecision::Allow);
431 for _ in 0..4 {
433 let _ = limiter.check("a");
434 }
435 assert_eq!(limiter.check("a"), RateDecision::Throttle);
437 assert_eq!(limiter.check("b"), RateDecision::Allow);
439 }
440
441 #[test]
442 fn test_adaptive_rate_stats_basic() {
443 let mut limiter = fast_limiter(100);
444 for _ in 0..10 {
445 let _ = limiter.check("e1");
446 }
447 for _ in 0..5 {
448 let _ = limiter.check("e2");
449 }
450 let stats = limiter.stats();
451 assert_eq!(stats.active_entities, 2);
452 assert_eq!(stats.total_requests, 15);
453 assert_eq!(stats.total_denied, 0);
454 assert_eq!(stats.anomaly_count, 0);
455 }
456
457 #[test]
458 fn test_adaptive_rate_stats_with_anomaly() {
459 let mut limiter = fast_limiter(100);
460 let _ = limiter.check("e1");
461 limiter.signal_anomaly("e2");
462 let stats = limiter.stats();
463 assert_eq!(stats.anomaly_count, 1);
464 }
465
466 #[test]
467 fn test_adaptive_rate_prune_stale() {
468 let mut limiter = fast_limiter(100);
469 let _ = limiter.check("stale");
470 std::thread::sleep(Duration::from_millis(210));
472 limiter.prune_stale();
473 assert!(limiter.get_state("stale").is_none());
474 }
475
476 #[test]
477 fn test_adaptive_rate_prune_keeps_anomaly() {
478 let mut limiter = fast_limiter(100);
479 limiter.signal_anomaly("anomalous");
480 std::thread::sleep(Duration::from_millis(210));
481 limiter.prune_stale();
482 assert!(limiter.get_state("anomalous").is_some());
484 }
485
486 #[test]
487 fn test_adaptive_rate_saturating_counters() {
488 let mut limiter = fast_limiter(u64::MAX);
489 let _ = limiter.check("saturate");
491 let state = limiter.states.get_mut("saturate").unwrap();
492 state.total_requests = u64::MAX;
493 state.request_count = u64::MAX;
494 let _ = limiter.check("saturate");
496 let state = limiter.get_state("saturate").unwrap();
497 assert_eq!(state.total_requests, u64::MAX);
498 }
499
500 #[test]
501 fn test_adaptive_rate_anomaly_minimum_rate_one() {
502 let mut limiter = AdaptiveRateLimiter::new(AdaptiveRateConfig {
504 base_rate_per_minute: 10,
505 burst_multiplier: 2.0,
506 anomaly_reduction_factor: 0.0,
507 recovery_period: Duration::from_secs(60),
508 window_size: Duration::from_secs(60),
509 });
510 limiter.signal_anomaly("agent");
511 let state = limiter.get_state("agent").unwrap();
512 assert_eq!(state.current_rate, 1); }
514
515 #[test]
516 fn test_adaptive_rate_clamp_nan_multiplier() {
517 let limiter = AdaptiveRateLimiter::new(AdaptiveRateConfig {
520 base_rate_per_minute: 10,
521 burst_multiplier: f64::NAN,
522 anomaly_reduction_factor: 0.5,
523 recovery_period: Duration::from_secs(60),
524 window_size: Duration::from_secs(60),
525 });
526 assert_eq!(limiter.config.base_rate_per_minute, 100);
528 assert!((limiter.config.burst_multiplier - 2.0).abs() < f64::EPSILON);
529 }
530
531 #[test]
532 fn test_adaptive_rate_default_config() {
533 let config = AdaptiveRateConfig::default();
534 assert_eq!(config.base_rate_per_minute, 100);
535 assert!((config.burst_multiplier - 2.0).abs() < f64::EPSILON);
536 assert!((config.anomaly_reduction_factor - 0.5).abs() < f64::EPSILON);
537 assert_eq!(config.recovery_period, Duration::from_secs(300));
538 assert_eq!(config.window_size, Duration::from_secs(60));
539 }
540}