use std::{fmt::Debug, hash::Hash, time::Duration};
use async_trait::async_trait;
use dashmap::{mapref::one::Ref, DashMap};
use crate::{
algorithms::RateLimitAlgorithm,
limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome, Token},
};
#[async_trait]
pub trait RateLimiterKeyed<K>: Sync
where
K: Hash + Eq + Send + Sync,
{
async fn acquire(&self, key: &K) -> Token;
async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token>;
async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>);
}
#[derive(Debug)]
pub struct DefaultRateLimiterKeyed<T, K>
where
T: RateLimitAlgorithm + Debug + Clone,
K: Hash + Eq + Send + Sync,
{
limiters: DashMap<K, DefaultRateLimiter<T>>,
algorithm: T,
}
impl<T, K> DefaultRateLimiterKeyed<T, K>
where
T: RateLimitAlgorithm + Debug + Clone,
K: Hash + Eq + Clone + Send + Sync,
{
pub fn new(algorithm: T) -> Self {
Self {
limiters: DashMap::new(),
algorithm,
}
}
fn get_or_create_limiter(&self, key: &K) -> Ref<'_, K, DefaultRateLimiter<T>> {
self.limiters
.entry(key.clone())
.or_insert_with(|| DefaultRateLimiter::new(self.algorithm.clone()));
self.limiters.get(key).unwrap()
}
pub fn active_keys(&self) -> usize {
self.limiters.len()
}
pub fn remove_key(&self, key: &K) -> bool {
self.limiters.remove(key).is_some()
}
pub fn clear(&self) {
self.limiters.clear();
}
}
#[async_trait]
impl<T, K> RateLimiterKeyed<K> for DefaultRateLimiterKeyed<T, K>
where
T: RateLimitAlgorithm + Send + Clone + Sync + Debug,
K: Hash + Eq + Clone + Send + Sync,
{
async fn acquire(&self, key: &K) -> Token {
let limiter = self.get_or_create_limiter(key);
limiter.acquire().await
}
async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token> {
let limiter = self.get_or_create_limiter(key);
limiter.acquire_timeout(duration).await
}
async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>) {
let limiter = self.get_or_create_limiter(key);
limiter.release(token, outcome).await
}
}
#[cfg(test)]
mod tests {
use crate::{
algorithms::Fixed,
keyed::{DefaultRateLimiterKeyed, RateLimiterKeyed},
limiter::RequestOutcome,
};
#[tokio::test]
async fn keyed_rate_limiter_works_independently_per_key() {
let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(1));
let key1 = "key1".to_string();
let key2 = "key2".to_string();
let token1 = limiter.acquire(&key1).await;
let token2 = limiter.acquire(&key2).await;
limiter
.release(&key1, token1, Some(RequestOutcome::Success))
.await;
limiter
.release(&key2, token2, Some(RequestOutcome::Success))
.await;
assert_eq!(limiter.active_keys(), 2);
}
#[tokio::test]
async fn keyed_rate_limiter_manages_keys() {
let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(10));
let _token1 = limiter.acquire(&"user1".to_string()).await;
let _token2 = limiter.acquire(&"user2".to_string()).await;
let _token3 = limiter.acquire(&"user3".to_string()).await;
assert_eq!(limiter.active_keys(), 3);
assert!(limiter.remove_key(&"user2".to_string()));
assert_eq!(limiter.active_keys(), 2);
assert!(!limiter.remove_key(&"nonexistent".to_string()));
limiter.clear();
assert_eq!(limiter.active_keys(), 0);
}
}