1use std::collections::HashMap;
17use std::sync::Mutex;
18use std::time::Instant;
19
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
25#[serde(default)]
26pub struct RateLimitConfig {
27 pub requests_per_second: u32,
30 pub burst: u32,
33}
34
35impl RateLimitConfig {
36 #[must_use]
38 pub fn enabled(&self) -> bool {
39 self.requests_per_second > 0
40 }
41
42 fn capacity(&self) -> u32 {
44 if self.burst > 0 {
45 self.burst
46 } else {
47 self.requests_per_second
48 }
49 }
50
51 pub fn apply_env_overrides(&mut self) -> Result<(), String> {
57 for (key, slot) in [
58 (
59 "QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND",
60 &mut self.requests_per_second,
61 ),
62 ("QUIVER_RATE_LIMIT_BURST", &mut self.burst),
63 ] {
64 if let Ok(raw) = std::env::var(key) {
65 *slot = raw
66 .parse()
67 .map_err(|_| format!("{key} must be a non-negative integer, got {raw:?}"))?;
68 }
69 }
70 Ok(())
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub struct RateLimitSnapshot {
77 pub limit: u32,
79 pub remaining: u32,
81 pub reset_secs: u64,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum RateDecision {
88 Allowed(RateLimitSnapshot),
90 Limited {
92 retry_after_secs: u64,
94 limit: u32,
96 },
97}
98
99struct Bucket {
101 tokens: f64,
102 last: Instant,
103}
104
105pub struct RateLimiter {
107 config: RateLimitConfig,
108 buckets: Mutex<HashMap<String, Bucket>>,
111}
112
113impl RateLimiter {
114 #[must_use]
116 pub fn new(config: RateLimitConfig) -> Self {
117 Self {
118 config,
119 buckets: Mutex::new(HashMap::new()),
120 }
121 }
122
123 #[must_use]
125 pub fn enabled(&self) -> bool {
126 self.config.enabled()
127 }
128
129 #[must_use]
131 pub fn check(&self, actor: &str) -> RateDecision {
132 self.check_at(actor, Instant::now())
133 }
134
135 fn check_at(&self, actor: &str, now: Instant) -> RateDecision {
137 if !self.config.enabled() {
138 return RateDecision::Allowed(RateLimitSnapshot {
139 limit: 0,
140 remaining: 0,
141 reset_secs: 0,
142 });
143 }
144 let capacity = f64::from(self.config.capacity());
145 let rate = f64::from(self.config.requests_per_second);
146 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
148 let bucket = buckets.entry(actor.to_owned()).or_insert(Bucket {
149 tokens: capacity,
150 last: now,
151 });
152 let elapsed = now.saturating_duration_since(bucket.last).as_secs_f64();
154 bucket.tokens = (bucket.tokens + elapsed * rate).min(capacity);
155 bucket.last = now;
156
157 let limit = self.config.capacity();
158 if bucket.tokens >= 1.0 {
159 bucket.tokens -= 1.0;
160 let reset_secs = ((capacity - bucket.tokens) / rate).ceil() as u64;
162 RateDecision::Allowed(RateLimitSnapshot {
163 limit,
164 remaining: bucket.tokens as u32,
165 reset_secs,
166 })
167 } else {
168 let retry_after_secs = ((1.0 - bucket.tokens) / rate).ceil().max(1.0) as u64;
170 RateDecision::Limited {
171 retry_after_secs,
172 limit,
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use std::time::Duration;
182
183 fn cfg(rps: u32, burst: u32) -> RateLimitConfig {
184 RateLimitConfig {
185 requests_per_second: rps,
186 burst,
187 }
188 }
189
190 #[test]
191 fn disabled_limiter_always_admits() {
192 let rl = RateLimiter::new(cfg(0, 0));
193 assert!(!rl.enabled());
194 for _ in 0..1000 {
195 assert!(matches!(rl.check("k"), RateDecision::Allowed(_)));
196 }
197 }
198
199 #[test]
200 fn burst_is_allowed_then_the_next_request_is_limited() {
201 let rl = RateLimiter::new(cfg(10, 3));
202 let t0 = Instant::now();
203 for expected_remaining in [2, 1, 0] {
205 match rl.check_at("k", t0) {
206 RateDecision::Allowed(s) => {
207 assert_eq!(s.limit, 3);
208 assert_eq!(s.remaining, expected_remaining);
209 }
210 RateDecision::Limited { .. } => panic!("burst should be admitted"),
211 }
212 }
213 match rl.check_at("k", t0) {
214 RateDecision::Limited {
215 retry_after_secs,
216 limit,
217 } => {
218 assert_eq!(limit, 3);
219 assert!(retry_after_secs >= 1);
220 }
221 RateDecision::Allowed(_) => panic!("4th request in a burst of 3 must be limited"),
222 }
223 }
224
225 #[test]
226 fn tokens_refill_at_the_configured_rate() {
227 let rl = RateLimiter::new(cfg(2, 2)); let t0 = Instant::now();
229 assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
230 assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
231 assert!(matches!(rl.check_at("k", t0), RateDecision::Limited { .. }));
232 let t1 = t0 + Duration::from_secs(1);
234 assert!(matches!(rl.check_at("k", t1), RateDecision::Allowed(_)));
235 assert!(matches!(rl.check_at("k", t1), RateDecision::Allowed(_)));
236 assert!(matches!(rl.check_at("k", t1), RateDecision::Limited { .. }));
237 }
238
239 #[test]
240 fn keys_have_independent_buckets() {
241 let rl = RateLimiter::new(cfg(5, 1));
242 let t0 = Instant::now();
243 assert!(matches!(rl.check_at("a", t0), RateDecision::Allowed(_)));
244 assert!(matches!(rl.check_at("a", t0), RateDecision::Limited { .. }));
245 assert!(matches!(rl.check_at("b", t0), RateDecision::Allowed(_)));
247 }
248
249 #[test]
250 fn burst_defaults_to_the_per_second_rate() {
251 let rl = RateLimiter::new(cfg(4, 0)); let t0 = Instant::now();
253 for _ in 0..4 {
254 assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
255 }
256 assert!(matches!(rl.check_at("k", t0), RateDecision::Limited { .. }));
257 }
258
259 #[test]
260 fn env_overrides_parse_and_reject_garbage() {
261 let mut c = RateLimitConfig::default();
262 unsafe {
264 std::env::set_var("QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND", "25");
265 std::env::set_var("QUIVER_RATE_LIMIT_BURST", "50");
266 }
267 c.apply_env_overrides().unwrap();
268 assert_eq!(c.requests_per_second, 25);
269 assert_eq!(c.burst, 50);
270 unsafe {
271 std::env::set_var("QUIVER_RATE_LIMIT_BURST", "lots");
272 }
273 assert!(c.apply_env_overrides().is_err());
274 unsafe {
275 std::env::remove_var("QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND");
276 std::env::remove_var("QUIVER_RATE_LIMIT_BURST");
277 }
278 }
279}