use std::time::Duration;
use std::marker::PhantomData;
use crate::limits::RateLimit;
use crate::{Uint, SimpleRateLimitResult, RateLimitResult, BuildResult};
use crate::precision::Precision;
use crate::time_source::TimeSource;
use crate::error::{RateLimitError, BuildError};
use rate_guard_core::cores::TokenBucketCore;
use rate_guard_core::SimpleRateLimitError;
pub struct TokenBucket<P: Precision, T: TimeSource> {
core: TokenBucketCore,
time_source: T,
_precision: PhantomData<P>,
}
#[derive(Debug)]
pub struct TokenBucketBuilder {
capacity: Option<Uint>,
refill_amount: Option<Uint>,
refill_every: Option<Duration>,
}
pub struct TokenBucketBuilderWithTime<T: TimeSource> {
capacity: Option<Uint>,
refill_amount: Option<Uint>,
refill_every: Option<Duration>,
time_source: T,
}
pub struct ConfiguredTokenBucketBuilder<P: Precision, T: TimeSource> {
capacity: Option<Uint>,
refill_amount: Option<Uint>,
refill_every: Option<Duration>,
time_source: T,
_precision: PhantomData<P>,
}
impl<P: Precision, T: TimeSource> std::fmt::Debug for TokenBucket<P, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenBucket")
.field("time_source", &self.time_source)
.field("_precision", &std::any::type_name::<P>())
.finish_non_exhaustive()
}
}
impl<T: TimeSource> std::fmt::Debug for TokenBucketBuilderWithTime<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenBucketBuilderWithTime")
.field("capacity", &self.capacity)
.field("refill_amount", &self.refill_amount)
.field("refill_every", &self.refill_every)
.field("time_source", &self.time_source)
.finish()
}
}
impl<P: Precision, T: TimeSource> std::fmt::Debug for ConfiguredTokenBucketBuilder<P, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfiguredTokenBucketBuilder")
.field("capacity", &self.capacity)
.field("refill_amount", &self.refill_amount)
.field("refill_every", &self.refill_every)
.field("time_source", &self.time_source)
.field("_precision", &std::any::type_name::<P>())
.finish()
}
}
impl<P: Precision, T: TimeSource> RateLimit for TokenBucket<P, T> {
#[inline(always)]
fn try_acquire(&self, tokens: Uint) -> SimpleRateLimitResult {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.try_acquire_at(current_tick, tokens)
}
#[inline(always)]
fn try_acquire_verbose(&self, tokens: Uint) -> RateLimitResult {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.try_acquire_verbose_at(current_tick, tokens)
.map_err(|e| RateLimitError::from_core_error(e, |ticks| P::from_ticks(ticks)))
}
#[inline(always)]
fn capacity_remaining(&self) -> Result<Uint, SimpleRateLimitError> {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.capacity_remaining(current_tick)
}
}
impl TokenBucketBuilder {
pub fn builder() -> Self {
Self {
capacity: None,
refill_amount: None,
refill_every: None,
}
}
pub fn capacity(mut self, capacity: Uint) -> Self {
self.capacity = Some(capacity);
self
}
pub fn refill_amount(mut self, amount: Uint) -> Self {
self.refill_amount = Some(amount);
self
}
pub fn refill_every(mut self, interval: Duration) -> Self {
self.refill_every = Some(interval);
self
}
pub fn with_time<T: TimeSource>(self, time_source: T) -> TokenBucketBuilderWithTime<T> {
TokenBucketBuilderWithTime {
capacity: self.capacity,
refill_amount: self.refill_amount,
refill_every: self.refill_every,
time_source,
}
}
}
impl Default for TokenBucketBuilder {
fn default() -> Self {
Self::builder()
}
}
impl<T: TimeSource> TokenBucketBuilderWithTime<T> {
pub fn with_precision<P: Precision>(self) -> ConfiguredTokenBucketBuilder<P, T> {
ConfiguredTokenBucketBuilder {
capacity: self.capacity,
refill_amount: self.refill_amount,
refill_every: self.refill_every,
time_source: self.time_source,
_precision: PhantomData,
}
}
}
impl<P: Precision, T: TimeSource> ConfiguredTokenBucketBuilder<P, T> {
pub fn build(self) -> BuildResult<TokenBucket<P, T>> {
let capacity = self.capacity.ok_or(BuildError::MissingArgument("capacity"))?;
let refill_amount = self.refill_amount.ok_or(BuildError::MissingArgument("refill_amount"))?;
let refill_every = self.refill_every.ok_or(BuildError::MissingArgument("refill_every"))?;
if capacity == 0 {
return Err(BuildError::InvalidArgument {
field: "capacity",
reason: "must be greater than 0"
});
}
if refill_amount == 0 {
return Err(BuildError::InvalidArgument {
field: "refill_amount",
reason: "must be greater than 0"
});
}
if refill_every == Duration::ZERO {
return Err(BuildError::InvalidArgument {
field: "refill_every",
reason: "must be greater than zero"
});
}
if refill_amount > capacity {
return Err(BuildError::InvalidArgument {
field: "refill_amount",
reason: "should not exceed capacity for optimal rate limiting behavior"
});
}
let refill_every_ticks = P::to_ticks(refill_every);
let core = TokenBucketCore::new(capacity, refill_every_ticks, refill_amount);
Ok(TokenBucket {
core,
time_source: self.time_source,
_precision: PhantomData,
})
}
}