1use axum::{
6 body::Body,
7 extract::{Request, State},
8 http::{header::AUTHORIZATION, StatusCode},
9 middleware::Next,
10 response::{IntoResponse, Response},
11};
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use crate::{
16 error::ErrorResponse,
17 jwt::{Claims, JwtManager, TokenError},
18};
19
20#[derive(Debug, Clone)]
22pub struct AuthUser {
23 pub claims: Claims,
25}
26
27impl AuthUser {
28 pub fn new(claims: Claims) -> Self {
30 Self { claims }
31 }
32
33 pub fn user_id(&self) -> &str {
35 &self.claims.sub
36 }
37
38 pub fn email(&self) -> Option<&str> {
40 self.claims.email.as_deref()
41 }
42
43 pub fn has_role(&self, role: &str) -> bool {
45 self.claims.has_role(role)
46 }
47
48 pub fn has_any_role(&self, roles: &[&str]) -> bool {
50 self.claims.has_any_role(roles)
51 }
52
53 pub fn has_all_roles(&self, roles: &[&str]) -> bool {
55 self.claims.has_all_roles(roles)
56 }
57}
58
59#[derive(Clone)]
61pub struct AuthState {
62 jwt_manager: Arc<JwtManager>,
63}
64
65impl AuthState {
66 pub fn new(jwt_manager: JwtManager) -> Self {
68 Self {
69 jwt_manager: Arc::new(jwt_manager),
70 }
71 }
72
73 pub fn jwt_manager(&self) -> &JwtManager {
75 &self.jwt_manager
76 }
77}
78
79pub async fn require_auth(
101 State(auth_state): State<AuthState>,
102 mut request: Request,
103 next: Next,
104) -> Result<Response, AuthError> {
105 debug!("Authenticating request");
106
107 let auth_header = request
109 .headers()
110 .get(AUTHORIZATION)
111 .and_then(|h| h.to_str().ok())
112 .ok_or(AuthError::MissingToken)?;
113
114 let token = JwtManager::extract_token_from_header(auth_header)
116 .map_err(|_| AuthError::InvalidToken)?;
117
118 let claims = auth_state
120 .jwt_manager
121 .validate_token(token)
122 .map_err(|e| match e {
123 TokenError::Expired => AuthError::ExpiredToken,
124 TokenError::NotYetValid => AuthError::InvalidToken,
125 _ => AuthError::InvalidToken,
126 })?;
127
128 debug!("User authenticated: {}", claims.sub);
129
130 request.extensions_mut().insert(AuthUser::new(claims));
132
133 Ok(next.run(request).await)
134}
135
136pub async fn optional_auth(
159 State(auth_state): State<AuthState>,
160 mut request: Request,
161 next: Next,
162) -> Response {
163 debug!("Attempting optional authentication");
164
165 if let Some(auth_header) = request.headers().get(AUTHORIZATION) {
167 if let Ok(header_str) = auth_header.to_str() {
168 if let Ok(token) = JwtManager::extract_token_from_header(header_str) {
169 if let Ok(claims) = auth_state.jwt_manager.validate_token(token) {
170 debug!("User optionally authenticated: {}", claims.sub);
171 request.extensions_mut().insert(AuthUser::new(claims));
172 }
173 }
174 }
175 }
176
177 next.run(request).await
178}
179
180pub async fn require_role(
207 State((auth_state, allowed_roles)): State<(AuthState, Vec<String>)>,
208 mut request: Request,
209 next: Next,
210) -> Result<Response, AuthError> {
211 debug!("Authenticating request with role check");
212
213 let auth_header = request
215 .headers()
216 .get(AUTHORIZATION)
217 .and_then(|h| h.to_str().ok())
218 .ok_or(AuthError::MissingToken)?;
219
220 let token = JwtManager::extract_token_from_header(auth_header)
221 .map_err(|_| AuthError::InvalidToken)?;
222
223 let claims = auth_state
224 .jwt_manager
225 .validate_token(token)
226 .map_err(|e| match e {
227 TokenError::Expired => AuthError::ExpiredToken,
228 TokenError::NotYetValid => AuthError::InvalidToken,
229 _ => AuthError::InvalidToken,
230 })?;
231
232 let role_refs: Vec<&str> = allowed_roles.iter().map(|s| s.as_str()).collect();
234 if !claims.has_any_role(&role_refs) {
235 warn!("User {} lacks required role", claims.sub);
236 return Err(AuthError::InsufficientPermissions);
237 }
238
239 debug!("User authenticated with role: {}", claims.sub);
240 request.extensions_mut().insert(AuthUser::new(claims));
241
242 Ok(next.run(request).await)
243}
244
245pub fn extract_user(request: &Request<Body>) -> Result<&AuthUser, AuthError> {
261 request
262 .extensions()
263 .get::<AuthUser>()
264 .ok_or(AuthError::Unauthenticated)
265}
266
267#[derive(Debug)]
269pub enum AuthError {
270 MissingToken,
272
273 InvalidToken,
275
276 ExpiredToken,
278
279 Unauthenticated,
281
282 InsufficientPermissions,
284}
285
286impl IntoResponse for AuthError {
287 fn into_response(self) -> Response {
288 let (status, message) = match self {
289 AuthError::MissingToken => (
290 StatusCode::UNAUTHORIZED,
291 "Missing authentication token",
292 ),
293 AuthError::InvalidToken => (
294 StatusCode::UNAUTHORIZED,
295 "Invalid authentication token",
296 ),
297 AuthError::ExpiredToken => (
298 StatusCode::UNAUTHORIZED,
299 "Authentication token has expired",
300 ),
301 AuthError::Unauthenticated => (
302 StatusCode::UNAUTHORIZED,
303 "Authentication required",
304 ),
305 AuthError::InsufficientPermissions => (
306 StatusCode::FORBIDDEN,
307 "Insufficient permissions",
308 ),
309 };
310
311 let error_response = ErrorResponse {
312 status: status.as_u16(),
313 error: message.to_string(),
314 code: None,
315 timestamp: chrono::Utc::now(),
316 };
317
318 (status, axum::Json(error_response)).into_response()
319 }
320}
321
322impl std::fmt::Display for AuthError {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 match self {
325 AuthError::MissingToken => write!(f, "Missing authentication token"),
326 AuthError::InvalidToken => write!(f, "Invalid authentication token"),
327 AuthError::ExpiredToken => write!(f, "Authentication token has expired"),
328 AuthError::Unauthenticated => write!(f, "Authentication required"),
329 AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"),
330 }
331 }
332}
333
334impl std::error::Error for AuthError {}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::jwt::{JwtConfig, JwtManager};
340 use axum::{
341 body::Body,
342 extract::Extension,
343 http::{Request, StatusCode},
344 middleware,
345 routing::get,
346 Router,
347 };
348 use tower::ServiceExt;
349
350 fn create_test_jwt_manager() -> JwtManager {
351 let config = JwtConfig::new("test-secret-key")
352 .with_issuer("test")
353 .with_audience("test");
354 JwtManager::new(config).unwrap()
355 }
356
357 async fn protected_handler(Extension(user): axum::extract::Extension<AuthUser>) -> String {
358 format!("Hello, {}", user.user_id())
359 }
360
361 #[tokio::test]
362 async fn test_require_auth_with_valid_token() {
363 let jwt_manager = create_test_jwt_manager();
364 let token = jwt_manager.generate_token("user123").unwrap();
365 let auth_state = AuthState::new(jwt_manager);
366
367 let app = Router::new()
368 .route("/protected", get(protected_handler))
369 .layer(middleware::from_fn_with_state(
370 auth_state.clone(),
371 require_auth,
372 ));
373
374 let request = Request::builder()
375 .uri("/protected")
376 .header(AUTHORIZATION, format!("Bearer {}", token))
377 .body(Body::empty())
378 .unwrap();
379
380 let response = app.oneshot(request).await.unwrap();
381 assert_eq!(response.status(), StatusCode::OK);
382 }
383
384 #[tokio::test]
385 async fn test_require_auth_without_token() {
386 let jwt_manager = create_test_jwt_manager();
387 let auth_state = AuthState::new(jwt_manager);
388
389 let app = Router::new()
390 .route("/protected", get(protected_handler))
391 .layer(middleware::from_fn_with_state(
392 auth_state.clone(),
393 require_auth,
394 ));
395
396 let request = Request::builder()
397 .uri("/protected")
398 .body(Body::empty())
399 .unwrap();
400
401 let response = app.oneshot(request).await.unwrap();
402 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
403 }
404
405 #[tokio::test]
406 async fn test_require_auth_with_invalid_token() {
407 let jwt_manager = create_test_jwt_manager();
408 let auth_state = AuthState::new(jwt_manager);
409
410 let app = Router::new()
411 .route("/protected", get(protected_handler))
412 .layer(middleware::from_fn_with_state(
413 auth_state.clone(),
414 require_auth,
415 ));
416
417 let request = Request::builder()
418 .uri("/protected")
419 .header(AUTHORIZATION, "Bearer invalid.token.here")
420 .body(Body::empty())
421 .unwrap();
422
423 let response = app.oneshot(request).await.unwrap();
424 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
425 }
426
427 #[tokio::test]
428 async fn test_optional_auth_with_token() {
429 let jwt_manager = create_test_jwt_manager();
430 let token = jwt_manager.generate_token("user123").unwrap();
431 let auth_state = AuthState::new(jwt_manager);
432
433 async fn handler(user: Option<axum::extract::Extension<AuthUser>>) -> String {
434 match user {
435 Some(Extension(u)) => format!("Hello, {}", u.user_id()),
436 None => "Hello, guest".to_string(),
437 }
438 }
439
440 let app = Router::new()
441 .route("/public", get(handler))
442 .layer(middleware::from_fn_with_state(
443 auth_state.clone(),
444 optional_auth,
445 ));
446
447 let request = Request::builder()
448 .uri("/public")
449 .header(AUTHORIZATION, format!("Bearer {}", token))
450 .body(Body::empty())
451 .unwrap();
452
453 let response = app.oneshot(request).await.unwrap();
454 assert_eq!(response.status(), StatusCode::OK);
455 }
456
457 #[tokio::test]
458 async fn test_optional_auth_without_token() {
459 let jwt_manager = create_test_jwt_manager();
460 let auth_state = AuthState::new(jwt_manager);
461
462 async fn handler() -> &'static str {
463 "Public content"
464 }
465
466 let app = Router::new()
467 .route("/public", get(handler))
468 .layer(middleware::from_fn_with_state(
469 auth_state.clone(),
470 optional_auth,
471 ));
472
473 let request = Request::builder()
474 .uri("/public")
475 .body(Body::empty())
476 .unwrap();
477
478 let response = app.oneshot(request).await.unwrap();
479 assert_eq!(response.status(), StatusCode::OK);
480 }
481
482 #[test]
483 fn test_auth_user() {
484 let claims = crate::jwt::Claims::new("user123", "test", "test", 3600)
485 .with_email("user@example.com")
486 .with_role("admin");
487
488 let auth_user = AuthUser::new(claims);
489
490 assert_eq!(auth_user.user_id(), "user123");
491 assert_eq!(auth_user.email(), Some("user@example.com"));
492 assert!(auth_user.has_role("admin"));
493 assert!(!auth_user.has_role("moderator"));
494 }
495}