Skip to main content

structured_proxy/shield/
mod.rs

1//! Shield: request rate limiting.
2//!
3//! Enforces the `shield` config as an axum middleware. Two rule kinds are
4//! supported: endpoint *classes* (glob path → per-client limit) and per
5//! *identifier* endpoints (limit by a value read from the request body). The
6//! counter backend is pluggable via [`RateLimitStore`]; the default is an
7//! in-process store, with an optional Redis store for multi-instance setups.
8
9pub mod matcher;
10pub mod rate;
11pub mod store;
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use axum::body::Body;
17use axum::extract::State;
18use axum::http::{HeaderMap, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::Json;
22
23use crate::config::ShieldConfig;
24use matcher::{EndpointClass, IdentifierEndpoint};
25use store::{Decision, MemoryStore, RateLimitStore};
26
27/// Maximum request body buffered to read an identifier field (256 KiB).
28const MAX_IDENTIFIER_BODY: usize = 256 * 1024;
29
30/// Compiled Shield rules plus the counter store.
31pub struct Shield {
32    store: Arc<dyn RateLimitStore>,
33    classes: Vec<EndpointClass>,
34    identifiers: Vec<IdentifierEndpoint>,
35    /// CIDR ranges whose `X-Forwarded-For` / `X-Real-IP` headers we trust.
36    trusted_proxies: Vec<ipnet::IpNet>,
37}
38
39impl Shield {
40    /// Build a Shield from config, or `None` when disabled / has no rules.
41    ///
42    /// Uses a Redis store when `redis_url` is set and the `redis` feature is
43    /// compiled in; otherwise an in-process [`MemoryStore`].
44    ///
45    /// # Errors
46    /// Returns an error string when a glob pattern or rate fails to compile, or
47    /// when the configured Redis backend cannot be reached.
48    pub fn build(config: &ShieldConfig) -> Result<Option<Arc<Self>>, String> {
49        if !config.enabled {
50            return Ok(None);
51        }
52        if config.endpoint_classes.is_empty() && config.identifier_endpoints.is_empty() {
53            tracing::warn!("shield enabled but no endpoint_classes or identifier_endpoints set");
54            return Ok(None);
55        }
56
57        let default_window = Duration::from_secs(config.window_secs.max(1));
58        let classes = matcher::compile_endpoint_classes(&config.endpoint_classes, default_window)?;
59        let identifiers =
60            matcher::compile_identifier_endpoints(&config.identifier_endpoints, default_window)?;
61
62        let trusted_proxies = config
63            .trusted_proxies
64            .iter()
65            .map(|s| parse_cidr(s))
66            .collect::<Result<Vec<_>, _>>()?;
67
68        let store = build_store(config)?;
69        Ok(Some(Arc::new(Self {
70            store,
71            classes,
72            identifiers,
73            trusted_proxies,
74        })))
75    }
76
77    fn match_class(&self, path: &str) -> Option<&EndpointClass> {
78        self.classes.iter().find(|c| c.matcher.is_match(path))
79    }
80
81    fn match_identifier(&self, path: &str) -> Option<&IdentifierEndpoint> {
82        self.identifiers.iter().find(|i| i.matcher.is_match(path))
83    }
84}
85
86/// Select the counter store from config.
87fn build_store(config: &ShieldConfig) -> Result<Arc<dyn RateLimitStore>, String> {
88    match &config.redis_url {
89        Some(url) => open_redis(url),
90        None => Ok(Arc::new(MemoryStore::new())),
91    }
92}
93
94#[cfg(feature = "redis")]
95fn open_redis(url: &str) -> Result<Arc<dyn RateLimitStore>, String> {
96    store::RedisStore::open(url)
97        .map(|s| Arc::new(s) as Arc<dyn RateLimitStore>)
98        .map_err(|e| format!("invalid Redis URL for rate-limit store: {e}"))
99}
100
101#[cfg(not(feature = "redis"))]
102fn open_redis(_url: &str) -> Result<Arc<dyn RateLimitStore>, String> {
103    tracing::warn!(
104        "shield.redis_url is set but the `redis` feature is not compiled in; \
105         falling back to the in-process store (per-replica limits only)"
106    );
107    Ok(Arc::new(MemoryStore::new()))
108}
109
110/// Axum middleware enforcing the compiled Shield rules.
111pub async fn middleware(
112    State(shield): State<Arc<Shield>>,
113    request: Request,
114    next: Next,
115) -> Response {
116    let path = request.uri().path().to_string();
117    let peer = request
118        .extensions()
119        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
120        .map(|ci| ci.0.ip());
121    let client = client_ip(peer, request.headers(), &shield.trusted_proxies);
122
123    // The decision whose budget we surface to the client via headers.
124    let mut report = None;
125
126    // Endpoint-class limit (per client IP), no body needed.
127    if let Some(class) = shield.match_class(&path) {
128        let key = format!("class:{}:{client}", class.class);
129        let decision = shield.store.hit(&key, &class.rate).await;
130        if !decision.allowed {
131            return too_many_requests(decision);
132        }
133        report = Some(decision);
134    }
135
136    // Per-identifier limit: buffer the body, read the field, then restore it.
137    let request = if let Some(id_ep) = shield.match_identifier(&path) {
138        let (parts, body) = request.into_parts();
139        let bytes = match axum::body::to_bytes(body, MAX_IDENTIFIER_BODY).await {
140            Ok(b) => b,
141            Err(_) => return payload_too_large(),
142        };
143        // Key by the identifier value, or fall back to the client IP when the
144        // field is absent so the limit can't be bypassed by omitting it.
145        let key = match extract_body_field(&bytes, &id_ep.body_field) {
146            Some(ident) => format!("id:{path}:{}:{ident}", id_ep.body_field),
147            None => format!("id:{path}:{}:ip:{client}", id_ep.body_field),
148        };
149        let decision = shield.store.hit(&key, &id_ep.rate).await;
150        if !decision.allowed {
151            return too_many_requests(decision);
152        }
153        // The identifier rule is more specific than a class rule, so report it.
154        report = Some(decision);
155        Request::from_parts(parts, Body::from(bytes))
156    } else {
157        request
158    };
159
160    let mut response = next.run(request).await;
161    if let Some(decision) = report {
162        attach_rate_headers(response.headers_mut(), &decision);
163    }
164    response
165}
166
167/// Convenience alias for the axum request type used by the middleware.
168type Request = axum::extract::Request;
169
170/// Parse a trusted-proxy entry as a CIDR range, accepting a bare IP as a /32
171/// or /128 host range.
172fn parse_cidr(s: &str) -> Result<ipnet::IpNet, String> {
173    if let Ok(net) = s.parse::<ipnet::IpNet>() {
174        return Ok(net);
175    }
176    if let Ok(ip) = s.parse::<std::net::IpAddr>() {
177        let prefix = if ip.is_ipv4() { 32 } else { 128 };
178        return ipnet::IpNet::new(ip, prefix)
179            .map_err(|e| format!("invalid trusted_proxies entry {s:?}: {e}"));
180    }
181    Err(format!("invalid trusted_proxies CIDR/IP: {s:?}"))
182}
183
184/// Resolve the client identity for keying.
185///
186/// `X-Forwarded-For` is trusted only when the direct `peer` is a configured
187/// trusted proxy, and even then the *rightmost* hop outside the trusted ranges
188/// is used: appending load balancers (nginx, ALB, GCP) add the connecting IP on
189/// the right, so the leftmost entries are attacker-controlled. Without
190/// connection info (e.g. a custom server that does not provide it) the headers
191/// are taken as a best effort.
192fn client_ip(
193    peer: Option<std::net::IpAddr>,
194    headers: &HeaderMap,
195    trusted: &[ipnet::IpNet],
196) -> String {
197    match peer {
198        Some(ip) => {
199            if trusted.iter().any(|net| net.contains(&ip)) {
200                if let Some(client) = rightmost_untrusted(headers, trusted) {
201                    return client;
202                }
203            }
204            ip.to_string()
205        }
206        // No connection info: fall back to the headers as a best effort.
207        None => best_effort_forwarded(headers).unwrap_or_else(|| "unknown".to_string()),
208    }
209}
210
211/// Rightmost `X-Forwarded-For` hop that is not within a trusted range, i.e. the
212/// last address appended by an untrusted party. Falls back to `X-Real-IP`.
213fn rightmost_untrusted(headers: &HeaderMap, trusted: &[ipnet::IpNet]) -> Option<String> {
214    if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
215        for hop in xff.split(',').rev() {
216            let hop = hop.trim();
217            if hop.is_empty() {
218                continue;
219            }
220            let trusted_hop = hop
221                .parse::<std::net::IpAddr>()
222                .is_ok_and(|ip| trusted.iter().any(|net| net.contains(&ip)));
223            if !trusted_hop {
224                return Some(hop.to_string());
225            }
226        }
227    }
228    // X-Real-IP is set by the proxy to the single real client address.
229    header_str(headers, "x-real-ip")
230}
231
232/// Best-effort client from forwarding headers when no peer is known: leftmost
233/// `X-Forwarded-For`, then `X-Real-IP`.
234fn best_effort_forwarded(headers: &HeaderMap) -> Option<String> {
235    headers
236        .get("x-forwarded-for")
237        .and_then(|v| v.to_str().ok())
238        .and_then(|v| v.split(',').next())
239        .map(str::trim)
240        .filter(|s| !s.is_empty())
241        .map(str::to_string)
242        .or_else(|| header_str(headers, "x-real-ip"))
243}
244
245/// Trimmed, non-empty value of a header.
246fn header_str(headers: &HeaderMap, name: &str) -> Option<String> {
247    headers
248        .get(name)
249        .and_then(|v| v.to_str().ok())
250        .map(str::trim)
251        .filter(|s| !s.is_empty())
252        .map(str::to_string)
253}
254
255/// Read a (possibly dotted) field from a JSON body as a string identifier.
256fn extract_body_field(bytes: &[u8], field: &str) -> Option<String> {
257    let value: serde_json::Value = serde_json::from_slice(bytes).ok()?;
258    let mut cur = &value;
259    for seg in field.split('.') {
260        cur = cur.get(seg)?;
261    }
262    match cur {
263        serde_json::Value::String(s) => Some(s.clone()),
264        serde_json::Value::Number(n) => Some(n.to_string()),
265        serde_json::Value::Bool(b) => Some(b.to_string()),
266        _ => None,
267    }
268}
269
270/// Add `X-RateLimit-*` headers describing the remaining budget.
271fn attach_rate_headers(headers: &mut HeaderMap, decision: &Decision) {
272    if let Ok(v) = decision.limit.to_string().parse() {
273        headers.insert("x-ratelimit-limit", v);
274    }
275    if let Ok(v) = decision.remaining.to_string().parse() {
276        headers.insert("x-ratelimit-remaining", v);
277    }
278}
279
280fn too_many_requests(decision: Decision) -> Response {
281    let retry = decision.retry_after.unwrap_or(Duration::ZERO).as_secs();
282    let mut response = (
283        StatusCode::TOO_MANY_REQUESTS,
284        Json(serde_json::json!({
285            "error": "RESOURCE_EXHAUSTED",
286            "message": "rate limit exceeded",
287        })),
288    )
289        .into_response();
290    let headers = response.headers_mut();
291    attach_rate_headers(headers, &decision);
292    if let Ok(v) = retry.to_string().parse() {
293        headers.insert("retry-after", v);
294    }
295    response
296}
297
298fn payload_too_large() -> Response {
299    (
300        StatusCode::PAYLOAD_TOO_LARGE,
301        Json(serde_json::json!({
302            "error": "INVALID_ARGUMENT",
303            "message": "request body too large for rate-limit identifier extraction",
304        })),
305    )
306        .into_response()
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn client_ip_trusts_headers_only_from_trusted_peer() {
315        let mut h = HeaderMap::new();
316        h.insert("x-forwarded-for", "203.0.113.7, 10.0.0.1".parse().unwrap());
317        let trusted = vec![parse_cidr("10.0.0.0/8").unwrap()];
318        let lb: std::net::IpAddr = "10.0.0.1".parse().unwrap();
319        let direct: std::net::IpAddr = "198.51.100.9".parse().unwrap();
320
321        // From a trusted proxy: trust the forwarded first hop.
322        assert_eq!(client_ip(Some(lb), &h, &trusted), "203.0.113.7");
323        // From an untrusted direct client: ignore the spoofable header, key by peer.
324        assert_eq!(client_ip(Some(direct), &h, &trusted), "198.51.100.9");
325    }
326
327    #[test]
328    fn client_ip_uses_rightmost_untrusted_hop() {
329        // Appending LBs (nginx proxy_add_x_forwarded_for, AWS ALB, GCP LB) add
330        // the connecting IP on the RIGHT, so the leftmost entry is attacker
331        // controlled. The real client is the rightmost hop outside trusted ranges.
332        let mut h = HeaderMap::new();
333        h.insert("x-forwarded-for", "1.1.1.1, 203.0.113.7".parse().unwrap());
334        let trusted = vec![parse_cidr("10.0.0.0/8").unwrap()];
335        let lb: std::net::IpAddr = "10.0.0.5".parse().unwrap();
336        assert_eq!(client_ip(Some(lb), &h, &trusted), "203.0.113.7");
337    }
338
339    #[test]
340    fn client_ip_without_peer_info_uses_headers_then_unknown() {
341        let mut h = HeaderMap::new();
342        h.insert("x-real-ip", "198.51.100.2".parse().unwrap());
343        assert_eq!(client_ip(None, &h, &[]), "198.51.100.2");
344        assert_eq!(client_ip(None, &HeaderMap::new(), &[]), "unknown");
345    }
346
347    #[test]
348    fn extract_body_field_reads_string_and_dotted() {
349        let body = br#"{"email":"a@b.com","nested":{"id":42}}"#;
350        assert_eq!(
351            extract_body_field(body, "email"),
352            Some("a@b.com".to_string())
353        );
354        assert_eq!(
355            extract_body_field(body, "nested.id"),
356            Some("42".to_string())
357        );
358        assert_eq!(extract_body_field(body, "missing"), None);
359        assert_eq!(extract_body_field(b"not json", "email"), None);
360    }
361
362    // --- middleware enforcement (real router) ---
363
364    use crate::config::{EndpointClassConfig, IdentifierEndpointConfig, ShieldConfig};
365    use axum::http::Request as HttpRequest;
366    use tower::ServiceExt;
367
368    fn shield_config(
369        classes: Vec<EndpointClassConfig>,
370        ids: Vec<IdentifierEndpointConfig>,
371    ) -> ShieldConfig {
372        ShieldConfig {
373            enabled: true,
374            endpoint_classes: classes,
375            identifier_endpoints: ids,
376            window_secs: 60,
377            redis_url: None,
378            trusted_proxies: Vec::new(),
379        }
380    }
381
382    fn app(shield: Arc<Shield>) -> axum::Router {
383        axum::Router::new()
384            .route("/limited", axum::routing::get(|| async { "ok" }))
385            .route("/login", axum::routing::post(|| async { "ok" }))
386            .layer(axum::middleware::from_fn_with_state(shield, middleware))
387    }
388
389    #[tokio::test]
390    async fn middleware_blocks_after_endpoint_class_limit() {
391        let shield = Shield::build(&shield_config(
392            vec![EndpointClassConfig {
393                pattern: "/limited".into(),
394                class: "t".into(),
395                rate: "2/min".into(),
396            }],
397            vec![],
398        ))
399        .unwrap()
400        .unwrap();
401        let app = app(shield);
402
403        let get = || HttpRequest::get("/limited").body(Body::empty()).unwrap();
404        // Two allowed (no client IP header → all share the "unknown" bucket).
405        assert_eq!(app.clone().oneshot(get()).await.unwrap().status(), 200);
406        let second = app.clone().oneshot(get()).await.unwrap();
407        assert_eq!(second.status(), 200);
408        assert_eq!(second.headers()["x-ratelimit-remaining"], "0");
409        // Third over the limit → 429 with Retry-After.
410        let third = app.clone().oneshot(get()).await.unwrap();
411        assert_eq!(third.status(), StatusCode::TOO_MANY_REQUESTS);
412        assert!(third.headers().contains_key("retry-after"));
413    }
414
415    #[tokio::test]
416    async fn middleware_limits_per_identifier_value() {
417        let shield = Shield::build(&shield_config(
418            vec![],
419            vec![IdentifierEndpointConfig {
420                path: "/login".into(),
421                body_field: "email".into(),
422                rate: "1/min".into(),
423            }],
424        ))
425        .unwrap()
426        .unwrap();
427        let app = app(shield);
428
429        let post = |email: &str| {
430            HttpRequest::post("/login")
431                .header("content-type", "application/json")
432                .body(Body::from(format!(r#"{{"email":"{email}"}}"#)))
433                .unwrap()
434        };
435
436        // First request for alice is allowed, second is blocked.
437        assert_eq!(
438            app.clone().oneshot(post("alice")).await.unwrap().status(),
439            200
440        );
441        assert_eq!(
442            app.clone().oneshot(post("alice")).await.unwrap().status(),
443            StatusCode::TOO_MANY_REQUESTS
444        );
445        // A different identifier has its own budget.
446        assert_eq!(
447            app.clone().oneshot(post("bob")).await.unwrap().status(),
448            200
449        );
450    }
451
452    #[tokio::test]
453    async fn identifier_response_carries_ratelimit_headers() {
454        let shield = Shield::build(&shield_config(
455            vec![],
456            vec![IdentifierEndpointConfig {
457                path: "/login".into(),
458                body_field: "email".into(),
459                rate: "5/min".into(),
460            }],
461        ))
462        .unwrap()
463        .unwrap();
464        let app = app(shield);
465
466        let resp = app
467            .oneshot(
468                HttpRequest::post("/login")
469                    .header("content-type", "application/json")
470                    .body(Body::from(r#"{"email":"a"}"#))
471                    .unwrap(),
472            )
473            .await
474            .unwrap();
475        assert_eq!(resp.status(), 200);
476        // Headers come from the identifier decision, not just class decisions.
477        assert_eq!(resp.headers()["x-ratelimit-limit"], "5");
478        assert_eq!(resp.headers()["x-ratelimit-remaining"], "4");
479    }
480
481    #[tokio::test]
482    async fn identifier_limit_not_bypassed_when_field_absent() {
483        // Omitting the identifier field must not skip the limit: requests with
484        // no extractable identifier fall back to a client-keyed counter.
485        let shield = Shield::build(&shield_config(
486            vec![],
487            vec![IdentifierEndpointConfig {
488                path: "/login".into(),
489                body_field: "email".into(),
490                rate: "1/min".into(),
491            }],
492        ))
493        .unwrap()
494        .unwrap();
495        let app = app(shield);
496
497        let post_no_email = || {
498            HttpRequest::post("/login")
499                .header("content-type", "application/json")
500                .body(Body::from("{}"))
501                .unwrap()
502        };
503        assert_eq!(
504            app.clone().oneshot(post_no_email()).await.unwrap().status(),
505            200
506        );
507        // Second request without the field is still counted → blocked.
508        assert_eq!(
509            app.clone().oneshot(post_no_email()).await.unwrap().status(),
510            StatusCode::TOO_MANY_REQUESTS
511        );
512    }
513
514    #[tokio::test]
515    async fn middleware_ignores_unmatched_paths() {
516        let shield = Shield::build(&shield_config(
517            vec![EndpointClassConfig {
518                pattern: "/limited".into(),
519                class: "t".into(),
520                rate: "1/min".into(),
521            }],
522            vec![],
523        ))
524        .unwrap()
525        .unwrap();
526        let app = app(shield);
527
528        // /login is not covered by any rule → never limited.
529        for _ in 0..5 {
530            let resp = app
531                .clone()
532                .oneshot(HttpRequest::post("/login").body(Body::empty()).unwrap())
533                .await
534                .unwrap();
535            assert_eq!(resp.status(), 200);
536        }
537    }
538}