request_rate_limiter/
keyed.rs1use std::{fmt::Debug, hash::Hash, time::Duration};
7
8use async_trait::async_trait;
9use dashmap::{mapref::one::Ref, DashMap};
10
11use crate::{
12 algorithms::RateLimitAlgorithm,
13 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome, Token},
14};
15
16#[async_trait]
21pub trait RateLimiterKeyed<K>: Sync
22where
23 K: Hash + Eq + Send + Sync,
24{
25 async fn acquire(&self, key: &K) -> Token;
27
28 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token>;
31
32 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>);
35}
36
37pub struct DefaultRateLimiterKeyed<T, K>
42where
43 T: RateLimitAlgorithm + Debug + Clone,
44 K: Hash + Eq + Send + Sync,
45{
46 limiters: DashMap<K, DefaultRateLimiter<T>>,
47 algorithm: T,
48}
49
50impl<T, K> DefaultRateLimiterKeyed<T, K>
51where
52 T: RateLimitAlgorithm + Debug + Clone,
53 K: Hash + Eq + Clone + Send + Sync,
54{
55 pub fn new(algorithm: T) -> Self {
58 Self {
59 limiters: DashMap::new(),
60 algorithm,
61 }
62 }
63
64 fn get_or_create_limiter(&self, key: &K) -> Ref<'_, K, DefaultRateLimiter<T>> {
66 if !self.limiters.contains_key(key) {
67 self.limiters
68 .insert(key.clone(), DefaultRateLimiter::new(self.algorithm.clone()));
69 }
70
71 self.limiters.get(key).unwrap()
72 }
73
74 pub fn active_keys(&self) -> usize {
76 self.limiters.len()
77 }
78
79 pub fn remove_key(&self, key: &K) -> bool {
82 self.limiters.remove(key).is_some()
83 }
84
85 pub fn clear(&self) {
87 self.limiters.clear();
88 }
89}
90
91#[async_trait]
92impl<T, K> RateLimiterKeyed<K> for DefaultRateLimiterKeyed<T, K>
93where
94 T: RateLimitAlgorithm + Send + Clone + Sync + Debug,
95 K: Hash + Eq + Clone + Send + Sync,
96{
97 async fn acquire(&self, key: &K) -> Token {
98 let limiter = self.get_or_create_limiter(key);
99 limiter.acquire().await
100 }
101
102 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token> {
103 let limiter = self.get_or_create_limiter(key);
104 limiter.acquire_timeout(duration).await
105 }
106
107 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>) {
108 let limiter = self.get_or_create_limiter(key);
109 limiter.release(token, outcome).await
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::{
116 algorithms::Fixed,
117 keyed::{DefaultRateLimiterKeyed, RateLimiterKeyed},
118 limiter::RequestOutcome,
119 };
120
121 #[tokio::test]
122 async fn keyed_rate_limiter_works_independently_per_key() {
123 let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(1));
124
125 let key1 = "key1".to_string();
126 let key2 = "key2".to_string();
127 let token1 = limiter.acquire(&key1).await;
129 let token2 = limiter.acquire(&key2).await;
130
131 limiter
133 .release(&key1, token1, Some(RequestOutcome::Success))
134 .await;
135 limiter
136 .release(&key2, token2, Some(RequestOutcome::Success))
137 .await;
138
139 assert_eq!(limiter.active_keys(), 2);
140 }
141
142 #[tokio::test]
143 async fn keyed_rate_limiter_manages_keys() {
144 let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(10));
145
146 let _token1 = limiter.acquire(&"user1".to_string()).await;
148 let _token2 = limiter.acquire(&"user2".to_string()).await;
149 let _token3 = limiter.acquire(&"user3".to_string()).await;
150
151 assert_eq!(limiter.active_keys(), 3);
152
153 assert!(limiter.remove_key(&"user2".to_string()));
155 assert_eq!(limiter.active_keys(), 2);
156
157 assert!(!limiter.remove_key(&"nonexistent".to_string()));
159
160 limiter.clear();
162 assert_eq!(limiter.active_keys(), 0);
163 }
164}