use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use arc_swap::ArcSwap;
use axum::{
body::{Body, Bytes},
extract::{ConnectInfo, State},
http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
};
use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
use http_body_util::{BodyExt, Full, Limited};
use hyper_util::client::legacy::{connect::HttpConnector, Client};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use crate::auth::{AuthEngine, Challenge, Decision};
use crate::config::{Config, HeadersCfg};
use crate::limiter::{Admit, DistributedLimiter};
use crate::metrics::Metrics;
use crate::waf::{WafEngine, WafMode};
pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
#[derive(Clone)]
pub struct AppState {
pub client: UpstreamClient,
pub metrics: Arc<Metrics>,
pub runtime: Arc<ArcSwap<Runtime>>,
}
pub struct RouteLimiter {
pub prefix: String,
pub limiter: Arc<KeyedLimiter>,
}
pub struct Runtime {
pub cfg: Arc<Config>,
pub upstream_base: Arc<String>,
pub auth: AuthEngine,
pub waf: WafEngine,
pub distributed: Option<DistributedLimiter>,
pub ip_limiter: Option<Arc<KeyedLimiter>>,
pub route_limiters: Vec<RouteLimiter>,
pub key_limiter: Option<Arc<StrLimiter>>,
pub max_body: usize,
pub max_response_body: usize,
pub max_header_bytes: usize,
pub upstream_timeout: Option<Duration>,
}
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
pub async fn handle(
State(state): State<AppState>,
ConnectInfo(peer): ConnectInfo<SocketAddr>,
req: Request<Body>,
) -> Response<Body> {
let started = Instant::now();
let rt = state.runtime.load();
let m = &state.metrics;
let method = req.method().clone();
let path = req
.uri()
.path_and_query()
.map(|p| p.as_str().to_string())
.unwrap_or_else(|| req.uri().path().to_string());
let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
if req.uri().path().starts_with("/__edgeguard/") {
return finish(
m,
&method,
&path,
ip,
started,
"not_found",
text(StatusCode::NOT_FOUND, "Not Found"),
);
}
if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
return finish(
m,
&method,
&path,
ip,
started,
"header_too_large",
text(
StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
"Request Header Fields Too Large",
),
);
}
if rt.cfg.ratelimit.enabled {
if let Some(d) = &rt.distributed {
match d.check_ip_route(ip, &path).await {
Admit::Allowed => {}
Admit::Limited(scope) => {
m.record_ratelimit_hit(scope);
return finish(
m,
&method,
&path,
ip,
started,
"rate_limited",
text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
);
}
Admit::Error => {
return finish(
m,
&method,
&path,
ip,
started,
"limiter_error",
text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
);
}
}
} else {
let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
Some(r) => (Some(r.limiter.as_ref()), "route"),
None => (rt.ip_limiter.as_deref(), "ip"),
};
if let Some(limiter) = limiter {
if limiter.check_key(&ip).is_err() {
m.record_ratelimit_hit(scope);
return finish(
m,
&method,
&path,
ip,
started,
"rate_limited",
text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
);
}
}
}
}
let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
Decision::Allow(principal) => principal,
Decision::Deny(challenge) => {
let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
let challenge_value = match challenge {
Challenge::Basic(c) => Some(c),
Challenge::Bearer => Some("Bearer".to_string()),
Challenge::None => None,
};
if let Some(c) = challenge_value {
if let Ok(v) = HeaderValue::from_str(&c) {
resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
}
}
return finish(m, &method, &path, ip, started, "unauthorized", resp);
}
};
if let Some(principal) = &principal {
let key_admit = if let Some(d) = &rt.distributed {
Some(d.check_key(principal).await)
} else {
rt.key_limiter.as_ref().map(|limiter| {
if limiter.check_key(principal).is_err() {
Admit::Limited("key")
} else {
Admit::Allowed
}
})
};
match key_admit {
Some(Admit::Limited(scope)) => {
m.record_ratelimit_hit(scope);
return finish(
m,
&method,
&path,
ip,
started,
"rate_limited",
text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
);
}
Some(Admit::Error) => {
return finish(
m,
&method,
&path,
ip,
started,
"limiter_error",
text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
);
}
Some(Admit::Allowed) | None => {}
}
}
let allow = &rt.cfg.validation.allow_methods;
if !allow.is_empty()
&& !allow
.iter()
.any(|x| x.eq_ignore_ascii_case(method.as_str()))
{
return finish(
m,
&method,
&path,
ip,
started,
"method_not_allowed",
text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
);
}
let (parts, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
Ok(b) => b,
Err(_) => {
return finish(
m,
&method,
&path,
ip,
started,
"payload_too_large",
text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
)
}
};
if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
m.record_waf_hit(hit.class);
match rt.waf.mode() {
WafMode::Block => {
warn!(
rule = %hit.rule_id,
class = hit.class,
location = hit.location,
client_ip = %ip,
path = %path,
"WAF blocked request"
);
return finish(
m,
&method,
&path,
ip,
started,
"forbidden",
text(StatusCode::FORBIDDEN, "Forbidden"),
);
}
WafMode::Report => warn!(
rule = %hit.rule_id,
class = hit.class,
location = hit.location,
client_ip = %ip,
path = %path,
"WAF rule matched (report-only)"
),
WafMode::Off => {}
}
}
let uri = format!("{}{}", rt.upstream_base, path);
let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
{
let headers = up.headers_mut().expect("builder headers");
let mut forwarded = parts.headers.clone();
strip_hop_by_hop(&mut forwarded);
for (name, value) in forwarded.iter() {
if name == header::HOST {
continue; }
headers.insert(name.clone(), value.clone());
}
if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
headers.insert(HeaderName::from_static("x-forwarded-for"), v);
}
headers.insert(
HeaderName::from_static("x-forwarded-proto"),
HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
);
}
let upstream_req = match up.body(Full::new(body_bytes)) {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "failed to build upstream request");
return finish(
m,
&method,
&path,
ip,
started,
"bad_gateway",
text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
);
}
};
let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
let timed_out = || {
warn!(upstream = %uri, "upstream timed out");
text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
};
let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
warn!(error = %e, upstream = %uri, "upstream unreachable");
return finish(
m,
&method,
&path,
ip,
started,
"upstream_error",
text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
);
}
Err(_) => {
return finish(
m,
&method,
&path,
ip,
started,
"upstream_timeout",
timed_out(),
)
}
};
let (mut resp_parts, resp_body) = upstream_resp.into_parts();
let resp_bytes = if rt.max_response_body > 0 {
match within(
deadline,
Limited::new(resp_body, rt.max_response_body).collect(),
)
.await
{
Ok(Ok(c)) => c.to_bytes(),
Ok(Err(_)) => {
warn!(
limit = rt.max_response_body,
"upstream response exceeded max_response_body"
);
return finish(
m,
&method,
&path,
ip,
started,
"upstream_body_too_large",
text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
);
}
Err(_) => {
return finish(
m,
&method,
&path,
ip,
started,
"upstream_timeout",
timed_out(),
)
}
}
} else {
match within(deadline, resp_body.collect()).await {
Ok(Ok(c)) => c.to_bytes(),
Ok(Err(e)) => {
warn!(error = %e, "failed reading upstream body");
return finish(
m,
&method,
&path,
ip,
started,
"upstream_body_error",
text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
);
}
Err(_) => {
return finish(
m,
&method,
&path,
ip,
started,
"upstream_timeout",
timed_out(),
)
}
}
};
strip_hop_by_hop(&mut resp_parts.headers);
resp_parts.headers.remove(header::CONTENT_LENGTH);
let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
harden_response(&rt.cfg, &mut response);
finish(m, &method, &path, ip, started, "ok", response)
}
pub async fn ready(State(state): State<AppState>) -> StatusCode {
let rt = state.runtime.load();
let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
return StatusCode::SERVICE_UNAVAILABLE;
};
match tokio::time::timeout(
Duration::from_secs(2),
TcpStream::connect((host.as_str(), port)),
)
.await
{
Ok(Ok(_)) => StatusCode::OK,
_ => StatusCode::SERVICE_UNAVAILABLE,
}
}
pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
let body = state.metrics.render();
let mut resp = Response::new(Body::from(body));
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
);
resp
}
pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
state.metrics.record_csp_report();
match serde_json::from_slice::<serde_json::Value>(&body) {
Ok(report) => {
let directive = report
.get("csp-report")
.and_then(|r| {
r.get("violated-directive")
.or_else(|| r.get("effective-directive"))
})
.and_then(|v| v.as_str())
.unwrap_or("unknown");
debug!(target: "edgeguard::csp", directive, "CSP violation report");
}
Err(_) => warn!(
bytes = body.len(),
"CSP violation report with an unparseable body"
),
}
StatusCode::NO_CONTENT
}
fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
if trust_forwarded {
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(s) = xff.to_str() {
if let Some(first) = s.split(',').next() {
if let Ok(ip) = first.trim().parse::<IpAddr>() {
return ip;
}
}
}
}
}
peer.ip()
}
fn header_bytes(headers: &HeaderMap) -> usize {
headers
.iter()
.map(|(name, value)| name.as_str().len() + value.as_bytes().len())
.sum()
}
fn strip_hop_by_hop(headers: &mut HeaderMap) {
let connection_named: Vec<HeaderName> = headers
.get_all(header::CONNECTION)
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|v| v.split(','))
.filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
.collect();
for name in HOP_BY_HOP {
headers.remove(*name);
}
for name in connection_named {
headers.remove(name);
}
}
fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
if cfg.tls.enabled {
return "https";
}
if cfg.server.trust_forwarded_for {
if let Some(value) = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
{
match value.split(',').next().map(str::trim) {
Some(p) if p.eq_ignore_ascii_case("https") => return "https",
Some(p) if p.eq_ignore_ascii_case("http") => return "http",
_ => {}
}
}
}
"http"
}
fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
routes
.iter()
.filter(|r| path.starts_with(&r.prefix))
.max_by_key(|r| r.prefix.len())
}
pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
out.push(("X-Content-Type-Options", "nosniff".to_string()));
if !cfg.frame_options.is_empty() {
out.push(("X-Frame-Options", cfg.frame_options.clone()));
}
if !cfg.referrer_policy.is_empty() {
out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
}
if !cfg.permissions_policy.is_empty() {
out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
}
if !cfg.csp.is_empty() {
let mut value = cfg.csp.clone();
if !cfg.csp_report_uri.is_empty() {
value.push_str("; report-uri ");
value.push_str(&cfg.csp_report_uri);
}
let name = if cfg.csp_report_only {
"Content-Security-Policy-Report-Only"
} else {
"Content-Security-Policy"
};
out.push((name, value));
}
if cfg.hsts {
out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
}
out
}
fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
let h = resp.headers_mut();
for (name, value) in security_headers(&cfg.headers) {
if let (Ok(n), Ok(v)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(&value),
) {
h.insert(n, v);
}
}
for name in &cfg.headers.strip {
if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
h.remove(hn);
}
}
if cfg.headers.force_secure_cookies {
let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
if !cookies.is_empty() {
h.remove(header::SET_COOKIE);
for c in cookies {
if let Ok(s) = c.to_str() {
let hardened = harden_cookie(s);
if let Ok(v) = HeaderValue::from_str(&hardened) {
h.append(header::SET_COOKIE, v);
}
} else {
h.append(header::SET_COOKIE, c);
}
}
}
}
}
fn harden_cookie(cookie: &str) -> String {
let attrs: std::collections::HashSet<String> = cookie
.split(';')
.skip(1)
.filter_map(|p| p.trim().split('=').next())
.map(|k| k.trim().to_ascii_lowercase())
.collect();
let mut out = cookie.trim_end_matches(';').to_string();
if !attrs.contains("secure") {
out.push_str("; Secure");
}
if !attrs.contains("httponly") {
out.push_str("; HttpOnly");
}
if !attrs.contains("samesite") {
out.push_str("; SameSite=Lax");
}
out
}
async fn within<F: Future>(
deadline: Option<tokio::time::Instant>,
fut: F,
) -> Result<F::Output, tokio::time::error::Elapsed> {
match deadline {
Some(dl) => tokio::time::timeout_at(dl, fut).await,
None => Ok(fut.await),
}
}
fn text(status: StatusCode, msg: &str) -> Response<Body> {
let mut resp = Response::new(Body::from(msg.to_string()));
*resp.status_mut() = status;
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
resp
}
fn finish(
metrics: &Metrics,
method: &Method,
path: &str,
ip: IpAddr,
started: Instant,
outcome: &str,
resp: Response<Body>,
) -> Response<Body> {
let elapsed = started.elapsed();
info!(
%method,
path = %path,
client_ip = %ip,
status = resp.status().as_u16(),
outcome,
latency_ms = elapsed.as_millis() as u64,
"request"
);
metrics.record_request(outcome);
metrics.observe_latency(elapsed);
resp
}
#[cfg(test)]
mod tests {
use super::*;
fn headers_with(name: &'static str, value: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(name, HeaderValue::from_str(value).unwrap());
h
}
#[test]
fn client_ip_ignores_xff_when_untrusted() {
let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
let h = headers_with("x-forwarded-for", "1.2.3.4");
assert_eq!(client_ip(&h, peer, false), peer.ip());
}
#[test]
fn client_ip_uses_first_xff_hop_when_trusted() {
let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
}
#[test]
fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
let garbage = headers_with("x-forwarded-for", "not-an-ip");
assert_eq!(client_ip(&garbage, peer, true), peer.ip());
}
#[test]
fn header_bytes_sums_names_and_values() {
let mut h = HeaderMap::new();
h.insert("a", HeaderValue::from_static("bb")); h.insert("ccc", HeaderValue::from_static("dddd")); assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
}
#[test]
fn strip_hop_by_hop_removes_fixed_and_connection_named() {
let mut h = HeaderMap::new();
h.insert(
"connection",
HeaderValue::from_static("keep-alive, X-Custom-Hop"),
);
h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
h.insert("x-custom-hop", HeaderValue::from_static("secret"));
h.insert("content-type", HeaderValue::from_static("text/plain"));
strip_hop_by_hop(&mut h);
assert!(!h.contains_key("connection"));
assert!(!h.contains_key("keep-alive"));
assert!(!h.contains_key("x-custom-hop"));
assert!(h.contains_key("content-type"));
}
#[test]
fn forwarded_proto_reflects_tls_and_trust() {
let mut cfg = Config::default();
cfg.tls.enabled = true;
assert_eq!(
forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
"https"
);
cfg.tls.enabled = false;
cfg.server.trust_forwarded_for = false;
assert_eq!(
forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
"http"
);
cfg.server.trust_forwarded_for = true;
assert_eq!(
forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
"https"
);
assert_eq!(
forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
"http"
);
assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
assert_eq!(
forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
"http"
);
}
#[test]
fn longest_route_picks_most_specific_prefix() {
let mk = |p: &str| RouteLimiter {
prefix: p.to_string(),
limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
std::num::NonZeroU32::new(1).unwrap(),
))),
};
let routes = vec![mk("/api/"), mk("/api/admin/")];
assert_eq!(
longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
Some("/api/admin/")
);
assert_eq!(
longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
Some("/api/")
);
assert!(longest_route(&routes, "/public").is_none());
}
#[test]
fn harden_cookie_adds_missing_flags() {
let out = harden_cookie("sid=abc");
assert!(out.contains("; Secure"), "{out}");
assert!(out.contains("; HttpOnly"), "{out}");
assert!(out.contains("; SameSite=Lax"), "{out}");
}
#[test]
fn harden_cookie_preserves_existing_attributes() {
let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
assert!(out.contains("; Secure"), "{out}");
assert!(out.contains("SameSite=Strict"), "{out}");
assert!(!out.contains("SameSite=Lax"), "{out}");
assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
}
#[test]
fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
let out = harden_cookie("session=securetoken");
assert!(out.contains("; Secure"), "{out}");
}
#[test]
fn security_headers_reflects_config_toggles() {
let cfg = HeadersCfg::default();
let got = security_headers(&cfg);
let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
assert!(names.contains(&"X-Content-Type-Options"));
assert!(names.contains(&"X-Frame-Options"));
assert!(names.contains(&"Referrer-Policy"));
assert!(names.contains(&"Permissions-Policy"));
assert!(names.contains(&"Content-Security-Policy"));
assert!(names.contains(&"Strict-Transport-Security"));
assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
let cfg = HeadersCfg {
hsts: false,
frame_options: String::new(),
csp: "default-src 'self'".into(),
csp_report_only: true,
csp_report_uri: "/__edgeguard/csp-report".into(),
..HeadersCfg::default()
};
let got = security_headers(&cfg);
let map: std::collections::HashMap<&str, String> =
got.iter().map(|(n, v)| (*n, v.clone())).collect();
assert!(!map.contains_key("Strict-Transport-Security"));
assert!(!map.contains_key("X-Frame-Options"));
assert!(!map.contains_key("Content-Security-Policy"));
assert_eq!(
map.get("Content-Security-Policy-Report-Only")
.map(|s| s.as_str()),
Some("default-src 'self'; report-uri /__edgeguard/csp-report")
);
}
}