use std::collections::HashMap;
use std::time::{Duration, Instant};
pub struct TokenBucket {
capacity: f64,
tokens: f64,
refill_rate: f64, last_refill: Instant,
}
impl TokenBucket {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
capacity,
tokens: capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let elapsed = self.last_refill.elapsed().as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
self.last_refill = Instant::now();
}
pub fn try_consume(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
pub fn wait_time(&mut self, tokens: f64) -> Duration {
self.refill();
if self.tokens >= tokens {
Duration::ZERO
} else {
let needed = tokens - self.tokens;
let wait_secs = needed / self.refill_rate;
Duration::from_secs_f64(wait_secs)
}
}
pub fn available_tokens(&mut self) -> f64 {
self.refill();
self.tokens
}
pub fn capacity(&self) -> f64 {
self.capacity
}
pub fn refill_rate(&self) -> f64 {
self.refill_rate
}
}
impl std::fmt::Debug for TokenBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenBucket")
.field("capacity", &self.capacity)
.field("tokens", &self.tokens)
.field("refill_rate", &self.refill_rate)
.finish()
}
}
pub struct RateLimiter {
buckets: HashMap<String, TokenBucket>,
default_capacity: f64,
default_rate: f64,
}
impl RateLimiter {
pub fn new(default_capacity: f64, default_rate_per_second: f64) -> Self {
Self {
buckets: HashMap::new(),
default_capacity,
default_rate: default_rate_per_second,
}
}
pub fn with_provider(mut self, provider: impl Into<String>, capacity: f64, rate: f64) -> Self {
self.buckets
.insert(provider.into(), TokenBucket::new(capacity, rate));
self
}
pub fn try_consume(&mut self, provider: &str) -> bool {
let (cap, rate) = (self.default_capacity, self.default_rate);
let bucket = self
.buckets
.entry(provider.to_string())
.or_insert_with(|| TokenBucket::new(cap, rate));
bucket.try_consume(1.0)
}
pub fn wait_time(&mut self, provider: &str) -> Duration {
let (cap, rate) = (self.default_capacity, self.default_rate);
let bucket = self
.buckets
.entry(provider.to_string())
.or_insert_with(|| TokenBucket::new(cap, rate));
bucket.wait_time(1.0)
}
pub fn available_tokens(&mut self, provider: &str) -> f64 {
let (cap, rate) = (self.default_capacity, self.default_rate);
let bucket = self
.buckets
.entry(provider.to_string())
.or_insert_with(|| TokenBucket::new(cap, rate));
bucket.available_tokens()
}
pub fn with_cloud_defaults() -> Self {
Self::new(10.0, 1.0)
.with_provider("ibm", 5.0, 5.0 / 60.0)
.with_provider("aws", 10.0, 10.0)
.with_provider("azure", 10.0, 10.0)
}
pub fn tracked_providers(&self) -> Vec<&str> {
self.buckets.keys().map(|s| s.as_str()).collect()
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("providers", &self.buckets.keys().collect::<Vec<_>>())
.field("default_capacity", &self.default_capacity)
.field("default_rate", &self.default_rate)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_token_bucket_starts_full() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert!((bucket.available_tokens() - 10.0).abs() < 1e-9);
}
#[test]
fn test_token_bucket_consume_success() {
let mut bucket = TokenBucket::new(5.0, 1.0);
assert!(bucket.try_consume(3.0));
assert!(bucket.available_tokens() < 3.0);
}
#[test]
fn test_token_bucket_consume_fails_when_empty() {
let mut bucket = TokenBucket::new(3.0, 0.001); assert!(bucket.try_consume(3.0));
assert!(!bucket.try_consume(1.0));
}
#[test]
fn test_token_bucket_wait_time_zero_when_full() {
let mut bucket = TokenBucket::new(10.0, 1.0);
let wait = bucket.wait_time(1.0);
assert_eq!(wait, Duration::ZERO);
}
#[test]
fn test_token_bucket_wait_time_nonzero_when_empty() {
let mut bucket = TokenBucket::new(3.0, 0.001); assert!(bucket.try_consume(3.0));
let wait = bucket.wait_time(1.0);
assert!(wait > Duration::ZERO);
}
#[test]
fn test_token_bucket_capacity_ceiling() {
let mut bucket = TokenBucket::new(5.0, 100.0);
let tokens = bucket.available_tokens();
assert!(tokens <= 5.0 + 1e-9); }
#[test]
fn test_token_bucket_accessors() {
let bucket = TokenBucket::new(10.0, 2.5);
assert!((bucket.capacity() - 10.0).abs() < 1e-9);
assert!((bucket.refill_rate() - 2.5).abs() < 1e-9);
}
#[test]
fn test_rate_limiter_new_provider_gets_defaults() {
let mut limiter = RateLimiter::new(5.0, 1.0);
let tokens = limiter.available_tokens("unknown_provider");
assert!((tokens - 5.0).abs() < 1e-9);
}
#[test]
fn test_rate_limiter_try_consume_success() {
let mut limiter = RateLimiter::new(10.0, 1.0);
assert!(limiter.try_consume("aws"));
}
#[test]
fn test_rate_limiter_exhaustion() {
let mut limiter = RateLimiter::new(3.0, 0.001);
assert!(limiter.try_consume("test"));
assert!(limiter.try_consume("test"));
assert!(limiter.try_consume("test"));
assert!(!limiter.try_consume("test"));
}
#[test]
fn test_rate_limiter_cloud_defaults_ibm() {
let mut limiter = RateLimiter::with_cloud_defaults();
for _ in 0..5 {
assert!(limiter.try_consume("ibm"));
}
assert!(!limiter.try_consume("ibm"));
}
#[test]
fn test_rate_limiter_cloud_defaults_aws() {
let mut limiter = RateLimiter::with_cloud_defaults();
for _ in 0..10 {
assert!(limiter.try_consume("aws"));
}
assert!(!limiter.try_consume("aws"));
}
#[test]
fn test_rate_limiter_wait_time_zero_when_available() {
let mut limiter = RateLimiter::new(10.0, 1.0);
let wait = limiter.wait_time("any_provider");
assert_eq!(wait, Duration::ZERO);
}
#[test]
fn test_rate_limiter_wait_time_positive_when_exhausted() {
let mut limiter = RateLimiter::new(1.0, 0.001);
assert!(limiter.try_consume("provider"));
let wait = limiter.wait_time("provider");
assert!(wait > Duration::ZERO);
}
#[test]
fn test_rate_limiter_independent_providers() {
let mut limiter = RateLimiter::new(2.0, 0.001);
assert!(limiter.try_consume("provider_a"));
assert!(limiter.try_consume("provider_a"));
assert!(!limiter.try_consume("provider_a"));
assert!(limiter.try_consume("provider_b"));
assert!(limiter.try_consume("provider_b"));
}
#[test]
fn test_rate_limiter_tracked_providers() {
let mut limiter = RateLimiter::with_cloud_defaults();
let providers = limiter.tracked_providers();
assert!(providers.contains(&"ibm"));
assert!(providers.contains(&"aws"));
assert!(providers.contains(&"azure"));
limiter.try_consume("ionq");
let providers = limiter.tracked_providers();
assert!(providers.contains(&"ionq"));
}
#[test]
fn test_token_bucket_debug() {
let bucket = TokenBucket::new(5.0, 1.0);
let s = format!("{:?}", bucket);
assert!(s.contains("TokenBucket"));
}
#[test]
fn test_rate_limiter_debug() {
let limiter = RateLimiter::with_cloud_defaults();
let s = format!("{:?}", limiter);
assert!(s.contains("RateLimiter"));
}
}