async_oidc_jwt_validator/
validator.rs1use crate::config::OidcConfig;
2use jsonwebtoken::errors::{Error as JwtError, ErrorKind, Result as JwtResult};
3use jsonwebtoken::jwk::{Jwk, JwkSet};
4use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
5use serde::Deserialize;
6use std::collections::HashMap;
7
8#[derive(Clone)]
10pub struct OidcValidator {
11 config: OidcConfig,
12 jwks_cache: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Jwk>>>,
13}
14
15impl OidcValidator {
16 pub fn new(config: OidcConfig) -> Self {
18 Self {
19 config,
20 jwks_cache: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
21 }
22 }
23
24 async fn fetch_jwks(&self) -> JwtResult<JwkSet> {
25 let jwks_url = self.config.jwks_uri.clone();
26
27 log::debug!("Fetching JWKS from: {}", jwks_url);
28
29 let response = reqwest::get(&jwks_url).await.map_err(|e| {
30 JwtError::from(ErrorKind::InvalidRsaKey(format!(
31 "Failed to fetch JWKS: {}",
32 e
33 )))
34 })?;
35
36 if !response.status().is_success() {
37 return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
38 "JWKS request failed with status: {}",
39 response.status()
40 ))));
41 }
42
43 let jwks: JwkSet = response.json().await.map_err(|e| {
44 JwtError::from(ErrorKind::InvalidRsaKey(format!(
45 "Failed to parse JWKS response: {}",
46 e
47 )))
48 })?;
49
50 log::debug!("Fetched {} keys from JWKS", jwks.keys.len());
51 Ok(jwks)
52 }
53
54 async fn get_jwk(&self, kid: &str) -> JwtResult<Jwk> {
55 {
57 let cache = self.jwks_cache.read().await;
58 if let Some(jwk) = cache.get(kid) {
59 return Ok(jwk.clone());
60 }
61 }
62
63 self.refresh_jwks_cache().await?;
65
66 let cache = self.jwks_cache.read().await;
67 cache
68 .get(kid)
69 .cloned()
70 .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))
71 }
72
73 pub async fn validate_custom<T>(&self, token: &str, validation: &Validation) -> JwtResult<T>
74 where
75 T: for<'de> Deserialize<'de>,
76 {
77 log::debug!("Verifying JWT token");
78
79 let header = jsonwebtoken::decode_header(token)?;
81
82 let kid = header
83 .kid
84 .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))?;
85 log::debug!("Token kid: {}", kid);
86
87 let jwk = self.get_jwk(&kid).await?;
89
90 log::debug!("Found matching key with kid: {}", kid);
91
92 let decoding_key = DecodingKey::from_jwk(&jwk)
93 .map_err(|_e| JwtError::from(ErrorKind::InvalidKeyFormat))?;
94
95 let token_data = decode::<T>(token, &decoding_key, validation)?;
97
98 log::debug!("Token verified successfully");
99 Ok(token_data.claims)
100 }
101
102 pub async fn validate<T>(&self, token: &str) -> JwtResult<T>
103 where
104 T: for<'de> Deserialize<'de>,
105 {
106 log::debug!("Validating JWT token with minimal validation");
107
108 let mut validation = Validation::new(Algorithm::RS256);
110
111 validation.set_issuer(&[&self.config.issuer_url]);
112
113 validation.set_audience(&[&self.config.client_id]);
114
115 self.validate_custom(token, &validation).await
116 }
117
118 pub async fn refresh_jwks_cache(&self) -> JwtResult<()> {
120 log::info!("Refreshing JWKS cache");
121 let new_jwks = self.fetch_jwks().await?;
122
123 let needs_update = {
125 let cache = self.jwks_cache.read().await;
126
127 let lengths_are_different = new_jwks.keys.len() != cache.len();
129
130 let has_added_keys = if lengths_are_different {
133 false } else {
135 new_jwks.keys.iter().any(|jwk| {
136 if let Some(kid) = &jwk.common.key_id {
137 !cache.contains_key(kid)
138 } else {
139 false }
141 })
142 };
143 lengths_are_different || has_added_keys
144 }; if needs_update {
148 log::info!("New keys detected, replacing entire cache");
149
150 let mut new_cache = HashMap::new();
152 for jwk in new_jwks.keys {
153 if let Some(kid) = jwk.common.key_id.clone() {
154 log::debug!("Adding key to new cache: {}", kid);
155 new_cache.insert(kid, jwk);
156 }
157 }
158
159 let mut cache = self.jwks_cache.write().await;
161 *cache = new_cache;
162
163 log::info!("Successfully replaced JWKS cache with {} keys", cache.len());
164 } else {
165 log::debug!("No new keys found in JWKS, cache unchanged");
166 }
167
168 Ok(())
169 }
170}