use std::collections::HashMap;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, TokenData, Validation};
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::config::Auth0Config;
use crate::error::AppError;
use crate::models::Claims;
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize)]
struct Jwk {
kid: String,
n: String,
e: String,
}
pub struct TokenValidator {
cache: RwLock<HashMap<String, DecodingKey>>,
}
impl TokenValidator {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
pub async fn validate(&self, token: &str, config: &Auth0Config) -> Result<TokenData<Claims>, AppError> {
let header = decode_header(token)?;
let kid = header.kid.ok_or_else(|| AppError::InvalidToken("Missing kid".to_string()))?;
let validation = build_validation(config);
let cached_key = self.cache.read().await.get(&kid).cloned();
if let Some(key) = cached_key {
if let Ok(data) = decode::<Claims>(token, &key, &validation) {
return Ok(data);
}
}
self.refresh_cache(config).await?;
let key = self.cache.read().await.get(&kid).cloned()
.ok_or_else(|| AppError::InvalidToken(format!("No JWK found for kid: {kid}")))?;
decode::<Claims>(token, &key, &validation).map_err(AppError::Jwt)
}
async fn refresh_cache(&self, config: &Auth0Config) -> Result<(), AppError> {
let jwks: Jwks = Client::new()
.get(config.auth0_jwks_uri())
.send()
.await?
.json()
.await?;
let mut cache = self.cache.write().await;
for jwk in jwks.keys {
let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e).map_err(AppError::Jwt)?;
cache.insert(jwk.kid, key);
}
Ok(())
}
}
fn build_validation(config: &Auth0Config) -> Validation {
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[config.auth0_issuer()]);
validation.set_audience(&[&config.auth0_audience]);
validation
}