r2_data2/
auth.rs

1use crate::{error::AuthError, state::AppState};
2use axum::{
3    body::Body,
4    extract::State,
5    http::{HeaderMap, Request},
6    middleware::Next,
7    response::Response,
8};
9use jsonwebtoken::{DecodingKey, Validation, decode};
10use serde::{Deserialize, Serialize};
11
12// Define the structure of the JWT claims
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct Claims {
15    pub sub: String, // Subject (e.g., user ID or email)
16    pub exp: usize,  // Expiration time (timestamp)
17                     // Add any other custom claims you might need
18                     // pub roles: Vec<String>,
19}
20
21pub async fn auth_middleware(
22    State(state): State<AppState>,
23    headers: HeaderMap,
24    mut request: Request<Body>,
25    next: Next,
26) -> Result<Response, AuthError> {
27    let token = headers
28        .get("Authorization")
29        .and_then(|header| header.to_str().ok())
30        .and_then(|header| header.strip_prefix("Bearer "));
31
32    let token = token.ok_or(AuthError::MissingCredentials)?;
33
34    let decoding_key = DecodingKey::from_secret(state.config.jwt_secret.as_ref());
35
36    let validation = Validation::default();
37
38    let claims = decode::<Claims>(token, &decoding_key, &validation)
39        .map_err(|e| AuthError::InvalidToken(e.to_string()))?
40        .claims;
41
42    // Store claims in request extensions for handlers to use if needed
43    request.extensions_mut().insert(claims);
44
45    Ok(next.run(request).await)
46}
47
48// Add tests module
49#[cfg(test)]
50mod tests {
51    use crate::AppConfig;
52
53    use super::*;
54    use jsonwebtoken::{EncodingKey, Header, encode};
55    use std::time::{Duration, SystemTime, UNIX_EPOCH};
56
57    // Example function to generate a JWT
58    fn generate_test_jwt(
59        user_id: &str,
60        duration_secs: u64,
61    ) -> Result<String, jsonwebtoken::errors::Error> {
62        let config = AppConfig::load().unwrap();
63        let secret = config.jwt_secret;
64        let now = SystemTime::now();
65        let expiration = now.duration_since(UNIX_EPOCH).expect("Time went backwards")
66            + Duration::from_secs(duration_secs);
67
68        let claims = Claims {
69            sub: user_id.to_owned(),
70            exp: expiration.as_secs() as usize,
71        };
72
73        let header = Header::default(); // Default algorithm is HS256
74        let encoding_key = EncodingKey::from_secret(secret.as_ref());
75
76        encode(&header, &claims, &encoding_key)
77    }
78
79    #[test]
80    fn test_jwt_generation() {
81        let config = AppConfig::load().unwrap();
82        let secret = config.jwt_secret;
83        let user_id = "test_user@example.com";
84        let token = generate_test_jwt(user_id, 3600 * 24 * 365 * 10); // 10 years expiration
85
86        assert!(token.is_ok());
87        let generated_token = token.unwrap();
88        println!("Generated Test JWT: {}", generated_token);
89
90        // Optional: Verify the generated token (requires decoding logic similar to middleware)
91        let decoding_key = DecodingKey::from_secret(secret.as_ref());
92        let validation = Validation::default();
93        let decoded = decode::<Claims>(&generated_token, &decoding_key, &validation);
94        assert!(decoded.is_ok());
95        assert_eq!(decoded.unwrap().claims.sub, user_id);
96    }
97}