request_rate_limiter/
keyed.rs1use std::{fmt::Debug, hash::Hash, time::Duration};
7
8use async_trait::async_trait;
9use dashmap::{mapref::one::Ref, DashMap};
10
11use crate::{
12 algorithms::RateLimitAlgorithm,
13 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome, Token},
14};
15
16#[async_trait]
21pub trait RateLimiterKeyed<K>: Sync
22where
23 K: Hash + Eq + Send + Sync,
24{
25 async fn acquire(&self, key: &K) -> Token;
27
28 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token>;
31
32 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>);
35}
36
37pub struct DefaultRateLimiterKeyed<T, K, F>
42where
43 T: RateLimitAlgorithm + Debug,
44 K: Hash + Eq + Send + Sync,
45 F: Fn() -> T + Send + Sync,
46{
47 limiters: DashMap<K, DefaultRateLimiter<T>>,
48 algorithm_factory: F,
49}
50
51impl<T, K, F> DefaultRateLimiterKeyed<T, K, F>
52where
53 T: RateLimitAlgorithm + Debug,
54 K: Hash + Eq + Clone + Send + Sync,
55 F: Fn() -> T + Send + Sync,
56{
57 pub fn new(algorithm_factory: F) -> Self {
60 Self {
61 limiters: DashMap::new(),
62 algorithm_factory,
63 }
64 }
65
66 fn get_or_create_limiter(&self, key: &K) -> Ref<'_, K, DefaultRateLimiter<T>> {
68 if !self.limiters.contains_key(key) {
69 self.limiters.insert(
70 key.clone(),
71 DefaultRateLimiter::new((self.algorithm_factory)()),
72 );
73 }
74
75 self.limiters.get(key).unwrap()
76 }
77
78 pub fn active_keys(&self) -> usize {
80 self.limiters.len()
81 }
82
83 pub fn remove_key(&self, key: &K) -> bool {
86 self.limiters.remove(key).is_some()
87 }
88
89 pub fn clear(&self) {
91 self.limiters.clear();
92 }
93}
94
95#[async_trait]
96impl<T, K, F> RateLimiterKeyed<K> for DefaultRateLimiterKeyed<T, K, F>
97where
98 T: RateLimitAlgorithm + Send + Sync + Debug,
99 K: Hash + Eq + Clone + Send + Sync,
100 F: Fn() -> T + Send + Sync,
101{
102 async fn acquire(&self, key: &K) -> Token {
103 let limiter = self.get_or_create_limiter(key);
104 limiter.acquire().await
105 }
106
107 async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token> {
108 let limiter = self.get_or_create_limiter(key);
109 limiter.acquire_timeout(duration).await
110 }
111
112 async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>) {
113 let limiter = self.get_or_create_limiter(key);
114 limiter.release(token, outcome).await
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use crate::{
121 algorithms::Fixed,
122 keyed::{DefaultRateLimiterKeyed, RateLimiterKeyed},
123 limiter::RequestOutcome,
124 };
125
126 #[tokio::test]
127 async fn keyed_rate_limiter_works_independently_per_key() {
128 let limiter = DefaultRateLimiterKeyed::<_, String, _>::new(|| Fixed::new(1));
129
130 let key1 = "key1".to_string();
131 let key2 = "key2".to_string();
132 let token1 = limiter.acquire(&key1).await;
134 let token2 = limiter.acquire(&key2).await;
135
136 limiter
138 .release(&key1, token1, Some(RequestOutcome::Success))
139 .await;
140 limiter
141 .release(&key2, token2, Some(RequestOutcome::Success))
142 .await;
143
144 assert_eq!(limiter.active_keys(), 2);
145 }
146
147 #[tokio::test]
148 async fn keyed_rate_limiter_manages_keys() {
149 let limiter = DefaultRateLimiterKeyed::<_, String, _>::new(|| Fixed::new(10));
150
151 let _token1 = limiter.acquire(&"user1".to_string()).await;
153 let _token2 = limiter.acquire(&"user2".to_string()).await;
154 let _token3 = limiter.acquire(&"user3".to_string()).await;
155
156 assert_eq!(limiter.active_keys(), 3);
157
158 assert!(limiter.remove_key(&"user2".to_string()));
160 assert_eq!(limiter.active_keys(), 2);
161
162 assert!(!limiter.remove_key(&"nonexistent".to_string()));
164
165 limiter.clear();
167 assert_eq!(limiter.active_keys(), 0);
168 }
169}