use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::Router;
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use subtle::ConstantTimeEq;
use tokio::sync::Mutex;
use tower_http::limit::RequestBodyLimitLayer;
use super::handlers::{health_handler, webhook_handler};
use super::server::AppState;
#[derive(Clone)]
struct AuthConfig {
token_hash: Option<blake3::Hash>,
}
const MAX_RATE_LIMIT_ENTRIES: usize = 10_000;
const RATE_WINDOW: Duration = Duration::from_mins(1);
#[derive(Clone, Debug)]
struct Cidr {
addr: IpAddr,
prefix_len: u8,
}
impl Cidr {
fn parse(s: &str) -> Option<Self> {
let (addr_str, prefix_str) = s.split_once('/')?;
let addr: IpAddr = addr_str.parse().ok()?;
let prefix_len: u8 = prefix_str.parse().ok()?;
let max = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
if prefix_len > max {
return None;
}
Some(Self { addr, prefix_len })
}
fn contains(&self, ip: IpAddr) -> bool {
match (self.addr, ip) {
(IpAddr::V4(net), IpAddr::V4(candidate)) => {
if self.prefix_len == 0 {
return true;
}
let shift = 32 - u32::from(self.prefix_len);
u32::from(net) >> shift == u32::from(candidate) >> shift
}
(IpAddr::V6(net), IpAddr::V6(candidate)) => {
if self.prefix_len == 0 {
return true;
}
let shift = 128 - u32::from(self.prefix_len);
u128::from(net) >> shift == u128::from(candidate) >> shift
}
_ => false,
}
}
}
#[derive(Clone)]
struct RateLimitState {
limit: u32,
counters: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
trusted_cidrs: Arc<Vec<Cidr>>,
}
pub(crate) fn build_router(
state: AppState,
auth_token: Option<&str>,
rate_limit: u32,
max_body_size: usize,
trusted_proxy_cidrs: &[String],
) -> Router {
let auth_cfg = AuthConfig {
token_hash: auth_token.map(|t| blake3::hash(t.as_bytes())),
};
let parsed_cidrs: Vec<Cidr> = trusted_proxy_cidrs
.iter()
.filter_map(|s| {
let c = Cidr::parse(s);
if c.is_none() {
tracing::warn!(cidr = %s, "gateway: invalid trusted_proxy_cidr, ignoring");
}
c
})
.collect();
let rate_state = RateLimitState {
limit: rate_limit,
counters: Arc::new(Mutex::new(HashMap::new())),
trusted_cidrs: Arc::new(parsed_cidrs),
};
let protected = Router::new()
.route("/webhook", post(webhook_handler))
.layer(middleware::from_fn_with_state(
rate_state,
rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(auth_cfg, auth_middleware))
.layer(RequestBodyLimitLayer::new(max_body_size));
Router::new()
.route("/health", get(health_handler))
.merge(protected)
.with_state(state)
}
async fn auth_middleware(
axum::extract::State(cfg): axum::extract::State<AuthConfig>,
req: Request<Body>,
next: Next,
) -> Response {
if let Some(expected_hash) = cfg.token_hash {
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
let token = auth_header
.and_then(|v| v.strip_prefix("Bearer "))
.unwrap_or("");
let token_hash = blake3::hash(token.as_bytes());
if !bool::from(token_hash.as_bytes().ct_eq(expected_hash.as_bytes())) {
return StatusCode::UNAUTHORIZED.into_response();
}
}
next.run(req).await
}
async fn rate_limit_middleware(
axum::extract::State(state): axum::extract::State<RateLimitState>,
req: Request<Body>,
next: Next,
) -> Response {
if state.limit == 0 {
return next.run(req).await;
}
let peer_ip = req
.extensions()
.get::<ConnectInfo<std::net::SocketAddr>>()
.map_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), |ci| ci.0.ip());
let ip = if !state.trusted_cidrs.is_empty()
&& state.trusted_cidrs.iter().any(|c| c.contains(peer_ip))
{
let xff_ip = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| {
v.split(',')
.map(str::trim)
.filter_map(|s| s.parse::<IpAddr>().ok())
.rev()
.find(|ip| !state.trusted_cidrs.iter().any(|c| c.contains(*ip)))
});
xff_ip.unwrap_or(peer_ip)
} else {
peer_ip
};
let now = Instant::now();
let mut counters = state.counters.lock().await;
if counters.len() >= MAX_RATE_LIMIT_ENTRIES && !counters.contains_key(&ip) {
counters.retain(|_, (_, ts)| now.duration_since(*ts) < RATE_WINDOW);
}
let entry = counters.entry(ip).or_insert((0, now));
if now.duration_since(entry.1) >= RATE_WINDOW {
*entry = (1, now);
} else {
entry.0 += 1;
if entry.0 > state.limit {
return StatusCode::TOO_MANY_REQUESTS.into_response();
}
}
drop(counters);
next.run(req).await
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use http_body_util::BodyExt;
use tower::{Service, ServiceExt};
use super::*;
use crate::server::AppState;
fn test_state() -> (AppState, tokio::sync::mpsc::Receiver<String>) {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let state = AppState {
webhook_tx: tx,
started_at: Instant::now(),
webhook_send_timeout: std::time::Duration::from_secs(5),
};
(state, rx)
}
fn make_router(
auth: Option<&str>,
rate_limit: u32,
) -> (Router, tokio::sync::mpsc::Receiver<String>) {
let (state, rx) = test_state();
(build_router(state, auth, rate_limit, 1_048_576, &[]), rx)
}
#[tokio::test]
async fn health_returns_ok() {
let (app, _rx) = make_router(None, 0);
let req = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
}
#[tokio::test]
async fn webhook_accepted() {
let (tx, mut rx) = tokio::sync::mpsc::channel(16);
let state = AppState {
webhook_tx: tx,
started_at: Instant::now(),
webhook_send_timeout: std::time::Duration::from_secs(5),
};
let app = build_router(state, None, 0, 1_048_576, &[]);
let body = serde_json::json!({
"channel": "discord",
"sender": "user1",
"body": "hello"
});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
let msg = rx.try_recv().unwrap();
assert!(msg.contains("user1"));
}
#[tokio::test]
async fn auth_rejects_missing_token() {
let (app, _rx) = make_router(Some("secret"), 0);
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 401);
}
#[tokio::test]
async fn auth_accepts_valid_token() {
let (app, _rx) = make_router(Some("secret"), 0);
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.header("authorization", "Bearer secret")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn auth_rejects_wrong_token() {
let (app, _rx) = make_router(Some("secret"), 0);
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.header("authorization", "Bearer wrong")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 401);
}
#[tokio::test]
async fn health_skips_auth() {
let (app, _rx) = make_router(Some("secret"), 0);
let req = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn rate_limit_enforced() {
let (mut app, _rx) = make_router(None, 2);
let make_req = || {
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap()
};
let resp = app.call(make_req()).await.unwrap();
assert_eq!(resp.status(), 200);
let resp = app.call(make_req()).await.unwrap();
assert_eq!(resp.status(), 200);
let resp = app.call(make_req()).await.unwrap();
assert_eq!(resp.status(), 429);
}
#[tokio::test]
async fn no_auth_when_token_unset() {
let (app, _rx) = make_router(None, 0);
let body = serde_json::json!({"channel": "a", "sender": "b", "body": "c"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn webhook_missing_field_returns_json_error() {
let (app, _rx) = make_router(None, 0);
let body = serde_json::json!({"channel": "ci643", "body": "test"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 422);
let ct = resp
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/json"),
"expected JSON content-type, got: {ct}"
);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert_eq!(json["status"], 422);
}
#[tokio::test]
async fn webhook_validation_failure_returns_json_error() {
let (app, _rx) = make_router(None, 0);
let body = serde_json::json!({
"channel": "ci643",
"sender": "a".repeat(257),
"body": "hello"
});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 422);
let ct = resp
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/json"),
"expected JSON content-type, got: {ct}"
);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert_eq!(json["status"], 422);
}
#[tokio::test]
async fn webhook_503_returns_json_error() {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(1);
drop(rx);
let state = AppState {
webhook_tx: tx,
started_at: Instant::now(),
webhook_send_timeout: std::time::Duration::from_secs(5),
};
let app = build_router(state, None, 0, 1_048_576, &[]);
let body = serde_json::json!({"channel": "c", "sender": "s", "body": "b"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 503);
let ct = resp
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/json"),
"expected application/json content-type for 503, got: {ct}"
);
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["status"], 503);
assert!(json.get("error").is_some());
}
#[tokio::test]
async fn body_size_limit() {
let (state, _rx) = test_state();
let app = build_router(state, None, 0, 64, &[]);
let oversized = vec![b'a'; 128];
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(oversized))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 413);
}
#[test]
fn bearer_ct_eq_is_constant_time() {
use std::time::Instant;
const ITERS: u32 = 100_000;
const MAX_RATIO: u128 = 10;
let expected_hash = blake3::hash(b"super-secret-gateway-token");
let candidates: &[&[u8]] = &[b"x", b"wrong_token_123", &[b'z'; 512]];
let mut times_ns: Vec<u128> = Vec::with_capacity(candidates.len());
for candidate in candidates {
let h = blake3::hash(candidate);
for _ in 0..1_000 {
let _ = h.as_bytes().ct_eq(expected_hash.as_bytes());
}
let start = Instant::now();
for _ in 0..ITERS {
let _ = h.as_bytes().ct_eq(expected_hash.as_bytes());
}
times_ns.push(start.elapsed().as_nanos() / u128::from(ITERS));
}
let min = *times_ns.iter().min().unwrap();
let max = *times_ns.iter().max().unwrap();
assert!(
min > 0 && max / min < MAX_RATIO,
"ct_eq timing ratio {max}/{min} exceeds {MAX_RATIO}×; times per iter: {times_ns:?} ns"
);
}
#[test]
fn cidr_ipv4_contains_in_range() {
let cidr = Cidr::parse("10.0.0.0/8").unwrap();
assert!(cidr.contains("10.1.2.3".parse().unwrap()));
assert!(cidr.contains("10.255.255.255".parse().unwrap()));
assert!(!cidr.contains("11.0.0.0".parse().unwrap()));
assert!(!cidr.contains("9.255.255.255".parse().unwrap()));
}
#[test]
fn cidr_ipv4_slash32_exact_host() {
let cidr = Cidr::parse("192.168.1.100/32").unwrap();
assert!(cidr.contains("192.168.1.100".parse().unwrap()));
assert!(!cidr.contains("192.168.1.101".parse().unwrap()));
}
#[test]
fn cidr_ipv4_slash0_matches_all() {
let cidr = Cidr::parse("0.0.0.0/0").unwrap();
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
assert!(cidr.contains("255.255.255.255".parse().unwrap()));
}
#[test]
fn cidr_ipv6_contains_in_range() {
let cidr = Cidr::parse("::1/128").unwrap();
assert!(cidr.contains("::1".parse().unwrap()));
assert!(!cidr.contains("::2".parse().unwrap()));
}
#[test]
fn cidr_ipv4_v6_mismatch_returns_false() {
let cidr = Cidr::parse("10.0.0.0/8").unwrap();
assert!(!cidr.contains("::1".parse().unwrap()));
}
#[test]
fn cidr_parse_rejects_invalid() {
assert!(Cidr::parse("not-a-cidr").is_none());
assert!(Cidr::parse("10.0.0.0/33").is_none());
assert!(Cidr::parse("::1/129").is_none());
assert!(Cidr::parse("10.0.0.0/").is_none());
}
#[tokio::test]
async fn xff_rightmost_untrusted_selected() {
let (state, _rx) = test_state();
let cidrs = vec!["0.0.0.0/0".to_string()];
let mut app = build_router(state, None, 1, 1_048_576, &cidrs);
let make_req = || {
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.header("x-forwarded-for", "1.2.3.4, 10.0.0.1")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap()
};
let resp1 = app.call(make_req()).await.unwrap();
assert_eq!(resp1.status(), 200);
let resp2 = app.call(make_req()).await.unwrap();
assert_eq!(
resp2.status(),
429,
"second request from same real IP must be rate-limited"
);
}
#[tokio::test]
async fn xff_absent_falls_back_to_tcp_peer() {
let (state, _rx) = test_state();
let mut app = build_router(state, None, 1, 1_048_576, &[]);
let make_req = || {
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap()
};
let resp1 = app.call(make_req()).await.unwrap();
assert_eq!(resp1.status(), 200);
let resp2 = app.call(make_req()).await.unwrap();
assert_eq!(
resp2.status(),
429,
"second request must be rate-limited via TCP peer"
);
}
#[tokio::test]
async fn xff_all_trusted_falls_back_to_peer() {
let (state, rx) = test_state();
let cidrs = vec!["0.0.0.0/0".to_string()];
let app = build_router(state, None, 0, 1_048_576, &cidrs);
let body = serde_json::json!({"channel":"a","sender":"b","body":"c"});
let req = Request::builder()
.method("POST")
.uri("/webhook")
.header("content-type", "application/json")
.header("x-forwarded-for", "10.0.0.1, 10.0.0.2")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
drop(rx);
}
}