use axum::http::header;
use jsonwebtoken::{DecodingKey, Validation};
use serde::{Deserialize, Serialize};
pub const SECURITY_HEADERS: [(&str, &str); 7] = [
("X-Content-Type-Options", "nosniff"),
("X-Frame-Options", "DENY"),
("X-XSS-Protection", "0"),
(
"Strict-Transport-Security",
"max-age=63072000; includeSubDomains; preload",
),
(
"Content-Security-Policy",
"default-src 'none'; frame-ancestors 'none'",
),
("Referrer-Policy", "strict-origin-when-cross-origin"),
(
"Permissions-Policy",
"camera=(), microphone=(), geolocation=()",
),
];
pub trait AuthProvider: Send + Sync + 'static {
fn authenticate(
&self,
req: &axum::http::Request<axum::body::Body>,
) -> Result<AuthIdentity, AuthError>;
}
#[derive(Debug, Clone)]
pub struct AuthIdentity {
pub subject: String,
pub scopes: Vec<String>,
pub role: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Missing authorization header")]
MissingToken,
#[error("Invalid token")]
InvalidToken,
#[error("Token expired")]
TokenExpired,
#[error("Insufficient scopes")]
InsufficientScopes,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
#[serde(default)]
pub scopes: ScopeClaim,
#[serde(default)]
pub role: Option<String>,
#[serde(default)]
pub exp: Option<u64>,
#[serde(default)]
pub iat: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct ScopeClaim(pub Vec<String>);
impl Serialize for ScopeClaim {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.0.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for ScopeClaim {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum ScopeValue {
Array(Vec<String>),
String(String),
}
match ScopeValue::deserialize(deserializer)? {
ScopeValue::Array(v) => Ok(ScopeClaim(v)),
ScopeValue::String(s) => {
if s.is_empty() {
Ok(ScopeClaim(Vec::new()))
} else {
Ok(ScopeClaim(
s.split(',').map(|s| s.trim().to_string()).collect(),
))
}
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum JwtAuthError {
#[error("Invalid PEM key: {0}")]
InvalidKey(String),
#[error("Unsupported algorithm: {0}. Supported: HS256, RS256, ES256")]
UnsupportedAlgorithm(String),
#[error("Missing environment variable: {0}")]
MissingEnvVar(String),
}
#[derive(Clone)]
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
impl std::fmt::Debug for JwtAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuth")
.field("algorithm", &self.validation.algorithms)
.finish()
}
}
impl JwtAuth {
pub fn from_env() -> Result<Self, JwtAuthError> {
let algorithm = std::env::var("JWT_ALGORITHM").unwrap_or_else(|_| "HS256".to_string());
match algorithm.to_uppercase().as_str() {
"HS256" => {
let secret =
std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret".to_string());
Ok(Self::new(&secret))
}
"RS256" => {
let pem = std::env::var("JWT_PUBLIC_KEY")
.map_err(|_| JwtAuthError::MissingEnvVar("JWT_PUBLIC_KEY".to_string()))?;
Self::from_rsa_pem(pem.as_bytes())
}
"ES256" => {
let pem = std::env::var("JWT_PUBLIC_KEY")
.map_err(|_| JwtAuthError::MissingEnvVar("JWT_PUBLIC_KEY".to_string()))?;
Self::from_ec_pem(pem.as_bytes())
}
other => Err(JwtAuthError::UnsupportedAlgorithm(other.to_string())),
}
}
pub fn new(secret: &str) -> Self {
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
Self::with_algorithm(decoding_key, jsonwebtoken::Algorithm::HS256)
}
pub fn from_rsa_pem(pem: &[u8]) -> Result<Self, JwtAuthError> {
let decoding_key =
DecodingKey::from_rsa_pem(pem).map_err(|e| JwtAuthError::InvalidKey(e.to_string()))?;
Ok(Self::with_algorithm(
decoding_key,
jsonwebtoken::Algorithm::RS256,
))
}
pub fn from_ec_pem(pem: &[u8]) -> Result<Self, JwtAuthError> {
let decoding_key =
DecodingKey::from_ec_pem(pem).map_err(|e| JwtAuthError::InvalidKey(e.to_string()))?;
Ok(Self::with_algorithm(
decoding_key,
jsonwebtoken::Algorithm::ES256,
))
}
fn with_algorithm(decoding_key: DecodingKey, algorithm: jsonwebtoken::Algorithm) -> Self {
let mut validation = Validation::new(algorithm);
validation.required_spec_claims.clear();
validation.validate_exp = true;
Self {
decoding_key,
validation,
}
}
}
impl AuthProvider for JwtAuth {
fn authenticate(
&self,
req: &axum::http::Request<axum::body::Body>,
) -> Result<AuthIdentity, AuthError> {
let header = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or(AuthError::MissingToken)?;
if !header.starts_with("Bearer ") {
return Err(AuthError::InvalidToken);
}
let token = &header[7..];
if token.is_empty() {
return Err(AuthError::InvalidToken);
}
let token_data =
jsonwebtoken::decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
_ => AuthError::InvalidToken,
})?;
Ok(AuthIdentity {
subject: token_data.claims.sub,
scopes: token_data.claims.scopes.0,
role: token_data.claims.role,
})
}
}