1pub mod matcher;
10pub mod rate;
11pub mod store;
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use axum::body::Body;
17use axum::extract::State;
18use axum::http::{HeaderMap, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::Json;
22
23use crate::config::ShieldConfig;
24use matcher::{EndpointClass, IdentifierEndpoint};
25use store::{Decision, MemoryStore, RateLimitStore};
26
27const MAX_IDENTIFIER_BODY: usize = 256 * 1024;
29
30pub struct Shield {
32 store: Arc<dyn RateLimitStore>,
33 classes: Vec<EndpointClass>,
34 identifiers: Vec<IdentifierEndpoint>,
35 trusted_proxies: Vec<ipnet::IpNet>,
37}
38
39impl Shield {
40 pub fn build(config: &ShieldConfig) -> Result<Option<Arc<Self>>, String> {
49 if !config.enabled {
50 return Ok(None);
51 }
52 if config.endpoint_classes.is_empty() && config.identifier_endpoints.is_empty() {
53 tracing::warn!("shield enabled but no endpoint_classes or identifier_endpoints set");
54 return Ok(None);
55 }
56
57 let default_window = Duration::from_secs(config.window_secs.max(1));
58 let classes = matcher::compile_endpoint_classes(&config.endpoint_classes, default_window)?;
59 let identifiers =
60 matcher::compile_identifier_endpoints(&config.identifier_endpoints, default_window)?;
61
62 let trusted_proxies = config
63 .trusted_proxies
64 .iter()
65 .map(|s| parse_cidr(s))
66 .collect::<Result<Vec<_>, _>>()?;
67
68 let store = build_store(config)?;
69 Ok(Some(Arc::new(Self {
70 store,
71 classes,
72 identifiers,
73 trusted_proxies,
74 })))
75 }
76
77 fn match_class(&self, path: &str) -> Option<&EndpointClass> {
78 self.classes.iter().find(|c| c.matcher.is_match(path))
79 }
80
81 fn match_identifier(&self, path: &str) -> Option<&IdentifierEndpoint> {
82 self.identifiers.iter().find(|i| i.matcher.is_match(path))
83 }
84}
85
86fn build_store(config: &ShieldConfig) -> Result<Arc<dyn RateLimitStore>, String> {
88 match &config.redis_url {
89 Some(url) => open_redis(url),
90 None => Ok(Arc::new(MemoryStore::new())),
91 }
92}
93
94#[cfg(feature = "redis")]
95fn open_redis(url: &str) -> Result<Arc<dyn RateLimitStore>, String> {
96 store::RedisStore::open(url)
97 .map(|s| Arc::new(s) as Arc<dyn RateLimitStore>)
98 .map_err(|e| format!("invalid Redis URL for rate-limit store: {e}"))
99}
100
101#[cfg(not(feature = "redis"))]
102fn open_redis(_url: &str) -> Result<Arc<dyn RateLimitStore>, String> {
103 tracing::warn!(
104 "shield.redis_url is set but the `redis` feature is not compiled in; \
105 falling back to the in-process store (per-replica limits only)"
106 );
107 Ok(Arc::new(MemoryStore::new()))
108}
109
110pub async fn middleware(
112 State(shield): State<Arc<Shield>>,
113 request: Request,
114 next: Next,
115) -> Response {
116 let path = request.uri().path().to_string();
117 let peer = request
118 .extensions()
119 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
120 .map(|ci| ci.0.ip());
121 let client = client_ip(peer, request.headers(), &shield.trusted_proxies);
122
123 let mut report = None;
125
126 if let Some(class) = shield.match_class(&path) {
128 let key = format!("class:{}:{client}", class.class);
129 let decision = shield.store.hit(&key, &class.rate).await;
130 if !decision.allowed {
131 return too_many_requests(decision);
132 }
133 report = Some(decision);
134 }
135
136 let request = if let Some(id_ep) = shield.match_identifier(&path) {
138 let (parts, body) = request.into_parts();
139 let bytes = match axum::body::to_bytes(body, MAX_IDENTIFIER_BODY).await {
140 Ok(b) => b,
141 Err(_) => return payload_too_large(),
142 };
143 let key = match extract_body_field(&bytes, &id_ep.body_field) {
146 Some(ident) => format!("id:{path}:{}:{ident}", id_ep.body_field),
147 None => format!("id:{path}:{}:ip:{client}", id_ep.body_field),
148 };
149 let decision = shield.store.hit(&key, &id_ep.rate).await;
150 if !decision.allowed {
151 return too_many_requests(decision);
152 }
153 report = Some(decision);
155 Request::from_parts(parts, Body::from(bytes))
156 } else {
157 request
158 };
159
160 let mut response = next.run(request).await;
161 if let Some(decision) = report {
162 attach_rate_headers(response.headers_mut(), &decision);
163 }
164 response
165}
166
167type Request = axum::extract::Request;
169
170fn parse_cidr(s: &str) -> Result<ipnet::IpNet, String> {
173 if let Ok(net) = s.parse::<ipnet::IpNet>() {
174 return Ok(net);
175 }
176 if let Ok(ip) = s.parse::<std::net::IpAddr>() {
177 let prefix = if ip.is_ipv4() { 32 } else { 128 };
178 return ipnet::IpNet::new(ip, prefix)
179 .map_err(|e| format!("invalid trusted_proxies entry {s:?}: {e}"));
180 }
181 Err(format!("invalid trusted_proxies CIDR/IP: {s:?}"))
182}
183
184fn client_ip(
193 peer: Option<std::net::IpAddr>,
194 headers: &HeaderMap,
195 trusted: &[ipnet::IpNet],
196) -> String {
197 match peer {
198 Some(ip) => {
199 if trusted.iter().any(|net| net.contains(&ip)) {
200 if let Some(client) = rightmost_untrusted(headers, trusted) {
201 return client;
202 }
203 }
204 ip.to_string()
205 }
206 None => best_effort_forwarded(headers).unwrap_or_else(|| "unknown".to_string()),
208 }
209}
210
211fn rightmost_untrusted(headers: &HeaderMap, trusted: &[ipnet::IpNet]) -> Option<String> {
214 if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
215 for hop in xff.split(',').rev() {
216 let hop = hop.trim();
217 if hop.is_empty() {
218 continue;
219 }
220 let trusted_hop = hop
221 .parse::<std::net::IpAddr>()
222 .is_ok_and(|ip| trusted.iter().any(|net| net.contains(&ip)));
223 if !trusted_hop {
224 return Some(hop.to_string());
225 }
226 }
227 }
228 header_str(headers, "x-real-ip")
230}
231
232fn best_effort_forwarded(headers: &HeaderMap) -> Option<String> {
235 headers
236 .get("x-forwarded-for")
237 .and_then(|v| v.to_str().ok())
238 .and_then(|v| v.split(',').next())
239 .map(str::trim)
240 .filter(|s| !s.is_empty())
241 .map(str::to_string)
242 .or_else(|| header_str(headers, "x-real-ip"))
243}
244
245fn header_str(headers: &HeaderMap, name: &str) -> Option<String> {
247 headers
248 .get(name)
249 .and_then(|v| v.to_str().ok())
250 .map(str::trim)
251 .filter(|s| !s.is_empty())
252 .map(str::to_string)
253}
254
255fn extract_body_field(bytes: &[u8], field: &str) -> Option<String> {
257 let value: serde_json::Value = serde_json::from_slice(bytes).ok()?;
258 let mut cur = &value;
259 for seg in field.split('.') {
260 cur = cur.get(seg)?;
261 }
262 match cur {
263 serde_json::Value::String(s) => Some(s.clone()),
264 serde_json::Value::Number(n) => Some(n.to_string()),
265 serde_json::Value::Bool(b) => Some(b.to_string()),
266 _ => None,
267 }
268}
269
270fn attach_rate_headers(headers: &mut HeaderMap, decision: &Decision) {
272 if let Ok(v) = decision.limit.to_string().parse() {
273 headers.insert("x-ratelimit-limit", v);
274 }
275 if let Ok(v) = decision.remaining.to_string().parse() {
276 headers.insert("x-ratelimit-remaining", v);
277 }
278}
279
280fn too_many_requests(decision: Decision) -> Response {
281 let retry = decision.retry_after.unwrap_or(Duration::ZERO).as_secs();
282 let mut response = (
283 StatusCode::TOO_MANY_REQUESTS,
284 Json(serde_json::json!({
285 "error": "RESOURCE_EXHAUSTED",
286 "message": "rate limit exceeded",
287 })),
288 )
289 .into_response();
290 let headers = response.headers_mut();
291 attach_rate_headers(headers, &decision);
292 if let Ok(v) = retry.to_string().parse() {
293 headers.insert("retry-after", v);
294 }
295 response
296}
297
298fn payload_too_large() -> Response {
299 (
300 StatusCode::PAYLOAD_TOO_LARGE,
301 Json(serde_json::json!({
302 "error": "INVALID_ARGUMENT",
303 "message": "request body too large for rate-limit identifier extraction",
304 })),
305 )
306 .into_response()
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn client_ip_trusts_headers_only_from_trusted_peer() {
315 let mut h = HeaderMap::new();
316 h.insert("x-forwarded-for", "203.0.113.7, 10.0.0.1".parse().unwrap());
317 let trusted = vec![parse_cidr("10.0.0.0/8").unwrap()];
318 let lb: std::net::IpAddr = "10.0.0.1".parse().unwrap();
319 let direct: std::net::IpAddr = "198.51.100.9".parse().unwrap();
320
321 assert_eq!(client_ip(Some(lb), &h, &trusted), "203.0.113.7");
323 assert_eq!(client_ip(Some(direct), &h, &trusted), "198.51.100.9");
325 }
326
327 #[test]
328 fn client_ip_uses_rightmost_untrusted_hop() {
329 let mut h = HeaderMap::new();
333 h.insert("x-forwarded-for", "1.1.1.1, 203.0.113.7".parse().unwrap());
334 let trusted = vec![parse_cidr("10.0.0.0/8").unwrap()];
335 let lb: std::net::IpAddr = "10.0.0.5".parse().unwrap();
336 assert_eq!(client_ip(Some(lb), &h, &trusted), "203.0.113.7");
337 }
338
339 #[test]
340 fn client_ip_without_peer_info_uses_headers_then_unknown() {
341 let mut h = HeaderMap::new();
342 h.insert("x-real-ip", "198.51.100.2".parse().unwrap());
343 assert_eq!(client_ip(None, &h, &[]), "198.51.100.2");
344 assert_eq!(client_ip(None, &HeaderMap::new(), &[]), "unknown");
345 }
346
347 #[test]
348 fn extract_body_field_reads_string_and_dotted() {
349 let body = br#"{"email":"a@b.com","nested":{"id":42}}"#;
350 assert_eq!(
351 extract_body_field(body, "email"),
352 Some("a@b.com".to_string())
353 );
354 assert_eq!(
355 extract_body_field(body, "nested.id"),
356 Some("42".to_string())
357 );
358 assert_eq!(extract_body_field(body, "missing"), None);
359 assert_eq!(extract_body_field(b"not json", "email"), None);
360 }
361
362 use crate::config::{EndpointClassConfig, IdentifierEndpointConfig, ShieldConfig};
365 use axum::http::Request as HttpRequest;
366 use tower::ServiceExt;
367
368 fn shield_config(
369 classes: Vec<EndpointClassConfig>,
370 ids: Vec<IdentifierEndpointConfig>,
371 ) -> ShieldConfig {
372 ShieldConfig {
373 enabled: true,
374 endpoint_classes: classes,
375 identifier_endpoints: ids,
376 window_secs: 60,
377 redis_url: None,
378 trusted_proxies: Vec::new(),
379 }
380 }
381
382 fn app(shield: Arc<Shield>) -> axum::Router {
383 axum::Router::new()
384 .route("/limited", axum::routing::get(|| async { "ok" }))
385 .route("/login", axum::routing::post(|| async { "ok" }))
386 .layer(axum::middleware::from_fn_with_state(shield, middleware))
387 }
388
389 #[tokio::test]
390 async fn middleware_blocks_after_endpoint_class_limit() {
391 let shield = Shield::build(&shield_config(
392 vec![EndpointClassConfig {
393 pattern: "/limited".into(),
394 class: "t".into(),
395 rate: "2/min".into(),
396 }],
397 vec![],
398 ))
399 .unwrap()
400 .unwrap();
401 let app = app(shield);
402
403 let get = || HttpRequest::get("/limited").body(Body::empty()).unwrap();
404 assert_eq!(app.clone().oneshot(get()).await.unwrap().status(), 200);
406 let second = app.clone().oneshot(get()).await.unwrap();
407 assert_eq!(second.status(), 200);
408 assert_eq!(second.headers()["x-ratelimit-remaining"], "0");
409 let third = app.clone().oneshot(get()).await.unwrap();
411 assert_eq!(third.status(), StatusCode::TOO_MANY_REQUESTS);
412 assert!(third.headers().contains_key("retry-after"));
413 }
414
415 #[tokio::test]
416 async fn middleware_limits_per_identifier_value() {
417 let shield = Shield::build(&shield_config(
418 vec![],
419 vec![IdentifierEndpointConfig {
420 path: "/login".into(),
421 body_field: "email".into(),
422 rate: "1/min".into(),
423 }],
424 ))
425 .unwrap()
426 .unwrap();
427 let app = app(shield);
428
429 let post = |email: &str| {
430 HttpRequest::post("/login")
431 .header("content-type", "application/json")
432 .body(Body::from(format!(r#"{{"email":"{email}"}}"#)))
433 .unwrap()
434 };
435
436 assert_eq!(
438 app.clone().oneshot(post("alice")).await.unwrap().status(),
439 200
440 );
441 assert_eq!(
442 app.clone().oneshot(post("alice")).await.unwrap().status(),
443 StatusCode::TOO_MANY_REQUESTS
444 );
445 assert_eq!(
447 app.clone().oneshot(post("bob")).await.unwrap().status(),
448 200
449 );
450 }
451
452 #[tokio::test]
453 async fn identifier_response_carries_ratelimit_headers() {
454 let shield = Shield::build(&shield_config(
455 vec![],
456 vec![IdentifierEndpointConfig {
457 path: "/login".into(),
458 body_field: "email".into(),
459 rate: "5/min".into(),
460 }],
461 ))
462 .unwrap()
463 .unwrap();
464 let app = app(shield);
465
466 let resp = app
467 .oneshot(
468 HttpRequest::post("/login")
469 .header("content-type", "application/json")
470 .body(Body::from(r#"{"email":"a"}"#))
471 .unwrap(),
472 )
473 .await
474 .unwrap();
475 assert_eq!(resp.status(), 200);
476 assert_eq!(resp.headers()["x-ratelimit-limit"], "5");
478 assert_eq!(resp.headers()["x-ratelimit-remaining"], "4");
479 }
480
481 #[tokio::test]
482 async fn identifier_limit_not_bypassed_when_field_absent() {
483 let shield = Shield::build(&shield_config(
486 vec![],
487 vec![IdentifierEndpointConfig {
488 path: "/login".into(),
489 body_field: "email".into(),
490 rate: "1/min".into(),
491 }],
492 ))
493 .unwrap()
494 .unwrap();
495 let app = app(shield);
496
497 let post_no_email = || {
498 HttpRequest::post("/login")
499 .header("content-type", "application/json")
500 .body(Body::from("{}"))
501 .unwrap()
502 };
503 assert_eq!(
504 app.clone().oneshot(post_no_email()).await.unwrap().status(),
505 200
506 );
507 assert_eq!(
509 app.clone().oneshot(post_no_email()).await.unwrap().status(),
510 StatusCode::TOO_MANY_REQUESTS
511 );
512 }
513
514 #[tokio::test]
515 async fn middleware_ignores_unmatched_paths() {
516 let shield = Shield::build(&shield_config(
517 vec![EndpointClassConfig {
518 pattern: "/limited".into(),
519 class: "t".into(),
520 rate: "1/min".into(),
521 }],
522 vec![],
523 ))
524 .unwrap()
525 .unwrap();
526 let app = app(shield);
527
528 for _ in 0..5 {
530 let resp = app
531 .clone()
532 .oneshot(HttpRequest::post("/login").body(Body::empty()).unwrap())
533 .await
534 .unwrap();
535 assert_eq!(resp.status(), 200);
536 }
537 }
538}