1use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use arc_swap::ArcSwap;
18use axum::{
19 body::{Body, Bytes},
20 extract::{ConnectInfo, State},
21 http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
22};
23use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
24use http_body_util::{BodyExt, Full, Limited};
25use hyper_util::client::legacy::{connect::HttpConnector, Client};
26use tokio::net::TcpStream;
27use tracing::{debug, info, warn};
28
29use crate::auth::{AuthEngine, Challenge, Decision};
30use crate::config::{Config, HeadersCfg};
31use crate::limiter::{Admit, DistributedLimiter};
32use crate::metrics::Metrics;
33use crate::waf::{WafEngine, WafMode};
34
35pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
36pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
38pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
39
40#[derive(Clone)]
43pub struct AppState {
44 pub client: UpstreamClient,
45 pub metrics: Arc<Metrics>,
46 pub runtime: Arc<ArcSwap<Runtime>>,
47 pub cp: Option<Arc<crate::cp::CpClient>>,
50}
51
52pub struct RouteLimiter {
54 pub prefix: String,
55 pub limiter: Arc<KeyedLimiter>,
56}
57
58pub struct Runtime {
61 pub cfg: Arc<Config>,
62 pub upstream_base: Arc<String>,
63 pub auth: AuthEngine,
64 pub waf: WafEngine,
66 pub distributed: Option<DistributedLimiter>,
69 pub ip_limiter: Option<Arc<KeyedLimiter>>,
71 pub route_limiters: Vec<RouteLimiter>,
73 pub key_limiter: Option<Arc<StrLimiter>>,
75 pub max_body: usize,
76 pub max_response_body: usize,
78 pub max_header_bytes: usize,
80 pub upstream_timeout: Option<Duration>,
82}
83
84const HOP_BY_HOP: &[&str] = &[
86 "connection",
87 "keep-alive",
88 "proxy-authenticate",
89 "proxy-authorization",
90 "te",
91 "trailer",
92 "transfer-encoding",
93 "upgrade",
94];
95
96pub async fn handle(
97 State(state): State<AppState>,
98 ConnectInfo(peer): ConnectInfo<SocketAddr>,
99 req: Request<Body>,
100) -> Response<Body> {
101 let started = Instant::now();
102 let rt = state.runtime.load();
105 let m = &state.metrics;
106
107 let method = req.method().clone();
108 let path = req
109 .uri()
110 .path_and_query()
111 .map(|p| p.as_str().to_string())
112 .unwrap_or_else(|| req.uri().path().to_string());
113
114 let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
115
116 if req.uri().path().starts_with("/__edgeguard/") {
122 return finish(
123 m,
124 &method,
125 &path,
126 ip,
127 started,
128 "not_found",
129 text(StatusCode::NOT_FOUND, "Not Found"),
130 );
131 }
132
133 if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
135 return finish(
136 m,
137 &method,
138 &path,
139 ip,
140 started,
141 "header_too_large",
142 text(
143 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
144 "Request Header Fields Too Large",
145 ),
146 );
147 }
148
149 if rt.cfg.ratelimit.enabled {
153 if let Some(d) = &rt.distributed {
154 match d.check_ip_route(ip, &path).await {
155 Admit::Allowed => {}
156 Admit::Limited(scope) => {
157 m.record_ratelimit_hit(scope);
158 return finish(
159 m,
160 &method,
161 &path,
162 ip,
163 started,
164 "rate_limited",
165 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
166 );
167 }
168 Admit::Error => {
169 return finish(
170 m,
171 &method,
172 &path,
173 ip,
174 started,
175 "limiter_error",
176 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
177 );
178 }
179 }
180 } else {
181 let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
182 Some(r) => (Some(r.limiter.as_ref()), "route"),
183 None => (rt.ip_limiter.as_deref(), "ip"),
184 };
185 if let Some(limiter) = limiter {
186 if limiter.check_key(&ip).is_err() {
187 m.record_ratelimit_hit(scope);
188 return finish(
189 m,
190 &method,
191 &path,
192 ip,
193 started,
194 "rate_limited",
195 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
196 );
197 }
198 }
199 }
200 }
201
202 let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
204 Decision::Allow(principal) => principal,
205 Decision::Deny(challenge) => {
206 let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
207 let challenge_value = match challenge {
208 Challenge::Basic(c) => Some(c),
209 Challenge::Bearer => Some("Bearer".to_string()),
210 Challenge::None => None,
211 };
212 if let Some(c) = challenge_value {
213 if let Ok(v) = HeaderValue::from_str(&c) {
214 resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
215 }
216 }
217 return finish(m, &method, &path, ip, started, "unauthorized", resp);
218 }
219 };
220
221 if let Some(principal) = &principal {
224 let key_admit = if let Some(d) = &rt.distributed {
225 Some(d.check_key(principal).await)
226 } else {
227 rt.key_limiter.as_ref().map(|limiter| {
228 if limiter.check_key(principal).is_err() {
229 Admit::Limited("key")
230 } else {
231 Admit::Allowed
232 }
233 })
234 };
235 match key_admit {
236 Some(Admit::Limited(scope)) => {
237 m.record_ratelimit_hit(scope);
238 return finish(
239 m,
240 &method,
241 &path,
242 ip,
243 started,
244 "rate_limited",
245 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
246 );
247 }
248 Some(Admit::Error) => {
249 return finish(
250 m,
251 &method,
252 &path,
253 ip,
254 started,
255 "limiter_error",
256 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
257 );
258 }
259 Some(Admit::Allowed) | None => {}
260 }
261 }
262
263 let allow = &rt.cfg.validation.allow_methods;
265 if !allow.is_empty()
266 && !allow
267 .iter()
268 .any(|x| x.eq_ignore_ascii_case(method.as_str()))
269 {
270 return finish(
271 m,
272 &method,
273 &path,
274 ip,
275 started,
276 "method_not_allowed",
277 text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
278 );
279 }
280
281 let (parts, body) = req.into_parts();
283 let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
284 Ok(b) => b,
285 Err(_) => {
286 return finish(
287 m,
288 &method,
289 &path,
290 ip,
291 started,
292 "payload_too_large",
293 text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
294 )
295 }
296 };
297 let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
299
300 if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
305 m.record_waf_hit(hit.class);
306 match rt.waf.mode() {
307 WafMode::Block => {
308 warn!(
309 rule = %hit.rule_id,
310 class = hit.class,
311 location = hit.location,
312 client_ip = %ip,
313 path = %path,
314 "WAF blocked request"
315 );
316 return finish(
317 m,
318 &method,
319 &path,
320 ip,
321 started,
322 "forbidden",
323 text(StatusCode::FORBIDDEN, "Forbidden"),
324 );
325 }
326 WafMode::Report => warn!(
327 rule = %hit.rule_id,
328 class = hit.class,
329 location = hit.location,
330 client_ip = %ip,
331 path = %path,
332 "WAF rule matched (report-only)"
333 ),
334 WafMode::Off => {}
337 }
338 }
339
340 let uri = format!("{}{}", rt.upstream_base, path);
342 let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
343 {
344 let headers = up.headers_mut().expect("builder headers");
345 let mut forwarded = parts.headers.clone();
348 strip_hop_by_hop(&mut forwarded);
349 for (name, value) in forwarded.iter() {
350 if name == header::HOST {
351 continue; }
353 headers.insert(name.clone(), value.clone());
354 }
355 if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
357 headers.insert(HeaderName::from_static("x-forwarded-for"), v);
358 }
359 headers.insert(
360 HeaderName::from_static("x-forwarded-proto"),
361 HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
362 );
363 }
364
365 let upstream_req = match up.body(Full::new(body_bytes)) {
366 Ok(r) => r,
367 Err(e) => {
368 warn!(error = %e, "failed to build upstream request");
369 return finish(
370 m,
371 &method,
372 &path,
373 ip,
374 started,
375 "bad_gateway",
376 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
377 );
378 }
379 };
380
381 let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
384 let timed_out = || {
385 warn!(upstream = %uri, "upstream timed out");
386 text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
387 };
388
389 let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
390 Ok(Ok(r)) => r,
391 Ok(Err(e)) => {
392 warn!(error = %e, upstream = %uri, "upstream unreachable");
393 return finish(
394 m,
395 &method,
396 &path,
397 ip,
398 started,
399 "upstream_error",
400 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
401 );
402 }
403 Err(_) => {
404 return finish(
405 m,
406 &method,
407 &path,
408 ip,
409 started,
410 "upstream_timeout",
411 timed_out(),
412 )
413 }
414 };
415
416 let (mut resp_parts, resp_body) = upstream_resp.into_parts();
417 let resp_bytes = if rt.max_response_body > 0 {
419 match within(
420 deadline,
421 Limited::new(resp_body, rt.max_response_body).collect(),
422 )
423 .await
424 {
425 Ok(Ok(c)) => c.to_bytes(),
426 Ok(Err(_)) => {
427 warn!(
428 limit = rt.max_response_body,
429 "upstream response exceeded max_response_body"
430 );
431 return finish(
432 m,
433 &method,
434 &path,
435 ip,
436 started,
437 "upstream_body_too_large",
438 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
439 );
440 }
441 Err(_) => {
442 return finish(
443 m,
444 &method,
445 &path,
446 ip,
447 started,
448 "upstream_timeout",
449 timed_out(),
450 )
451 }
452 }
453 } else {
454 match within(deadline, resp_body.collect()).await {
455 Ok(Ok(c)) => c.to_bytes(),
456 Ok(Err(e)) => {
457 warn!(error = %e, "failed reading upstream body");
458 return finish(
459 m,
460 &method,
461 &path,
462 ip,
463 started,
464 "upstream_body_error",
465 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
466 );
467 }
468 Err(_) => {
469 return finish(
470 m,
471 &method,
472 &path,
473 ip,
474 started,
475 "upstream_timeout",
476 timed_out(),
477 )
478 }
479 }
480 };
481
482 strip_hop_by_hop(&mut resp_parts.headers);
485 resp_parts.headers.remove(header::CONTENT_LENGTH);
486
487 m.add_usage_bytes(
490 ingress_bytes,
491 header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
492 );
493
494 let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
495 harden_response(&rt.cfg, &mut response);
496
497 finish(m, &method, &path, ip, started, "ok", response)
498}
499
500pub async fn ready(State(state): State<AppState>) -> StatusCode {
505 let rt = state.runtime.load();
506 let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
507 return StatusCode::SERVICE_UNAVAILABLE;
508 };
509 match tokio::time::timeout(
510 Duration::from_secs(2),
511 TcpStream::connect((host.as_str(), port)),
512 )
513 .await
514 {
515 Ok(Ok(_)) => StatusCode::OK,
516 _ => StatusCode::SERVICE_UNAVAILABLE,
517 }
518}
519
520pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
524 let body = state.metrics.render();
525 let mut resp = Response::new(Body::from(body));
526 resp.headers_mut().insert(
527 header::CONTENT_TYPE,
528 HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
529 );
530 resp
531}
532
533pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
536 state.metrics.record_csp_report();
537 if let Some(cp) = &state.cp {
541 if state.runtime.load().cfg.control_plane.forward_csp {
542 let cp = cp.clone();
543 let raw = body.clone();
544 tokio::spawn(async move { cp.forward_csp(&raw).await });
545 }
546 }
547 match serde_json::from_slice::<serde_json::Value>(&body) {
551 Ok(report) => {
552 let directive = report
553 .get("csp-report")
554 .and_then(|r| {
555 r.get("violated-directive")
556 .or_else(|| r.get("effective-directive"))
557 })
558 .and_then(|v| v.as_str())
559 .unwrap_or("unknown");
560 debug!(target: "edgeguard::csp", directive, "CSP violation report");
561 }
562 Err(_) => warn!(
563 bytes = body.len(),
564 "CSP violation report with an unparseable body"
565 ),
566 }
567 StatusCode::NO_CONTENT
568}
569
570fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
574 if trust_forwarded {
575 if let Some(xff) = headers.get("x-forwarded-for") {
576 if let Ok(s) = xff.to_str() {
577 if let Some(first) = s.split(',').next() {
578 if let Ok(ip) = first.trim().parse::<IpAddr>() {
579 return ip;
580 }
581 }
582 }
583 }
584 }
585 peer.ip()
586}
587
588fn header_bytes(headers: &HeaderMap) -> usize {
591 headers
592 .iter()
593 .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
594 .sum()
595}
596
597fn strip_hop_by_hop(headers: &mut HeaderMap) {
601 let connection_named: Vec<HeaderName> = headers
604 .get_all(header::CONNECTION)
605 .iter()
606 .filter_map(|v| v.to_str().ok())
607 .flat_map(|v| v.split(','))
608 .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
609 .collect();
610 for name in HOP_BY_HOP {
611 headers.remove(*name);
612 }
613 for name in connection_named {
614 headers.remove(name);
615 }
616}
617
618fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
624 if cfg.tls.enabled {
625 return "https";
626 }
627 if cfg.server.trust_forwarded_for {
628 if let Some(value) = headers
629 .get("x-forwarded-proto")
630 .and_then(|v| v.to_str().ok())
631 {
632 match value.split(',').next().map(str::trim) {
633 Some(p) if p.eq_ignore_ascii_case("https") => return "https",
634 Some(p) if p.eq_ignore_ascii_case("http") => return "http",
635 _ => {}
636 }
637 }
638 }
639 "http"
640}
641
642fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
644 routes
645 .iter()
646 .filter(|r| path.starts_with(&r.prefix))
647 .max_by_key(|r| r.prefix.len())
648}
649
650pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
654
655pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
667 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
668 out.push(("X-Content-Type-Options", "nosniff".to_string()));
669 if !cfg.frame_options.is_empty() {
670 out.push(("X-Frame-Options", cfg.frame_options.clone()));
671 }
672 if !cfg.referrer_policy.is_empty() {
673 out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
674 }
675 if !cfg.permissions_policy.is_empty() {
676 out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
677 }
678 if !cfg.csp.is_empty() {
679 let mut value = cfg.csp.clone();
681 if !cfg.csp_report_uri.is_empty() {
682 value.push_str("; report-uri ");
683 value.push_str(&cfg.csp_report_uri);
684 }
685 let name = if cfg.csp_report_only {
686 "Content-Security-Policy-Report-Only"
687 } else {
688 "Content-Security-Policy"
689 };
690 out.push((name, value));
691 }
692 if cfg.hsts {
693 out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
694 }
695 out
696}
697
698fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
700 let h = resp.headers_mut();
701
702 for (name, value) in security_headers(&cfg.headers) {
706 if let (Ok(n), Ok(v)) = (
707 HeaderName::from_bytes(name.as_bytes()),
708 HeaderValue::from_str(&value),
709 ) {
710 h.insert(n, v);
711 }
712 }
713
714 for name in &cfg.headers.strip {
716 if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
717 h.remove(hn);
718 }
719 }
720
721 if cfg.headers.force_secure_cookies {
723 let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
724 if !cookies.is_empty() {
725 h.remove(header::SET_COOKIE);
726 for c in cookies {
727 if let Ok(s) = c.to_str() {
728 let hardened = harden_cookie(s);
729 if let Ok(v) = HeaderValue::from_str(&hardened) {
730 h.append(header::SET_COOKIE, v);
731 }
732 } else {
733 h.append(header::SET_COOKIE, c);
734 }
735 }
736 }
737 }
738}
739
740fn harden_cookie(cookie: &str) -> String {
741 let attrs: std::collections::HashSet<String> = cookie
745 .split(';')
746 .skip(1)
747 .filter_map(|p| p.trim().split('=').next())
748 .map(|k| k.trim().to_ascii_lowercase())
749 .collect();
750
751 let mut out = cookie.trim_end_matches(';').to_string();
752 if !attrs.contains("secure") {
753 out.push_str("; Secure");
754 }
755 if !attrs.contains("httponly") {
756 out.push_str("; HttpOnly");
757 }
758 if !attrs.contains("samesite") {
759 out.push_str("; SameSite=Lax");
760 }
761 out
762}
763
764async fn within<F: Future>(
767 deadline: Option<tokio::time::Instant>,
768 fut: F,
769) -> Result<F::Output, tokio::time::error::Elapsed> {
770 match deadline {
771 Some(dl) => tokio::time::timeout_at(dl, fut).await,
772 None => Ok(fut.await),
773 }
774}
775
776fn text(status: StatusCode, msg: &str) -> Response<Body> {
777 let mut resp = Response::new(Body::from(msg.to_string()));
778 *resp.status_mut() = status;
779 resp.headers_mut().insert(
780 header::CONTENT_TYPE,
781 HeaderValue::from_static("text/plain; charset=utf-8"),
782 );
783 resp
784}
785
786fn finish(
788 metrics: &Metrics,
789 method: &Method,
790 path: &str,
791 ip: IpAddr,
792 started: Instant,
793 outcome: &str,
794 resp: Response<Body>,
795) -> Response<Body> {
796 let elapsed = started.elapsed();
797 info!(
798 %method,
799 path = %path,
800 client_ip = %ip,
801 status = resp.status().as_u16(),
802 outcome,
803 latency_ms = elapsed.as_millis() as u64,
804 "request"
805 );
806 metrics.record_request(outcome);
807 metrics.observe_latency(elapsed);
808 metrics.add_usage_request();
811 resp
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 fn headers_with(name: &'static str, value: &str) -> HeaderMap {
819 let mut h = HeaderMap::new();
820 h.insert(name, HeaderValue::from_str(value).unwrap());
821 h
822 }
823
824 #[test]
825 fn client_ip_ignores_xff_when_untrusted() {
826 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
827 let h = headers_with("x-forwarded-for", "1.2.3.4");
828 assert_eq!(client_ip(&h, peer, false), peer.ip());
830 }
831
832 #[test]
833 fn client_ip_uses_first_xff_hop_when_trusted() {
834 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
835 let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
836 assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
837 }
838
839 #[test]
840 fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
841 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
842 assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
843 let garbage = headers_with("x-forwarded-for", "not-an-ip");
844 assert_eq!(client_ip(&garbage, peer, true), peer.ip());
845 }
846
847 #[test]
848 fn header_bytes_sums_names_and_values() {
849 let mut h = HeaderMap::new();
850 h.insert("a", HeaderValue::from_static("bb")); h.insert("ccc", HeaderValue::from_static("dddd")); assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
853 }
854
855 #[test]
856 fn strip_hop_by_hop_removes_fixed_and_connection_named() {
857 let mut h = HeaderMap::new();
858 h.insert(
859 "connection",
860 HeaderValue::from_static("keep-alive, X-Custom-Hop"),
861 );
862 h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
863 h.insert("x-custom-hop", HeaderValue::from_static("secret"));
864 h.insert("content-type", HeaderValue::from_static("text/plain"));
865 strip_hop_by_hop(&mut h);
866 assert!(!h.contains_key("connection"));
867 assert!(!h.contains_key("keep-alive"));
868 assert!(!h.contains_key("x-custom-hop"));
870 assert!(h.contains_key("content-type"));
872 }
873
874 #[test]
875 fn forwarded_proto_reflects_tls_and_trust() {
876 let mut cfg = Config::default();
877
878 cfg.tls.enabled = true;
880 assert_eq!(
881 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
882 "https"
883 );
884
885 cfg.tls.enabled = false;
887 cfg.server.trust_forwarded_for = false;
888 assert_eq!(
889 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
890 "http"
891 );
892
893 cfg.server.trust_forwarded_for = true;
895 assert_eq!(
896 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
897 "https"
898 );
899 assert_eq!(
900 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
901 "http"
902 );
903 assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
905 assert_eq!(
906 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
907 "http"
908 );
909 }
910
911 #[test]
912 fn longest_route_picks_most_specific_prefix() {
913 let mk = |p: &str| RouteLimiter {
914 prefix: p.to_string(),
915 limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
916 std::num::NonZeroU32::new(1).unwrap(),
917 ))),
918 };
919 let routes = vec![mk("/api/"), mk("/api/admin/")];
920 assert_eq!(
921 longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
922 Some("/api/admin/")
923 );
924 assert_eq!(
925 longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
926 Some("/api/")
927 );
928 assert!(longest_route(&routes, "/public").is_none());
929 }
930
931 #[test]
932 fn harden_cookie_adds_missing_flags() {
933 let out = harden_cookie("sid=abc");
934 assert!(out.contains("; Secure"), "{out}");
935 assert!(out.contains("; HttpOnly"), "{out}");
936 assert!(out.contains("; SameSite=Lax"), "{out}");
937 }
938
939 #[test]
940 fn harden_cookie_preserves_existing_attributes() {
941 let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
942 assert!(out.contains("; Secure"), "{out}");
943 assert!(out.contains("SameSite=Strict"), "{out}");
944 assert!(!out.contains("SameSite=Lax"), "{out}");
946 assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
947 }
948
949 #[test]
950 fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
951 let out = harden_cookie("session=securetoken");
954 assert!(out.contains("; Secure"), "{out}");
955 }
956
957 #[test]
958 fn security_headers_reflects_config_toggles() {
959 let cfg = HeadersCfg::default();
961 let got = security_headers(&cfg);
962 let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
963 assert!(names.contains(&"X-Content-Type-Options"));
964 assert!(names.contains(&"X-Frame-Options"));
965 assert!(names.contains(&"Referrer-Policy"));
966 assert!(names.contains(&"Permissions-Policy"));
967 assert!(names.contains(&"Content-Security-Policy"));
968 assert!(names.contains(&"Strict-Transport-Security"));
969 assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
970
971 let cfg = HeadersCfg {
974 hsts: false,
975 frame_options: String::new(),
976 csp: "default-src 'self'".into(),
977 csp_report_only: true,
978 csp_report_uri: "/__edgeguard/csp-report".into(),
979 ..HeadersCfg::default()
980 };
981 let got = security_headers(&cfg);
982 let map: std::collections::HashMap<&str, String> =
983 got.iter().map(|(n, v)| (*n, v.clone())).collect();
984 assert!(!map.contains_key("Strict-Transport-Security"));
985 assert!(!map.contains_key("X-Frame-Options"));
986 assert!(!map.contains_key("Content-Security-Policy"));
987 assert_eq!(
988 map.get("Content-Security-Policy-Report-Only")
989 .map(|s| s.as_str()),
990 Some("default-src 'self'; report-uri /__edgeguard/csp-report")
991 );
992 }
993}