1use super::{AuthError, AuthenticatedPrincipal};
28use jsonwebtoken::jwk::{Jwk, JwkSet};
29use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
30use serde_json::Value;
31use solo_core::TenantId;
32use std::collections::HashMap;
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37#[derive(Debug, Clone)]
40pub struct OidcConfig {
41 pub discovery_url: String,
42 pub audience: String,
43 pub tenant_claim_name: String,
44}
45
46struct CachedJwks {
50 keys: HashMap<String, KeyEntry>,
51 fetched_at: Instant,
52 ttl: Duration,
53}
54
55struct KeyEntry {
62 key: DecodingKey,
63 algorithm: Algorithm,
64}
65
66#[derive(Clone)]
69pub struct OidcValidator {
70 config: OidcConfig,
71 http_client: reqwest::Client,
72 jwks_cache: Arc<RwLock<Option<CachedJwks>>>,
73}
74
75impl std::fmt::Debug for OidcValidator {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("OidcValidator")
78 .field("config", &self.config)
79 .field("jwks_cache", &"<RwLock>")
80 .finish()
81 }
82}
83
84impl OidcValidator {
85 pub fn new(config: OidcConfig) -> Self {
86 let http_client = reqwest::Client::builder()
87 .timeout(Duration::from_secs(10))
88 .build()
89 .expect("reqwest client");
90 Self {
91 config,
92 http_client,
93 jwks_cache: Arc::new(RwLock::new(None)),
94 }
95 }
96
97 #[cfg(test)]
101 pub fn with_http_client(config: OidcConfig, http_client: reqwest::Client) -> Self {
102 Self {
103 config,
104 http_client,
105 jwks_cache: Arc::new(RwLock::new(None)),
106 }
107 }
108
109 pub async fn validate(
111 &self,
112 header: Option<&str>,
113 ) -> Result<AuthenticatedPrincipal, AuthError> {
114 let header = header.ok_or(AuthError::MissingAuthHeader)?;
115 let token = header
116 .strip_prefix("Bearer ")
117 .ok_or(AuthError::MalformedAuthHeader)?;
118
119 let jwt_header = decode_header(token).map_err(|e| AuthError::InvalidOidcToken {
121 reason: format!("decode header: {e}"),
122 })?;
123 let kid = jwt_header
124 .kid
125 .clone()
126 .ok_or_else(|| AuthError::InvalidOidcToken {
127 reason: "missing kid in token header".to_string(),
128 })?;
129
130 let entry = self.get_key(&kid).await?;
132
133 if entry.algorithm != jwt_header.alg {
138 return Err(AuthError::InvalidOidcToken {
139 reason: format!(
140 "token alg {:?} does not match JWK alg {:?}",
141 jwt_header.alg, entry.algorithm
142 ),
143 });
144 }
145
146 let mut validation = Validation::new(entry.algorithm);
147 validation.set_audience(&[&self.config.audience]);
148 let token_data: TokenData<Value> =
149 decode(token, &entry.key, &validation).map_err(|e| AuthError::InvalidOidcToken {
150 reason: format!("{e}"),
151 })?;
152
153 let subject = token_data
154 .claims
155 .get("sub")
156 .and_then(|v| v.as_str())
157 .unwrap_or("unknown")
158 .to_string();
159
160 let tenant_claim_str = token_data
161 .claims
162 .get(&self.config.tenant_claim_name)
163 .and_then(|v| v.as_str())
164 .ok_or_else(|| AuthError::MissingTenantClaim {
165 claim_name: self.config.tenant_claim_name.clone(),
166 })?;
167 let tenant_claim = TenantId::new(tenant_claim_str.to_string())?;
168
169 let scopes = token_data
170 .claims
171 .get("scope")
172 .and_then(|v| v.as_str())
173 .map(|s| s.split_whitespace().map(String::from).collect())
174 .unwrap_or_default();
175
176 Ok(AuthenticatedPrincipal {
177 subject,
178 tenant_claim: Some(tenant_claim),
179 scopes,
180 claims: token_data.claims,
181 })
182 }
183
184 async fn get_key(&self, kid: &str) -> Result<KeyEntry, AuthError> {
188 {
190 let cache = self.jwks_cache.read().await;
191 if let Some(c) = cache.as_ref()
192 && c.fetched_at.elapsed() < c.ttl
193 && let Some(entry) = c.keys.get(kid)
194 {
195 return Ok(KeyEntry {
196 key: entry.key.clone(),
197 algorithm: entry.algorithm,
198 });
199 }
200 }
202
203 self.refresh_cache().await?;
205 let cache = self.jwks_cache.read().await;
206 cache
207 .as_ref()
208 .and_then(|c| c.keys.get(kid))
209 .map(|entry| KeyEntry {
210 key: entry.key.clone(),
211 algorithm: entry.algorithm,
212 })
213 .ok_or_else(|| AuthError::Jwks(format!("kid '{kid}' not found in JWKS")))
214 }
215
216 async fn refresh_cache(&self) -> Result<(), AuthError> {
220 let discovery_resp = self
221 .http_client
222 .get(&self.config.discovery_url)
223 .send()
224 .await
225 .map_err(|e| AuthError::Discovery(format!("{e}")))?
226 .error_for_status()
227 .map_err(|e| AuthError::Discovery(format!("{e}")))?;
228
229 let ttl = parse_max_age(
232 discovery_resp
233 .headers()
234 .get("cache-control")
235 .and_then(|h| h.to_str().ok()),
236 )
237 .unwrap_or(Duration::from_secs(3600));
238
239 let body: Value = discovery_resp
240 .json()
241 .await
242 .map_err(|e| AuthError::Discovery(format!("{e}")))?;
243 let jwks_uri = body
244 .get("jwks_uri")
245 .and_then(|v| v.as_str())
246 .ok_or_else(|| AuthError::Discovery("discovery missing jwks_uri".to_string()))?;
247
248 let jwks: JwkSet = self
249 .http_client
250 .get(jwks_uri)
251 .send()
252 .await
253 .map_err(|e| AuthError::Jwks(format!("{e}")))?
254 .error_for_status()
255 .map_err(|e| AuthError::Jwks(format!("{e}")))?
256 .json()
257 .await
258 .map_err(|e| AuthError::Jwks(format!("{e}")))?;
259
260 let mut keys = HashMap::new();
261 for jwk in jwks.keys.iter() {
262 let Some(kid) = jwk.common.key_id.as_deref() else {
263 continue;
264 };
265 let Some(algorithm) = jwk_algorithm(jwk) else {
266 continue;
267 };
268 let key = match DecodingKey::from_jwk(jwk) {
269 Ok(k) => k,
270 Err(_) => continue,
271 };
272 keys.insert(
273 kid.to_string(),
274 KeyEntry {
275 key,
276 algorithm,
277 },
278 );
279 }
280
281 let mut cache = self.jwks_cache.write().await;
282 *cache = Some(CachedJwks {
283 keys,
284 fetched_at: Instant::now(),
285 ttl,
286 });
287 Ok(())
288 }
289}
290
291fn jwk_algorithm(jwk: &Jwk) -> Option<Algorithm> {
296 use jsonwebtoken::jwk::KeyAlgorithm;
297 match jwk.common.key_algorithm? {
298 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
299 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
300 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
301 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
302 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
303 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
304 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
305 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
306 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
307 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
308 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
309 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
310 _ => None,
312 }
313}
314
315fn parse_max_age(header: Option<&str>) -> Option<Duration> {
318 let h = header?;
319 for part in h.split(',').map(str::trim) {
320 if let Some(rest) = part.strip_prefix("max-age=")
321 && let Ok(n) = rest.parse::<u64>()
322 {
323 return Some(Duration::from_secs(n));
324 }
325 }
326 None
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use jsonwebtoken::{EncodingKey, Header};
333 use serde_json::json;
334 use wiremock::matchers::{method, path};
335 use wiremock::{Mock, MockServer, ResponseTemplate};
336
337 struct FakeIdp {
353 server: MockServer,
354 signing_secret: Vec<u8>,
355 signing_kid: String,
356 }
357
358 impl FakeIdp {
359 async fn start(
363 signing_kid: &str,
364 cache_max_age_secs: Option<u64>,
365 ) -> Self {
366 let server = MockServer::start().await;
367 let secret = b"fixture-secret-bytes-for-hmac-tests".to_vec();
368 let kid = signing_kid.to_string();
369
370 let discovery_body = json!({
371 "issuer": server.uri(),
372 "jwks_uri": format!("{}/jwks", server.uri()),
373 });
374 let mut discovery_resp = ResponseTemplate::new(200).set_body_json(discovery_body);
375 if let Some(secs) = cache_max_age_secs {
376 discovery_resp = discovery_resp
377 .insert_header("cache-control", format!("max-age={secs}").as_str());
378 }
379 Mock::given(method("GET"))
380 .and(path("/.well-known/openid-configuration"))
381 .respond_with(discovery_resp)
382 .mount(&server)
383 .await;
384
385 let jwks_body = json!({
386 "keys": [
387 {
388 "kty": "oct",
389 "kid": &kid,
390 "alg": "HS256",
391 "k": base64_url(&secret),
392 }
393 ]
394 });
395 Mock::given(method("GET"))
396 .and(path("/jwks"))
397 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
398 .mount(&server)
399 .await;
400
401 Self {
402 server,
403 signing_secret: secret,
404 signing_kid: kid,
405 }
406 }
407
408 async fn rotate_to(&mut self, new_kid: &str, new_secret: &[u8]) {
413 let jwks_body = json!({
414 "keys": [
415 {
416 "kty": "oct",
417 "kid": new_kid,
418 "alg": "HS256",
419 "k": base64_url(new_secret),
420 }
421 ]
422 });
423 self.server.reset().await;
424 let discovery_body = json!({
426 "issuer": self.server.uri(),
427 "jwks_uri": format!("{}/jwks", self.server.uri()),
428 });
429 Mock::given(method("GET"))
430 .and(path("/.well-known/openid-configuration"))
431 .respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
432 .mount(&self.server)
433 .await;
434 Mock::given(method("GET"))
435 .and(path("/jwks"))
436 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
437 .mount(&self.server)
438 .await;
439 self.signing_secret = new_secret.to_vec();
440 self.signing_kid = new_kid.to_string();
441 }
442
443 fn mint(&self, claims_override: Value) -> String {
447 self.mint_with_kid(&self.signing_kid, &self.signing_secret, claims_override)
448 }
449
450 fn mint_with_kid(&self, kid: &str, secret: &[u8], claims_override: Value) -> String {
453 let mut header = Header::new(Algorithm::HS256);
454 header.kid = Some(kid.to_string());
455 let now = std::time::SystemTime::now()
456 .duration_since(std::time::UNIX_EPOCH)
457 .unwrap()
458 .as_secs();
459 let mut claims = json!({
460 "iss": self.server.uri(),
461 "sub": "test-subject",
462 "aud": "test-audience",
463 "exp": now + 600,
464 "iat": now,
465 "solo_tenant": "default",
466 });
467 if let (Value::Object(c), Value::Object(o)) = (&mut claims, claims_override) {
468 for (k, v) in o {
469 c.insert(k, v);
470 }
471 }
472 jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret))
473 .expect("encode")
474 }
475 }
476
477 fn base64_url(bytes: &[u8]) -> String {
479 use base64::Engine;
480 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
481 }
482
483 fn make_validator(server_uri: &str) -> OidcValidator {
484 OidcValidator::with_http_client(
485 OidcConfig {
486 discovery_url: format!("{server_uri}/.well-known/openid-configuration"),
487 audience: "test-audience".to_string(),
488 tenant_claim_name: "solo_tenant".to_string(),
489 },
490 reqwest::Client::builder()
491 .timeout(Duration::from_secs(2))
492 .build()
493 .unwrap(),
494 )
495 }
496
497 #[tokio::test]
498 async fn oidc_happy_path() {
499 let idp = FakeIdp::start("test-kid-1", None).await;
500 let validator = make_validator(&idp.server.uri());
501
502 let token = idp.mint(json!({ "solo_tenant": "tenant-a" }));
503 let principal = validator
504 .validate(Some(&format!("Bearer {token}")))
505 .await
506 .expect("validate");
507 assert_eq!(principal.subject, "test-subject");
508 assert_eq!(
509 principal.tenant_claim,
510 Some(TenantId::new("tenant-a").unwrap())
511 );
512 }
513
514 #[tokio::test]
515 async fn oidc_key_rotation() {
516 let mut idp = FakeIdp::start("old-kid", None).await;
517 let validator = make_validator(&idp.server.uri());
518
519 let warmup_token = idp.mint(json!({}));
521 let _ = validator
522 .validate(Some(&format!("Bearer {warmup_token}")))
523 .await
524 .expect("warmup");
525
526 let new_secret = b"new-rotated-secret-32-bytes--here".to_vec();
528 idp.rotate_to("new-kid", &new_secret).await;
529 let token = idp.mint(json!({}));
530
531 let principal = validator
534 .validate(Some(&format!("Bearer {token}")))
535 .await
536 .expect("post-rotation");
537 assert_eq!(principal.subject, "test-subject");
538 }
539
540 #[tokio::test]
541 async fn oidc_invalid_audience() {
542 let idp = FakeIdp::start("kid-aud", None).await;
543 let validator = make_validator(&idp.server.uri());
544
545 let token = idp.mint(json!({ "aud": "wrong-audience" }));
546 let err = validator
547 .validate(Some(&format!("Bearer {token}")))
548 .await
549 .unwrap_err();
550 assert!(
551 matches!(err, AuthError::InvalidOidcToken { .. }),
552 "got {err:?}"
553 );
554 }
555
556 #[tokio::test]
557 async fn oidc_expired_token() {
558 let idp = FakeIdp::start("kid-exp", None).await;
559 let validator = make_validator(&idp.server.uri());
560
561 let now = std::time::SystemTime::now()
563 .duration_since(std::time::UNIX_EPOCH)
564 .unwrap()
565 .as_secs();
566 let token = idp.mint(json!({ "exp": now - 300, "iat": now - 600 }));
567 let err = validator
568 .validate(Some(&format!("Bearer {token}")))
569 .await
570 .unwrap_err();
571 assert!(
572 matches!(err, AuthError::InvalidOidcToken { .. }),
573 "got {err:?}"
574 );
575 }
576
577 #[tokio::test]
578 async fn oidc_missing_tenant_claim() {
579 let idp = FakeIdp::start("kid-no-tenant", None).await;
580 let validator = make_validator(&idp.server.uri());
581
582 let token = idp.mint(json!({ "solo_tenant": null }));
583 let err = validator
584 .validate(Some(&format!("Bearer {token}")))
585 .await
586 .unwrap_err();
587 assert!(
588 matches!(err, AuthError::MissingTenantClaim { ref claim_name } if claim_name == "solo_tenant"),
589 "got {err:?}"
590 );
591 }
592
593 #[tokio::test]
594 async fn oidc_invalid_tenant_claim_format() {
595 let idp = FakeIdp::start("kid-bad-tenant", None).await;
596 let validator = make_validator(&idp.server.uri());
597
598 let token = idp.mint(json!({ "solo_tenant": "INVALID UPPERCASE" }));
600 let err = validator
601 .validate(Some(&format!("Bearer {token}")))
602 .await
603 .unwrap_err();
604 assert!(
605 matches!(err, AuthError::InvalidTenantClaim(_)),
606 "got {err:?}"
607 );
608 }
609
610 #[tokio::test]
611 async fn oidc_jwks_cache_within_ttl_no_refetch() {
612 let server = MockServer::start().await;
615 let secret = b"counted-secret-32-bytes--padding".to_vec();
616 let kid = "counted-kid";
617
618 let discovery_body = json!({
619 "issuer": server.uri(),
620 "jwks_uri": format!("{}/jwks", server.uri()),
621 });
622 Mock::given(method("GET"))
623 .and(path("/.well-known/openid-configuration"))
624 .respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
625 .expect(1) .mount(&server)
627 .await;
628 let jwks_body = json!({
629 "keys": [
630 {
631 "kty": "oct",
632 "kid": kid,
633 "alg": "HS256",
634 "k": base64_url(&secret),
635 }
636 ]
637 });
638 Mock::given(method("GET"))
639 .and(path("/jwks"))
640 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
641 .expect(1) .mount(&server)
643 .await;
644
645 let validator = make_validator(&server.uri());
646
647 let mut header = Header::new(Algorithm::HS256);
650 header.kid = Some(kid.to_string());
651 let now = std::time::SystemTime::now()
652 .duration_since(std::time::UNIX_EPOCH)
653 .unwrap()
654 .as_secs();
655 let claims = json!({
656 "iss": server.uri(),
657 "sub": "subj",
658 "aud": "test-audience",
659 "exp": now + 600,
660 "iat": now,
661 "solo_tenant": "default",
662 });
663 let token = jsonwebtoken::encode(
664 &header,
665 &claims,
666 &EncodingKey::from_secret(&secret),
667 )
668 .unwrap();
669
670 let _ = validator
671 .validate(Some(&format!("Bearer {token}")))
672 .await
673 .expect("first");
674 let _ = validator
675 .validate(Some(&format!("Bearer {token}")))
676 .await
677 .expect("second");
678 }
680
681 #[tokio::test]
682 async fn oidc_jwks_cache_respects_cache_control_max_age() {
683 let server = MockServer::start().await;
685 let secret = b"max-age-secret-bytes-for-tests--".to_vec();
686 let kid = "max-age-kid";
687
688 let discovery_body = json!({
689 "issuer": server.uri(),
690 "jwks_uri": format!("{}/jwks", server.uri()),
691 });
692 Mock::given(method("GET"))
693 .and(path("/.well-known/openid-configuration"))
694 .respond_with(
695 ResponseTemplate::new(200)
696 .insert_header("cache-control", "max-age=0")
697 .set_body_json(discovery_body),
698 )
699 .expect(2) .mount(&server)
701 .await;
702 let jwks_body = json!({
703 "keys": [
704 {
705 "kty": "oct",
706 "kid": kid,
707 "alg": "HS256",
708 "k": base64_url(&secret),
709 }
710 ]
711 });
712 Mock::given(method("GET"))
713 .and(path("/jwks"))
714 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
715 .expect(2)
716 .mount(&server)
717 .await;
718
719 let validator = make_validator(&server.uri());
720
721 let mut header = Header::new(Algorithm::HS256);
722 header.kid = Some(kid.to_string());
723 let now = std::time::SystemTime::now()
724 .duration_since(std::time::UNIX_EPOCH)
725 .unwrap()
726 .as_secs();
727 let claims = json!({
728 "iss": server.uri(),
729 "sub": "subj",
730 "aud": "test-audience",
731 "exp": now + 600,
732 "iat": now,
733 "solo_tenant": "default",
734 });
735 let token = jsonwebtoken::encode(
736 &header,
737 &claims,
738 &EncodingKey::from_secret(&secret),
739 )
740 .unwrap();
741
742 let _ = validator
743 .validate(Some(&format!("Bearer {token}")))
744 .await
745 .expect("first");
746 let _ = validator
747 .validate(Some(&format!("Bearer {token}")))
748 .await
749 .expect("second");
750 }
751
752 #[test]
753 fn parse_max_age_handles_typical_headers() {
754 assert_eq!(
756 parse_max_age(Some("max-age=300")),
757 Some(Duration::from_secs(300))
758 );
759 assert_eq!(
761 parse_max_age(Some("public, max-age=86400, must-revalidate")),
762 Some(Duration::from_secs(86400))
763 );
764 assert_eq!(parse_max_age(Some("no-cache, no-store")), None);
766 assert_eq!(parse_max_age(None), None);
767 }
768}