1use super::{AuthError, AuthenticatedPrincipal};
28use jsonwebtoken::jwk::{Jwk, JwkSet};
29use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
30use serde_json::Value;
31use solo_core::TenantId;
32use std::collections::HashMap;
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37#[derive(Debug, Clone)]
40pub struct OidcConfig {
41 pub discovery_url: String,
42 pub audience: String,
43 pub tenant_claim_name: String,
44}
45
46struct CachedJwks {
50 keys: HashMap<String, KeyEntry>,
51 fetched_at: Instant,
52 ttl: Duration,
53}
54
55struct KeyEntry {
62 key: DecodingKey,
63 algorithm: Algorithm,
64}
65
66#[derive(Clone)]
69pub struct OidcValidator {
70 config: OidcConfig,
71 http_client: reqwest::Client,
72 jwks_cache: Arc<RwLock<Option<CachedJwks>>>,
73}
74
75impl std::fmt::Debug for OidcValidator {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("OidcValidator")
78 .field("config", &self.config)
79 .field("jwks_cache", &"<RwLock>")
80 .finish()
81 }
82}
83
84impl OidcValidator {
85 pub fn new(config: OidcConfig) -> Self {
86 let http_client = reqwest::Client::builder()
87 .timeout(Duration::from_secs(10))
88 .build()
89 .expect("reqwest client");
90 Self {
91 config,
92 http_client,
93 jwks_cache: Arc::new(RwLock::new(None)),
94 }
95 }
96
97 #[cfg(test)]
101 pub fn with_http_client(config: OidcConfig, http_client: reqwest::Client) -> Self {
102 Self {
103 config,
104 http_client,
105 jwks_cache: Arc::new(RwLock::new(None)),
106 }
107 }
108
109 pub async fn validate(
111 &self,
112 header: Option<&str>,
113 ) -> Result<AuthenticatedPrincipal, AuthError> {
114 let header = header.ok_or(AuthError::MissingAuthHeader)?;
115 let token = header
116 .strip_prefix("Bearer ")
117 .ok_or(AuthError::MalformedAuthHeader)?;
118
119 let jwt_header = decode_header(token).map_err(|e| AuthError::InvalidOidcToken {
121 reason: format!("decode header: {e}"),
122 })?;
123 let kid = jwt_header
124 .kid
125 .clone()
126 .ok_or_else(|| AuthError::InvalidOidcToken {
127 reason: "missing kid in token header".to_string(),
128 })?;
129
130 let entry = self.get_key(&kid).await?;
132
133 if entry.algorithm != jwt_header.alg {
138 return Err(AuthError::InvalidOidcToken {
139 reason: format!(
140 "token alg {:?} does not match JWK alg {:?}",
141 jwt_header.alg, entry.algorithm
142 ),
143 });
144 }
145
146 let mut validation = Validation::new(entry.algorithm);
147 validation.set_audience(&[&self.config.audience]);
148 let token_data: TokenData<Value> =
149 decode(token, &entry.key, &validation).map_err(|e| AuthError::InvalidOidcToken {
150 reason: format!("{e}"),
151 })?;
152
153 let subject = token_data
154 .claims
155 .get("sub")
156 .and_then(|v| v.as_str())
157 .unwrap_or("unknown")
158 .to_string();
159
160 let tenant_claim_str = token_data
161 .claims
162 .get(&self.config.tenant_claim_name)
163 .and_then(|v| v.as_str())
164 .ok_or_else(|| AuthError::MissingTenantClaim {
165 claim_name: self.config.tenant_claim_name.clone(),
166 })?;
167 let tenant_claim = TenantId::new(tenant_claim_str.to_string())?;
168
169 let scopes = token_data
170 .claims
171 .get("scope")
172 .and_then(|v| v.as_str())
173 .map(|s| s.split_whitespace().map(String::from).collect())
174 .unwrap_or_default();
175
176 Ok(AuthenticatedPrincipal {
177 subject,
178 tenant_claim: Some(tenant_claim),
179 scopes,
180 claims: token_data.claims,
181 })
182 }
183
184 async fn get_key(&self, kid: &str) -> Result<KeyEntry, AuthError> {
188 {
190 let cache = self.jwks_cache.read().await;
191 if let Some(c) = cache.as_ref()
192 && c.fetched_at.elapsed() < c.ttl
193 && let Some(entry) = c.keys.get(kid)
194 {
195 return Ok(KeyEntry {
196 key: entry.key.clone(),
197 algorithm: entry.algorithm,
198 });
199 }
200 }
202
203 self.refresh_cache().await?;
205 let cache = self.jwks_cache.read().await;
206 cache
207 .as_ref()
208 .and_then(|c| c.keys.get(kid))
209 .map(|entry| KeyEntry {
210 key: entry.key.clone(),
211 algorithm: entry.algorithm,
212 })
213 .ok_or_else(|| AuthError::Jwks(format!("kid '{kid}' not found in JWKS")))
214 }
215
216 async fn refresh_cache(&self) -> Result<(), AuthError> {
220 let discovery_resp = self
221 .http_client
222 .get(&self.config.discovery_url)
223 .send()
224 .await
225 .map_err(|e| AuthError::Discovery(format!("{e}")))?
226 .error_for_status()
227 .map_err(|e| AuthError::Discovery(format!("{e}")))?;
228
229 let ttl = parse_max_age(
232 discovery_resp
233 .headers()
234 .get("cache-control")
235 .and_then(|h| h.to_str().ok()),
236 )
237 .unwrap_or(Duration::from_secs(3600));
238
239 let body: Value = discovery_resp
240 .json()
241 .await
242 .map_err(|e| AuthError::Discovery(format!("{e}")))?;
243 let jwks_uri = body
244 .get("jwks_uri")
245 .and_then(|v| v.as_str())
246 .ok_or_else(|| AuthError::Discovery("discovery missing jwks_uri".to_string()))?;
247
248 let jwks: JwkSet = self
249 .http_client
250 .get(jwks_uri)
251 .send()
252 .await
253 .map_err(|e| AuthError::Jwks(format!("{e}")))?
254 .error_for_status()
255 .map_err(|e| AuthError::Jwks(format!("{e}")))?
256 .json()
257 .await
258 .map_err(|e| AuthError::Jwks(format!("{e}")))?;
259
260 let mut keys = HashMap::new();
261 for jwk in jwks.keys.iter() {
262 let Some(kid) = jwk.common.key_id.as_deref() else {
263 continue;
264 };
265 let Some(algorithm) = jwk_algorithm(jwk) else {
266 continue;
267 };
268 let key = match DecodingKey::from_jwk(jwk) {
269 Ok(k) => k,
270 Err(_) => continue,
271 };
272 keys.insert(kid.to_string(), KeyEntry { key, algorithm });
273 }
274
275 let mut cache = self.jwks_cache.write().await;
276 *cache = Some(CachedJwks {
277 keys,
278 fetched_at: Instant::now(),
279 ttl,
280 });
281 Ok(())
282 }
283}
284
285fn jwk_algorithm(jwk: &Jwk) -> Option<Algorithm> {
290 use jsonwebtoken::jwk::KeyAlgorithm;
291 match jwk.common.key_algorithm? {
292 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
293 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
294 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
295 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
296 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
297 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
298 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
299 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
300 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
301 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
302 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
303 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
304 _ => None,
306 }
307}
308
309fn parse_max_age(header: Option<&str>) -> Option<Duration> {
312 let h = header?;
313 for part in h.split(',').map(str::trim) {
314 if let Some(rest) = part.strip_prefix("max-age=")
315 && let Ok(n) = rest.parse::<u64>()
316 {
317 return Some(Duration::from_secs(n));
318 }
319 }
320 None
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use jsonwebtoken::{EncodingKey, Header};
327 use serde_json::json;
328 use wiremock::matchers::{method, path};
329 use wiremock::{Mock, MockServer, ResponseTemplate};
330
331 struct FakeIdp {
347 server: MockServer,
348 signing_secret: Vec<u8>,
349 signing_kid: String,
350 }
351
352 impl FakeIdp {
353 async fn start(signing_kid: &str, cache_max_age_secs: Option<u64>) -> Self {
357 let server = MockServer::start().await;
358 let secret = b"fixture-secret-bytes-for-hmac-tests".to_vec();
359 let kid = signing_kid.to_string();
360
361 let discovery_body = json!({
362 "issuer": server.uri(),
363 "jwks_uri": format!("{}/jwks", server.uri()),
364 });
365 let mut discovery_resp = ResponseTemplate::new(200).set_body_json(discovery_body);
366 if let Some(secs) = cache_max_age_secs {
367 discovery_resp = discovery_resp
368 .insert_header("cache-control", format!("max-age={secs}").as_str());
369 }
370 Mock::given(method("GET"))
371 .and(path("/.well-known/openid-configuration"))
372 .respond_with(discovery_resp)
373 .mount(&server)
374 .await;
375
376 let jwks_body = json!({
377 "keys": [
378 {
379 "kty": "oct",
380 "kid": &kid,
381 "alg": "HS256",
382 "k": base64_url(&secret),
383 }
384 ]
385 });
386 Mock::given(method("GET"))
387 .and(path("/jwks"))
388 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
389 .mount(&server)
390 .await;
391
392 Self {
393 server,
394 signing_secret: secret,
395 signing_kid: kid,
396 }
397 }
398
399 async fn rotate_to(&mut self, new_kid: &str, new_secret: &[u8]) {
404 let jwks_body = json!({
405 "keys": [
406 {
407 "kty": "oct",
408 "kid": new_kid,
409 "alg": "HS256",
410 "k": base64_url(new_secret),
411 }
412 ]
413 });
414 self.server.reset().await;
415 let discovery_body = json!({
417 "issuer": self.server.uri(),
418 "jwks_uri": format!("{}/jwks", self.server.uri()),
419 });
420 Mock::given(method("GET"))
421 .and(path("/.well-known/openid-configuration"))
422 .respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
423 .mount(&self.server)
424 .await;
425 Mock::given(method("GET"))
426 .and(path("/jwks"))
427 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
428 .mount(&self.server)
429 .await;
430 self.signing_secret = new_secret.to_vec();
431 self.signing_kid = new_kid.to_string();
432 }
433
434 fn mint(&self, claims_override: Value) -> String {
438 self.mint_with_kid(&self.signing_kid, &self.signing_secret, claims_override)
439 }
440
441 fn mint_with_kid(&self, kid: &str, secret: &[u8], claims_override: Value) -> String {
444 let mut header = Header::new(Algorithm::HS256);
445 header.kid = Some(kid.to_string());
446 let now = std::time::SystemTime::now()
447 .duration_since(std::time::UNIX_EPOCH)
448 .unwrap()
449 .as_secs();
450 let mut claims = json!({
451 "iss": self.server.uri(),
452 "sub": "test-subject",
453 "aud": "test-audience",
454 "exp": now + 600,
455 "iat": now,
456 "solo_tenant": "default",
457 });
458 if let (Value::Object(c), Value::Object(o)) = (&mut claims, claims_override) {
459 for (k, v) in o {
460 c.insert(k, v);
461 }
462 }
463 jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret))
464 .expect("encode")
465 }
466 }
467
468 fn base64_url(bytes: &[u8]) -> String {
470 use base64::Engine;
471 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
472 }
473
474 fn make_validator(server_uri: &str) -> OidcValidator {
475 OidcValidator::with_http_client(
476 OidcConfig {
477 discovery_url: format!("{server_uri}/.well-known/openid-configuration"),
478 audience: "test-audience".to_string(),
479 tenant_claim_name: "solo_tenant".to_string(),
480 },
481 reqwest::Client::builder()
482 .timeout(Duration::from_secs(2))
483 .build()
484 .unwrap(),
485 )
486 }
487
488 #[tokio::test]
489 async fn oidc_happy_path() {
490 let idp = FakeIdp::start("test-kid-1", None).await;
491 let validator = make_validator(&idp.server.uri());
492
493 let token = idp.mint(json!({ "solo_tenant": "tenant-a" }));
494 let principal = validator
495 .validate(Some(&format!("Bearer {token}")))
496 .await
497 .expect("validate");
498 assert_eq!(principal.subject, "test-subject");
499 assert_eq!(
500 principal.tenant_claim,
501 Some(TenantId::new("tenant-a").unwrap())
502 );
503 }
504
505 #[tokio::test]
506 async fn oidc_key_rotation() {
507 let mut idp = FakeIdp::start("old-kid", None).await;
508 let validator = make_validator(&idp.server.uri());
509
510 let warmup_token = idp.mint(json!({}));
512 let _ = validator
513 .validate(Some(&format!("Bearer {warmup_token}")))
514 .await
515 .expect("warmup");
516
517 let new_secret = b"new-rotated-secret-32-bytes--here".to_vec();
519 idp.rotate_to("new-kid", &new_secret).await;
520 let token = idp.mint(json!({}));
521
522 let principal = validator
525 .validate(Some(&format!("Bearer {token}")))
526 .await
527 .expect("post-rotation");
528 assert_eq!(principal.subject, "test-subject");
529 }
530
531 #[tokio::test]
532 async fn oidc_invalid_audience() {
533 let idp = FakeIdp::start("kid-aud", None).await;
534 let validator = make_validator(&idp.server.uri());
535
536 let token = idp.mint(json!({ "aud": "wrong-audience" }));
537 let err = validator
538 .validate(Some(&format!("Bearer {token}")))
539 .await
540 .unwrap_err();
541 assert!(
542 matches!(err, AuthError::InvalidOidcToken { .. }),
543 "got {err:?}"
544 );
545 }
546
547 #[tokio::test]
548 async fn oidc_expired_token() {
549 let idp = FakeIdp::start("kid-exp", None).await;
550 let validator = make_validator(&idp.server.uri());
551
552 let now = std::time::SystemTime::now()
554 .duration_since(std::time::UNIX_EPOCH)
555 .unwrap()
556 .as_secs();
557 let token = idp.mint(json!({ "exp": now - 300, "iat": now - 600 }));
558 let err = validator
559 .validate(Some(&format!("Bearer {token}")))
560 .await
561 .unwrap_err();
562 assert!(
563 matches!(err, AuthError::InvalidOidcToken { .. }),
564 "got {err:?}"
565 );
566 }
567
568 #[tokio::test]
569 async fn oidc_missing_tenant_claim() {
570 let idp = FakeIdp::start("kid-no-tenant", None).await;
571 let validator = make_validator(&idp.server.uri());
572
573 let token = idp.mint(json!({ "solo_tenant": null }));
574 let err = validator
575 .validate(Some(&format!("Bearer {token}")))
576 .await
577 .unwrap_err();
578 assert!(
579 matches!(err, AuthError::MissingTenantClaim { ref claim_name } if claim_name == "solo_tenant"),
580 "got {err:?}"
581 );
582 }
583
584 #[tokio::test]
585 async fn oidc_invalid_tenant_claim_format() {
586 let idp = FakeIdp::start("kid-bad-tenant", None).await;
587 let validator = make_validator(&idp.server.uri());
588
589 let token = idp.mint(json!({ "solo_tenant": "INVALID UPPERCASE" }));
591 let err = validator
592 .validate(Some(&format!("Bearer {token}")))
593 .await
594 .unwrap_err();
595 assert!(
596 matches!(err, AuthError::InvalidTenantClaim(_)),
597 "got {err:?}"
598 );
599 }
600
601 #[tokio::test]
602 async fn oidc_jwks_cache_within_ttl_no_refetch() {
603 let server = MockServer::start().await;
606 let secret = b"counted-secret-32-bytes--padding".to_vec();
607 let kid = "counted-kid";
608
609 let discovery_body = json!({
610 "issuer": server.uri(),
611 "jwks_uri": format!("{}/jwks", server.uri()),
612 });
613 Mock::given(method("GET"))
614 .and(path("/.well-known/openid-configuration"))
615 .respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
616 .expect(1) .mount(&server)
618 .await;
619 let jwks_body = json!({
620 "keys": [
621 {
622 "kty": "oct",
623 "kid": kid,
624 "alg": "HS256",
625 "k": base64_url(&secret),
626 }
627 ]
628 });
629 Mock::given(method("GET"))
630 .and(path("/jwks"))
631 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
632 .expect(1) .mount(&server)
634 .await;
635
636 let validator = make_validator(&server.uri());
637
638 let mut header = Header::new(Algorithm::HS256);
641 header.kid = Some(kid.to_string());
642 let now = std::time::SystemTime::now()
643 .duration_since(std::time::UNIX_EPOCH)
644 .unwrap()
645 .as_secs();
646 let claims = json!({
647 "iss": server.uri(),
648 "sub": "subj",
649 "aud": "test-audience",
650 "exp": now + 600,
651 "iat": now,
652 "solo_tenant": "default",
653 });
654 let token =
655 jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(&secret)).unwrap();
656
657 let _ = validator
658 .validate(Some(&format!("Bearer {token}")))
659 .await
660 .expect("first");
661 let _ = validator
662 .validate(Some(&format!("Bearer {token}")))
663 .await
664 .expect("second");
665 }
667
668 #[tokio::test]
669 async fn oidc_jwks_cache_respects_cache_control_max_age() {
670 let server = MockServer::start().await;
672 let secret = b"max-age-secret-bytes-for-tests--".to_vec();
673 let kid = "max-age-kid";
674
675 let discovery_body = json!({
676 "issuer": server.uri(),
677 "jwks_uri": format!("{}/jwks", server.uri()),
678 });
679 Mock::given(method("GET"))
680 .and(path("/.well-known/openid-configuration"))
681 .respond_with(
682 ResponseTemplate::new(200)
683 .insert_header("cache-control", "max-age=0")
684 .set_body_json(discovery_body),
685 )
686 .expect(2) .mount(&server)
688 .await;
689 let jwks_body = json!({
690 "keys": [
691 {
692 "kty": "oct",
693 "kid": kid,
694 "alg": "HS256",
695 "k": base64_url(&secret),
696 }
697 ]
698 });
699 Mock::given(method("GET"))
700 .and(path("/jwks"))
701 .respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
702 .expect(2)
703 .mount(&server)
704 .await;
705
706 let validator = make_validator(&server.uri());
707
708 let mut header = Header::new(Algorithm::HS256);
709 header.kid = Some(kid.to_string());
710 let now = std::time::SystemTime::now()
711 .duration_since(std::time::UNIX_EPOCH)
712 .unwrap()
713 .as_secs();
714 let claims = json!({
715 "iss": server.uri(),
716 "sub": "subj",
717 "aud": "test-audience",
718 "exp": now + 600,
719 "iat": now,
720 "solo_tenant": "default",
721 });
722 let token =
723 jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(&secret)).unwrap();
724
725 let _ = validator
726 .validate(Some(&format!("Bearer {token}")))
727 .await
728 .expect("first");
729 let _ = validator
730 .validate(Some(&format!("Bearer {token}")))
731 .await
732 .expect("second");
733 }
734
735 #[test]
736 fn parse_max_age_handles_typical_headers() {
737 assert_eq!(
739 parse_max_age(Some("max-age=300")),
740 Some(Duration::from_secs(300))
741 );
742 assert_eq!(
744 parse_max_age(Some("public, max-age=86400, must-revalidate")),
745 Some(Duration::from_secs(86400))
746 );
747 assert_eq!(parse_max_age(Some("no-cache, no-store")), None);
749 assert_eq!(parse_max_age(None), None);
750 }
751}