Skip to main content

zvault_server/
middleware.rs

1//! Authentication middleware for `ZVault`.
2//!
3//! Extracts the `X-Vault-Token` header, validates it against the token store,
4//! and injects the token entry into the request extensions for downstream
5//! handlers to use for policy checks.
6
7use std::sync::Arc;
8
9use axum::extract::{Request, State};
10use axum::http::StatusCode;
11use axum::middleware::Next;
12use axum::response::{IntoResponse, Response};
13
14use crate::state::AppState;
15
16/// Authentication context injected into request extensions.
17#[derive(Debug, Clone)]
18pub struct AuthContext {
19    /// The token hash (for audit logging).
20    pub token_hash: String,
21    /// Policies attached to this token.
22    pub policies: Vec<String>,
23    /// Display name for audit.
24    pub display_name: String,
25}
26
27/// Middleware that validates the `X-Vault-Token` header.
28///
29/// Skips auth for health and seal-status endpoints.
30pub async fn auth_middleware(
31    State(state): State<Arc<AppState>>,
32    mut req: Request,
33    next: Next,
34) -> Response {
35    let path = req.uri().path().to_owned();
36
37    // Skip auth for public endpoints.
38    if path == "/v1/sys/health"
39        || path == "/v1/sys/seal-status"
40        || path == "/v1/sys/init"
41        || path == "/v1/sys/unseal"
42        || path.starts_with("/app/")
43        || path == "/app"
44        || path == "/"
45    {
46        return next.run(req).await;
47    }
48
49    let token = req
50        .headers()
51        .get("X-Vault-Token")
52        .and_then(|v| v.to_str().ok())
53        .map(String::from);
54
55    let Some(token) = token else {
56        return (
57            StatusCode::UNAUTHORIZED,
58            axum::Json(serde_json::json!({"error": "unauthorized", "message": "missing X-Vault-Token header"})),
59        ).into_response();
60    };
61
62    match state.token_store.lookup(&token).await {
63        Ok(entry) => {
64            let ctx = AuthContext {
65                token_hash: entry.token_hash.clone(),
66                policies: entry.policies.clone(),
67                display_name: entry.display_name.clone(),
68            };
69            req.extensions_mut().insert(ctx);
70            next.run(req).await
71        }
72        Err(_) => (
73            StatusCode::UNAUTHORIZED,
74            axum::Json(serde_json::json!({"error": "unauthorized", "message": "invalid or expired token"})),
75        ).into_response(),
76    }
77}