1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code, missing_docs, clippy::unwrap_used)]
3
4mod key;
5
6use axum_core::extract::{FromRef, FromRequestParts};
7use axum_core::response::{IntoResponse, Response};
8use dashmap::DashMap;
9use http::request::Parts;
10use http::StatusCode;
11use std::error::Error;
12use std::fmt::Display;
13use std::hash::Hash;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18#[derive(Debug, Clone, Copy, Default)]
21pub struct Limit<const COUNT: usize, const PER: u64, K>(pub K::Extractor)
22where
23 K: Key;
24
25pub type LimitPerSecond<const COUNT: usize, K> = Limit<COUNT, 1000, K>;
27
28pub type LimitPerMinute<const COUNT: usize, K> = Limit<COUNT, 60_000, K>;
30
31pub type LimitPerHour<const COUNT: usize, K> = Limit<COUNT, 3_600_000, K>;
33
34pub type LimitPerDay<const COUNT: usize, K> = Limit<COUNT, 86_400_000, K>;
36
37impl<const COUNT: usize, const PER: u64, K> AsRef<K::Extractor> for Limit<COUNT, PER, K>
38where
39 K: Key,
40{
41 fn as_ref(&self) -> &K::Extractor {
42 &self.0
43 }
44}
45
46impl<const COUNT: usize, const PER: u64, K> AsMut<K::Extractor> for Limit<COUNT, PER, K>
47where
48 K: Key,
49{
50 fn as_mut(&mut self) -> &mut K::Extractor {
51 &mut self.0
52 }
53}
54
55impl<const COUNT: usize, const PER: u64, K> Deref for Limit<COUNT, PER, K>
56where
57 K: Key,
58{
59 type Target = K::Extractor;
60
61 fn deref(&self) -> &Self::Target {
62 &self.0
63 }
64}
65
66impl<const COUNT: usize, const PER: u64, K> DerefMut for Limit<COUNT, PER, K>
67where
68 K: Key,
69{
70 fn deref_mut(&mut self) -> &mut Self::Target {
71 &mut self.0
72 }
73}
74
75impl<const COUNT: usize, const PER: u64, K> Display for Limit<COUNT, PER, K>
76where
77 K: Key,
78 K::Extractor: Display,
79{
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 self.0.fmt(f)
82 }
83}
84
85impl<const COUNT: usize, const PER: u64, K> Limit<COUNT, PER, K>
86where
87 K: Key,
88{
89 pub const fn count() -> usize {
91 COUNT
92 }
93
94 pub const fn per() -> u64 {
96 PER
97 }
98
99 pub fn into_inner(self) -> K::Extractor {
101 self.0
102 }
103}
104
105#[async_trait::async_trait]
108pub trait Key: Eq + Hash + Send + Sync {
109 type Extractor;
112 fn from_extractor(extractor: &Self::Extractor) -> Self;
114}
115
116struct TokenBucket {
119 tokens: usize,
120 last_refill_time: Instant,
121 refill_duration: Duration,
122}
123
124impl TokenBucket {
125 fn new(tokens: impl Into<usize>, per: impl Into<u64>) -> Self {
127 Self {
128 tokens: tokens.into(),
129 last_refill_time: Instant::now(),
130 refill_duration: Duration::from_millis(per.into()),
131 }
132 }
133
134 fn try_acquire(&mut self) -> bool {
136 self.refill();
137 if self.tokens > 0 {
138 self.tokens -= 1;
139 true
140 } else {
141 false
142 }
143 }
144
145 fn refill(&mut self) {
147 let now = Instant::now();
148 let elapsed = now.duration_since(self.last_refill_time);
149
150 if elapsed >= self.refill_duration {
152 let elapsed_millis = elapsed.as_millis() as u64; let refill_duration_millis = self.refill_duration.as_millis() as u64; let new_tokens = (elapsed_millis / refill_duration_millis) as usize;
157 self.tokens += new_tokens;
158
159 self.last_refill_time =
161 now - Duration::from_millis(elapsed_millis % refill_duration_millis);
162 }
163 }
164}
165
166#[derive(Clone)]
170pub struct LimitState<K>
171where
172 K: Key,
173{
174 rate_limits: Arc<DashMap<K, TokenBucket>>,
175}
176
177impl<K> Default for LimitState<K>
178where
179 K: Key,
180{
181 fn default() -> Self {
183 Self {
184 rate_limits: Arc::new(DashMap::new()),
185 }
186 }
187}
188
189impl<K> LimitState<K>
190where
191 K: Key,
192{
193 pub fn check(&self, key: K, count: usize, per: u64) -> bool {
195 let mut bucket = self
196 .rate_limits
197 .entry(key)
198 .or_insert_with(|| TokenBucket::new(count, per));
199 bucket.try_acquire()
200 }
201}
202
203#[async_trait::async_trait]
204impl<const C: usize, const P: u64, K, S> FromRequestParts<S> for Limit<C, P, K>
205where
206 LimitState<K>: FromRef<S>,
207 S: Send + Sync,
208 K: Key,
209 K::Extractor: FromRequestParts<S>,
210{
211 type Rejection = LimitRejection<<<K as Key>::Extractor as FromRequestParts<S>>::Rejection>;
212
213 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
214 let key_extractor = match K::Extractor::from_request_parts(parts, state).await {
215 Ok(ke) => ke,
216 Err(rejection) => return Err(LimitRejection::KeyExtractionFailure(rejection)),
217 };
218
219 let limit_state: LimitState<K> = FromRef::from_ref(state);
220 let key = K::from_extractor(&key_extractor);
221 if limit_state.check(key, C, P) {
222 Ok(Self(key_extractor))
223 } else {
224 Err(LimitRejection::RateLimitExceeded)
225 }
226 }
227}
228
229#[derive(Debug)]
231pub enum LimitRejection<R> {
232 KeyExtractionFailure(R),
234
235 RateLimitExceeded,
237}
238
239impl<R: Display> Display for LimitRejection<R> {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 match self {
242 LimitRejection::KeyExtractionFailure(r) => write!(f, "{r}"),
243 LimitRejection::RateLimitExceeded => write!(f, "Rate limit exceeded."),
244 }
245 }
246}
247
248impl<R: Error + 'static> Error for LimitRejection<R> {
249 fn source(&self) -> Option<&(dyn Error + 'static)> {
250 match self {
251 LimitRejection::KeyExtractionFailure(ve) => Some(ve),
252 LimitRejection::RateLimitExceeded => None,
253 }
254 }
255}
256
257impl<R: IntoResponse> IntoResponse for LimitRejection<R> {
258 fn into_response(self) -> Response {
259 match self {
260 LimitRejection::KeyExtractionFailure(rejection) => rejection.into_response(),
261 LimitRejection::RateLimitExceeded => {
262 (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded.").into_response()
263 }
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use axum::routing::get;
272 use axum::Router;
273 use axum_test::TestServer;
274 use http::Uri;
275 use std::future::IntoFuture;
276
277 #[tokio::test]
278 async fn limit() {
279 const TEST_ROUTE0: &str = "/limit0";
280 const TEST_ROUTE1: &str = "/limit1";
281 async fn handler0(Limit(_uri): Limit<1, 1, Uri>) -> impl IntoResponse {}
282
283 async fn handler1(Limit(_uri): Limit<3, 1, Uri>) -> impl IntoResponse {}
284
285 let my_app = Router::new()
286 .route(TEST_ROUTE0, get(handler0))
287 .route(TEST_ROUTE1, get(handler1))
288 .with_state(LimitState::default());
289
290 let server = TestServer::new(my_app).expect("Failed to create test server");
291
292 let response = server.get(TEST_ROUTE0).await;
293 assert_eq!(response.status_code(), StatusCode::OK);
294 let response = server.get(TEST_ROUTE0).await;
295 assert_eq!(response.status_code(), StatusCode::TOO_MANY_REQUESTS);
296 tokio::time::sleep(Duration::from_secs(1)).await;
297 let response = server.get(TEST_ROUTE0).await;
298 assert_eq!(response.status_code(), StatusCode::OK);
299
300 let gets = vec![
301 server.get(TEST_ROUTE1).into_future(),
302 server.get(TEST_ROUTE1).into_future(),
303 server.get(TEST_ROUTE1).into_future(),
304 ];
305
306 let resp = futures::future::join_all(gets).await;
307 assert!(!resp.iter().any(|r| !r.status_code().is_success()));
308 assert_eq!(
309 server.get(TEST_ROUTE1).await.status_code(),
310 StatusCode::TOO_MANY_REQUESTS
311 );
312 tokio::time::sleep(Duration::from_secs(1)).await;
313 let response = server.get(TEST_ROUTE1).await;
314 assert_eq!(response.status_code(), StatusCode::OK);
315 }
316
317 #[tokio::test]
318 async fn limit_per_100_millis() {
319 const TEST_ROUTE: &str = "/limit_per_100_millis";
320
321 async fn handler(Limit(_uri): Limit<1, 100, Uri>) -> impl IntoResponse {}
322
323 let my_app = Router::new()
324 .route(TEST_ROUTE, get(handler))
325 .with_state(LimitState::default());
326
327 let server = TestServer::new(my_app).expect("Failed to create test server");
328
329 let response = server.get(TEST_ROUTE).await;
331 assert_eq!(response.status_code(), StatusCode::OK);
332
333 let response = server.get(TEST_ROUTE).await;
335 assert_eq!(response.status_code(), StatusCode::TOO_MANY_REQUESTS);
336
337 tokio::time::sleep(Duration::from_millis(100)).await;
339
340 let response = server.get(TEST_ROUTE).await;
342 assert_eq!(response.status_code(), StatusCode::OK);
343 }
344}