use {axum::{body::Body,
http::{Request,
StatusCode},
response::{IntoResponse,
Response}},
jsonwebtoken::{Algorithm,
DecodingKey,
Validation,
decode},
serde::{Deserialize,
Serialize},
std::{future::Future,
pin::Pin,
sync::Arc,
task::{Context,
Poll}},
thiserror::Error,
tower::{Layer,
Service}};
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub secret: String,
pub algorithm: Algorithm,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: String::new(),
algorithm: Algorithm::HS512,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String,
#[serde(default)]
pub iat: Option<i64>,
#[serde(default)]
pub exp: Option<i64>,
}
#[derive(Debug, Error)]
pub enum JwtAuthError {
#[error("Token expired")]
TokenExpired,
#[error("Invalid token signature")]
InvalidSignature,
#[error("Invalid token format: {0}")]
InvalidFormat(String),
#[error("JWT not configured")]
NotConfigured,
}
impl IntoResponse for JwtAuthError {
fn into_response(self) -> Response {
let body = match &self {
| JwtAuthError::TokenExpired => "Token expired",
| JwtAuthError::InvalidSignature => "Invalid token signature",
| JwtAuthError::InvalidFormat(_) => "Invalid token format",
| JwtAuthError::NotConfigured => "Authentication not configured",
};
(StatusCode::UNAUTHORIZED, body).into_response()
}
}
#[derive(Clone)]
enum VerificationMode {
Hmac {
decoding_key: DecodingKey,
algorithm: Algorithm,
},
None,
}
#[derive(Clone)]
pub struct JwtValidator {
mode: VerificationMode,
}
impl JwtValidator {
pub fn from_config(config: &JwtConfig) -> Self {
if !config.secret.is_empty() {
match DecodingKey::from_base64_secret(&config.secret) {
| Ok(decoding_key) => {
return Self {
mode: VerificationMode::Hmac {
decoding_key,
algorithm: config.algorithm,
},
};
}
| Err(e) => {
tracing::error!("Failed to decode base64 JWT secret: {}", e);
}
}
}
Self {
mode: VerificationMode::None,
}
}
pub fn is_configured(&self) -> bool {
!matches!(self.mode, VerificationMode::None)
}
pub fn validate_and_extract_subject(&self, token: &str) -> Result<String, JwtAuthError> {
if !self.is_configured() {
return Err(JwtAuthError::NotConfigured);
}
let token = token.strip_prefix("Bearer ").unwrap_or(token);
let token_data = match &self.mode {
| VerificationMode::Hmac {
decoding_key,
algorithm,
} => {
let mut validation = Validation::new(*algorithm);
validation.validate_exp = true;
decode::<Claims>(token, decoding_key, &validation)
}
| VerificationMode::None => return Err(JwtAuthError::NotConfigured),
}
.map_err(|e| match e.kind() {
| jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtAuthError::TokenExpired,
| jsonwebtoken::errors::ErrorKind::InvalidSignature => JwtAuthError::InvalidSignature,
| _ => JwtAuthError::InvalidFormat(e.to_string()),
})?;
Ok(token_data.claims.sub)
}
}
#[derive(Clone)]
pub struct JwtAuthLayer {
validator: Arc<JwtValidator>,
subject_validator: Option<fn(&str) -> bool>,
}
impl JwtAuthLayer {
pub fn new(config: &JwtConfig) -> Self {
Self {
validator: Arc::new(JwtValidator::from_config(config)),
subject_validator: None,
}
}
pub fn with_subject_validator(mut self, f: fn(&str) -> bool) -> Self {
self.subject_validator = Some(f);
self
}
}
impl<S> Layer<S> for JwtAuthLayer {
type Service = JwtAuthMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
JwtAuthMiddleware {
inner,
validator: self.validator.clone(),
subject_validator: self.subject_validator,
}
}
}
#[derive(Clone)]
pub struct JwtAuthMiddleware<S> {
inner: S,
validator: Arc<JwtValidator>,
subject_validator: Option<fn(&str) -> bool>,
}
impl<S> Service<Request<Body>> for JwtAuthMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send,
{
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
let validator = self.validator.clone();
let subject_validator = self.subject_validator;
Box::pin(async move {
let auth_header = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let token = match auth_header {
| Some(ref header) if header.starts_with("Bearer ") => &header[7 ..],
| _ => {
return Ok((StatusCode::UNAUTHORIZED, "Missing or invalid Authorization header").into_response());
}
};
let subject = match validator.validate_and_extract_subject(token) {
| Ok(sub) => sub,
| Err(e) => return Ok(e.into_response()),
};
if let Some(validate_fn) = subject_validator {
if !validate_fn(&subject) {
return Ok((StatusCode::UNAUTHORIZED, "Invalid subject in token").into_response());
}
}
req.extensions_mut().insert(subject);
inner.call(req).await
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use {super::*,
jsonwebtoken::{EncodingKey,
Header,
encode}};
fn create_test_secret() -> String {
data_encoding::BASE64.encode(b"test-secret-key-for-jwt")
}
fn create_test_token(sub: &str, secret: &[u8], exp: Option<i64>) -> String {
let claims = Claims {
sub: sub.to_string(),
iat: Some(chrono_now()),
exp,
};
let header = Header::new(Algorithm::HS512);
encode(&header, &claims, &EncodingKey::from_secret(secret)).unwrap()
}
fn chrono_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
#[test]
fn test_from_config_empty_secret() {
let config = JwtConfig::default();
let validator = JwtValidator::from_config(&config);
assert!(!validator.is_configured());
}
#[test]
fn test_from_config_valid_secret() {
let config = JwtConfig {
secret: create_test_secret(),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
assert!(validator.is_configured());
}
#[test]
fn test_from_config_invalid_base64() {
let config = JwtConfig {
secret: "not-valid-base64!!!".to_string(),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
assert!(!validator.is_configured());
}
#[test]
fn test_not_configured_returns_error() {
let validator = JwtValidator::from_config(&JwtConfig::default());
let result = validator.validate_and_extract_subject("some-token");
assert!(matches!(result, Err(JwtAuthError::NotConfigured)));
}
#[test]
fn test_valid_token() {
let secret = b"test-secret-key-for-jwt";
let config = JwtConfig {
secret: data_encoding::BASE64.encode(secret),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
let exp = chrono_now() + 3600; let token = create_test_token("user123", secret, Some(exp));
let result = validator.validate_and_extract_subject(&token);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "user123");
}
#[test]
fn test_valid_token_with_bearer_prefix() {
let secret = b"test-secret-key-for-jwt";
let config = JwtConfig {
secret: data_encoding::BASE64.encode(secret),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
let exp = chrono_now() + 3600;
let token = create_test_token("user123", secret, Some(exp));
let bearer_token = format!("Bearer {}", token);
let result = validator.validate_and_extract_subject(&bearer_token);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "user123");
}
#[test]
fn test_expired_token() {
let secret = b"test-secret-key-for-jwt";
let config = JwtConfig {
secret: data_encoding::BASE64.encode(secret),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
let exp = chrono_now() - 3600; let token = create_test_token("user123", secret, Some(exp));
let result = validator.validate_and_extract_subject(&token);
assert!(matches!(result, Err(JwtAuthError::TokenExpired)));
}
#[test]
fn test_invalid_signature() {
let secret = b"test-secret-key-for-jwt";
let wrong_secret = b"wrong-secret-key-for-jwt";
let config = JwtConfig {
secret: data_encoding::BASE64.encode(secret),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
let exp = chrono_now() + 3600;
let token = create_test_token("user123", wrong_secret, Some(exp));
let result = validator.validate_and_extract_subject(&token);
assert!(matches!(result, Err(JwtAuthError::InvalidSignature)));
}
#[test]
fn test_invalid_token_format() {
let secret = b"test-secret-key-for-jwt";
let config = JwtConfig {
secret: data_encoding::BASE64.encode(secret),
algorithm: Algorithm::HS512,
};
let validator = JwtValidator::from_config(&config);
let result = validator.validate_and_extract_subject("not-a-jwt-token");
assert!(matches!(result, Err(JwtAuthError::InvalidFormat(_))));
}
#[test]
fn test_subject_validator_hook() {
let layer = JwtAuthLayer::new(&JwtConfig::default())
.with_subject_validator(|s| s.len() == 8 && s.chars().all(|c| c.is_ascii_alphanumeric()));
assert!(layer.subject_validator.is_some());
let validate = layer.subject_validator.unwrap();
assert!(validate("abc12345"));
assert!(!validate("short"));
assert!(!validate("has spaces"));
}
}