1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 time::{Duration, Instant},
20};
21
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
23use serde::Deserialize;
24use tokio::sync::RwLock;
25
26use crate::auth::{AuthIdentity, AuthMethod};
27
28fn evaluate_oauth_redirect(
51 attempt: &reqwest::redirect::Attempt<'_>,
52 allow_http: bool,
53) -> Result<(), String> {
54 let prev_https = attempt
55 .previous()
56 .last()
57 .is_some_and(|prev| prev.scheme() == "https");
58 let target_url = attempt.url();
59 let dest_scheme = target_url.scheme();
60 if dest_scheme != "https" {
61 if prev_https {
62 return Err("redirect downgrades https -> http".to_owned());
63 }
64 if !allow_http || dest_scheme != "http" {
65 return Err("redirect to non-HTTP(S) URL refused".to_owned());
66 }
67 }
68 if let Some(reason) = crate::ssrf::redirect_target_reason(target_url) {
69 return Err(format!("redirect target forbidden: {reason}"));
70 }
71 if attempt.previous().len() >= 2 {
72 return Err("too many redirects (max 2)".to_owned());
73 }
74 Ok(())
75}
76
77#[derive(Clone)]
118pub struct OauthHttpClient {
119 inner: reqwest::Client,
120}
121
122impl OauthHttpClient {
123 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
141 Self::build(Some(config))
142 }
143
144 #[deprecated(
167 since = "1.2.1",
168 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
169 )]
170 pub fn new() -> Result<Self, crate::error::McpxError> {
171 Self::build(None)
172 }
173
174 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
177 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
178
179 let mut builder = reqwest::Client::builder()
180 .connect_timeout(Duration::from_secs(10))
181 .timeout(Duration::from_secs(30))
182 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
183 match evaluate_oauth_redirect(&attempt, allow_http) {
193 Ok(()) => attempt.follow(),
194 Err(reason) => {
195 tracing::warn!(
196 reason = %reason,
197 target = %attempt.url(),
198 "oauth redirect rejected"
199 );
200 attempt.error(reason)
201 }
202 }
203 }));
204
205 if let Some(cfg) = config
206 && let Some(ref ca_path) = cfg.ca_cert_path
207 {
208 let pem = std::fs::read(ca_path).map_err(|e| {
213 crate::error::McpxError::Startup(format!(
214 "oauth http client: read ca_cert_path {}: {e}",
215 ca_path.display()
216 ))
217 })?;
218 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
219 crate::error::McpxError::Startup(format!(
220 "oauth http client: parse ca_cert_path {}: {e}",
221 ca_path.display()
222 ))
223 })?;
224 builder = builder.add_root_certificate(cert);
225 }
226
227 let inner = builder.build().map_err(|e| {
228 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
229 })?;
230 Ok(Self { inner })
231 }
232
233 #[doc(hidden)]
239 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
240 self.inner.get(url).send().await
241 }
242}
243
244impl std::fmt::Debug for OauthHttpClient {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
247 }
248}
249
250#[derive(Debug, Clone, Deserialize)]
256#[non_exhaustive]
257pub struct OAuthConfig {
258 pub issuer: String,
260 pub audience: String,
262 pub jwks_uri: String,
264 #[serde(default)]
267 pub scopes: Vec<ScopeMapping>,
268 pub role_claim: Option<String>,
274 #[serde(default)]
277 pub role_mappings: Vec<RoleMapping>,
278 #[serde(default = "default_jwks_cache_ttl")]
281 pub jwks_cache_ttl: String,
282 pub proxy: Option<OAuthProxyConfig>,
286 pub token_exchange: Option<TokenExchangeConfig>,
291 #[serde(default)]
306 pub ca_cert_path: Option<PathBuf>,
307 #[serde(default)]
319 pub allow_http_oauth_urls: bool,
320 #[serde(default = "default_max_jwks_keys")]
324 pub max_jwks_keys: usize,
325}
326
327fn default_jwks_cache_ttl() -> String {
328 "10m".into()
329}
330
331const fn default_max_jwks_keys() -> usize {
332 256
333}
334
335impl Default for OAuthConfig {
336 fn default() -> Self {
337 Self {
338 issuer: String::new(),
339 audience: String::new(),
340 jwks_uri: String::new(),
341 scopes: Vec::new(),
342 role_claim: None,
343 role_mappings: Vec::new(),
344 jwks_cache_ttl: default_jwks_cache_ttl(),
345 proxy: None,
346 token_exchange: None,
347 ca_cert_path: None,
348 allow_http_oauth_urls: false,
349 max_jwks_keys: default_max_jwks_keys(),
350 }
351 }
352}
353
354impl OAuthConfig {
355 pub fn builder(
361 issuer: impl Into<String>,
362 audience: impl Into<String>,
363 jwks_uri: impl Into<String>,
364 ) -> OAuthConfigBuilder {
365 OAuthConfigBuilder {
366 inner: Self {
367 issuer: issuer.into(),
368 audience: audience.into(),
369 jwks_uri: jwks_uri.into(),
370 ..Self::default()
371 },
372 }
373 }
374
375 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
391 let allow_http = self.allow_http_oauth_urls;
392 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
393 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
394 return Err(crate::error::McpxError::Config(format!(
395 "oauth.issuer forbidden ({reason})"
396 )));
397 }
398 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
399 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
400 return Err(crate::error::McpxError::Config(format!(
401 "oauth.jwks_uri forbidden ({reason})"
402 )));
403 }
404 if let Some(proxy) = &self.proxy {
405 let url = check_oauth_url(
406 "oauth.proxy.authorize_url",
407 &proxy.authorize_url,
408 allow_http,
409 )?;
410 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
411 return Err(crate::error::McpxError::Config(format!(
412 "oauth.proxy.authorize_url forbidden ({reason})"
413 )));
414 }
415 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
416 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
417 return Err(crate::error::McpxError::Config(format!(
418 "oauth.proxy.token_url forbidden ({reason})"
419 )));
420 }
421 if let Some(url) = &proxy.introspection_url {
422 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
423 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
424 return Err(crate::error::McpxError::Config(format!(
425 "oauth.proxy.introspection_url forbidden ({reason})"
426 )));
427 }
428 }
429 if let Some(url) = &proxy.revocation_url {
430 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
431 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
432 return Err(crate::error::McpxError::Config(format!(
433 "oauth.proxy.revocation_url forbidden ({reason})"
434 )));
435 }
436 }
437 }
438 if let Some(tx) = &self.token_exchange {
439 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
440 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
441 return Err(crate::error::McpxError::Config(format!(
442 "oauth.token_exchange.token_url forbidden ({reason})"
443 )));
444 }
445 }
446 Ok(())
447 }
448}
449
450fn check_oauth_url(
457 field: &str,
458 raw: &str,
459 allow_http: bool,
460) -> Result<url::Url, crate::error::McpxError> {
461 let parsed = url::Url::parse(raw).map_err(|e| {
462 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
463 })?;
464 if !parsed.username().is_empty() || parsed.password().is_some() {
465 return Err(crate::error::McpxError::Config(format!(
466 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
467 )));
468 }
469 match parsed.scheme() {
470 "https" => Ok(parsed),
471 "http" if allow_http => Ok(parsed),
472 "http" => Err(crate::error::McpxError::Config(format!(
473 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
474 to override - strongly discouraged in production)"
475 ))),
476 other => Err(crate::error::McpxError::Config(format!(
477 "{field}: must use https scheme (got {other:?})"
478 ))),
479 }
480}
481
482#[derive(Debug, Clone)]
488#[must_use = "builders do nothing until `.build()` is called"]
489pub struct OAuthConfigBuilder {
490 inner: OAuthConfig,
491}
492
493impl OAuthConfigBuilder {
494 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
496 self.inner.scopes = scopes;
497 self
498 }
499
500 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
502 self.inner.scopes.push(ScopeMapping {
503 scope: scope.into(),
504 role: role.into(),
505 });
506 self
507 }
508
509 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
512 self.inner.role_claim = Some(claim.into());
513 self
514 }
515
516 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
518 self.inner.role_mappings = mappings;
519 self
520 }
521
522 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
525 self.inner.role_mappings.push(RoleMapping {
526 claim_value: claim_value.into(),
527 role: role.into(),
528 });
529 self
530 }
531
532 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
535 self.inner.jwks_cache_ttl = ttl.into();
536 self
537 }
538
539 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
542 self.inner.proxy = Some(proxy);
543 self
544 }
545
546 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
548 self.inner.token_exchange = Some(token_exchange);
549 self
550 }
551
552 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
557 self.inner.ca_cert_path = Some(path.into());
558 self
559 }
560
561 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
567 self.inner.allow_http_oauth_urls = allow;
568 self
569 }
570
571 #[must_use]
573 pub fn build(self) -> OAuthConfig {
574 self.inner
575 }
576}
577
578#[derive(Debug, Clone, Deserialize)]
580#[non_exhaustive]
581pub struct ScopeMapping {
582 pub scope: String,
584 pub role: String,
586}
587
588#[derive(Debug, Clone, Deserialize)]
592#[non_exhaustive]
593pub struct RoleMapping {
594 pub claim_value: String,
596 pub role: String,
598}
599
600#[derive(Debug, Clone, Deserialize)]
607#[non_exhaustive]
608pub struct TokenExchangeConfig {
609 pub token_url: String,
612 pub client_id: String,
614 pub client_secret: Option<secrecy::SecretString>,
617 pub client_cert: Option<ClientCertConfig>,
621 pub audience: String,
625}
626
627impl TokenExchangeConfig {
628 #[must_use]
630 pub fn new(
631 token_url: String,
632 client_id: String,
633 client_secret: Option<secrecy::SecretString>,
634 client_cert: Option<ClientCertConfig>,
635 audience: String,
636 ) -> Self {
637 Self {
638 token_url,
639 client_id,
640 client_secret,
641 client_cert,
642 audience,
643 }
644 }
645}
646
647#[derive(Debug, Clone, Deserialize)]
650#[non_exhaustive]
651pub struct ClientCertConfig {
652 pub cert_path: PathBuf,
654 pub key_path: PathBuf,
656}
657
658#[derive(Debug, Deserialize)]
660#[non_exhaustive]
661pub struct ExchangedToken {
662 pub access_token: String,
664 pub expires_in: Option<u64>,
666 pub issued_token_type: Option<String>,
669}
670
671#[derive(Debug, Clone, Deserialize, Default)]
678#[non_exhaustive]
679pub struct OAuthProxyConfig {
680 pub authorize_url: String,
683 pub token_url: String,
686 pub client_id: String,
688 pub client_secret: Option<secrecy::SecretString>,
690 #[serde(default)]
694 pub introspection_url: Option<String>,
695 #[serde(default)]
699 pub revocation_url: Option<String>,
700 #[serde(default)]
712 pub expose_admin_endpoints: bool,
713}
714
715impl OAuthProxyConfig {
716 pub fn builder(
724 authorize_url: impl Into<String>,
725 token_url: impl Into<String>,
726 client_id: impl Into<String>,
727 ) -> OAuthProxyConfigBuilder {
728 OAuthProxyConfigBuilder {
729 inner: Self {
730 authorize_url: authorize_url.into(),
731 token_url: token_url.into(),
732 client_id: client_id.into(),
733 ..Self::default()
734 },
735 }
736 }
737}
738
739#[derive(Debug, Clone)]
745#[must_use = "builders do nothing until `.build()` is called"]
746pub struct OAuthProxyConfigBuilder {
747 inner: OAuthProxyConfig,
748}
749
750impl OAuthProxyConfigBuilder {
751 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
753 self.inner.client_secret = Some(secret);
754 self
755 }
756
757 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
761 self.inner.introspection_url = Some(url.into());
762 self
763 }
764
765 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
769 self.inner.revocation_url = Some(url.into());
770 self
771 }
772
773 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
781 self.inner.expose_admin_endpoints = expose;
782 self
783 }
784
785 #[must_use]
787 pub fn build(self) -> OAuthProxyConfig {
788 self.inner
789 }
790}
791
792type JwksKeyCache = (
800 HashMap<String, (Algorithm, DecodingKey)>,
801 Vec<(Algorithm, DecodingKey)>,
802);
803
804struct CachedKeys {
805 keys: HashMap<String, (Algorithm, DecodingKey)>,
807 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
809 fetched_at: Instant,
810 ttl: Duration,
811}
812
813impl CachedKeys {
814 fn is_expired(&self) -> bool {
815 self.fetched_at.elapsed() >= self.ttl
816 }
817}
818
819#[allow(
828 missing_debug_implementations,
829 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
830)]
831#[non_exhaustive]
832pub struct JwksCache {
833 jwks_uri: String,
834 ttl: Duration,
835 max_jwks_keys: usize,
836 inner: RwLock<Option<CachedKeys>>,
837 http: reqwest::Client,
838 validation_template: Validation,
839 expected_audience: String,
842 scopes: Vec<ScopeMapping>,
843 role_claim: Option<String>,
844 role_mappings: Vec<RoleMapping>,
845 last_refresh_attempt: RwLock<Option<Instant>>,
848 refresh_lock: tokio::sync::Mutex<()>,
850}
851
852const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
854
855const ACCEPTED_ALGS: &[Algorithm] = &[
857 Algorithm::RS256,
858 Algorithm::RS384,
859 Algorithm::RS512,
860 Algorithm::ES256,
861 Algorithm::ES384,
862 Algorithm::PS256,
863 Algorithm::PS384,
864 Algorithm::PS512,
865 Algorithm::EdDSA,
866];
867
868#[derive(Debug, Clone, Copy, PartialEq, Eq)]
870#[non_exhaustive]
871pub enum JwtValidationFailure {
872 Expired,
874 Invalid,
876}
877
878impl JwksCache {
879 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
886 rustls::crypto::ring::default_provider()
889 .install_default()
890 .ok();
891 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
892 .install_default()
893 .ok();
894
895 let ttl =
896 humantime::parse_duration(&config.jwks_cache_ttl).unwrap_or(Duration::from_mins(10));
897
898 let mut validation = Validation::new(Algorithm::RS256);
899 validation.validate_aud = false;
911 validation.set_issuer(&[&config.issuer]);
912 validation.set_required_spec_claims(&["exp", "iss"]);
913 validation.validate_exp = true;
914 validation.validate_nbf = true;
915
916 let allow_http = config.allow_http_oauth_urls;
917
918 let mut http_builder = reqwest::Client::builder()
919 .timeout(Duration::from_secs(10))
920 .connect_timeout(Duration::from_secs(3))
921 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
922 match evaluate_oauth_redirect(&attempt, allow_http) {
932 Ok(()) => attempt.follow(),
933 Err(reason) => {
934 tracing::warn!(
935 reason = %reason,
936 target = %attempt.url(),
937 "oauth redirect rejected"
938 );
939 attempt.error(reason)
940 }
941 }
942 }));
943
944 if let Some(ref ca_path) = config.ca_cert_path {
945 let pem = std::fs::read(ca_path)?;
951 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
952 http_builder = http_builder.add_root_certificate(cert);
953 }
954
955 let http = http_builder.build()?;
956
957 Ok(Self {
958 jwks_uri: config.jwks_uri.clone(),
959 ttl,
960 max_jwks_keys: config.max_jwks_keys,
961 inner: RwLock::new(None),
962 http,
963 validation_template: validation,
964 expected_audience: config.audience.clone(),
965 scopes: config.scopes.clone(),
966 role_claim: config.role_claim.clone(),
967 role_mappings: config.role_mappings.clone(),
968 last_refresh_attempt: RwLock::new(None),
969 refresh_lock: tokio::sync::Mutex::new(()),
970 })
971 }
972
973 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
975 self.validate_token_with_reason(token).await.ok()
976 }
977
978 pub async fn validate_token_with_reason(
985 &self,
986 token: &str,
987 ) -> Result<AuthIdentity, JwtValidationFailure> {
988 let claims = self.decode_claims(token).await?;
989
990 self.check_audience(&claims)?;
991 let role = self.resolve_role(&claims)?;
992
993 let sub = claims.sub;
996 let name = claims
997 .extra
998 .get("preferred_username")
999 .and_then(|v| v.as_str())
1000 .map(String::from)
1001 .or_else(|| sub.clone())
1002 .or(claims.azp)
1003 .or(claims.client_id)
1004 .unwrap_or_else(|| "oauth-client".into());
1005
1006 Ok(AuthIdentity {
1007 name,
1008 role,
1009 method: AuthMethod::OAuthJwt,
1010 raw_token: None,
1011 sub,
1012 })
1013 }
1014
1015 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1027 let (key, alg) = self.select_jwks_key(token).await?;
1028
1029 let mut validation = self.validation_template.clone();
1033 validation.algorithms = vec![alg];
1034
1035 let token_owned = token.to_owned();
1038 let join =
1039 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1040 .await;
1041
1042 let decode_result = match join {
1043 Ok(r) => r,
1044 Err(join_err) => {
1045 core::hint::cold_path();
1046 tracing::error!(
1047 error = %join_err,
1048 "JWT decode task panicked or was cancelled"
1049 );
1050 return Err(JwtValidationFailure::Invalid);
1051 }
1052 };
1053
1054 decode_result.map(|td| td.claims).map_err(|e| {
1055 core::hint::cold_path();
1056 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
1057 JwtValidationFailure::Expired
1058 } else {
1059 JwtValidationFailure::Invalid
1060 };
1061 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
1062 failure
1063 })
1064 }
1065
1066 #[allow(clippy::cognitive_complexity)]
1075 async fn select_jwks_key(
1076 &self,
1077 token: &str,
1078 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
1079 let Ok(header) = decode_header(token) else {
1080 core::hint::cold_path();
1081 tracing::debug!("JWT header decode failed");
1082 return Err(JwtValidationFailure::Invalid);
1083 };
1084 let kid = header.kid.as_deref();
1085 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
1086
1087 if !ACCEPTED_ALGS.contains(&header.alg) {
1088 core::hint::cold_path();
1089 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
1090 return Err(JwtValidationFailure::Invalid);
1091 }
1092
1093 let Some(key) = self.find_key(kid, header.alg).await else {
1094 core::hint::cold_path();
1095 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
1096 return Err(JwtValidationFailure::Invalid);
1097 };
1098
1099 Ok((key, header.alg))
1100 }
1101
1102 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
1110 let aud_ok = claims.aud.contains(&self.expected_audience)
1111 || claims
1112 .azp
1113 .as_deref()
1114 .is_some_and(|azp| azp == self.expected_audience);
1115 if aud_ok {
1116 return Ok(());
1117 }
1118 core::hint::cold_path();
1119 tracing::debug!(
1120 aud = ?claims.aud.0,
1121 azp = ?claims.azp,
1122 expected = %self.expected_audience,
1123 "JWT rejected: audience mismatch (neither aud nor azp match)"
1124 );
1125 Err(JwtValidationFailure::Invalid)
1126 }
1127
1128 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
1134 if let Some(ref claim_path) = self.role_claim {
1135 let values = resolve_claim_path(&claims.extra, claim_path);
1136 return self
1137 .role_mappings
1138 .iter()
1139 .find(|m| values.contains(&m.claim_value.as_str()))
1140 .map(|m| m.role.clone())
1141 .ok_or(JwtValidationFailure::Invalid);
1142 }
1143
1144 let token_scopes: Vec<&str> = claims
1145 .scope
1146 .as_deref()
1147 .unwrap_or("")
1148 .split_whitespace()
1149 .collect();
1150
1151 self.scopes
1152 .iter()
1153 .find(|m| token_scopes.contains(&m.scope.as_str()))
1154 .map(|m| m.role.clone())
1155 .ok_or(JwtValidationFailure::Invalid)
1156 }
1157
1158 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1161 {
1163 let guard = self.inner.read().await;
1164 if let Some(cached) = guard.as_ref()
1165 && !cached.is_expired()
1166 && let Some(key) = lookup_key(cached, kid, alg)
1167 {
1168 return Some(key);
1169 }
1170 }
1171
1172 self.refresh_with_cooldown().await;
1174
1175 let guard = self.inner.read().await;
1176 guard
1177 .as_ref()
1178 .and_then(|cached| lookup_key(cached, kid, alg))
1179 }
1180
1181 async fn refresh_with_cooldown(&self) {
1186 let _guard = self.refresh_lock.lock().await;
1188
1189 {
1191 let last = self.last_refresh_attempt.read().await;
1192 if let Some(ts) = *last
1193 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
1194 {
1195 tracing::debug!(
1196 elapsed_ms = ts.elapsed().as_millis(),
1197 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
1198 "JWKS refresh skipped (cooldown active)"
1199 );
1200 return;
1201 }
1202 }
1203
1204 {
1207 let mut last = self.last_refresh_attempt.write().await;
1208 *last = Some(Instant::now());
1209 }
1210
1211 let _ = self.refresh_inner().await;
1213 }
1214
1215 async fn refresh_inner(&self) -> Result<(), String> {
1220 let Some(jwks) = self.fetch_jwks().await else {
1221 return Ok(());
1222 };
1223 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
1224 Ok(cache) => cache,
1225 Err(msg) => {
1226 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
1227 return Err(msg);
1228 }
1229 };
1230
1231 tracing::debug!(
1232 named = keys.len(),
1233 unnamed = unnamed_keys.len(),
1234 "JWKS refreshed"
1235 );
1236
1237 let mut guard = self.inner.write().await;
1238 *guard = Some(CachedKeys {
1239 keys,
1240 unnamed_keys,
1241 fetched_at: Instant::now(),
1242 ttl: self.ttl,
1243 });
1244 Ok(())
1245 }
1246
1247 async fn fetch_jwks(&self) -> Option<JwkSet> {
1249 let resp = match self.http.get(&self.jwks_uri).send().await {
1250 Ok(resp) => resp,
1251 Err(e) => {
1252 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
1253 return None;
1254 }
1255 };
1256 match resp.json::<JwkSet>().await {
1257 Ok(jwks) => Some(jwks),
1258 Err(e) => {
1259 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to parse JWKS");
1260 None
1261 }
1262 }
1263 }
1264
1265 #[cfg(any(test, feature = "test-helpers"))]
1268 #[doc(hidden)]
1269 pub async fn __test_refresh_now(&self) -> Result<(), String> {
1270 let jwks = self
1271 .fetch_jwks()
1272 .await
1273 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
1274 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
1275 let mut guard = self.inner.write().await;
1276 *guard = Some(CachedKeys {
1277 keys,
1278 unnamed_keys,
1279 fetched_at: Instant::now(),
1280 ttl: self.ttl,
1281 });
1282 Ok(())
1283 }
1284
1285 #[cfg(any(test, feature = "test-helpers"))]
1288 #[doc(hidden)]
1289 pub async fn __test_has_kid(&self, kid: &str) -> bool {
1290 let guard = self.inner.read().await;
1291 guard
1292 .as_ref()
1293 .is_some_and(|cache| cache.keys.contains_key(kid))
1294 }
1295}
1296
1297fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
1299 if jwks.keys.len() > max_keys {
1300 return Err(format!(
1301 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
1302 jwks.keys.len(),
1303 max_keys
1304 ));
1305 }
1306 let mut keys = HashMap::new();
1307 let mut unnamed_keys = Vec::new();
1308 for jwk in &jwks.keys {
1309 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
1310 continue;
1311 };
1312 let Some(alg) = jwk_algorithm(jwk) else {
1313 continue;
1314 };
1315 if let Some(ref kid) = jwk.common.key_id {
1316 keys.insert(kid.clone(), (alg, decoding_key));
1317 } else {
1318 unnamed_keys.push((alg, decoding_key));
1319 }
1320 }
1321 Ok((keys, unnamed_keys))
1322}
1323
1324fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1326 if let Some(kid) = kid
1327 && let Some((cached_alg, key)) = cached.keys.get(kid)
1328 && *cached_alg == alg
1329 {
1330 return Some(key.clone());
1331 }
1332 cached
1334 .unnamed_keys
1335 .iter()
1336 .find(|(a, _)| *a == alg)
1337 .map(|(_, k)| k.clone())
1338}
1339
1340#[allow(clippy::wildcard_enum_match_arm)]
1342fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
1343 jwk.common.key_algorithm.and_then(|ka| match ka {
1344 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
1345 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
1346 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
1347 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
1348 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
1349 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
1350 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
1351 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
1352 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
1353 _ => None,
1354 })
1355}
1356
1357fn resolve_claim_path<'a>(
1371 extra: &'a HashMap<String, serde_json::Value>,
1372 path: &str,
1373) -> Vec<&'a str> {
1374 let mut segments = path.split('.');
1375 let Some(first) = segments.next() else {
1376 return Vec::new();
1377 };
1378
1379 let mut current: Option<&serde_json::Value> = extra.get(first);
1380
1381 for segment in segments {
1382 current = current.and_then(|v| v.get(segment));
1383 }
1384
1385 match current {
1386 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
1387 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
1388 _ => Vec::new(),
1389 }
1390}
1391
1392#[derive(Debug, Deserialize)]
1398struct Claims {
1399 sub: Option<String>,
1401 #[serde(default)]
1404 aud: OneOrMany,
1405 azp: Option<String>,
1407 client_id: Option<String>,
1409 scope: Option<String>,
1411 #[serde(flatten)]
1413 extra: HashMap<String, serde_json::Value>,
1414}
1415
1416#[derive(Debug, Default)]
1418struct OneOrMany(Vec<String>);
1419
1420impl OneOrMany {
1421 fn contains(&self, value: &str) -> bool {
1422 self.0.iter().any(|v| v == value)
1423 }
1424}
1425
1426impl<'de> Deserialize<'de> for OneOrMany {
1427 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
1428 use serde::de;
1429
1430 struct Visitor;
1431 impl<'de> de::Visitor<'de> for Visitor {
1432 type Value = OneOrMany;
1433 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1434 f.write_str("a string or array of strings")
1435 }
1436 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
1437 Ok(OneOrMany(vec![v.to_owned()]))
1438 }
1439 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
1440 let mut v = Vec::new();
1441 while let Some(s) = seq.next_element::<String>()? {
1442 v.push(s);
1443 }
1444 Ok(OneOrMany(v))
1445 }
1446 }
1447 deserializer.deserialize_any(Visitor)
1448 }
1449}
1450
1451#[must_use]
1458pub fn looks_like_jwt(token: &str) -> bool {
1459 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1460
1461 let mut parts = token.splitn(4, '.');
1462 let Some(header_b64) = parts.next() else {
1463 return false;
1464 };
1465 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
1467 return false;
1468 }
1469 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
1471 return false;
1472 };
1473 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
1475 return false;
1476 };
1477 header.get("alg").is_some()
1478}
1479
1480#[must_use]
1490pub fn protected_resource_metadata(
1491 resource_url: &str,
1492 server_url: &str,
1493 config: &OAuthConfig,
1494) -> serde_json::Value {
1495 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
1500 let auth_server = server_url;
1501 serde_json::json!({
1502 "resource": resource_url,
1503 "authorization_servers": [auth_server],
1504 "scopes_supported": scopes,
1505 "bearer_methods_supported": ["header"]
1506 })
1507}
1508
1509#[must_use]
1514pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
1515 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
1516 let mut meta = serde_json::json!({
1517 "issuer": &config.issuer,
1518 "authorization_endpoint": format!("{server_url}/authorize"),
1519 "token_endpoint": format!("{server_url}/token"),
1520 "registration_endpoint": format!("{server_url}/register"),
1521 "response_types_supported": ["code"],
1522 "grant_types_supported": ["authorization_code", "refresh_token"],
1523 "code_challenge_methods_supported": ["S256"],
1524 "scopes_supported": scopes,
1525 "token_endpoint_auth_methods_supported": ["none"],
1526 });
1527 if let Some(proxy) = &config.proxy
1528 && proxy.expose_admin_endpoints
1529 && let Some(obj) = meta.as_object_mut()
1530 {
1531 if proxy.introspection_url.is_some() {
1532 obj.insert(
1533 "introspection_endpoint".into(),
1534 serde_json::Value::String(format!("{server_url}/introspect")),
1535 );
1536 }
1537 if proxy.revocation_url.is_some() {
1538 obj.insert(
1539 "revocation_endpoint".into(),
1540 serde_json::Value::String(format!("{server_url}/revoke")),
1541 );
1542 }
1543 }
1544 meta
1545}
1546
1547#[must_use]
1560pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
1561 use axum::{
1562 http::{StatusCode, header},
1563 response::IntoResponse,
1564 };
1565
1566 let upstream_query = replace_client_id(query, &proxy.client_id);
1568 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
1569
1570 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
1571}
1572
1573pub async fn handle_token(
1579 http: &OauthHttpClient,
1580 proxy: &OAuthProxyConfig,
1581 body: &str,
1582) -> axum::response::Response {
1583 use axum::{
1584 http::{StatusCode, header},
1585 response::IntoResponse,
1586 };
1587
1588 let mut upstream_body = replace_client_id(body, &proxy.client_id);
1590
1591 if let Some(ref secret) = proxy.client_secret {
1593 use std::fmt::Write;
1594
1595 use secrecy::ExposeSecret;
1596 let _ = write!(
1597 upstream_body,
1598 "&client_secret={}",
1599 urlencoding::encode(secret.expose_secret())
1600 );
1601 }
1602
1603 let result = http
1604 .inner
1605 .post(&proxy.token_url)
1606 .header("Content-Type", "application/x-www-form-urlencoded")
1607 .body(upstream_body)
1608 .send()
1609 .await;
1610
1611 match result {
1612 Ok(resp) => {
1613 let status =
1614 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
1615 let body_bytes = resp.bytes().await.unwrap_or_default();
1616 (
1617 status,
1618 [(header::CONTENT_TYPE, "application/json")],
1619 body_bytes,
1620 )
1621 .into_response()
1622 }
1623 Err(e) => {
1624 tracing::error!(error = %e, "OAuth token proxy request failed");
1625 (
1626 StatusCode::BAD_GATEWAY,
1627 [(header::CONTENT_TYPE, "application/json")],
1628 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
1629 )
1630 .into_response()
1631 }
1632 }
1633}
1634
1635#[must_use]
1642pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
1643 let mut resp = serde_json::json!({
1644 "client_id": proxy.client_id,
1645 "token_endpoint_auth_method": "none",
1646 });
1647 if let Some(uris) = body.get("redirect_uris")
1648 && let Some(obj) = resp.as_object_mut()
1649 {
1650 obj.insert("redirect_uris".into(), uris.clone());
1651 }
1652 if let Some(name) = body.get("client_name")
1653 && let Some(obj) = resp.as_object_mut()
1654 {
1655 obj.insert("client_name".into(), name.clone());
1656 }
1657 resp
1658}
1659
1660pub async fn handle_introspect(
1666 http: &OauthHttpClient,
1667 proxy: &OAuthProxyConfig,
1668 body: &str,
1669) -> axum::response::Response {
1670 let Some(ref url) = proxy.introspection_url else {
1671 return oauth_error_response(
1672 axum::http::StatusCode::NOT_FOUND,
1673 "not_supported",
1674 "introspection endpoint is not configured",
1675 );
1676 };
1677 proxy_oauth_admin_request(http, proxy, url, body).await
1678}
1679
1680pub async fn handle_revoke(
1687 http: &OauthHttpClient,
1688 proxy: &OAuthProxyConfig,
1689 body: &str,
1690) -> axum::response::Response {
1691 let Some(ref url) = proxy.revocation_url else {
1692 return oauth_error_response(
1693 axum::http::StatusCode::NOT_FOUND,
1694 "not_supported",
1695 "revocation endpoint is not configured",
1696 );
1697 };
1698 proxy_oauth_admin_request(http, proxy, url, body).await
1699}
1700
1701async fn proxy_oauth_admin_request(
1705 http: &OauthHttpClient,
1706 proxy: &OAuthProxyConfig,
1707 upstream_url: &str,
1708 body: &str,
1709) -> axum::response::Response {
1710 use axum::{
1711 http::{StatusCode, header},
1712 response::IntoResponse,
1713 };
1714
1715 let mut upstream_body = replace_client_id(body, &proxy.client_id);
1716 if let Some(ref secret) = proxy.client_secret {
1717 use std::fmt::Write;
1718
1719 use secrecy::ExposeSecret;
1720 let _ = write!(
1721 upstream_body,
1722 "&client_secret={}",
1723 urlencoding::encode(secret.expose_secret())
1724 );
1725 }
1726
1727 let result = http
1728 .inner
1729 .post(upstream_url)
1730 .header("Content-Type", "application/x-www-form-urlencoded")
1731 .body(upstream_body)
1732 .send()
1733 .await;
1734
1735 match result {
1736 Ok(resp) => {
1737 let status =
1738 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
1739 let content_type = resp
1740 .headers()
1741 .get(header::CONTENT_TYPE)
1742 .and_then(|v| v.to_str().ok())
1743 .unwrap_or("application/json")
1744 .to_owned();
1745 let body_bytes = resp.bytes().await.unwrap_or_default();
1746 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
1747 }
1748 Err(e) => {
1749 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
1750 oauth_error_response(
1751 StatusCode::BAD_GATEWAY,
1752 "server_error",
1753 "upstream endpoint unreachable",
1754 )
1755 }
1756 }
1757}
1758
1759fn oauth_error_response(
1760 status: axum::http::StatusCode,
1761 error: &str,
1762 description: &str,
1763) -> axum::response::Response {
1764 use axum::{http::header, response::IntoResponse};
1765 let body = serde_json::json!({
1766 "error": error,
1767 "error_description": description,
1768 });
1769 (
1770 status,
1771 [(header::CONTENT_TYPE, "application/json")],
1772 body.to_string(),
1773 )
1774 .into_response()
1775}
1776
1777#[derive(Debug, Deserialize)]
1783struct OAuthErrorResponse {
1784 error: String,
1785 error_description: Option<String>,
1786}
1787
1788fn sanitize_oauth_error_code(raw: &str) -> &'static str {
1795 match raw {
1796 "invalid_request" => "invalid_request",
1797 "invalid_client" => "invalid_client",
1798 "invalid_grant" => "invalid_grant",
1799 "unauthorized_client" => "unauthorized_client",
1800 "unsupported_grant_type" => "unsupported_grant_type",
1801 "invalid_scope" => "invalid_scope",
1802 "temporarily_unavailable" => "temporarily_unavailable",
1803 "invalid_target" => "invalid_target",
1805 _ => "server_error",
1808 }
1809}
1810
1811pub async fn exchange_token(
1823 http: &OauthHttpClient,
1824 config: &TokenExchangeConfig,
1825 subject_token: &str,
1826) -> Result<ExchangedToken, crate::error::McpxError> {
1827 use secrecy::ExposeSecret;
1828
1829 let mut req = http
1830 .inner
1831 .post(&config.token_url)
1832 .header("Content-Type", "application/x-www-form-urlencoded")
1833 .header("Accept", "application/json");
1834
1835 if let Some(ref secret) = config.client_secret {
1837 use base64::Engine;
1838 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
1839 "{}:{}",
1840 urlencoding::encode(&config.client_id),
1841 urlencoding::encode(secret.expose_secret()),
1842 ));
1843 req = req.header("Authorization", format!("Basic {credentials}"));
1844 }
1845 let form_body = build_exchange_form(config, subject_token);
1848
1849 let resp = req.body(form_body).send().await.map_err(|e| {
1850 tracing::error!(error = %e, "token exchange request failed");
1851 crate::error::McpxError::Auth("server_error".into())
1853 })?;
1854
1855 let status = resp.status();
1856 let body_bytes = resp.bytes().await.map_err(|e| {
1857 tracing::error!(error = %e, "failed to read token exchange response");
1858 crate::error::McpxError::Auth("server_error".into())
1859 })?;
1860
1861 if !status.is_success() {
1862 core::hint::cold_path();
1863 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
1866 let short_code = parsed
1867 .as_ref()
1868 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
1869 if let Some(ref e) = parsed {
1870 tracing::warn!(
1871 status = %status,
1872 upstream_error = %e.error,
1873 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
1874 client_code = %short_code,
1875 "token exchange rejected by authorization server",
1876 );
1877 } else {
1878 tracing::warn!(
1879 status = %status,
1880 client_code = %short_code,
1881 "token exchange rejected (unparseable upstream body)",
1882 );
1883 }
1884 return Err(crate::error::McpxError::Auth(short_code.into()));
1885 }
1886
1887 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
1888 tracing::error!(error = %e, "failed to parse token exchange response");
1889 crate::error::McpxError::Auth("server_error".into())
1892 })?;
1893
1894 log_exchanged_token(&exchanged);
1895
1896 Ok(exchanged)
1897}
1898
1899fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
1902 let body = format!(
1903 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
1904 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
1905 urlencoding::encode(subject_token),
1906 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
1907 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
1908 urlencoding::encode(&config.audience),
1909 );
1910 if config.client_secret.is_none() {
1911 format!(
1912 "{body}&client_id={}",
1913 urlencoding::encode(&config.client_id)
1914 )
1915 } else {
1916 body
1917 }
1918}
1919
1920fn log_exchanged_token(exchanged: &ExchangedToken) {
1923 use base64::Engine;
1924
1925 if !looks_like_jwt(&exchanged.access_token) {
1926 tracing::debug!(
1927 token_len = exchanged.access_token.len(),
1928 issued_token_type = ?exchanged.issued_token_type,
1929 expires_in = exchanged.expires_in,
1930 "exchanged token (opaque)",
1931 );
1932 return;
1933 }
1934 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
1935 return;
1936 };
1937 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
1938 return;
1939 };
1940 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
1941 return;
1942 };
1943 tracing::debug!(
1944 sub = ?claims.get("sub"),
1945 aud = ?claims.get("aud"),
1946 azp = ?claims.get("azp"),
1947 iss = ?claims.get("iss"),
1948 expires_in = exchanged.expires_in,
1949 "exchanged token claims (JWT)",
1950 );
1951}
1952
1953fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
1955 let encoded_id = urlencoding::encode(upstream_client_id);
1956 let mut parts: Vec<String> = params
1957 .split('&')
1958 .filter(|p| !p.starts_with("client_id="))
1959 .map(String::from)
1960 .collect();
1961 parts.push(format!("client_id={encoded_id}"));
1962 parts.join("&")
1963}
1964
1965#[cfg(test)]
1966mod tests {
1967 use std::sync::Arc;
1968
1969 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1970
1971 use super::*;
1972
1973 #[test]
1974 fn looks_like_jwt_valid() {
1975 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
1977 let payload = URL_SAFE_NO_PAD.encode(b"{}");
1978 let token = format!("{header}.{payload}.signature");
1979 assert!(looks_like_jwt(&token));
1980 }
1981
1982 #[test]
1983 fn looks_like_jwt_rejects_opaque_token() {
1984 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
1985 }
1986
1987 #[test]
1988 fn looks_like_jwt_rejects_two_segments() {
1989 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
1990 let token = format!("{header}.payload");
1991 assert!(!looks_like_jwt(&token));
1992 }
1993
1994 #[test]
1995 fn looks_like_jwt_rejects_four_segments() {
1996 assert!(!looks_like_jwt("a.b.c.d"));
1997 }
1998
1999 #[test]
2000 fn looks_like_jwt_rejects_no_alg() {
2001 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
2002 let payload = URL_SAFE_NO_PAD.encode(b"{}");
2003 let token = format!("{header}.{payload}.sig");
2004 assert!(!looks_like_jwt(&token));
2005 }
2006
2007 #[test]
2008 fn protected_resource_metadata_shape() {
2009 let config = OAuthConfig {
2010 issuer: "https://auth.example.com".into(),
2011 audience: "https://mcp.example.com/mcp".into(),
2012 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
2013 scopes: vec![
2014 ScopeMapping {
2015 scope: "mcp:read".into(),
2016 role: "viewer".into(),
2017 },
2018 ScopeMapping {
2019 scope: "mcp:admin".into(),
2020 role: "ops".into(),
2021 },
2022 ],
2023 role_claim: None,
2024 role_mappings: vec![],
2025 jwks_cache_ttl: "10m".into(),
2026 proxy: None,
2027 token_exchange: None,
2028 ca_cert_path: None,
2029 allow_http_oauth_urls: false,
2030 max_jwks_keys: default_max_jwks_keys(),
2031 };
2032 let meta = protected_resource_metadata(
2033 "https://mcp.example.com/mcp",
2034 "https://mcp.example.com",
2035 &config,
2036 );
2037 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
2038 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
2039 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
2040 assert_eq!(meta["bearer_methods_supported"][0], "header");
2041 }
2042
2043 fn validation_https_config() -> OAuthConfig {
2048 OAuthConfig::builder(
2049 "https://auth.example.com",
2050 "mcp",
2051 "https://auth.example.com/.well-known/jwks.json",
2052 )
2053 .build()
2054 }
2055
2056 #[test]
2057 fn validate_accepts_all_https_urls() {
2058 let cfg = validation_https_config();
2059 cfg.validate().expect("all-HTTPS config must validate");
2060 }
2061
2062 #[test]
2063 fn validate_rejects_http_jwks_uri() {
2064 let mut cfg = validation_https_config();
2065 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
2066 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
2067 let msg = err.to_string();
2068 assert!(
2069 msg.contains("oauth.jwks_uri") && msg.contains("https"),
2070 "error must reference offending field + scheme requirement; got {msg:?}"
2071 );
2072 }
2073
2074 #[test]
2075 fn validate_rejects_http_proxy_authorize_url() {
2076 let mut cfg = validation_https_config();
2077 cfg.proxy = Some(
2078 OAuthProxyConfig::builder(
2079 "http://idp.example.com/authorize", "https://idp.example.com/token",
2081 "client",
2082 )
2083 .build(),
2084 );
2085 let err = cfg
2086 .validate()
2087 .expect_err("http authorize_url must be rejected");
2088 assert!(
2089 err.to_string().contains("oauth.proxy.authorize_url"),
2090 "error must reference proxy.authorize_url; got {err}"
2091 );
2092 }
2093
2094 #[test]
2095 fn validate_rejects_http_proxy_token_url() {
2096 let mut cfg = validation_https_config();
2097 cfg.proxy = Some(
2098 OAuthProxyConfig::builder(
2099 "https://idp.example.com/authorize",
2100 "http://idp.example.com/token", "client",
2102 )
2103 .build(),
2104 );
2105 let err = cfg.validate().expect_err("http token_url must be rejected");
2106 assert!(
2107 err.to_string().contains("oauth.proxy.token_url"),
2108 "error must reference proxy.token_url; got {err}"
2109 );
2110 }
2111
2112 #[test]
2113 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
2114 let mut cfg = validation_https_config();
2115 cfg.proxy = Some(
2116 OAuthProxyConfig::builder(
2117 "https://idp.example.com/authorize",
2118 "https://idp.example.com/token",
2119 "client",
2120 )
2121 .introspection_url("http://idp.example.com/introspect")
2122 .build(),
2123 );
2124 let err = cfg
2125 .validate()
2126 .expect_err("http introspection_url must be rejected");
2127 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
2128
2129 let mut cfg = validation_https_config();
2130 cfg.proxy = Some(
2131 OAuthProxyConfig::builder(
2132 "https://idp.example.com/authorize",
2133 "https://idp.example.com/token",
2134 "client",
2135 )
2136 .revocation_url("http://idp.example.com/revoke")
2137 .build(),
2138 );
2139 let err = cfg
2140 .validate()
2141 .expect_err("http revocation_url must be rejected");
2142 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
2143 }
2144
2145 #[test]
2146 fn validate_rejects_http_token_exchange_url() {
2147 let mut cfg = validation_https_config();
2148 cfg.token_exchange = Some(TokenExchangeConfig::new(
2149 "http://idp.example.com/token".into(), "client".into(),
2151 None,
2152 None,
2153 "downstream".into(),
2154 ));
2155 let err = cfg
2156 .validate()
2157 .expect_err("http token_exchange.token_url must be rejected");
2158 assert!(
2159 err.to_string().contains("oauth.token_exchange.token_url"),
2160 "error must reference token_exchange.token_url; got {err}"
2161 );
2162 }
2163
2164 #[test]
2165 fn validate_rejects_unparseable_url() {
2166 let mut cfg = validation_https_config();
2167 cfg.jwks_uri = "not a url".into();
2168 let err = cfg
2169 .validate()
2170 .expect_err("unparseable URL must be rejected");
2171 assert!(err.to_string().contains("invalid URL"));
2172 }
2173
2174 #[test]
2175 fn validate_rejects_non_http_scheme() {
2176 let mut cfg = validation_https_config();
2177 cfg.jwks_uri = "file:///etc/passwd".into();
2178 let err = cfg.validate().expect_err("file:// scheme must be rejected");
2179 let msg = err.to_string();
2180 assert!(
2181 msg.contains("must use https scheme") && msg.contains("file"),
2182 "error must reject non-http(s) schemes; got {msg:?}"
2183 );
2184 }
2185
2186 #[test]
2187 fn validate_accepts_http_with_escape_hatch() {
2188 let mut cfg = OAuthConfig::builder(
2193 "http://auth.local",
2194 "mcp",
2195 "http://auth.local/.well-known/jwks.json",
2196 )
2197 .allow_http_oauth_urls(true)
2198 .build();
2199 cfg.proxy = Some(
2200 OAuthProxyConfig::builder(
2201 "http://idp.local/authorize",
2202 "http://idp.local/token",
2203 "client",
2204 )
2205 .introspection_url("http://idp.local/introspect")
2206 .revocation_url("http://idp.local/revoke")
2207 .build(),
2208 );
2209 cfg.token_exchange = Some(TokenExchangeConfig::new(
2210 "http://idp.local/token".into(),
2211 "client".into(),
2212 None,
2213 None,
2214 "downstream".into(),
2215 ));
2216 cfg.validate()
2217 .expect("escape hatch must permit http on all URL fields");
2218 }
2219
2220 #[test]
2221 fn validate_with_escape_hatch_still_rejects_unparseable() {
2222 let mut cfg = validation_https_config();
2225 cfg.allow_http_oauth_urls = true;
2226 cfg.jwks_uri = "::not-a-url::".into();
2227 cfg.validate()
2228 .expect_err("escape hatch must NOT bypass URL parsing");
2229 }
2230
2231 #[tokio::test]
2232 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
2233 rustls::crypto::ring::default_provider()
2248 .install_default()
2249 .ok();
2250
2251 let policy = reqwest::redirect::Policy::custom(|attempt| {
2252 if attempt.url().scheme() != "https" {
2253 attempt.error("redirect to non-HTTPS URL refused")
2254 } else if attempt.previous().len() >= 2 {
2255 attempt.error("too many redirects (max 2)")
2256 } else {
2257 attempt.follow()
2258 }
2259 });
2260 let client = reqwest::Client::builder()
2261 .timeout(Duration::from_secs(5))
2262 .connect_timeout(Duration::from_secs(3))
2263 .redirect(policy)
2264 .build()
2265 .expect("test client builds");
2266
2267 let mock = wiremock::MockServer::start().await;
2268 wiremock::Mock::given(wiremock::matchers::method("GET"))
2269 .and(wiremock::matchers::path("/jwks.json"))
2270 .respond_with(
2271 wiremock::ResponseTemplate::new(302)
2272 .insert_header("location", "http://example.invalid/jwks.json"),
2273 )
2274 .mount(&mock)
2275 .await;
2276
2277 let url = format!("{}/jwks.json", mock.uri());
2286 let err = client
2287 .get(&url)
2288 .send()
2289 .await
2290 .expect_err("redirect policy must reject scheme downgrade");
2291 let chain = format!("{err:#}");
2292 assert!(
2293 chain.contains("redirect to non-HTTPS URL refused")
2294 || chain.to_lowercase().contains("redirect"),
2295 "error must surface redirect-policy rejection; got {chain:?}"
2296 );
2297 }
2298
2299 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
2304
2305 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
2307 let mut rng = rsa::rand_core::OsRng;
2308 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
2309 let private_pem = private_key
2310 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
2311 .expect("PKCS8 PEM export")
2312 .to_string();
2313
2314 let public_key = private_key.to_public_key();
2315 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
2316 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
2317
2318 let jwks = serde_json::json!({
2319 "keys": [{
2320 "kty": "RSA",
2321 "use": "sig",
2322 "alg": "RS256",
2323 "kid": kid,
2324 "n": n,
2325 "e": e
2326 }]
2327 });
2328
2329 (private_pem, jwks)
2330 }
2331
2332 fn mint_token(
2334 private_pem: &str,
2335 kid: &str,
2336 issuer: &str,
2337 audience: &str,
2338 subject: &str,
2339 scope: &str,
2340 ) -> String {
2341 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
2342 .expect("encoding key from PEM");
2343 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2344 header.kid = Some(kid.into());
2345
2346 let now = jsonwebtoken::get_current_timestamp();
2347 let claims = serde_json::json!({
2348 "iss": issuer,
2349 "aud": audience,
2350 "sub": subject,
2351 "scope": scope,
2352 "exp": now + 3600,
2353 "iat": now,
2354 });
2355
2356 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
2357 }
2358
2359 fn test_config(jwks_uri: &str) -> OAuthConfig {
2360 OAuthConfig {
2361 issuer: "https://auth.test.local".into(),
2362 audience: "https://mcp.test.local/mcp".into(),
2363 jwks_uri: jwks_uri.into(),
2364 scopes: vec![
2365 ScopeMapping {
2366 scope: "mcp:read".into(),
2367 role: "viewer".into(),
2368 },
2369 ScopeMapping {
2370 scope: "mcp:admin".into(),
2371 role: "ops".into(),
2372 },
2373 ],
2374 role_claim: None,
2375 role_mappings: vec![],
2376 jwks_cache_ttl: "5m".into(),
2377 proxy: None,
2378 token_exchange: None,
2379 ca_cert_path: None,
2380 allow_http_oauth_urls: true,
2381 max_jwks_keys: default_max_jwks_keys(),
2382 }
2383 }
2384
2385 #[tokio::test]
2386 async fn valid_jwt_returns_identity() {
2387 let kid = "test-key-1";
2388 let (pem, jwks) = generate_test_keypair(kid);
2389
2390 let mock_server = wiremock::MockServer::start().await;
2391 wiremock::Mock::given(wiremock::matchers::method("GET"))
2392 .and(wiremock::matchers::path("/jwks.json"))
2393 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2394 .mount(&mock_server)
2395 .await;
2396
2397 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2398 let config = test_config(&jwks_uri);
2399 let cache = JwksCache::new(&config).unwrap();
2400
2401 let token = mint_token(
2402 &pem,
2403 kid,
2404 "https://auth.test.local",
2405 "https://mcp.test.local/mcp",
2406 "ci-bot",
2407 "mcp:read mcp:other",
2408 );
2409
2410 let identity = cache.validate_token(&token).await;
2411 assert!(identity.is_some(), "valid JWT should authenticate");
2412 let id = identity.unwrap();
2413 assert_eq!(id.name, "ci-bot");
2414 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
2416 }
2417
2418 #[tokio::test]
2419 async fn wrong_issuer_rejected() {
2420 let kid = "test-key-2";
2421 let (pem, jwks) = generate_test_keypair(kid);
2422
2423 let mock_server = wiremock::MockServer::start().await;
2424 wiremock::Mock::given(wiremock::matchers::method("GET"))
2425 .and(wiremock::matchers::path("/jwks.json"))
2426 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2427 .mount(&mock_server)
2428 .await;
2429
2430 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2431 let config = test_config(&jwks_uri);
2432 let cache = JwksCache::new(&config).unwrap();
2433
2434 let token = mint_token(
2435 &pem,
2436 kid,
2437 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
2439 "attacker",
2440 "mcp:admin",
2441 );
2442
2443 assert!(cache.validate_token(&token).await.is_none());
2444 }
2445
2446 #[tokio::test]
2447 async fn wrong_audience_rejected() {
2448 let kid = "test-key-3";
2449 let (pem, jwks) = generate_test_keypair(kid);
2450
2451 let mock_server = wiremock::MockServer::start().await;
2452 wiremock::Mock::given(wiremock::matchers::method("GET"))
2453 .and(wiremock::matchers::path("/jwks.json"))
2454 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2455 .mount(&mock_server)
2456 .await;
2457
2458 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2459 let config = test_config(&jwks_uri);
2460 let cache = JwksCache::new(&config).unwrap();
2461
2462 let token = mint_token(
2463 &pem,
2464 kid,
2465 "https://auth.test.local",
2466 "https://wrong-audience.example.com", "attacker",
2468 "mcp:admin",
2469 );
2470
2471 assert!(cache.validate_token(&token).await.is_none());
2472 }
2473
2474 #[tokio::test]
2475 async fn expired_jwt_rejected() {
2476 let kid = "test-key-4";
2477 let (pem, jwks) = generate_test_keypair(kid);
2478
2479 let mock_server = wiremock::MockServer::start().await;
2480 wiremock::Mock::given(wiremock::matchers::method("GET"))
2481 .and(wiremock::matchers::path("/jwks.json"))
2482 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2483 .mount(&mock_server)
2484 .await;
2485
2486 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2487 let config = test_config(&jwks_uri);
2488 let cache = JwksCache::new(&config).unwrap();
2489
2490 let encoding_key =
2492 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
2493 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2494 header.kid = Some(kid.into());
2495 let now = jsonwebtoken::get_current_timestamp();
2496 let claims = serde_json::json!({
2497 "iss": "https://auth.test.local",
2498 "aud": "https://mcp.test.local/mcp",
2499 "sub": "expired-bot",
2500 "scope": "mcp:read",
2501 "exp": now - 120,
2502 "iat": now - 3720,
2503 });
2504 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
2505
2506 assert!(cache.validate_token(&token).await.is_none());
2507 }
2508
2509 #[tokio::test]
2510 async fn no_matching_scope_rejected() {
2511 let kid = "test-key-5";
2512 let (pem, jwks) = generate_test_keypair(kid);
2513
2514 let mock_server = wiremock::MockServer::start().await;
2515 wiremock::Mock::given(wiremock::matchers::method("GET"))
2516 .and(wiremock::matchers::path("/jwks.json"))
2517 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2518 .mount(&mock_server)
2519 .await;
2520
2521 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2522 let config = test_config(&jwks_uri);
2523 let cache = JwksCache::new(&config).unwrap();
2524
2525 let token = mint_token(
2526 &pem,
2527 kid,
2528 "https://auth.test.local",
2529 "https://mcp.test.local/mcp",
2530 "limited-bot",
2531 "some:other:scope", );
2533
2534 assert!(cache.validate_token(&token).await.is_none());
2535 }
2536
2537 #[tokio::test]
2538 async fn wrong_signing_key_rejected() {
2539 let kid = "test-key-6";
2540 let (_pem, jwks) = generate_test_keypair(kid);
2541
2542 let (attacker_pem, _) = generate_test_keypair(kid);
2544
2545 let mock_server = wiremock::MockServer::start().await;
2546 wiremock::Mock::given(wiremock::matchers::method("GET"))
2547 .and(wiremock::matchers::path("/jwks.json"))
2548 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2549 .mount(&mock_server)
2550 .await;
2551
2552 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2553 let config = test_config(&jwks_uri);
2554 let cache = JwksCache::new(&config).unwrap();
2555
2556 let token = mint_token(
2558 &attacker_pem,
2559 kid,
2560 "https://auth.test.local",
2561 "https://mcp.test.local/mcp",
2562 "attacker",
2563 "mcp:admin",
2564 );
2565
2566 assert!(cache.validate_token(&token).await.is_none());
2567 }
2568
2569 #[tokio::test]
2570 async fn admin_scope_maps_to_ops_role() {
2571 let kid = "test-key-7";
2572 let (pem, jwks) = generate_test_keypair(kid);
2573
2574 let mock_server = wiremock::MockServer::start().await;
2575 wiremock::Mock::given(wiremock::matchers::method("GET"))
2576 .and(wiremock::matchers::path("/jwks.json"))
2577 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2578 .mount(&mock_server)
2579 .await;
2580
2581 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2582 let config = test_config(&jwks_uri);
2583 let cache = JwksCache::new(&config).unwrap();
2584
2585 let token = mint_token(
2586 &pem,
2587 kid,
2588 "https://auth.test.local",
2589 "https://mcp.test.local/mcp",
2590 "admin-bot",
2591 "mcp:admin",
2592 );
2593
2594 let id = cache
2595 .validate_token(&token)
2596 .await
2597 .expect("should authenticate");
2598 assert_eq!(id.role, "ops");
2599 assert_eq!(id.name, "admin-bot");
2600 }
2601
2602 #[tokio::test]
2603 async fn jwks_server_down_returns_none() {
2604 let config = test_config("http://127.0.0.1:1/jwks.json");
2606 let cache = JwksCache::new(&config).unwrap();
2607
2608 let kid = "orphan-key";
2609 let (pem, _) = generate_test_keypair(kid);
2610 let token = mint_token(
2611 &pem,
2612 kid,
2613 "https://auth.test.local",
2614 "https://mcp.test.local/mcp",
2615 "bot",
2616 "mcp:read",
2617 );
2618
2619 assert!(cache.validate_token(&token).await.is_none());
2620 }
2621
2622 #[test]
2627 fn resolve_claim_path_flat_string() {
2628 let mut extra = HashMap::new();
2629 extra.insert(
2630 "scope".into(),
2631 serde_json::Value::String("mcp:read mcp:admin".into()),
2632 );
2633 let values = resolve_claim_path(&extra, "scope");
2634 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
2635 }
2636
2637 #[test]
2638 fn resolve_claim_path_flat_array() {
2639 let mut extra = HashMap::new();
2640 extra.insert(
2641 "roles".into(),
2642 serde_json::json!(["mcp-admin", "mcp-viewer"]),
2643 );
2644 let values = resolve_claim_path(&extra, "roles");
2645 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
2646 }
2647
2648 #[test]
2649 fn resolve_claim_path_nested_keycloak() {
2650 let mut extra = HashMap::new();
2651 extra.insert(
2652 "realm_access".into(),
2653 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
2654 );
2655 let values = resolve_claim_path(&extra, "realm_access.roles");
2656 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
2657 }
2658
2659 #[test]
2660 fn resolve_claim_path_missing_returns_empty() {
2661 let extra = HashMap::new();
2662 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
2663 }
2664
2665 #[test]
2666 fn resolve_claim_path_numeric_leaf_returns_empty() {
2667 let mut extra = HashMap::new();
2668 extra.insert("count".into(), serde_json::json!(42));
2669 assert!(resolve_claim_path(&extra, "count").is_empty());
2670 }
2671
2672 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
2678 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
2679 .expect("encoding key from PEM");
2680 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2681 header.kid = Some(kid.into());
2682 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
2683 }
2684
2685 fn test_config_with_role_claim(
2686 jwks_uri: &str,
2687 role_claim: &str,
2688 role_mappings: Vec<RoleMapping>,
2689 ) -> OAuthConfig {
2690 OAuthConfig {
2691 issuer: "https://auth.test.local".into(),
2692 audience: "https://mcp.test.local/mcp".into(),
2693 jwks_uri: jwks_uri.into(),
2694 scopes: vec![],
2695 role_claim: Some(role_claim.into()),
2696 role_mappings,
2697 jwks_cache_ttl: "5m".into(),
2698 proxy: None,
2699 token_exchange: None,
2700 ca_cert_path: None,
2701 allow_http_oauth_urls: true,
2702 max_jwks_keys: default_max_jwks_keys(),
2703 }
2704 }
2705
2706 #[tokio::test]
2707 async fn role_claim_keycloak_nested_array() {
2708 let kid = "test-role-1";
2709 let (pem, jwks) = generate_test_keypair(kid);
2710
2711 let mock_server = wiremock::MockServer::start().await;
2712 wiremock::Mock::given(wiremock::matchers::method("GET"))
2713 .and(wiremock::matchers::path("/jwks.json"))
2714 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2715 .mount(&mock_server)
2716 .await;
2717
2718 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2719 let config = test_config_with_role_claim(
2720 &jwks_uri,
2721 "realm_access.roles",
2722 vec![
2723 RoleMapping {
2724 claim_value: "mcp-admin".into(),
2725 role: "ops".into(),
2726 },
2727 RoleMapping {
2728 claim_value: "mcp-viewer".into(),
2729 role: "viewer".into(),
2730 },
2731 ],
2732 );
2733 let cache = JwksCache::new(&config).unwrap();
2734
2735 let now = jsonwebtoken::get_current_timestamp();
2736 let token = mint_token_with_claims(
2737 &pem,
2738 kid,
2739 &serde_json::json!({
2740 "iss": "https://auth.test.local",
2741 "aud": "https://mcp.test.local/mcp",
2742 "sub": "keycloak-user",
2743 "exp": now + 3600,
2744 "iat": now,
2745 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
2746 }),
2747 );
2748
2749 let id = cache
2750 .validate_token(&token)
2751 .await
2752 .expect("should authenticate");
2753 assert_eq!(id.name, "keycloak-user");
2754 assert_eq!(id.role, "ops");
2755 }
2756
2757 #[tokio::test]
2758 async fn role_claim_flat_roles_array() {
2759 let kid = "test-role-2";
2760 let (pem, jwks) = generate_test_keypair(kid);
2761
2762 let mock_server = wiremock::MockServer::start().await;
2763 wiremock::Mock::given(wiremock::matchers::method("GET"))
2764 .and(wiremock::matchers::path("/jwks.json"))
2765 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2766 .mount(&mock_server)
2767 .await;
2768
2769 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2770 let config = test_config_with_role_claim(
2771 &jwks_uri,
2772 "roles",
2773 vec![
2774 RoleMapping {
2775 claim_value: "MCP.Admin".into(),
2776 role: "ops".into(),
2777 },
2778 RoleMapping {
2779 claim_value: "MCP.Reader".into(),
2780 role: "viewer".into(),
2781 },
2782 ],
2783 );
2784 let cache = JwksCache::new(&config).unwrap();
2785
2786 let now = jsonwebtoken::get_current_timestamp();
2787 let token = mint_token_with_claims(
2788 &pem,
2789 kid,
2790 &serde_json::json!({
2791 "iss": "https://auth.test.local",
2792 "aud": "https://mcp.test.local/mcp",
2793 "sub": "azure-ad-user",
2794 "exp": now + 3600,
2795 "iat": now,
2796 "roles": ["MCP.Reader", "OtherApp.Admin"]
2797 }),
2798 );
2799
2800 let id = cache
2801 .validate_token(&token)
2802 .await
2803 .expect("should authenticate");
2804 assert_eq!(id.name, "azure-ad-user");
2805 assert_eq!(id.role, "viewer");
2806 }
2807
2808 #[tokio::test]
2809 async fn role_claim_no_matching_value_rejected() {
2810 let kid = "test-role-3";
2811 let (pem, jwks) = generate_test_keypair(kid);
2812
2813 let mock_server = wiremock::MockServer::start().await;
2814 wiremock::Mock::given(wiremock::matchers::method("GET"))
2815 .and(wiremock::matchers::path("/jwks.json"))
2816 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2817 .mount(&mock_server)
2818 .await;
2819
2820 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2821 let config = test_config_with_role_claim(
2822 &jwks_uri,
2823 "roles",
2824 vec![RoleMapping {
2825 claim_value: "mcp-admin".into(),
2826 role: "ops".into(),
2827 }],
2828 );
2829 let cache = JwksCache::new(&config).unwrap();
2830
2831 let now = jsonwebtoken::get_current_timestamp();
2832 let token = mint_token_with_claims(
2833 &pem,
2834 kid,
2835 &serde_json::json!({
2836 "iss": "https://auth.test.local",
2837 "aud": "https://mcp.test.local/mcp",
2838 "sub": "limited-user",
2839 "exp": now + 3600,
2840 "iat": now,
2841 "roles": ["some-other-role"]
2842 }),
2843 );
2844
2845 assert!(cache.validate_token(&token).await.is_none());
2846 }
2847
2848 #[tokio::test]
2849 async fn role_claim_space_separated_string() {
2850 let kid = "test-role-4";
2851 let (pem, jwks) = generate_test_keypair(kid);
2852
2853 let mock_server = wiremock::MockServer::start().await;
2854 wiremock::Mock::given(wiremock::matchers::method("GET"))
2855 .and(wiremock::matchers::path("/jwks.json"))
2856 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2857 .mount(&mock_server)
2858 .await;
2859
2860 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2861 let config = test_config_with_role_claim(
2862 &jwks_uri,
2863 "custom_scope",
2864 vec![
2865 RoleMapping {
2866 claim_value: "write".into(),
2867 role: "ops".into(),
2868 },
2869 RoleMapping {
2870 claim_value: "read".into(),
2871 role: "viewer".into(),
2872 },
2873 ],
2874 );
2875 let cache = JwksCache::new(&config).unwrap();
2876
2877 let now = jsonwebtoken::get_current_timestamp();
2878 let token = mint_token_with_claims(
2879 &pem,
2880 kid,
2881 &serde_json::json!({
2882 "iss": "https://auth.test.local",
2883 "aud": "https://mcp.test.local/mcp",
2884 "sub": "custom-client",
2885 "exp": now + 3600,
2886 "iat": now,
2887 "custom_scope": "read audit"
2888 }),
2889 );
2890
2891 let id = cache
2892 .validate_token(&token)
2893 .await
2894 .expect("should authenticate");
2895 assert_eq!(id.name, "custom-client");
2896 assert_eq!(id.role, "viewer");
2897 }
2898
2899 #[tokio::test]
2900 async fn scope_backward_compat_without_role_claim() {
2901 let kid = "test-compat-1";
2903 let (pem, jwks) = generate_test_keypair(kid);
2904
2905 let mock_server = wiremock::MockServer::start().await;
2906 wiremock::Mock::given(wiremock::matchers::method("GET"))
2907 .and(wiremock::matchers::path("/jwks.json"))
2908 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2909 .mount(&mock_server)
2910 .await;
2911
2912 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2913 let config = test_config(&jwks_uri); let cache = JwksCache::new(&config).unwrap();
2915
2916 let token = mint_token(
2917 &pem,
2918 kid,
2919 "https://auth.test.local",
2920 "https://mcp.test.local/mcp",
2921 "legacy-bot",
2922 "mcp:admin other:scope",
2923 );
2924
2925 let id = cache
2926 .validate_token(&token)
2927 .await
2928 .expect("should authenticate");
2929 assert_eq!(id.name, "legacy-bot");
2930 assert_eq!(id.role, "ops"); }
2932
2933 #[tokio::test]
2938 async fn jwks_refresh_deduplication() {
2939 let kid = "test-dedup";
2942 let (pem, jwks) = generate_test_keypair(kid);
2943
2944 let mock_server = wiremock::MockServer::start().await;
2945 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
2946 .and(wiremock::matchers::path("/jwks.json"))
2947 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2948 .expect(1) .mount(&mock_server)
2950 .await;
2951
2952 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2953 let config = test_config(&jwks_uri);
2954 let cache = Arc::new(JwksCache::new(&config).unwrap());
2955
2956 let token = mint_token(
2958 &pem,
2959 kid,
2960 "https://auth.test.local",
2961 "https://mcp.test.local/mcp",
2962 "concurrent-bot",
2963 "mcp:read",
2964 );
2965
2966 let mut handles = Vec::new();
2967 for _ in 0..5 {
2968 let c = Arc::clone(&cache);
2969 let t = token.clone();
2970 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
2971 }
2972
2973 for h in handles {
2974 let result = h.await.unwrap();
2975 assert!(result.is_some(), "all concurrent requests should succeed");
2976 }
2977
2978 }
2980
2981 #[tokio::test]
2982 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
2983 let kid = "test-cooldown";
2986 let (_pem, jwks) = generate_test_keypair(kid);
2987
2988 let mock_server = wiremock::MockServer::start().await;
2989 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
2990 .and(wiremock::matchers::path("/jwks.json"))
2991 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2992 .expect(1) .mount(&mock_server)
2994 .await;
2995
2996 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2997 let config = test_config(&jwks_uri);
2998 let cache = JwksCache::new(&config).unwrap();
2999
3000 let fake_token1 =
3002 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
3003 let _ = cache.validate_token(fake_token1).await;
3004
3005 let fake_token2 =
3008 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
3009 let _ = cache.validate_token(fake_token2).await;
3010
3011 let fake_token3 =
3013 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
3014 let _ = cache.validate_token(fake_token3).await;
3015
3016 }
3018
3019 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
3022 OAuthProxyConfig {
3023 authorize_url: "https://example.invalid/auth".into(),
3024 token_url: token_url.into(),
3025 client_id: "mcp-client".into(),
3026 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
3027 introspection_url: None,
3028 revocation_url: None,
3029 expose_admin_endpoints: false,
3030 }
3031 }
3032
3033 fn test_http_client() -> OauthHttpClient {
3036 rustls::crypto::ring::default_provider()
3037 .install_default()
3038 .ok();
3039 OauthHttpClient::with_config(&OAuthConfig::default()).expect("build test http client")
3040 }
3041
3042 #[tokio::test]
3043 async fn introspect_proxies_and_injects_client_credentials() {
3044 use wiremock::matchers::{body_string_contains, method, path};
3045
3046 let mock_server = wiremock::MockServer::start().await;
3047 wiremock::Mock::given(method("POST"))
3048 .and(path("/introspect"))
3049 .and(body_string_contains("client_id=mcp-client"))
3050 .and(body_string_contains("client_secret=shh"))
3051 .and(body_string_contains("token=abc"))
3052 .respond_with(
3053 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
3054 "active": true,
3055 "scope": "read"
3056 })),
3057 )
3058 .expect(1)
3059 .mount(&mock_server)
3060 .await;
3061
3062 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
3063 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
3064
3065 let http = test_http_client();
3066 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3067 assert_eq!(resp.status(), 200);
3068 }
3069
3070 #[tokio::test]
3071 async fn introspect_returns_404_when_not_configured() {
3072 let proxy = proxy_cfg("https://example.invalid/token");
3073 let http = test_http_client();
3074 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3075 assert_eq!(resp.status(), 404);
3076 }
3077
3078 #[tokio::test]
3079 async fn revoke_proxies_and_returns_upstream_status() {
3080 use wiremock::matchers::{method, path};
3081
3082 let mock_server = wiremock::MockServer::start().await;
3083 wiremock::Mock::given(method("POST"))
3084 .and(path("/revoke"))
3085 .respond_with(wiremock::ResponseTemplate::new(200))
3086 .expect(1)
3087 .mount(&mock_server)
3088 .await;
3089
3090 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
3091 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
3092
3093 let http = test_http_client();
3094 let resp = handle_revoke(&http, &proxy, "token=abc").await;
3095 assert_eq!(resp.status(), 200);
3096 }
3097
3098 #[tokio::test]
3099 async fn revoke_returns_404_when_not_configured() {
3100 let proxy = proxy_cfg("https://example.invalid/token");
3101 let http = test_http_client();
3102 let resp = handle_revoke(&http, &proxy, "token=abc").await;
3103 assert_eq!(resp.status(), 404);
3104 }
3105
3106 #[test]
3107 fn metadata_advertises_endpoints_only_when_configured() {
3108 let mut cfg = test_config("https://auth.test.local/jwks.json");
3109 let m = authorization_server_metadata("https://mcp.local", &cfg);
3111 assert!(m.get("introspection_endpoint").is_none());
3112 assert!(m.get("revocation_endpoint").is_none());
3113
3114 let mut proxy = proxy_cfg("https://upstream.local/token");
3117 proxy.introspection_url = Some("https://upstream.local/introspect".into());
3118 proxy.revocation_url = Some("https://upstream.local/revoke".into());
3119 cfg.proxy = Some(proxy);
3120 let m = authorization_server_metadata("https://mcp.local", &cfg);
3121 assert!(
3122 m.get("introspection_endpoint").is_none(),
3123 "introspection must not be advertised when expose_admin_endpoints=false"
3124 );
3125 assert!(
3126 m.get("revocation_endpoint").is_none(),
3127 "revocation must not be advertised when expose_admin_endpoints=false"
3128 );
3129
3130 if let Some(p) = cfg.proxy.as_mut() {
3132 p.expose_admin_endpoints = true;
3133 p.revocation_url = None;
3134 }
3135 let m = authorization_server_metadata("https://mcp.local", &cfg);
3136 assert_eq!(
3137 m["introspection_endpoint"],
3138 serde_json::Value::String("https://mcp.local/introspect".into())
3139 );
3140 assert!(m.get("revocation_endpoint").is_none());
3141
3142 if let Some(p) = cfg.proxy.as_mut() {
3144 p.revocation_url = Some("https://upstream.local/revoke".into());
3145 }
3146 let m = authorization_server_metadata("https://mcp.local", &cfg);
3147 assert_eq!(
3148 m["revocation_endpoint"],
3149 serde_json::Value::String("https://mcp.local/revoke".into())
3150 );
3151 }
3152}