use axum::{
extract::{Request, State},
http::{StatusCode, HeaderMap},
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use crate::server::AppState;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: usize,
pub iat: usize,
pub role: String,
}
pub async fn auth_middleware(
State(state): State<AppState>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let path = request.uri().path();
if path == "/health" || path == "/metrics" || !state.config.auth.enabled {
return Ok(next.run(request).await);
}
let auth_header = headers
.get("Authorization")
.and_then(|header| header.to_str().ok())
.and_then(|header| header.strip_prefix("Bearer "));
let token = match auth_header {
Some(token) => token,
None => {
warn!("Missing authorization header");
return Err(StatusCode::UNAUTHORIZED);
}
};
let decoding_key = DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes());
let validation = Validation::default();
match decode::<Claims>(token, &decoding_key, &validation) {
Ok(token_data) => {
debug!("Authenticated user: {}", token_data.claims.sub);
Ok(next.run(request).await)
}
Err(e) => {
warn!("Invalid token: {}", e);
Err(StatusCode::UNAUTHORIZED)
}
}
}