1use std::collections::HashSet;
32use std::sync::Arc;
33use std::time::Duration;
34
35use jsonwebtoken::dangerous::insecure_decode;
36use jsonwebtoken::jwk::JwkSet;
37use jsonwebtoken::{DecodingKey, Validation, decode, decode_header};
38use parking_lot::RwLock;
39use serde::de::DeserializeOwned;
40
41use crate::error::{Error, Result};
42
43#[derive(Clone)]
48pub struct ExternalJwtIssuer {
49 issuer_url: String,
50 audience: HashSet<String>,
51 jwks_uri: String,
52 jwks: Arc<RwLock<JwkSet>>,
53 refresh_interval: Duration,
54}
55
56impl ExternalJwtIssuer {
57 pub async fn discover(
63 issuer_url: String,
64 audience: Vec<String>,
65 refresh_secs: u64,
66 ) -> Result<Self> {
67 let trimmed = issuer_url.trim_end_matches('/').to_string();
68 let discovery_url = format!("{trimmed}/.well-known/openid-configuration");
69
70 let client = reqwest::Client::builder()
71 .timeout(Duration::from_secs(10))
72 .build()
73 .map_err(|e| Error::Oidc(format!("build http client: {e}")))?;
74
75 let metadata: serde_json::Value = client
76 .get(&discovery_url)
77 .send()
78 .await
79 .map_err(|e| Error::Oidc(format!("discover {discovery_url}: {e}")))?
80 .error_for_status()
81 .map_err(|e| Error::Oidc(format!("discover {discovery_url}: {e}")))?
82 .json()
83 .await
84 .map_err(|e| Error::Oidc(format!("parse {discovery_url}: {e}")))?;
85
86 let jwks_uri = metadata
87 .get("jwks_uri")
88 .and_then(|v| v.as_str())
89 .ok_or_else(|| {
90 Error::Oidc(format!("{discovery_url} missing `jwks_uri` field"))
91 })?
92 .to_string();
93
94 let jwks = fetch_jwks(&client, &jwks_uri).await?;
95
96 let verifier = Self {
97 issuer_url: trimmed.clone(),
98 audience: audience.into_iter().collect(),
99 jwks_uri: jwks_uri.clone(),
100 jwks: Arc::new(RwLock::new(jwks)),
101 refresh_interval: Duration::from_secs(refresh_secs.max(60)),
102 };
103
104 verifier.spawn_refresh(client);
105 Ok(verifier)
106 }
107
108 pub fn issuer(&self) -> &str {
112 &self.issuer_url
113 }
114
115 pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<jsonwebtoken::TokenData<T>> {
121 let header = decode_header(token)
122 .map_err(|e| Error::Oidc(format!("decode jwt header: {e}")))?;
123 let kid = header
124 .kid
125 .as_ref()
126 .ok_or_else(|| Error::Oidc("jwt header missing `kid`".to_string()))?;
127
128 let jwk = {
129 let jwks = self.jwks.read();
130 jwks.find(kid).cloned()
131 };
132 let jwk = jwk.ok_or_else(|| {
133 Error::Oidc(format!(
134 "kid `{kid}` not in cached jwks for issuer `{}`",
135 self.issuer_url
136 ))
137 })?;
138
139 let key = DecodingKey::from_jwk(&jwk)
140 .map_err(|e| Error::Oidc(format!("build decoding key from jwk: {e}")))?;
141
142 let mut validation = Validation::new(header.alg);
143 validation.set_issuer(&[&self.issuer_url]);
144 if self.audience.is_empty() {
145 validation.validate_aud = false;
148 } else {
149 let aud: Vec<&str> = self.audience.iter().map(String::as_str).collect();
150 validation.set_audience(&aud);
151 }
152
153 decode::<T>(token, &key, &validation)
154 .map_err(|e| Error::Oidc(format!("verify jwt against `{}`: {e}", self.issuer_url)))
155 }
156
157 fn spawn_refresh(&self, client: reqwest::Client) {
158 let jwks = Arc::clone(&self.jwks);
159 let jwks_uri = self.jwks_uri.clone();
160 let interval = self.refresh_interval;
161 let issuer_url = self.issuer_url.clone();
162
163 tokio::spawn(async move {
164 loop {
165 tokio::time::sleep(interval).await;
166 match fetch_jwks(&client, &jwks_uri).await {
167 Ok(fresh) => {
168 *jwks.write() = fresh;
169 tracing::debug!(
170 target: "assay-auth::external_jwt",
171 issuer = %issuer_url,
172 "refreshed jwks"
173 );
174 }
175 Err(e) => {
176 tracing::warn!(
177 target: "assay-auth::external_jwt",
178 issuer = %issuer_url,
179 error = %e,
180 "failed to refresh jwks; keeping previous keys"
181 );
182 }
183 }
184 }
185 });
186 }
187}
188
189async fn fetch_jwks(client: &reqwest::Client, uri: &str) -> Result<JwkSet> {
190 let body: serde_json::Value = client
191 .get(uri)
192 .send()
193 .await
194 .map_err(|e| Error::Oidc(format!("fetch jwks {uri}: {e}")))?
195 .error_for_status()
196 .map_err(|e| Error::Oidc(format!("fetch jwks {uri}: {e}")))?
197 .json()
198 .await
199 .map_err(|e| Error::Oidc(format!("parse jwks {uri}: {e}")))?;
200 serde_json::from_value(body)
201 .map_err(|e| Error::Oidc(format!("decode jwks {uri}: {e}")))
202}
203
204pub fn verify_with_any<T: DeserializeOwned>(
211 issuers: &[ExternalJwtIssuer],
212 token: &str,
213) -> Option<Result<jsonwebtoken::TokenData<T>>> {
214 if issuers.is_empty() {
215 return None;
216 }
217
218 #[derive(serde::Deserialize)]
223 struct IssClaim {
224 iss: String,
225 }
226 let unverified = insecure_decode::<IssClaim>(token).ok()?;
227 let iss = unverified.claims.iss;
228 let trimmed = iss.trim_end_matches('/');
229
230 for issuer in issuers {
231 if issuer.issuer() == trimmed || issuer.issuer() == iss {
232 return Some(issuer.verify::<T>(token));
233 }
234 }
235 None
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
242 use serde::{Deserialize, Serialize};
243
244 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
245 struct TestClaims {
246 iss: String,
247 aud: String,
248 sub: String,
249 exp: usize,
250 }
251
252 fn verifier_for_tests(issuer: &str, audience: Vec<String>, jwks: JwkSet) -> ExternalJwtIssuer {
257 ExternalJwtIssuer {
258 issuer_url: issuer.trim_end_matches('/').to_string(),
259 audience: audience.into_iter().collect(),
260 jwks_uri: format!("{issuer}/jwks"),
261 jwks: Arc::new(RwLock::new(jwks)),
262 refresh_interval: Duration::from_secs(3600),
263 }
264 }
265
266 fn hs256_jwks_with_kid(kid: &str, secret: &[u8]) -> JwkSet {
267 let json = serde_json::json!({
268 "keys": [{
269 "kty": "oct",
270 "use": "sig",
271 "alg": "HS256",
272 "kid": kid,
273 "k": base64_url(secret)
274 }]
275 });
276 serde_json::from_value(json).unwrap()
277 }
278
279 fn base64_url(b: &[u8]) -> String {
280 const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
284 let mut out = Vec::with_capacity(b.len().div_ceil(3) * 4);
285 for chunk in b.chunks(3) {
286 let mut buf = [0u8; 3];
287 buf[..chunk.len()].copy_from_slice(chunk);
288 let n = u32::from_be_bytes([0, buf[0], buf[1], buf[2]]);
289 out.push(T[((n >> 18) & 0x3F) as usize]);
290 out.push(T[((n >> 12) & 0x3F) as usize]);
291 if chunk.len() >= 2 {
292 out.push(T[((n >> 6) & 0x3F) as usize]);
293 }
294 if chunk.len() == 3 {
295 out.push(T[(n & 0x3F) as usize]);
296 }
297 }
298 String::from_utf8(out).expect("ascii")
299 }
300
301 fn issue_test_token(secret: &[u8], kid: &str, claims: &TestClaims) -> String {
302 let mut header = Header::new(Algorithm::HS256);
303 header.kid = Some(kid.to_string());
304 encode(&header, claims, &EncodingKey::from_secret(secret)).unwrap()
305 }
306
307 #[test]
308 fn verifies_token_from_configured_issuer() {
309 let secret = b"unit-test-secret-key-32bytes!!!!";
310 let kid = "test-key-1";
311 let issuer = "https://hydra.example.com";
312 let aud = "test-app";
313 let claims = TestClaims {
314 iss: issuer.to_string(),
315 aud: aud.to_string(),
316 sub: "user-42".to_string(),
317 exp: (std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .unwrap()
320 .as_secs()
321 + 3600) as usize,
322 };
323 let token = issue_test_token(secret, kid, &claims);
324
325 let v = verifier_for_tests(issuer, vec![aud.to_string()], hs256_jwks_with_kid(kid, secret));
326 let out = v.verify::<TestClaims>(&token).unwrap();
327 assert_eq!(out.claims, claims);
328 }
329
330 #[test]
331 fn rejects_token_with_wrong_issuer() {
332 let secret = b"unit-test-secret-key-32bytes!!!!";
333 let kid = "test-key-1";
334 let claims = TestClaims {
335 iss: "https://other.example.com".to_string(),
336 aud: "test-app".to_string(),
337 sub: "user-42".to_string(),
338 exp: (std::time::SystemTime::now()
339 .duration_since(std::time::UNIX_EPOCH)
340 .unwrap()
341 .as_secs()
342 + 3600) as usize,
343 };
344 let token = issue_test_token(secret, kid, &claims);
345
346 let v = verifier_for_tests(
347 "https://hydra.example.com",
348 vec!["test-app".to_string()],
349 hs256_jwks_with_kid(kid, secret),
350 );
351 assert!(v.verify::<TestClaims>(&token).is_err());
352 }
353
354 #[test]
355 fn rejects_token_with_wrong_audience() {
356 let secret = b"unit-test-secret-key-32bytes!!!!";
357 let kid = "test-key-1";
358 let issuer = "https://hydra.example.com";
359 let claims = TestClaims {
360 iss: issuer.to_string(),
361 aud: "some-other-app".to_string(),
362 sub: "user-42".to_string(),
363 exp: (std::time::SystemTime::now()
364 .duration_since(std::time::UNIX_EPOCH)
365 .unwrap()
366 .as_secs()
367 + 3600) as usize,
368 };
369 let token = issue_test_token(secret, kid, &claims);
370
371 let v = verifier_for_tests(
372 issuer,
373 vec!["test-app".to_string()],
374 hs256_jwks_with_kid(kid, secret),
375 );
376 assert!(v.verify::<TestClaims>(&token).is_err());
377 }
378
379 #[test]
380 fn rejects_token_with_unknown_kid() {
381 let secret = b"unit-test-secret-key-32bytes!!!!";
382 let issuer = "https://hydra.example.com";
383 let claims = TestClaims {
384 iss: issuer.to_string(),
385 aud: "test-app".to_string(),
386 sub: "user-42".to_string(),
387 exp: (std::time::SystemTime::now()
388 .duration_since(std::time::UNIX_EPOCH)
389 .unwrap()
390 .as_secs()
391 + 3600) as usize,
392 };
393 let token = issue_test_token(secret, "rotated-key", &claims);
395
396 let v = verifier_for_tests(
397 issuer,
398 vec!["test-app".to_string()],
399 hs256_jwks_with_kid("current-key", secret),
400 );
401 let err = v.verify::<TestClaims>(&token).unwrap_err().to_string();
402 assert!(err.contains("kid"), "error should mention kid: {err}");
403 }
404
405 #[test]
406 fn rejects_expired_token() {
407 let secret = b"unit-test-secret-key-32bytes!!!!";
408 let kid = "test-key-1";
409 let issuer = "https://hydra.example.com";
410 let claims = TestClaims {
411 iss: issuer.to_string(),
412 aud: "test-app".to_string(),
413 sub: "user-42".to_string(),
414 exp: (std::time::SystemTime::now()
416 .duration_since(std::time::UNIX_EPOCH)
417 .unwrap()
418 .as_secs()
419 - 3600) as usize,
420 };
421 let token = issue_test_token(secret, kid, &claims);
422
423 let v = verifier_for_tests(
424 issuer,
425 vec!["test-app".to_string()],
426 hs256_jwks_with_kid(kid, secret),
427 );
428 assert!(v.verify::<TestClaims>(&token).is_err());
429 }
430
431 #[test]
432 fn empty_audience_list_skips_aud_check() {
433 let secret = b"unit-test-secret-key-32bytes!!!!";
434 let kid = "test-key-1";
435 let issuer = "https://hydra.example.com";
436 let claims = TestClaims {
437 iss: issuer.to_string(),
438 aud: "literally-anything".to_string(),
439 sub: "user-42".to_string(),
440 exp: (std::time::SystemTime::now()
441 .duration_since(std::time::UNIX_EPOCH)
442 .unwrap()
443 .as_secs()
444 + 3600) as usize,
445 };
446 let token = issue_test_token(secret, kid, &claims);
447
448 let v = verifier_for_tests(issuer, vec![], hs256_jwks_with_kid(kid, secret));
450 assert!(v.verify::<TestClaims>(&token).is_ok());
451 }
452
453 #[test]
454 fn verify_with_any_routes_by_iss() {
455 let secret_a = b"key-A-secret-32bytes-unit-tests!";
456 let secret_b = b"key-B-secret-32bytes-unit-tests!";
457 let issuer_a = "https://hydra-a.example.com";
458 let issuer_b = "https://hydra-b.example.com";
459
460 let v_a = verifier_for_tests(
461 issuer_a,
462 vec!["test-app".to_string()],
463 hs256_jwks_with_kid("a-key", secret_a),
464 );
465 let v_b = verifier_for_tests(
466 issuer_b,
467 vec!["test-app".to_string()],
468 hs256_jwks_with_kid("b-key", secret_b),
469 );
470
471 let claims_b = TestClaims {
472 iss: issuer_b.to_string(),
473 aud: "test-app".to_string(),
474 sub: "user-42".to_string(),
475 exp: (std::time::SystemTime::now()
476 .duration_since(std::time::UNIX_EPOCH)
477 .unwrap()
478 .as_secs()
479 + 3600) as usize,
480 };
481 let token_b = issue_test_token(secret_b, "b-key", &claims_b);
482
483 let result = verify_with_any::<TestClaims>(&[v_a, v_b], &token_b)
485 .expect("verifier should match issuer_b")
486 .expect("verification should succeed");
487 assert_eq!(result.claims, claims_b);
488 }
489
490 #[test]
491 fn verify_with_any_returns_none_for_unknown_issuer() {
492 let secret = b"unit-test-secret-key-32bytes!!!!";
493 let v = verifier_for_tests(
494 "https://hydra.example.com",
495 vec!["test-app".to_string()],
496 hs256_jwks_with_kid("a-key", secret),
497 );
498 let claims = TestClaims {
499 iss: "https://stranger.example.com".to_string(),
500 aud: "test-app".to_string(),
501 sub: "user-42".to_string(),
502 exp: (std::time::SystemTime::now()
503 .duration_since(std::time::UNIX_EPOCH)
504 .unwrap()
505 .as_secs()
506 + 3600) as usize,
507 };
508 let token = issue_test_token(secret, "a-key", &claims);
509
510 let result = verify_with_any::<TestClaims>(&[v], &token);
511 assert!(result.is_none(), "unknown issuer should fall through");
512 }
513
514 #[test]
515 fn verify_with_any_returns_none_for_empty_issuer_list() {
516 let secret = b"unit-test-secret-key-32bytes!!!!";
517 let claims = TestClaims {
518 iss: "https://anywhere.example.com".to_string(),
519 aud: "test-app".to_string(),
520 sub: "user-42".to_string(),
521 exp: 9999999999,
522 };
523 let token = issue_test_token(secret, "x", &claims);
524 assert!(verify_with_any::<TestClaims>(&[], &token).is_none());
525 }
526}