fraiseql_server/middleware/rate_limit/
middleware_fn.rs1use std::{
7 net::{IpAddr, SocketAddr},
8 sync::Arc,
9};
10
11use axum::{
12 body::Body,
13 extract::ConnectInfo,
14 http::{Request, StatusCode},
15 middleware::Next,
16 response::{IntoResponse, Response},
17};
18use tracing::warn;
19
20use super::{config::RateLimitConfig, dispatch::RateLimiter, key::is_private_or_loopback};
21
22#[derive(Debug)]
29pub struct RateLimitExceeded {
30 pub retry_after_secs: u32,
32}
33
34impl IntoResponse for RateLimitExceeded {
35 fn into_response(self) -> Response {
36 let retry = self.retry_after_secs;
37 let retry_str = retry.to_string();
38 let body = format!(
39 r#"{{"errors":[{{"message":"Rate limit exceeded. Please retry after {retry} second{s}."}}]}}"#,
40 s = if retry == 1 { "" } else { "s" }
41 );
42 (
43 StatusCode::TOO_MANY_REQUESTS,
44 [
45 ("Content-Type", "application/json"),
46 ("Retry-After", retry_str.as_str()),
47 ],
48 body,
49 )
50 .into_response()
51 }
52}
53
54static PROXY_WARNING_LOGGED: std::sync::atomic::AtomicBool =
58 std::sync::atomic::AtomicBool::new(false);
59
60pub(super) fn extract_real_ip(
69 req: &Request<Body>,
70 trust_proxy: bool,
71 trusted_cidrs: &[ipnet::IpNet],
72 addr: &SocketAddr,
73) -> String {
74 if trust_proxy {
75 if !trusted_cidrs.is_empty() {
77 let direct: IpAddr = addr.ip();
78 let from_trusted_proxy = trusted_cidrs.iter().any(|cidr| cidr.contains(&direct));
79 if !from_trusted_proxy {
80 tracing::debug!(
81 %direct,
82 "Connection not from a trusted proxy CIDR; ignoring X-Forwarded-For"
83 );
84 return direct.to_string();
85 }
86 }
87
88 if let Some(real_ip) = req
89 .headers()
90 .get("x-real-ip")
91 .and_then(|v| v.to_str().ok())
92 .map(str::trim)
93 .filter(|s| !s.is_empty())
94 {
95 return real_ip.to_string();
96 }
97 if let Some(xff) = req.headers().get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
98 if let Some(first) = xff.split(',').next().map(str::trim).filter(|s| !s.is_empty()) {
99 return first.to_string();
100 }
101 }
102 } else if is_private_or_loopback(addr.ip())
103 && !PROXY_WARNING_LOGGED.load(std::sync::atomic::Ordering::Relaxed)
104 && !PROXY_WARNING_LOGGED.swap(true, std::sync::atomic::Ordering::Relaxed)
105 {
106 warn!(
107 peer_ip = %addr.ip(),
108 "Rate limiter: peer address is loopback/RFC-1918 — server appears to be \
109 behind a reverse proxy. All requests will share a single rate-limit bucket \
110 unless you set `trust_proxy_headers = true` in [security.rate_limiting]."
111 );
112 }
113 addr.ip().to_string()
114}
115
116pub(super) fn extract_jwt_subject(authorization: &str) -> Option<String> {
124 use base64::Engine as _;
125 let token = authorization.strip_prefix("Bearer ")?;
126 let payload_b64 = token.split('.').nth(1)?;
127 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64).ok()?;
128 let json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
129 json.get("sub").and_then(|v| v.as_str()).map(String::from)
130}
131
132#[allow(clippy::cognitive_complexity)] pub async fn rate_limit_middleware(
146 ConnectInfo(addr): ConnectInfo<SocketAddr>,
147 req: Request<Body>,
148 next: Next,
149) -> Result<Response, RateLimitExceeded> {
150 let limiter = req
152 .extensions()
153 .get::<Arc<RateLimiter>>()
154 .cloned()
155 .unwrap_or_else(|| Arc::new(RateLimiter::new(RateLimitConfig::default())));
156
157 let ip = extract_real_ip(
158 &req,
159 limiter.config().trust_proxy_headers,
160 &limiter.config().trusted_proxy_cidrs,
161 &addr,
162 );
163 let path = req.uri().path().to_string();
164
165 let user_id = req
167 .headers()
168 .get(axum::http::header::AUTHORIZATION)
169 .and_then(|v| v.to_str().ok())
170 .and_then(extract_jwt_subject);
171
172 let path_result = limiter.check_path_limit(&path, &ip).await;
174 if !path_result.allowed {
175 warn!(ip = %ip, path = %path, "Per-path rate limit exceeded");
176 return Err(RateLimitExceeded {
177 retry_after_secs: path_result.retry_after_secs,
178 });
179 }
180
181 let limit_result = if let Some(ref uid) = user_id {
183 limiter.check_user_limit(uid).await
185 } else {
186 limiter.check_ip_limit(&ip).await
188 };
189
190 if !limit_result.allowed {
191 if let Some(ref uid) = user_id {
192 warn!(user_id = %uid, "Per-user rate limit exceeded");
193 } else {
194 warn!(ip = %ip, "IP rate limit exceeded");
195 }
196 return Err(RateLimitExceeded {
197 retry_after_secs: limit_result.retry_after_secs,
198 });
199 }
200
201 let remaining = limit_result.remaining;
202
203 let response = next.run(req).await;
204
205 let mut response = response;
207 let limit = if user_id.is_some() {
208 limiter.config().rps_per_user
209 } else {
210 limiter.config().rps_per_ip
211 };
212 if let Ok(limit_value) = format!("{limit}").parse() {
213 response.headers_mut().insert("X-RateLimit-Limit", limit_value);
214 }
215 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
216 if let Ok(remaining_value) = format!("{}", remaining as u32).parse() {
218 response.headers_mut().insert("X-RateLimit-Remaining", remaining_value);
219 }
220
221 Ok(response)
222}
223
224#[cfg(test)]
225mod tests {
226 #![allow(clippy::unwrap_used)] use std::net::{IpAddr, Ipv4Addr, SocketAddr};
229
230 use axum::{body::Body, http::Request};
231
232 use super::extract_real_ip;
233
234 fn socket_addr(ip: [u8; 4]) -> SocketAddr {
235 SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip)), 12345)
236 }
237
238 fn req_with_xff(xff: &str) -> Request<Body> {
239 Request::builder()
240 .uri("http://example.com/graphql")
241 .header("x-forwarded-for", xff)
242 .body(Body::empty())
243 .unwrap()
244 }
245
246 #[test]
247 fn test_spoofed_xforwardedfor_ignored_when_direct_ip_not_in_trusted_cidrs() {
248 let cidrs: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
251 let addr = socket_addr([203, 0, 113, 1]); let req = req_with_xff("1.2.3.4");
253
254 let ip = extract_real_ip(&req, true, &cidrs, &addr);
255 assert_eq!(ip, "203.0.113.1", "Should use direct IP, not spoofed X-Forwarded-For");
256 }
257
258 #[test]
259 fn test_forwarded_ip_used_when_direct_ip_is_trusted_proxy() {
260 let cidrs: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
263 let addr = socket_addr([10, 0, 1, 5]); let req = req_with_xff("5.6.7.8");
265
266 let ip = extract_real_ip(&req, true, &cidrs, &addr);
267 assert_eq!(ip, "5.6.7.8", "Should use X-Forwarded-For from trusted proxy");
268 }
269
270 #[test]
271 fn test_no_cidrs_trusts_all_proxies() {
272 let cidrs: Vec<ipnet::IpNet> = vec![];
275 let addr = socket_addr([203, 0, 113, 1]); let req = req_with_xff("9.9.9.9");
277
278 let ip = extract_real_ip(&req, true, &cidrs, &addr);
279 assert_eq!(ip, "9.9.9.9", "Empty CIDRs: all proxies trusted");
280 }
281}