Skip to main content

allowthem_server/
rate_limit.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::ConnectInfo;
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7pub use governor::Quota;
8use governor::RateLimiter;
9use governor::clock::{Clock, DefaultClock, QuantaInstant};
10use governor::middleware::NoOpMiddleware;
11use governor::state::keyed::DashMapStateStore;
12
13type KeyedLimiter =
14    RateLimiter<String, DashMapStateStore<String>, DefaultClock, NoOpMiddleware<QuantaInstant>>;
15
16/// A keyed rate limiter for IP-based throttling.
17///
18/// Create one per endpoint group with the desired quota, store it in your
19/// app state, and call `check` at the top of the handler.
20#[derive(Clone)]
21pub struct AuthRateLimiter {
22    inner: Arc<KeyedLimiter>,
23}
24
25impl AuthRateLimiter {
26    pub fn new(quota: Quota) -> Self {
27        Self {
28            inner: Arc::new(RateLimiter::keyed(quota)),
29        }
30    }
31
32    #[allow(clippy::result_large_err)]
33    pub fn check(&self, key: &str) -> Result<(), Response> {
34        match self.inner.check_key(&key.to_owned()) {
35            Ok(_) => Ok(()),
36            Err(not_until) => {
37                let wait = not_until.wait_time_from(DefaultClock::default().now());
38                let retry_after = wait.as_secs().saturating_add(1);
39                Err(rate_limit_response(retry_after))
40            }
41        }
42    }
43}
44
45/// Extract the client IP address from request extensions.
46///
47/// Returns the IP as a string, or `"unknown"` if `ConnectInfo` is not present.
48/// Requires the server to use `into_make_service_with_connect_info::<SocketAddr>()`.
49pub fn extract_client_ip(extensions: &axum::http::Extensions) -> String {
50    extensions
51        .get::<ConnectInfo<SocketAddr>>()
52        .map(|ci| ci.0.ip().to_string())
53        .unwrap_or_else(|| "unknown".into())
54}
55
56fn rate_limit_response(retry_after_secs: u64) -> Response {
57    let mut response = (
58        StatusCode::TOO_MANY_REQUESTS,
59        format!(
60            "Too many requests. Retry after {} seconds.",
61            retry_after_secs
62        ),
63    )
64        .into_response();
65    if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after_secs.to_string()) {
66        response.headers_mut().insert("retry-after", val);
67    }
68    response
69}
70
71#[cfg(test)]
72mod tests {
73    use std::num::NonZeroU32;
74
75    use super::*;
76
77    #[test]
78    fn requests_within_burst_are_allowed() {
79        let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(3).unwrap()));
80        let ip = "127.0.0.1";
81        assert!(limiter.check(ip).is_ok());
82        assert!(limiter.check(ip).is_ok());
83        assert!(limiter.check(ip).is_ok());
84    }
85
86    #[test]
87    fn requests_exceeding_burst_get_429() {
88        let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(2).unwrap()));
89        let ip = "192.168.1.1";
90        assert!(limiter.check(ip).is_ok());
91        assert!(limiter.check(ip).is_ok());
92        let err = limiter.check(ip).unwrap_err();
93        assert_eq!(err.status(), StatusCode::TOO_MANY_REQUESTS);
94        assert!(err.headers().contains_key("retry-after"));
95    }
96
97    #[test]
98    fn different_keys_have_independent_limits() {
99        let limiter = AuthRateLimiter::new(Quota::per_minute(NonZeroU32::new(1).unwrap()));
100        assert!(limiter.check("10.0.0.1").is_ok());
101        assert!(limiter.check("10.0.0.2").is_ok());
102        // First IP is now limited, second is not
103        assert!(limiter.check("10.0.0.1").is_err());
104        assert!(limiter.check("10.0.0.3").is_ok());
105    }
106
107    #[test]
108    fn extract_client_ip_returns_unknown_without_connect_info() {
109        let extensions = axum::http::Extensions::new();
110        assert_eq!(extract_client_ip(&extensions), "unknown");
111    }
112}