1use axum::{
2 extract::Request,
3 http::StatusCode,
4 middleware::Next,
5 response::{IntoResponse, Response},
6 Json,
7};
8use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Claims {
14 pub sub: String, pub exp: usize, pub role: String, }
18
19#[derive(Debug, Serialize)]
21pub struct ErrorResponse {
22 pub message: String,
23}
24
25#[derive(Clone)]
27pub struct JwtConfig {
28 pub secret: String,
29}
30
31impl JwtConfig {
32 pub fn new(secret: String) -> Self {
33 Self { secret }
34 }
35}
36
37pub async fn jwt_auth_middleware(
39 mut request: Request,
40 next: Next,
41) -> Result<Response, Response> {
42 println!("→ JWT Auth Middleware: Checking authentication");
43
44 let config = request
46 .extensions()
47 .get::<JwtConfig>()
48 .ok_or_else(|| {
49 println!(" ✗ JWT config not found in extensions");
50 error_response(
51 StatusCode::INTERNAL_SERVER_ERROR,
52 "JWT configuration missing"
53 )
54 })?
55 .clone();
56
57 let auth_header = request
59 .headers()
60 .get("authorization")
61 .and_then(|h| h.to_str().ok());
62
63 let auth_header = match auth_header {
64 Some(header) => header,
65 None => {
66 println!(" ✗ No Authorization header found");
67 return Err(error_response(
68 StatusCode::UNAUTHORIZED,
69 "Missing authorization header"
70 ));
71 }
72 };
73
74 println!(" ✓ Authorization header found");
75
76 let token = if let Some(t) = auth_header.strip_prefix("Bearer ") {
78 println!(" ✓ Token format: Bearer <token>");
80 t
81 } else {
82 println!(" ✓ Token format: direct token (no Bearer prefix)");
84 auth_header
85 };
86
87 println!(" ✓ Token extracted from header");
88
89 let validation = Validation::new(Algorithm::HS256);
91
92 match decode::<Claims>(
93 token,
94 &DecodingKey::from_secret(config.secret.as_bytes()),
95 &validation,
96 ) {
97 Ok(token_data) => {
98 println!(" ✓ Token valid");
99 println!(" - User: {}", token_data.claims.sub);
100 println!(" - Role: {}", token_data.claims.role);
101
102 request.extensions_mut().insert(token_data.claims);
104
105 println!(" ✓ Authentication successful, proceeding to handler");
106 Ok(next.run(request).await)
107 }
108 Err(err) => {
109 println!(" ✗ Token validation failed: {:?}", err);
110
111 let error_message = match err.kind() {
112 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
113 "Token has expired"
114 }
115 jsonwebtoken::errors::ErrorKind::InvalidToken => {
116 "Invalid token"
117 }
118 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
119 "Invalid token signature"
120 }
121 _ => "Token validation failed"
122 };
123
124 Err(error_response(StatusCode::UNAUTHORIZED, error_message))
125 }
126 }
127}
128
129fn error_response(status: StatusCode, message: &str) -> Response {
131 let error = ErrorResponse {
132 message: message.to_string(),
133 };
134 (status, Json(error)).into_response()
135}
136
137use axum::extract::FromRequestParts;
139use axum::http::request::Parts;
140
141#[axum::async_trait]
142impl<S> FromRequestParts<S> for Claims
143where
144 S: Send + Sync,
145{
146 type Rejection = (StatusCode, String);
147
148 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
149 parts
150 .extensions
151 .get::<Claims>()
152 .cloned()
153 .ok_or_else(|| {
154 (
155 StatusCode::INTERNAL_SERVER_ERROR,
156 "Claims not found in request extensions".to_string(),
157 )
158 })
159 }
160}