1use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use arc_swap::ArcSwap;
20use axum::{
21 body::{Body, Bytes},
22 extract::{ConnectInfo, State},
23 http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
24};
25use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
26use http_body_util::{BodyExt, Full, Limited};
27use hyper::body::{Body as HttpBody, Frame, SizeHint};
28use hyper_util::client::legacy::{connect::HttpConnector, Client};
29use hyper_util::rt::TokioIo;
30use tokio::net::TcpStream;
31use tracing::{debug, info, warn};
32
33use crate::auth::{AuthEngine, Challenge, Decision};
34use crate::config::{Config, HeadersCfg};
35use crate::limiter::{Admit, DistributedLimiter};
36use crate::metrics::Metrics;
37use crate::waf::{WafEngine, WafMode};
38
39pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
40pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
42pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
43
44#[derive(Clone)]
47pub struct AppState {
48 pub client: UpstreamClient,
49 pub metrics: Arc<Metrics>,
50 pub runtime: Arc<ArcSwap<Runtime>>,
51 pub cp: Option<Arc<crate::cp::CpClient>>,
54 pub quota: Arc<crate::cp::QuotaState>,
58}
59
60pub struct RouteLimiter {
62 pub prefix: String,
63 pub limiter: Arc<KeyedLimiter>,
64}
65
66pub struct Runtime {
69 pub cfg: Arc<Config>,
70 pub upstream_base: Arc<String>,
73 pub upstream_routes: Vec<(String, Arc<String>)>,
76 pub auth: AuthEngine,
77 pub waf: WafEngine,
79 pub cors: Option<crate::cors::CorsPolicy>,
81 pub access: Option<crate::access::AccessPolicy>,
83 pub distributed: Option<DistributedLimiter>,
86 pub ip_limiter: Option<Arc<KeyedLimiter>>,
88 pub route_limiters: Vec<RouteLimiter>,
90 pub key_limiter: Option<Arc<StrLimiter>>,
92 pub max_body: usize,
93 pub max_response_body: usize,
95 pub max_header_bytes: usize,
97 pub upstream_timeout: Option<Duration>,
99 pub stream_passthrough: bool,
102 pub websocket_passthrough: bool,
105}
106
107impl Runtime {
108 pub fn pick_upstream(&self, path: &str) -> &str {
111 self.upstream_routes
112 .iter()
113 .filter(|(prefix, _)| path_prefix_matches(path, prefix))
114 .max_by_key(|(prefix, _)| prefix.len())
115 .map(|(_, base)| base.as_str())
116 .unwrap_or_else(|| self.upstream_base.as_str())
117 }
118}
119
120fn path_prefix_matches(path: &str, prefix: &str) -> bool {
126 if prefix == "/" {
127 return true;
128 }
129 match path.strip_prefix(prefix) {
130 Some(rest) => {
131 rest.is_empty()
132 || prefix.ends_with('/')
133 || rest.starts_with('/')
134 || rest.starts_with('?')
135 }
136 None => false,
137 }
138}
139
140const HOP_BY_HOP: &[&str] = &[
142 "connection",
143 "keep-alive",
144 "proxy-authenticate",
145 "proxy-authorization",
146 "te",
147 "trailer",
148 "transfer-encoding",
149 "upgrade",
150];
151
152pub async fn handle(
153 State(state): State<AppState>,
154 ConnectInfo(peer): ConnectInfo<SocketAddr>,
155 req: Request<Body>,
156) -> Response<Body> {
157 let rt = state.runtime.load_full();
161 let origin = req
165 .headers()
166 .get(header::ORIGIN)
167 .and_then(|v| v.to_str().ok())
168 .map(str::to_owned);
169 let mut resp = handle_inner(&state, &rt, peer, req).await;
170 if let Some(origin) = &origin {
171 if let Some(cors) = &rt.cors {
172 cors.decorate_origin(origin, &mut resp);
173 }
174 }
175 resp
176}
177
178async fn handle_inner(
179 state: &AppState,
180 rt: &Runtime,
181 peer: SocketAddr,
182 req: Request<Body>,
183) -> Response<Body> {
184 let started = Instant::now();
185 let m = &state.metrics;
186
187 let method = req.method().clone();
188 let path = req
189 .uri()
190 .path_and_query()
191 .map(|p| p.as_str().to_string())
192 .unwrap_or_else(|| req.uri().path().to_string());
193
194 let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
195 let rid = resolve_request_id(req.headers());
198
199 if req.uri().path().starts_with("/__edgeguard/") {
205 return finish(
206 m,
207 &rid,
208 &method,
209 &path,
210 ip,
211 started,
212 "not_found",
213 text(StatusCode::NOT_FOUND, "Not Found"),
214 );
215 }
216
217 if let Some(access) = &rt.access {
221 if !access.allowed(ip) {
222 return finish(
223 m,
224 &rid,
225 &method,
226 &path,
227 ip,
228 started,
229 "ip_denied",
230 text(StatusCode::FORBIDDEN, "Forbidden"),
231 );
232 }
233 }
234
235 if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
237 return finish(
238 m,
239 &rid,
240 &method,
241 &path,
242 ip,
243 started,
244 "header_too_large",
245 text(
246 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
247 "Request Header Fields Too Large",
248 ),
249 );
250 }
251
252 if rt.cfg.control_plane.enforce_quota && state.quota.blocked() {
258 let mut resp = text(StatusCode::TOO_MANY_REQUESTS, "Quota Exceeded");
259 let reset = state.quota.reset_epoch();
260 if reset > 0 {
261 let now = std::time::SystemTime::now()
262 .duration_since(std::time::UNIX_EPOCH)
263 .map(|d| d.as_secs() as i64)
264 .unwrap_or(0);
265 let retry_after = reset.saturating_sub(now).max(0);
266 if let Ok(v) = HeaderValue::from_str(&retry_after.to_string()) {
267 resp.headers_mut().insert(header::RETRY_AFTER, v);
268 }
269 }
270 return finish(m, &rid, &method, &path, ip, started, "over_quota", resp);
271 }
272
273 if rt.cfg.ratelimit.enabled {
277 if let Some(d) = &rt.distributed {
278 match d.check_ip_route(ip, &path).await {
279 Admit::Allowed => {}
280 Admit::Limited(scope) => {
281 m.record_ratelimit_hit(scope);
282 return finish(
283 m,
284 &rid,
285 &method,
286 &path,
287 ip,
288 started,
289 "rate_limited",
290 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
291 );
292 }
293 Admit::Error => {
294 return finish(
295 m,
296 &rid,
297 &method,
298 &path,
299 ip,
300 started,
301 "limiter_error",
302 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
303 );
304 }
305 }
306 } else {
307 let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
308 Some(r) => (Some(r.limiter.as_ref()), "route"),
309 None => (rt.ip_limiter.as_deref(), "ip"),
310 };
311 if let Some(limiter) = limiter {
312 if limiter.check_key(&ip).is_err() {
313 m.record_ratelimit_hit(scope);
314 return finish(
315 m,
316 &rid,
317 &method,
318 &path,
319 ip,
320 started,
321 "rate_limited",
322 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
323 );
324 }
325 }
326 }
327 }
328
329 if method == Method::OPTIONS {
334 if let Some(cors) = &rt.cors {
335 if let Some(resp) = cors.preflight_response(req.headers()) {
336 return finish(m, &rid, &method, &path, ip, started, "cors_preflight", resp);
337 }
338 }
339 }
340
341 let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
343 Decision::Allow(principal) => principal,
344 Decision::Deny(challenge) => {
345 let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
346 let challenge_value = match challenge {
347 Challenge::Basic(c) => Some(c),
348 Challenge::Bearer => Some("Bearer".to_string()),
349 Challenge::None => None,
350 };
351 if let Some(c) = challenge_value {
352 if let Ok(v) = HeaderValue::from_str(&c) {
353 resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
354 }
355 }
356 return finish(m, &rid, &method, &path, ip, started, "unauthorized", resp);
357 }
358 };
359
360 if let Some(principal) = &principal {
363 let key_admit = if let Some(d) = &rt.distributed {
364 Some(d.check_key(principal).await)
365 } else {
366 rt.key_limiter.as_ref().map(|limiter| {
367 if limiter.check_key(principal).is_err() {
368 Admit::Limited("key")
369 } else {
370 Admit::Allowed
371 }
372 })
373 };
374 match key_admit {
375 Some(Admit::Limited(scope)) => {
376 m.record_ratelimit_hit(scope);
377 return finish(
378 m,
379 &rid,
380 &method,
381 &path,
382 ip,
383 started,
384 "rate_limited",
385 text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
386 );
387 }
388 Some(Admit::Error) => {
389 return finish(
390 m,
391 &rid,
392 &method,
393 &path,
394 ip,
395 started,
396 "limiter_error",
397 text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
398 );
399 }
400 Some(Admit::Allowed) | None => {}
401 }
402 }
403
404 let allow = &rt.cfg.validation.allow_methods;
406 if !allow.is_empty()
407 && !allow
408 .iter()
409 .any(|x| x.eq_ignore_ascii_case(method.as_str()))
410 {
411 return finish(
412 m,
413 &rid,
414 &method,
415 &path,
416 ip,
417 started,
418 "method_not_allowed",
419 text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
420 );
421 }
422
423 if rt.websocket_passthrough && is_upgrade_request(req.headers()) {
430 return proxy_upgrade(state, rt, req, &rid, &method, &path, ip, started).await;
431 }
432
433 let (parts, body) = req.into_parts();
435 let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
436 Ok(b) => b,
437 Err(_) => {
438 return finish(
439 m,
440 &rid,
441 &method,
442 &path,
443 ip,
444 started,
445 "payload_too_large",
446 text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
447 )
448 }
449 };
450 let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
452
453 if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
458 m.record_waf_hit(hit.class);
459 match rt.waf.mode() {
460 WafMode::Block => {
461 warn!(
462 rule = %hit.rule_id,
463 class = hit.class,
464 location = hit.location,
465 client_ip = %ip,
466 path = %path,
467 "WAF blocked request"
468 );
469 return finish(
470 m,
471 &rid,
472 &method,
473 &path,
474 ip,
475 started,
476 "forbidden",
477 text(StatusCode::FORBIDDEN, "Forbidden"),
478 );
479 }
480 WafMode::Report => warn!(
481 rule = %hit.rule_id,
482 class = hit.class,
483 location = hit.location,
484 client_ip = %ip,
485 path = %path,
486 "WAF rule matched (report-only)"
487 ),
488 WafMode::Off => {}
491 }
492 }
493
494 let uri = format!("{}{}", rt.pick_upstream(&path), path);
496 let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
497 {
498 let headers = up.headers_mut().expect("builder headers");
499 let mut forwarded = parts.headers.clone();
502 strip_hop_by_hop(&mut forwarded);
503 for (name, value) in forwarded.iter() {
504 if name == header::HOST {
505 continue; }
507 headers.insert(name.clone(), value.clone());
508 }
509 if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
511 headers.insert(HeaderName::from_static("x-forwarded-for"), v);
512 }
513 headers.insert(
514 HeaderName::from_static("x-forwarded-proto"),
515 HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
516 );
517 if let Ok(v) = HeaderValue::from_str(&rid) {
519 headers.insert(HeaderName::from_static(REQUEST_ID_HEADER), v);
520 }
521 }
522
523 let upstream_req = match up.body(Full::new(body_bytes)) {
524 Ok(r) => r,
525 Err(e) => {
526 warn!(error = %e, "failed to build upstream request");
527 return finish(
528 m,
529 &rid,
530 &method,
531 &path,
532 ip,
533 started,
534 "bad_gateway",
535 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
536 );
537 }
538 };
539
540 let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
543 let timed_out = || {
544 warn!(upstream = %uri, "upstream timed out");
545 text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
546 };
547
548 let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
549 Ok(Ok(r)) => r,
550 Ok(Err(e)) => {
551 warn!(error = %e, upstream = %uri, "upstream unreachable");
552 return finish(
553 m,
554 &rid,
555 &method,
556 &path,
557 ip,
558 started,
559 "upstream_error",
560 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
561 );
562 }
563 Err(_) => {
564 return finish(
565 m,
566 &rid,
567 &method,
568 &path,
569 ip,
570 started,
571 "upstream_timeout",
572 timed_out(),
573 )
574 }
575 };
576
577 let (mut resp_parts, resp_body) = upstream_resp.into_parts();
578
579 if rt.stream_passthrough && is_event_stream(&resp_parts.headers) {
587 strip_hop_by_hop(&mut resp_parts.headers);
588 resp_parts.headers.remove(header::CONTENT_LENGTH);
589 let header_egress = header_bytes(&resp_parts.headers);
590 let body = Body::new(CountingBody::new(
591 resp_body,
592 Arc::clone(m),
593 ingress_bytes,
594 header_egress,
595 ));
596 let mut response = Response::from_parts(resp_parts, body);
597 harden_response(&rt.cfg, &mut response);
598 return finish(m, &rid, &method, &path, ip, started, "ok", response);
600 }
601
602 let resp_bytes = if rt.max_response_body > 0 {
604 match within(
605 deadline,
606 Limited::new(resp_body, rt.max_response_body).collect(),
607 )
608 .await
609 {
610 Ok(Ok(c)) => c.to_bytes(),
611 Ok(Err(_)) => {
612 warn!(
613 limit = rt.max_response_body,
614 "upstream response exceeded max_response_body"
615 );
616 return finish(
617 m,
618 &rid,
619 &method,
620 &path,
621 ip,
622 started,
623 "upstream_body_too_large",
624 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
625 );
626 }
627 Err(_) => {
628 return finish(
629 m,
630 &rid,
631 &method,
632 &path,
633 ip,
634 started,
635 "upstream_timeout",
636 timed_out(),
637 )
638 }
639 }
640 } else {
641 match within(deadline, resp_body.collect()).await {
642 Ok(Ok(c)) => c.to_bytes(),
643 Ok(Err(e)) => {
644 warn!(error = %e, "failed reading upstream body");
645 return finish(
646 m,
647 &rid,
648 &method,
649 &path,
650 ip,
651 started,
652 "upstream_body_error",
653 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
654 );
655 }
656 Err(_) => {
657 return finish(
658 m,
659 &rid,
660 &method,
661 &path,
662 ip,
663 started,
664 "upstream_timeout",
665 timed_out(),
666 )
667 }
668 }
669 };
670
671 strip_hop_by_hop(&mut resp_parts.headers);
674 resp_parts.headers.remove(header::CONTENT_LENGTH);
675
676 m.add_usage_bytes(
679 ingress_bytes,
680 header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
681 );
682
683 let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
684 harden_response(&rt.cfg, &mut response);
685 finish(m, &rid, &method, &path, ip, started, "ok", response)
688}
689
690pub async fn ready(State(state): State<AppState>) -> StatusCode {
695 let rt = state.runtime.load();
696 let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
697 return StatusCode::SERVICE_UNAVAILABLE;
698 };
699 match tokio::time::timeout(
700 Duration::from_secs(2),
701 TcpStream::connect((host.as_str(), port)),
702 )
703 .await
704 {
705 Ok(Ok(_)) => StatusCode::OK,
706 _ => StatusCode::SERVICE_UNAVAILABLE,
707 }
708}
709
710pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
714 let body = state.metrics.render();
715 let mut resp = Response::new(Body::from(body));
716 resp.headers_mut().insert(
717 header::CONTENT_TYPE,
718 HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
719 );
720 resp
721}
722
723pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
726 state.metrics.record_csp_report();
727 if let Some(cp) = &state.cp {
731 if state.runtime.load().cfg.control_plane.forward_csp {
732 let cp = cp.clone();
733 let raw = body.clone();
734 tokio::spawn(async move { cp.forward_csp(&raw).await });
735 }
736 }
737 match serde_json::from_slice::<serde_json::Value>(&body) {
741 Ok(report) => {
742 let directive = report
743 .get("csp-report")
744 .and_then(|r| {
745 r.get("violated-directive")
746 .or_else(|| r.get("effective-directive"))
747 })
748 .and_then(|v| v.as_str())
749 .unwrap_or("unknown");
750 debug!(target: "edgeguard::csp", directive, "CSP violation report");
751 }
752 Err(_) => warn!(
753 bytes = body.len(),
754 "CSP violation report with an unparseable body"
755 ),
756 }
757 StatusCode::NO_CONTENT
758}
759
760const REQUEST_ID_HEADER: &str = "x-request-id";
764
765fn resolve_request_id(headers: &HeaderMap) -> String {
770 if let Some(v) = headers.get(REQUEST_ID_HEADER).and_then(|v| v.to_str().ok()) {
771 let v = v.trim();
772 if !v.is_empty() && v.len() <= 128 && v.bytes().all(|b| b.is_ascii_graphic()) {
773 return v.to_string();
774 }
775 }
776 uuid::Uuid::new_v4().to_string()
777}
778
779fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
783 if trust_forwarded {
784 if let Some(xff) = headers.get("x-forwarded-for") {
785 if let Ok(s) = xff.to_str() {
786 if let Some(first) = s.split(',').next() {
787 if let Ok(ip) = first.trim().parse::<IpAddr>() {
788 return ip;
789 }
790 }
791 }
792 }
793 }
794 peer.ip()
795}
796
797fn header_bytes(headers: &HeaderMap) -> usize {
800 headers
801 .iter()
802 .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
803 .sum()
804}
805
806fn is_event_stream(headers: &HeaderMap) -> bool {
810 headers
811 .get(header::CONTENT_TYPE)
812 .and_then(|v| v.to_str().ok())
813 .map(|v| {
814 v.split(';')
815 .next()
816 .map(str::trim)
817 .map(|ct| ct.eq_ignore_ascii_case("text/event-stream"))
818 .unwrap_or(false)
819 })
820 .unwrap_or(false)
821}
822
823fn is_upgrade_request(headers: &HeaderMap) -> bool {
826 let conn_has_upgrade = headers
827 .get_all(header::CONNECTION)
828 .iter()
829 .filter_map(|v| v.to_str().ok())
830 .flat_map(|v| v.split(','))
831 .any(|t| t.trim().eq_ignore_ascii_case("upgrade"));
832 conn_has_upgrade && headers.contains_key(header::UPGRADE)
833}
834
835#[allow(clippy::too_many_arguments)]
843async fn proxy_upgrade(
844 state: &AppState,
845 rt: &Runtime,
846 mut req: Request<Body>,
847 request_id: &str,
848 method: &Method,
849 path: &str,
850 ip: IpAddr,
851 started: Instant,
852) -> Response<Body> {
853 let m = &state.metrics;
854
855 let client_upgrade = hyper::upgrade::on(&mut req);
858
859 let uri = format!("{}{}", rt.pick_upstream(path), path);
862 let mut up = Request::builder().method(req.method().clone()).uri(&uri);
863 {
864 let headers = up.headers_mut().expect("builder headers");
865 let upgrade = req.headers().get(header::UPGRADE).cloned();
869 let mut forwarded = req.headers().clone();
870 strip_hop_by_hop(&mut forwarded);
871 for (name, value) in forwarded.iter() {
872 if name == header::HOST {
873 continue;
874 }
875 headers.insert(name.clone(), value.clone());
876 }
877 headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
878 if let Some(v) = upgrade {
879 headers.insert(header::UPGRADE, v);
880 }
881 if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
882 headers.insert(HeaderName::from_static("x-forwarded-for"), v);
883 }
884 headers.insert(
885 HeaderName::from_static("x-forwarded-proto"),
886 HeaderValue::from_static(forwarded_proto(&rt.cfg, req.headers())),
887 );
888 if let Ok(v) = HeaderValue::from_str(request_id) {
889 headers.insert(HeaderName::from_static(REQUEST_ID_HEADER), v);
890 }
891 }
892 let upstream_req = match up.body(Full::new(Bytes::new())) {
893 Ok(r) => r,
894 Err(e) => {
895 warn!(error = %e, "failed to build upstream upgrade request");
896 return finish(
897 m,
898 request_id,
899 method,
900 path,
901 ip,
902 started,
903 "bad_gateway",
904 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
905 );
906 }
907 };
908
909 let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
912 let timed_out = || {
913 warn!(upstream = %uri, "upstream timed out (upgrade)");
914 finish(
915 m,
916 request_id,
917 method,
918 path,
919 ip,
920 started,
921 "upstream_timeout",
922 text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout"),
923 )
924 };
925
926 let mut up_resp = match within(deadline, state.client.request(upstream_req)).await {
927 Ok(Ok(r)) => r,
928 Ok(Err(e)) => {
929 warn!(error = %e, upstream = %uri, "upstream unreachable (upgrade)");
930 return finish(
931 m,
932 request_id,
933 method,
934 path,
935 ip,
936 started,
937 "upstream_error",
938 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
939 );
940 }
941 Err(_) => return timed_out(),
942 };
943
944 if up_resp.status() != StatusCode::SWITCHING_PROTOCOLS {
948 let (mut parts, body) = up_resp.into_parts();
949 let body_fut = async {
952 if rt.max_response_body > 0 {
953 Limited::new(body, rt.max_response_body)
954 .collect()
955 .await
956 .map(|c| c.to_bytes())
957 .map_err(|_| ())
958 } else {
959 body.collect().await.map(|c| c.to_bytes()).map_err(|_| ())
960 }
961 };
962 let bytes = match within(deadline, body_fut).await {
963 Ok(Ok(b)) => b,
964 Ok(Err(())) => {
965 warn!("upstream upgrade-rejection body failed or exceeded max_response_body");
966 return finish(
967 m,
968 request_id,
969 method,
970 path,
971 ip,
972 started,
973 "bad_gateway",
974 text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
975 );
976 }
977 Err(_) => return timed_out(),
978 };
979 strip_hop_by_hop(&mut parts.headers);
980 parts.headers.remove(header::CONTENT_LENGTH);
981 let mut response = Response::from_parts(parts, Body::from(bytes));
982 harden_response(&rt.cfg, &mut response);
983 return finish(m, request_id, method, path, ip, started, "ok", response);
984 }
985
986 let upstream_upgrade = hyper::upgrade::on(&mut up_resp);
988 tokio::spawn(async move {
989 match tokio::join!(client_upgrade, upstream_upgrade) {
990 (Ok(client_io), Ok(up_io)) => {
991 let mut client_io = TokioIo::new(client_io);
992 let mut up_io = TokioIo::new(up_io);
993 if let Err(e) = tokio::io::copy_bidirectional(&mut client_io, &mut up_io).await {
994 debug!(error = %e, "websocket tunnel closed");
995 }
996 }
997 (c, u) => warn!(
998 client_ok = c.is_ok(),
999 upstream_ok = u.is_ok(),
1000 "websocket upgrade did not complete"
1001 ),
1002 }
1003 });
1004
1005 let (mut parts, _body) = up_resp.into_parts();
1011 let upgrade = parts.headers.get(header::UPGRADE).cloned();
1012 strip_hop_by_hop(&mut parts.headers);
1013 parts.headers.remove(header::CONTENT_LENGTH);
1014 parts
1015 .headers
1016 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1017 if let Some(v) = upgrade {
1018 parts.headers.insert(header::UPGRADE, v);
1019 }
1020 let response = Response::from_parts(parts, Body::empty());
1021 finish(
1022 m,
1023 request_id,
1024 method,
1025 path,
1026 ip,
1027 started,
1028 "ws_upgrade",
1029 response,
1030 )
1031}
1032
1033struct CountingBody<B> {
1039 inner: B,
1040 metrics: Arc<Metrics>,
1041 ingress: usize,
1042 egress: usize,
1044}
1045
1046impl<B> CountingBody<B> {
1047 fn new(inner: B, metrics: Arc<Metrics>, ingress: usize, header_egress: usize) -> Self {
1048 Self {
1049 inner,
1050 metrics,
1051 ingress,
1052 egress: header_egress,
1053 }
1054 }
1055}
1056
1057impl<B> HttpBody for CountingBody<B>
1058where
1059 B: HttpBody<Data = Bytes> + Unpin,
1060{
1061 type Data = Bytes;
1062 type Error = B::Error;
1063
1064 fn poll_frame(
1065 mut self: Pin<&mut Self>,
1066 cx: &mut Context<'_>,
1067 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1068 let this = self.as_mut().get_mut();
1069 let polled = Pin::new(&mut this.inner).poll_frame(cx);
1070 if let Poll::Ready(Some(Ok(frame))) = &polled {
1071 if let Some(data) = frame.data_ref() {
1072 this.egress = this.egress.saturating_add(data.len());
1073 }
1074 }
1075 polled
1076 }
1077
1078 fn is_end_stream(&self) -> bool {
1079 self.inner.is_end_stream()
1080 }
1081
1082 fn size_hint(&self) -> SizeHint {
1083 self.inner.size_hint()
1084 }
1085}
1086
1087impl<B> Drop for CountingBody<B> {
1088 fn drop(&mut self) {
1089 self.metrics.add_usage_bytes(self.ingress, self.egress);
1090 }
1091}
1092
1093fn strip_hop_by_hop(headers: &mut HeaderMap) {
1097 let connection_named: Vec<HeaderName> = headers
1100 .get_all(header::CONNECTION)
1101 .iter()
1102 .filter_map(|v| v.to_str().ok())
1103 .flat_map(|v| v.split(','))
1104 .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
1105 .collect();
1106 for name in HOP_BY_HOP {
1107 headers.remove(*name);
1108 }
1109 for name in connection_named {
1110 headers.remove(name);
1111 }
1112}
1113
1114fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
1120 if cfg.tls.enabled {
1121 return "https";
1122 }
1123 if cfg.server.trust_forwarded_for {
1124 if let Some(value) = headers
1125 .get("x-forwarded-proto")
1126 .and_then(|v| v.to_str().ok())
1127 {
1128 match value.split(',').next().map(str::trim) {
1129 Some(p) if p.eq_ignore_ascii_case("https") => return "https",
1130 Some(p) if p.eq_ignore_ascii_case("http") => return "http",
1131 _ => {}
1132 }
1133 }
1134 }
1135 "http"
1136}
1137
1138fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
1140 routes
1141 .iter()
1142 .filter(|r| path.starts_with(&r.prefix))
1143 .max_by_key(|r| r.prefix.len())
1144}
1145
1146pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
1150
1151pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
1163 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
1164 out.push(("X-Content-Type-Options", "nosniff".to_string()));
1165 if !cfg.frame_options.is_empty() {
1166 out.push(("X-Frame-Options", cfg.frame_options.clone()));
1167 }
1168 if !cfg.referrer_policy.is_empty() {
1169 out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
1170 }
1171 if !cfg.permissions_policy.is_empty() {
1172 out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
1173 }
1174 if !cfg.csp.is_empty() {
1175 let mut value = cfg.csp.clone();
1177 if !cfg.csp_report_uri.is_empty() {
1178 value.push_str("; report-uri ");
1179 value.push_str(&cfg.csp_report_uri);
1180 }
1181 let name = if cfg.csp_report_only {
1182 "Content-Security-Policy-Report-Only"
1183 } else {
1184 "Content-Security-Policy"
1185 };
1186 out.push((name, value));
1187 }
1188 if cfg.hsts {
1189 out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
1190 }
1191 out
1192}
1193
1194fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
1196 let h = resp.headers_mut();
1197
1198 for (name, value) in security_headers(&cfg.headers) {
1202 if let (Ok(n), Ok(v)) = (
1203 HeaderName::from_bytes(name.as_bytes()),
1204 HeaderValue::from_str(&value),
1205 ) {
1206 h.insert(n, v);
1207 }
1208 }
1209
1210 for name in &cfg.headers.strip {
1212 if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
1213 h.remove(hn);
1214 }
1215 }
1216
1217 if cfg.headers.force_secure_cookies {
1219 let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
1220 if !cookies.is_empty() {
1221 h.remove(header::SET_COOKIE);
1222 for c in cookies {
1223 if let Ok(s) = c.to_str() {
1224 let hardened = harden_cookie(s);
1225 if let Ok(v) = HeaderValue::from_str(&hardened) {
1226 h.append(header::SET_COOKIE, v);
1227 }
1228 } else {
1229 h.append(header::SET_COOKIE, c);
1230 }
1231 }
1232 }
1233 }
1234}
1235
1236fn harden_cookie(cookie: &str) -> String {
1237 let attrs: std::collections::HashSet<String> = cookie
1241 .split(';')
1242 .skip(1)
1243 .filter_map(|p| p.trim().split('=').next())
1244 .map(|k| k.trim().to_ascii_lowercase())
1245 .collect();
1246
1247 let mut out = cookie.trim_end_matches(';').to_string();
1248 if !attrs.contains("secure") {
1249 out.push_str("; Secure");
1250 }
1251 if !attrs.contains("httponly") {
1252 out.push_str("; HttpOnly");
1253 }
1254 if !attrs.contains("samesite") {
1255 out.push_str("; SameSite=Lax");
1256 }
1257 out
1258}
1259
1260async fn within<F: Future>(
1263 deadline: Option<tokio::time::Instant>,
1264 fut: F,
1265) -> Result<F::Output, tokio::time::error::Elapsed> {
1266 match deadline {
1267 Some(dl) => tokio::time::timeout_at(dl, fut).await,
1268 None => Ok(fut.await),
1269 }
1270}
1271
1272fn text(status: StatusCode, msg: &str) -> Response<Body> {
1273 let mut resp = Response::new(Body::from(msg.to_string()));
1274 *resp.status_mut() = status;
1275 resp.headers_mut().insert(
1276 header::CONTENT_TYPE,
1277 HeaderValue::from_static("text/plain; charset=utf-8"),
1278 );
1279 resp
1280}
1281
1282#[allow(clippy::too_many_arguments)]
1287fn finish(
1288 metrics: &Metrics,
1289 request_id: &str,
1290 method: &Method,
1291 path: &str,
1292 ip: IpAddr,
1293 started: Instant,
1294 outcome: &str,
1295 mut resp: Response<Body>,
1296) -> Response<Body> {
1297 if let Ok(v) = HeaderValue::from_str(request_id) {
1300 resp.headers_mut().insert(REQUEST_ID_HEADER, v);
1301 }
1302 let elapsed = started.elapsed();
1303 info!(
1304 request_id,
1305 %method,
1306 path = %path,
1307 client_ip = %ip,
1308 status = resp.status().as_u16(),
1309 outcome,
1310 latency_ms = elapsed.as_millis() as u64,
1311 "request"
1312 );
1313 metrics.record_request(outcome);
1314 metrics.observe_latency(elapsed);
1315 metrics.add_usage_request();
1318 resp
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 use super::*;
1324
1325 fn headers_with(name: &'static str, value: &str) -> HeaderMap {
1326 let mut h = HeaderMap::new();
1327 h.insert(name, HeaderValue::from_str(value).unwrap());
1328 h
1329 }
1330
1331 #[test]
1332 fn path_prefix_matches_on_segment_boundary_only() {
1333 assert!(path_prefix_matches("/api", "/api"));
1335 assert!(path_prefix_matches("/api/users", "/api"));
1336 assert!(path_prefix_matches("/api?x=1", "/api"));
1337 assert!(path_prefix_matches("/api/users", "/api/"));
1339 assert!(!path_prefix_matches("/apiary", "/api"));
1341 assert!(!path_prefix_matches("/apiary/honey", "/api"));
1342 assert!(path_prefix_matches("/anything", "/"));
1344 }
1345
1346 #[test]
1347 fn client_ip_ignores_xff_when_untrusted() {
1348 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1349 let h = headers_with("x-forwarded-for", "1.2.3.4");
1350 assert_eq!(client_ip(&h, peer, false), peer.ip());
1352 }
1353
1354 #[test]
1355 fn client_ip_uses_first_xff_hop_when_trusted() {
1356 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1357 let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
1358 assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
1359 }
1360
1361 #[test]
1362 fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
1363 let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1364 assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
1365 let garbage = headers_with("x-forwarded-for", "not-an-ip");
1366 assert_eq!(client_ip(&garbage, peer, true), peer.ip());
1367 }
1368
1369 #[test]
1370 fn header_bytes_sums_names_and_values() {
1371 let mut h = HeaderMap::new();
1372 h.insert("a", HeaderValue::from_static("bb")); h.insert("ccc", HeaderValue::from_static("dddd")); assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
1375 }
1376
1377 #[test]
1378 fn strip_hop_by_hop_removes_fixed_and_connection_named() {
1379 let mut h = HeaderMap::new();
1380 h.insert(
1381 "connection",
1382 HeaderValue::from_static("keep-alive, X-Custom-Hop"),
1383 );
1384 h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
1385 h.insert("x-custom-hop", HeaderValue::from_static("secret"));
1386 h.insert("content-type", HeaderValue::from_static("text/plain"));
1387 strip_hop_by_hop(&mut h);
1388 assert!(!h.contains_key("connection"));
1389 assert!(!h.contains_key("keep-alive"));
1390 assert!(!h.contains_key("x-custom-hop"));
1392 assert!(h.contains_key("content-type"));
1394 }
1395
1396 #[test]
1397 fn forwarded_proto_reflects_tls_and_trust() {
1398 let mut cfg = Config::default();
1399
1400 cfg.tls.enabled = true;
1402 assert_eq!(
1403 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
1404 "https"
1405 );
1406
1407 cfg.tls.enabled = false;
1409 cfg.server.trust_forwarded_for = false;
1410 assert_eq!(
1411 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1412 "http"
1413 );
1414
1415 cfg.server.trust_forwarded_for = true;
1417 assert_eq!(
1418 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1419 "https"
1420 );
1421 assert_eq!(
1422 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
1423 "http"
1424 );
1425 assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
1427 assert_eq!(
1428 forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
1429 "http"
1430 );
1431 }
1432
1433 #[test]
1434 fn longest_route_picks_most_specific_prefix() {
1435 let mk = |p: &str| RouteLimiter {
1436 prefix: p.to_string(),
1437 limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
1438 std::num::NonZeroU32::new(1).unwrap(),
1439 ))),
1440 };
1441 let routes = vec![mk("/api/"), mk("/api/admin/")];
1442 assert_eq!(
1443 longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
1444 Some("/api/admin/")
1445 );
1446 assert_eq!(
1447 longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
1448 Some("/api/")
1449 );
1450 assert!(longest_route(&routes, "/public").is_none());
1451 }
1452
1453 #[test]
1454 fn path_prefix_matches_on_segment_boundaries() {
1455 assert!(path_prefix_matches("/api", "/api")); assert!(path_prefix_matches("/api/users", "/api")); assert!(path_prefix_matches("/api?q=1", "/api")); assert!(!path_prefix_matches("/apiary", "/api")); assert!(path_prefix_matches("/api/users", "/api/"));
1462 assert!(!path_prefix_matches("/apiary", "/api/"));
1463 assert!(path_prefix_matches("/anything", "/"));
1465 }
1466
1467 #[test]
1468 fn harden_cookie_adds_missing_flags() {
1469 let out = harden_cookie("sid=abc");
1470 assert!(out.contains("; Secure"), "{out}");
1471 assert!(out.contains("; HttpOnly"), "{out}");
1472 assert!(out.contains("; SameSite=Lax"), "{out}");
1473 }
1474
1475 #[test]
1476 fn harden_cookie_preserves_existing_attributes() {
1477 let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
1478 assert!(out.contains("; Secure"), "{out}");
1479 assert!(out.contains("SameSite=Strict"), "{out}");
1480 assert!(!out.contains("SameSite=Lax"), "{out}");
1482 assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
1483 }
1484
1485 #[test]
1486 fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
1487 let out = harden_cookie("session=securetoken");
1490 assert!(out.contains("; Secure"), "{out}");
1491 }
1492
1493 #[test]
1494 fn security_headers_reflects_config_toggles() {
1495 let cfg = HeadersCfg::default();
1497 let got = security_headers(&cfg);
1498 let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
1499 assert!(names.contains(&"X-Content-Type-Options"));
1500 assert!(names.contains(&"X-Frame-Options"));
1501 assert!(names.contains(&"Referrer-Policy"));
1502 assert!(names.contains(&"Permissions-Policy"));
1503 assert!(names.contains(&"Content-Security-Policy"));
1504 assert!(names.contains(&"Strict-Transport-Security"));
1505 assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
1506
1507 let cfg = HeadersCfg {
1510 hsts: false,
1511 frame_options: String::new(),
1512 csp: "default-src 'self'".into(),
1513 csp_report_only: true,
1514 csp_report_uri: "/__edgeguard/csp-report".into(),
1515 ..HeadersCfg::default()
1516 };
1517 let got = security_headers(&cfg);
1518 let map: std::collections::HashMap<&str, String> =
1519 got.iter().map(|(n, v)| (*n, v.clone())).collect();
1520 assert!(!map.contains_key("Strict-Transport-Security"));
1521 assert!(!map.contains_key("X-Frame-Options"));
1522 assert!(!map.contains_key("Content-Security-Policy"));
1523 assert_eq!(
1524 map.get("Content-Security-Policy-Report-Only")
1525 .map(|s| s.as_str()),
1526 Some("default-src 'self'; report-uri /__edgeguard/csp-report")
1527 );
1528 }
1529}