Skip to main content

dbrest_core/auth/
middleware.rs

1//! Axum auth middleware
2//!
3//! Extracts the `Authorization: Bearer <token>` header, validates the JWT,
4//! caches the result, and inserts an [`AuthResult`] into the request
5//! extensions for downstream handlers.
6//!
7//! # Flow
8//!
9//! 1. Extract token from `Authorization` header.
10//! 2. If no token and anonymous role is configured → anonymous access.
11//! 3. If no token and no anonymous role → 401 (`DBRST302`).
12//! 4. Check the JWT cache for a previous validation result.
13//! 5. On cache miss, validate via [`jwt::parse_and_validate`].
14//! 6. Store the result in the cache and attach it to the request extensions.
15//!
16//! # Error Response
17//!
18//! JWT errors produce a JSON error body with the appropriate DBRST error
19//! code and a `WWW-Authenticate` header when the status is 401.
20
21use std::sync::Arc;
22
23use arc_swap::ArcSwap;
24use axum::{
25    extract::Request,
26    middleware::Next,
27    response::{IntoResponse, Response},
28};
29use http::header;
30
31use crate::config::AppConfig;
32use crate::error::response::ErrorResponse;
33
34use super::cache::JwtCache;
35use super::error::JwtError;
36use super::jwt;
37use super::types::AuthResult;
38
39/// Shared authentication state passed to the middleware via axum `State`.
40///
41/// Contains the config and JWT cache. Cloned per-request (cheap — all
42/// fields are `Arc` or `Clone`-cheap).
43///
44/// The `config` field is an `ArcSwap` so that live config reloads
45/// (triggered via `NOTIFY dbrst, 'reload config'`) are visible to the
46/// auth middleware without restarting the server.
47#[derive(Debug, Clone)]
48pub struct AuthState {
49    pub config: Arc<ArcSwap<AppConfig>>,
50    pub cache: JwtCache,
51}
52
53impl AuthState {
54    /// Create a new `AuthState` wrapping the given config snapshot.
55    ///
56    /// The config is placed inside a fresh `ArcSwap`. Use
57    /// [`with_shared_config`](Self::with_shared_config) to share the
58    /// same `ArcSwap` with `AppState` for live-reload support.
59    pub fn new(config: Arc<AppConfig>) -> Self {
60        let max_entries = config.jwt_cache_max_entries;
61        Self {
62            config: Arc::new(ArcSwap::new(config)),
63            cache: JwtCache::new(max_entries),
64        }
65    }
66
67    /// Create an `AuthState` that shares an existing `ArcSwap<AppConfig>`.
68    ///
69    /// When the `ArcSwap` is updated (e.g. during config reload), the
70    /// auth middleware automatically sees the new values.
71    pub fn with_shared_config(config: Arc<ArcSwap<AppConfig>>) -> Self {
72        let max_entries = config.load().jwt_cache_max_entries;
73        Self {
74            config,
75            cache: JwtCache::new(max_entries),
76        }
77    }
78
79    /// Get a snapshot of the current config.
80    pub fn load_config(&self) -> arc_swap::Guard<Arc<AppConfig>> {
81        self.config.load()
82    }
83}
84
85/// Axum middleware function for JWT authentication.
86///
87/// Attach to a router via `axum::middleware::from_fn_with_state`:
88///
89/// ```ignore
90/// use axum::{Router, middleware};
91/// use dbrest::auth::middleware::{auth_middleware, AuthState};
92///
93/// let state = AuthState::new(config.into());
94/// let app = Router::new()
95///     .route("/items", get(handler))
96///     .layer(middleware::from_fn_with_state(state, auth_middleware));
97/// ```
98pub async fn auth_middleware(
99    axum::extract::State(state): axum::extract::State<AuthState>,
100    mut request: Request,
101    next: Next,
102) -> Response {
103    match authenticate(&state, &request).await {
104        Ok(auth_result) => {
105            request.extensions_mut().insert(auth_result);
106            next.run(request).await
107        }
108        Err(jwt_err) => jwt_error_response(jwt_err),
109    }
110}
111
112/// Core authentication logic, separated for testability.
113pub async fn authenticate(state: &AuthState, request: &Request) -> Result<AuthResult, JwtError> {
114    let token = extract_bearer_token(request);
115    authenticate_token(state, token).await
116}
117
118/// Authenticate with an already-extracted token string.
119///
120/// This variant avoids borrowing the `Request` across await points, making
121/// it usable in contexts where the `Request` body is not `Sync`.
122#[tracing::instrument(name = "authenticate", skip_all)]
123pub async fn authenticate_token(
124    state: &AuthState,
125    token: Option<&str>,
126) -> Result<AuthResult, JwtError> {
127    let config = state.load_config();
128    match token {
129        Some(token) => {
130            // Check cache first
131            if let Some(cached) = state.cache.get(token).await {
132                metrics::counter!("jwt.cache.hit.total").increment(1);
133                return Ok((*cached).clone());
134            }
135
136            // Validate
137            let result = jwt::parse_and_validate(token, &config)?;
138
139            // Cache the result
140            state.cache.insert(token, result.clone()).await;
141            metrics::counter!("jwt.cache.miss.total").increment(1);
142
143            Ok(result)
144        }
145        None => {
146            // No token — check anonymous role
147            if let Some(ref anon_role) = config.db_anon_role {
148                Ok(AuthResult::anonymous(anon_role))
149            } else {
150                Err(JwtError::TokenRequired)
151            }
152        }
153    }
154}
155
156/// Extract the Bearer token from the `Authorization` header.
157///
158/// Returns `None` if no `Authorization` header is present.
159/// Returns `Some("")` if the header is `Bearer ` with an empty token,
160/// which is then caught by `parse_and_validate` as `EmptyAuthHeader`.
161fn extract_bearer_token(request: &Request) -> Option<&str> {
162    let header_value = request.headers().get(header::AUTHORIZATION)?;
163    let header_str = header_value.to_str().ok()?;
164
165    if let Some(token) = header_str.strip_prefix("Bearer ") {
166        Some(token)
167    } else if let Some(token) = header_str.strip_prefix("bearer ") {
168        Some(token)
169    } else {
170        // Not a Bearer token — ignore (might be Basic auth etc.)
171        None
172    }
173}
174
175/// Build an HTTP error response from a JWT error.
176pub fn jwt_error_response(err: JwtError) -> Response {
177    tracing::warn!(
178        error_code = err.code(),
179        http_status = err.status().as_u16(),
180        "Auth rejected: {}",
181        err
182    );
183
184    let status = err.status();
185    let www_auth = err.www_authenticate();
186
187    let body = ErrorResponse {
188        code: err.code(),
189        message: err.to_string(),
190        details: err.details(),
191        hint: None,
192    };
193
194    let mut response = (status, axum::Json(body)).into_response();
195
196    if let Some(www_auth_value) = www_auth
197        && let Ok(hv) = http::HeaderValue::from_str(&www_auth_value)
198    {
199        response.headers_mut().insert(header::WWW_AUTHENTICATE, hv);
200    }
201
202    response
203}
204
205// ---------------------------------------------------------------------------
206// Tests
207// ---------------------------------------------------------------------------
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use axum::body::Body;
213    use jsonwebtoken::{EncodingKey, Header as JwtHeader};
214
215    fn test_state(secret: &str) -> AuthState {
216        let mut config = AppConfig::default();
217        config.jwt_secret = Some(secret.to_string());
218        config.db_anon_role = Some("web_anon".to_string());
219        config.jwt_cache_max_entries = 100;
220        AuthState::new(Arc::new(config))
221    }
222
223    fn test_state_no_anon(secret: &str) -> AuthState {
224        let mut config = AppConfig::default();
225        config.jwt_secret = Some(secret.to_string());
226        config.db_anon_role = None;
227        config.jwt_cache_max_entries = 100;
228        AuthState::new(Arc::new(config))
229    }
230
231    fn encode_token(claims: &serde_json::Value, secret: &str) -> String {
232        jsonwebtoken::encode(
233            &JwtHeader::default(),
234            claims,
235            &EncodingKey::from_secret(secret.as_bytes()),
236        )
237        .unwrap()
238    }
239
240    fn make_request(token: Option<&str>) -> Request {
241        let mut builder = Request::builder().method("GET").uri("/items");
242        if let Some(t) = token {
243            builder = builder.header("Authorization", format!("Bearer {t}"));
244        }
245        builder.body(Body::empty()).unwrap()
246    }
247
248    #[tokio::test]
249    async fn test_authenticate_valid_token() {
250        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
251        let state = test_state(secret);
252        let claims = serde_json::json!({
253            "role": "test_author",
254            "exp": chrono::Utc::now().timestamp() + 3600
255        });
256        let token = encode_token(&claims, secret);
257        let request = make_request(Some(&token));
258
259        let result = authenticate(&state, &request).await.unwrap();
260        assert_eq!(result.role.as_str(), "test_author");
261        assert!(!result.is_anonymous());
262    }
263
264    #[tokio::test]
265    async fn test_authenticate_anonymous() {
266        let state = test_state("secret");
267        let request = make_request(None);
268
269        let result = authenticate(&state, &request).await.unwrap();
270        assert_eq!(result.role.as_str(), "web_anon");
271        assert!(result.is_anonymous());
272    }
273
274    #[tokio::test]
275    async fn test_authenticate_no_anon_no_token() {
276        let state = test_state_no_anon("secret");
277        let request = make_request(None);
278
279        let err = authenticate(&state, &request).await.unwrap_err();
280        assert!(matches!(err, JwtError::TokenRequired));
281    }
282
283    #[tokio::test]
284    async fn test_authenticate_expired_token() {
285        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
286        let state = test_state(secret);
287        let claims = serde_json::json!({
288            "role": "test_author",
289            "exp": chrono::Utc::now().timestamp() - 60
290        });
291        let token = encode_token(&claims, secret);
292        let request = make_request(Some(&token));
293
294        let err = authenticate(&state, &request).await.unwrap_err();
295        assert!(matches!(err, JwtError::Claims(_)));
296    }
297
298    #[tokio::test]
299    async fn test_authenticate_wrong_secret() {
300        let state = test_state("correct_secret_is_long_enough");
301        let claims = serde_json::json!({
302            "role": "test_author",
303            "exp": chrono::Utc::now().timestamp() + 3600
304        });
305        let token = encode_token(&claims, "wrong_secret_value_different");
306        let request = make_request(Some(&token));
307
308        let err = authenticate(&state, &request).await.unwrap_err();
309        assert!(matches!(err, JwtError::Decode(_)));
310    }
311
312    #[tokio::test]
313    async fn test_authenticate_cache_hit() {
314        let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
315        let state = test_state(secret);
316        let claims = serde_json::json!({
317            "role": "cached_role",
318            "exp": chrono::Utc::now().timestamp() + 3600
319        });
320        let token = encode_token(&claims, secret);
321
322        // First request — cache miss
323        let request = make_request(Some(&token));
324        let result1 = authenticate(&state, &request).await.unwrap();
325        assert_eq!(result1.role.as_str(), "cached_role");
326
327        // Second request — cache hit
328        let request = make_request(Some(&token));
329        let result2 = authenticate(&state, &request).await.unwrap();
330        assert_eq!(result2.role.as_str(), "cached_role");
331
332        // Verify cache has the entry
333        assert!(state.cache.get(&token).await.is_some());
334    }
335
336    #[tokio::test]
337    async fn test_authenticate_empty_bearer() {
338        let state = test_state("secret");
339        let request = Request::builder()
340            .method("GET")
341            .uri("/items")
342            .header("Authorization", "Bearer ")
343            .body(Body::empty())
344            .unwrap();
345
346        let err = authenticate(&state, &request).await.unwrap_err();
347        assert!(matches!(
348            err,
349            JwtError::Decode(super::super::error::JwtDecodeError::EmptyAuthHeader)
350        ));
351    }
352
353    #[test]
354    fn test_extract_bearer_token() {
355        let req = make_request(Some("abc123"));
356        assert_eq!(extract_bearer_token(&req), Some("abc123"));
357
358        let req = make_request(None);
359        assert!(extract_bearer_token(&req).is_none());
360
361        // Case insensitive "bearer"
362        let req = Request::builder()
363            .method("GET")
364            .uri("/")
365            .header("Authorization", "bearer mytoken")
366            .body(Body::empty())
367            .unwrap();
368        assert_eq!(extract_bearer_token(&req), Some("mytoken"));
369
370        // Basic auth — should return None
371        let req = Request::builder()
372            .method("GET")
373            .uri("/")
374            .header("Authorization", "Basic dXNlcjpwYXNz")
375            .body(Body::empty())
376            .unwrap();
377        assert!(extract_bearer_token(&req).is_none());
378    }
379
380    #[test]
381    fn test_jwt_error_response_has_www_authenticate() {
382        let err = JwtError::TokenRequired;
383        let response = jwt_error_response(err);
384        assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
385        assert!(response.headers().contains_key(header::WWW_AUTHENTICATE));
386        assert_eq!(
387            response.headers().get(header::WWW_AUTHENTICATE).unwrap(),
388            "Bearer"
389        );
390    }
391
392    #[test]
393    fn test_jwt_error_response_decode() {
394        let err = JwtError::Decode(super::super::error::JwtDecodeError::BadCrypto);
395        let response = jwt_error_response(err);
396        assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
397        let www = response
398            .headers()
399            .get(header::WWW_AUTHENTICATE)
400            .unwrap()
401            .to_str()
402            .unwrap();
403        assert!(www.contains("invalid_token"));
404    }
405
406    #[test]
407    fn test_jwt_error_response_secret_missing() {
408        let err = JwtError::SecretMissing;
409        let response = jwt_error_response(err);
410        assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
411        assert!(!response.headers().contains_key(header::WWW_AUTHENTICATE));
412    }
413
414    #[test]
415    fn test_shared_config_swap_propagates() {
416        let config = AppConfig::default();
417        let swap = Arc::new(ArcSwap::new(Arc::new(config)));
418        let auth = AuthState::with_shared_config(swap.clone());
419
420        // Initial config
421        assert_eq!(auth.load_config().server_port, 3000);
422
423        // Swap in new config
424        let mut new_config = AppConfig::default();
425        new_config.server_port = 9999;
426        swap.store(Arc::new(new_config));
427
428        // Auth state sees the update immediately
429        assert_eq!(auth.load_config().server_port, 9999);
430    }
431
432    #[test]
433    fn test_new_constructor_creates_isolated_swap() {
434        let config = AppConfig::default();
435        let auth = AuthState::new(Arc::new(config));
436        assert_eq!(auth.load_config().server_port, 3000);
437    }
438}