request_rate_limiter/
keyed.rs

1//! Keyed rate limiting functionality.
2//!
3//! This module provides rate limiting with per-key isolation, allowing independent
4//! rate limiting across different clients, users, or request types.
5
6use std::{fmt::Debug, hash::Hash, time::Duration};
7
8use async_trait::async_trait;
9use dashmap::{mapref::one::Ref, DashMap};
10
11use crate::{
12    algorithms::RateLimitAlgorithm,
13    limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome, Token},
14};
15
16/// Controls the rate of requests over time with per-key rate limiting.
17///
18/// Each key maintains its own rate limit state, allowing for independent
19/// rate limiting across different clients, users, or request types.
20#[async_trait]
21pub trait RateLimiterKeyed<K>: Sync
22where
23    K: Hash + Eq + Send + Sync,
24{
25    /// Acquire permission to make a request for a specific key. Waits until a token is available.
26    async fn acquire(&self, key: &K) -> Token;
27
28    /// Acquire permission to make a request for a specific key with a timeout.
29    /// Returns a token if successful.
30    async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token>;
31
32    /// Release the token and record the outcome of the request for the specific key.
33    /// The response time is calculated from when the token was acquired.
34    async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>);
35}
36
37/// A keyed rate limiter that maintains separate rate limiters for each key.
38///
39/// Uses DashMap for efficient concurrent access to per-key rate limiters.
40/// Each key gets its own independent rate limiter instance.
41pub struct DefaultRateLimiterKeyed<T, K, F>
42where
43    T: RateLimitAlgorithm + Debug,
44    K: Hash + Eq + Send + Sync,
45    F: Fn() -> T + Send + Sync,
46{
47    limiters: DashMap<K, DefaultRateLimiter<T>>,
48    algorithm_factory: F,
49}
50
51impl<T, K, F> DefaultRateLimiterKeyed<T, K, F>
52where
53    T: RateLimitAlgorithm + Debug,
54    K: Hash + Eq + Clone + Send + Sync,
55    F: Fn() -> T + Send + Sync,
56{
57    /// Create a new keyed rate limiter with the given algorithm factory function.
58    /// Each key will get a fresh instance of the algorithm created by calling the factory.
59    pub fn new(algorithm_factory: F) -> Self {
60        Self {
61            limiters: DashMap::new(),
62            algorithm_factory,
63        }
64    }
65
66    /// Get or create a rate limiter for the given key.
67    fn get_or_create_limiter(&self, key: &K) -> Ref<'_, K, DefaultRateLimiter<T>> {
68        if !self.limiters.contains_key(key) {
69            self.limiters.insert(
70                key.clone(),
71                DefaultRateLimiter::new((self.algorithm_factory)()),
72            );
73        }
74
75        self.limiters.get(key).unwrap()
76    }
77
78    /// Get the number of active keys being tracked.
79    pub fn active_keys(&self) -> usize {
80        self.limiters.len()
81    }
82
83    /// Remove a key and its associated rate limiter.
84    /// Returns true if the key existed and was removed.
85    pub fn remove_key(&self, key: &K) -> bool {
86        self.limiters.remove(key).is_some()
87    }
88
89    /// Clear all keys and their associated rate limiters.
90    pub fn clear(&self) {
91        self.limiters.clear();
92    }
93}
94
95#[async_trait]
96impl<T, K, F> RateLimiterKeyed<K> for DefaultRateLimiterKeyed<T, K, F>
97where
98    T: RateLimitAlgorithm + Send + Sync + Debug,
99    K: Hash + Eq + Clone + Send + Sync,
100    F: Fn() -> T + Send + Sync,
101{
102    async fn acquire(&self, key: &K) -> Token {
103        let limiter = self.get_or_create_limiter(key);
104        limiter.acquire().await
105    }
106
107    async fn acquire_timeout(&self, key: &K, duration: Duration) -> Option<Token> {
108        let limiter = self.get_or_create_limiter(key);
109        limiter.acquire_timeout(duration).await
110    }
111
112    async fn release(&self, key: &K, token: Token, outcome: Option<RequestOutcome>) {
113        let limiter = self.get_or_create_limiter(key);
114        limiter.release(token, outcome).await
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use crate::{
121        algorithms::Fixed,
122        keyed::{DefaultRateLimiterKeyed, RateLimiterKeyed},
123        limiter::RequestOutcome,
124    };
125
126    #[tokio::test]
127    async fn keyed_rate_limiter_works_independently_per_key() {
128        let limiter = DefaultRateLimiterKeyed::<_, String, _>::new(|| Fixed::new(1));
129
130        let key1 = "key1".to_string();
131        let key2 = "key2".to_string();
132        // Acquire tokens for different keys - should work independently
133        let token1 = limiter.acquire(&key1).await;
134        let token2 = limiter.acquire(&key2).await;
135
136        // Both should succeed because they're different keys
137        limiter
138            .release(&key1, token1, Some(RequestOutcome::Success))
139            .await;
140        limiter
141            .release(&key2, token2, Some(RequestOutcome::Success))
142            .await;
143
144        assert_eq!(limiter.active_keys(), 2);
145    }
146
147    #[tokio::test]
148    async fn keyed_rate_limiter_manages_keys() {
149        let limiter = DefaultRateLimiterKeyed::<_, String, _>::new(|| Fixed::new(10));
150
151        // Create limiters for multiple keys
152        let _token1 = limiter.acquire(&"user1".to_string()).await;
153        let _token2 = limiter.acquire(&"user2".to_string()).await;
154        let _token3 = limiter.acquire(&"user3".to_string()).await;
155
156        assert_eq!(limiter.active_keys(), 3);
157
158        // Remove one key
159        assert!(limiter.remove_key(&"user2".to_string()));
160        assert_eq!(limiter.active_keys(), 2);
161
162        // Try to remove non-existent key
163        assert!(!limiter.remove_key(&"nonexistent".to_string()));
164
165        // Clear all keys
166        limiter.clear();
167        assert_eq!(limiter.active_keys(), 0);
168    }
169}