allowthem_server/
rate_limit.rs1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::ConnectInfo;
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7pub use governor::Quota;
8use governor::RateLimiter;
9use governor::clock::{Clock, DefaultClock, QuantaInstant};
10use governor::middleware::NoOpMiddleware;
11use governor::state::keyed::DashMapStateStore;
12
13type KeyedLimiter =
14 RateLimiter<String, DashMapStateStore<String>, DefaultClock, NoOpMiddleware<QuantaInstant>>;
15
16#[derive(Clone)]
21pub struct AuthRateLimiter {
22 inner: Arc<KeyedLimiter>,
23}
24
25impl AuthRateLimiter {
26 pub fn new(quota: Quota) -> Self {
27 Self {
28 inner: Arc::new(RateLimiter::keyed(quota)),
29 }
30 }
31
32 #[allow(clippy::result_large_err)]
33 pub fn check(&self, key: &str) -> Result<(), Response> {
34 match self.inner.check_key(&key.to_owned()) {
35 Ok(_) => Ok(()),
36 Err(not_until) => {
37 let wait = not_until.wait_time_from(DefaultClock::default().now());
38 let retry_after = wait.as_secs().saturating_add(1);
39 Err(rate_limit_response(retry_after))
40 }
41 }
42 }
43}
44
45pub fn extract_client_ip(extensions: &axum::http::Extensions) -> String {
50 extensions
51 .get::<ConnectInfo<SocketAddr>>()
52 .map(|ci| ci.0.ip().to_string())
53 .unwrap_or_else(|| "unknown".into())
54}
55
56fn rate_limit_response(retry_after_secs: u64) -> Response {
57 let mut response = (
58 StatusCode::TOO_MANY_REQUESTS,
59 format!(
60 "Too many requests. Retry after {} seconds.",
61 retry_after_secs
62 ),
63 )
64 .into_response();
65 if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after_secs.to_string()) {
66 response.headers_mut().insert("retry-after", val);
67 }
68 response
69}
70
71#[cfg(test)]
72mod tests {
73 use std::num::NonZeroU32;
74
75 use super::*;
76
77 #[test]
78 fn requests_within_burst_are_allowed() {
79 let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(3).unwrap()));
80 let ip = "127.0.0.1";
81 assert!(limiter.check(ip).is_ok());
82 assert!(limiter.check(ip).is_ok());
83 assert!(limiter.check(ip).is_ok());
84 }
85
86 #[test]
87 fn requests_exceeding_burst_get_429() {
88 let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(2).unwrap()));
89 let ip = "192.168.1.1";
90 assert!(limiter.check(ip).is_ok());
91 assert!(limiter.check(ip).is_ok());
92 let err = limiter.check(ip).unwrap_err();
93 assert_eq!(err.status(), StatusCode::TOO_MANY_REQUESTS);
94 assert!(err.headers().contains_key("retry-after"));
95 }
96
97 #[test]
98 fn different_keys_have_independent_limits() {
99 let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(1).unwrap()));
100 assert!(limiter.check("10.0.0.1").is_ok());
101 assert!(limiter.check("10.0.0.2").is_ok());
102 assert!(limiter.check("10.0.0.1").is_err());
104 assert!(limiter.check("10.0.0.3").is_ok());
105 }
106
107 #[test]
108 fn extract_client_ip_returns_unknown_without_connect_info() {
109 let extensions = axum::http::Extensions::new();
110 assert_eq!(extract_client_ip(&extensions), "unknown");
111 }
112}