keyhog_verifier/
rate_limit.rs1use dashmap::DashMap;
11use parking_lot::Mutex;
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::time::{Duration, Instant};
14
15struct ServiceLimit {
16 last_request: Instant,
17 interval: Duration,
18}
19
20pub struct RateLimiter {
21 services: DashMap<String, Mutex<ServiceLimit>>,
22 default_interval_nanos: AtomicU64,
26 global_error_count: AtomicUsize,
27}
28
29impl RateLimiter {
30 pub fn new(rps: f64) -> Self {
31 Self {
32 services: DashMap::new(),
33 default_interval_nanos: AtomicU64::new(rps_to_nanos(rps)),
34 global_error_count: AtomicUsize::new(0),
35 }
36 }
37
38 pub fn set_default_rps(&self, rps: f64) {
45 self.default_interval_nanos
46 .store(rps_to_nanos(rps), Ordering::Relaxed);
47 }
48
49 fn default_interval(&self) -> Duration {
51 Duration::from_nanos(self.default_interval_nanos.load(Ordering::Relaxed))
52 }
53
54 pub async fn wait(&self, service: &str) {
55 let bp = if self.global_error_count.load(Ordering::Relaxed) > 50 {
56 Duration::from_secs(1)
57 } else {
58 Duration::from_millis(0)
59 };
60 let wait_time = {
61 let default = self.default_interval();
62 let entry = self.services.entry(service.to_string()).or_insert_with(|| {
63 Mutex::new(ServiceLimit {
64 last_request: Instant::now() - default,
65 interval: default,
66 })
67 });
68 let mut limit = entry.value().lock();
69 let now = Instant::now();
70 let elapsed = now.duration_since(limit.last_request);
71 if elapsed < limit.interval {
72 let wait = limit.interval - elapsed;
73 limit.last_request = now + wait;
74 Some(wait)
75 } else {
76 limit.last_request = now;
77 None
78 }
79 };
80 if let Some(wait) = wait_time {
81 tokio::time::sleep(wait.max(bp)).await;
82 }
83 }
84
85 pub fn record_error(&self) {
86 self.global_error_count.fetch_add(1, Ordering::Relaxed);
87 }
88
89 pub fn record_success(&self) {
90 let _ = self
91 .global_error_count
92 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
93 Some(n.saturating_sub(1))
94 });
95 }
96
97 pub async fn update_limit(&self, service: &str, rps: f64) {
98 let interval = Duration::from_nanos(rps_to_nanos(rps));
99 self.services.insert(
100 service.to_string(),
101 Mutex::new(ServiceLimit {
102 last_request: Instant::now(),
103 interval,
104 }),
105 );
106 }
107}
108
109fn rps_to_nanos(rps: f64) -> u64 {
110 let rate = if rps.is_finite() && rps > 0.0 { rps } else { 1.0 };
111 let nanos = (1.0e9 / rate).round();
112 if nanos.is_finite() && nanos >= 1.0 && nanos <= u64::MAX as f64 {
113 nanos as u64
114 } else {
115 1_000_000_000 }
117}
118
119use std::sync::OnceLock;
120pub static GLOBAL_RATE_LIMITER: OnceLock<RateLimiter> = OnceLock::new();
121
122pub fn get_rate_limiter() -> &'static RateLimiter {
125 GLOBAL_RATE_LIMITER.get_or_init(|| RateLimiter::new(5.0))
126}
127
128pub fn set_global_default_rps(rps: f64) {
132 get_rate_limiter().set_default_rps(rps);
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn rps_to_nanos_clamps_invalid_input() {
141 assert_eq!(rps_to_nanos(0.0), 1_000_000_000);
142 assert_eq!(rps_to_nanos(-1.0), 1_000_000_000);
143 assert_eq!(rps_to_nanos(f64::NAN), 1_000_000_000);
144 assert_eq!(rps_to_nanos(f64::INFINITY), 1_000_000_000);
145 }
146
147 #[test]
148 fn rps_to_nanos_typical_rates() {
149 assert_eq!(rps_to_nanos(1.0), 1_000_000_000);
150 assert_eq!(rps_to_nanos(5.0), 200_000_000);
151 assert_eq!(rps_to_nanos(100.0), 10_000_000);
152 }
153
154 #[test]
155 fn set_default_rps_updates_atomically() {
156 let r = RateLimiter::new(5.0);
157 assert_eq!(r.default_interval(), Duration::from_millis(200));
158 r.set_default_rps(20.0);
159 assert_eq!(r.default_interval(), Duration::from_millis(50));
160 }
161}