1use crate::{error::AuthError, state::AppState};
2use axum::{
3 body::Body,
4 extract::State,
5 http::{HeaderMap, Request},
6 middleware::Next,
7 response::Response,
8};
9use jsonwebtoken::{DecodingKey, Validation, decode};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct Claims {
15 pub sub: String, pub exp: usize, }
20
21pub async fn auth_middleware(
22 State(state): State<AppState>,
23 headers: HeaderMap,
24 mut request: Request<Body>,
25 next: Next,
26) -> Result<Response, AuthError> {
27 let token = headers
28 .get("Authorization")
29 .and_then(|header| header.to_str().ok())
30 .and_then(|header| header.strip_prefix("Bearer "));
31
32 let token = token.ok_or(AuthError::MissingCredentials)?;
33
34 let decoding_key = DecodingKey::from_secret(state.config.jwt_secret.as_ref());
35
36 let validation = Validation::default();
37
38 let claims = decode::<Claims>(token, &decoding_key, &validation)
39 .map_err(|e| AuthError::InvalidToken(e.to_string()))?
40 .claims;
41
42 request.extensions_mut().insert(claims);
44
45 Ok(next.run(request).await)
46}
47
48#[cfg(test)]
50mod tests {
51 use crate::AppConfig;
52
53 use super::*;
54 use jsonwebtoken::{EncodingKey, Header, encode};
55 use std::time::{Duration, SystemTime, UNIX_EPOCH};
56
57 fn generate_test_jwt(
59 user_id: &str,
60 duration_secs: u64,
61 ) -> Result<String, jsonwebtoken::errors::Error> {
62 let config = AppConfig::load().unwrap();
63 let secret = config.jwt_secret;
64 let now = SystemTime::now();
65 let expiration = now.duration_since(UNIX_EPOCH).expect("Time went backwards")
66 + Duration::from_secs(duration_secs);
67
68 let claims = Claims {
69 sub: user_id.to_owned(),
70 exp: expiration.as_secs() as usize,
71 };
72
73 let header = Header::default(); let encoding_key = EncodingKey::from_secret(secret.as_ref());
75
76 encode(&header, &claims, &encoding_key)
77 }
78
79 #[test]
80 fn test_jwt_generation() {
81 let config = AppConfig::load().unwrap();
82 let secret = config.jwt_secret;
83 let user_id = "test_user@example.com";
84 let token = generate_test_jwt(user_id, 3600 * 24 * 365 * 10); assert!(token.is_ok());
87 let generated_token = token.unwrap();
88 println!("Generated Test JWT: {}", generated_token);
89
90 let decoding_key = DecodingKey::from_secret(secret.as_ref());
92 let validation = Validation::default();
93 let decoded = decode::<Claims>(&generated_token, &decoding_key, &validation);
94 assert!(decoded.is_ok());
95 assert_eq!(decoded.unwrap().claims.sub, user_id);
96 }
97}