1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::error::Error as StdError;
20use std::hash::Hash;
21
22use salvo_core::conn::SocketAddr;
23use salvo_core::handler::{Skipper, none_skipper};
24use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
25use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
26
27mod quota;
28pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
29#[macro_use]
30mod cfg;
31
32cfg_feature! {
33 #![feature = "moka-store"]
34
35 mod moka_store;
36 pub use moka_store::MokaStore;
37}
38
39cfg_feature! {
40 #![feature = "fixed-guard"]
41
42 mod fixed_guard;
43 pub use fixed_guard::FixedGuard;
44}
45
46cfg_feature! {
47 #![feature = "sliding-guard"]
48
49 mod sliding_guard;
50 pub use sliding_guard::SlidingGuard;
51}
52
53pub trait RateIssuer: Send + Sync + 'static {
55 type Key: Hash + Eq + Send + Sync + 'static;
57 fn issue(
59 &self,
60 req: &mut Request,
61 depot: &Depot,
62 ) -> impl Future<Output = Option<Self::Key>> + Send;
63}
64impl<F, K> RateIssuer for F
65where
66 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
67 K: Hash + Eq + Send + Sync + 'static,
68{
69 type Key = K;
70 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
71 (self)(req, depot)
72 }
73}
74
75pub struct RemoteIpIssuer;
77impl RateIssuer for RemoteIpIssuer {
78 type Key = String;
79 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
80 match req.remote_addr() {
81 SocketAddr::IPv4(addr) => Some(addr.ip().to_string()),
82 SocketAddr::IPv6(addr) => Some(addr.ip().to_string()),
83 _ => None,
84 }
85 }
86}
87
88pub trait RateGuard: Clone + Send + Sync + 'static {
90 type Quota: Clone + Send + Sync + 'static;
92 fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
94
95 fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
97
98 fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
100
101 fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
103}
104
105pub trait RateStore: Send + Sync + 'static {
107 type Error: StdError;
109 type Key: Hash + Eq + Send + Clone + 'static;
111 type Guard;
113 fn load_guard<Q>(
115 &self,
116 key: &Q,
117 refer: &Self::Guard,
118 ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
119 where
120 Self::Key: Borrow<Q>,
121 Q: Hash + Eq + Sync;
122 fn save_guard(
124 &self,
125 key: Self::Key,
126 guard: Self::Guard,
127 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
128}
129
130pub struct RateLimiter<G, S, I, Q> {
132 guard: G,
133 store: S,
134 issuer: I,
135 quota_getter: Q,
136 add_headers: bool,
137 skipper: Box<dyn Skipper>,
138}
139
140impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
141 #[inline]
143 pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
144 Self {
145 guard,
146 store,
147 issuer,
148 quota_getter,
149 add_headers: false,
150 skipper: Box::new(none_skipper),
151 }
152 }
153
154 #[inline]
156 pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
157 self.skipper = Box::new(skipper);
158 self
159 }
160
161 #[inline]
164 pub fn add_headers(mut self, add_headers: bool) -> Self {
165 self.add_headers = add_headers;
166 self
167 }
168}
169
170#[async_trait]
171impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
172where
173 G: RateGuard<Quota = P::Quota>,
174 S: RateStore<Key = I::Key, Guard = G>,
175 P: QuotaGetter<I::Key>,
176 I: RateIssuer,
177{
178 async fn handle(
179 &self,
180 req: &mut Request,
181 depot: &mut Depot,
182 res: &mut Response,
183 ctrl: &mut FlowCtrl,
184 ) {
185 if self.skipper.skipped(req, depot) {
186 return;
187 }
188 let key = match self.issuer.issue(req, depot).await {
189 Some(key) => key,
190 None => {
191 res.render(StatusError::bad_request().brief("Invalid identifier."));
192 ctrl.skip_rest();
193 return;
194 }
195 };
196 let quota = match self.quota_getter.get(&key).await {
197 Ok(quota) => quota,
198 Err(e) => {
199 tracing::error!(error = ?e, "RateLimiter error: {}", e);
200 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
201 ctrl.skip_rest();
202 return;
203 }
204 };
205 let mut guard = match self.store.load_guard(&key, &self.guard).await {
206 Ok(guard) => guard,
207 Err(e) => {
208 tracing::error!(error = ?e, "RateLimiter error: {}", e);
209 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
210 ctrl.skip_rest();
211 return;
212 }
213 };
214 let verified = guard.verify("a).await;
215
216 if self.add_headers {
217 res.headers_mut().insert(
218 "X-RateLimit-Limit",
219 HeaderValue::from_str(&guard.limit("a).await.to_string())
220 .expect("Invalid header value"),
221 );
222 res.headers_mut().insert(
223 "X-RateLimit-Remaining",
224 HeaderValue::from_str(&(guard.remaining("a).await).to_string())
225 .expect("Invalid header value"),
226 );
227 res.headers_mut().insert(
228 "X-RateLimit-Reset",
229 HeaderValue::from_str(&guard.reset("a).await.to_string())
230 .expect("Invalid header value"),
231 );
232 }
233 if !verified {
234 res.status_code(StatusCode::TOO_MANY_REQUESTS);
235 ctrl.skip_rest();
236 }
237 if let Err(e) = self.store.save_guard(key, guard).await {
238 tracing::error!(error = ?e, "RateLimiter save guard failed");
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use std::collections::HashMap;
246 use std::sync::LazyLock;
247
248 use salvo_core::Error;
249 use salvo_core::prelude::*;
250 use salvo_core::test::{ResponseExt, TestClient};
251
252 use super::*;
253
254 struct UserIssuer;
255 impl RateIssuer for UserIssuer {
256 type Key = String;
257 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
258 req.query::<Self::Key>("user")
259 }
260 }
261
262 #[handler]
263 async fn limited() -> &'static str {
264 "Limited page"
265 }
266
267 #[tokio::test]
268 async fn test_fixed_dynamic_quota() {
269 static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
270 let mut map = HashMap::new();
271 map.insert("user1".into(), BasicQuota::per_second(1));
272 map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
273 map
274 });
275
276 struct CustomQuotaGetter;
277 impl QuotaGetter<String> for CustomQuotaGetter {
278 type Quota = BasicQuota;
279 type Error = Error;
280
281 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
282 where
283 String: Borrow<Q>,
284 Q: Hash + Eq + Sync,
285 {
286 USER_QUOTAS
287 .get(key)
288 .cloned()
289 .ok_or_else(|| Error::other("user not found"))
290 }
291 }
292 let limiter = RateLimiter::new(
293 FixedGuard::default(),
294 MokaStore::default(),
295 UserIssuer,
296 CustomQuotaGetter,
297 );
298 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
299 let service = Service::new(router);
300
301 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
302 .send(&service)
303 .await;
304 assert_eq!(response.status_code, Some(StatusCode::OK));
305 assert_eq!(response.take_string().await.unwrap(), "Limited page");
306
307 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
308 .send(&service)
309 .await;
310 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
311
312 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
313
314 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
315 .send(&service)
316 .await;
317 assert_eq!(response.status_code, Some(StatusCode::OK));
318 assert_eq!(response.take_string().await.unwrap(), "Limited page");
319
320 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
321 .send(&service)
322 .await;
323 assert_eq!(response.status_code, Some(StatusCode::OK));
324 assert_eq!(response.take_string().await.unwrap(), "Limited page");
325
326 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
327 .send(&service)
328 .await;
329 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
330
331 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
332
333 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
334 .send(&service)
335 .await;
336 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
337
338 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
339
340 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
341 .send(&service)
342 .await;
343 assert_eq!(response.status_code, Some(StatusCode::OK));
344 assert_eq!(response.take_string().await.unwrap(), "Limited page");
345 }
346
347 #[tokio::test]
348 async fn test_sliding_dynamic_quota() {
349 static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
350 let mut map = HashMap::new();
351 map.insert("user1".into(), CelledQuota::per_second(1, 1));
352 map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
353 map
354 });
355
356 struct CustomQuotaGetter;
357 impl QuotaGetter<String> for CustomQuotaGetter {
358 type Quota = CelledQuota;
359 type Error = Error;
360
361 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
362 where
363 String: Borrow<Q>,
364 Q: Hash + Eq + Sync,
365 {
366 USER_QUOTAS
367 .get(key)
368 .cloned()
369 .ok_or_else(|| Error::other("user not found"))
370 }
371 }
372 let limiter = RateLimiter::new(
373 SlidingGuard::default(),
374 MokaStore::default(),
375 UserIssuer,
376 CustomQuotaGetter,
377 );
378 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
379 let service = Service::new(router);
380
381 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
382 .send(&service)
383 .await;
384 assert_eq!(response.status_code, Some(StatusCode::OK));
385 assert_eq!(response.take_string().await.unwrap(), "Limited page");
386
387 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
388 .send(&service)
389 .await;
390 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
391
392 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
393
394 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
395 .send(&service)
396 .await;
397 assert_eq!(response.status_code, Some(StatusCode::OK));
398 assert_eq!(response.take_string().await.unwrap(), "Limited page");
399
400 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
401 .send(&service)
402 .await;
403 assert_eq!(response.status_code, Some(StatusCode::OK));
404 assert_eq!(response.take_string().await.unwrap(), "Limited page");
405
406 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
407 .send(&service)
408 .await;
409 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
410
411 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
412
413 let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
414 .send(&service)
415 .await;
416 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
417
418 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
419
420 let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
421 .send(&service)
422 .await;
423 assert_eq!(response.status_code, Some(StatusCode::OK));
424 assert_eq!(response.take_string().await.unwrap(), "Limited page");
425 }
426}