use axum::{
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, pub exp: usize, pub role: String, }
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub message: String,
}
#[derive(Clone)]
pub struct JwtConfig {
pub secret: String,
}
impl JwtConfig {
pub fn new(secret: String) -> Self {
Self { secret }
}
}
pub async fn jwt_auth_middleware(
mut request: Request,
next: Next,
) -> Result<Response, Response> {
println!("→ JWT Auth Middleware: Checking authentication");
let config = request
.extensions()
.get::<JwtConfig>()
.ok_or_else(|| {
println!(" ✗ JWT config not found in extensions");
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"JWT configuration missing"
)
})?
.clone();
let auth_header = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok());
let auth_header = match auth_header {
Some(header) => header,
None => {
println!(" ✗ No Authorization header found");
return Err(error_response(
StatusCode::UNAUTHORIZED,
"Missing authorization header"
));
}
};
println!(" ✓ Authorization header found");
let token = if let Some(t) = auth_header.strip_prefix("Bearer ") {
println!(" ✓ Token format: Bearer <token>");
t
} else {
println!(" ✓ Token format: direct token (no Bearer prefix)");
auth_header
};
println!(" ✓ Token extracted from header");
let validation = Validation::new(Algorithm::HS256);
match decode::<Claims>(
token,
&DecodingKey::from_secret(config.secret.as_bytes()),
&validation,
) {
Ok(token_data) => {
println!(" ✓ Token valid");
println!(" - User: {}", token_data.claims.sub);
println!(" - Role: {}", token_data.claims.role);
request.extensions_mut().insert(token_data.claims);
println!(" ✓ Authentication successful, proceeding to handler");
Ok(next.run(request).await)
}
Err(err) => {
println!(" ✗ Token validation failed: {:?}", err);
let error_message = match err.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
"Token has expired"
}
jsonwebtoken::errors::ErrorKind::InvalidToken => {
"Invalid token"
}
jsonwebtoken::errors::ErrorKind::InvalidSignature => {
"Invalid token signature"
}
_ => "Token validation failed"
};
Err(error_response(StatusCode::UNAUTHORIZED, error_message))
}
}
}
fn error_response(status: StatusCode, message: &str) -> Response {
let error = ErrorResponse {
message: message.to_string(),
};
(status, Json(error)).into_response()
}
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
#[axum::async_trait]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Claims>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Claims not found in request extensions".to_string(),
)
})
}
}