1use jsonwebtoken::{
26 decode, decode_header,
27 jwk::{JwkSet, KeyAlgorithm},
28 Algorithm, DecodingKey, TokenData, Validation,
29};
30use serde::{Deserialize, Serialize};
31use std::{
32 collections::HashMap,
33 time::{Duration, Instant},
34};
35use tokio::sync::RwLock;
36
37use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
38use sha2::{Digest, Sha256};
39
40fn contains_control_chars(s: &str) -> bool {
46 vellaveto_types::has_dangerous_chars(s)
47}
48
49use vellaveto_types::uri_util::normalize_dpop_htu;
53
54#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
56#[serde(rename_all = "snake_case")]
57pub enum DpopMode {
58 #[default]
60 Off,
61 Optional,
63 Required,
65}
66
67#[derive(Debug, Clone)]
69pub struct OAuthConfig {
70 pub issuer: String,
73
74 pub audience: String,
77
78 pub jwks_uri: Option<String>,
81
82 pub required_scopes: Vec<String>,
85
86 pub pass_through: bool,
89
90 pub allowed_algorithms: Vec<Algorithm>,
98
99 pub expected_resource: Option<String>,
103
104 pub clock_skew_leeway: Duration,
107
108 pub require_audience: bool,
111
112 pub dpop_mode: DpopMode,
114
115 pub dpop_allowed_algorithms: Vec<Algorithm>,
117
118 pub dpop_require_ath: bool,
120
121 pub dpop_max_clock_skew: Duration,
123}
124
125pub fn default_allowed_algorithms() -> Vec<Algorithm> {
131 vec![
132 Algorithm::RS256,
133 Algorithm::RS384,
134 Algorithm::RS512,
135 Algorithm::ES256,
136 Algorithm::ES384,
137 Algorithm::PS256,
138 Algorithm::PS384,
139 Algorithm::PS512,
140 Algorithm::EdDSA,
141 ]
142}
143
144pub fn default_dpop_allowed_algorithms() -> Vec<Algorithm> {
148 vec![Algorithm::ES256, Algorithm::EdDSA]
149}
150
151impl OAuthConfig {
152 pub fn effective_jwks_uri(&self) -> String {
154 self.jwks_uri.clone().unwrap_or_else(|| {
155 let base = self.issuer.trim_end_matches('/');
156 format!("{base}/.well-known/jwks.json")
157 })
158 }
159}
160
161pub fn extract_bearer_token(auth_header: &str) -> Result<&str, OAuthError> {
163 let token = if auth_header.len() > 7 && auth_header[..7].eq_ignore_ascii_case("bearer ") {
166 &auth_header[7..]
167 } else {
168 return Err(OAuthError::InvalidFormat);
169 };
170
171 if token.is_empty() {
172 return Err(OAuthError::InvalidFormat);
173 }
174
175 Ok(token)
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct OAuthClaims {
181 #[serde(default)]
183 pub sub: String,
184
185 #[serde(default)]
187 pub iss: String,
188
189 #[serde(default, deserialize_with = "deserialize_aud")]
191 pub aud: Vec<String>,
192
193 #[serde(default)]
195 pub exp: u64,
196
197 #[serde(default)]
199 pub iat: u64,
200
201 #[serde(default)]
203 pub scope: String,
204
205 #[serde(default)]
208 pub resource: Option<String>,
209
210 #[serde(default)]
213 pub cnf: Option<OAuthConfirmationClaim>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct OAuthConfirmationClaim {
219 #[serde(default)]
221 pub jkt: Option<String>,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
232struct DpopClaims {
233 #[serde(default)]
234 htm: String,
235 #[serde(default)]
236 htu: String,
237 #[serde(default)]
238 iat: u64,
239 #[serde(default)]
240 jti: String,
241 #[serde(default)]
242 ath: Option<String>,
243 #[serde(default)]
245 exp: Option<u64>,
246 #[serde(default)]
248 nbf: Option<u64>,
249 #[serde(flatten)]
252 _extra: serde_json::Map<String, serde_json::Value>,
253}
254
255impl OAuthClaims {
256 pub fn scopes(&self) -> Vec<&str> {
258 if self.scope.is_empty() {
259 Vec::new()
260 } else {
261 self.scope.split(' ').filter(|s| !s.is_empty()).collect()
262 }
263 }
264}
265
266fn deserialize_aud<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
268where
269 D: serde::Deserializer<'de>,
270{
271 use serde::de;
272
273 struct AudVisitor;
274
275 impl<'de> de::Visitor<'de> for AudVisitor {
276 type Value = Vec<String>;
277
278 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
279 f.write_str("a string or array of strings")
280 }
281
282 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
283 Ok(vec![v.to_string()])
284 }
285
286 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
287 let mut values = Vec::new();
288 while let Some(v) = seq.next_element::<String>()? {
289 values.push(v);
290 }
291 Ok(values)
292 }
293
294 fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
295 Ok(Vec::new())
296 }
297
298 fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
299 Ok(Vec::new())
300 }
301 }
302
303 deserializer.deserialize_any(AudVisitor)
304}
305
306fn jwk_required_str_field<'a>(
307 obj: &'a serde_json::Map<String, serde_json::Value>,
308 field: &str,
309) -> Result<&'a str, OAuthError> {
310 obj.get(field)
311 .and_then(serde_json::Value::as_str)
312 .ok_or_else(|| OAuthError::InvalidDpopProof(format!("DPoP JWK missing '{field}' field")))
313}
314
315fn dpop_jwk_thumbprint_sha256(jwk: &jsonwebtoken::jwk::Jwk) -> Result<String, OAuthError> {
320 let value = serde_json::to_value(jwk)
321 .map_err(|e| OAuthError::InvalidDpopProof(format!("invalid DPoP JWK: {e}")))?;
322 let obj = value
323 .as_object()
324 .ok_or_else(|| OAuthError::InvalidDpopProof("invalid DPoP JWK object".to_string()))?;
325
326 let kty = jwk_required_str_field(obj, "kty")?;
327
328 let canonical = match kty {
329 "EC" => {
330 let crv = jwk_required_str_field(obj, "crv")?;
331 let x = jwk_required_str_field(obj, "x")?;
332 let y = jwk_required_str_field(obj, "y")?;
333 format!(
334 r#"{{"crv":{},"kty":{},"x":{},"y":{}}}"#,
335 serde_json::to_string(crv).map_err(|e| {
336 OAuthError::InvalidDpopProof(format!("failed to encode JWK curve: {e}"))
337 })?,
338 serde_json::to_string(kty).map_err(|e| {
339 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {e}"))
340 })?,
341 serde_json::to_string(x).map_err(|e| {
342 OAuthError::InvalidDpopProof(format!("failed to encode JWK x: {e}"))
343 })?,
344 serde_json::to_string(y).map_err(|e| {
345 OAuthError::InvalidDpopProof(format!("failed to encode JWK y: {e}"))
346 })?
347 )
348 }
349 "OKP" => {
350 let crv = jwk_required_str_field(obj, "crv")?;
351 let x = jwk_required_str_field(obj, "x")?;
352 format!(
353 r#"{{"crv":{},"kty":{},"x":{}}}"#,
354 serde_json::to_string(crv).map_err(|e| {
355 OAuthError::InvalidDpopProof(format!("failed to encode JWK curve: {e}"))
356 })?,
357 serde_json::to_string(kty).map_err(|e| {
358 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {e}"))
359 })?,
360 serde_json::to_string(x).map_err(|e| {
361 OAuthError::InvalidDpopProof(format!("failed to encode JWK x: {e}"))
362 })?
363 )
364 }
365 "RSA" => {
366 let e = jwk_required_str_field(obj, "e")?;
367 let n = jwk_required_str_field(obj, "n")?;
368 format!(
369 r#"{{"e":{},"kty":{},"n":{}}}"#,
370 serde_json::to_string(e).map_err(|err| {
371 OAuthError::InvalidDpopProof(format!("failed to encode JWK e: {err}"))
372 })?,
373 serde_json::to_string(kty).map_err(|err| {
374 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {err}"))
375 })?,
376 serde_json::to_string(n).map_err(|err| {
377 OAuthError::InvalidDpopProof(format!("failed to encode JWK n: {err}"))
378 })?
379 )
380 }
381 _ => {
382 return Err(OAuthError::InvalidDpopProof(format!(
383 "unsupported DPoP JWK key type '{kty}'"
384 )));
385 }
386 };
387
388 Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(canonical.as_bytes())))
389}
390
391#[derive(Debug, thiserror::Error)]
393pub enum OAuthError {
394 #[error("missing Authorization header")]
395 MissingToken,
396
397 #[error("invalid Authorization header format (expected: Bearer <token>)")]
398 InvalidFormat,
399
400 #[error("JWT validation failed: {0}")]
401 JwtError(#[from] jsonwebtoken::errors::Error),
402
403 #[error("insufficient scope: required {required}, found {found}")]
404 InsufficientScope { required: String, found: String },
405
406 #[error("JWKS fetch failed: {0}")]
407 JwksFetchFailed(String),
408
409 #[error("no matching key found in JWKS for kid '{0}'")]
410 NoMatchingKey(String),
411
412 #[error("disallowed algorithm: {0:?} is not in the allowed list")]
413 DisallowedAlgorithm(Algorithm),
414
415 #[error("token missing 'kid' header but JWKS contains {0} keys — ambiguous key selection")]
416 MissingKid(usize),
417
418 #[error("resource mismatch: token resource '{token}' does not match expected '{expected}' (RFC 8707)")]
419 ResourceMismatch { expected: String, token: String },
420
421 #[error("token missing required 'aud' claim")]
422 MissingAudience,
423
424 #[error("token audience mismatch: expected '{expected}', found '{found}'")]
425 AudienceMismatch { expected: String, found: String },
426
427 #[error("authorization server does not support PKCE (S256)")]
428 PkceNotSupported,
429
430 #[error("missing DPoP proof header")]
431 MissingDpopProof,
432
433 #[error("invalid DPoP proof: {0}")]
434 InvalidDpopProof(String),
435
436 #[error("DPoP replay detected")]
437 DpopReplayDetected,
438
439 #[error("JWT claim contains control characters")]
442 ClaimControlCharacters,
443}
444
445struct CachedJwks {
447 keys: JwkSet,
448 fetched_at: Instant,
449}
450
451pub struct OAuthValidator {
455 config: OAuthConfig,
456 http_client: reqwest::Client,
457 jwks_cache: RwLock<Option<CachedJwks>>,
458 cache_ttl: Duration,
460 dpop_jti_cache: RwLock<HashMap<String, u64>>,
464}
465
466impl OAuthValidator {
467 pub fn new(config: OAuthConfig, http_client: reqwest::Client) -> Self {
471 if config.dpop_mode != DpopMode::Off && !config.dpop_require_ath {
475 tracing::warn!(
476 "DPoP mode is {:?} but dpop_require_ath is false — \
477 access-token binding (RFC 9449 §4.3) is NOT enforced",
478 config.dpop_mode
479 );
480 }
481 Self {
482 config,
483 http_client,
484 jwks_cache: RwLock::new(None),
485 cache_ttl: Duration::from_secs(300), dpop_jti_cache: RwLock::new(HashMap::new()),
487 }
488 }
489
490 pub async fn validate_token(&self, auth_header: &str) -> Result<OAuthClaims, OAuthError> {
494 let token = extract_bearer_token(auth_header)?;
495
496 let header = decode_header(token)?;
498
499 if !self.config.allowed_algorithms.contains(&header.alg) {
506 return Err(OAuthError::DisallowedAlgorithm(header.alg));
507 }
508
509 let kid = header.kid.clone().unwrap_or_default();
510
511 let decoding_key = self.get_decoding_key(&kid, &header.alg).await?;
513
514 let mut validation = Validation::new(header.alg);
516 validation.set_issuer(&[&self.config.issuer]);
517 validation.set_audience(&[&self.config.audience]);
518 validation.validate_exp = true;
519 validation.validate_nbf = true; validation.leeway = self.config.clock_skew_leeway.as_secs();
521
522 let token_data: TokenData<OAuthClaims> = decode(token, &decoding_key, &validation)?;
524 let claims = token_data.claims;
525
526 if contains_control_chars(&claims.sub)
530 || contains_control_chars(&claims.iss)
531 || contains_control_chars(&claims.scope)
532 || claims.aud.iter().any(|a| contains_control_chars(a))
533 || claims
534 .resource
535 .as_deref()
536 .is_some_and(contains_control_chars)
537 || claims
538 .cnf
539 .as_ref()
540 .and_then(|cnf| cnf.jkt.as_deref())
541 .is_some_and(contains_control_chars)
542 {
543 tracing::warn!("SECURITY: Rejecting JWT with control characters in claims");
544 return Err(OAuthError::ClaimControlCharacters);
545 }
546
547 if claims.aud.is_empty() {
548 if self.config.require_audience {
549 return Err(OAuthError::MissingAudience);
550 }
551 } else if !claims.aud.iter().any(|aud| aud == &self.config.audience) {
552 return Err(OAuthError::AudienceMismatch {
553 expected: self.config.audience.clone(),
554 found: claims.aud.join(" "),
555 });
556 }
557
558 if !self.config.required_scopes.is_empty() {
560 let token_scopes = claims.scopes();
561 for required in &self.config.required_scopes {
562 if !token_scopes.contains(&required.as_str()) {
563 return Err(OAuthError::InsufficientScope {
564 required: self.config.required_scopes.join(" "),
565 found: claims.scope.clone(),
566 });
567 }
568 }
569 }
570
571 if let Some(ref expected_resource) = self.config.expected_resource {
575 match &claims.resource {
576 Some(token_resource) if token_resource == expected_resource => {
577 }
579 Some(token_resource) => {
580 return Err(OAuthError::ResourceMismatch {
581 expected: expected_resource.clone(),
582 token: token_resource.clone(),
583 });
584 }
585 None => {
586 return Err(OAuthError::ResourceMismatch {
587 expected: expected_resource.clone(),
588 token: String::new(),
589 });
590 }
591 }
592 }
593
594 if self.config.dpop_mode == DpopMode::Required {
595 let token_jkt = claims
596 .cnf
597 .as_ref()
598 .and_then(|cnf| cnf.jkt.as_deref())
599 .map(str::trim)
600 .filter(|jkt| !jkt.is_empty());
601
602 if token_jkt.is_none() {
603 return Err(OAuthError::InvalidDpopProof(
604 "missing cnf.jkt in access token for required DPoP mode".to_string(),
605 ));
606 }
607 }
608
609 Ok(claims)
610 }
611
612 pub async fn validate_dpop_proof(
614 &self,
615 dpop_header: Option<&str>,
616 access_token: &str,
617 expected_method: &str,
618 expected_uri: &str,
619 token_claims: Option<&OAuthClaims>,
620 ) -> Result<(), OAuthError> {
621 match self.config.dpop_mode {
622 DpopMode::Off => return Ok(()),
623 DpopMode::Optional if dpop_header.is_none() => return Ok(()),
624 DpopMode::Required if dpop_header.is_none() => {
625 return Err(OAuthError::MissingDpopProof)
626 }
627 _ => {}
628 }
629
630 let proof_jwt = dpop_header
631 .map(str::trim)
632 .filter(|v| !v.is_empty())
633 .ok_or(OAuthError::MissingDpopProof)?;
634
635 let header = decode_header(proof_jwt)?;
636
637 if !self.config.dpop_allowed_algorithms.contains(&header.alg) {
638 return Err(OAuthError::DisallowedAlgorithm(header.alg));
639 }
640
641 let has_dpop_typ = header
642 .typ
643 .as_deref()
644 .map(|typ| typ.eq_ignore_ascii_case("dpop+jwt"))
645 .unwrap_or(false);
646 if !has_dpop_typ {
647 return Err(OAuthError::InvalidDpopProof(
648 "missing typ=dpop+jwt header".to_string(),
649 ));
650 }
651
652 let jwk = header.jwk.ok_or_else(|| {
653 OAuthError::InvalidDpopProof("missing embedded JWK in DPoP header".to_string())
654 })?;
655 let decoding_key = DecodingKey::from_jwk(&jwk)
656 .map_err(|e| OAuthError::InvalidDpopProof(format!("invalid DPoP JWK: {e}")))?;
657
658 let mut validation = Validation::new(header.alg);
659 validation.validate_exp = false;
660 validation.validate_nbf = false;
661 validation.required_spec_claims.clear();
662 let token_data: TokenData<DpopClaims> = decode(proof_jwt, &decoding_key, &validation)?;
663 let claims = token_data.claims;
664
665 if contains_control_chars(&claims.htm)
669 || contains_control_chars(&claims.htu)
670 || contains_control_chars(&claims.jti)
671 {
672 return Err(OAuthError::InvalidDpopProof(
673 "DPoP claims contain control or format characters".to_string(),
674 ));
675 }
676
677 if claims.htm.is_empty() || !claims.htm.eq_ignore_ascii_case(expected_method) {
678 return Err(OAuthError::InvalidDpopProof(format!(
679 "htm mismatch: expected '{}', got '{}'",
680 expected_method, claims.htm
681 )));
682 }
683
684 if !claims.htu.is_ascii() {
688 return Err(OAuthError::InvalidDpopProof(
689 "htu contains non-ASCII characters".to_string(),
690 ));
691 }
692
693 if normalize_dpop_htu(&claims.htu) != normalize_dpop_htu(expected_uri) {
699 return Err(OAuthError::InvalidDpopProof(format!(
700 "htu mismatch: expected '{}', got '{}'",
701 expected_uri, claims.htu
702 )));
703 }
704
705 if claims.jti.trim().is_empty() {
706 return Err(OAuthError::InvalidDpopProof("missing jti".to_string()));
707 }
708 if claims.jti.len() > 256 {
710 return Err(OAuthError::InvalidDpopProof(
711 "jti exceeds maximum length".to_string(),
712 ));
713 }
714
715 let now = chrono::Utc::now().timestamp();
716 let iat = i64::try_from(claims.iat)
718 .map_err(|_| OAuthError::InvalidDpopProof("iat out of range".to_string()))?;
719 let skew = self
720 .config
721 .dpop_max_clock_skew
722 .as_secs()
723 .min(i64::MAX as u64) as i64;
724 if (now - iat).abs() > skew {
725 return Err(OAuthError::InvalidDpopProof(format!(
726 "iat outside allowed skew window (iat={}, now={})",
727 claims.iat, now
728 )));
729 }
730
731 if let Some(exp) = claims.exp {
734 let exp_i64 = i64::try_from(exp)
735 .map_err(|_| OAuthError::InvalidDpopProof("exp out of range".to_string()))?;
736 if exp_i64 < now - skew {
737 return Err(OAuthError::InvalidDpopProof(
738 "DPoP proof expired".to_string(),
739 ));
740 }
741 }
742 if let Some(nbf) = claims.nbf {
743 let nbf_i64 = i64::try_from(nbf)
744 .map_err(|_| OAuthError::InvalidDpopProof("nbf out of range".to_string()))?;
745 if nbf_i64 > now + skew {
746 return Err(OAuthError::InvalidDpopProof(
747 "DPoP proof not yet valid (nbf in future)".to_string(),
748 ));
749 }
750 }
751
752 if claims._extra.len() > 20 {
755 return Err(OAuthError::InvalidDpopProof(
756 "DPoP proof contains too many unknown claims".to_string(),
757 ));
758 }
759
760 if self.config.dpop_require_ath {
761 let expected_ath = URL_SAFE_NO_PAD.encode(Sha256::digest(access_token.as_bytes()));
762 match claims.ath.as_deref() {
763 Some(ath) if ath == expected_ath => {}
764 _ => {
765 return Err(OAuthError::InvalidDpopProof(
766 "ath mismatch for access token binding".to_string(),
767 ));
768 }
769 }
770 }
771
772 if let Some(token_jkt) = token_claims
773 .and_then(|c| c.cnf.as_ref())
774 .and_then(|cnf| cnf.jkt.as_deref())
775 .map(str::trim)
776 .filter(|jkt| !jkt.is_empty())
777 {
778 let proof_jkt = dpop_jwk_thumbprint_sha256(&jwk)?;
779 if proof_jkt != token_jkt {
780 return Err(OAuthError::InvalidDpopProof(
781 "cnf.jkt does not match DPoP proof key thumbprint".to_string(),
782 ));
783 }
784 }
785
786 let now_u64 = now.max(0) as u64;
788 let replay_window = std::cmp::max((skew.max(0) as u64).saturating_mul(2), 600);
790 let oldest_allowed = now_u64.saturating_sub(replay_window);
791
792 let replay_key = match claims.ath.as_deref() {
795 Some(ath) if !ath.is_empty() => format!("{}:{}", claims.jti, ath),
796 _ => claims.jti.clone(),
797 };
798 if replay_key.len() > 512 {
799 return Err(OAuthError::InvalidDpopProof(
800 "DPoP replay key exceeds maximum length".to_string(),
801 ));
802 }
803
804 const MAX_JTI_CACHE_SIZE: usize = 8192;
810
811 let mut cache = self.dpop_jti_cache.write().await;
812
813 cache.retain(|_, ts| *ts >= oldest_allowed);
815
816 if cache.contains_key(&replay_key) {
818 return Err(OAuthError::DpopReplayDetected);
819 }
820
821 if cache.len() >= MAX_JTI_CACHE_SIZE {
823 return Err(OAuthError::InvalidDpopProof(
824 "DPoP replay cache at capacity — try again later".to_string(),
825 ));
826 }
827
828 cache.insert(replay_key, now_u64);
829
830 Ok(())
831 }
832
833 async fn get_decoding_key(
839 &self,
840 kid: &str,
841 alg: &Algorithm,
842 ) -> Result<DecodingKey, OAuthError> {
843 {
845 let cache = self.jwks_cache.read().await;
846 if let Some(cached) = cache.as_ref() {
847 if cached.fetched_at.elapsed() < self.cache_ttl {
848 if let Some(key) = find_key_in_jwks(&cached.keys, kid, alg) {
849 return Ok(key);
850 }
851 }
852 }
853 }
854 let mut cache = self.jwks_cache.write().await;
858
859 if let Some(cached) = cache.as_ref() {
861 if cached.fetched_at.elapsed() < self.cache_ttl {
862 if let Some(key) = find_key_in_jwks(&cached.keys, kid, alg) {
863 return Ok(key);
864 }
865 }
866 }
867
868 let jwks = self.fetch_jwks().await?;
870
871 if kid.is_empty() && jwks.keys.len() > 1 {
875 return Err(OAuthError::MissingKid(jwks.keys.len()));
876 }
877
878 let key = find_key_in_jwks(&jwks, kid, alg)
879 .ok_or_else(|| OAuthError::NoMatchingKey(kid.to_string()))?;
880
881 *cache = Some(CachedJwks {
883 keys: jwks,
884 fetched_at: Instant::now(),
885 });
886
887 Ok(key)
888 }
889
890 async fn fetch_jwks(&self) -> Result<JwkSet, OAuthError> {
892 let uri = self.config.effective_jwks_uri();
893
894 tracing::debug!("Fetching JWKS from {}", uri);
895
896 let response = self
897 .http_client
898 .get(&uri)
899 .timeout(Duration::from_secs(10))
900 .send()
901 .await
902 .map_err(|e| OAuthError::JwksFetchFailed(format!("request failed: {e}")))?;
903
904 if !response.status().is_success() {
905 return Err(OAuthError::JwksFetchFailed(format!(
906 "HTTP {}",
907 response.status()
908 )));
909 }
910
911 const MAX_JWKS_BODY_SIZE: usize = 1024 * 1024;
914
915 if let Some(len) = response.content_length() {
917 if len > MAX_JWKS_BODY_SIZE as u64 {
918 return Err(OAuthError::JwksFetchFailed(format!(
919 "JWKS Content-Length {len} exceeds {MAX_JWKS_BODY_SIZE} byte limit"
920 )));
921 }
922 }
923
924 let capacity = std::cmp::min(
927 response.content_length().unwrap_or(8192) as usize,
928 MAX_JWKS_BODY_SIZE,
929 );
930 let mut body = Vec::with_capacity(capacity);
931 let mut response = response;
932 while let Some(chunk) = response
933 .chunk()
934 .await
935 .map_err(|e| OAuthError::JwksFetchFailed(format!("body read failed: {e}")))?
936 {
937 if body.len().saturating_add(chunk.len()) > MAX_JWKS_BODY_SIZE {
938 return Err(OAuthError::JwksFetchFailed(format!(
939 "JWKS response exceeds {MAX_JWKS_BODY_SIZE} byte limit"
940 )));
941 }
942 body.extend_from_slice(&chunk);
943 }
944
945 let jwks: JwkSet = serde_json::from_slice(&body)
946 .map_err(|e| OAuthError::JwksFetchFailed(format!("invalid JWKS JSON: {e}")))?;
947
948 tracing::info!("Fetched {} keys from JWKS endpoint", jwks.keys.len());
949
950 Ok(jwks)
951 }
952
953 pub fn config(&self) -> &OAuthConfig {
955 &self.config
956 }
957}
958
959fn key_algorithm_to_algorithm(ka: &KeyAlgorithm) -> Option<Algorithm> {
964 match ka {
965 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
966 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
967 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
968 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
969 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
970 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
971 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
972 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
973 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
974 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
975 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
976 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
977 _ => None,
979 }
980}
981
982fn find_key_in_jwks(jwks: &JwkSet, kid: &str, alg: &Algorithm) -> Option<DecodingKey> {
991 for key in &jwks.keys {
992 if !kid.is_empty() {
994 match &key.common.key_id {
995 Some(key_kid) if key_kid == kid => {} Some(_) => continue, None => continue, }
999 }
1000
1001 if let Some(ref key_alg) = key.common.key_algorithm {
1003 match key_algorithm_to_algorithm(key_alg) {
1004 Some(mapped) if &mapped == alg => {} _ => continue, }
1007 }
1008
1009 if let Ok(dk) = DecodingKey::from_jwk(key) {
1011 return Some(dk);
1012 }
1013 }
1014 None
1015}
1016
1017pub fn verify_pkce_support(metadata: &serde_json::Value) -> Result<(), OAuthError> {
1050 let supported = metadata
1051 .get("code_challenge_methods_supported")
1052 .and_then(|v| v.as_array())
1053 .map(|arr| arr.iter().any(|m| m.as_str() == Some("S256")))
1054 .unwrap_or(false);
1055
1056 if !supported {
1057 return Err(OAuthError::PkceNotSupported);
1058 }
1059 Ok(())
1060}
1061
1062#[cfg(test)]
1063mod tests {
1064 use super::*;
1065 use vellaveto_types::uri_util::decode_unreserved_percent;
1066
1067 #[test]
1068 fn test_oauth_config_effective_jwks_uri_explicit() {
1069 let config = OAuthConfig {
1070 issuer: "https://auth.example.com".to_string(),
1071 audience: "mcp-server".to_string(),
1072 jwks_uri: Some("https://auth.example.com/keys".to_string()),
1073 required_scopes: vec![],
1074 pass_through: false,
1075 allowed_algorithms: default_allowed_algorithms(),
1076 expected_resource: None,
1077 clock_skew_leeway: Duration::from_secs(30),
1078 require_audience: true,
1079 dpop_mode: DpopMode::Off,
1080 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1081 dpop_require_ath: true,
1082 dpop_max_clock_skew: Duration::from_secs(300),
1083 };
1084 assert_eq!(config.effective_jwks_uri(), "https://auth.example.com/keys");
1085 }
1086
1087 #[test]
1088 fn test_oauth_config_effective_jwks_uri_wellknown() {
1089 let config = OAuthConfig {
1090 issuer: "https://auth.example.com".to_string(),
1091 audience: "mcp-server".to_string(),
1092 jwks_uri: None,
1093 required_scopes: vec![],
1094 pass_through: false,
1095 allowed_algorithms: default_allowed_algorithms(),
1096 expected_resource: None,
1097 clock_skew_leeway: Duration::from_secs(30),
1098 require_audience: true,
1099 dpop_mode: DpopMode::Off,
1100 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1101 dpop_require_ath: true,
1102 dpop_max_clock_skew: Duration::from_secs(300),
1103 };
1104 assert_eq!(
1105 config.effective_jwks_uri(),
1106 "https://auth.example.com/.well-known/jwks.json"
1107 );
1108 }
1109
1110 #[test]
1111 fn test_oauth_config_effective_jwks_uri_trailing_slash() {
1112 let config = OAuthConfig {
1113 issuer: "https://auth.example.com/".to_string(),
1114 audience: "mcp-server".to_string(),
1115 jwks_uri: None,
1116 required_scopes: vec![],
1117 pass_through: false,
1118 allowed_algorithms: default_allowed_algorithms(),
1119 expected_resource: None,
1120 clock_skew_leeway: Duration::from_secs(30),
1121 require_audience: true,
1122 dpop_mode: DpopMode::Off,
1123 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1124 dpop_require_ath: true,
1125 dpop_max_clock_skew: Duration::from_secs(300),
1126 };
1127 assert_eq!(
1128 config.effective_jwks_uri(),
1129 "https://auth.example.com/.well-known/jwks.json"
1130 );
1131 }
1132
1133 #[test]
1134 fn test_oauth_claims_scopes_parsing() {
1135 let claims = OAuthClaims {
1136 sub: "user-123".to_string(),
1137 iss: "https://auth.example.com".to_string(),
1138 aud: vec!["mcp-server".to_string()],
1139 exp: 0,
1140 iat: 0,
1141 scope: "tools.call resources.read admin".to_string(),
1142 resource: None,
1143 cnf: None,
1144 };
1145 let scopes = claims.scopes();
1146 assert_eq!(scopes, vec!["tools.call", "resources.read", "admin"]);
1147 }
1148
1149 #[test]
1150 fn test_oauth_claims_empty_scope() {
1151 let claims = OAuthClaims {
1152 sub: "user-123".to_string(),
1153 iss: "https://auth.example.com".to_string(),
1154 aud: vec![],
1155 exp: 0,
1156 iat: 0,
1157 scope: String::new(),
1158 resource: None,
1159 cnf: None,
1160 };
1161 let scopes = claims.scopes();
1162 assert!(scopes.is_empty());
1163 }
1164
1165 #[test]
1166 fn test_deserialize_aud_string() {
1167 let json = r#"{"sub":"user","aud":"mcp-server","scope":""}"#;
1168 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1169 assert_eq!(claims.aud, vec!["mcp-server"]);
1170 }
1171
1172 #[test]
1173 fn test_deserialize_aud_array() {
1174 let json = r#"{"sub":"user","aud":["mcp-server","other"],"scope":""}"#;
1175 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1176 assert_eq!(claims.aud, vec!["mcp-server", "other"]);
1177 }
1178
1179 #[test]
1180 fn test_oauth_error_display() {
1181 let err = OAuthError::MissingToken;
1182 assert_eq!(err.to_string(), "missing Authorization header");
1183
1184 let err = OAuthError::InsufficientScope {
1185 required: "tools.call admin".to_string(),
1186 found: "tools.call".to_string(),
1187 };
1188 assert!(err.to_string().contains("insufficient scope"));
1189 }
1190
1191 #[test]
1193 fn test_default_allowed_algorithms_excludes_hmac() {
1194 let allowed = default_allowed_algorithms();
1195 assert!(!allowed.contains(&Algorithm::HS256));
1196 assert!(!allowed.contains(&Algorithm::HS384));
1197 assert!(!allowed.contains(&Algorithm::HS512));
1198 }
1199
1200 #[test]
1201 fn test_default_allowed_algorithms_includes_asymmetric() {
1202 let allowed = default_allowed_algorithms();
1203 assert!(allowed.contains(&Algorithm::RS256));
1204 assert!(allowed.contains(&Algorithm::ES256));
1205 assert!(allowed.contains(&Algorithm::PS256));
1206 assert!(allowed.contains(&Algorithm::EdDSA));
1207 }
1208
1209 #[test]
1210 fn test_disallowed_algorithm_error_display() {
1211 let err = OAuthError::DisallowedAlgorithm(Algorithm::HS256);
1212 assert!(err.to_string().contains("disallowed algorithm"));
1213 assert!(err.to_string().contains("HS256"));
1214 }
1215
1216 #[test]
1217 fn test_missing_kid_error_display() {
1218 let err = OAuthError::MissingKid(3);
1219 assert!(err.to_string().contains("missing 'kid'"));
1220 assert!(err.to_string().contains("3 keys"));
1221 }
1222
1223 #[test]
1225 fn test_key_algorithm_to_algorithm_all_signing() {
1226 assert_eq!(
1227 key_algorithm_to_algorithm(&KeyAlgorithm::HS256),
1228 Some(Algorithm::HS256)
1229 );
1230 assert_eq!(
1231 key_algorithm_to_algorithm(&KeyAlgorithm::RS256),
1232 Some(Algorithm::RS256)
1233 );
1234 assert_eq!(
1235 key_algorithm_to_algorithm(&KeyAlgorithm::ES256),
1236 Some(Algorithm::ES256)
1237 );
1238 assert_eq!(
1239 key_algorithm_to_algorithm(&KeyAlgorithm::PS256),
1240 Some(Algorithm::PS256)
1241 );
1242 assert_eq!(
1243 key_algorithm_to_algorithm(&KeyAlgorithm::EdDSA),
1244 Some(Algorithm::EdDSA)
1245 );
1246 }
1247
1248 #[test]
1249 fn test_key_algorithm_to_algorithm_encryption_returns_none() {
1250 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA1_5), None);
1251 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA_OAEP), None);
1252 assert_eq!(
1253 key_algorithm_to_algorithm(&KeyAlgorithm::RSA_OAEP_256),
1254 None
1255 );
1256 }
1257
1258 #[test]
1260 fn test_resource_mismatch_error_display() {
1261 let err = OAuthError::ResourceMismatch {
1262 expected: "https://mcp.example.com".to_string(),
1263 token: "https://other.example.com".to_string(),
1264 };
1265 let msg = err.to_string();
1266 assert!(msg.contains("resource mismatch"));
1267 assert!(msg.contains("https://mcp.example.com"));
1268 assert!(msg.contains("https://other.example.com"));
1269 assert!(msg.contains("RFC 8707"));
1270 }
1271
1272 #[test]
1273 fn test_resource_mismatch_missing_claim_error_display() {
1274 let err = OAuthError::ResourceMismatch {
1275 expected: "https://mcp.example.com".to_string(),
1276 token: String::new(),
1277 };
1278 let msg = err.to_string();
1279 assert!(msg.contains("resource mismatch"));
1280 assert!(msg.contains("https://mcp.example.com"));
1281 }
1282
1283 #[test]
1284 fn test_deserialize_claims_with_resource() {
1285 let json =
1286 r#"{"sub":"user","aud":"mcp-server","scope":"","resource":"https://mcp.example.com"}"#;
1287 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1288 assert_eq!(claims.resource, Some("https://mcp.example.com".to_string()));
1289 }
1290
1291 #[test]
1292 fn test_deserialize_claims_without_resource() {
1293 let json = r#"{"sub":"user","aud":"mcp-server","scope":""}"#;
1294 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1295 assert_eq!(claims.resource, None);
1296 }
1297
1298 #[test]
1299 fn test_deserialize_claims_with_cnf_jkt() {
1300 let json = r#"{"sub":"user","aud":"mcp-server","scope":"","cnf":{"jkt":"thumbprint-123"}}"#;
1301 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1302 let jkt = claims
1303 .cnf
1304 .as_ref()
1305 .and_then(|cnf| cnf.jkt.as_deref())
1306 .expect("cnf.jkt must deserialize");
1307 assert_eq!(jkt, "thumbprint-123");
1308 }
1309
1310 #[test]
1311 fn test_dpop_jwk_thumbprint_sha256_rsa() {
1312 let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(serde_json::json!({
1313 "kty": "RSA",
1314 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
1315 "e": "AQAB"
1316 }))
1317 .expect("valid RSA JWK");
1318
1319 let thumbprint = dpop_jwk_thumbprint_sha256(&jwk).expect("thumbprint should compute");
1320 assert_eq!(thumbprint, "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs");
1321 }
1322
1323 #[test]
1324 fn test_dpop_jwk_thumbprint_rejects_unsupported_key_type() {
1325 let jwk: jsonwebtoken::jwk::Jwk =
1326 serde_json::from_value(serde_json::json!({"kty": "oct", "k": "AQAB"}))
1327 .expect("valid octet JWK");
1328
1329 let err = dpop_jwk_thumbprint_sha256(&jwk).expect_err("octet keys are not valid for DPoP");
1330 assert!(err.to_string().contains("unsupported DPoP JWK key type"));
1331 }
1332
1333 #[test]
1334 fn test_clock_skew_leeway_configurable() {
1335 let config = OAuthConfig {
1336 issuer: "https://auth.example.com".to_string(),
1337 audience: "mcp-server".to_string(),
1338 jwks_uri: None,
1339 required_scopes: vec![],
1340 pass_through: false,
1341 allowed_algorithms: default_allowed_algorithms(),
1342 expected_resource: None,
1343 clock_skew_leeway: Duration::from_secs(60),
1344 require_audience: true,
1345 dpop_mode: DpopMode::Off,
1346 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1347 dpop_require_ath: true,
1348 dpop_max_clock_skew: Duration::from_secs(300),
1349 };
1350 assert_eq!(config.clock_skew_leeway, Duration::from_secs(60));
1351 }
1352
1353 #[test]
1354 fn test_deserialize_missing_aud_yields_empty_vec() {
1355 let json = r#"{"sub":"user","scope":"read"}"#;
1356 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1357 assert!(claims.aud.is_empty());
1358 }
1359
1360 #[test]
1361 fn test_missing_audience_error_display() {
1362 let err = OAuthError::MissingAudience;
1363 assert_eq!(err.to_string(), "token missing required 'aud' claim");
1364 }
1365
1366 #[test]
1367 fn test_audience_mismatch_error_display() {
1368 let err = OAuthError::AudienceMismatch {
1369 expected: "mcp-server".to_string(),
1370 found: "other-aud".to_string(),
1371 };
1372 let msg = err.to_string();
1373 assert!(msg.contains("audience mismatch"));
1374 assert!(msg.contains("mcp-server"));
1375 assert!(msg.contains("other-aud"));
1376 }
1377
1378 #[test]
1383 fn test_verify_pkce_support_s256_supported() {
1384 let metadata = serde_json::json!({
1385 "issuer": "https://auth.example.com",
1386 "code_challenge_methods_supported": ["S256", "plain"]
1387 });
1388 assert!(verify_pkce_support(&metadata).is_ok());
1389 }
1390
1391 #[test]
1392 fn test_verify_pkce_support_s256_only() {
1393 let metadata = serde_json::json!({
1394 "issuer": "https://auth.example.com",
1395 "code_challenge_methods_supported": ["S256"]
1396 });
1397 assert!(verify_pkce_support(&metadata).is_ok());
1398 }
1399
1400 #[test]
1401 fn test_verify_pkce_support_missing_field() {
1402 let metadata = serde_json::json!({
1403 "issuer": "https://auth.example.com"
1404 });
1405 let result = verify_pkce_support(&metadata);
1406 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1407 }
1408
1409 #[test]
1410 fn test_verify_pkce_support_plain_only() {
1411 let metadata = serde_json::json!({
1413 "issuer": "https://auth.example.com",
1414 "code_challenge_methods_supported": ["plain"]
1415 });
1416 let result = verify_pkce_support(&metadata);
1417 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1418 }
1419
1420 #[test]
1421 fn test_verify_pkce_support_empty_array() {
1422 let metadata = serde_json::json!({
1423 "issuer": "https://auth.example.com",
1424 "code_challenge_methods_supported": []
1425 });
1426 let result = verify_pkce_support(&metadata);
1427 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1428 }
1429
1430 #[test]
1431 fn test_pkce_not_supported_error_display() {
1432 let err = OAuthError::PkceNotSupported;
1433 assert!(err.to_string().contains("PKCE"));
1434 assert!(err.to_string().contains("S256"));
1435 }
1436
1437 const TEST_RSA_PRIVATE_PEM: &str = r#"-----BEGIN PRIVATE KEY-----
1452MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDXsFTcrmrrw3RK
1453Ll2pK5mhqySdvuoQY/U1CwFXmu1S2BAEnh+/Yzsilc/LWJjBmcDdmY88NC+F8PhO
1454q6+hQjWZR08QewinBg69w2+TRqr4x09XXZm/w3Y+jOlspHR85PISy8sqkHzGk3o4
1455cLNCxDkw2mwQaVDQQz1YJ0x22+IoOZniRTntUK1yAyI0jqhpZJjn9dY+CbDt/H8B
1456nGollAlhKQFizDAIMYOL/duJLJv1jtgv5hwvH94tSYgLGzJcufmxvioBBD4ZcxDq
1457Lk2vNdC5ETVkS9GeyYfuQbHW55lYSACe2NfYwgwYc3PO6X6PlJxuksL7JfR6kivU
1458WkYHOTVLAgMBAAECggEATuXElRkEKYvMrRn6ztgREa9N7JoaerZlyupkqkwUxfod
1459GeNRj6vXxNXyNdsJvb/laeozF/2q6J715aktzJowiwonpMqsppQzrjygQspV3jzi
1460C/5EMH5qcYUQGdqqdck1t6Rug/poeicWTTEEkca/eNxdLT+o/RWrieSONuhF+Ro0
14617S60Dc4tFRA6XBDayikUzuFd2XoroRfoukC+HcC7mHMQdPHNt7QjORJNitdjwP+P
1462BcwNm61043sz9VSdW9FMtrdpg+pndbzRiYwYCDRt7r0hhUSZ4cojY6Tyoqexa5BD
14637W5jDTmySO5/Jzl3QGvtevyKVx3x6DQE9858W8kucQKBgQDynqb19kVK6IFPccFX
146401D9qVg0vZ1WoAR30s79DkoKA/NyM3sjP431p38Kkj1QomERBSCb0O1OOfsfjAEO
14652SoEqTSa+2cgDYikQ16IEqKfbucilDNMTsYQz9Jwx16BEJRGoz+lbt52exhZN+nf
1466qfVtuwIvlb46bksxh5pXJ9L7wwKBgQDjlXatVjiZgigpGwmj4/Uj9tI0c+AmtooX
1467zA/2B3GJdXbRVtMvFsQB73/d7U2lCwUHJmG56FYS1Dg8C88Xn7nE6kBSYoeguxCA
1468bLPBbCGtPD/VeGP2ymxLxsULLiiRx4S+6K7hUulCkg9m3CWkc3m5AH5lrHVEXgi1
1469YadcFKMv2QKBgAuaJqXQdxPT9osUB4jppA/dT0iGYMXJtSz9ucREMKo18ihd6d+P
1470pHxA3ERnJeN7QGUN97c70H1TLH0fttU88VNzu/5FU3Mm8ofYaObc7UXuicMPjzxw
14717+vR5GBcSFqnrk+Kcvq4SI8l584sbFSzzfbHYJ1h7czhhVsC/xB36RD9AoGBANQ8
1472JXGer6fQrp0u3r2dL5Y7bmqGCWpw3rU0k0nwRRxYk9bDbqxCQcZAUHFpBPi+HxE8
14735PQXTHXAvTSaGqXASeDuR8/MnQjyioAJX1Uo/vrr7eeonyieO4IrOsSjZigU9aGH
1474otb0mB2B0qUs9lm3arNxV25/9tgsDVkBWa7QfCJ5AoGBAJF/XTU+YnGjQYGgxvfg
1475Ma5j2E3NRga/10ncKjDKbRNzLXk887xp4kl68vDTayAKGLu+ndYQ9dMpHCUTPky5
14762KGQijoG2H/1Ri4JE8dGa+RbjG3gMIRIdbYApn/Q4nrAadrWrDLaTpbnAhhL95FJ
1477TfzccotDw2uXy3Xbwy/kdpfK
1478-----END PRIVATE KEY-----"#;
1479
1480 const TEST_RSA_N: &str = "17BU3K5q68N0Si5dqSuZoasknb7qEGP1NQsBV5rtUtgQBJ4fv2M7IpXPy1iYwZnA3ZmPPDQvhfD4TquvoUI1mUdPEHsIpwYOvcNvk0aq-MdPV12Zv8N2PozpbKR0fOTyEsvLKpB8xpN6OHCzQsQ5MNpsEGlQ0EM9WCdMdtviKDmZ4kU57VCtcgMiNI6oaWSY5_XWPgmw7fx_AZxqJZQJYSkBYswwCDGDi_3biSyb9Y7YL-YcLx_eLUmICxsyXLn5sb4qAQQ-GXMQ6i5NrzXQuRE1ZEvRnsmH7kGx1ueZWEgAntjX2MIMGHNzzul-j5ScbpLC-yX0epIr1FpGBzk1Sw";
1482 const TEST_RSA_E: &str = "AQAB";
1483
1484 fn test_jwks_json(kid: &str) -> String {
1486 serde_json::json!({
1487 "keys": [{
1488 "kty": "RSA",
1489 "use": "sig",
1490 "alg": "RS256",
1491 "kid": kid,
1492 "n": TEST_RSA_N,
1493 "e": TEST_RSA_E
1494 }]
1495 })
1496 .to_string()
1497 }
1498
1499 async fn start_mock_jwks_server(
1505 jwks_json: String,
1506 ) -> Option<(String, tokio::task::JoinHandle<()>)> {
1507 use axum::{routing::get, Router};
1508 use std::net::SocketAddr;
1509
1510 let app = Router::new().route(
1511 "/.well-known/jwks.json",
1512 get(move || {
1513 let json = jwks_json.clone();
1514 async move {
1515 (
1516 [(
1517 http::header::CONTENT_TYPE,
1518 http::HeaderValue::from_static("application/json"),
1519 )],
1520 json,
1521 )
1522 }
1523 }),
1524 );
1525
1526 let listener = match tokio::net::TcpListener::bind("127.0.0.1:0").await {
1527 Ok(listener) => listener,
1528 Err(error) if error.kind() == std::io::ErrorKind::PermissionDenied => {
1529 eprintln!("skipping oauth e2e test: cannot bind local jwks server: {error}");
1530 return None;
1531 }
1532 Err(error) => panic!("bind to random port: {error}"),
1533 };
1534 let addr: SocketAddr = listener.local_addr().expect("local addr");
1535 let base_url = format!("http://127.0.0.1:{}", addr.port());
1536
1537 let handle = tokio::spawn(async move {
1538 axum::serve(listener, app)
1539 .await
1540 .expect("mock JWKS server failed");
1541 });
1542
1543 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1545
1546 Some((base_url, handle))
1547 }
1548
1549 fn test_oauth_config(jwks_url: String) -> OAuthConfig {
1551 OAuthConfig {
1552 issuer: "https://auth.example.com".to_string(),
1553 audience: "mcp-server".to_string(),
1554 jwks_uri: Some(jwks_url),
1555 required_scopes: vec!["tools.call".to_string()],
1556 pass_through: false,
1557 allowed_algorithms: default_allowed_algorithms(),
1558 expected_resource: None,
1559 clock_skew_leeway: Duration::from_secs(30),
1560 require_audience: true,
1561 dpop_mode: DpopMode::Off,
1562 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1563 dpop_require_ath: true,
1564 dpop_max_clock_skew: Duration::from_secs(300),
1565 }
1566 }
1567
1568 fn sign_test_jwt(claims: &serde_json::Value, kid: &str) -> String {
1570 use jsonwebtoken::{encode, EncodingKey, Header};
1571
1572 let key =
1573 EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM.as_bytes()).expect("valid RSA PEM");
1574 let mut header = Header::new(Algorithm::RS256);
1575 header.kid = Some(kid.to_string());
1576
1577 encode(&header, claims, &key).expect("JWT signing must succeed")
1578 }
1579
1580 fn valid_claims() -> serde_json::Value {
1582 let now = chrono::Utc::now().timestamp() as u64;
1583 serde_json::json!({
1584 "sub": "user-123",
1585 "iss": "https://auth.example.com",
1586 "aud": "mcp-server",
1587 "exp": now + 3600,
1588 "iat": now,
1589 "nbf": now - 10,
1590 "scope": "tools.call resources.read"
1591 })
1592 }
1593
1594 #[tokio::test]
1595 async fn test_e2e_valid_jwt_accepted() {
1596 let kid = "test-key-1";
1597 let jwks = test_jwks_json(kid);
1598 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1599 return;
1600 };
1601
1602 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1603 let validator = OAuthValidator::new(config, reqwest::Client::new());
1604
1605 let token = sign_test_jwt(&valid_claims(), kid);
1606 let auth_header = format!("Bearer {token}");
1607
1608 let claims = validator
1609 .validate_token(&auth_header)
1610 .await
1611 .expect("valid JWT must be accepted");
1612 assert_eq!(claims.sub, "user-123");
1613 assert_eq!(claims.iss, "https://auth.example.com");
1614 assert!(claims.scopes().contains(&"tools.call"));
1615 }
1616
1617 #[tokio::test]
1618 async fn test_e2e_expired_jwt_rejected() {
1619 let kid = "test-key-1";
1620 let jwks = test_jwks_json(kid);
1621 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1622 return;
1623 };
1624
1625 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1626 let validator = OAuthValidator::new(config, reqwest::Client::new());
1627
1628 let now = chrono::Utc::now().timestamp() as u64;
1629 let mut claims = valid_claims();
1630 claims["exp"] = serde_json::json!(now - 600); let token = sign_test_jwt(&claims, kid);
1633 let auth_header = format!("Bearer {token}");
1634
1635 let err = validator
1636 .validate_token(&auth_header)
1637 .await
1638 .expect_err("expired JWT must be rejected");
1639 assert!(
1640 matches!(err, OAuthError::JwtError(_)),
1641 "expected JwtError for expired token, got: {err}"
1642 );
1643 }
1644
1645 #[tokio::test]
1646 async fn test_e2e_wrong_algorithm_rejected() {
1647 let kid = "test-key-1";
1648 let jwks = test_jwks_json(kid);
1649 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1650 return;
1651 };
1652
1653 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1655 config.allowed_algorithms = vec![Algorithm::ES256];
1656
1657 let validator = OAuthValidator::new(config, reqwest::Client::new());
1658
1659 let token = sign_test_jwt(&valid_claims(), kid);
1661 let auth_header = format!("Bearer {token}");
1662
1663 let err = validator
1664 .validate_token(&auth_header)
1665 .await
1666 .expect_err("RS256 JWT must be rejected when only ES256 is allowed");
1667 assert!(
1668 matches!(err, OAuthError::DisallowedAlgorithm(Algorithm::RS256)),
1669 "expected DisallowedAlgorithm(RS256), got: {err}"
1670 );
1671 }
1672
1673 #[tokio::test]
1674 async fn test_e2e_wrong_issuer_rejected() {
1675 let kid = "test-key-1";
1676 let jwks = test_jwks_json(kid);
1677 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1678 return;
1679 };
1680
1681 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1682 let validator = OAuthValidator::new(config, reqwest::Client::new());
1683
1684 let mut claims = valid_claims();
1685 claims["iss"] = serde_json::json!("https://evil.example.com");
1686
1687 let token = sign_test_jwt(&claims, kid);
1688 let auth_header = format!("Bearer {token}");
1689
1690 let err = validator
1691 .validate_token(&auth_header)
1692 .await
1693 .expect_err("wrong issuer must be rejected");
1694 assert!(
1695 matches!(err, OAuthError::JwtError(_)),
1696 "expected JwtError for issuer mismatch, got: {err}"
1697 );
1698 }
1699
1700 #[tokio::test]
1701 async fn test_e2e_wrong_audience_rejected() {
1702 let kid = "test-key-1";
1703 let jwks = test_jwks_json(kid);
1704 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1705 return;
1706 };
1707
1708 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1709 let validator = OAuthValidator::new(config, reqwest::Client::new());
1710
1711 let mut claims = valid_claims();
1712 claims["aud"] = serde_json::json!("wrong-audience");
1713
1714 let token = sign_test_jwt(&claims, kid);
1715 let auth_header = format!("Bearer {token}");
1716
1717 let err = validator
1718 .validate_token(&auth_header)
1719 .await
1720 .expect_err("wrong audience must be rejected");
1721 assert!(
1723 matches!(
1724 err,
1725 OAuthError::JwtError(_) | OAuthError::AudienceMismatch { .. }
1726 ),
1727 "expected audience rejection, got: {err}"
1728 );
1729 }
1730
1731 #[tokio::test]
1732 async fn test_e2e_missing_required_scope_rejected() {
1733 let kid = "test-key-1";
1734 let jwks = test_jwks_json(kid);
1735 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1736 return;
1737 };
1738
1739 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1740 let validator = OAuthValidator::new(config, reqwest::Client::new());
1741
1742 let mut claims = valid_claims();
1743 claims["scope"] = serde_json::json!("resources.read"); let token = sign_test_jwt(&claims, kid);
1746 let auth_header = format!("Bearer {token}");
1747
1748 let err = validator
1749 .validate_token(&auth_header)
1750 .await
1751 .expect_err("missing required scope must be rejected");
1752 assert!(
1753 matches!(err, OAuthError::InsufficientScope { .. }),
1754 "expected InsufficientScope, got: {err}"
1755 );
1756 }
1757
1758 #[tokio::test]
1759 async fn test_e2e_resource_mismatch_rejected() {
1760 let kid = "test-key-1";
1761 let jwks = test_jwks_json(kid);
1762 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1763 return;
1764 };
1765
1766 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1767 config.expected_resource = Some("https://mcp.example.com".to_string());
1768
1769 let validator = OAuthValidator::new(config, reqwest::Client::new());
1770
1771 let mut claims = valid_claims();
1773 claims["resource"] = serde_json::json!("https://evil.example.com");
1774
1775 let token = sign_test_jwt(&claims, kid);
1776 let auth_header = format!("Bearer {token}");
1777
1778 let err = validator
1779 .validate_token(&auth_header)
1780 .await
1781 .expect_err("resource mismatch must be rejected");
1782 assert!(
1783 matches!(err, OAuthError::ResourceMismatch { .. }),
1784 "expected ResourceMismatch, got: {err}"
1785 );
1786 }
1787
1788 #[tokio::test]
1789 async fn test_e2e_resource_missing_when_required_rejected() {
1790 let kid = "test-key-1";
1791 let jwks = test_jwks_json(kid);
1792 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1793 return;
1794 };
1795
1796 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1797 config.expected_resource = Some("https://mcp.example.com".to_string());
1798
1799 let validator = OAuthValidator::new(config, reqwest::Client::new());
1800
1801 let token = sign_test_jwt(&valid_claims(), kid);
1803 let auth_header = format!("Bearer {token}");
1804
1805 let err = validator
1806 .validate_token(&auth_header)
1807 .await
1808 .expect_err("missing resource when required must be rejected");
1809 assert!(
1810 matches!(err, OAuthError::ResourceMismatch { .. }),
1811 "expected ResourceMismatch, got: {err}"
1812 );
1813 }
1814
1815 #[tokio::test]
1816 async fn test_e2e_resource_match_accepted() {
1817 let kid = "test-key-1";
1818 let jwks = test_jwks_json(kid);
1819 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1820 return;
1821 };
1822
1823 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1824 config.expected_resource = Some("https://mcp.example.com".to_string());
1825
1826 let validator = OAuthValidator::new(config, reqwest::Client::new());
1827
1828 let mut claims = valid_claims();
1829 claims["resource"] = serde_json::json!("https://mcp.example.com");
1830
1831 let token = sign_test_jwt(&claims, kid);
1832 let auth_header = format!("Bearer {token}");
1833
1834 let result = validator.validate_token(&auth_header).await;
1835 assert!(
1836 result.is_ok(),
1837 "matching resource must be accepted: {result:?}"
1838 );
1839 }
1840
1841 #[tokio::test]
1842 async fn test_e2e_kid_mismatch_rejected() {
1843 let jwks = test_jwks_json("server-key-1");
1844 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1845 return;
1846 };
1847
1848 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1849 let validator = OAuthValidator::new(config, reqwest::Client::new());
1850
1851 let token = sign_test_jwt(&valid_claims(), "wrong-key");
1853 let auth_header = format!("Bearer {token}");
1854
1855 let err = validator
1856 .validate_token(&auth_header)
1857 .await
1858 .expect_err("kid mismatch must be rejected");
1859 assert!(
1860 matches!(err, OAuthError::NoMatchingKey(_)),
1861 "expected NoMatchingKey, got: {err}"
1862 );
1863 }
1864
1865 #[tokio::test]
1866 async fn test_e2e_multi_key_jwks_no_kid_rejected() {
1867 let jwks = serde_json::json!({
1869 "keys": [
1870 {
1871 "kty": "RSA", "use": "sig", "alg": "RS256",
1872 "kid": "key-1",
1873 "n": TEST_RSA_N, "e": TEST_RSA_E
1874 },
1875 {
1876 "kty": "RSA", "use": "sig", "alg": "RS256",
1877 "kid": "key-2",
1878 "n": TEST_RSA_N, "e": TEST_RSA_E
1879 }
1880 ]
1881 })
1882 .to_string();
1883 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1884 return;
1885 };
1886
1887 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1888 let validator = OAuthValidator::new(config, reqwest::Client::new());
1889
1890 let key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM.as_bytes())
1892 .expect("valid RSA PEM");
1893 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
1894 header.kid = None; let token = jsonwebtoken::encode(&header, &valid_claims(), &key).expect("JWT signing");
1896 let auth_header = format!("Bearer {token}");
1897
1898 let err = validator
1899 .validate_token(&auth_header)
1900 .await
1901 .expect_err("missing kid with multi-key JWKS must be rejected");
1902 assert!(
1903 matches!(err, OAuthError::MissingKid(2)),
1904 "expected MissingKid(2), got: {err}"
1905 );
1906 }
1907
1908 #[tokio::test]
1909 async fn test_e2e_tampered_signature_rejected() {
1910 let kid = "test-key-1";
1911 let jwks = test_jwks_json(kid);
1912 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1913 return;
1914 };
1915
1916 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1917 let validator = OAuthValidator::new(config, reqwest::Client::new());
1918
1919 let token = sign_test_jwt(&valid_claims(), kid);
1920 let mut tampered = token.clone();
1922 let last_char = tampered.pop().unwrap_or('A');
1923 tampered.push(if last_char == 'A' { 'B' } else { 'A' });
1924
1925 let auth_header = format!("Bearer {tampered}");
1926
1927 let err = validator
1928 .validate_token(&auth_header)
1929 .await
1930 .expect_err("tampered signature must be rejected");
1931 assert!(
1932 matches!(err, OAuthError::JwtError(_)),
1933 "expected JwtError for signature tampering, got: {err}"
1934 );
1935 }
1936
1937 #[tokio::test]
1938 async fn test_e2e_missing_audience_with_require_audience_rejected() {
1939 let kid = "test-key-1";
1940 let jwks = test_jwks_json(kid);
1941 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1942 return;
1943 };
1944
1945 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1946 config.require_audience = true;
1947
1948 let validator = OAuthValidator::new(config, reqwest::Client::new());
1949
1950 let now = chrono::Utc::now().timestamp() as u64;
1952 let claims = serde_json::json!({
1953 "sub": "user-123",
1954 "iss": "https://auth.example.com",
1955 "exp": now + 3600,
1956 "iat": now,
1957 "nbf": now - 10,
1958 "scope": "tools.call"
1959 });
1960
1961 let token = sign_test_jwt(&claims, kid);
1962 let auth_header = format!("Bearer {token}");
1963
1964 let err = validator
1965 .validate_token(&auth_header)
1966 .await
1967 .expect_err("missing aud with require_audience must be rejected");
1968 assert!(
1969 matches!(err, OAuthError::JwtError(_) | OAuthError::MissingAudience),
1970 "expected audience rejection, got: {err}"
1971 );
1972 }
1973
1974 #[tokio::test]
1975 async fn test_e2e_bearer_case_insensitive() {
1976 let kid = "test-key-1";
1977 let jwks = test_jwks_json(kid);
1978 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1979 return;
1980 };
1981
1982 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
1983 let validator = OAuthValidator::new(config, reqwest::Client::new());
1984
1985 let token = sign_test_jwt(&valid_claims(), kid);
1986
1987 let auth_header = format!("BEARER {token}");
1989 let result = validator.validate_token(&auth_header).await;
1990 assert!(
1991 result.is_ok(),
1992 "BEARER (uppercase) must be accepted: {result:?}"
1993 );
1994 }
1995
1996 #[tokio::test]
1997 async fn test_e2e_not_before_future_rejected() {
1998 let kid = "test-key-1";
1999 let jwks = test_jwks_json(kid);
2000 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
2001 return;
2002 };
2003
2004 let config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
2005 let validator = OAuthValidator::new(config, reqwest::Client::new());
2006
2007 let now = chrono::Utc::now().timestamp() as u64;
2008 let mut claims = valid_claims();
2009 claims["nbf"] = serde_json::json!(now + 600); let token = sign_test_jwt(&claims, kid);
2012 let auth_header = format!("Bearer {token}");
2013
2014 let err = validator
2015 .validate_token(&auth_header)
2016 .await
2017 .expect_err("token with future nbf must be rejected");
2018 assert!(
2019 matches!(err, OAuthError::JwtError(_)),
2020 "expected JwtError for nbf in the future, got: {err}"
2021 );
2022 }
2023
2024 #[tokio::test]
2025 async fn test_e2e_dpop_required_but_no_cnf_jkt_rejected() {
2026 let kid = "test-key-1";
2027 let jwks = test_jwks_json(kid);
2028 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
2029 return;
2030 };
2031
2032 let mut config = test_oauth_config(format!("{base_url}/.well-known/jwks.json"));
2033 config.dpop_mode = DpopMode::Required;
2034
2035 let validator = OAuthValidator::new(config, reqwest::Client::new());
2036
2037 let token = sign_test_jwt(&valid_claims(), kid);
2039 let auth_header = format!("Bearer {token}");
2040
2041 let err = validator
2042 .validate_token(&auth_header)
2043 .await
2044 .expect_err("DPoP required but no cnf.jkt must be rejected");
2045 assert!(
2046 matches!(err, OAuthError::InvalidDpopProof(_)),
2047 "expected InvalidDpopProof, got: {err}"
2048 );
2049 }
2050
2051 #[test]
2054 fn test_contains_control_chars_rejects_dangerous_chars() {
2055 assert!(!contains_control_chars("normal-user@example.com"));
2057 assert!(!contains_control_chars("admin"));
2058 assert!(!contains_control_chars(""));
2059
2060 assert!(contains_control_chars("user\x00name")); assert!(contains_control_chars("user\x1bname")); assert!(contains_control_chars("user\x07name")); assert!(contains_control_chars("user\u{200B}name")); assert!(contains_control_chars("user\u{202E}name")); assert!(contains_control_chars("user\u{FEFF}name")); assert!(contains_control_chars("line1\nline2")); assert!(contains_control_chars("col1\tcol2")); assert!(contains_control_chars("x\u{00AD}y")); assert!(contains_control_chars("x\u{FFF9}y")); }
2080
2081 #[test]
2086 fn test_decode_unreserved_percent_decodes_unreserved_chars() {
2087 assert_eq!(decode_unreserved_percent("%2D"), "-");
2088 assert_eq!(decode_unreserved_percent("%2E"), ".");
2089 assert_eq!(decode_unreserved_percent("%5F"), "_");
2090 assert_eq!(decode_unreserved_percent("%7E"), "~");
2091 assert_eq!(decode_unreserved_percent("%41"), "A");
2092 assert_eq!(decode_unreserved_percent("%61"), "a");
2093 assert_eq!(decode_unreserved_percent("%30"), "0");
2094 }
2095
2096 #[test]
2097 fn test_decode_unreserved_percent_keeps_reserved_encoded() {
2098 assert_eq!(decode_unreserved_percent("%2F"), "%2F"); assert_eq!(decode_unreserved_percent("%40"), "%40"); assert_eq!(decode_unreserved_percent("%3A"), "%3A"); assert_eq!(decode_unreserved_percent("%00"), "%00"); assert_eq!(decode_unreserved_percent("%20"), "%20"); assert_eq!(decode_unreserved_percent("%3F"), "%3F"); assert_eq!(decode_unreserved_percent("%23"), "%23"); }
2106
2107 #[test]
2108 fn test_decode_unreserved_percent_normalizes_hex_case() {
2109 assert_eq!(decode_unreserved_percent("%2d"), "-");
2111 assert_eq!(decode_unreserved_percent("%7e"), "~");
2112 assert_eq!(decode_unreserved_percent("%2f"), "%2F");
2114 assert_eq!(decode_unreserved_percent("%3a"), "%3A");
2115 }
2116
2117 #[test]
2118 fn test_decode_unreserved_percent_incomplete_sequences() {
2119 assert_eq!(decode_unreserved_percent("foo%"), "foo%");
2120 assert_eq!(decode_unreserved_percent("foo%2"), "foo%2");
2121 assert_eq!(decode_unreserved_percent("%"), "%");
2122 assert_eq!(decode_unreserved_percent("%G0"), "%G0"); }
2124
2125 #[test]
2126 fn test_decode_unreserved_percent_mixed_content() {
2127 assert_eq!(
2128 decode_unreserved_percent("foo%2Dbar%2Fbaz"),
2129 "foo-bar%2Fbaz"
2130 );
2131 assert_eq!(decode_unreserved_percent(""), "");
2132 assert_eq!(decode_unreserved_percent("no-encoding"), "no-encoding");
2133 assert_eq!(decode_unreserved_percent("a%2Db%2Ec%5Fd%7Ee"), "a-b.c_d~e");
2134 }
2135
2136 }