use crate::time::{Instant, instant_now};
pub const DEFAULT_BURST_CAPACITY: u32 = 100;
pub const DEFAULT_REFILL_RATE: f64 = 10.0;
#[derive(Debug, Clone)]
pub struct TokenBucket {
capacity: u32,
tokens: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new() -> Self {
Self::with_params(DEFAULT_BURST_CAPACITY, DEFAULT_REFILL_RATE)
}
pub fn with_params(capacity: u32, refill_rate: f64) -> Self {
Self {
capacity,
tokens: capacity as f64,
refill_rate,
last_refill: instant_now(),
}
}
pub fn try_acquire(&mut self) -> bool {
self.try_acquire_n(1)
}
pub fn try_acquire_n(&mut self, n: u32) -> bool {
self.refill();
if self.tokens >= n as f64 {
self.tokens -= n as f64;
true
} else {
false
}
}
#[cfg(test)]
pub fn available(&mut self) -> bool {
self.refill();
self.tokens >= 1.0
}
#[cfg(test)]
pub fn tokens(&mut self) -> f64 {
self.refill();
self.tokens
}
#[cfg(test)]
pub fn capacity(&self) -> u32 {
self.capacity
}
fn refill(&mut self) {
let now = instant_now();
let elapsed = now.duration_since(self.last_refill);
let elapsed_secs = elapsed.as_secs_f64();
self.tokens += elapsed_secs * self.refill_rate;
if self.tokens > self.capacity as f64 {
self.tokens = self.capacity as f64;
}
self.last_refill = now;
}
#[cfg(test)]
pub fn reset(&mut self) {
self.tokens = self.capacity as f64;
self.last_refill = instant_now();
}
#[cfg(test)]
pub fn time_until_available(&mut self) -> std::time::Duration {
self.refill();
if self.tokens >= 1.0 {
std::time::Duration::ZERO
} else {
let needed = 1.0 - self.tokens;
let secs = needed / self.refill_rate;
std::time::Duration::from_secs_f64(secs)
}
}
}
impl Default for TokenBucket {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct HandshakeRateLimiter {
bucket: TokenBucket,
pending_count: usize,
max_pending: usize,
}
impl HandshakeRateLimiter {
pub fn with_params(bucket: TokenBucket, max_pending: usize) -> Self {
Self {
bucket,
pending_count: 0,
max_pending,
}
}
#[cfg(test)]
pub fn can_start_handshake(&mut self) -> bool {
self.bucket.available() && self.pending_count < self.max_pending
}
pub fn start_handshake(&mut self) -> bool {
if self.pending_count >= self.max_pending {
return false;
}
if self.bucket.try_acquire() {
self.pending_count += 1;
true
} else {
false
}
}
pub fn complete_handshake(&mut self) {
if self.pending_count > 0 {
self.pending_count -= 1;
}
}
#[cfg(test)]
pub fn pending_count(&self) -> usize {
self.pending_count
}
#[cfg(test)]
pub fn bucket(&self) -> &TokenBucket {
&self.bucket
}
#[cfg(test)]
pub fn reset(&mut self) {
self.bucket.reset();
self.pending_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_token_bucket_basic() {
let mut bucket = TokenBucket::with_params(10, 1.0);
assert_eq!(bucket.capacity(), 10);
assert!(bucket.tokens() >= 9.9);
for _ in 0..10 {
assert!(bucket.try_acquire());
}
assert!(!bucket.try_acquire());
assert!(!bucket.available());
}
#[test]
fn test_token_bucket_refill() {
let mut bucket = TokenBucket::with_params(10, 100.0);
for _ in 0..10 {
bucket.try_acquire();
}
assert!(!bucket.available());
let before = Instant::now();
thread::sleep(Duration::from_millis(50));
let elapsed_secs = before.elapsed().as_secs_f64();
let expected = (elapsed_secs * 100.0).min(10.0);
let lo = (expected * 0.8).min(expected - 0.5).max(0.0);
let hi = (expected * 1.2).max(expected + 0.5).min(10.0);
let tokens = bucket.tokens();
assert!(
(lo..=hi).contains(&tokens),
"tokens: {}, expected ~{:.2} (range {:.2}..={:.2})",
tokens,
expected,
lo,
hi
);
}
#[test]
fn test_token_bucket_try_acquire_n() {
let mut bucket = TokenBucket::with_params(10, 1.0);
assert!(bucket.try_acquire_n(5));
assert!(bucket.tokens() >= 4.9 && bucket.tokens() <= 5.1);
assert!(bucket.try_acquire_n(5));
assert!(!bucket.try_acquire_n(1));
}
#[test]
fn test_token_bucket_reset() {
let mut bucket = TokenBucket::with_params(10, 1.0);
for _ in 0..10 {
bucket.try_acquire();
}
bucket.reset();
assert!(bucket.tokens() >= 9.9);
}
#[test]
fn test_token_bucket_time_until_available() {
let mut bucket = TokenBucket::with_params(10, 10.0);
assert_eq!(bucket.time_until_available(), Duration::ZERO);
for _ in 0..10 {
bucket.try_acquire();
}
let wait = bucket.time_until_available();
assert!(wait.as_millis() >= 90 && wait.as_millis() <= 110);
}
#[test]
fn test_handshake_rate_limiter_basic() {
let mut limiter = HandshakeRateLimiter::with_params(TokenBucket::new(), 100);
assert!(limiter.can_start_handshake());
assert_eq!(limiter.pending_count(), 0);
assert!(limiter.start_handshake());
assert_eq!(limiter.pending_count(), 1);
limiter.complete_handshake();
assert_eq!(limiter.pending_count(), 0);
}
#[test]
fn test_handshake_rate_limiter_max_pending() {
let bucket = TokenBucket::with_params(1000, 100.0);
let mut limiter = HandshakeRateLimiter::with_params(bucket, 3);
assert!(limiter.start_handshake());
assert!(limiter.start_handshake());
assert!(limiter.start_handshake());
assert!(!limiter.can_start_handshake());
assert!(!limiter.start_handshake());
limiter.complete_handshake();
assert!(limiter.can_start_handshake());
assert!(limiter.start_handshake());
}
#[test]
fn test_handshake_rate_limiter_token_exhaustion() {
let bucket = TokenBucket::with_params(5, 0.0); let mut limiter = HandshakeRateLimiter::with_params(bucket, 100);
for _ in 0..5 {
assert!(limiter.start_handshake());
}
for _ in 0..5 {
limiter.complete_handshake();
}
assert!(!limiter.can_start_handshake());
assert!(!limiter.start_handshake());
}
#[test]
fn test_handshake_rate_limiter_reset() {
let mut limiter = HandshakeRateLimiter::with_params(TokenBucket::new(), 100);
limiter.start_handshake();
limiter.start_handshake();
assert_eq!(limiter.pending_count(), 2);
limiter.reset();
assert_eq!(limiter.pending_count(), 0);
assert!(limiter.bucket().tokens >= DEFAULT_BURST_CAPACITY as f64 - 0.1);
}
}