use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use dashmap::DashMap;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, TokenData, Validation};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::authn::Authenticator;
use super::error::AuthError;
use super::principal::{AuthMethod, Principal, PrincipalBuilder};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Jwk {
#[serde(rename = "kid")]
key_id: String,
#[serde(rename = "kty")]
key_type: String,
#[serde(rename = "alg")]
algorithm: Option<String>,
#[serde(rename = "use")]
key_use: Option<String>,
n: Option<String>,
e: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct StandardClaims {
sub: String, iss: String, aud: StringOrVec, exp: u64, iat: Option<u64>, nbf: Option<u64>, }
#[derive(Debug, Serialize, Deserialize)]
struct CustomClaims {
#[serde(flatten)]
standard: StandardClaims,
#[serde(default)]
tenant_id: Option<String>,
#[serde(default)]
roles: Vec<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
email: Option<String>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}
impl StringOrVec {
#[cfg(test)]
fn contains(&self, value: &str) -> bool {
match self {
StringOrVec::String(s) => s == value,
StringOrVec::Vec(v) => v.iter().any(|s| s == value),
}
}
}
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub issuer: String,
pub audience: String,
pub jwks_url: String,
pub default_tenant_id: String,
pub tenant_claim: String,
pub roles_claim: String,
pub jwks_cache_ttl: Duration,
pub allowed_algorithms: Vec<Algorithm>,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
issuer: String::new(),
audience: String::new(),
jwks_url: String::new(),
default_tenant_id: "default".to_string(),
tenant_claim: "tenant_id".to_string(),
roles_claim: "roles".to_string(),
jwks_cache_ttl: Duration::from_secs(3600),
allowed_algorithms: vec![Algorithm::RS256],
}
}
}
struct CachedJwks {
jwks: Jwks,
expires_at: SystemTime,
}
pub struct JwtAuthenticator {
config: JwtConfig,
jwks_cache: RwLock<Option<CachedJwks>>,
decoding_keys: DashMap<String, DecodingKey>,
http_client: reqwest::Client,
}
impl std::fmt::Debug for JwtAuthenticator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuthenticator")
.field("config", &self.config)
.field("http_client", &"<reqwest::Client>")
.finish()
}
}
impl JwtAuthenticator {
pub fn new(config: JwtConfig) -> Result<Self, AuthError> {
if config.issuer.is_empty() {
return Err(AuthError::Configuration(
"JWT issuer is required".to_string(),
));
}
if config.audience.is_empty() {
return Err(AuthError::Configuration(
"JWT audience is required".to_string(),
));
}
if config.jwks_url.is_empty() {
return Err(AuthError::Configuration("JWKS URL is required".to_string()));
}
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| {
AuthError::Configuration(format!("Failed to create HTTP client: {}", e))
})?;
Ok(Self {
config,
jwks_cache: RwLock::new(None),
decoding_keys: DashMap::new(),
http_client,
})
}
async fn fetch_jwks(&self) -> Result<Jwks, AuthError> {
let response = self
.http_client
.get(&self.config.jwks_url)
.send()
.await
.map_err(|e| AuthError::External(format!("Failed to fetch JWKS: {}", e)))?;
if !response.status().is_success() {
return Err(AuthError::External(format!(
"JWKS endpoint returned status {}",
response.status()
)));
}
let jwks: Jwks = response
.json()
.await
.map_err(|e| AuthError::External(format!("Failed to parse JWKS: {}", e)))?;
Ok(jwks)
}
async fn get_jwks(&self) -> Result<Jwks, AuthError> {
{
let cache = self.jwks_cache.read();
if let Some(cached) = cache.as_ref() {
if SystemTime::now() < cached.expires_at {
return Ok(cached.jwks.clone());
}
}
}
let jwks = self.fetch_jwks().await?;
self.decoding_keys.clear();
tracing::debug!(
jwks_keys = jwks.keys.len(),
"Refreshed JWKS and purged cached decoding keys"
);
{
let mut cache = self.jwks_cache.write();
*cache = Some(CachedJwks {
jwks: jwks.clone(),
expires_at: SystemTime::now() + self.config.jwks_cache_ttl,
});
}
Ok(jwks)
}
async fn get_decoding_key(&self, kid: &str) -> Result<DecodingKey, AuthError> {
if let Some(key) = self.decoding_keys.get(kid) {
return Ok(key.clone());
}
let jwks = self.get_jwks().await?;
let jwk = jwks
.keys
.iter()
.find(|k| k.key_id == kid)
.ok_or_else(|| AuthError::InvalidToken("Key ID not found in JWKS".to_string()))?;
if jwk.key_type != "RSA" {
return Err(AuthError::InvalidToken(format!(
"Unsupported key type: {}",
jwk.key_type
)));
}
let n = jwk
.n
.as_ref()
.ok_or_else(|| AuthError::InvalidToken("Missing RSA modulus".to_string()))?;
let e = jwk
.e
.as_ref()
.ok_or_else(|| AuthError::InvalidToken("Missing RSA exponent".to_string()))?;
let decoding_key = DecodingKey::from_rsa_components(n, e)
.map_err(|e| AuthError::InvalidToken(format!("Invalid RSA key: {}", e)))?;
self.decoding_keys
.insert(kid.to_string(), decoding_key.clone());
Ok(decoding_key)
}
const MAX_TOKEN_SIZE: usize = 16 * 1024;
async fn validate_token(&self, token: &str) -> Result<TokenData<CustomClaims>, AuthError> {
if token.len() > Self::MAX_TOKEN_SIZE {
return Err(AuthError::InvalidToken(format!(
"Token exceeds maximum size of {} bytes",
Self::MAX_TOKEN_SIZE,
)));
}
let header = decode_header(token)
.map_err(|e| AuthError::InvalidToken(format!("Failed to decode header: {}", e)))?;
let kid = header
.kid
.ok_or_else(|| AuthError::InvalidToken("Missing key ID in token".to_string()))?;
let alg = header.alg;
if !self.config.allowed_algorithms.contains(&alg) {
return Err(AuthError::InvalidToken(format!(
"Algorithm {:?} not allowed",
alg
)));
}
let decoding_key = self.get_decoding_key(&kid).await?;
let mut validation = Validation::new(alg);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.config.audience]);
let token_data = decode::<CustomClaims>(token, &decoding_key, &validation)
.map_err(|e| AuthError::InvalidToken(format!("Token validation failed: {}", e)))?;
Ok(token_data)
}
fn claims_to_principal(&self, claims: CustomClaims) -> Result<Principal, AuthError> {
let tenant_id = claims
.tenant_id
.clone()
.unwrap_or_else(|| self.config.default_tenant_id.clone());
let name = claims
.name
.clone()
.or_else(|| claims.email.clone())
.unwrap_or_else(|| claims.standard.sub.clone());
let mut builder = PrincipalBuilder::new(
claims.standard.sub.clone(),
name,
super::principal::PrincipalType::User,
tenant_id,
AuthMethod::Bearer,
);
if claims.roles.is_empty() {
builder = builder.with_role("user"); } else {
builder = builder.with_roles(claims.roles.clone());
}
Ok(builder.build())
}
fn extract_token(headers: &axum::http::HeaderMap) -> Option<String> {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer ").map(|t| t.to_string()))
}
}
#[async_trait::async_trait]
impl Authenticator for JwtAuthenticator {
async fn authenticate(&self, headers: &axum::http::HeaderMap) -> Result<Principal, AuthError> {
let token = match Self::extract_token(headers) {
Some(t) => t,
None => return Err(AuthError::Unauthenticated), };
let token_data = self.validate_token(&token).await?;
let principal = self.claims_to_principal(token_data.claims)?;
Ok(principal)
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::Bearer
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn test_jwt_config_validation() {
let config = JwtConfig::default();
let result = JwtAuthenticator::new(config);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AuthError::Configuration(_)));
}
#[test]
fn test_extract_token_from_bearer_header() {
use axum::http::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Bearer test_token_123".parse().unwrap(),
);
let token = JwtAuthenticator::extract_token(&headers);
assert_eq!(token, Some("test_token_123".to_string()));
}
#[test]
fn test_extract_token_no_bearer_prefix() {
use axum::http::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"test_token_123".parse().unwrap(),
);
let token = JwtAuthenticator::extract_token(&headers);
assert_eq!(token, None);
}
#[test]
fn test_extract_token_missing_header() {
use axum::http::HeaderMap;
let headers = HeaderMap::new();
let token = JwtAuthenticator::extract_token(&headers);
assert_eq!(token, None);
}
#[test]
fn test_string_or_vec_contains() {
let single = StringOrVec::String("test".to_string());
assert!(single.contains("test"));
assert!(!single.contains("other"));
let multiple = StringOrVec::Vec(vec!["test1".to_string(), "test2".to_string()]);
assert!(multiple.contains("test1"));
assert!(multiple.contains("test2"));
assert!(!multiple.contains("test3"));
}
#[tokio::test]
async fn test_claims_to_principal() {
let config = JwtConfig {
issuer: "https://issuer.example.com".to_string(),
audience: "rustberg-api".to_string(),
jwks_url: "https://issuer.example.com/.well-known/jwks.json".to_string(),
default_tenant_id: "default".to_string(),
..Default::default()
};
let authenticator = JwtAuthenticator::new(config).unwrap();
let claims = CustomClaims {
standard: StandardClaims {
sub: "user123".to_string(),
iss: "https://issuer.example.com".to_string(),
aud: StringOrVec::String("rustberg-api".to_string()),
exp: (Utc::now().timestamp() + 3600) as u64,
iat: Some(Utc::now().timestamp() as u64),
nbf: None,
},
tenant_id: Some("tenant1".to_string()),
roles: vec!["admin".to_string()],
name: Some("Test User".to_string()),
email: Some("test@example.com".to_string()),
extra: HashMap::new(),
};
let principal = authenticator.claims_to_principal(claims).unwrap();
assert_eq!(principal.id(), "user123");
assert_eq!(principal.name(), "Test User");
assert_eq!(principal.tenant_id(), "tenant1");
assert!(principal.roles().contains("admin"));
assert_eq!(principal.roles().len(), 1);
assert_eq!(principal.auth_method(), &AuthMethod::Bearer);
}
#[tokio::test]
async fn test_claims_to_principal_with_defaults() {
let config = JwtConfig {
issuer: "https://issuer.example.com".to_string(),
audience: "rustberg-api".to_string(),
jwks_url: "https://issuer.example.com/.well-known/jwks.json".to_string(),
default_tenant_id: "default".to_string(),
..Default::default()
};
let authenticator = JwtAuthenticator::new(config).unwrap();
let claims = CustomClaims {
standard: StandardClaims {
sub: "user123".to_string(),
iss: "https://issuer.example.com".to_string(),
aud: StringOrVec::String("rustberg-api".to_string()),
exp: (Utc::now().timestamp() + 3600) as u64,
iat: Some(Utc::now().timestamp() as u64),
nbf: None,
},
tenant_id: None,
roles: vec![],
name: None,
email: None,
extra: HashMap::new(),
};
let principal = authenticator.claims_to_principal(claims).unwrap();
assert_eq!(principal.id(), "user123");
assert_eq!(principal.name(), "user123"); assert_eq!(principal.tenant_id(), "default"); assert!(principal.roles().contains("user")); assert_eq!(principal.roles().len(), 1);
assert_eq!(principal.auth_method(), &AuthMethod::Bearer);
}
}