quantrs2_device/security/
rate_limit.rs1use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10pub struct TokenBucket {
15 capacity: f64,
16 tokens: f64,
17 refill_rate: f64, last_refill: Instant,
19}
20
21impl TokenBucket {
22 pub fn new(capacity: f64, refill_rate: f64) -> Self {
28 Self {
29 capacity,
30 tokens: capacity,
31 refill_rate,
32 last_refill: Instant::now(),
33 }
34 }
35
36 fn refill(&mut self) {
38 let elapsed = self.last_refill.elapsed().as_secs_f64();
39 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
40 self.last_refill = Instant::now();
41 }
42
43 pub fn try_consume(&mut self, tokens: f64) -> bool {
48 self.refill();
49 if self.tokens >= tokens {
50 self.tokens -= tokens;
51 true
52 } else {
53 false
54 }
55 }
56
57 pub fn wait_time(&mut self, tokens: f64) -> Duration {
61 self.refill();
62 if self.tokens >= tokens {
63 Duration::ZERO
64 } else {
65 let needed = tokens - self.tokens;
66 let wait_secs = needed / self.refill_rate;
67 Duration::from_secs_f64(wait_secs)
68 }
69 }
70
71 pub fn available_tokens(&mut self) -> f64 {
73 self.refill();
74 self.tokens
75 }
76
77 pub fn capacity(&self) -> f64 {
79 self.capacity
80 }
81
82 pub fn refill_rate(&self) -> f64 {
84 self.refill_rate
85 }
86}
87
88impl std::fmt::Debug for TokenBucket {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("TokenBucket")
91 .field("capacity", &self.capacity)
92 .field("tokens", &self.tokens)
93 .field("refill_rate", &self.refill_rate)
94 .finish()
95 }
96}
97
98pub struct RateLimiter {
119 buckets: HashMap<String, TokenBucket>,
120 default_capacity: f64,
121 default_rate: f64,
122}
123
124impl RateLimiter {
125 pub fn new(default_capacity: f64, default_rate_per_second: f64) -> Self {
131 Self {
132 buckets: HashMap::new(),
133 default_capacity,
134 default_rate: default_rate_per_second,
135 }
136 }
137
138 pub fn with_provider(mut self, provider: impl Into<String>, capacity: f64, rate: f64) -> Self {
140 self.buckets
141 .insert(provider.into(), TokenBucket::new(capacity, rate));
142 self
143 }
144
145 pub fn try_consume(&mut self, provider: &str) -> bool {
150 let (cap, rate) = (self.default_capacity, self.default_rate);
151 let bucket = self
152 .buckets
153 .entry(provider.to_string())
154 .or_insert_with(|| TokenBucket::new(cap, rate));
155 bucket.try_consume(1.0)
156 }
157
158 pub fn wait_time(&mut self, provider: &str) -> Duration {
162 let (cap, rate) = (self.default_capacity, self.default_rate);
163 let bucket = self
164 .buckets
165 .entry(provider.to_string())
166 .or_insert_with(|| TokenBucket::new(cap, rate));
167 bucket.wait_time(1.0)
168 }
169
170 pub fn available_tokens(&mut self, provider: &str) -> f64 {
174 let (cap, rate) = (self.default_capacity, self.default_rate);
175 let bucket = self
176 .buckets
177 .entry(provider.to_string())
178 .or_insert_with(|| TokenBucket::new(cap, rate));
179 bucket.available_tokens()
180 }
181
182 pub fn with_cloud_defaults() -> Self {
192 Self::new(10.0, 1.0)
193 .with_provider("ibm", 5.0, 5.0 / 60.0)
194 .with_provider("aws", 10.0, 10.0)
195 .with_provider("azure", 10.0, 10.0)
196 }
197
198 pub fn tracked_providers(&self) -> Vec<&str> {
200 self.buckets.keys().map(|s| s.as_str()).collect()
201 }
202}
203
204impl std::fmt::Debug for RateLimiter {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("RateLimiter")
207 .field("providers", &self.buckets.keys().collect::<Vec<_>>())
208 .field("default_capacity", &self.default_capacity)
209 .field("default_rate", &self.default_rate)
210 .finish()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::time::Duration;
218
219 #[test]
220 fn test_token_bucket_starts_full() {
221 let mut bucket = TokenBucket::new(10.0, 1.0);
222 assert!((bucket.available_tokens() - 10.0).abs() < 1e-9);
223 }
224
225 #[test]
226 fn test_token_bucket_consume_success() {
227 let mut bucket = TokenBucket::new(5.0, 1.0);
228 assert!(bucket.try_consume(3.0));
229 assert!(bucket.available_tokens() < 3.0);
231 }
232
233 #[test]
234 fn test_token_bucket_consume_fails_when_empty() {
235 let mut bucket = TokenBucket::new(3.0, 0.001); assert!(bucket.try_consume(3.0));
238 assert!(!bucket.try_consume(1.0));
240 }
241
242 #[test]
243 fn test_token_bucket_wait_time_zero_when_full() {
244 let mut bucket = TokenBucket::new(10.0, 1.0);
245 let wait = bucket.wait_time(1.0);
246 assert_eq!(wait, Duration::ZERO);
247 }
248
249 #[test]
250 fn test_token_bucket_wait_time_nonzero_when_empty() {
251 let mut bucket = TokenBucket::new(3.0, 0.001); assert!(bucket.try_consume(3.0));
253 let wait = bucket.wait_time(1.0);
254 assert!(wait > Duration::ZERO);
256 }
257
258 #[test]
259 fn test_token_bucket_capacity_ceiling() {
260 let mut bucket = TokenBucket::new(5.0, 100.0);
262 let tokens = bucket.available_tokens();
265 assert!(tokens <= 5.0 + 1e-9); }
267
268 #[test]
269 fn test_token_bucket_accessors() {
270 let bucket = TokenBucket::new(10.0, 2.5);
271 assert!((bucket.capacity() - 10.0).abs() < 1e-9);
272 assert!((bucket.refill_rate() - 2.5).abs() < 1e-9);
273 }
274
275 #[test]
276 fn test_rate_limiter_new_provider_gets_defaults() {
277 let mut limiter = RateLimiter::new(5.0, 1.0);
278 let tokens = limiter.available_tokens("unknown_provider");
280 assert!((tokens - 5.0).abs() < 1e-9);
281 }
282
283 #[test]
284 fn test_rate_limiter_try_consume_success() {
285 let mut limiter = RateLimiter::new(10.0, 1.0);
286 assert!(limiter.try_consume("aws"));
287 }
288
289 #[test]
290 fn test_rate_limiter_exhaustion() {
291 let mut limiter = RateLimiter::new(3.0, 0.001);
292 assert!(limiter.try_consume("test"));
293 assert!(limiter.try_consume("test"));
294 assert!(limiter.try_consume("test"));
295 assert!(!limiter.try_consume("test"));
297 }
298
299 #[test]
300 fn test_rate_limiter_cloud_defaults_ibm() {
301 let mut limiter = RateLimiter::with_cloud_defaults();
302 for _ in 0..5 {
304 assert!(limiter.try_consume("ibm"));
305 }
306 assert!(!limiter.try_consume("ibm"));
308 }
309
310 #[test]
311 fn test_rate_limiter_cloud_defaults_aws() {
312 let mut limiter = RateLimiter::with_cloud_defaults();
313 for _ in 0..10 {
315 assert!(limiter.try_consume("aws"));
316 }
317 assert!(!limiter.try_consume("aws"));
318 }
319
320 #[test]
321 fn test_rate_limiter_wait_time_zero_when_available() {
322 let mut limiter = RateLimiter::new(10.0, 1.0);
323 let wait = limiter.wait_time("any_provider");
324 assert_eq!(wait, Duration::ZERO);
325 }
326
327 #[test]
328 fn test_rate_limiter_wait_time_positive_when_exhausted() {
329 let mut limiter = RateLimiter::new(1.0, 0.001);
330 assert!(limiter.try_consume("provider"));
331 let wait = limiter.wait_time("provider");
332 assert!(wait > Duration::ZERO);
333 }
334
335 #[test]
336 fn test_rate_limiter_independent_providers() {
337 let mut limiter = RateLimiter::new(2.0, 0.001);
338 assert!(limiter.try_consume("provider_a"));
340 assert!(limiter.try_consume("provider_a"));
341 assert!(!limiter.try_consume("provider_a"));
342
343 assert!(limiter.try_consume("provider_b"));
345 assert!(limiter.try_consume("provider_b"));
346 }
347
348 #[test]
349 fn test_rate_limiter_tracked_providers() {
350 let mut limiter = RateLimiter::with_cloud_defaults();
351 let providers = limiter.tracked_providers();
353 assert!(providers.contains(&"ibm"));
354 assert!(providers.contains(&"aws"));
355 assert!(providers.contains(&"azure"));
356
357 limiter.try_consume("ionq");
359 let providers = limiter.tracked_providers();
360 assert!(providers.contains(&"ionq"));
361 }
362
363 #[test]
364 fn test_token_bucket_debug() {
365 let bucket = TokenBucket::new(5.0, 1.0);
366 let s = format!("{:?}", bucket);
367 assert!(s.contains("TokenBucket"));
368 }
369
370 #[test]
371 fn test_rate_limiter_debug() {
372 let limiter = RateLimiter::with_cloud_defaults();
373 let s = format!("{:?}", limiter);
374 assert!(s.contains("RateLimiter"));
375 }
376}