Skip to main content

fraiseql_server/middleware/rate_limit/
middleware_fn.rs

1//! Rate limit middleware function and supporting helpers.
2//!
3//! Contains the axum middleware entry-point, IP extraction logic, and the
4//! JWT subject parser used for per-user rate limiting.
5
6use std::{
7    net::{IpAddr, SocketAddr},
8    sync::Arc,
9};
10
11use axum::{
12    body::Body,
13    extract::ConnectInfo,
14    http::{Request, StatusCode},
15    middleware::Next,
16    response::{IntoResponse, Response},
17};
18use tracing::warn;
19
20use super::{config::RateLimitConfig, dispatch::RateLimiter, key::is_private_or_loopback};
21
22/// Rate limit middleware response.
23///
24/// Carries the number of seconds the client should wait before retrying,
25/// derived from the active rate-limit configuration at the time the request
26/// was rejected.  This value is emitted as both the `Retry-After` HTTP header
27/// and in the GraphQL error message body.
28#[derive(Debug)]
29pub struct RateLimitExceeded {
30    /// Seconds until the token bucket refills by at least one token.
31    pub retry_after_secs: u32,
32}
33
34impl IntoResponse for RateLimitExceeded {
35    fn into_response(self) -> Response {
36        let retry = self.retry_after_secs;
37        let retry_str = retry.to_string();
38        let body = format!(
39            r#"{{"errors":[{{"message":"Rate limit exceeded. Please retry after {retry} second{s}."}}]}}"#,
40            s = if retry == 1 { "" } else { "s" }
41        );
42        (
43            StatusCode::TOO_MANY_REQUESTS,
44            [
45                ("Content-Type", "application/json"),
46                ("Retry-After", retry_str.as_str()),
47            ],
48            body,
49        )
50            .into_response()
51    }
52}
53
54/// Emitted at most once when the server appears to be behind a proxy but
55/// `trust_proxy_headers` is `false` — rate limiting would bucket all requests
56/// under the proxy's IP in that configuration.
57static PROXY_WARNING_LOGGED: std::sync::atomic::AtomicBool =
58    std::sync::atomic::AtomicBool::new(false);
59
60/// Extract the real client IP from request headers when behind a trusted reverse proxy.
61///
62/// Checks `X-Real-IP` first, then the first address in `X-Forwarded-For` (set by
63/// the proxy to the original client).  Falls back to the TCP peer address when
64/// neither header is present or `trust_proxy` is false.
65///
66/// **Security**: only enable `trust_proxy` when the server is guaranteed to sit
67/// behind a proxy that sets these headers; otherwise clients can spoof the IP.
68pub(super) fn extract_real_ip(
69    req: &Request<Body>,
70    trust_proxy: bool,
71    trusted_cidrs: &[ipnet::IpNet],
72    addr: &SocketAddr,
73) -> String {
74    if trust_proxy {
75        // If trusted_cidrs is non-empty, verify the direct connection IP is a known proxy.
76        if !trusted_cidrs.is_empty() {
77            let direct: IpAddr = addr.ip();
78            let from_trusted_proxy = trusted_cidrs.iter().any(|cidr| cidr.contains(&direct));
79            if !from_trusted_proxy {
80                tracing::debug!(
81                    %direct,
82                    "Connection not from a trusted proxy CIDR; ignoring X-Forwarded-For"
83                );
84                return direct.to_string();
85            }
86        }
87
88        if let Some(real_ip) = req
89            .headers()
90            .get("x-real-ip")
91            .and_then(|v| v.to_str().ok())
92            .map(str::trim)
93            .filter(|s| !s.is_empty())
94        {
95            return real_ip.to_string();
96        }
97        if let Some(xff) = req.headers().get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
98            if let Some(first) = xff.split(',').next().map(str::trim).filter(|s| !s.is_empty()) {
99                return first.to_string();
100            }
101        }
102    } else if is_private_or_loopback(addr.ip())
103        && !PROXY_WARNING_LOGGED.load(std::sync::atomic::Ordering::Relaxed)
104        && !PROXY_WARNING_LOGGED.swap(true, std::sync::atomic::Ordering::Relaxed)
105    {
106        warn!(
107            peer_ip = %addr.ip(),
108            "Rate limiter: peer address is loopback/RFC-1918 — server appears to be \
109             behind a reverse proxy. All requests will share a single rate-limit bucket \
110             unless you set `trust_proxy_headers = true` in [security.rate_limiting]."
111        );
112    }
113    addr.ip().to_string()
114}
115
116/// Decode a JWT bearer token's payload section and extract the `sub` claim
117/// without performing cryptographic signature verification.
118///
119/// Signature verification is intentionally omitted: rate limiting is a
120/// best-effort control that degrades gracefully — an invalid or forged JWT
121/// simply returns `None`, falling back to IP-based limiting.  Verified
122/// identity is handled by the auth middleware upstream.
123pub(super) fn extract_jwt_subject(authorization: &str) -> Option<String> {
124    use base64::Engine as _;
125    let token = authorization.strip_prefix("Bearer ")?;
126    let payload_b64 = token.split('.').nth(1)?;
127    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64).ok()?;
128    let json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
129    json.get("sub").and_then(|v| v.as_str()).map(String::from)
130}
131
132/// Rate limiting middleware for GraphQL requests.
133///
134/// Decision order:
135/// 1. Per-path limit (auth endpoints) — always checked, uses path-specific window.
136/// 2. Per-user limit (authenticated requests) — checked when a JWT `sub` claim is present in the
137///    `Authorization` header; authenticated users get `rps_per_user` (default 10× `rps_per_ip`)
138///    instead of the shared IP bucket.
139/// 3. Per-IP limit (unauthenticated or no bearer token) — fallback.
140///
141/// # Errors
142///
143/// Returns `RateLimitExceeded` if the per-path, per-user, or per-IP rate limit is exceeded.
144#[allow(clippy::cognitive_complexity)] // Reason: multi-dimension rate limiting (per-path, per-user, per-IP) with config lookups
145pub async fn rate_limit_middleware(
146    ConnectInfo(addr): ConnectInfo<SocketAddr>,
147    req: Request<Body>,
148    next: Next,
149) -> Result<Response, RateLimitExceeded> {
150    // Get or create rate limiter from state
151    let limiter = req
152        .extensions()
153        .get::<Arc<RateLimiter>>()
154        .cloned()
155        .unwrap_or_else(|| Arc::new(RateLimiter::new(RateLimitConfig::default())));
156
157    let ip = extract_real_ip(
158        &req,
159        limiter.config().trust_proxy_headers,
160        &limiter.config().trusted_proxy_cidrs,
161        &addr,
162    );
163    let path = req.uri().path().to_string();
164
165    // Extract JWT subject for per-user limiting (no signature verification needed here).
166    let user_id = req
167        .headers()
168        .get(axum::http::header::AUTHORIZATION)
169        .and_then(|v| v.to_str().ok())
170        .and_then(extract_jwt_subject);
171
172    // ── Per-path limit (strictest, always enforced) ───────────────────────
173    let path_result = limiter.check_path_limit(&path, &ip).await;
174    if !path_result.allowed {
175        warn!(ip = %ip, path = %path, "Per-path rate limit exceeded");
176        return Err(RateLimitExceeded {
177            retry_after_secs: path_result.retry_after_secs,
178        });
179    }
180
181    // ── Per-user or per-IP limit ──────────────────────────────────────────
182    let limit_result = if let Some(ref uid) = user_id {
183        // Authenticated: apply the higher per-user bucket.
184        limiter.check_user_limit(uid).await
185    } else {
186        // Unauthenticated: apply the shared IP bucket.
187        limiter.check_ip_limit(&ip).await
188    };
189
190    if !limit_result.allowed {
191        if let Some(ref uid) = user_id {
192            warn!(user_id = %uid, "Per-user rate limit exceeded");
193        } else {
194            warn!(ip = %ip, "IP rate limit exceeded");
195        }
196        return Err(RateLimitExceeded {
197            retry_after_secs: limit_result.retry_after_secs,
198        });
199    }
200
201    let remaining = limit_result.remaining;
202
203    let response = next.run(req).await;
204
205    // Add rate limit headers
206    let mut response = response;
207    let limit = if user_id.is_some() {
208        limiter.config().rps_per_user
209    } else {
210        limiter.config().rps_per_ip
211    };
212    if let Ok(limit_value) = format!("{limit}").parse() {
213        response.headers_mut().insert("X-RateLimit-Limit", limit_value);
214    }
215    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
216    // Reason: remaining tokens is a small non-negative count that fits in u32
217    if let Ok(remaining_value) = format!("{}", remaining as u32).parse() {
218        response.headers_mut().insert("X-RateLimit-Remaining", remaining_value);
219    }
220
221    Ok(response)
222}
223
224#[cfg(test)]
225mod tests {
226    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
227
228    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
229
230    use axum::{body::Body, http::Request};
231
232    use super::extract_real_ip;
233
234    fn socket_addr(ip: [u8; 4]) -> SocketAddr {
235        SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip)), 12345)
236    }
237
238    fn req_with_xff(xff: &str) -> Request<Body> {
239        Request::builder()
240            .uri("http://example.com/graphql")
241            .header("x-forwarded-for", xff)
242            .body(Body::empty())
243            .unwrap()
244    }
245
246    #[test]
247    fn test_spoofed_xforwardedfor_ignored_when_direct_ip_not_in_trusted_cidrs() {
248        // Direct IP is a public internet address — NOT in the trusted proxy CIDR 10.0.0.0/8.
249        // Even though trust_proxy_headers = true, X-Forwarded-For must be ignored.
250        let cidrs: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
251        let addr = socket_addr([203, 0, 113, 1]); // TEST-NET-3, public
252        let req = req_with_xff("1.2.3.4");
253
254        let ip = extract_real_ip(&req, true, &cidrs, &addr);
255        assert_eq!(ip, "203.0.113.1", "Should use direct IP, not spoofed X-Forwarded-For");
256    }
257
258    #[test]
259    fn test_forwarded_ip_used_when_direct_ip_is_trusted_proxy() {
260        // Direct IP is inside 10.0.0.0/8 (trusted proxy CIDR).
261        // X-Forwarded-For should be honoured.
262        let cidrs: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
263        let addr = socket_addr([10, 0, 1, 5]); // inside trusted CIDR
264        let req = req_with_xff("5.6.7.8");
265
266        let ip = extract_real_ip(&req, true, &cidrs, &addr);
267        assert_eq!(ip, "5.6.7.8", "Should use X-Forwarded-For from trusted proxy");
268    }
269
270    #[test]
271    fn test_no_cidrs_trusts_all_proxies() {
272        // When trusted_proxy_cidrs is empty and trust_proxy_headers = true,
273        // all direct IPs are treated as trusted proxies.
274        let cidrs: Vec<ipnet::IpNet> = vec![];
275        let addr = socket_addr([203, 0, 113, 1]); // public IP
276        let req = req_with_xff("9.9.9.9");
277
278        let ip = extract_real_ip(&req, true, &cidrs, &addr);
279        assert_eq!(ip, "9.9.9.9", "Empty CIDRs: all proxies trusted");
280    }
281}