use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone)]
pub struct JwtConfig {
secret: Arc<str>,
algorithms: Vec<Algorithm>,
issuer: Option<String>,
audience: Option<String>,
enabled: bool,
}
impl JwtConfig {
pub fn hs256(secret: impl Into<String>) -> Self {
Self {
secret: Arc::from(secret.into().as_str()),
algorithms: vec![Algorithm::HS256],
issuer: None,
audience: None,
enabled: true,
}
}
pub fn rs256(public_key: impl Into<String>) -> Self {
Self {
secret: Arc::from(public_key.into().as_str()),
algorithms: vec![Algorithm::RS256],
issuer: None,
audience: None,
enabled: true,
}
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
self.algorithms = algorithms;
self
}
pub fn disabled() -> Self {
Self {
secret: Arc::from(""),
algorithms: vec![],
issuer: None,
audience: None,
enabled: false,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
impl std::fmt::Debug for JwtConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtConfig")
.field("enabled", &self.enabled)
.field("algorithms", &self.algorithms)
.field("issuer", &self.issuer)
.field("audience", &self.audience)
.field("secret", &"[redacted]")
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}
impl JwtClaims {
pub fn subject(&self) -> Option<&str> {
self.sub.as_deref()
}
}
pub async fn jwt_middleware(
State(config): State<Arc<JwtConfig>>,
mut req: Request<Body>,
next: Next,
) -> Response {
if !config.enabled {
return next.run(req).await;
}
let token = match extract_bearer_token(req.headers()) {
Some(t) => t,
None => {
return unauthorized_response("missing Authorization header");
}
};
match validate_token(&config, token) {
Ok(claims) => {
req.extensions_mut().insert(claims);
next.run(req).await
}
Err(e) => {
tracing::warn!(error = %e, "JWT validation failed");
unauthorized_response("invalid token")
}
}
}
fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
let header_value = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
header_value.strip_prefix("Bearer ")
}
fn validate_token(config: &JwtConfig, token: &str) -> Result<JwtClaims, String> {
let mut validation = Validation::new(
config
.algorithms
.first()
.copied()
.unwrap_or(Algorithm::HS256),
);
if config.algorithms.len() > 1 {
validation.algorithms = config.algorithms.clone();
}
if let Some(ref issuer) = config.issuer {
validation.set_issuer(&[issuer.as_str()]);
}
if let Some(ref audience) = config.audience {
validation.set_audience(&[audience.as_str()]);
}
validation.validate_exp = true;
validation.validate_nbf = true;
validation.leeway = 30;
let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
let token_data = decode::<JwtClaims>(token, &decoding_key, &validation)
.map_err(|e| format!("JWT validation error: {e}"))?;
Ok(token_data.claims)
}
fn unauthorized_response(message: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
[(header::CONTENT_TYPE, "application/json")],
serde_json::json!({"error": message}).to_string(),
)
.into_response()
}
pub fn get_claims<B>(req: &Request<B>) -> Option<&JwtClaims>
where
B: std::fmt::Debug + Send + Sync + 'static,
{
req.extensions().get::<JwtClaims>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_bearer_token_valid() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
"Bearer my.jwt.token".parse().unwrap(),
);
assert_eq!(extract_bearer_token(&headers), Some("my.jwt.token"));
}
#[test]
fn test_extract_bearer_token_missing() {
let headers = axum::http::HeaderMap::new();
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn test_extract_bearer_token_not_bearer() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(header::AUTHORIZATION, "Basic dXNlcjpwYXNz".parse().unwrap());
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn test_jwt_config_disabled() {
let config = JwtConfig::disabled();
assert!(!config.is_enabled());
}
#[test]
fn test_jwt_config_debug_redacts_secret() {
let config = JwtConfig::hs256("super-secret-key-1234567890");
let debug_str = format!("{:?}", config);
assert!(!debug_str.contains("super-secret"));
assert!(debug_str.contains("[redacted]"));
}
#[test]
fn test_jwt_claims_subject() {
let claims = JwtClaims {
iss: None,
sub: Some("client-123".to_string()),
aud: None,
exp: None,
nbf: None,
iat: None,
jti: None,
extra: serde_json::Map::new(),
};
assert_eq!(claims.subject(), Some("client-123"));
}
#[test]
fn test_jwt_claims_no_subject() {
let claims = JwtClaims {
iss: None,
sub: None,
aud: None,
exp: None,
nbf: None,
iat: None,
jti: None,
extra: serde_json::Map::new(),
};
assert_eq!(claims.subject(), None);
}
}