use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderMap, HeaderValue, Method, Response, StatusCode};
use axum::middleware::Next;
use axum::response::IntoResponse;
use serde_json::json;
use subtle::ConstantTimeEq;
use vti_common::auth::extractor::ADMIN_SESSION_COOKIE;
const CSRF_EXEMPT_PATHS: &[&str] = &[
"/v1/join-requests",
"/v1/auth/challenge",
"/v1/auth/",
"/v1/wallet/auth/challenge",
"/v1/wallet/auth/",
"/v1/auth/refresh",
"/v1/auth/admin-login",
"/v1/install/claim/start",
"/v1/install/claim/finish",
"/v1/admin/bootstrap",
];
fn has_bearer_auth(headers: &HeaderMap) -> bool {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.trim_start().get(..7))
.map(|scheme| scheme.eq_ignore_ascii_case("bearer "))
.unwrap_or(false)
}
fn is_csrf_exempt(path: &str) -> bool {
if CSRF_EXEMPT_PATHS.contains(&path) {
return true;
}
if let Some(rest) = path.strip_prefix("/v1/join-requests/") {
return rest.ends_with("/accept") || rest.ends_with("/status");
}
false
}
fn has_session_cookie(headers: &HeaderMap) -> bool {
headers
.get_all(axum::http::header::COOKIE)
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(';'))
.map(|s| s.trim())
.filter_map(|kv| kv.split_once('='))
.any(|(name, _)| name == ADMIN_SESSION_COOKIE)
}
pub async fn enforce(request: Request, next: Next) -> Response<Body> {
match *request.method() {
Method::POST | Method::PUT | Method::PATCH | Method::DELETE => {}
_ => return next.run(request).await,
}
if has_bearer_auth(request.headers()) {
return next.run(request).await;
}
let path = request.uri().path();
if is_csrf_exempt(path) {
return next.run(request).await;
}
if !has_session_cookie(request.headers()) {
return next.run(request).await;
}
if request
.headers()
.get("sec-fetch-site")
.map(|v| v == HeaderValue::from_static("same-origin"))
.unwrap_or(false)
{
return next.run(request).await;
}
let cookie_token = request
.headers()
.get_all(axum::http::header::COOKIE)
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(';'))
.map(|s| s.trim())
.find_map(|kv| kv.strip_prefix("csrf="));
let header_token = request
.headers()
.get("x-csrf-token")
.and_then(|v| v.to_str().ok());
if let (Some(c), Some(h)) = (cookie_token, header_token)
&& !c.is_empty()
&& c.len() == h.len()
&& bool::from(c.as_bytes().ct_eq(h.as_bytes()))
{
return next.run(request).await;
}
let body = json!({
"error": "CsrfFailed",
"message": "POST/PUT/PATCH/DELETE requests require Sec-Fetch-Site: same-origin or a matching csrf cookie + X-CSRF-Token header",
});
(StatusCode::FORBIDDEN, axum::Json(body)).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::http::Request;
use axum::routing::post;
use http_body_util::BodyExt;
use tower::ServiceExt;
async fn ok() -> &'static str {
"ok"
}
const SESSION_COOKIE: &str = "vtc_admin_session=jwt.header.sig";
fn app() -> Router {
Router::new()
.route("/v1/members", post(ok).get(ok))
.route("/v1/join-requests", post(ok))
.route("/v1/join-requests/{id}/accept", post(ok))
.route("/v1/join-requests/{id}/status", post(ok))
.route("/v1/join-requests/{id}/approve", post(ok))
.route("/v1/auth/challenge", post(ok))
.layer(axum::middleware::from_fn(enforce))
}
#[tokio::test]
async fn get_bypasses() {
let resp = app()
.oneshot(
Request::builder()
.method("GET")
.uri("/v1/members")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn cookie_session_post_without_csrf_returns_403() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("cookie", SESSION_COOKIE)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
let body: serde_json::Value =
serde_json::from_slice(&resp.into_body().collect().await.unwrap().to_bytes()).unwrap();
assert_eq!(body["error"], "CsrfFailed");
}
#[tokio::test]
async fn post_without_session_cookie_passes_csrf() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn cookie_session_post_with_sec_fetch_site_same_origin_passes() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("cookie", SESSION_COOKIE)
.header("sec-fetch-site", "same-origin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn cookie_session_post_with_matching_csrf_cookie_and_header_passes() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header(
"cookie",
format!("{SESSION_COOKIE}; csrf=abc123; other=foo"),
)
.header("x-csrf-token", "abc123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn cookie_session_post_with_non_matching_csrf_header_returns_403() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("cookie", format!("{SESSION_COOKIE}; csrf=abc123"))
.header("x-csrf-token", "different")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn public_join_request_post_bypasses() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/join-requests")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn auth_challenge_post_bypasses() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/auth/challenge")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn bearer_post_without_cookie_or_origin_passes() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("authorization", "Bearer eyJabc.def.ghi")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn bearer_scheme_is_case_insensitive() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("authorization", "bearer eyJabc.def.ghi")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn non_bearer_authorization_scheme_on_cookie_session_stays_gated() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/members")
.header("authorization", "Basic dXNlcjpwYXNz")
.header("cookie", SESSION_COOKIE)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn holder_accept_post_bypasses() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/join-requests/req-123/accept")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn holder_status_post_bypasses() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/join-requests/req-123/status")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn admin_approve_on_join_mount_stays_gated() {
let resp = app()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/join-requests/req-123/approve")
.header("cookie", SESSION_COOKIE)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[test]
fn session_cookie_detected_among_others() {
let mut present = HeaderMap::new();
present.insert(
"cookie",
"foo=1; vtc_admin_session=abc; csrf=t".parse().unwrap(),
);
assert!(has_session_cookie(&present));
let mut absent = HeaderMap::new();
absent.insert("cookie", "csrf=t; other_session=abc".parse().unwrap());
assert!(!has_session_cookie(&absent));
}
}