auth_middleware/
lib.rs

1use axum::{
2    extract::Request,
3    http::StatusCode,
4    middleware::Next,
5    response::{IntoResponse, Response},
6    Json,
7};
8use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
9use serde::{Deserialize, Serialize};
10
11// JWT Claims structure - must match your token structure
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14    pub sub: String,        // Subject (user ID/username)
15    pub exp: usize,         // Expiration time (Unix timestamp)
16    pub role: String,       // User role (for RBAC)
17}
18
19// Error response structure
20#[derive(Debug, Serialize)]
21pub struct ErrorResponse {
22    pub message: String,
23}
24
25// Configuration for JWT middleware
26#[derive(Clone)]
27pub struct JwtConfig {
28    pub secret: String,
29}
30
31impl JwtConfig {
32    pub fn new(secret: String) -> Self {
33        Self { secret }
34    }
35}
36
37// JWT Auth Middleware
38pub async fn jwt_auth_middleware(
39    mut request: Request,
40    next: Next,
41) -> Result<Response, Response> {
42    println!("→ JWT Auth Middleware: Checking authentication");
43
44    // Get secret from request extensions (injected by axum layer)
45    let config = request
46        .extensions()
47        .get::<JwtConfig>()
48        .ok_or_else(|| {
49            println!("   ✗ JWT config not found in extensions");
50            error_response(
51                StatusCode::INTERNAL_SERVER_ERROR,
52                "JWT configuration missing"
53            )
54        })?
55        .clone();
56
57    // Extract Authorization header
58    let auth_header = request
59        .headers()
60        .get("authorization")
61        .and_then(|h| h.to_str().ok());
62
63    let auth_header = match auth_header {
64        Some(header) => header,
65        None => {
66            println!("   ✗ No Authorization header found");
67            return Err(error_response(
68                StatusCode::UNAUTHORIZED,
69                "Missing authorization header"
70            ));
71        }
72    };
73
74    println!("   ✓ Authorization header found");
75
76    // Extract token - support both "Bearer <token>" and direct token
77    let token = if let Some(t) = auth_header.strip_prefix("Bearer ") {
78        // Bearer format: "Bearer eyJhbGc..."
79        println!("   ✓ Token format: Bearer <token>");
80        t
81    } else {
82        // Direct token format: "eyJhbGc..."
83        println!("   ✓ Token format: direct token (no Bearer prefix)");
84        auth_header
85    };
86
87    println!("   ✓ Token extracted from header");
88
89    // Validate JWT token with configured secret
90    let validation = Validation::new(Algorithm::HS256);
91
92    match decode::<Claims>(
93        token,
94        &DecodingKey::from_secret(config.secret.as_bytes()),
95        &validation,
96    ) {
97        Ok(token_data) => {
98            println!("   ✓ Token valid");
99            println!("      - User: {}", token_data.claims.sub);
100            println!("      - Role: {}", token_data.claims.role);
101
102            // Token is valid - attach claims to request for handlers to use
103            request.extensions_mut().insert(token_data.claims);
104
105            println!("   ✓ Authentication successful, proceeding to handler");
106            Ok(next.run(request).await)
107        }
108        Err(err) => {
109            println!("   ✗ Token validation failed: {:?}", err);
110
111            let error_message = match err.kind() {
112                jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
113                    "Token has expired"
114                }
115                jsonwebtoken::errors::ErrorKind::InvalidToken => {
116                    "Invalid token"
117                }
118                jsonwebtoken::errors::ErrorKind::InvalidSignature => {
119                    "Invalid token signature"
120                }
121                _ => "Token validation failed"
122            };
123
124            Err(error_response(StatusCode::UNAUTHORIZED, error_message))
125        }
126    }
127}
128
129// Helper function to create consistent error responses
130fn error_response(status: StatusCode, message: &str) -> Response {
131    let error = ErrorResponse {
132        message: message.to_string(),
133    };
134    (status, Json(error)).into_response()
135}
136
137// Optional: Extension trait to make it easy to extract claims in handlers
138use axum::extract::FromRequestParts;
139use axum::http::request::Parts;
140
141#[axum::async_trait]
142impl<S> FromRequestParts<S> for Claims
143where
144    S: Send + Sync,
145{
146    type Rejection = (StatusCode, String);
147
148    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
149        parts
150            .extensions
151            .get::<Claims>()
152            .cloned()
153            .ok_or_else(|| {
154                (
155                    StatusCode::INTERNAL_SERVER_ERROR,
156                    "Claims not found in request extensions".to_string(),
157                )
158            })
159    }
160}