use std::time::SystemTime;
pub struct TokenBucket {
r: f64,
b: f64,
tokens: f64,
last: SystemTime,
}
pub type TokenAcquisitionResult = Result<f64, f64>;
impl TokenBucket {
pub fn new(r: f64, b: f64) -> TokenBucket {
TokenBucket {
r,
b,
tokens: b,
last: SystemTime::now(),
}
}
pub fn acquire(&mut self, count: f64) -> TokenAcquisitionResult {
let now = SystemTime::now();
let duration_ms: u128 = now.duration_since(self.last)
.expect("clock went backwards")
.as_millis();
self.tokens = self.b.min(
self.tokens + (self.r * duration_ms as f64) / 1000.0,
);
let allowed = self.tokens >= count;
if allowed {
self.tokens -= count;
self.last = now;
let rate: f64 = (1f64 / duration_ms as f64) * 1000.0;
Ok(rate)
} else {
let rate: f64 = (1f64 / duration_ms as f64) * 1000.0;
Err(rate)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{thread, time::Duration};
#[test]
fn test_initial_acquire() {
let mut bucket = TokenBucket::new(1.0, 1.0);
let result = bucket.acquire(1.0);
assert!(result.is_ok());
}
#[test]
fn test_acquire_when_tokens_available() {
let mut bucket = TokenBucket::new(1.0, 1.0);
let result = bucket.acquire(1.0);
assert!(result.is_ok());
thread::sleep(Duration::from_secs(1));
let result = bucket.acquire(1.0);
assert!(result.is_ok());
}
#[test]
fn test_acquire_when_tokens_not_available() {
let mut bucket = TokenBucket::new(1.0, 1.0);
let result = bucket.acquire(2.0);
assert!(result.is_err());
}
#[test]
fn test_acquire_with_replenish() {
let mut bucket = TokenBucket::new(1.0, 2.0);
let result1 = bucket.acquire(1.0);
assert!(result1.is_ok());
thread::sleep(Duration::from_secs(1));
let result2 = bucket.acquire(1.0);
assert!(result2.is_ok());
}
#[test]
fn test_rate_limited() {
let mut bucket = TokenBucket::new(1.0, 1.0);
let result1 = bucket.acquire(1.0);
assert!(result1.is_ok());
thread::sleep(Duration::from_millis(500));
let result2 = bucket.acquire(1.0);
assert!(result2.is_err());
}
}