1use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use arc_swap::ArcSwap;
18use axum::{
19 body::{Body, Bytes},
20 extract::{ConnectInfo, State},
21 http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
22};
23use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
24use http_body_util::{BodyExt, Full, Limited};
25use hyper_util::client::legacy::{connect::HttpConnector, Client};
26use tokio::net::TcpStream;
27use tracing::{debug, info, warn};
28
29use crate::auth::{AuthEngine, Challenge, Decision};
30use crate::config::{Config, HeadersCfg};
31use crate::limiter::{Admit, DistributedLimiter};
32use crate::metrics::Metrics;
33use crate::waf::{WafEngine, WafMode};
34
35pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
36pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
38pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
39
40#[derive(Clone)]
43pub struct AppState {
44 pub client: UpstreamClient,
45 pub metrics: Arc<Metrics>,
46 pub runtime: Arc<ArcSwap<Runtime>>,
47}
48
49pub struct RouteLimiter {
51 pub prefix: String,
52 pub limiter: Arc<KeyedLimiter>,
53}
54
55pub struct Runtime {
58 pub cfg: Arc<Config>,
59 pub upstream_base: Arc<String>,
60 pub auth: AuthEngine,
61 pub waf: WafEngine,
63 pub distributed: Option<DistributedLimiter>,
66 pub ip_limiter: Option<Arc<KeyedLimiter>>,
68 pub route_limiters: Vec<RouteLimiter>,
70 pub key_limiter: Option<Arc<StrLimiter>>,
72 pub max_body: usize,
73 pub max_response_body: usize,
75 pub max_header_bytes: usize,
77 pub upstream_timeout: Option<Duration>,
79}
80
81const HOP_BY_HOP: &[&str] = &[
83 "connection",
84 "keep-alive",
85 "proxy-authenticate",
86 "proxy-authorization",
87 "te",
88 "trailer",
89 "transfer-encoding",
90 "upgrade",
91];
92
93pub async fn handle(
94 State(state): State<AppState>,
95 ConnectInfo(peer): ConnectInfo<SocketAddr>,
96 req: Request<Body>,
97) -> Response<Body> {
98 let started = Instant::now();
99 let rt = state.runtime.load();
102 let m = &state.metrics;
103
104 let method = req.method().clone();
105 let path = req
106 .uri()
107 .path_and_query()
108 .map(|p| p.as_str().to_string())
109 .unwrap_or_else(|| req.uri().path().to_string());
110
111 let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
112
113 if req.uri().path().starts_with("/__edgeguard/") {
119 return finish(
120 m,
121 &method,
122 &path,
123 ip,
124 started,
125 "not_found",
126 text(StatusCode::NOT_FOUND, "Not Found"),
127 );
128 }
129
130 if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
132 return finish(
133 m,
134 &method,
135 &path,
136 ip,
137 started,
138 "header_too_large",
139 text(
140 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
141 "Request Header Fields Too Large",
142 ),
143 );
144 }
145
146 if rt.cfg.ratelimit.enabled {
150 if let Some(d) = &rt.distributed {
151 match d.check_ip_route(ip, &path).await {
152 Admit::Allowed => {}
153 Admit::Limited(scope) => {
154 m.record_ratelimit_hit(scope);
155 return finish(
156 m,
157 &method,
158 &path,
159 ip,
160 started,
161 "rate_limited",
162 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
163 );
164 }
165 Admit::Error => {
166 return finish(
167 m,
168 &method,
169 &path,
170 ip,
171 started,
172 "limiter_error",
173 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
174 );
175 }
176 }
177 } else {
178 let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
179 Some(r) => (Some(r.limiter.as_ref()), "route"),
180 None => (rt.ip_limiter.as_deref(), "ip"),
181 };
182 if let Some(limiter) = limiter {
183 if limiter.check_key(&ip).is_err() {
184 m.record_ratelimit_hit(scope);
185 return finish(
186 m,
187 &method,
188 &path,
189 ip,
190 started,
191 "rate_limited",
192 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
193 );
194 }
195 }
196 }
197 }
198
199 let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
201 Decision::Allow(principal) => principal,
202 Decision::Deny(challenge) => {
203 let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
204 let challenge_value = match challenge {
205 Challenge::Basic(c) => Some(c),
206 Challenge::Bearer => Some("Bearer".to_string()),
207 Challenge::None => None,
208 };
209 if let Some(c) = challenge_value {
210 if let Ok(v) = HeaderValue::from_str(&c) {
211 resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
212 }
213 }
214 return finish(m, &method, &path, ip, started, "unauthorized", resp);
215 }
216 };
217
218 if let Some(principal) = &principal {
221 let key_admit = if let Some(d) = &rt.distributed {
222 Some(d.check_key(principal).await)
223 } else {
224 rt.key_limiter.as_ref().map(|limiter| {
225 if limiter.check_key(principal).is_err() {
226 Admit::Limited("key")
227 } else {
228 Admit::Allowed
229 }
230 })
231 };
232 match key_admit {
233 Some(Admit::Limited(scope)) => {
234 m.record_ratelimit_hit(scope);
235 return finish(
236 m,
237 &method,
238 &path,
239 ip,
240 started,
241 "rate_limited",
242 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
243 );
244 }
245 Some(Admit::Error) => {
246 return finish(
247 m,
248 &method,
249 &path,
250 ip,
251 started,
252 "limiter_error",
253 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
254 );
255 }
256 Some(Admit::Allowed) | None => {}
257 }
258 }
259
260 let allow = &rt.cfg.validation.allow_methods;
262 if !allow.is_empty()
263 && !allow
264 .iter()
265 .any(|x| x.eq_ignore_ascii_case(method.as_str()))
266 {
267 return finish(
268 m,
269 &method,
270 &path,
271 ip,
272 started,
273 "method_not_allowed",
274 text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
275 );
276 }
277
278 let (parts, body) = req.into_parts();
280 let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
281 Ok(b) => b,
282 Err(_) => {
283 return finish(
284 m,
285 &method,
286 &path,
287 ip,
288 started,
289 "payload_too_large",
290 text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
291 )
292 }
293 };
294
295 if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
300 m.record_waf_hit(hit.class);
301 match rt.waf.mode() {
302 WafMode::Block => {
303 warn!(
304 rule = %hit.rule_id,
305 class = hit.class,
306 location = hit.location,
307 client_ip = %ip,
308 path = %path,
309 "WAF blocked request"
310 );
311 return finish(
312 m,
313 &method,
314 &path,
315 ip,
316 started,
317 "forbidden",
318 text(StatusCode::FORBIDDEN, "Forbidden"),
319 );
320 }
321 WafMode::Report => warn!(
322 rule = %hit.rule_id,
323 class = hit.class,
324 location = hit.location,
325 client_ip = %ip,
326 path = %path,
327 "WAF rule matched (report-only)"
328 ),
329 WafMode::Off => {}
332 }
333 }
334
335 let uri = format!("{}{}", rt.upstream_base, path);
337 let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
338 {
339 let headers = up.headers_mut().expect("builder headers");
340 let mut forwarded = parts.headers.clone();
343 strip_hop_by_hop(&mut forwarded);
344 for (name, value) in forwarded.iter() {
345 if name == header::HOST {
346 continue; }
348 headers.insert(name.clone(), value.clone());
349 }
350 if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
352 headers.insert(HeaderName::from_static("x-forwarded-for"), v);
353 }
354 headers.insert(
355 HeaderName::from_static("x-forwarded-proto"),
356 HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
357 );
358 }
359
360 let upstream_req = match up.body(Full::new(body_bytes)) {
361 Ok(r) => r,
362 Err(e) => {
363 warn!(error = %e, "failed to build upstream request");
364 return finish(
365 m,
366 &method,
367 &path,
368 ip,
369 started,
370 "bad_gateway",
371 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
372 );
373 }
374 };
375
376 let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
379 let timed_out = || {
380 warn!(upstream = %uri, "upstream timed out");
381 text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
382 };
383
384 let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
385 Ok(Ok(r)) => r,
386 Ok(Err(e)) => {
387 warn!(error = %e, upstream = %uri, "upstream unreachable");
388 return finish(
389 m,
390 &method,
391 &path,
392 ip,
393 started,
394 "upstream_error",
395 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
396 );
397 }
398 Err(_) => {
399 return finish(
400 m,
401 &method,
402 &path,
403 ip,
404 started,
405 "upstream_timeout",
406 timed_out(),
407 )
408 }
409 };
410
411 let (mut resp_parts, resp_body) = upstream_resp.into_parts();
412 let resp_bytes = if rt.max_response_body > 0 {
414 match within(
415 deadline,
416 Limited::new(resp_body, rt.max_response_body).collect(),
417 )
418 .await
419 {
420 Ok(Ok(c)) => c.to_bytes(),
421 Ok(Err(_)) => {
422 warn!(
423 limit = rt.max_response_body,
424 "upstream response exceeded max_response_body"
425 );
426 return finish(
427 m,
428 &method,
429 &path,
430 ip,
431 started,
432 "upstream_body_too_large",
433 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
434 );
435 }
436 Err(_) => {
437 return finish(
438 m,
439 &method,
440 &path,
441 ip,
442 started,
443 "upstream_timeout",
444 timed_out(),
445 )
446 }
447 }
448 } else {
449 match within(deadline, resp_body.collect()).await {
450 Ok(Ok(c)) => c.to_bytes(),
451 Ok(Err(e)) => {
452 warn!(error = %e, "failed reading upstream body");
453 return finish(
454 m,
455 &method,
456 &path,
457 ip,
458 started,
459 "upstream_body_error",
460 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
461 );
462 }
463 Err(_) => {
464 return finish(
465 m,
466 &method,
467 &path,
468 ip,
469 started,
470 "upstream_timeout",
471 timed_out(),
472 )
473 }
474 }
475 };
476
477 strip_hop_by_hop(&mut resp_parts.headers);
480 resp_parts.headers.remove(header::CONTENT_LENGTH);
481
482 let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
483 harden_response(&rt.cfg, &mut response);
484
485 finish(m, &method, &path, ip, started, "ok", response)
486}
487
488pub async fn ready(State(state): State<AppState>) -> StatusCode {
493 let rt = state.runtime.load();
494 let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
495 return StatusCode::SERVICE_UNAVAILABLE;
496 };
497 match tokio::time::timeout(
498 Duration::from_secs(2),
499 TcpStream::connect((host.as_str(), port)),
500 )
501 .await
502 {
503 Ok(Ok(_)) => StatusCode::OK,
504 _ => StatusCode::SERVICE_UNAVAILABLE,
505 }
506}
507
508pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
512 let body = state.metrics.render();
513 let mut resp = Response::new(Body::from(body));
514 resp.headers_mut().insert(
515 header::CONTENT_TYPE,
516 HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
517 );
518 resp
519}
520
521pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
524 state.metrics.record_csp_report();
525 match serde_json::from_slice::<serde_json::Value>(&body) {
529 Ok(report) => {
530 let directive = report
531 .get("csp-report")
532 .and_then(|r| {
533 r.get("violated-directive")
534 .or_else(|| r.get("effective-directive"))
535 })
536 .and_then(|v| v.as_str())
537 .unwrap_or("unknown");
538 debug!(target: "edgeguard::csp", directive, "CSP violation report");
539 }
540 Err(_) => warn!(
541 bytes = body.len(),
542 "CSP violation report with an unparseable body"
543 ),
544 }
545 StatusCode::NO_CONTENT
546}
547
548fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
552 if trust_forwarded {
553 if let Some(xff) = headers.get("x-forwarded-for") {
554 if let Ok(s) = xff.to_str() {
555 if let Some(first) = s.split(',').next() {
556 if let Ok(ip) = first.trim().parse::<IpAddr>() {
557 return ip;
558 }
559 }
560 }
561 }
562 }
563 peer.ip()
564}
565
566fn header_bytes(headers: &HeaderMap) -> usize {
569 headers
570 .iter()
571 .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
572 .sum()
573}
574
575fn strip_hop_by_hop(headers: &mut HeaderMap) {
579 let connection_named: Vec<HeaderName> = headers
582 .get_all(header::CONNECTION)
583 .iter()
584 .filter_map(|v| v.to_str().ok())
585 .flat_map(|v| v.split(','))
586 .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
587 .collect();
588 for name in HOP_BY_HOP {
589 headers.remove(*name);
590 }
591 for name in connection_named {
592 headers.remove(name);
593 }
594}
595
596fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
602 if cfg.tls.enabled {
603 return "https";
604 }
605 if cfg.server.trust_forwarded_for {
606 if let Some(value) = headers
607 .get("x-forwarded-proto")
608 .and_then(|v| v.to_str().ok())
609 {
610 match value.split(',').next().map(str::trim) {
611 Some(p) if p.eq_ignore_ascii_case("https") => return "https",
612 Some(p) if p.eq_ignore_ascii_case("http") => return "http",
613 _ => {}
614 }
615 }
616 }
617 "http"
618}
619
620fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
622 routes
623 .iter()
624 .filter(|r| path.starts_with(&r.prefix))
625 .max_by_key(|r| r.prefix.len())
626}
627
628pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
632
633pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
645 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
646 out.push(("X-Content-Type-Options", "nosniff".to_string()));
647 if !cfg.frame_options.is_empty() {
648 out.push(("X-Frame-Options", cfg.frame_options.clone()));
649 }
650 if !cfg.referrer_policy.is_empty() {
651 out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
652 }
653 if !cfg.permissions_policy.is_empty() {
654 out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
655 }
656 if !cfg.csp.is_empty() {
657 let mut value = cfg.csp.clone();
659 if !cfg.csp_report_uri.is_empty() {
660 value.push_str("; report-uri ");
661 value.push_str(&cfg.csp_report_uri);
662 }
663 let name = if cfg.csp_report_only {
664 "Content-Security-Policy-Report-Only"
665 } else {
666 "Content-Security-Policy"
667 };
668 out.push((name, value));
669 }
670 if cfg.hsts {
671 out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
672 }
673 out
674}
675
676fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
678 let h = resp.headers_mut();
679
680 for (name, value) in security_headers(&cfg.headers) {
684 if let (Ok(n), Ok(v)) = (
685 HeaderName::from_bytes(name.as_bytes()),
686 HeaderValue::from_str(&value),
687 ) {
688 h.insert(n, v);
689 }
690 }
691
692 for name in &cfg.headers.strip {
694 if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
695 h.remove(hn);
696 }
697 }
698
699 if cfg.headers.force_secure_cookies {
701 let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
702 if !cookies.is_empty() {
703 h.remove(header::SET_COOKIE);
704 for c in cookies {
705 if let Ok(s) = c.to_str() {
706 let hardened = harden_cookie(s);
707 if let Ok(v) = HeaderValue::from_str(&hardened) {
708 h.append(header::SET_COOKIE, v);
709 }
710 } else {
711 h.append(header::SET_COOKIE, c);
712 }
713 }
714 }
715 }
716}
717
718fn harden_cookie(cookie: &str) -> String {
719 let attrs: std::collections::HashSet<String> = cookie
723 .split(';')
724 .skip(1)
725 .filter_map(|p| p.trim().split('=').next())
726 .map(|k| k.trim().to_ascii_lowercase())
727 .collect();
728
729 let mut out = cookie.trim_end_matches(';').to_string();
730 if !attrs.contains("secure") {
731 out.push_str("; Secure");
732 }
733 if !attrs.contains("httponly") {
734 out.push_str("; HttpOnly");
735 }
736 if !attrs.contains("samesite") {
737 out.push_str("; SameSite=Lax");
738 }
739 out
740}
741
742async fn within<F: Future>(
745 deadline: Option<tokio::time::Instant>,
746 fut: F,
747) -> Result<F::Output, tokio::time::error::Elapsed> {
748 match deadline {
749 Some(dl) => tokio::time::timeout_at(dl, fut).await,
750 None => Ok(fut.await),
751 }
752}
753
754fn text(status: StatusCode, msg: &str) -> Response<Body> {
755 let mut resp = Response::new(Body::from(msg.to_string()));
756 *resp.status_mut() = status;
757 resp.headers_mut().insert(
758 header::CONTENT_TYPE,
759 HeaderValue::from_static("text/plain; charset=utf-8"),
760 );
761 resp
762}
763
764fn finish(
766 metrics: &Metrics,
767 method: &Method,
768 path: &str,
769 ip: IpAddr,
770 started: Instant,
771 outcome: &str,
772 resp: Response<Body>,
773) -> Response<Body> {
774 let elapsed = started.elapsed();
775 info!(
776 %method,
777 path = %path,
778 client_ip = %ip,
779 status = resp.status().as_u16(),
780 outcome,
781 latency_ms = elapsed.as_millis() as u64,
782 "request"
783 );
784 metrics.record_request(outcome);
785 metrics.observe_latency(elapsed);
786 resp
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792
793 fn headers_with(name: &'static str, value: &str) -> HeaderMap {
794 let mut h = HeaderMap::new();
795 h.insert(name, HeaderValue::from_str(value).unwrap());
796 h
797 }
798
799 #[test]
800 fn client_ip_ignores_xff_when_untrusted() {
801 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
802 let h = headers_with("x-forwarded-for", "1.2.3.4");
803 assert_eq!(client_ip(&h, peer, false), peer.ip());
805 }
806
807 #[test]
808 fn client_ip_uses_first_xff_hop_when_trusted() {
809 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
810 let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
811 assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
812 }
813
814 #[test]
815 fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
816 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
817 assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
818 let garbage = headers_with("x-forwarded-for", "not-an-ip");
819 assert_eq!(client_ip(&garbage, peer, true), peer.ip());
820 }
821
822 #[test]
823 fn header_bytes_sums_names_and_values() {
824 let mut h = HeaderMap::new();
825 h.insert("a", HeaderValue::from_static("bb")); h.insert("ccc", HeaderValue::from_static("dddd")); assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
828 }
829
830 #[test]
831 fn strip_hop_by_hop_removes_fixed_and_connection_named() {
832 let mut h = HeaderMap::new();
833 h.insert(
834 "connection",
835 HeaderValue::from_static("keep-alive, X-Custom-Hop"),
836 );
837 h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
838 h.insert("x-custom-hop", HeaderValue::from_static("secret"));
839 h.insert("content-type", HeaderValue::from_static("text/plain"));
840 strip_hop_by_hop(&mut h);
841 assert!(!h.contains_key("connection"));
842 assert!(!h.contains_key("keep-alive"));
843 assert!(!h.contains_key("x-custom-hop"));
845 assert!(h.contains_key("content-type"));
847 }
848
849 #[test]
850 fn forwarded_proto_reflects_tls_and_trust() {
851 let mut cfg = Config::default();
852
853 cfg.tls.enabled = true;
855 assert_eq!(
856 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
857 "https"
858 );
859
860 cfg.tls.enabled = false;
862 cfg.server.trust_forwarded_for = false;
863 assert_eq!(
864 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
865 "http"
866 );
867
868 cfg.server.trust_forwarded_for = true;
870 assert_eq!(
871 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
872 "https"
873 );
874 assert_eq!(
875 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
876 "http"
877 );
878 assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
880 assert_eq!(
881 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
882 "http"
883 );
884 }
885
886 #[test]
887 fn longest_route_picks_most_specific_prefix() {
888 let mk = |p: &str| RouteLimiter {
889 prefix: p.to_string(),
890 limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
891 std::num::NonZeroU32::new(1).unwrap(),
892 ))),
893 };
894 let routes = vec![mk("/api/"), mk("/api/admin/")];
895 assert_eq!(
896 longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
897 Some("/api/admin/")
898 );
899 assert_eq!(
900 longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
901 Some("/api/")
902 );
903 assert!(longest_route(&routes, "/public").is_none());
904 }
905
906 #[test]
907 fn harden_cookie_adds_missing_flags() {
908 let out = harden_cookie("sid=abc");
909 assert!(out.contains("; Secure"), "{out}");
910 assert!(out.contains("; HttpOnly"), "{out}");
911 assert!(out.contains("; SameSite=Lax"), "{out}");
912 }
913
914 #[test]
915 fn harden_cookie_preserves_existing_attributes() {
916 let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
917 assert!(out.contains("; Secure"), "{out}");
918 assert!(out.contains("SameSite=Strict"), "{out}");
919 assert!(!out.contains("SameSite=Lax"), "{out}");
921 assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
922 }
923
924 #[test]
925 fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
926 let out = harden_cookie("session=securetoken");
929 assert!(out.contains("; Secure"), "{out}");
930 }
931
932 #[test]
933 fn security_headers_reflects_config_toggles() {
934 let cfg = HeadersCfg::default();
936 let got = security_headers(&cfg);
937 let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
938 assert!(names.contains(&"X-Content-Type-Options"));
939 assert!(names.contains(&"X-Frame-Options"));
940 assert!(names.contains(&"Referrer-Policy"));
941 assert!(names.contains(&"Permissions-Policy"));
942 assert!(names.contains(&"Content-Security-Policy"));
943 assert!(names.contains(&"Strict-Transport-Security"));
944 assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
945
946 let cfg = HeadersCfg {
949 hsts: false,
950 frame_options: String::new(),
951 csp: "default-src 'self'".into(),
952 csp_report_only: true,
953 csp_report_uri: "/__edgeguard/csp-report".into(),
954 ..HeadersCfg::default()
955 };
956 let got = security_headers(&cfg);
957 let map: std::collections::HashMap<&str, String> =
958 got.iter().map(|(n, v)| (*n, v.clone())).collect();
959 assert!(!map.contains_key("Strict-Transport-Security"));
960 assert!(!map.contains_key("X-Frame-Options"));
961 assert!(!map.contains_key("Content-Security-Policy"));
962 assert_eq!(
963 map.get("Content-Security-Policy-Report-Only")
964 .map(|s| s.as_str()),
965 Some("default-src 'self'; report-uri /__edgeguard/csp-report")
966 );
967 }
968}