axum_limit/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code, missing_docs, clippy::unwrap_used)]
3
4mod key;
5
6use axum_core::extract::{FromRef, FromRequestParts};
7use axum_core::response::{IntoResponse, Response};
8use dashmap::DashMap;
9use http::request::Parts;
10use http::StatusCode;
11use std::error::Error;
12use std::fmt::Display;
13use std::hash::Hash;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18/// Represents a rate limit configuration with generic parameters for count and time period.
19/// This struct uses generics to allow flexible integration with any extractor that implements the `Key` trait.
20#[derive(Debug, Clone, Copy, Default)]
21pub struct Limit<const COUNT: usize, const PER: u64, K>(pub K::Extractor)
22where
23    K: Key;
24
25/// Rate limit configured to apply per second.
26pub type LimitPerSecond<const COUNT: usize, K> = Limit<COUNT, 1000, K>;
27
28/// Rate limit configured to apply per minute.
29pub type LimitPerMinute<const COUNT: usize, K> = Limit<COUNT, 60_000, K>;
30
31/// Rate limit configured to apply per hour.
32pub type LimitPerHour<const COUNT: usize, K> = Limit<COUNT, 3_600_000, K>;
33
34/// Rate limit configured to apply per day.
35pub type LimitPerDay<const COUNT: usize, K> = Limit<COUNT, 86_400_000, K>;
36
37impl<const COUNT: usize, const PER: u64, K> AsRef<K::Extractor> for Limit<COUNT, PER, K>
38where
39    K: Key,
40{
41    fn as_ref(&self) -> &K::Extractor {
42        &self.0
43    }
44}
45
46impl<const COUNT: usize, const PER: u64, K> AsMut<K::Extractor> for Limit<COUNT, PER, K>
47where
48    K: Key,
49{
50    fn as_mut(&mut self) -> &mut K::Extractor {
51        &mut self.0
52    }
53}
54
55impl<const COUNT: usize, const PER: u64, K> Deref for Limit<COUNT, PER, K>
56where
57    K: Key,
58{
59    type Target = K::Extractor;
60
61    fn deref(&self) -> &Self::Target {
62        &self.0
63    }
64}
65
66impl<const COUNT: usize, const PER: u64, K> DerefMut for Limit<COUNT, PER, K>
67where
68    K: Key,
69{
70    fn deref_mut(&mut self) -> &mut Self::Target {
71        &mut self.0
72    }
73}
74
75impl<const COUNT: usize, const PER: u64, K> Display for Limit<COUNT, PER, K>
76where
77    K: Key,
78    K::Extractor: Display,
79{
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        self.0.fmt(f)
82    }
83}
84
85impl<const COUNT: usize, const PER: u64, K> Limit<COUNT, PER, K>
86where
87    K: Key,
88{
89    /// Returns the count of requests allowed within the specified period.
90    pub const fn count() -> usize {
91        COUNT
92    }
93
94    /// Returns the period (in milliseconds) for which the limit applies.
95    pub const fn per() -> u64 {
96        PER
97    }
98
99    /// Consumes the limit and returns the inner extractor, allowing direct access to the underlying mechanism.
100    pub fn into_inner(self) -> K::Extractor {
101        self.0
102    }
103}
104
105/// Trait defining the requirements for a key extractor, which is used to uniquely identify limit subjects
106/// and extract rate limit parameters dynamically in request processing.
107#[async_trait::async_trait]
108pub trait Key: Eq + Hash + Send + Sync {
109    /// The `Extractor` associated type represents a component capable of extracting key-specific information from request parts.
110    /// This information is then used to manage and enforce rate limits dynamically within the application.
111    type Extractor;
112    /// Creates an instance of `Self` from the provided extractor reference, allowing extraction of key data.
113    fn from_extractor(extractor: &Self::Extractor) -> Self;
114}
115
116/// Implements a token bucket for rate limiting.
117/// This struct manages the tokens for rate limiting, providing methods to acquire and refill tokens based on time elapsed.
118struct TokenBucket {
119    tokens: usize,
120    last_refill_time: Instant,
121    refill_duration: Duration,
122}
123
124impl TokenBucket {
125    /// Constructs a new `TokenBucket` with a specific number of tokens and a refill period.
126    fn new(tokens: impl Into<usize>, per: impl Into<u64>) -> Self {
127        Self {
128            tokens: tokens.into(),
129            last_refill_time: Instant::now(),
130            refill_duration: Duration::from_millis(per.into()),
131        }
132    }
133
134    /// Attempts to acquire a token. Returns `true` if a token was successfully acquired.
135    fn try_acquire(&mut self) -> bool {
136        self.refill();
137        if self.tokens > 0 {
138            self.tokens -= 1;
139            true
140        } else {
141            false
142        }
143    }
144
145    /// Refills tokens based on time elapsed since the last refill.
146    fn refill(&mut self) {
147        let now = Instant::now();
148        let elapsed = now.duration_since(self.last_refill_time);
149
150        // Calculate the elapsed time in milliseconds
151        if elapsed >= self.refill_duration {
152            let elapsed_millis = elapsed.as_millis() as u64; // Convert elapsed time to milliseconds
153            let refill_duration_millis = self.refill_duration.as_millis() as u64; // Convert refill duration to milliseconds
154
155            // Calculate the number of new tokens to add
156            let new_tokens = (elapsed_millis / refill_duration_millis) as usize;
157            self.tokens += new_tokens;
158
159            // Reset the last refill time to avoid under-refilling tokens
160            self.last_refill_time =
161                now - Duration::from_millis(elapsed_millis % refill_duration_millis);
162        }
163    }
164}
165
166/// Manages the state of rate limits for various keys.
167/// This struct holds a concurrent map of keys to their corresponding `TokenBucket` instances,
168/// enabling efficient state management across asynchronous tasks.
169#[derive(Clone)]
170pub struct LimitState<K>
171where
172    K: Key,
173{
174    rate_limits: Arc<DashMap<K, TokenBucket>>,
175}
176
177impl<K> Default for LimitState<K>
178where
179    K: Key,
180{
181    /// Constructs a new `LimitState` with an empty map of rate limits.
182    fn default() -> Self {
183        Self {
184            rate_limits: Arc::new(DashMap::new()),
185        }
186    }
187}
188
189impl<K> LimitState<K>
190where
191    K: Key,
192{
193    /// Checks and updates the rate limit for the given key, returning `true` if the request can proceed.
194    pub fn check(&self, key: K, count: usize, per: u64) -> bool {
195        let mut bucket = self
196            .rate_limits
197            .entry(key)
198            .or_insert_with(|| TokenBucket::new(count, per));
199        bucket.try_acquire()
200    }
201}
202
203#[async_trait::async_trait]
204impl<const C: usize, const P: u64, K, S> FromRequestParts<S> for Limit<C, P, K>
205where
206    LimitState<K>: FromRef<S>,
207    S: Send + Sync,
208    K: Key,
209    K::Extractor: FromRequestParts<S>,
210{
211    type Rejection = LimitRejection<<<K as Key>::Extractor as FromRequestParts<S>>::Rejection>;
212
213    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
214        let key_extractor = match K::Extractor::from_request_parts(parts, state).await {
215            Ok(ke) => ke,
216            Err(rejection) => return Err(LimitRejection::KeyExtractionFailure(rejection)),
217        };
218
219        let limit_state: LimitState<K> = FromRef::from_ref(state);
220        let key = K::from_extractor(&key_extractor);
221        if limit_state.check(key, C, P) {
222            Ok(Self(key_extractor))
223        } else {
224            Err(LimitRejection::RateLimitExceeded)
225        }
226    }
227}
228
229/// Enumerates possible failure modes for rate limiting when extracting from request parts.
230#[derive(Debug)]
231pub enum LimitRejection<R> {
232    /// Indicates a failure during key extraction, storing the underlying rejection reason.
233    KeyExtractionFailure(R),
234
235    /// Indicates that the rate limit has been exceeded.
236    RateLimitExceeded,
237}
238
239impl<R: Display> Display for LimitRejection<R> {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        match self {
242            LimitRejection::KeyExtractionFailure(r) => write!(f, "{r}"),
243            LimitRejection::RateLimitExceeded => write!(f, "Rate limit exceeded."),
244        }
245    }
246}
247
248impl<R: Error + 'static> Error for LimitRejection<R> {
249    fn source(&self) -> Option<&(dyn Error + 'static)> {
250        match self {
251            LimitRejection::KeyExtractionFailure(ve) => Some(ve),
252            LimitRejection::RateLimitExceeded => None,
253        }
254    }
255}
256
257impl<R: IntoResponse> IntoResponse for LimitRejection<R> {
258    fn into_response(self) -> Response {
259        match self {
260            LimitRejection::KeyExtractionFailure(rejection) => rejection.into_response(),
261            LimitRejection::RateLimitExceeded => {
262                (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded.").into_response()
263            }
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use axum::routing::get;
272    use axum::Router;
273    use axum_test::TestServer;
274    use http::Uri;
275    use std::future::IntoFuture;
276
277    #[tokio::test]
278    async fn limit() {
279        const TEST_ROUTE0: &str = "/limit0";
280        const TEST_ROUTE1: &str = "/limit1";
281        async fn handler0(Limit(_uri): Limit<1, 1, Uri>) -> impl IntoResponse {}
282
283        async fn handler1(Limit(_uri): Limit<3, 1, Uri>) -> impl IntoResponse {}
284
285        let my_app = Router::new()
286            .route(TEST_ROUTE0, get(handler0))
287            .route(TEST_ROUTE1, get(handler1))
288            .with_state(LimitState::default());
289
290        let server = TestServer::new(my_app).expect("Failed to create test server");
291
292        let response = server.get(TEST_ROUTE0).await;
293        assert_eq!(response.status_code(), StatusCode::OK);
294        let response = server.get(TEST_ROUTE0).await;
295        assert_eq!(response.status_code(), StatusCode::TOO_MANY_REQUESTS);
296        tokio::time::sleep(Duration::from_secs(1)).await;
297        let response = server.get(TEST_ROUTE0).await;
298        assert_eq!(response.status_code(), StatusCode::OK);
299
300        let gets = vec![
301            server.get(TEST_ROUTE1).into_future(),
302            server.get(TEST_ROUTE1).into_future(),
303            server.get(TEST_ROUTE1).into_future(),
304        ];
305
306        let resp = futures::future::join_all(gets).await;
307        assert!(!resp.iter().any(|r| !r.status_code().is_success()));
308        assert_eq!(
309            server.get(TEST_ROUTE1).await.status_code(),
310            StatusCode::TOO_MANY_REQUESTS
311        );
312        tokio::time::sleep(Duration::from_secs(1)).await;
313        let response = server.get(TEST_ROUTE1).await;
314        assert_eq!(response.status_code(), StatusCode::OK);
315    }
316
317    #[tokio::test]
318    async fn limit_per_100_millis() {
319        const TEST_ROUTE: &str = "/limit_per_100_millis";
320
321        async fn handler(Limit(_uri): Limit<1, 100, Uri>) -> impl IntoResponse {}
322
323        let my_app = Router::new()
324            .route(TEST_ROUTE, get(handler))
325            .with_state(LimitState::default());
326
327        let server = TestServer::new(my_app).expect("Failed to create test server");
328
329        // 第一次请求应该成功
330        let response = server.get(TEST_ROUTE).await;
331        assert_eq!(response.status_code(), StatusCode::OK);
332
333        // 马上再发起一次请求应该被限制
334        let response = server.get(TEST_ROUTE).await;
335        assert_eq!(response.status_code(), StatusCode::TOO_MANY_REQUESTS);
336
337        // 等待 100 毫秒
338        tokio::time::sleep(Duration::from_millis(100)).await;
339
340        // 再次请求应该成功
341        let response = server.get(TEST_ROUTE).await;
342        assert_eq!(response.status_code(), StatusCode::OK);
343    }
344}