Skip to main content

dbrest_core/auth/
middleware.rs

1//! Axum auth middleware
2//!
3//! Extracts the `Authorization: Bearer <token>` header, validates the JWT,
4//! caches the result, and inserts an [`AuthResult`] into the request
5//! extensions for downstream handlers.
6//!
7//! # Flow
8//!
9//! 1. Extract token from `Authorization` header.
10//! 2. If no token and anonymous role is configured → anonymous access.
11//! 3. If no token and no anonymous role → 401 (`DBRST302`).
12//! 4. Check the JWT cache for a previous validation result.
13//! 5. On cache miss, validate via [`jwt::parse_and_validate`].
14//! 6. Store the result in the cache and attach it to the request extensions.
15//!
16//! # Error Response
17//!
18//! JWT errors produce a JSON error body with the appropriate DBRST error
19//! code and a `WWW-Authenticate` header when the status is 401.
20
21use std::sync::Arc;
22
23use arc_swap::ArcSwap;
24use axum::{
25    extract::Request,
26    middleware::Next,
27    response::{IntoResponse, Response},
28};
29use http::header;
30
31use crate::config::AppConfig;
32use crate::error::response::ErrorResponse;
33
34use super::cache::JwtCache;
35use super::error::JwtError;
36use super::jwt;
37use super::types::AuthResult;
38
39/// Shared authentication state passed to the middleware via axum `State`.
40///
41/// Contains the config and JWT cache. Cloned per-request (cheap — all
42/// fields are `Arc` or `Clone`-cheap).
43///
44/// The `config` field is an `ArcSwap` so that live config reloads
45/// (triggered via `NOTIFY dbrst, 'reload config'`) are visible to the
46/// auth middleware without restarting the server.
47#[derive(Debug, Clone)]
48pub struct AuthState {
49    pub config: Arc<ArcSwap<AppConfig>>,
50    pub cache: JwtCache,
51}
52
53impl AuthState {
54    /// Create a new `AuthState` wrapping the given config snapshot.
55    ///
56    /// The config is placed inside a fresh `ArcSwap`. Use
57    /// [`with_shared_config`](Self::with_shared_config) to share the
58    /// same `ArcSwap` with `AppState` for live-reload support.
59    pub fn new(config: Arc<AppConfig>) -> Self {
60        let max_entries = config.jwt_cache_max_entries;
61        Self {
62            config: Arc::new(ArcSwap::new(config)),
63            cache: JwtCache::new(max_entries),
64        }
65    }
66
67    /// Create an `AuthState` that shares an existing `ArcSwap<AppConfig>`.
68    ///
69    /// When the `ArcSwap` is updated (e.g. during config reload), the
70    /// auth middleware automatically sees the new values.
71    pub fn with_shared_config(config: Arc<ArcSwap<AppConfig>>) -> Self {
72        let max_entries = config.load().jwt_cache_max_entries;
73        Self {
74            config,
75            cache: JwtCache::new(max_entries),
76        }
77    }
78
79    /// Get a snapshot of the current config.
80    pub fn load_config(&self) -> arc_swap::Guard<Arc<AppConfig>> {
81        self.config.load()
82    }
83}
84
85/// Axum middleware function for JWT authentication.
86///
87/// Attach to a router via `axum::middleware::from_fn_with_state`:
88///
89/// ```ignore
90/// use axum::{Router, middleware};
91/// use dbrest::auth::middleware::{auth_middleware, AuthState};
92///
93/// let state = AuthState::new(config.into());
94/// let app = Router::new()
95///     .route("/items", get(handler))
96///     .layer(middleware::from_fn_with_state(state, auth_middleware));
97/// ```
98pub async fn auth_middleware(
99    axum::extract::State(state): axum::extract::State<AuthState>,
100    mut request: Request,
101    next: Next,
102) -> Response {
103    match authenticate(&state, &request).await {
104        Ok(auth_result) => {
105            request.extensions_mut().insert(auth_result);
106            next.run(request).await
107        }
108        Err(jwt_err) => jwt_error_response(jwt_err),
109    }
110}
111
112/// Core authentication logic, separated for testability.
113pub async fn authenticate(state: &AuthState, request: &Request) -> Result<AuthResult, JwtError> {
114    let token = extract_bearer_token(request);
115    authenticate_token(state, token).await
116}
117
118/// Authenticate with an already-extracted token string.
119///
120/// This variant avoids borrowing the `Request` across await points, making
121/// it usable in contexts where the `Request` body is not `Sync`.
122pub async fn authenticate_token(
123    state: &AuthState,
124    token: Option<&str>,
125) -> Result<AuthResult, JwtError> {
126    let config = state.load_config();
127    match token {
128        Some(token) => {
129            // Check cache first
130            if let Some(cached) = state.cache.get(token).await {
131                return Ok((*cached).clone());
132            }
133
134            // Validate
135            let result = jwt::parse_and_validate(token, &config)?;
136
137            // Cache the result
138            state.cache.insert(token, result.clone()).await;
139
140            Ok(result)
141        }
142        None => {
143            // No token — check anonymous role
144            if let Some(ref anon_role) = config.db_anon_role {
145                Ok(AuthResult::anonymous(anon_role))
146            } else {
147                Err(JwtError::TokenRequired)
148            }
149        }
150    }
151}
152
153/// Extract the Bearer token from the `Authorization` header.
154///
155/// Returns `None` if no `Authorization` header is present.
156/// Returns `Some("")` if the header is `Bearer ` with an empty token,
157/// which is then caught by `parse_and_validate` as `EmptyAuthHeader`.
158fn extract_bearer_token(request: &Request) -> Option<&str> {
159    let header_value = request.headers().get(header::AUTHORIZATION)?;
160    let header_str = header_value.to_str().ok()?;
161
162    if let Some(token) = header_str.strip_prefix("Bearer ") {
163        Some(token)
164    } else if let Some(token) = header_str.strip_prefix("bearer ") {
165        Some(token)
166    } else {
167        // Not a Bearer token — ignore (might be Basic auth etc.)
168        None
169    }
170}
171
172/// Build an HTTP error response from a JWT error.
173pub fn jwt_error_response(err: JwtError) -> Response {
174    let status = err.status();
175    let www_auth = err.www_authenticate();
176
177    let body = ErrorResponse {
178        code: err.code(),
179        message: err.to_string(),
180        details: err.details(),
181        hint: None,
182    };
183
184    let mut response = (status, axum::Json(body)).into_response();
185
186    if let Some(www_auth_value) = www_auth
187        && let Ok(hv) = http::HeaderValue::from_str(&www_auth_value)
188    {
189        response.headers_mut().insert(header::WWW_AUTHENTICATE, hv);
190    }
191
192    response
193}
194
195// ---------------------------------------------------------------------------
196// Tests
197// ---------------------------------------------------------------------------
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use axum::body::Body;
203    use jsonwebtoken::{EncodingKey, Header as JwtHeader};
204
205    fn test_state(secret: &str) -> AuthState {
206        let mut config = AppConfig::default();
207        config.jwt_secret = Some(secret.to_string());
208        config.db_anon_role = Some("web_anon".to_string());
209        config.jwt_cache_max_entries = 100;
210        AuthState::new(Arc::new(config))
211    }
212
213    fn test_state_no_anon(secret: &str) -> AuthState {
214        let mut config = AppConfig::default();
215        config.jwt_secret = Some(secret.to_string());
216        config.db_anon_role = None;
217        config.jwt_cache_max_entries = 100;
218        AuthState::new(Arc::new(config))
219    }
220
221    fn encode_token(claims: &serde_json::Value, secret: &str) -> String {
222        jsonwebtoken::encode(
223            &JwtHeader::default(),
224            claims,
225            &EncodingKey::from_secret(secret.as_bytes()),
226        )
227        .unwrap()
228    }
229
230    fn make_request(token: Option<&str>) -> Request {
231        let mut builder = Request::builder().method("GET").uri("/items");
232        if let Some(t) = token {
233            builder = builder.header("Authorization", format!("Bearer {t}"));
234        }
235        builder.body(Body::empty()).unwrap()
236    }
237
238    #[tokio::test]
239    async fn test_authenticate_valid_token() {
240        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
241        let state = test_state(secret);
242        let claims = serde_json::json!({
243            "role": "test_author",
244            "exp": chrono::Utc::now().timestamp() + 3600
245        });
246        let token = encode_token(&claims, secret);
247        let request = make_request(Some(&token));
248
249        let result = authenticate(&state, &request).await.unwrap();
250        assert_eq!(result.role.as_str(), "test_author");
251        assert!(!result.is_anonymous());
252    }
253
254    #[tokio::test]
255    async fn test_authenticate_anonymous() {
256        let state = test_state("secret");
257        let request = make_request(None);
258
259        let result = authenticate(&state, &request).await.unwrap();
260        assert_eq!(result.role.as_str(), "web_anon");
261        assert!(result.is_anonymous());
262    }
263
264    #[tokio::test]
265    async fn test_authenticate_no_anon_no_token() {
266        let state = test_state_no_anon("secret");
267        let request = make_request(None);
268
269        let err = authenticate(&state, &request).await.unwrap_err();
270        assert!(matches!(err, JwtError::TokenRequired));
271    }
272
273    #[tokio::test]
274    async fn test_authenticate_expired_token() {
275        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
276        let state = test_state(secret);
277        let claims = serde_json::json!({
278            "role": "test_author",
279            "exp": chrono::Utc::now().timestamp() - 60
280        });
281        let token = encode_token(&claims, secret);
282        let request = make_request(Some(&token));
283
284        let err = authenticate(&state, &request).await.unwrap_err();
285        assert!(matches!(err, JwtError::Claims(_)));
286    }
287
288    #[tokio::test]
289    async fn test_authenticate_wrong_secret() {
290        let state = test_state("correct_secret_is_long_enough");
291        let claims = serde_json::json!({
292            "role": "test_author",
293            "exp": chrono::Utc::now().timestamp() + 3600
294        });
295        let token = encode_token(&claims, "wrong_secret_value_different");
296        let request = make_request(Some(&token));
297
298        let err = authenticate(&state, &request).await.unwrap_err();
299        assert!(matches!(err, JwtError::Decode(_)));
300    }
301
302    #[tokio::test]
303    async fn test_authenticate_cache_hit() {
304        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
305        let state = test_state(secret);
306        let claims = serde_json::json!({
307            "role": "cached_role",
308            "exp": chrono::Utc::now().timestamp() + 3600
309        });
310        let token = encode_token(&claims, secret);
311
312        // First request — cache miss
313        let request = make_request(Some(&token));
314        let result1 = authenticate(&state, &request).await.unwrap();
315        assert_eq!(result1.role.as_str(), "cached_role");
316
317        // Second request — cache hit
318        let request = make_request(Some(&token));
319        let result2 = authenticate(&state, &request).await.unwrap();
320        assert_eq!(result2.role.as_str(), "cached_role");
321
322        // Verify cache has the entry
323        assert!(state.cache.get(&token).await.is_some());
324    }
325
326    #[tokio::test]
327    async fn test_authenticate_empty_bearer() {
328        let state = test_state("secret");
329        let request = Request::builder()
330            .method("GET")
331            .uri("/items")
332            .header("Authorization", "Bearer ")
333            .body(Body::empty())
334            .unwrap();
335
336        let err = authenticate(&state, &request).await.unwrap_err();
337        assert!(matches!(
338            err,
339            JwtError::Decode(super::super::error::JwtDecodeError::EmptyAuthHeader)
340        ));
341    }
342
343    #[test]
344    fn test_extract_bearer_token() {
345        let req = make_request(Some("abc123"));
346        assert_eq!(extract_bearer_token(&req), Some("abc123"));
347
348        let req = make_request(None);
349        assert!(extract_bearer_token(&req).is_none());
350
351        // Case insensitive "bearer"
352        let req = Request::builder()
353            .method("GET")
354            .uri("/")
355            .header("Authorization", "bearer mytoken")
356            .body(Body::empty())
357            .unwrap();
358        assert_eq!(extract_bearer_token(&req), Some("mytoken"));
359
360        // Basic auth — should return None
361        let req = Request::builder()
362            .method("GET")
363            .uri("/")
364            .header("Authorization", "Basic dXNlcjpwYXNz")
365            .body(Body::empty())
366            .unwrap();
367        assert!(extract_bearer_token(&req).is_none());
368    }
369
370    #[test]
371    fn test_jwt_error_response_has_www_authenticate() {
372        let err = JwtError::TokenRequired;
373        let response = jwt_error_response(err);
374        assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
375        assert!(response.headers().contains_key(header::WWW_AUTHENTICATE));
376        assert_eq!(
377            response.headers().get(header::WWW_AUTHENTICATE).unwrap(),
378            "Bearer"
379        );
380    }
381
382    #[test]
383    fn test_jwt_error_response_decode() {
384        let err = JwtError::Decode(super::super::error::JwtDecodeError::BadCrypto);
385        let response = jwt_error_response(err);
386        assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
387        let www = response
388            .headers()
389            .get(header::WWW_AUTHENTICATE)
390            .unwrap()
391            .to_str()
392            .unwrap();
393        assert!(www.contains("invalid_token"));
394    }
395
396    #[test]
397    fn test_jwt_error_response_secret_missing() {
398        let err = JwtError::SecretMissing;
399        let response = jwt_error_response(err);
400        assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
401        assert!(!response.headers().contains_key(header::WWW_AUTHENTICATE));
402    }
403
404    #[test]
405    fn test_shared_config_swap_propagates() {
406        let config = AppConfig::default();
407        let swap = Arc::new(ArcSwap::new(Arc::new(config)));
408        let auth = AuthState::with_shared_config(swap.clone());
409
410        // Initial config
411        assert_eq!(auth.load_config().server_port, 3000);
412
413        // Swap in new config
414        let mut new_config = AppConfig::default();
415        new_config.server_port = 9999;
416        swap.store(Arc::new(new_config));
417
418        // Auth state sees the update immediately
419        assert_eq!(auth.load_config().server_port, 9999);
420    }
421
422    #[test]
423    fn test_new_constructor_creates_isolated_swap() {
424        let config = AppConfig::default();
425        let auth = AuthState::new(Arc::new(config));
426        assert_eq!(auth.load_config().server_port, 3000);
427    }
428}