async_oidc_jwt_validator/
lib.rs1use jsonwebtoken::errors::{Error as JwtError, ErrorKind, Result as JwtResult};
2use jsonwebtoken::jwk::{Jwk, JwkSet};
3use jsonwebtoken::{decode, DecodingKey};
4use serde::Deserialize;
5use std::collections::HashMap;
6
7pub use jsonwebtoken::{Algorithm, Validation};
9
10#[derive(Debug, Deserialize)]
12struct OidcDiscovery {
13 issuer: String,
14 jwks_uri: String,
15}
16
17#[derive(Debug, Clone)]
19pub struct OidcConfig {
20 pub issuer_url: String,
21 pub client_id: String,
22 pub jwks_uri: String,
23}
24
25#[derive(Clone)]
27pub struct OidcValidator {
28 config: OidcConfig,
29 jwks_cache: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Jwk>>>,
30}
31
32impl OidcConfig {
33 pub fn new(issuer_url: String, client_id: String, jwks_uri: String) -> Self {
35 Self {
36 issuer_url,
37 client_id,
38 jwks_uri,
39 }
40 }
41
42 pub async fn new_with_discovery(issuer_url: String, client_id: String) -> JwtResult<Self> {
43 let jwks_uri = Self::discover_jwks_uri(&issuer_url).await?;
44 Ok(Self {
45 issuer_url,
46 client_id,
47 jwks_uri,
48 })
49 }
50
51 async fn discover_jwks_uri(issuer_url: &str) -> JwtResult<String> {
52 let discovery_url = format!("{}/.well-known/openid-configuration", issuer_url);
53
54 log::debug!("Fetching OpenID Connect Discovery from: {}", discovery_url);
55
56 let response = reqwest::get(&discovery_url).await.map_err(|e| {
57 JwtError::from(ErrorKind::InvalidRsaKey(format!(
58 "Failed to fetch OIDC discovery document: {}",
59 e
60 )))
61 })?;
62
63 let content_type = response
64 .headers()
65 .get("content-type")
66 .and_then(|value| value.to_str().ok())
67 .unwrap_or_default();
68
69 if !content_type.starts_with("application/json") {
70 return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
71 "Unexpected Content-Type: '{}', expected 'application/json'",
72 content_type
73 ))));
74 }
75
76 if !response.status().is_success() {
77 return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
78 "OIDC discovery request failed with status: {}",
79 response.status()
80 ))));
81 }
82
83 let discovery: OidcDiscovery = response.json().await.map_err(|e| {
84 JwtError::from(ErrorKind::InvalidRsaKey(format!(
85 "Failed to parse OIDC discovery response: {}",
86 e
87 )))
88 })?;
89
90 if discovery.issuer != issuer_url {
91 return Err(JwtError::from(ErrorKind::InvalidIssuer));
92 }
93
94 log::debug!("Discovered JWKS URI: {}", discovery.jwks_uri);
95 Ok(discovery.jwks_uri)
96 }
97}
98
99impl OidcValidator {
100 pub fn new(config: OidcConfig) -> Self {
102 Self {
103 config,
104 jwks_cache: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
105 }
106 }
107
108 async fn fetch_jwks(&self) -> JwtResult<JwkSet> {
109 let jwks_url = self.config.jwks_uri.clone();
110
111 log::debug!("Fetching JWKS from: {}", jwks_url);
112
113 let response = reqwest::get(&jwks_url).await.map_err(|e| {
114 JwtError::from(ErrorKind::InvalidRsaKey(format!(
115 "Failed to fetch JWKS: {}",
116 e
117 )))
118 })?;
119
120 if !response.status().is_success() {
121 return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
122 "JWKS request failed with status: {}",
123 response.status()
124 ))));
125 }
126
127 let jwks: JwkSet = response.json().await.map_err(|e| {
128 JwtError::from(ErrorKind::InvalidRsaKey(format!(
129 "Failed to parse JWKS response: {}",
130 e
131 )))
132 })?;
133
134 log::debug!("Fetched {} keys from JWKS", jwks.keys.len());
135 Ok(jwks)
136 }
137
138 async fn get_jwk(&self, kid: &str) -> JwtResult<Jwk> {
139 {
141 let cache = self.jwks_cache.read().await;
142 if let Some(jwk) = cache.get(kid) {
143 return Ok(jwk.clone());
144 }
145 }
146
147 self.refresh_jwks_cache().await?;
149
150 let cache = self.jwks_cache.read().await;
151 cache
152 .get(kid)
153 .cloned()
154 .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))
155 }
156
157 pub async fn validate_custom<T>(&self, token: &str, validation: &Validation) -> JwtResult<T>
158 where
159 T: for<'de> Deserialize<'de>,
160 {
161 log::debug!("Verifying JWT token");
162
163 let header = jsonwebtoken::decode_header(token)?;
165
166 let kid = header
167 .kid
168 .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))?;
169 log::debug!("Token kid: {}", kid);
170
171 let jwk = self.get_jwk(&kid).await?;
173
174 log::debug!("Found matching key with kid: {}", kid);
175
176 let decoding_key = DecodingKey::from_jwk(&jwk)
177 .map_err(|_e| JwtError::from(ErrorKind::InvalidKeyFormat))?;
178
179 let token_data = decode::<T>(token, &decoding_key, validation)?;
181
182 log::debug!("Token verified successfully");
183 Ok(token_data.claims)
184 }
185
186 pub async fn validate<T>(&self, token: &str) -> JwtResult<T>
187 where
188 T: for<'de> Deserialize<'de>,
189 {
190 log::debug!("Validating JWT token with minimal validation");
191
192 let mut validation = Validation::new(Algorithm::RS256);
194 validation.set_issuer(&[&self.config.issuer_url]);
195 validation.set_audience(&[&self.config.client_id]);
196
197 self.validate_custom(token, &validation).await
198 }
199
200 pub async fn refresh_jwks_cache(&self) -> JwtResult<()> {
202 log::info!("Refreshing JWKS cache");
203 let new_jwks = self.fetch_jwks().await?;
204
205 let needs_update = {
207 let cache = self.jwks_cache.read().await;
208
209 let lengths_are_different = new_jwks.keys.len() != cache.len();
211
212 let has_added_keys = if lengths_are_different {
215 false } else {
217 new_jwks.keys.iter().any(|jwk| {
218 if let Some(kid) = &jwk.common.key_id {
219 !cache.contains_key(kid)
220 } else {
221 false }
223 })
224 };
225 lengths_are_different || has_added_keys
226 }; if needs_update {
230 log::info!("New keys detected, replacing entire cache");
231
232 let mut new_cache = HashMap::new();
234 for jwk in new_jwks.keys {
235 if let Some(kid) = jwk.common.key_id.clone() {
236 log::debug!("Adding key to new cache: {}", kid);
237 new_cache.insert(kid, jwk);
238 }
239 }
240
241 let mut cache = self.jwks_cache.write().await;
243 *cache = new_cache;
244
245 log::info!("Successfully replaced JWKS cache with {} keys", cache.len());
246 } else {
247 log::debug!("No new keys found in JWKS, cache unchanged");
248 }
249
250 Ok(())
251 }
252}