use std::collections::VecDeque;
use std::fmt;
use parking_lot::{Mutex, MutexGuard};
use super::Instant; use crate::time::Duration;
pub trait TokenBucket {
fn get_capacity(&self) -> f32;
fn get_delay(&self) -> Option<Duration>;
fn get_tokens(&self, n: usize) -> bool;
#[cfg_attr(feature = "nightly", allow(dead_code, reason = "false positive"))]
#[cfg_attr(not(feature = "nightly"), allow(dead_code))]
fn get_bucket_duration(&self) -> Duration;
#[cfg_attr(feature = "nightly", allow(dead_code, reason = "false positive"))]
#[cfg_attr(not(feature = "nightly"), allow(dead_code))]
fn get_total_limit(&self) -> usize;
}
pub struct VectorTokenBucket {
_given_total_limit: usize,
_rate_usage_factor: f32,
duration: Duration,
total_limit: usize,
duration_overhead: Duration,
burst_duration: Duration,
burst_limit: usize,
timestamps: Mutex<VecDeque<Instant>>,
}
impl VectorTokenBucket {
pub fn new(
duration: Duration,
given_total_limit: usize,
duration_overhead: Duration,
burst_factor: f32,
rate_usage_factor: f32,
) -> Self {
debug_assert!(
0.0 < rate_usage_factor && rate_usage_factor <= 1.0,
"BAD rate_usage_factor {}.",
rate_usage_factor
);
debug_assert!(
0.0 < burst_factor && burst_factor <= 1.0,
"BAD burst_factor {}.",
burst_factor
);
let total_limit = std::cmp::max(
1,
(given_total_limit as f32 * rate_usage_factor).floor() as usize,
);
let d_eff = duration + duration_overhead;
let burst_duration = d_eff.mul_f32(burst_factor);
let burst_limit = std::cmp::max(1, (total_limit as f32 * burst_factor).floor() as usize);
debug_assert!(burst_limit <= total_limit);
VectorTokenBucket {
_given_total_limit: given_total_limit,
_rate_usage_factor: rate_usage_factor,
duration,
total_limit,
duration_overhead,
burst_duration,
burst_limit,
timestamps: Mutex::new(VecDeque::with_capacity(total_limit)),
}
}
fn update_get_timestamps(&self) -> MutexGuard<'_, VecDeque<Instant>> {
let mut timestamps = self.timestamps.lock();
if let Some(cutoff) = Instant::now().checked_sub(self.duration + self.duration_overhead) {
while timestamps.back().is_some_and(|ts| *ts < cutoff) {
timestamps.pop_back();
}
}
timestamps
}
}
impl TokenBucket for VectorTokenBucket {
fn get_capacity(&self) -> f32 {
if self.total_limit == 0 {
return -1.0;
}
let timestamps = self.update_get_timestamps();
if timestamps.len() > self.total_limit {
return -1.0;
}
1.0 - (timestamps.len() as f32 / self.total_limit as f32)
}
fn get_delay(&self) -> Option<Duration> {
let timestamps = self.update_get_timestamps();
if let Some(ts) = timestamps.get(self.total_limit - 1) {
Instant::now()
.checked_duration_since(*ts)
.and_then(|passed_dur| {
(self.duration + self.duration_overhead).checked_sub(passed_dur)
})
}
else if let Some(ts) = timestamps.get(self.burst_limit - 1) {
Instant::now()
.checked_duration_since(*ts)
.and_then(|passed_dur| self.burst_duration.checked_sub(passed_dur))
}
else {
None
}
}
fn get_tokens(&self, n: usize) -> bool {
let mut timestamps = self.update_get_timestamps();
let now = Instant::now();
timestamps.reserve(n);
for _ in 0..n {
timestamps.push_front(now);
}
if self.total_limit < timestamps.len() {
return false;
}
if let Some(burst_time) = timestamps.get(self.burst_limit) {
let duration_since = now.duration_since(*burst_time); if duration_since < self.burst_duration {
return false;
}
}
true
}
fn get_bucket_duration(&self) -> Duration {
self.duration
}
fn get_total_limit(&self) -> usize {
self.total_limit
}
}
impl fmt::Debug for VectorTokenBucket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"({}/{}:{})",
self.timestamps.lock().len(),
self.total_limit,
self.duration.as_secs()
)
}
}