1use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use arc_swap::ArcSwap;
20use axum::{
21 body::{Body, Bytes},
22 extract::{ConnectInfo, State},
23 http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
24};
25use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
26use http_body_util::{BodyExt, Full, Limited};
27use hyper::body::{Body as HttpBody, Frame, SizeHint};
28use hyper_util::client::legacy::{connect::HttpConnector, Client};
29use tokio::net::TcpStream;
30use tracing::{debug, info, warn};
31
32use crate::auth::{AuthEngine, Challenge, Decision};
33use crate::config::{Config, HeadersCfg};
34use crate::limiter::{Admit, DistributedLimiter};
35use crate::metrics::Metrics;
36use crate::waf::{WafEngine, WafMode};
37
38pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
39pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
41pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
42
43#[derive(Clone)]
46pub struct AppState {
47 pub client: UpstreamClient,
48 pub metrics: Arc<Metrics>,
49 pub runtime: Arc<ArcSwap<Runtime>>,
50 pub cp: Option<Arc<crate::cp::CpClient>>,
53}
54
55pub struct RouteLimiter {
57 pub prefix: String,
58 pub limiter: Arc<KeyedLimiter>,
59}
60
61pub struct Runtime {
64 pub cfg: Arc<Config>,
65 pub upstream_base: Arc<String>,
66 pub auth: AuthEngine,
67 pub waf: WafEngine,
69 pub distributed: Option<DistributedLimiter>,
72 pub ip_limiter: Option<Arc<KeyedLimiter>>,
74 pub route_limiters: Vec<RouteLimiter>,
76 pub key_limiter: Option<Arc<StrLimiter>>,
78 pub max_body: usize,
79 pub max_response_body: usize,
81 pub max_header_bytes: usize,
83 pub upstream_timeout: Option<Duration>,
85 pub stream_passthrough: bool,
88}
89
90const HOP_BY_HOP: &[&str] = &[
92 "connection",
93 "keep-alive",
94 "proxy-authenticate",
95 "proxy-authorization",
96 "te",
97 "trailer",
98 "transfer-encoding",
99 "upgrade",
100];
101
102pub async fn handle(
103 State(state): State<AppState>,
104 ConnectInfo(peer): ConnectInfo<SocketAddr>,
105 req: Request<Body>,
106) -> Response<Body> {
107 let started = Instant::now();
108 let rt = state.runtime.load();
111 let m = &state.metrics;
112
113 let method = req.method().clone();
114 let path = req
115 .uri()
116 .path_and_query()
117 .map(|p| p.as_str().to_string())
118 .unwrap_or_else(|| req.uri().path().to_string());
119
120 let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
121
122 if req.uri().path().starts_with("/__edgeguard/") {
128 return finish(
129 m,
130 &method,
131 &path,
132 ip,
133 started,
134 "not_found",
135 text(StatusCode::NOT_FOUND, "Not Found"),
136 );
137 }
138
139 if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
141 return finish(
142 m,
143 &method,
144 &path,
145 ip,
146 started,
147 "header_too_large",
148 text(
149 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
150 "Request Header Fields Too Large",
151 ),
152 );
153 }
154
155 if rt.cfg.ratelimit.enabled {
159 if let Some(d) = &rt.distributed {
160 match d.check_ip_route(ip, &path).await {
161 Admit::Allowed => {}
162 Admit::Limited(scope) => {
163 m.record_ratelimit_hit(scope);
164 return finish(
165 m,
166 &method,
167 &path,
168 ip,
169 started,
170 "rate_limited",
171 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
172 );
173 }
174 Admit::Error => {
175 return finish(
176 m,
177 &method,
178 &path,
179 ip,
180 started,
181 "limiter_error",
182 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
183 );
184 }
185 }
186 } else {
187 let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
188 Some(r) => (Some(r.limiter.as_ref()), "route"),
189 None => (rt.ip_limiter.as_deref(), "ip"),
190 };
191 if let Some(limiter) = limiter {
192 if limiter.check_key(&ip).is_err() {
193 m.record_ratelimit_hit(scope);
194 return finish(
195 m,
196 &method,
197 &path,
198 ip,
199 started,
200 "rate_limited",
201 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
202 );
203 }
204 }
205 }
206 }
207
208 let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
210 Decision::Allow(principal) => principal,
211 Decision::Deny(challenge) => {
212 let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
213 let challenge_value = match challenge {
214 Challenge::Basic(c) => Some(c),
215 Challenge::Bearer => Some("Bearer".to_string()),
216 Challenge::None => None,
217 };
218 if let Some(c) = challenge_value {
219 if let Ok(v) = HeaderValue::from_str(&c) {
220 resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
221 }
222 }
223 return finish(m, &method, &path, ip, started, "unauthorized", resp);
224 }
225 };
226
227 if let Some(principal) = &principal {
230 let key_admit = if let Some(d) = &rt.distributed {
231 Some(d.check_key(principal).await)
232 } else {
233 rt.key_limiter.as_ref().map(|limiter| {
234 if limiter.check_key(principal).is_err() {
235 Admit::Limited("key")
236 } else {
237 Admit::Allowed
238 }
239 })
240 };
241 match key_admit {
242 Some(Admit::Limited(scope)) => {
243 m.record_ratelimit_hit(scope);
244 return finish(
245 m,
246 &method,
247 &path,
248 ip,
249 started,
250 "rate_limited",
251 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
252 );
253 }
254 Some(Admit::Error) => {
255 return finish(
256 m,
257 &method,
258 &path,
259 ip,
260 started,
261 "limiter_error",
262 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
263 );
264 }
265 Some(Admit::Allowed) | None => {}
266 }
267 }
268
269 let allow = &rt.cfg.validation.allow_methods;
271 if !allow.is_empty()
272 && !allow
273 .iter()
274 .any(|x| x.eq_ignore_ascii_case(method.as_str()))
275 {
276 return finish(
277 m,
278 &method,
279 &path,
280 ip,
281 started,
282 "method_not_allowed",
283 text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
284 );
285 }
286
287 let (parts, body) = req.into_parts();
289 let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
290 Ok(b) => b,
291 Err(_) => {
292 return finish(
293 m,
294 &method,
295 &path,
296 ip,
297 started,
298 "payload_too_large",
299 text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
300 )
301 }
302 };
303 let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
305
306 if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
311 m.record_waf_hit(hit.class);
312 match rt.waf.mode() {
313 WafMode::Block => {
314 warn!(
315 rule = %hit.rule_id,
316 class = hit.class,
317 location = hit.location,
318 client_ip = %ip,
319 path = %path,
320 "WAF blocked request"
321 );
322 return finish(
323 m,
324 &method,
325 &path,
326 ip,
327 started,
328 "forbidden",
329 text(StatusCode::FORBIDDEN, "Forbidden"),
330 );
331 }
332 WafMode::Report => warn!(
333 rule = %hit.rule_id,
334 class = hit.class,
335 location = hit.location,
336 client_ip = %ip,
337 path = %path,
338 "WAF rule matched (report-only)"
339 ),
340 WafMode::Off => {}
343 }
344 }
345
346 let uri = format!("{}{}", rt.upstream_base, path);
348 let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
349 {
350 let headers = up.headers_mut().expect("builder headers");
351 let mut forwarded = parts.headers.clone();
354 strip_hop_by_hop(&mut forwarded);
355 for (name, value) in forwarded.iter() {
356 if name == header::HOST {
357 continue; }
359 headers.insert(name.clone(), value.clone());
360 }
361 if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
363 headers.insert(HeaderName::from_static("x-forwarded-for"), v);
364 }
365 headers.insert(
366 HeaderName::from_static("x-forwarded-proto"),
367 HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
368 );
369 }
370
371 let upstream_req = match up.body(Full::new(body_bytes)) {
372 Ok(r) => r,
373 Err(e) => {
374 warn!(error = %e, "failed to build upstream request");
375 return finish(
376 m,
377 &method,
378 &path,
379 ip,
380 started,
381 "bad_gateway",
382 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
383 );
384 }
385 };
386
387 let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
390 let timed_out = || {
391 warn!(upstream = %uri, "upstream timed out");
392 text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
393 };
394
395 let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
396 Ok(Ok(r)) => r,
397 Ok(Err(e)) => {
398 warn!(error = %e, upstream = %uri, "upstream unreachable");
399 return finish(
400 m,
401 &method,
402 &path,
403 ip,
404 started,
405 "upstream_error",
406 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
407 );
408 }
409 Err(_) => {
410 return finish(
411 m,
412 &method,
413 &path,
414 ip,
415 started,
416 "upstream_timeout",
417 timed_out(),
418 )
419 }
420 };
421
422 let (mut resp_parts, resp_body) = upstream_resp.into_parts();
423
424 if rt.stream_passthrough && is_event_stream(&resp_parts.headers) {
432 strip_hop_by_hop(&mut resp_parts.headers);
433 resp_parts.headers.remove(header::CONTENT_LENGTH);
434 let header_egress = header_bytes(&resp_parts.headers);
435 let body = Body::new(CountingBody::new(
436 resp_body,
437 Arc::clone(m),
438 ingress_bytes,
439 header_egress,
440 ));
441 let mut response = Response::from_parts(resp_parts, body);
442 harden_response(&rt.cfg, &mut response);
443 return finish(m, &method, &path, ip, started, "ok", response);
444 }
445
446 let resp_bytes = if rt.max_response_body > 0 {
448 match within(
449 deadline,
450 Limited::new(resp_body, rt.max_response_body).collect(),
451 )
452 .await
453 {
454 Ok(Ok(c)) => c.to_bytes(),
455 Ok(Err(_)) => {
456 warn!(
457 limit = rt.max_response_body,
458 "upstream response exceeded max_response_body"
459 );
460 return finish(
461 m,
462 &method,
463 &path,
464 ip,
465 started,
466 "upstream_body_too_large",
467 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
468 );
469 }
470 Err(_) => {
471 return finish(
472 m,
473 &method,
474 &path,
475 ip,
476 started,
477 "upstream_timeout",
478 timed_out(),
479 )
480 }
481 }
482 } else {
483 match within(deadline, resp_body.collect()).await {
484 Ok(Ok(c)) => c.to_bytes(),
485 Ok(Err(e)) => {
486 warn!(error = %e, "failed reading upstream body");
487 return finish(
488 m,
489 &method,
490 &path,
491 ip,
492 started,
493 "upstream_body_error",
494 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
495 );
496 }
497 Err(_) => {
498 return finish(
499 m,
500 &method,
501 &path,
502 ip,
503 started,
504 "upstream_timeout",
505 timed_out(),
506 )
507 }
508 }
509 };
510
511 strip_hop_by_hop(&mut resp_parts.headers);
514 resp_parts.headers.remove(header::CONTENT_LENGTH);
515
516 m.add_usage_bytes(
519 ingress_bytes,
520 header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
521 );
522
523 let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
524 harden_response(&rt.cfg, &mut response);
525
526 finish(m, &method, &path, ip, started, "ok", response)
527}
528
529pub async fn ready(State(state): State<AppState>) -> StatusCode {
534 let rt = state.runtime.load();
535 let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
536 return StatusCode::SERVICE_UNAVAILABLE;
537 };
538 match tokio::time::timeout(
539 Duration::from_secs(2),
540 TcpStream::connect((host.as_str(), port)),
541 )
542 .await
543 {
544 Ok(Ok(_)) => StatusCode::OK,
545 _ => StatusCode::SERVICE_UNAVAILABLE,
546 }
547}
548
549pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
553 let body = state.metrics.render();
554 let mut resp = Response::new(Body::from(body));
555 resp.headers_mut().insert(
556 header::CONTENT_TYPE,
557 HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
558 );
559 resp
560}
561
562pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
565 state.metrics.record_csp_report();
566 if let Some(cp) = &state.cp {
570 if state.runtime.load().cfg.control_plane.forward_csp {
571 let cp = cp.clone();
572 let raw = body.clone();
573 tokio::spawn(async move { cp.forward_csp(&raw).await });
574 }
575 }
576 match serde_json::from_slice::<serde_json::Value>(&body) {
580 Ok(report) => {
581 let directive = report
582 .get("csp-report")
583 .and_then(|r| {
584 r.get("violated-directive")
585 .or_else(|| r.get("effective-directive"))
586 })
587 .and_then(|v| v.as_str())
588 .unwrap_or("unknown");
589 debug!(target: "edgeguard::csp", directive, "CSP violation report");
590 }
591 Err(_) => warn!(
592 bytes = body.len(),
593 "CSP violation report with an unparseable body"
594 ),
595 }
596 StatusCode::NO_CONTENT
597}
598
599fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
603 if trust_forwarded {
604 if let Some(xff) = headers.get("x-forwarded-for") {
605 if let Ok(s) = xff.to_str() {
606 if let Some(first) = s.split(',').next() {
607 if let Ok(ip) = first.trim().parse::<IpAddr>() {
608 return ip;
609 }
610 }
611 }
612 }
613 }
614 peer.ip()
615}
616
617fn header_bytes(headers: &HeaderMap) -> usize {
620 headers
621 .iter()
622 .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
623 .sum()
624}
625
626fn is_event_stream(headers: &HeaderMap) -> bool {
630 headers
631 .get(header::CONTENT_TYPE)
632 .and_then(|v| v.to_str().ok())
633 .map(|v| {
634 v.split(';')
635 .next()
636 .map(str::trim)
637 .map(|ct| ct.eq_ignore_ascii_case("text/event-stream"))
638 .unwrap_or(false)
639 })
640 .unwrap_or(false)
641}
642
643struct CountingBody<B> {
649 inner: B,
650 metrics: Arc<Metrics>,
651 ingress: usize,
652 egress: usize,
654}
655
656impl<B> CountingBody<B> {
657 fn new(inner: B, metrics: Arc<Metrics>, ingress: usize, header_egress: usize) -> Self {
658 Self {
659 inner,
660 metrics,
661 ingress,
662 egress: header_egress,
663 }
664 }
665}
666
667impl<B> HttpBody for CountingBody<B>
668where
669 B: HttpBody<Data = Bytes> + Unpin,
670{
671 type Data = Bytes;
672 type Error = B::Error;
673
674 fn poll_frame(
675 mut self: Pin<&mut Self>,
676 cx: &mut Context<'_>,
677 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
678 let this = self.as_mut().get_mut();
679 let polled = Pin::new(&mut this.inner).poll_frame(cx);
680 if let Poll::Ready(Some(Ok(frame))) = &polled {
681 if let Some(data) = frame.data_ref() {
682 this.egress = this.egress.saturating_add(data.len());
683 }
684 }
685 polled
686 }
687
688 fn is_end_stream(&self) -> bool {
689 self.inner.is_end_stream()
690 }
691
692 fn size_hint(&self) -> SizeHint {
693 self.inner.size_hint()
694 }
695}
696
697impl<B> Drop for CountingBody<B> {
698 fn drop(&mut self) {
699 self.metrics.add_usage_bytes(self.ingress, self.egress);
700 }
701}
702
703fn strip_hop_by_hop(headers: &mut HeaderMap) {
707 let connection_named: Vec<HeaderName> = headers
710 .get_all(header::CONNECTION)
711 .iter()
712 .filter_map(|v| v.to_str().ok())
713 .flat_map(|v| v.split(','))
714 .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
715 .collect();
716 for name in HOP_BY_HOP {
717 headers.remove(*name);
718 }
719 for name in connection_named {
720 headers.remove(name);
721 }
722}
723
724fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
730 if cfg.tls.enabled {
731 return "https";
732 }
733 if cfg.server.trust_forwarded_for {
734 if let Some(value) = headers
735 .get("x-forwarded-proto")
736 .and_then(|v| v.to_str().ok())
737 {
738 match value.split(',').next().map(str::trim) {
739 Some(p) if p.eq_ignore_ascii_case("https") => return "https",
740 Some(p) if p.eq_ignore_ascii_case("http") => return "http",
741 _ => {}
742 }
743 }
744 }
745 "http"
746}
747
748fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
750 routes
751 .iter()
752 .filter(|r| path.starts_with(&r.prefix))
753 .max_by_key(|r| r.prefix.len())
754}
755
756pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
760
761pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
773 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
774 out.push(("X-Content-Type-Options", "nosniff".to_string()));
775 if !cfg.frame_options.is_empty() {
776 out.push(("X-Frame-Options", cfg.frame_options.clone()));
777 }
778 if !cfg.referrer_policy.is_empty() {
779 out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
780 }
781 if !cfg.permissions_policy.is_empty() {
782 out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
783 }
784 if !cfg.csp.is_empty() {
785 let mut value = cfg.csp.clone();
787 if !cfg.csp_report_uri.is_empty() {
788 value.push_str("; report-uri ");
789 value.push_str(&cfg.csp_report_uri);
790 }
791 let name = if cfg.csp_report_only {
792 "Content-Security-Policy-Report-Only"
793 } else {
794 "Content-Security-Policy"
795 };
796 out.push((name, value));
797 }
798 if cfg.hsts {
799 out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
800 }
801 out
802}
803
804fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
806 let h = resp.headers_mut();
807
808 for (name, value) in security_headers(&cfg.headers) {
812 if let (Ok(n), Ok(v)) = (
813 HeaderName::from_bytes(name.as_bytes()),
814 HeaderValue::from_str(&value),
815 ) {
816 h.insert(n, v);
817 }
818 }
819
820 for name in &cfg.headers.strip {
822 if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
823 h.remove(hn);
824 }
825 }
826
827 if cfg.headers.force_secure_cookies {
829 let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
830 if !cookies.is_empty() {
831 h.remove(header::SET_COOKIE);
832 for c in cookies {
833 if let Ok(s) = c.to_str() {
834 let hardened = harden_cookie(s);
835 if let Ok(v) = HeaderValue::from_str(&hardened) {
836 h.append(header::SET_COOKIE, v);
837 }
838 } else {
839 h.append(header::SET_COOKIE, c);
840 }
841 }
842 }
843 }
844}
845
846fn harden_cookie(cookie: &str) -> String {
847 let attrs: std::collections::HashSet<String> = cookie
851 .split(';')
852 .skip(1)
853 .filter_map(|p| p.trim().split('=').next())
854 .map(|k| k.trim().to_ascii_lowercase())
855 .collect();
856
857 let mut out = cookie.trim_end_matches(';').to_string();
858 if !attrs.contains("secure") {
859 out.push_str("; Secure");
860 }
861 if !attrs.contains("httponly") {
862 out.push_str("; HttpOnly");
863 }
864 if !attrs.contains("samesite") {
865 out.push_str("; SameSite=Lax");
866 }
867 out
868}
869
870async fn within<F: Future>(
873 deadline: Option<tokio::time::Instant>,
874 fut: F,
875) -> Result<F::Output, tokio::time::error::Elapsed> {
876 match deadline {
877 Some(dl) => tokio::time::timeout_at(dl, fut).await,
878 None => Ok(fut.await),
879 }
880}
881
882fn text(status: StatusCode, msg: &str) -> Response<Body> {
883 let mut resp = Response::new(Body::from(msg.to_string()));
884 *resp.status_mut() = status;
885 resp.headers_mut().insert(
886 header::CONTENT_TYPE,
887 HeaderValue::from_static("text/plain; charset=utf-8"),
888 );
889 resp
890}
891
892fn finish(
894 metrics: &Metrics,
895 method: &Method,
896 path: &str,
897 ip: IpAddr,
898 started: Instant,
899 outcome: &str,
900 resp: Response<Body>,
901) -> Response<Body> {
902 let elapsed = started.elapsed();
903 info!(
904 %method,
905 path = %path,
906 client_ip = %ip,
907 status = resp.status().as_u16(),
908 outcome,
909 latency_ms = elapsed.as_millis() as u64,
910 "request"
911 );
912 metrics.record_request(outcome);
913 metrics.observe_latency(elapsed);
914 metrics.add_usage_request();
917 resp
918}
919
920#[cfg(test)]
921mod tests {
922 use super::*;
923
924 fn headers_with(name: &'static str, value: &str) -> HeaderMap {
925 let mut h = HeaderMap::new();
926 h.insert(name, HeaderValue::from_str(value).unwrap());
927 h
928 }
929
930 #[test]
931 fn client_ip_ignores_xff_when_untrusted() {
932 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
933 let h = headers_with("x-forwarded-for", "1.2.3.4");
934 assert_eq!(client_ip(&h, peer, false), peer.ip());
936 }
937
938 #[test]
939 fn client_ip_uses_first_xff_hop_when_trusted() {
940 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
941 let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
942 assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
943 }
944
945 #[test]
946 fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
947 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
948 assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
949 let garbage = headers_with("x-forwarded-for", "not-an-ip");
950 assert_eq!(client_ip(&garbage, peer, true), peer.ip());
951 }
952
953 #[test]
954 fn header_bytes_sums_names_and_values() {
955 let mut h = HeaderMap::new();
956 h.insert("a", HeaderValue::from_static("bb")); h.insert("ccc", HeaderValue::from_static("dddd")); assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
959 }
960
961 #[test]
962 fn strip_hop_by_hop_removes_fixed_and_connection_named() {
963 let mut h = HeaderMap::new();
964 h.insert(
965 "connection",
966 HeaderValue::from_static("keep-alive, X-Custom-Hop"),
967 );
968 h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
969 h.insert("x-custom-hop", HeaderValue::from_static("secret"));
970 h.insert("content-type", HeaderValue::from_static("text/plain"));
971 strip_hop_by_hop(&mut h);
972 assert!(!h.contains_key("connection"));
973 assert!(!h.contains_key("keep-alive"));
974 assert!(!h.contains_key("x-custom-hop"));
976 assert!(h.contains_key("content-type"));
978 }
979
980 #[test]
981 fn forwarded_proto_reflects_tls_and_trust() {
982 let mut cfg = Config::default();
983
984 cfg.tls.enabled = true;
986 assert_eq!(
987 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
988 "https"
989 );
990
991 cfg.tls.enabled = false;
993 cfg.server.trust_forwarded_for = false;
994 assert_eq!(
995 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
996 "http"
997 );
998
999 cfg.server.trust_forwarded_for = true;
1001 assert_eq!(
1002 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1003 "https"
1004 );
1005 assert_eq!(
1006 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
1007 "http"
1008 );
1009 assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
1011 assert_eq!(
1012 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
1013 "http"
1014 );
1015 }
1016
1017 #[test]
1018 fn longest_route_picks_most_specific_prefix() {
1019 let mk = |p: &str| RouteLimiter {
1020 prefix: p.to_string(),
1021 limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
1022 std::num::NonZeroU32::new(1).unwrap(),
1023 ))),
1024 };
1025 let routes = vec![mk("/api/"), mk("/api/admin/")];
1026 assert_eq!(
1027 longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
1028 Some("/api/admin/")
1029 );
1030 assert_eq!(
1031 longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
1032 Some("/api/")
1033 );
1034 assert!(longest_route(&routes, "/public").is_none());
1035 }
1036
1037 #[test]
1038 fn harden_cookie_adds_missing_flags() {
1039 let out = harden_cookie("sid=abc");
1040 assert!(out.contains("; Secure"), "{out}");
1041 assert!(out.contains("; HttpOnly"), "{out}");
1042 assert!(out.contains("; SameSite=Lax"), "{out}");
1043 }
1044
1045 #[test]
1046 fn harden_cookie_preserves_existing_attributes() {
1047 let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
1048 assert!(out.contains("; Secure"), "{out}");
1049 assert!(out.contains("SameSite=Strict"), "{out}");
1050 assert!(!out.contains("SameSite=Lax"), "{out}");
1052 assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
1053 }
1054
1055 #[test]
1056 fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
1057 let out = harden_cookie("session=securetoken");
1060 assert!(out.contains("; Secure"), "{out}");
1061 }
1062
1063 #[test]
1064 fn security_headers_reflects_config_toggles() {
1065 let cfg = HeadersCfg::default();
1067 let got = security_headers(&cfg);
1068 let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
1069 assert!(names.contains(&"X-Content-Type-Options"));
1070 assert!(names.contains(&"X-Frame-Options"));
1071 assert!(names.contains(&"Referrer-Policy"));
1072 assert!(names.contains(&"Permissions-Policy"));
1073 assert!(names.contains(&"Content-Security-Policy"));
1074 assert!(names.contains(&"Strict-Transport-Security"));
1075 assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
1076
1077 let cfg = HeadersCfg {
1080 hsts: false,
1081 frame_options: String::new(),
1082 csp: "default-src 'self'".into(),
1083 csp_report_only: true,
1084 csp_report_uri: "/__edgeguard/csp-report".into(),
1085 ..HeadersCfg::default()
1086 };
1087 let got = security_headers(&cfg);
1088 let map: std::collections::HashMap<&str, String> =
1089 got.iter().map(|(n, v)| (*n, v.clone())).collect();
1090 assert!(!map.contains_key("Strict-Transport-Security"));
1091 assert!(!map.contains_key("X-Frame-Options"));
1092 assert!(!map.contains_key("Content-Security-Policy"));
1093 assert_eq!(
1094 map.get("Content-Security-Policy-Report-Only")
1095 .map(|s| s.as_str()),
1096 Some("default-src 'self'; report-uri /__edgeguard/csp-report")
1097 );
1098 }
1099}