Skip to main content

quantrs2_device/security/
rate_limit.rs

1//! Token bucket rate limiter for cloud API calls.
2//!
3//! Prevents hitting provider rate limits (IBM: 5 jobs/min, AWS: 10 req/s, etc.)
4//! by tracking token consumption per backend. The token bucket algorithm smooths
5//! bursty traffic while respecting sustained-rate limits.
6
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10/// A token bucket: accumulates capacity tokens at `refill_rate` tokens/second.
11///
12/// Starts full (all tokens available). Each API call consumes one or more tokens.
13/// Tokens are replenished continuously up to `capacity`.
14pub struct TokenBucket {
15    capacity: f64,
16    tokens: f64,
17    refill_rate: f64, // tokens per second
18    last_refill: Instant,
19}
20
21impl TokenBucket {
22    /// Create a new token bucket with the given capacity and refill rate.
23    ///
24    /// # Arguments
25    /// * `capacity`    — maximum number of tokens (burst ceiling)
26    /// * `refill_rate` — tokens added per second (sustained throughput)
27    pub fn new(capacity: f64, refill_rate: f64) -> Self {
28        Self {
29            capacity,
30            tokens: capacity,
31            refill_rate,
32            last_refill: Instant::now(),
33        }
34    }
35
36    /// Refill tokens based on elapsed time since the last call
37    fn refill(&mut self) {
38        let elapsed = self.last_refill.elapsed().as_secs_f64();
39        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
40        self.last_refill = Instant::now();
41    }
42
43    /// Try to consume `tokens` from the bucket.
44    ///
45    /// Returns `true` if the tokens were available and consumed,
46    /// `false` if the bucket is too empty (caller should wait).
47    pub fn try_consume(&mut self, tokens: f64) -> bool {
48        self.refill();
49        if self.tokens >= tokens {
50            self.tokens -= tokens;
51            true
52        } else {
53            false
54        }
55    }
56
57    /// Estimate wait time before `tokens` become available.
58    ///
59    /// Returns [`Duration::ZERO`] if sufficient tokens exist right now.
60    pub fn wait_time(&mut self, tokens: f64) -> Duration {
61        self.refill();
62        if self.tokens >= tokens {
63            Duration::ZERO
64        } else {
65            let needed = tokens - self.tokens;
66            let wait_secs = needed / self.refill_rate;
67            Duration::from_secs_f64(wait_secs)
68        }
69    }
70
71    /// Return the current token count after applying any accumulated refill
72    pub fn available_tokens(&mut self) -> f64 {
73        self.refill();
74        self.tokens
75    }
76
77    /// Return the configured bucket capacity (burst ceiling)
78    pub fn capacity(&self) -> f64 {
79        self.capacity
80    }
81
82    /// Return the configured refill rate in tokens per second
83    pub fn refill_rate(&self) -> f64 {
84        self.refill_rate
85    }
86}
87
88impl std::fmt::Debug for TokenBucket {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("TokenBucket")
91            .field("capacity", &self.capacity)
92            .field("tokens", &self.tokens)
93            .field("refill_rate", &self.refill_rate)
94            .finish()
95    }
96}
97
98/// Per-provider token-bucket rate limiter.
99///
100/// Maintains a separate [`TokenBucket`] for each provider key. Providers not
101/// explicitly configured get a bucket with the default capacity and rate.
102///
103/// # Example
104///
105/// ```rust
106/// use quantrs2_device::security::rate_limit::RateLimiter;
107///
108/// let mut limiter = RateLimiter::with_cloud_defaults();
109///
110/// // Fast path — tokens available
111/// if limiter.try_consume("aws") {
112///     // submit request
113/// } else {
114///     let delay = limiter.wait_time("aws");
115///     // sleep(delay), then retry
116/// }
117/// ```
118pub struct RateLimiter {
119    buckets: HashMap<String, TokenBucket>,
120    default_capacity: f64,
121    default_rate: f64,
122}
123
124impl RateLimiter {
125    /// Create a rate limiter with the given defaults for unconfigured providers.
126    ///
127    /// # Arguments
128    /// * `default_capacity`         — burst limit for unknown providers
129    /// * `default_rate_per_second`  — sustained rate for unknown providers
130    pub fn new(default_capacity: f64, default_rate_per_second: f64) -> Self {
131        Self {
132            buckets: HashMap::new(),
133            default_capacity,
134            default_rate: default_rate_per_second,
135        }
136    }
137
138    /// Register a provider-specific bucket, overriding the defaults.
139    pub fn with_provider(mut self, provider: impl Into<String>, capacity: f64, rate: f64) -> Self {
140        self.buckets
141            .insert(provider.into(), TokenBucket::new(capacity, rate));
142        self
143    }
144
145    /// Try to consume one token for the given provider key.
146    ///
147    /// Creates a default bucket for the provider if it has not been seen before.
148    /// Returns `true` if the request can proceed immediately.
149    pub fn try_consume(&mut self, provider: &str) -> bool {
150        let (cap, rate) = (self.default_capacity, self.default_rate);
151        let bucket = self
152            .buckets
153            .entry(provider.to_string())
154            .or_insert_with(|| TokenBucket::new(cap, rate));
155        bucket.try_consume(1.0)
156    }
157
158    /// Estimate the wait time before a token is available for the given provider.
159    ///
160    /// Returns [`Duration::ZERO`] if a token is available immediately.
161    pub fn wait_time(&mut self, provider: &str) -> Duration {
162        let (cap, rate) = (self.default_capacity, self.default_rate);
163        let bucket = self
164            .buckets
165            .entry(provider.to_string())
166            .or_insert_with(|| TokenBucket::new(cap, rate));
167        bucket.wait_time(1.0)
168    }
169
170    /// Return the number of available tokens for a provider (after refill).
171    ///
172    /// Creates a default bucket if the provider has not been seen before.
173    pub fn available_tokens(&mut self, provider: &str) -> f64 {
174        let (cap, rate) = (self.default_capacity, self.default_rate);
175        let bucket = self
176            .buckets
177            .entry(provider.to_string())
178            .or_insert_with(|| TokenBucket::new(cap, rate));
179        bucket.available_tokens()
180    }
181
182    /// Pre-configured limiter with typical cloud provider limits:
183    ///
184    /// | Provider | Burst | Sustained       |
185    /// |----------|-------|-----------------|
186    /// | IBM      | 5     | 5 / 60s         |
187    /// | AWS      | 10    | 10/s            |
188    /// | Azure    | 10    | 10/s            |
189    ///
190    /// Unknown providers get a 10-token bucket at 1 token/second.
191    pub fn with_cloud_defaults() -> Self {
192        Self::new(10.0, 1.0)
193            .with_provider("ibm", 5.0, 5.0 / 60.0)
194            .with_provider("aws", 10.0, 10.0)
195            .with_provider("azure", 10.0, 10.0)
196    }
197
198    /// Return the list of currently tracked provider keys
199    pub fn tracked_providers(&self) -> Vec<&str> {
200        self.buckets.keys().map(|s| s.as_str()).collect()
201    }
202}
203
204impl std::fmt::Debug for RateLimiter {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("RateLimiter")
207            .field("providers", &self.buckets.keys().collect::<Vec<_>>())
208            .field("default_capacity", &self.default_capacity)
209            .field("default_rate", &self.default_rate)
210            .finish()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::time::Duration;
218
219    #[test]
220    fn test_token_bucket_starts_full() {
221        let mut bucket = TokenBucket::new(10.0, 1.0);
222        assert!((bucket.available_tokens() - 10.0).abs() < 1e-9);
223    }
224
225    #[test]
226    fn test_token_bucket_consume_success() {
227        let mut bucket = TokenBucket::new(5.0, 1.0);
228        assert!(bucket.try_consume(3.0));
229        // ~2 tokens remain
230        assert!(bucket.available_tokens() < 3.0);
231    }
232
233    #[test]
234    fn test_token_bucket_consume_fails_when_empty() {
235        let mut bucket = TokenBucket::new(3.0, 0.001); // very slow refill
236                                                       // Drain all tokens
237        assert!(bucket.try_consume(3.0));
238        // Next consume should fail immediately
239        assert!(!bucket.try_consume(1.0));
240    }
241
242    #[test]
243    fn test_token_bucket_wait_time_zero_when_full() {
244        let mut bucket = TokenBucket::new(10.0, 1.0);
245        let wait = bucket.wait_time(1.0);
246        assert_eq!(wait, Duration::ZERO);
247    }
248
249    #[test]
250    fn test_token_bucket_wait_time_nonzero_when_empty() {
251        let mut bucket = TokenBucket::new(3.0, 0.001); // very slow refill
252        assert!(bucket.try_consume(3.0));
253        let wait = bucket.wait_time(1.0);
254        // wait should be positive (≈ 1000s at 0.001 t/s)
255        assert!(wait > Duration::ZERO);
256    }
257
258    #[test]
259    fn test_token_bucket_capacity_ceiling() {
260        // Even after a long wait the bucket won't exceed capacity
261        let mut bucket = TokenBucket::new(5.0, 100.0);
262        // Force-set tokens as if a long time passed — by consuming 0 and then refilling
263        // Simulate by constructing with a past Instant
264        let tokens = bucket.available_tokens();
265        assert!(tokens <= 5.0 + 1e-9); // never exceeds capacity
266    }
267
268    #[test]
269    fn test_token_bucket_accessors() {
270        let bucket = TokenBucket::new(10.0, 2.5);
271        assert!((bucket.capacity() - 10.0).abs() < 1e-9);
272        assert!((bucket.refill_rate() - 2.5).abs() < 1e-9);
273    }
274
275    #[test]
276    fn test_rate_limiter_new_provider_gets_defaults() {
277        let mut limiter = RateLimiter::new(5.0, 1.0);
278        // "unknown" provider should start with full bucket
279        let tokens = limiter.available_tokens("unknown_provider");
280        assert!((tokens - 5.0).abs() < 1e-9);
281    }
282
283    #[test]
284    fn test_rate_limiter_try_consume_success() {
285        let mut limiter = RateLimiter::new(10.0, 1.0);
286        assert!(limiter.try_consume("aws"));
287    }
288
289    #[test]
290    fn test_rate_limiter_exhaustion() {
291        let mut limiter = RateLimiter::new(3.0, 0.001);
292        assert!(limiter.try_consume("test"));
293        assert!(limiter.try_consume("test"));
294        assert!(limiter.try_consume("test"));
295        // 4th should fail
296        assert!(!limiter.try_consume("test"));
297    }
298
299    #[test]
300    fn test_rate_limiter_cloud_defaults_ibm() {
301        let mut limiter = RateLimiter::with_cloud_defaults();
302        // IBM bucket has capacity 5
303        for _ in 0..5 {
304            assert!(limiter.try_consume("ibm"));
305        }
306        // 6th should fail (very slow refill)
307        assert!(!limiter.try_consume("ibm"));
308    }
309
310    #[test]
311    fn test_rate_limiter_cloud_defaults_aws() {
312        let mut limiter = RateLimiter::with_cloud_defaults();
313        // AWS bucket has capacity 10
314        for _ in 0..10 {
315            assert!(limiter.try_consume("aws"));
316        }
317        assert!(!limiter.try_consume("aws"));
318    }
319
320    #[test]
321    fn test_rate_limiter_wait_time_zero_when_available() {
322        let mut limiter = RateLimiter::new(10.0, 1.0);
323        let wait = limiter.wait_time("any_provider");
324        assert_eq!(wait, Duration::ZERO);
325    }
326
327    #[test]
328    fn test_rate_limiter_wait_time_positive_when_exhausted() {
329        let mut limiter = RateLimiter::new(1.0, 0.001);
330        assert!(limiter.try_consume("provider"));
331        let wait = limiter.wait_time("provider");
332        assert!(wait > Duration::ZERO);
333    }
334
335    #[test]
336    fn test_rate_limiter_independent_providers() {
337        let mut limiter = RateLimiter::new(2.0, 0.001);
338        // Exhaust "provider_a"
339        assert!(limiter.try_consume("provider_a"));
340        assert!(limiter.try_consume("provider_a"));
341        assert!(!limiter.try_consume("provider_a"));
342
343        // "provider_b" should be unaffected
344        assert!(limiter.try_consume("provider_b"));
345        assert!(limiter.try_consume("provider_b"));
346    }
347
348    #[test]
349    fn test_rate_limiter_tracked_providers() {
350        let mut limiter = RateLimiter::with_cloud_defaults();
351        // Pre-configured providers are already tracked
352        let providers = limiter.tracked_providers();
353        assert!(providers.contains(&"ibm"));
354        assert!(providers.contains(&"aws"));
355        assert!(providers.contains(&"azure"));
356
357        // Accessing a new provider adds it to the tracked set
358        limiter.try_consume("ionq");
359        let providers = limiter.tracked_providers();
360        assert!(providers.contains(&"ionq"));
361    }
362
363    #[test]
364    fn test_token_bucket_debug() {
365        let bucket = TokenBucket::new(5.0, 1.0);
366        let s = format!("{:?}", bucket);
367        assert!(s.contains("TokenBucket"));
368    }
369
370    #[test]
371    fn test_rate_limiter_debug() {
372        let limiter = RateLimiter::with_cloud_defaults();
373        let s = format!("{:?}", limiter);
374        assert!(s.contains("RateLimiter"));
375    }
376}