1use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::jwk::KeyAlgorithm;
8use jsonwebtoken::{
9 Algorithm, DecodingKey, Header,
10 jwk::{Jwk, JwkSet},
11};
12use parking_lot::RwLock;
13use reqwest::{Client as ReqwestClient, StatusCode};
14use url::Url;
15
16use crate::errors::AuthError;
17
18#[derive(Clone, Debug)]
20pub struct JwksCache {
21 pub jwks: JwkSet,
22 pub fetched_at: Instant,
23 pub ttl: Duration,
24}
25
26#[derive(Debug)]
47pub struct KeyResolver {
48 client: ReqwestClient,
49 jwks_cache: RwLock<HashMap<String, JwksCache>>,
50 default_jwks_ttl: Duration,
51}
52
53impl Default for KeyResolver {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl KeyResolver {
60 const STATIC_JWKS_ENTRY: &'static str = "static_jwks";
61
62 pub fn new() -> Self {
64 let client = ReqwestClient::builder()
66 .user_agent("AGNTCY Slim Auth")
67 .build()
68 .expect("Failed to create reqwest client");
69
70 Self {
71 client,
72 jwks_cache: RwLock::new(HashMap::new()),
73 default_jwks_ttl: Duration::from_secs(3600), }
75 }
76
77 pub fn with_jwks(jwks: JwkSet) -> Self {
78 let mut cache = HashMap::new();
80 cache.insert(
81 Self::STATIC_JWKS_ENTRY.to_string(),
82 JwksCache {
83 jwks,
84 fetched_at: Instant::now(),
85 ttl: Duration::from_secs(u64::MAX), },
87 );
88
89 let client = ReqwestClient::builder()
90 .user_agent("AGNTCY Slim Auth")
91 .build()
92 .expect("Failed to create reqwest client");
93
94 Self {
95 client,
96 jwks_cache: RwLock::new(cache),
97 default_jwks_ttl: Duration::from_secs(3600), }
99 }
100
101 pub fn with_jwks_ttl(mut self, ttl: Duration) -> Self {
103 self.default_jwks_ttl = ttl;
104 self
105 }
106
107 pub async fn resolve_key(
118 &self,
119 issuer: &str,
120 token_header: &Header,
121 ) -> Result<DecodingKey, AuthError> {
122 if let Some(cache_entry) = self.jwks_cache.read().get(Self::STATIC_JWKS_ENTRY) {
124 return self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header);
126 }
127
128 if let Ok(cached_key) = self.get_cached_key(issuer, token_header) {
130 return Ok(cached_key);
131 }
132
133 let jwks = self.fetch_jwks(issuer).await?;
135
136 self.get_decoded_key_from_jwks(&jwks, token_header)
138 }
139
140 fn jwk_to_decoding_key(&self, jwk: &Jwk) -> Result<DecodingKey, AuthError> {
142 let ret = DecodingKey::from_jwk(jwk)?;
143 Ok(ret)
144 }
145
146 fn key_alg_to_algorithm(&self, alg: &KeyAlgorithm) -> Result<Algorithm, AuthError> {
147 match alg {
148 KeyAlgorithm::HS256 => Ok(Algorithm::HS256),
149 KeyAlgorithm::HS384 => Ok(Algorithm::HS384),
150 KeyAlgorithm::HS512 => Ok(Algorithm::HS512),
151 KeyAlgorithm::ES256 => Ok(Algorithm::ES256),
152 KeyAlgorithm::ES384 => Ok(Algorithm::ES384),
153 KeyAlgorithm::RS256 => Ok(Algorithm::RS256),
154 KeyAlgorithm::RS384 => Ok(Algorithm::RS384),
155 KeyAlgorithm::RS512 => Ok(Algorithm::RS512),
156 KeyAlgorithm::PS256 => Ok(Algorithm::PS256),
157 KeyAlgorithm::PS384 => Ok(Algorithm::PS384),
158 KeyAlgorithm::PS512 => Ok(Algorithm::PS512),
159 KeyAlgorithm::EdDSA => Ok(Algorithm::EdDSA),
160 _ => Err(AuthError::JwtUnsupportedKeyAlgorithm(*alg)),
161 }
162 }
163
164 fn get_decoded_key_from_jwks(
165 &self,
166 jwks: &JwkSet,
167 token_header: &Header,
168 ) -> Result<DecodingKey, AuthError> {
169 if let Some(kid) = &token_header.kid {
171 for key in &jwks.keys {
173 if let Some(id) = &key.common.key_id
174 && id == kid
175 {
176 return self.jwk_to_decoding_key(key);
177 }
178 }
179 } else {
180 for key in &jwks.keys {
182 if let Some(alg) = &key.common.key_algorithm
183 && let Ok(algorithm) = self.key_alg_to_algorithm(alg)
184 {
185 if algorithm == token_header.alg {
187 return self.jwk_to_decoding_key(key);
188 }
189 }
190 }
191 }
192
193 Err(AuthError::JwksNoSuitableKey)
195 }
196
197 pub fn get_cached_key(
199 &self,
200 issuer: &str,
201 token_header: &Header,
202 ) -> Result<DecodingKey, AuthError> {
203 let cache = self.jwks_cache.read();
205
206 if let Some(cache_entry) = cache.get(Self::STATIC_JWKS_ENTRY) {
208 return self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header);
210 }
211
212 let cache_entry = cache.get(issuer);
213 if cache_entry.is_none() {
214 return Err(AuthError::JwksCacheMiss {
215 issuer: issuer.to_string(),
216 });
217 }
218
219 let cache_entry = cache_entry.unwrap();
220
221 if cache_entry.fetched_at.elapsed() > cache_entry.ttl {
222 return Err(AuthError::JwksCacheExpired {
223 issuer: issuer.to_string(),
224 });
225 }
226
227 self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header)
229 }
230
231 async fn fetch_jwks(&self, issuer: &str) -> Result<JwkSet, AuthError> {
236 let jwks_uri = self.build_jwks_uri(issuer).await?;
238
239 let jwks = self.fetch_jwks_from_uri(&jwks_uri).await?;
241
242 self.jwks_cache.write().insert(
244 issuer.to_string(),
245 JwksCache {
246 jwks: jwks.clone(),
247 fetched_at: Instant::now(),
248 ttl: self.default_jwks_ttl,
249 },
250 );
251
252 Ok(jwks)
253 }
254
255 async fn build_jwks_uri(&self, issuer: &str) -> Result<String, AuthError> {
261 let mut issuer_url = Url::parse(issuer)?;
264
265 let mut openid_config_url = issuer_url.clone();
267 let mut openid_path = openid_config_url.path().trim_end_matches('/').to_owned();
268 openid_path.push_str("/.well-known/openid-configuration");
269 openid_config_url.set_path(&openid_path);
270
271 let openid_config_response = self.client.get(openid_config_url.to_string()).send().await;
273
274 if let Ok(response) = openid_config_response
276 && response.status() == StatusCode::OK
277 && let Ok(config) = response.json::<serde_json::Value>().await
278 && let Some(jwks_uri) = config.get("jwks_uri").and_then(|v| v.as_str())
279 {
280 return Ok(jwks_uri.to_string());
281 }
282
283 let mut path = issuer_url.path().trim_end_matches('/').to_owned();
285 path.push_str("/.well-known/jwks.json");
286 issuer_url.set_path(&path);
287
288 Ok(issuer_url.to_string())
289 }
290
291 async fn fetch_jwks_from_uri(&self, uri: &str) -> Result<JwkSet, AuthError> {
293 let response = self.client.get(uri).send().await?;
295
296 if response.status() != StatusCode::OK {
298 return Err(AuthError::JwtFetchJwksFailed(response.status()));
299 }
300
301 let body = response.bytes().await?;
303
304 let jwks: JwkSet = serde_json::from_slice(&body)?;
306
307 Ok(jwks)
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use serde_json::json;
315 use wiremock::matchers::{method, path};
316 use wiremock::{Mock, MockServer, ResponseTemplate};
317
318 async fn create_test_resolver() -> (KeyResolver, MockServer) {
320 let server = MockServer::start().await;
321 let resolver = KeyResolver::new();
322 (resolver, server)
323 }
324
325 #[tokio::test]
326 async fn test_build_jwks_uri_with_openid_discovery() {
327 let (resolver, mock_server) = create_test_resolver().await;
328
329 let jwks_uri = "https://example.com/custom/path/to/jwks.json";
331 Mock::given(method("GET"))
332 .and(path("/.well-known/openid-configuration"))
333 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
334 "issuer": "https://example.com",
335 "jwks_uri": jwks_uri
336 })))
337 .mount(&mock_server)
338 .await;
339
340 let uri = resolver.build_jwks_uri(&mock_server.uri()).await.unwrap();
342 assert_eq!(uri, jwks_uri);
343 }
344
345 #[tokio::test]
346 async fn test_build_jwks_uri_fallback() {
347 let (resolver, mock_server) = create_test_resolver().await;
348
349 Mock::given(method("GET"))
351 .and(path("/.well-known/openid-configuration"))
352 .respond_with(ResponseTemplate::new(404))
353 .mount(&mock_server)
354 .await;
355
356 let uri = resolver.build_jwks_uri(&mock_server.uri()).await.unwrap();
358 assert_eq!(uri, format!("{}/.well-known/jwks.json", mock_server.uri()));
359 }
360
361 #[tokio::test]
362 async fn test_fetch_jwks_from_uri() {
363 let (resolver, mock_server) = create_test_resolver().await;
364
365 let jwks = json!({
367 "keys": [
368 {
369 "kty": "RSA",
370 "kid": "test-key",
371 "n": "some-modulus",
372 "e": "AQAB"
373 }
374 ]
375 });
376
377 Mock::given(method("GET"))
378 .and(path("/.well-known/jwks.json"))
379 .respond_with(ResponseTemplate::new(200).set_body_json(jwks))
380 .mount(&mock_server)
381 .await;
382
383 let jwks_uri = format!("{}/.well-known/jwks.json", mock_server.uri());
385 let fetched_jwks = resolver.fetch_jwks_from_uri(&jwks_uri).await.unwrap();
386
387 assert_eq!(fetched_jwks.keys.len(), 1);
388 assert_eq!(
389 fetched_jwks.keys[0].common.key_id.as_deref(),
390 Some("test-key")
391 );
392 }
393}