use crate::{Bucket, BucketBuilder, RateLimitInfo};
use std::collections::HashMap;
use std::hash::Hash;
use std::mem::Discriminant;
use tokio::sync::Mutex;
#[cfg(feature = "cache")]
use std::collections::hash_map::Entry;
#[cfg(feature = "cache")]
use std::future::Future;
pub trait ToBucket<Key, Value = ()>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send,
{
fn to_bucket(&self) -> Option<Bucket<Key, Value>>;
}
pub struct LimitedRequestsModifier<'a, GroupKey, BucketKey = GroupKey, Value = ()>
where
GroupKey: Hash + PartialEq + Clone + Eq + Send + Sync,
BucketKey: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
group_key: GroupKey,
multi_bucket: &'a mut LimitedRequests<GroupKey, BucketKey, Value>,
bucket_builder: BucketBuilder,
}
impl<'a, GroupKey, BucketKey, Value> LimitedRequestsModifier<'a, GroupKey, BucketKey, Value>
where
GroupKey: Hash + PartialEq + Clone + Eq + Send + Sync,
BucketKey: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
fn new(
group_key: GroupKey,
multi_bucket: &'a mut LimitedRequests<GroupKey, BucketKey, Value>,
) -> Self {
Self {
group_key,
multi_bucket,
bucket_builder: BucketBuilder::default(),
}
}
#[inline]
#[must_use]
pub fn delay(mut self, secs: u64) -> Self {
self.bucket_builder.delay(secs);
self
}
#[inline]
#[must_use]
pub fn time_span(mut self, secs: u64) -> Self {
self.bucket_builder.time_span(secs);
self
}
#[inline]
#[must_use]
pub fn limit(mut self, n: u32) -> Self {
self.bucket_builder.limit(n);
self
}
#[inline]
#[must_use]
pub fn await_ratelimits(mut self, is_awaiting: bool) -> Self {
self.bucket_builder.await_ratelimits(is_awaiting);
self
}
pub fn build(mut self) {
self.multi_bucket
.insert_bucket(self.group_key, self.bucket_builder.build());
}
}
pub struct LimitedRequests<GroupKey, BucketKey = GroupKey, Value = ()>
where
GroupKey: Hash + PartialEq + Clone + Eq + Send + Sync,
BucketKey: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
buckets: HashMap<GroupKey, Mutex<Bucket<BucketKey, Value>>>,
}
impl<GroupKey, BucketKey, Value> LimitedRequests<GroupKey, BucketKey, Value>
where
GroupKey: Hash + PartialEq + Clone + Eq + Send + Sync,
BucketKey: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
#[must_use]
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
}
}
#[allow(clippy::option_if_let_else)]
pub fn insert_group(&mut self, bucket_key: GroupKey) -> Result<(), GroupKey>
where
GroupKey: ToBucket<BucketKey, Value>,
{
if let Some(bucket) = bucket_key.to_bucket() {
self.buckets.insert(bucket_key, Mutex::new(bucket));
Ok(())
} else {
Err(bucket_key)
}
}
pub fn build_group(
&mut self,
bucket_key: GroupKey,
) -> LimitedRequestsModifier<GroupKey, BucketKey, Value> {
LimitedRequestsModifier::new(bucket_key, self)
}
pub fn insert_groups<IterItem>(
&mut self,
bucket_keys: impl IntoIterator<Item = GroupKey>,
) -> Result<(), Vec<GroupKey>>
where
GroupKey: ToBucket<BucketKey, Value>,
{
let erroneous_keys: Vec<GroupKey> = bucket_keys
.into_iter()
.filter_map(|key| self.insert_group(key).err())
.collect();
if erroneous_keys.is_empty() {
Ok(())
} else {
Err(erroneous_keys)
}
}
pub fn insert_bucket(&mut self, group_key: GroupKey, bucket: Bucket<BucketKey, Value>) {
self.buckets.insert(group_key, Mutex::new(bucket));
}
pub fn insert_buckets(
&mut self,
group_key: &GroupKey,
buckets: impl IntoIterator<Item = Bucket<BucketKey, Value>>,
) {
buckets
.into_iter()
.for_each(|bucket| self.insert_bucket(group_key.clone(), bucket));
}
pub async fn hit_limit(
&self,
bucket_key: &GroupKey,
value_key: &BucketKey,
) -> Option<RateLimitInfo<Value>> {
let bucket = self.buckets.get(bucket_key)?;
bucket.lock().await.hit_limit(value_key).await
}
#[cfg(feature = "cache")]
pub async fn cache_or(
&self,
group_key: &GroupKey,
bucket_key: &BucketKey,
function: impl Future<Output = Option<Value>> + Send,
) -> Option<Value> {
let group = self.buckets.get(group_key)?;
let info = group.lock().await.hit_limit(bucket_key).await;
if let Some(info) = info {
info.cached
} else {
let to_cache = function.await;
if let Some(ref to_cache) = to_cache {
group
.lock()
.await
.add_cache_value(bucket_key, to_cache.clone())
.await;
}
to_cache
}
}
#[cfg(feature = "cache")]
pub async fn mut_cache_or(
&mut self,
group_key: &GroupKey,
bucket_key: &BucketKey,
function: impl Future<Output = Option<Value>> + Send,
) -> Option<Value>
where
GroupKey: ToBucket<BucketKey, Value>,
{
let group = match self.buckets.get(group_key) {
Some(group) => group,
None => match self.buckets.entry(group_key.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let value = entry.key().to_bucket()?;
entry.insert(Mutex::new(value))
}
},
};
let info = group.lock().await.hit_limit(bucket_key).await;
if let Some(info) = info {
info.cached
} else {
let to_cache = function.await;
if let Some(ref to_cache) = to_cache {
group
.lock()
.await
.add_cache_value(bucket_key, to_cache.clone())
.await;
}
to_cache
}
}
}
impl<GroupKey, BucketKey, Value> Default for LimitedRequests<GroupKey, BucketKey, Value>
where
GroupKey: Hash + PartialEq + Clone + Eq + Send + Sync,
BucketKey: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
fn default() -> Self {
Self::new()
}
}
pub struct CachedLimitedEnumsModifier<'a, Key, Value>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
key: &'a Key,
multi_bucket: &'a mut CachedLimitedEnums<Key, Value>,
bucket_builder: BucketBuilder,
}
impl<'a, 'b, Key, Value> CachedLimitedEnumsModifier<'a, Key, Value>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
fn new(key: &'a Key, multi_bucket: &'a mut CachedLimitedEnums<Key, Value>) -> Self {
Self {
key,
multi_bucket,
bucket_builder: BucketBuilder::default(),
}
}
#[inline]
#[must_use]
pub fn delay(mut self, secs: u64) -> Self {
self.bucket_builder.delay(secs);
self
}
#[inline]
#[must_use]
pub fn time_span(mut self, secs: u64) -> Self {
self.bucket_builder.time_span(secs);
self
}
#[inline]
#[must_use]
pub fn limit(mut self, n: u32) -> Self {
self.bucket_builder.limit(n);
self
}
#[inline]
#[must_use]
pub fn await_ratelimits(mut self, is_awaiting: bool) -> Self {
self.bucket_builder.await_ratelimits(is_awaiting);
self
}
pub fn build(mut self) {
self.multi_bucket
.insert_bucket(self.key, self.bucket_builder.build());
}
}
#[derive(Default)]
pub struct CachedLimitedEnums<Key, Value>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
buckets: HashMap<Discriminant<Key>, Mutex<Bucket<Key, Value>>>,
}
impl<Key, Value> CachedLimitedEnums<Key, Value>
where
Key: Hash + PartialEq + Clone + Eq + Send + Sync,
Value: Clone + Send + Sync,
{
#[must_use]
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
}
}
#[allow(clippy::mem_discriminant_non_enum)]
#[allow(clippy::option_if_let_else)]
pub fn insert_enum(&mut self, key: Key) -> Result<(), Key>
where
Key: ToBucket<Key, Value>,
{
if let Some(bucket) = key.to_bucket() {
self.buckets
.insert(std::mem::discriminant(&key), Mutex::new(bucket));
Ok(())
} else {
Err(key)
}
}
#[allow(clippy::mem_discriminant_non_enum)]
pub fn insert_bucket(&mut self, key: &Key, bucket: Bucket<Key, Value>) {
self.buckets
.insert(std::mem::discriminant(key), Mutex::new(bucket));
}
#[allow(clippy::mem_discriminant_non_enum)]
pub fn build_group<'a>(
&'a mut self,
key: &'a Key,
) -> CachedLimitedEnumsModifier<'a, Key, Value> {
CachedLimitedEnumsModifier::new(key, self)
}
pub fn insert_enums(&mut self, keys: impl IntoIterator<Item = Key>) -> Result<(), Vec<Key>>
where
Key: ToBucket<Key, Value>,
{
let erroneous_keys: Vec<Key> = keys
.into_iter()
.filter_map(|key| self.insert_enum(key).err())
.collect();
if erroneous_keys.is_empty() {
Ok(())
} else {
Err(erroneous_keys)
}
}
#[allow(clippy::mem_discriminant_non_enum)]
pub async fn hit_limit(&mut self, key: &Key) -> Option<RateLimitInfo<Value>> {
let bucket = self.buckets.get(&std::mem::discriminant(key))?;
bucket.lock().await.hit_limit(key).await
}
#[allow(clippy::mem_discriminant_non_enum)]
#[cfg(feature = "cache")]
pub async fn cache_or(
&self,
key: &Key,
function: impl Future<Output = Option<Value>> + Send,
) -> Option<Value> {
let group = self.buckets.get(&std::mem::discriminant(key))?;
let info = group.lock().await.hit_limit(key).await;
if let Some(info) = info {
info.cached
} else {
let to_cache = function.await;
if let Some(ref to_cache) = to_cache {
group
.lock()
.await
.add_cache_value(key, to_cache.clone())
.await;
}
to_cache
}
}
#[allow(clippy::mem_discriminant_non_enum)]
#[cfg(feature = "cache")]
pub async fn mut_cache_or(
&mut self,
key: &Key,
function: impl Future<Output = Option<Value>> + Send,
) -> Option<Value>
where
Key: ToBucket<Key, Value>,
{
let bucket = key.to_bucket()?;
let group = self
.buckets
.entry(std::mem::discriminant(key))
.or_insert_with(|| Mutex::new(bucket));
let info = group.lock().await.hit_limit(key).await;
if let Some(info) = info {
info.cached
} else {
let to_cache = function.await;
if let Some(ref to_cache) = to_cache {
group
.lock()
.await
.add_cache_value(key, to_cache.clone())
.await;
}
to_cache
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_routed_caching() {
use crate::BucketBuilder;
use ToBucket;
#[derive(Hash, PartialEq, Eq, Clone)]
enum SpecificRoute {
GetUser(u64),
GetAllUsers,
};
impl ToBucket<SpecificRoute, String> for SpecificRoute {
fn to_bucket(&self) -> Option<Bucket<SpecificRoute, String>> {
Some(match self {
Self::GetUser(_) => BucketBuilder::new().limit(2).time_span(60).build(),
Self::GetAllUsers => BucketBuilder::new().limit(1).time_span(600).build(),
})
}
}
let mut buckets: CachedLimitedEnums<SpecificRoute, String> = CachedLimitedEnums::new();
let value = buckets
.mut_cache_or(&SpecificRoute::GetUser(1), async move {
Some("Ferris".to_string())
})
.await;
assert_eq!(value, Some("Ferris".to_string()));
let value = buckets
.mut_cache_or(&SpecificRoute::GetUser(1), async move {
Some("Ferris2".to_string())
})
.await;
assert_eq!(value, Some("Ferris2".to_string()));
let value = buckets
.mut_cache_or(&SpecificRoute::GetAllUsers, async move {
Some("Ferris, Ferris2".to_string())
})
.await;
assert_eq!(value, Some("Ferris, Ferris2".to_string()));
let value = buckets
.mut_cache_or(&SpecificRoute::GetAllUsers, async move {
Some("Ferris, Ferris2, Ferris3".to_string())
})
.await;
assert_eq!(value, Some("Ferris, Ferris2".to_string()));
}
}