use std::time::Duration;
use std::marker::PhantomData;
use crate::limits::RateLimit;
use crate::{Uint, SimpleRateLimitResult, RateLimitResult, BuildResult};
use crate::precision::Precision;
use crate::time_source::TimeSource;
use crate::error::{RateLimitError, BuildError};
use rate_guard_core::cores::SlidingWindowCounterCore;
use rate_guard_core::SimpleRateLimitError;
pub struct SlidingWindowCounter<P: Precision, T: TimeSource> {
core: SlidingWindowCounterCore,
time_source: T,
_precision: PhantomData<P>,
}
#[derive(Debug)]
pub struct SlidingWindowCounterBuilder {
capacity: Option<Uint>,
bucket_duration: Option<Duration>,
bucket_count: Option<Uint>,
}
pub struct SlidingWindowCounterBuilderWithTime<T: TimeSource> {
capacity: Option<Uint>,
bucket_duration: Option<Duration>,
bucket_count: Option<Uint>,
time_source: T,
}
pub struct ConfiguredSlidingWindowCounterBuilder<P: Precision, T: TimeSource> {
capacity: Option<Uint>,
bucket_duration: Option<Duration>,
bucket_count: Option<Uint>,
time_source: T,
_precision: PhantomData<P>,
}
impl<P: Precision, T: TimeSource> std::fmt::Debug for SlidingWindowCounter<P, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SlidingWindowCounter")
.field("time_source", &self.time_source)
.field("_precision", &std::any::type_name::<P>())
.finish_non_exhaustive()
}
}
impl<T: TimeSource> std::fmt::Debug for SlidingWindowCounterBuilderWithTime<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SlidingWindowCounterBuilderWithTime")
.field("capacity", &self.capacity)
.field("bucket_duration", &self.bucket_duration)
.field("bucket_count", &self.bucket_count)
.field("time_source", &self.time_source)
.finish()
}
}
impl<P: Precision, T: TimeSource> std::fmt::Debug for ConfiguredSlidingWindowCounterBuilder<P, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfiguredSlidingWindowCounterBuilder")
.field("capacity", &self.capacity)
.field("bucket_duration", &self.bucket_duration)
.field("bucket_count", &self.bucket_count)
.field("time_source", &self.time_source)
.field("_precision", &std::any::type_name::<P>())
.finish()
}
}
impl<P: Precision, T: TimeSource> RateLimit for SlidingWindowCounter<P, T> {
#[inline(always)]
fn try_acquire(&self, tokens: Uint) -> SimpleRateLimitResult {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.try_acquire_at(current_tick, tokens)
}
#[inline(always)]
fn try_acquire_verbose(&self, tokens: Uint) -> RateLimitResult {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.try_acquire_verbose_at(current_tick, tokens)
.map_err(|e| RateLimitError::from_core_error(e, |ticks| P::from_ticks(ticks)))
}
#[inline(always)]
fn capacity_remaining(&self) -> Result<Uint, SimpleRateLimitError> {
let elapsed = self.time_source.now();
let current_tick = P::to_ticks(elapsed);
self.core.capacity_remaining(current_tick)
}
}
impl SlidingWindowCounterBuilder {
pub fn builder() -> Self {
Self {
capacity: None,
bucket_duration: None,
bucket_count: None,
}
}
pub fn capacity(mut self, capacity: Uint) -> Self {
self.capacity = Some(capacity);
self
}
pub fn bucket_duration(mut self, duration: Duration) -> Self {
self.bucket_duration = Some(duration);
self
}
pub fn bucket_count(mut self, count: Uint) -> Self {
self.bucket_count = Some(count);
self
}
pub fn with_time<T: TimeSource>(self, time_source: T) -> SlidingWindowCounterBuilderWithTime<T> {
SlidingWindowCounterBuilderWithTime {
capacity: self.capacity,
bucket_duration: self.bucket_duration,
bucket_count: self.bucket_count,
time_source,
}
}
}
impl Default for SlidingWindowCounterBuilder {
fn default() -> Self {
Self::builder()
}
}
impl<T: TimeSource> SlidingWindowCounterBuilderWithTime<T> {
pub fn with_precision<P: Precision>(self) -> ConfiguredSlidingWindowCounterBuilder<P, T> {
ConfiguredSlidingWindowCounterBuilder {
capacity: self.capacity,
bucket_duration: self.bucket_duration,
bucket_count: self.bucket_count,
time_source: self.time_source,
_precision: PhantomData,
}
}
}
impl<P: Precision, T: TimeSource> ConfiguredSlidingWindowCounterBuilder<P, T> {
pub fn build(self) -> BuildResult<SlidingWindowCounter<P, T>> {
let capacity = self.capacity.ok_or(BuildError::MissingArgument("capacity"))?;
let bucket_duration = self.bucket_duration.ok_or(BuildError::MissingArgument("bucket_duration"))?;
let bucket_count = self.bucket_count.ok_or(BuildError::MissingArgument("bucket_count"))?;
if capacity == 0 {
return Err(BuildError::InvalidArgument {
field: "capacity",
reason: "must be greater than 0"
});
}
if bucket_duration == Duration::ZERO {
return Err(BuildError::InvalidArgument {
field: "bucket_duration",
reason: "must be greater than zero"
});
}
if bucket_count == 0 {
return Err(BuildError::InvalidArgument {
field: "bucket_count",
reason: "must be greater than 0"
});
}
let bucket_ticks = P::to_ticks(bucket_duration);
let core = SlidingWindowCounterCore::new(capacity, bucket_ticks, bucket_count);
Ok(SlidingWindowCounter {
core,
time_source: self.time_source,
_precision: PhantomData,
})
}
}
pub use SlidingWindowCounter as PrecisionSlidingWindowCounter;