#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![deny(clippy::nursery)]
#![deny(clippy::cargo)]
#![deny(missing_docs)]
use std::{
collections::HashMap,
hash::Hash,
time::{Duration, Instant},
};
#[cfg(not(feature = "cache"))]
use std::marker::PhantomData;
#[cfg(feature = "cache")]
use std::future::Future;
pub mod multi_bucket;
#[cfg(all(feature = "tokio_0_2", not(feature = "tokio")))]
extern crate tokio_compat as tokio;
#[cfg(all(feature = "tokio_0_2", not(feature = "tokio")))]
use tokio::time::delay_for as sleep;
#[cfg(all(feature = "tokio"))]
use tokio::time::sleep;
#[derive(Clone, Debug)]
pub enum RateLimitAction {
Delayed,
FailedDelay,
Cancelled,
}
pub struct Bucket<Key, Value = ()>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send,
{
pub(crate) ratelimit: RateLimitSpec,
pub(crate) tickets_for: HashMap<Key, RateLimitInstance<Value>>,
pub(crate) await_ratelimits: bool,
}
impl<Key: Hash + PartialEq + Clone + Eq + Send + Sync, Value: Clone + Send> Bucket<Key, Value> {
pub async fn hit_limit(&mut self, key: &Key) -> Option<RateLimitInfo<Value>> {
let now = Instant::now();
let Self {
tickets_for,
ratelimit,
..
} = self;
let ticket_owner = match tickets_for.get_mut(key) {
Some(bucket) => bucket,
None => tickets_for
.entry(key.clone())
.or_insert_with(|| RateLimitInstance::new(now)),
};
if let Some((timespan, limit)) = ratelimit.limit {
if (ticket_owner.tickets + 1) > limit {
if let Some(ratelimit) =
(ticket_owner.set_time + timespan).checked_duration_since(now)
{
return Self::rating(ticket_owner, ratelimit, self.await_ratelimits).await;
} else {
ticket_owner.tickets = 0;
ticket_owner.set_time = now;
#[cfg(feature = "cache")]
{
ticket_owner.cached_value = None;
}
}
}
}
if let Some(ratelimit) = ticket_owner
.last_time
.and_then(|x| (x + ratelimit.delay).checked_duration_since(now))
{
return Self::rating(ticket_owner, ratelimit, self.await_ratelimits).await;
} else {
ticket_owner.awaiting = ticket_owner.awaiting.saturating_sub(1);
ticket_owner.tickets += 1;
ticket_owner.is_first_try = true;
ticket_owner.last_time = Some(now);
#[cfg(feature = "cache")]
{
ticket_owner.cached_value = None;
}
}
None
}
#[cfg(feature = "cache")]
pub async fn hit_or_cache(
&mut self,
key: &Key,
function: impl Future<Output = Option<Value>> + Send,
) -> Option<Value> {
if let Some(cached_value) = self.hit_limit(key).await.and_then(|i| i.cached) {
Some(cached_value)
} else {
let to_cache = function.await?;
let value = to_cache.clone();
self.add_cache_value(key, value).await;
Some(to_cache)
}
}
#[cfg(feature = "cache")]
pub async fn add_cache_value(&mut self, key: &Key, value: Value) {
if let Some(instance) = self.tickets_for.get_mut(key) {
instance.cached_value = Some(value);
}
}
async fn rating(
ticket_owner: &mut RateLimitInstance<Value>,
ratelimit: Duration,
await_ratelimits: bool,
) -> Option<RateLimitInfo<Value>> {
let was_first_try = ticket_owner.is_first_try;
let action = if await_ratelimits {
RateLimitAction::Delayed
} else {
RateLimitAction::Cancelled
};
if let RateLimitAction::Delayed = action {
sleep(ratelimit).await;
return None;
}
Some(RateLimitInfo {
rate_limit: ratelimit,
active_delays: ticket_owner.awaiting,
#[cfg(feature = "cache")]
cached: ticket_owner.cached_value.clone(),
action,
is_first_try: was_first_try,
#[cfg(not(feature = "cache"))]
phantom: PhantomData,
})
}
}
pub(crate) struct RateLimitSpec {
pub delay: Duration,
pub limit: Option<(Duration, u32)>,
}
pub(crate) struct RateLimitInstance<Value> {
pub last_time: Option<Instant>,
pub set_time: Instant,
pub tickets: u32,
pub awaiting: u32,
pub is_first_try: bool,
#[cfg(feature = "cache")]
pub cached_value: Option<Value>,
#[cfg(not(feature = "cache"))]
phantom: PhantomData<Value>,
}
impl<Value> RateLimitInstance<Value> {
const fn new(creation_time: Instant) -> Self {
Self {
last_time: None,
set_time: creation_time,
tickets: 0,
awaiting: 0,
is_first_try: true,
#[cfg(feature = "cache")]
cached_value: None,
#[cfg(not(feature = "cache"))]
phantom: PhantomData,
}
}
}
#[derive(Clone, Debug)]
pub struct RateLimitInfo<Value: Clone> {
pub rate_limit: Duration,
pub active_delays: u32,
pub is_first_try: bool,
pub action: RateLimitAction,
#[cfg(feature = "cache")]
pub cached: Option<Value>,
#[cfg(not(feature = "cache"))]
phantom: PhantomData<Value>,
}
pub struct BucketBuilder {
pub(crate) delay: Duration,
pub(crate) time_span: Duration,
pub(crate) limit: u32,
pub(crate) await_ratelimits: bool,
}
impl Default for BucketBuilder {
fn default() -> Self {
Self {
delay: Duration::default(),
time_span: Duration::default(),
limit: 1,
await_ratelimits: false,
}
}
}
impl BucketBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn delay(&mut self, secs: u64) -> &mut Self {
self.delay = Duration::from_secs(secs);
self
}
#[inline]
pub fn time_span(&mut self, secs: u64) -> &mut Self {
self.time_span = Duration::from_secs(secs);
self
}
#[inline]
pub fn limit(&mut self, n: u32) -> &mut Self {
self.limit = n;
self
}
#[inline]
pub fn await_ratelimits(&mut self, is_awaiting: bool) -> &mut Self {
self.await_ratelimits = is_awaiting;
self
}
#[inline]
pub fn build<Key, Value>(&mut self) -> Bucket<Key, Value>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send,
{
Bucket {
ratelimit: RateLimitSpec {
delay: self.delay,
limit: Some((self.time_span, self.limit)),
},
tickets_for: HashMap::new(),
await_ratelimits: self.await_ratelimits,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_caching() {
#[derive(Clone, Hash, PartialEq, Eq)]
enum Route {
GetUser(u64),
};
let mut bucket: Bucket<Route, String> =
BucketBuilder::new().limit(2).time_span(60).delay(5).build();
let value = bucket
.hit_or_cache(
&Route::GetUser(1),
async move { Some("success1".to_string()) },
)
.await;
assert_eq!(value, Some("success1".to_string()));
let value = bucket
.hit_or_cache(
&Route::GetUser(1),
async move { Some("success2".to_string()) },
)
.await;
assert_eq!(value, Some("success1".to_string()));
}
}