dbrest_core/auth/
middleware.rs1use std::sync::Arc;
22
23use arc_swap::ArcSwap;
24use axum::{
25 extract::Request,
26 middleware::Next,
27 response::{IntoResponse, Response},
28};
29use http::header;
30
31use crate::config::AppConfig;
32use crate::error::response::ErrorResponse;
33
34use super::cache::JwtCache;
35use super::error::JwtError;
36use super::jwt;
37use super::types::AuthResult;
38
39#[derive(Debug, Clone)]
48pub struct AuthState {
49 pub config: Arc<ArcSwap<AppConfig>>,
50 pub cache: JwtCache,
51}
52
53impl AuthState {
54 pub fn new(config: Arc<AppConfig>) -> Self {
60 let max_entries = config.jwt_cache_max_entries;
61 Self {
62 config: Arc::new(ArcSwap::new(config)),
63 cache: JwtCache::new(max_entries),
64 }
65 }
66
67 pub fn with_shared_config(config: Arc<ArcSwap<AppConfig>>) -> Self {
72 let max_entries = config.load().jwt_cache_max_entries;
73 Self {
74 config,
75 cache: JwtCache::new(max_entries),
76 }
77 }
78
79 pub fn load_config(&self) -> arc_swap::Guard<Arc<AppConfig>> {
81 self.config.load()
82 }
83}
84
85pub async fn auth_middleware(
99 axum::extract::State(state): axum::extract::State<AuthState>,
100 mut request: Request,
101 next: Next,
102) -> Response {
103 match authenticate(&state, &request).await {
104 Ok(auth_result) => {
105 request.extensions_mut().insert(auth_result);
106 next.run(request).await
107 }
108 Err(jwt_err) => jwt_error_response(jwt_err),
109 }
110}
111
112pub async fn authenticate(state: &AuthState, request: &Request) -> Result<AuthResult, JwtError> {
114 let token = extract_bearer_token(request);
115 authenticate_token(state, token).await
116}
117
118#[tracing::instrument(name = "authenticate", skip_all)]
123pub async fn authenticate_token(
124 state: &AuthState,
125 token: Option<&str>,
126) -> Result<AuthResult, JwtError> {
127 let config = state.load_config();
128 match token {
129 Some(token) => {
130 if let Some(cached) = state.cache.get(token).await {
132 metrics::counter!("jwt.cache.hit.total").increment(1);
133 return Ok((*cached).clone());
134 }
135
136 let result = jwt::parse_and_validate(token, &config)?;
138
139 state.cache.insert(token, result.clone()).await;
141 metrics::counter!("jwt.cache.miss.total").increment(1);
142
143 Ok(result)
144 }
145 None => {
146 if let Some(ref anon_role) = config.db_anon_role {
148 Ok(AuthResult::anonymous(anon_role))
149 } else {
150 Err(JwtError::TokenRequired)
151 }
152 }
153 }
154}
155
156fn extract_bearer_token(request: &Request) -> Option<&str> {
162 let header_value = request.headers().get(header::AUTHORIZATION)?;
163 let header_str = header_value.to_str().ok()?;
164
165 if let Some(token) = header_str.strip_prefix("Bearer ") {
166 Some(token)
167 } else if let Some(token) = header_str.strip_prefix("bearer ") {
168 Some(token)
169 } else {
170 None
172 }
173}
174
175pub fn jwt_error_response(err: JwtError) -> Response {
177 tracing::warn!(
178 error_code = err.code(),
179 http_status = err.status().as_u16(),
180 "Auth rejected: {}",
181 err
182 );
183
184 let status = err.status();
185 let www_auth = err.www_authenticate();
186
187 let body = ErrorResponse {
188 code: err.code(),
189 message: err.to_string(),
190 details: err.details(),
191 hint: None,
192 };
193
194 let mut response = (status, axum::Json(body)).into_response();
195
196 if let Some(www_auth_value) = www_auth
197 && let Ok(hv) = http::HeaderValue::from_str(&www_auth_value)
198 {
199 response.headers_mut().insert(header::WWW_AUTHENTICATE, hv);
200 }
201
202 response
203}
204
205#[cfg(test)]
210mod tests {
211 use super::*;
212 use axum::body::Body;
213 use jsonwebtoken::{EncodingKey, Header as JwtHeader};
214
215 fn test_state(secret: &str) -> AuthState {
216 let mut config = AppConfig::default();
217 config.jwt_secret = Some(secret.to_string());
218 config.db_anon_role = Some("web_anon".to_string());
219 config.jwt_cache_max_entries = 100;
220 AuthState::new(Arc::new(config))
221 }
222
223 fn test_state_no_anon(secret: &str) -> AuthState {
224 let mut config = AppConfig::default();
225 config.jwt_secret = Some(secret.to_string());
226 config.db_anon_role = None;
227 config.jwt_cache_max_entries = 100;
228 AuthState::new(Arc::new(config))
229 }
230
231 fn encode_token(claims: &serde_json::Value, secret: &str) -> String {
232 jsonwebtoken::encode(
233 &JwtHeader::default(),
234 claims,
235 &EncodingKey::from_secret(secret.as_bytes()),
236 )
237 .unwrap()
238 }
239
240 fn make_request(token: Option<&str>) -> Request {
241 let mut builder = Request::builder().method("GET").uri("/items");
242 if let Some(t) = token {
243 builder = builder.header("Authorization", format!("Bearer {t}"));
244 }
245 builder.body(Body::empty()).unwrap()
246 }
247
248 #[tokio::test]
249 async fn test_authenticate_valid_token() {
250 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
251 let state = test_state(secret);
252 let claims = serde_json::json!({
253 "role": "test_author",
254 "exp": chrono::Utc::now().timestamp() + 3600
255 });
256 let token = encode_token(&claims, secret);
257 let request = make_request(Some(&token));
258
259 let result = authenticate(&state, &request).await.unwrap();
260 assert_eq!(result.role.as_str(), "test_author");
261 assert!(!result.is_anonymous());
262 }
263
264 #[tokio::test]
265 async fn test_authenticate_anonymous() {
266 let state = test_state("secret");
267 let request = make_request(None);
268
269 let result = authenticate(&state, &request).await.unwrap();
270 assert_eq!(result.role.as_str(), "web_anon");
271 assert!(result.is_anonymous());
272 }
273
274 #[tokio::test]
275 async fn test_authenticate_no_anon_no_token() {
276 let state = test_state_no_anon("secret");
277 let request = make_request(None);
278
279 let err = authenticate(&state, &request).await.unwrap_err();
280 assert!(matches!(err, JwtError::TokenRequired));
281 }
282
283 #[tokio::test]
284 async fn test_authenticate_expired_token() {
285 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
286 let state = test_state(secret);
287 let claims = serde_json::json!({
288 "role": "test_author",
289 "exp": chrono::Utc::now().timestamp() - 60
290 });
291 let token = encode_token(&claims, secret);
292 let request = make_request(Some(&token));
293
294 let err = authenticate(&state, &request).await.unwrap_err();
295 assert!(matches!(err, JwtError::Claims(_)));
296 }
297
298 #[tokio::test]
299 async fn test_authenticate_wrong_secret() {
300 let state = test_state("correct_secret_is_long_enough");
301 let claims = serde_json::json!({
302 "role": "test_author",
303 "exp": chrono::Utc::now().timestamp() + 3600
304 });
305 let token = encode_token(&claims, "wrong_secret_value_different");
306 let request = make_request(Some(&token));
307
308 let err = authenticate(&state, &request).await.unwrap_err();
309 assert!(matches!(err, JwtError::Decode(_)));
310 }
311
312 #[tokio::test]
313 async fn test_authenticate_cache_hit() {
314 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
315 let state = test_state(secret);
316 let claims = serde_json::json!({
317 "role": "cached_role",
318 "exp": chrono::Utc::now().timestamp() + 3600
319 });
320 let token = encode_token(&claims, secret);
321
322 let request = make_request(Some(&token));
324 let result1 = authenticate(&state, &request).await.unwrap();
325 assert_eq!(result1.role.as_str(), "cached_role");
326
327 let request = make_request(Some(&token));
329 let result2 = authenticate(&state, &request).await.unwrap();
330 assert_eq!(result2.role.as_str(), "cached_role");
331
332 assert!(state.cache.get(&token).await.is_some());
334 }
335
336 #[tokio::test]
337 async fn test_authenticate_empty_bearer() {
338 let state = test_state("secret");
339 let request = Request::builder()
340 .method("GET")
341 .uri("/items")
342 .header("Authorization", "Bearer ")
343 .body(Body::empty())
344 .unwrap();
345
346 let err = authenticate(&state, &request).await.unwrap_err();
347 assert!(matches!(
348 err,
349 JwtError::Decode(super::super::error::JwtDecodeError::EmptyAuthHeader)
350 ));
351 }
352
353 #[test]
354 fn test_extract_bearer_token() {
355 let req = make_request(Some("abc123"));
356 assert_eq!(extract_bearer_token(&req), Some("abc123"));
357
358 let req = make_request(None);
359 assert!(extract_bearer_token(&req).is_none());
360
361 let req = Request::builder()
363 .method("GET")
364 .uri("/")
365 .header("Authorization", "bearer mytoken")
366 .body(Body::empty())
367 .unwrap();
368 assert_eq!(extract_bearer_token(&req), Some("mytoken"));
369
370 let req = Request::builder()
372 .method("GET")
373 .uri("/")
374 .header("Authorization", "Basic dXNlcjpwYXNz")
375 .body(Body::empty())
376 .unwrap();
377 assert!(extract_bearer_token(&req).is_none());
378 }
379
380 #[test]
381 fn test_jwt_error_response_has_www_authenticate() {
382 let err = JwtError::TokenRequired;
383 let response = jwt_error_response(err);
384 assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
385 assert!(response.headers().contains_key(header::WWW_AUTHENTICATE));
386 assert_eq!(
387 response.headers().get(header::WWW_AUTHENTICATE).unwrap(),
388 "Bearer"
389 );
390 }
391
392 #[test]
393 fn test_jwt_error_response_decode() {
394 let err = JwtError::Decode(super::super::error::JwtDecodeError::BadCrypto);
395 let response = jwt_error_response(err);
396 assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
397 let www = response
398 .headers()
399 .get(header::WWW_AUTHENTICATE)
400 .unwrap()
401 .to_str()
402 .unwrap();
403 assert!(www.contains("invalid_token"));
404 }
405
406 #[test]
407 fn test_jwt_error_response_secret_missing() {
408 let err = JwtError::SecretMissing;
409 let response = jwt_error_response(err);
410 assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
411 assert!(!response.headers().contains_key(header::WWW_AUTHENTICATE));
412 }
413
414 #[test]
415 fn test_shared_config_swap_propagates() {
416 let config = AppConfig::default();
417 let swap = Arc::new(ArcSwap::new(Arc::new(config)));
418 let auth = AuthState::with_shared_config(swap.clone());
419
420 assert_eq!(auth.load_config().server_port, 3000);
422
423 let mut new_config = AppConfig::default();
425 new_config.server_port = 9999;
426 swap.store(Arc::new(new_config));
427
428 assert_eq!(auth.load_config().server_port, 9999);
430 }
431
432 #[test]
433 fn test_new_constructor_creates_isolated_swap() {
434 let config = AppConfig::default();
435 let auth = AuthState::new(Arc::new(config));
436 assert_eq!(auth.load_config().server_port, 3000);
437 }
438}