Skip to main content

keyhog_verifier/
rate_limit.rs

1//! Per-service rate limiting for verification requests.
2//!
3//! `RateLimiter` enforces a minimum inter-request interval per service
4//! (token-bucket-style with a 1-token bucket). Per-service entries can
5//! override the default interval via [`RateLimiter::update_limit`]; the
6//! default interval is hot-swappable at runtime via
7//! [`RateLimiter::set_default_rps`] so the CLI's `--verify-rate` flag
8//! can take effect after the global limiter has already been
9//! lazily initialised by an earlier call site.
10use dashmap::DashMap;
11use parking_lot::Mutex;
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::time::{Duration, Instant};
14
15struct ServiceLimit {
16    last_request: Instant,
17    interval: Duration,
18}
19
20pub struct RateLimiter {
21    services: DashMap<String, Mutex<ServiceLimit>>,
22    /// Default inter-request interval, in nanoseconds. Atomic so the
23    /// CLI can adjust the global limiter's pace after construction
24    /// without having to thread a setter through every caller.
25    default_interval_nanos: AtomicU64,
26    global_error_count: AtomicUsize,
27}
28
29impl RateLimiter {
30    pub fn new(rps: f64) -> Self {
31        Self {
32            services: DashMap::new(),
33            default_interval_nanos: AtomicU64::new(rps_to_nanos(rps)),
34            global_error_count: AtomicUsize::new(0),
35        }
36    }
37
38    /// Replace the default per-service interval. Existing per-service
39    /// entries created via [`Self::update_limit`] are left at their
40    /// override; only the lazily-created defaults pick up the new pace.
41    /// Non-finite or non-positive `rps` falls back to 1.0 — the same
42    /// guard as `new()` so a caller can't drive the limiter into a
43    /// zero-interval (= infinite-rate) state by accident.
44    pub fn set_default_rps(&self, rps: f64) {
45        self.default_interval_nanos
46            .store(rps_to_nanos(rps), Ordering::Relaxed);
47    }
48
49    /// Default interval as a `Duration`. Lock-free.
50    fn default_interval(&self) -> Duration {
51        Duration::from_nanos(self.default_interval_nanos.load(Ordering::Relaxed))
52    }
53
54    pub async fn wait(&self, service: &str) {
55        let bp = if self.global_error_count.load(Ordering::Relaxed) > 50 {
56            Duration::from_secs(1)
57        } else {
58            Duration::from_millis(0)
59        };
60        let wait_time = {
61            let default = self.default_interval();
62            let entry = self.services.entry(service.to_string()).or_insert_with(|| {
63                Mutex::new(ServiceLimit {
64                    last_request: Instant::now() - default,
65                    interval: default,
66                })
67            });
68            let mut limit = entry.value().lock();
69            let now = Instant::now();
70            let elapsed = now.duration_since(limit.last_request);
71            if elapsed < limit.interval {
72                let wait = limit.interval - elapsed;
73                limit.last_request = now + wait;
74                Some(wait)
75            } else {
76                limit.last_request = now;
77                None
78            }
79        };
80        if let Some(wait) = wait_time {
81            tokio::time::sleep(wait.max(bp)).await;
82        }
83    }
84
85    pub fn record_error(&self) {
86        self.global_error_count.fetch_add(1, Ordering::Relaxed);
87    }
88
89    pub fn record_success(&self) {
90        let _ = self
91            .global_error_count
92            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
93                Some(n.saturating_sub(1))
94            });
95    }
96
97    pub async fn update_limit(&self, service: &str, rps: f64) {
98        let interval = Duration::from_nanos(rps_to_nanos(rps));
99        self.services.insert(
100            service.to_string(),
101            Mutex::new(ServiceLimit {
102                last_request: Instant::now(),
103                interval,
104            }),
105        );
106    }
107}
108
109fn rps_to_nanos(rps: f64) -> u64 {
110    let rate = if rps.is_finite() && rps > 0.0 { rps } else { 1.0 };
111    let nanos = (1.0e9 / rate).round();
112    if nanos.is_finite() && nanos >= 1.0 && nanos <= u64::MAX as f64 {
113        nanos as u64
114    } else {
115        1_000_000_000 // 1s fallback for absurd inputs
116    }
117}
118
119use std::sync::OnceLock;
120pub static GLOBAL_RATE_LIMITER: OnceLock<RateLimiter> = OnceLock::new();
121
122/// Lazily create the process-wide rate limiter at the default 5 rps.
123/// Use [`set_global_default_rps`] to retune after init.
124pub fn get_rate_limiter() -> &'static RateLimiter {
125    GLOBAL_RATE_LIMITER.get_or_init(|| RateLimiter::new(5.0))
126}
127
128/// Convenience setter the CLI calls once at startup to apply the
129/// `--verify-rate` flag. Idempotent; safe to call before or after the
130/// limiter has been lazily initialised.
131pub fn set_global_default_rps(rps: f64) {
132    get_rate_limiter().set_default_rps(rps);
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn rps_to_nanos_clamps_invalid_input() {
141        assert_eq!(rps_to_nanos(0.0), 1_000_000_000);
142        assert_eq!(rps_to_nanos(-1.0), 1_000_000_000);
143        assert_eq!(rps_to_nanos(f64::NAN), 1_000_000_000);
144        assert_eq!(rps_to_nanos(f64::INFINITY), 1_000_000_000);
145    }
146
147    #[test]
148    fn rps_to_nanos_typical_rates() {
149        assert_eq!(rps_to_nanos(1.0), 1_000_000_000);
150        assert_eq!(rps_to_nanos(5.0), 200_000_000);
151        assert_eq!(rps_to_nanos(100.0), 10_000_000);
152    }
153
154    #[test]
155    fn set_default_rps_updates_atomically() {
156        let r = RateLimiter::new(5.0);
157        assert_eq!(r.default_interval(), Duration::from_millis(200));
158        r.set_default_rps(20.0);
159        assert_eq!(r.default_interval(), Duration::from_millis(50));
160    }
161}