use std::sync::Mutex;
use crate::{SimpleRateLimitResult, SimpleRateLimitError, Uint, VerboseRateLimitResult, VerboseRateLimitError};
use crate::rate_limit::RateLimitCore;
pub struct TokenBucketCore {
capacity: Uint,
refill_interval: Uint,
refill_amount: Uint,
state: Mutex<TokenBucketCoreState>,
}
struct TokenBucketCoreState {
available: Uint,
last_refill_tick: Uint,
}
impl RateLimitCore for TokenBucketCore {
#[inline(always)]
fn try_acquire_at(&self, tick: Uint,tokens: Uint) -> SimpleRateLimitResult {
self.try_acquire_at(tick, tokens)
}
#[inline(always)]
fn try_acquire_verbose_at(&self, tick: Uint, tokens: Uint) -> VerboseRateLimitResult {
self.try_acquire_verbose_at(tick, tokens)
}
#[inline(always)]
fn capacity_remaining(&self, tick: Uint) -> Result<Uint, SimpleRateLimitError> {
self.capacity_remaining(tick)
}
#[inline(always)]
fn capacity_remaining_or_0(&self, tick: Uint) -> Uint {
self.capacity_remaining_or_0(tick)
}
}
impl TokenBucketCore {
pub fn new(capacity: Uint, refill_interval: Uint, refill_amount: Uint) -> Self {
assert!(capacity > 0, "capacity must be greater than 0");
assert!(refill_interval > 0, "refill_interval must be greater than 0");
assert!(refill_amount > 0, "refill_amount must be greater than 0");
TokenBucketCore {
capacity,
refill_interval,
refill_amount,
state: Mutex::new(TokenBucketCoreState {
available: capacity, last_refill_tick: 0,
}),
}
}
#[inline(always)]
pub fn try_acquire_at(&self, tick: Uint,tokens: Uint) -> SimpleRateLimitResult {
if tokens == 0 {
return Ok(());
}
if tokens > self.capacity {
return Err(SimpleRateLimitError::BeyondCapacity);
}
let mut state = match self.state.try_lock() {
Ok(guard) => guard,
Err(_) => return Err(SimpleRateLimitError::ContentionFailure),
};
if tick < state.last_refill_tick {
return Err(SimpleRateLimitError::ExpiredTick);
}
let elapsed_ticks = tick - state.last_refill_tick;
let refill_times = elapsed_ticks / self.refill_interval;
let total_refilled = refill_times.saturating_mul(self.refill_amount);
state.available = (state.available.saturating_add(total_refilled)).min(self.capacity);
if refill_times > 0 {
state.last_refill_tick = state.last_refill_tick + (refill_times * self.refill_interval);
}
if tokens <= state.available {
state.available -= tokens;
Ok(())
} else {
Err(SimpleRateLimitError::InsufficientCapacity)
}
}
#[inline(always)]
pub fn try_acquire_verbose_at(&self, tick: Uint, tokens: Uint) -> VerboseRateLimitResult {
if tokens == 0 {
return Ok(());
}
let mut state = self.state.try_lock()
.map_err(|_| VerboseRateLimitError::ContentionFailure)?;
if tick < state.last_refill_tick {
return Err(VerboseRateLimitError::ExpiredTick {
min_acceptable_tick: state.last_refill_tick,
});
}
if tokens > self.capacity {
return Err(VerboseRateLimitError::BeyondCapacity {
acquiring: tokens,
capacity: self.capacity,
});
}
let elapsed_ticks = tick - state.last_refill_tick;
let refill_times = elapsed_ticks / self.refill_interval;
let total_refilled = refill_times.saturating_mul(self.refill_amount);
state.available = (state.available + total_refilled).min(self.capacity);
if refill_times > 0 {
state.last_refill_tick += refill_times * self.refill_interval;
}
if tokens <= state.available {
state.available -= tokens;
Ok(())
} else {
let available = state.available;
let shortfall = tokens.saturating_sub(available);
debug_assert!(shortfall > 0);
let needed_refills = (shortfall + self.refill_amount - 1) / self.refill_amount; debug_assert!(needed_refills >= 1);
let next_refill_tick = state.last_refill_tick + self.refill_interval;
let retry_after_ticks =
(needed_refills - 1) * self.refill_interval + (next_refill_tick - tick);
Err(VerboseRateLimitError::InsufficientCapacity {
acquiring: tokens,
available,
retry_after_ticks,
})
}
}
#[inline]
pub fn tokens_in_bucket(&self, tick: Uint) -> Result<Uint, SimpleRateLimitError> {
self.capacity_remaining(tick)
}
#[inline(always)]
pub fn capacity_remaining(&self, tick: Uint) -> Result<Uint, SimpleRateLimitError> {
let mut state = match self.state.try_lock() {
Ok(guard) => guard,
Err(_) => return Err(SimpleRateLimitError::ContentionFailure),
};
if tick < state.last_refill_tick {
return Err(SimpleRateLimitError::ExpiredTick);
}
let elapsed_ticks = tick - state.last_refill_tick;
let refill_times = elapsed_ticks / self.refill_interval;
let total_refilled = refill_times.saturating_mul(self.refill_amount);
state.available = (state.available.saturating_add(total_refilled)).min(self.capacity);
if refill_times > 0 {
state.last_refill_tick = state.last_refill_tick + (refill_times * self.refill_interval);
}
Ok(state.available)
}
#[inline(always)]
pub fn capacity_remaining_or_0(&self, tick: Uint) -> Uint {
self.capacity_remaining(tick).unwrap_or(0)
}
#[inline(always)]
pub fn current_capacity(&self) -> Result<Uint, SimpleRateLimitError> {
let state = match self.state.try_lock() {
Ok(guard) => guard,
Err(_) => return Err(SimpleRateLimitError::ContentionFailure),
};
Ok(state.available)
}
#[inline(always)]
pub fn current_capacity_or_0(&self) -> Uint {
self.current_capacity().unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct TokenBucketCoreConfig {
pub capacity: Uint,
pub refill_interval: Uint,
pub refill_amount: Uint,
}
impl TokenBucketCoreConfig {
pub fn new(capacity: Uint, refill_interval: Uint, refill_amount: Uint) -> Self {
Self {
capacity,
refill_interval,
refill_amount,
}
}
}
impl From<TokenBucketCoreConfig> for TokenBucketCore {
#[inline(always)]
fn from(config: TokenBucketCoreConfig) -> Self {
TokenBucketCore::new(config.capacity, config.refill_interval, config.refill_amount)
}
}