auth-middleware-pkg 0.1.0

JWT authentication middleware for Axum with token validation and role-based access control
Documentation
use axum::{
    extract::Request,
    http::StatusCode,
    middleware::Next,
    response::{IntoResponse, Response},
    Json,
};
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};

// JWT Claims structure - must match your token structure
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
    pub sub: String,        // Subject (user ID/username)
    pub exp: usize,         // Expiration time (Unix timestamp)
    pub role: String,       // User role (for RBAC)
}

// Error response structure
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
    pub message: String,
}

// Configuration for JWT middleware
#[derive(Clone)]
pub struct JwtConfig {
    pub secret: String,
}

impl JwtConfig {
    pub fn new(secret: String) -> Self {
        Self { secret }
    }
}

// JWT Auth Middleware
pub async fn jwt_auth_middleware(
    mut request: Request,
    next: Next,
) -> Result<Response, Response> {
    println!("→ JWT Auth Middleware: Checking authentication");

    // Get secret from request extensions (injected by axum layer)
    let config = request
        .extensions()
        .get::<JwtConfig>()
        .ok_or_else(|| {
            println!("   ✗ JWT config not found in extensions");
            error_response(
                StatusCode::INTERNAL_SERVER_ERROR,
                "JWT configuration missing"
            )
        })?
        .clone();

    // Extract Authorization header
    let auth_header = request
        .headers()
        .get("authorization")
        .and_then(|h| h.to_str().ok());

    let auth_header = match auth_header {
        Some(header) => header,
        None => {
            println!("   ✗ No Authorization header found");
            return Err(error_response(
                StatusCode::UNAUTHORIZED,
                "Missing authorization header"
            ));
        }
    };

    println!("   ✓ Authorization header found");

    // Extract token - support both "Bearer <token>" and direct token
    let token = if let Some(t) = auth_header.strip_prefix("Bearer ") {
        // Bearer format: "Bearer eyJhbGc..."
        println!("   ✓ Token format: Bearer <token>");
        t
    } else {
        // Direct token format: "eyJhbGc..."
        println!("   ✓ Token format: direct token (no Bearer prefix)");
        auth_header
    };

    println!("   ✓ Token extracted from header");

    // Validate JWT token with configured secret
    let validation = Validation::new(Algorithm::HS256);

    match decode::<Claims>(
        token,
        &DecodingKey::from_secret(config.secret.as_bytes()),
        &validation,
    ) {
        Ok(token_data) => {
            println!("   ✓ Token valid");
            println!("      - User: {}", token_data.claims.sub);
            println!("      - Role: {}", token_data.claims.role);

            // Token is valid - attach claims to request for handlers to use
            request.extensions_mut().insert(token_data.claims);

            println!("   ✓ Authentication successful, proceeding to handler");
            Ok(next.run(request).await)
        }
        Err(err) => {
            println!("   ✗ Token validation failed: {:?}", err);

            let error_message = match err.kind() {
                jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
                    "Token has expired"
                }
                jsonwebtoken::errors::ErrorKind::InvalidToken => {
                    "Invalid token"
                }
                jsonwebtoken::errors::ErrorKind::InvalidSignature => {
                    "Invalid token signature"
                }
                _ => "Token validation failed"
            };

            Err(error_response(StatusCode::UNAUTHORIZED, error_message))
        }
    }
}

// Helper function to create consistent error responses
fn error_response(status: StatusCode, message: &str) -> Response {
    let error = ErrorResponse {
        message: message.to_string(),
    };
    (status, Json(error)).into_response()
}

// Optional: Extension trait to make it easy to extract claims in handlers
use axum::extract::FromRequestParts;
use axum::http::request::Parts;

#[axum::async_trait]
impl<S> FromRequestParts<S> for Claims
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, String);

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        parts
            .extensions
            .get::<Claims>()
            .cloned()
            .ok_or_else(|| {
                (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    "Claims not found in request extensions".to_string(),
                )
            })
    }
}