auth0_integration/services/
validate_token.rs1use std::collections::HashMap;
2
3use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, TokenData, Validation};
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::sync::RwLock;
7
8use crate::config::Auth0Config;
9use crate::error::AppError;
10use crate::models::Claims;
11
12#[derive(Debug, Deserialize)]
13struct Jwks {
14 keys: Vec<Jwk>,
15}
16
17#[derive(Debug, Deserialize)]
18struct Jwk {
19 kid: String,
20 n: String,
21 e: String,
22}
23
24pub struct TokenValidator {
25 cache: RwLock<HashMap<String, DecodingKey>>,
26}
27
28impl TokenValidator {
29 pub fn new() -> Self {
30 Self {
31 cache: RwLock::new(HashMap::new()),
32 }
33 }
34
35 pub async fn validate(&self, token: &str, config: &Auth0Config) -> Result<TokenData<Claims>, AppError> {
36 let header = decode_header(token)?;
37 let kid = header.kid.ok_or_else(|| AppError::InvalidToken("Missing kid".to_string()))?;
38
39 let validation = build_validation(config);
40
41 let cached_key = self.cache.read().await.get(&kid).cloned();
43 if let Some(key) = cached_key {
44 if let Ok(data) = decode::<Claims>(token, &key, &validation) {
45 return Ok(data);
46 }
47 }
48
49 self.refresh_cache(config).await?;
51
52 let key = self.cache.read().await.get(&kid).cloned()
53 .ok_or_else(|| AppError::InvalidToken(format!("No JWK found for kid: {kid}")))?;
54
55 decode::<Claims>(token, &key, &validation).map_err(AppError::Jwt)
56 }
57
58 async fn refresh_cache(&self, config: &Auth0Config) -> Result<(), AppError> {
59 let jwks: Jwks = Client::new()
60 .get(config.auth0_jwks_uri())
61 .send()
62 .await?
63 .json()
64 .await?;
65
66 let mut cache = self.cache.write().await;
67 for jwk in jwks.keys {
68 let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e).map_err(AppError::Jwt)?;
69 cache.insert(jwk.kid, key);
70 }
71
72 Ok(())
73 }
74}
75
76fn build_validation(config: &Auth0Config) -> Validation {
77 let mut validation = Validation::new(Algorithm::RS256);
78 validation.set_issuer(&[config.auth0_issuer()]);
79 validation.set_audience(&[&config.auth0_audience]);
80 validation
81}