request_rate_limiter/
keyed.rs1use std::{fmt::Debug, hash::Hash, time::Duration};
7
8use async_trait::async_trait;
9use dashmap::DashMap;
10
11use crate::{
12 algorithms::RateLimitAlgorithm,
13 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome, Token},
14};
15
16#[async_trait]
21pub trait RateLimiterKeyed<K>: Sync
22where
23 K: Hash + Eq + Send + Sync,
24{
25 async fn acquire(&self, key: &K) -> Token;
27
28 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token>;
31
32 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>);
35}
36
37#[derive(Debug)]
42pub struct DefaultRateLimiterKeyed<T, K>
43where
44 T: RateLimitAlgorithm + Debug + Clone,
45 K: Hash + Eq + Send + Sync,
46{
47 limiters: DashMap<K, DefaultRateLimiter<T>>,
48 algorithm: T,
49}
50
51impl<T, K> DefaultRateLimiterKeyed<T, K>
52where
53 T: RateLimitAlgorithm + Debug + Clone,
54 K: Hash + Eq + Clone + Send + Sync,
55{
56 pub fn new(algorithm: T) -> Self {
59 Self {
60 limiters: DashMap::new(),
61 algorithm,
62 }
63 }
64
65 fn get_or_create_limiter(&self, key: &K) -> DefaultRateLimiter<T> {
68 if let Some(limiter) = self.limiters.get(key) {
69 return limiter.value().clone();
70 }
71
72 let limiter_ref = self
73 .limiters
74 .entry(key.clone())
75 .or_insert_with(|| DefaultRateLimiter::new(self.algorithm.clone()));
76
77 limiter_ref.value().clone()
78 }
79
80 pub fn active_keys(&self) -> usize {
82 self.limiters.len()
83 }
84
85 pub fn remove_key(&self, key: &K) -> bool {
88 self.limiters.remove(key).is_some()
89 }
90
91 pub fn clear(&self) {
93 self.limiters.clear();
94 }
95}
96
97#[async_trait]
98impl<T, K> RateLimiterKeyed<K> for DefaultRateLimiterKeyed<T, K>
99where
100 T: RateLimitAlgorithm + Send + Clone + Sync + Debug,
101 K: Hash + Eq + Clone + Send + Sync,
102{
103 async fn acquire(&self, key: &K) -> Token {
104 let limiter = self.get_or_create_limiter(key);
105 limiter.acquire().await
106 }
107
108 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token> {
109 let limiter = self.get_or_create_limiter(key);
110 limiter.acquire_timeout(duration).await
111 }
112
113 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>) {
114 let limiter = self.get_or_create_limiter(key);
115 limiter.release(token, outcome).await
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use crate::{
122 algorithms::Fixed,
123 keyed::{DefaultRateLimiterKeyed, RateLimiterKeyed},
124 limiter::RequestOutcome,
125 };
126
127 #[tokio::test]
128 async fn keyed_rate_limiter_works_independently_per_key() {
129 let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(1));
130
131 let key1 = "key1".to_string();
132 let key2 = "key2".to_string();
133 let token1 = limiter.acquire(&key1).await;
135 let token2 = limiter.acquire(&key2).await;
136
137 limiter
139 .release(&key1, token1, Some(RequestOutcome::Success))
140 .await;
141 limiter
142 .release(&key2, token2, Some(RequestOutcome::Success))
143 .await;
144
145 assert_eq!(limiter.active_keys(), 2);
146 }
147
148 #[tokio::test]
149 async fn keyed_rate_limiter_manages_keys() {
150 let limiter = DefaultRateLimiterKeyed::<_, String>::new(Fixed::new(10));
151
152 let _token1 = limiter.acquire(&"user1".to_string()).await;
154 let _token2 = limiter.acquire(&"user2".to_string()).await;
155 let _token3 = limiter.acquire(&"user3".to_string()).await;
156
157 assert_eq!(limiter.active_keys(), 3);
158
159 assert!(limiter.remove_key(&"user2".to_string()));
161 assert_eq!(limiter.active_keys(), 2);
162
163 assert!(!limiter.remove_key(&"nonexistent".to_string()));
165
166 limiter.clear();
168 assert_eq!(limiter.active_keys(), 0);
169 }
170}