dbrest_core/auth/
middleware.rs1use std::sync::Arc;
22
23use arc_swap::ArcSwap;
24use axum::{
25 extract::Request,
26 middleware::Next,
27 response::{IntoResponse, Response},
28};
29use http::header;
30
31use crate::config::AppConfig;
32use crate::error::response::ErrorResponse;
33
34use super::cache::JwtCache;
35use super::error::JwtError;
36use super::jwt;
37use super::types::AuthResult;
38
39#[derive(Debug, Clone)]
48pub struct AuthState {
49 pub config: Arc<ArcSwap<AppConfig>>,
50 pub cache: JwtCache,
51}
52
53impl AuthState {
54 pub fn new(config: Arc<AppConfig>) -> Self {
60 let max_entries = config.jwt_cache_max_entries;
61 Self {
62 config: Arc::new(ArcSwap::new(config)),
63 cache: JwtCache::new(max_entries),
64 }
65 }
66
67 pub fn with_shared_config(config: Arc<ArcSwap<AppConfig>>) -> Self {
72 let max_entries = config.load().jwt_cache_max_entries;
73 Self {
74 config,
75 cache: JwtCache::new(max_entries),
76 }
77 }
78
79 pub fn load_config(&self) -> arc_swap::Guard<Arc<AppConfig>> {
81 self.config.load()
82 }
83}
84
85pub async fn auth_middleware(
99 axum::extract::State(state): axum::extract::State<AuthState>,
100 mut request: Request,
101 next: Next,
102) -> Response {
103 match authenticate(&state, &request).await {
104 Ok(auth_result) => {
105 request.extensions_mut().insert(auth_result);
106 next.run(request).await
107 }
108 Err(jwt_err) => jwt_error_response(jwt_err),
109 }
110}
111
112pub async fn authenticate(state: &AuthState, request: &Request) -> Result<AuthResult, JwtError> {
114 let token = extract_bearer_token(request);
115 authenticate_token(state, token).await
116}
117
118pub async fn authenticate_token(
123 state: &AuthState,
124 token: Option<&str>,
125) -> Result<AuthResult, JwtError> {
126 let config = state.load_config();
127 match token {
128 Some(token) => {
129 if let Some(cached) = state.cache.get(token).await {
131 return Ok((*cached).clone());
132 }
133
134 let result = jwt::parse_and_validate(token, &config)?;
136
137 state.cache.insert(token, result.clone()).await;
139
140 Ok(result)
141 }
142 None => {
143 if let Some(ref anon_role) = config.db_anon_role {
145 Ok(AuthResult::anonymous(anon_role))
146 } else {
147 Err(JwtError::TokenRequired)
148 }
149 }
150 }
151}
152
153fn extract_bearer_token(request: &Request) -> Option<&str> {
159 let header_value = request.headers().get(header::AUTHORIZATION)?;
160 let header_str = header_value.to_str().ok()?;
161
162 if let Some(token) = header_str.strip_prefix("Bearer ") {
163 Some(token)
164 } else if let Some(token) = header_str.strip_prefix("bearer ") {
165 Some(token)
166 } else {
167 None
169 }
170}
171
172pub fn jwt_error_response(err: JwtError) -> Response {
174 let status = err.status();
175 let www_auth = err.www_authenticate();
176
177 let body = ErrorResponse {
178 code: err.code(),
179 message: err.to_string(),
180 details: err.details(),
181 hint: None,
182 };
183
184 let mut response = (status, axum::Json(body)).into_response();
185
186 if let Some(www_auth_value) = www_auth
187 && let Ok(hv) = http::HeaderValue::from_str(&www_auth_value)
188 {
189 response.headers_mut().insert(header::WWW_AUTHENTICATE, hv);
190 }
191
192 response
193}
194
195#[cfg(test)]
200mod tests {
201 use super::*;
202 use axum::body::Body;
203 use jsonwebtoken::{EncodingKey, Header as JwtHeader};
204
205 fn test_state(secret: &str) -> AuthState {
206 let mut config = AppConfig::default();
207 config.jwt_secret = Some(secret.to_string());
208 config.db_anon_role = Some("web_anon".to_string());
209 config.jwt_cache_max_entries = 100;
210 AuthState::new(Arc::new(config))
211 }
212
213 fn test_state_no_anon(secret: &str) -> AuthState {
214 let mut config = AppConfig::default();
215 config.jwt_secret = Some(secret.to_string());
216 config.db_anon_role = None;
217 config.jwt_cache_max_entries = 100;
218 AuthState::new(Arc::new(config))
219 }
220
221 fn encode_token(claims: &serde_json::Value, secret: &str) -> String {
222 jsonwebtoken::encode(
223 &JwtHeader::default(),
224 claims,
225 &EncodingKey::from_secret(secret.as_bytes()),
226 )
227 .unwrap()
228 }
229
230 fn make_request(token: Option<&str>) -> Request {
231 let mut builder = Request::builder().method("GET").uri("/items");
232 if let Some(t) = token {
233 builder = builder.header("Authorization", format!("Bearer {t}"));
234 }
235 builder.body(Body::empty()).unwrap()
236 }
237
238 #[tokio::test]
239 async fn test_authenticate_valid_token() {
240 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
241 let state = test_state(secret);
242 let claims = serde_json::json!({
243 "role": "test_author",
244 "exp": chrono::Utc::now().timestamp() + 3600
245 });
246 let token = encode_token(&claims, secret);
247 let request = make_request(Some(&token));
248
249 let result = authenticate(&state, &request).await.unwrap();
250 assert_eq!(result.role.as_str(), "test_author");
251 assert!(!result.is_anonymous());
252 }
253
254 #[tokio::test]
255 async fn test_authenticate_anonymous() {
256 let state = test_state("secret");
257 let request = make_request(None);
258
259 let result = authenticate(&state, &request).await.unwrap();
260 assert_eq!(result.role.as_str(), "web_anon");
261 assert!(result.is_anonymous());
262 }
263
264 #[tokio::test]
265 async fn test_authenticate_no_anon_no_token() {
266 let state = test_state_no_anon("secret");
267 let request = make_request(None);
268
269 let err = authenticate(&state, &request).await.unwrap_err();
270 assert!(matches!(err, JwtError::TokenRequired));
271 }
272
273 #[tokio::test]
274 async fn test_authenticate_expired_token() {
275 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
276 let state = test_state(secret);
277 let claims = serde_json::json!({
278 "role": "test_author",
279 "exp": chrono::Utc::now().timestamp() - 60
280 });
281 let token = encode_token(&claims, secret);
282 let request = make_request(Some(&token));
283
284 let err = authenticate(&state, &request).await.unwrap_err();
285 assert!(matches!(err, JwtError::Claims(_)));
286 }
287
288 #[tokio::test]
289 async fn test_authenticate_wrong_secret() {
290 let state = test_state("correct_secret_is_long_enough");
291 let claims = serde_json::json!({
292 "role": "test_author",
293 "exp": chrono::Utc::now().timestamp() + 3600
294 });
295 let token = encode_token(&claims, "wrong_secret_value_different");
296 let request = make_request(Some(&token));
297
298 let err = authenticate(&state, &request).await.unwrap_err();
299 assert!(matches!(err, JwtError::Decode(_)));
300 }
301
302 #[tokio::test]
303 async fn test_authenticate_cache_hit() {
304 let secret = "a]gq@2Yr4wLvA#_6!qnMb*X^tbP$I@av";
305 let state = test_state(secret);
306 let claims = serde_json::json!({
307 "role": "cached_role",
308 "exp": chrono::Utc::now().timestamp() + 3600
309 });
310 let token = encode_token(&claims, secret);
311
312 let request = make_request(Some(&token));
314 let result1 = authenticate(&state, &request).await.unwrap();
315 assert_eq!(result1.role.as_str(), "cached_role");
316
317 let request = make_request(Some(&token));
319 let result2 = authenticate(&state, &request).await.unwrap();
320 assert_eq!(result2.role.as_str(), "cached_role");
321
322 assert!(state.cache.get(&token).await.is_some());
324 }
325
326 #[tokio::test]
327 async fn test_authenticate_empty_bearer() {
328 let state = test_state("secret");
329 let request = Request::builder()
330 .method("GET")
331 .uri("/items")
332 .header("Authorization", "Bearer ")
333 .body(Body::empty())
334 .unwrap();
335
336 let err = authenticate(&state, &request).await.unwrap_err();
337 assert!(matches!(
338 err,
339 JwtError::Decode(super::super::error::JwtDecodeError::EmptyAuthHeader)
340 ));
341 }
342
343 #[test]
344 fn test_extract_bearer_token() {
345 let req = make_request(Some("abc123"));
346 assert_eq!(extract_bearer_token(&req), Some("abc123"));
347
348 let req = make_request(None);
349 assert!(extract_bearer_token(&req).is_none());
350
351 let req = Request::builder()
353 .method("GET")
354 .uri("/")
355 .header("Authorization", "bearer mytoken")
356 .body(Body::empty())
357 .unwrap();
358 assert_eq!(extract_bearer_token(&req), Some("mytoken"));
359
360 let req = Request::builder()
362 .method("GET")
363 .uri("/")
364 .header("Authorization", "Basic dXNlcjpwYXNz")
365 .body(Body::empty())
366 .unwrap();
367 assert!(extract_bearer_token(&req).is_none());
368 }
369
370 #[test]
371 fn test_jwt_error_response_has_www_authenticate() {
372 let err = JwtError::TokenRequired;
373 let response = jwt_error_response(err);
374 assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
375 assert!(response.headers().contains_key(header::WWW_AUTHENTICATE));
376 assert_eq!(
377 response.headers().get(header::WWW_AUTHENTICATE).unwrap(),
378 "Bearer"
379 );
380 }
381
382 #[test]
383 fn test_jwt_error_response_decode() {
384 let err = JwtError::Decode(super::super::error::JwtDecodeError::BadCrypto);
385 let response = jwt_error_response(err);
386 assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
387 let www = response
388 .headers()
389 .get(header::WWW_AUTHENTICATE)
390 .unwrap()
391 .to_str()
392 .unwrap();
393 assert!(www.contains("invalid_token"));
394 }
395
396 #[test]
397 fn test_jwt_error_response_secret_missing() {
398 let err = JwtError::SecretMissing;
399 let response = jwt_error_response(err);
400 assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
401 assert!(!response.headers().contains_key(header::WWW_AUTHENTICATE));
402 }
403
404 #[test]
405 fn test_shared_config_swap_propagates() {
406 let config = AppConfig::default();
407 let swap = Arc::new(ArcSwap::new(Arc::new(config)));
408 let auth = AuthState::with_shared_config(swap.clone());
409
410 assert_eq!(auth.load_config().server_port, 3000);
412
413 let mut new_config = AppConfig::default();
415 new_config.server_port = 9999;
416 swap.store(Arc::new(new_config));
417
418 assert_eq!(auth.load_config().server_port, 9999);
420 }
421
422 #[test]
423 fn test_new_constructor_creates_isolated_swap() {
424 let config = AppConfig::default();
425 let auth = AuthState::new(Arc::new(config));
426 assert_eq!(auth.load_config().server_port, 3000);
427 }
428}