use axum::{body::Body, http::Request, middleware::Next, response::IntoResponse, Router};
pub fn require_bearer(
expected: impl Into<String>,
) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static {
let expected = expected.into();
move |router: Router<()>| {
let expected = expected.clone();
router.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let expected = expected.clone();
async move {
let authorized = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|token| constant_time_eq(token.as_bytes(), expected.as_bytes()))
.unwrap_or(false);
if authorized {
next.run(req).await
} else {
axum::http::StatusCode::UNAUTHORIZED.into_response()
}
}
},
))
}
}
pub fn guard<G>(guard_fn: G) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static
where
G: Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
{
move |router: Router<()>| {
let guard_fn = guard_fn.clone();
router.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let guard_fn = guard_fn.clone();
async move {
if guard_fn(&req) {
next.run(req).await
} else {
axum::http::StatusCode::FORBIDDEN.into_response()
}
}
},
))
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let len = a.len().max(b.len());
let mut diff: u8 = 0;
for i in 0..len {
let ab = a.get(i).copied().unwrap_or(0);
let bb = b.get(i).copied().unwrap_or(0);
diff |= ab ^ bb;
}
diff == 0
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request, routing::get, Router};
use http_body_util::BodyExt;
use tower::ServiceExt;
use super::*;
#[test]
fn ct_eq_identical_slices() {
assert!(constant_time_eq(b"secret", b"secret"));
}
#[test]
fn ct_eq_different_slices() {
assert!(!constant_time_eq(b"secret", b"wrong!"));
}
#[test]
fn ct_eq_empty_slices() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn ct_eq_different_lengths_short_a() {
assert!(!constant_time_eq(b"abc", b"abcd"));
}
#[test]
fn ct_eq_different_lengths_short_b() {
assert!(!constant_time_eq(b"abcd", b"abc"));
}
#[test]
fn ct_eq_empty_vs_nonempty() {
assert!(!constant_time_eq(b"", b"x"));
}
fn bearer_router() -> Router<()> {
let inner = Router::new().route("/protected", get(|| async { "ok" }));
require_bearer("correct-token")(inner)
}
#[tokio::test]
async fn bearer_accepts_correct_token() {
let app = bearer_router();
let response = app
.oneshot(
Request::builder()
.uri("/protected")
.header("Authorization", "Bearer correct-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 200);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn bearer_rejects_wrong_token() {
let app = bearer_router();
let response = app
.oneshot(
Request::builder()
.uri("/protected")
.header("Authorization", "Bearer wrong-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 401);
}
#[tokio::test]
async fn bearer_rejects_missing_header() {
let app = bearer_router();
let response = app
.oneshot(
Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 401);
}
#[tokio::test]
async fn bearer_rejects_malformed_header() {
let app = bearer_router();
let response = app
.oneshot(
Request::builder()
.uri("/protected")
.header("Authorization", "correct-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 401);
}
fn guard_router(
predicate: impl Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
) -> Router<()> {
let inner = Router::new().route("/guarded", get(|| async { "ok" }));
guard(predicate)(inner)
}
#[tokio::test]
async fn guard_allows_request_when_predicate_is_true() {
let app = guard_router(|_req| true);
let response = app
.oneshot(
Request::builder()
.uri("/guarded")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 200);
}
#[tokio::test]
async fn guard_blocks_request_with_403_when_predicate_is_false() {
let app = guard_router(|_req| false);
let response = app
.oneshot(
Request::builder()
.uri("/guarded")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 403);
}
#[tokio::test]
async fn guard_predicate_receives_live_request_headers() {
let app = guard_router(|req| req.headers().contains_key("x-allowed"));
let blocked = app
.clone()
.oneshot(
Request::builder()
.uri("/guarded")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(blocked.status(), 403);
let allowed = app
.oneshot(
Request::builder()
.uri("/guarded")
.header("x-allowed", "yes")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(allowed.status(), 200);
}
#[tokio::test]
async fn guard_predicate_receives_live_request_uri() {
let app = guard_router(|req| req.uri().path().starts_with("/guarded"));
let response = app
.oneshot(
Request::builder()
.uri("/guarded")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), 200);
}
}