Skip to main content

auth0_integration/services/
validate_token.rs

1use std::collections::HashMap;
2
3use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, TokenData, Validation};
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::sync::RwLock;
7
8use crate::config::Auth0Config;
9use crate::error::AppError;
10use crate::models::Claims;
11
12#[derive(Debug, Deserialize)]
13struct Jwks {
14    keys: Vec<Jwk>,
15}
16
17#[derive(Debug, Deserialize)]
18struct Jwk {
19    kid: String,
20    n: String,
21    e: String,
22}
23
24pub struct TokenValidator {
25    cache: RwLock<HashMap<String, DecodingKey>>,
26}
27
28impl TokenValidator {
29    pub fn new() -> Self {
30        Self {
31            cache: RwLock::new(HashMap::new()),
32        }
33    }
34
35    pub async fn validate(&self, token: &str, config: &Auth0Config) -> Result<TokenData<Claims>, AppError> {
36        let header = decode_header(token)?;
37        let kid = header.kid.ok_or_else(|| AppError::InvalidToken("Missing kid".to_string()))?;
38
39        let validation = build_validation(config);
40
41        // Try with cached key first
42        let cached_key = self.cache.read().await.get(&kid).cloned();
43        if let Some(key) = cached_key {
44            if let Ok(data) = decode::<Claims>(token, &key, &validation) {
45                return Ok(data);
46            }
47        }
48
49        // Cache miss or validation failed → refresh JWKS and retry once
50        self.refresh_cache(config).await?;
51
52        let key = self.cache.read().await.get(&kid).cloned()
53            .ok_or_else(|| AppError::InvalidToken(format!("No JWK found for kid: {kid}")))?;
54
55        decode::<Claims>(token, &key, &validation).map_err(AppError::Jwt)
56    }
57
58    async fn refresh_cache(&self, config: &Auth0Config) -> Result<(), AppError> {
59        let jwks: Jwks = Client::new()
60            .get(config.auth0_jwks_uri())
61            .send()
62            .await?
63            .json()
64            .await?;
65
66        let mut cache = self.cache.write().await;
67        for jwk in jwks.keys {
68            let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e).map_err(AppError::Jwt)?;
69            cache.insert(jwk.kid, key);
70        }
71
72        Ok(())
73    }
74}
75
76fn build_validation(config: &Auth0Config) -> Validation {
77    let mut validation = Validation::new(Algorithm::RS256);
78    validation.set_issuer(&[config.auth0_issuer()]);
79    validation.set_audience(&[&config.auth0_audience]);
80    validation
81}