1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 time::{Duration, Instant},
20};
21
22use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
23use serde::Deserialize;
24use tokio::{net::lookup_host, sync::RwLock};
25
26use crate::auth::{AuthIdentity, AuthMethod};
27
28fn evaluate_oauth_redirect(
51 attempt: &reqwest::redirect::Attempt<'_>,
52 allow_http: bool,
53) -> Result<(), String> {
54 let prev_https = attempt
55 .previous()
56 .last()
57 .is_some_and(|prev| prev.scheme() == "https");
58 let target_url = attempt.url();
59 let dest_scheme = target_url.scheme();
60 if dest_scheme != "https" {
61 if prev_https {
62 return Err("redirect downgrades https -> http".to_owned());
63 }
64 if !allow_http || dest_scheme != "http" {
65 return Err("redirect to non-HTTP(S) URL refused".to_owned());
66 }
67 }
68 if let Some(reason) = crate::ssrf::redirect_target_reason(target_url) {
69 return Err(format!("redirect target forbidden: {reason}"));
70 }
71 if attempt.previous().len() >= 2 {
72 return Err("too many redirects (max 2)".to_owned());
73 }
74 Ok(())
75}
76
77#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
86async fn screen_oauth_target_with_test_override(
87 url: &str,
88 allow_http: bool,
89 #[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
90) -> Result<(), crate::error::McpxError> {
91 let parsed = check_oauth_url("oauth target", url, allow_http)?;
92 #[cfg(any(test, feature = "test-helpers"))]
93 if test_allow_loopback_ssrf {
94 return Ok(());
95 }
96 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
97 return Err(crate::error::McpxError::Config(format!(
98 "OAuth target forbidden ({reason}): {url}"
99 )));
100 }
101
102 let host = parsed.host_str().ok_or_else(|| {
103 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
104 })?;
105 let port = parsed.port_or_known_default().ok_or_else(|| {
106 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
107 })?;
108
109 let addrs = lookup_host((host, port)).await.map_err(|error| {
110 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
111 })?;
112
113 let mut any_addr = false;
114 for addr in addrs {
115 any_addr = true;
116 if let Some(reason) = crate::ssrf::ip_block_reason(addr.ip()) {
117 return Err(crate::error::McpxError::Config(format!(
118 "OAuth target resolved to blocked IP ({reason}): {url}"
119 )));
120 }
121 }
122 if !any_addr {
123 return Err(crate::error::McpxError::Config(format!(
124 "OAuth target DNS resolution returned no addresses: {url}"
125 )));
126 }
127
128 Ok(())
129}
130
131async fn screen_oauth_target(url: &str, allow_http: bool) -> Result<(), crate::error::McpxError> {
132 #[cfg(any(test, feature = "test-helpers"))]
133 {
134 screen_oauth_target_with_test_override(url, allow_http, false).await
135 }
136 #[cfg(not(any(test, feature = "test-helpers")))]
137 {
138 let parsed = check_oauth_url("oauth target", url, allow_http)?;
139 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
140 return Err(crate::error::McpxError::Config(format!(
141 "OAuth target forbidden ({reason}): {url}"
142 )));
143 }
144
145 let host = parsed.host_str().ok_or_else(|| {
146 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
147 })?;
148 let port = parsed.port_or_known_default().ok_or_else(|| {
149 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
150 })?;
151
152 let addrs = lookup_host((host, port)).await.map_err(|error| {
153 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
154 })?;
155
156 let mut any_addr = false;
157 for addr in addrs {
158 any_addr = true;
159 if let Some(reason) = crate::ssrf::ip_block_reason(addr.ip()) {
160 return Err(crate::error::McpxError::Config(format!(
161 "OAuth target resolved to blocked IP ({reason}): {url}"
162 )));
163 }
164 }
165 if !any_addr {
166 return Err(crate::error::McpxError::Config(format!(
167 "OAuth target DNS resolution returned no addresses: {url}"
168 )));
169 }
170
171 Ok(())
172 }
173}
174
175#[derive(Clone)]
216pub struct OauthHttpClient {
217 inner: reqwest::Client,
218 allow_http: bool,
219 #[cfg(any(test, feature = "test-helpers"))]
220 test_allow_loopback_ssrf: bool,
221}
222
223impl OauthHttpClient {
224 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
242 Self::build(Some(config))
243 }
244
245 #[deprecated(
268 since = "1.2.1",
269 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
270 )]
271 pub fn new() -> Result<Self, crate::error::McpxError> {
272 Self::build(None)
273 }
274
275 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
278 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
279
280 let mut builder = reqwest::Client::builder()
281 .connect_timeout(Duration::from_secs(10))
282 .timeout(Duration::from_secs(30))
283 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
284 match evaluate_oauth_redirect(&attempt, allow_http) {
294 Ok(()) => attempt.follow(),
295 Err(reason) => {
296 tracing::warn!(
297 reason = %reason,
298 target = %attempt.url(),
299 "oauth redirect rejected"
300 );
301 attempt.error(reason)
302 }
303 }
304 }));
305
306 if let Some(cfg) = config
307 && let Some(ref ca_path) = cfg.ca_cert_path
308 {
309 let pem = std::fs::read(ca_path).map_err(|e| {
314 crate::error::McpxError::Startup(format!(
315 "oauth http client: read ca_cert_path {}: {e}",
316 ca_path.display()
317 ))
318 })?;
319 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
320 crate::error::McpxError::Startup(format!(
321 "oauth http client: parse ca_cert_path {}: {e}",
322 ca_path.display()
323 ))
324 })?;
325 builder = builder.add_root_certificate(cert);
326 }
327
328 let inner = builder.build().map_err(|e| {
329 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
330 })?;
331 Ok(Self {
332 inner,
333 allow_http,
334 #[cfg(any(test, feature = "test-helpers"))]
335 test_allow_loopback_ssrf: false,
336 })
337 }
338
339 async fn send_screened(
340 &self,
341 url: &str,
342 request: reqwest::RequestBuilder,
343 ) -> Result<reqwest::Response, crate::error::McpxError> {
344 #[cfg(any(test, feature = "test-helpers"))]
345 if self.test_allow_loopback_ssrf {
346 screen_oauth_target_with_test_override(url, self.allow_http, true).await?;
347 } else {
348 screen_oauth_target(url, self.allow_http).await?;
349 }
350 #[cfg(not(any(test, feature = "test-helpers")))]
351 screen_oauth_target(url, self.allow_http).await?;
352 request.send().await.map_err(|error| {
353 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
354 })
355 }
356
357 #[cfg(any(test, feature = "test-helpers"))]
362 #[doc(hidden)]
363 #[must_use]
364 pub fn __test_allow_loopback_ssrf(mut self) -> Self {
365 self.test_allow_loopback_ssrf = true;
366 self
367 }
368
369 #[doc(hidden)]
375 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
376 self.inner.get(url).send().await
377 }
378}
379
380impl std::fmt::Debug for OauthHttpClient {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
383 }
384}
385
386#[derive(Debug, Clone, Deserialize)]
392#[non_exhaustive]
393pub struct OAuthConfig {
394 pub issuer: String,
396 pub audience: String,
398 pub jwks_uri: String,
400 #[serde(default)]
403 pub scopes: Vec<ScopeMapping>,
404 pub role_claim: Option<String>,
410 #[serde(default)]
413 pub role_mappings: Vec<RoleMapping>,
414 #[serde(default = "default_jwks_cache_ttl")]
417 pub jwks_cache_ttl: String,
418 pub proxy: Option<OAuthProxyConfig>,
422 pub token_exchange: Option<TokenExchangeConfig>,
427 #[serde(default)]
442 pub ca_cert_path: Option<PathBuf>,
443 #[serde(default)]
455 pub allow_http_oauth_urls: bool,
456 #[serde(default = "default_max_jwks_keys")]
460 pub max_jwks_keys: usize,
461 #[serde(default)]
468 pub strict_audience_validation: bool,
469 #[serde(default = "default_jwks_max_bytes")]
473 pub jwks_max_response_bytes: u64,
474}
475
476fn default_jwks_cache_ttl() -> String {
477 "10m".into()
478}
479
480const fn default_max_jwks_keys() -> usize {
481 256
482}
483
484const fn default_jwks_max_bytes() -> u64 {
485 1024 * 1024
486}
487
488impl Default for OAuthConfig {
489 fn default() -> Self {
490 Self {
491 issuer: String::new(),
492 audience: String::new(),
493 jwks_uri: String::new(),
494 scopes: Vec::new(),
495 role_claim: None,
496 role_mappings: Vec::new(),
497 jwks_cache_ttl: default_jwks_cache_ttl(),
498 proxy: None,
499 token_exchange: None,
500 ca_cert_path: None,
501 allow_http_oauth_urls: false,
502 max_jwks_keys: default_max_jwks_keys(),
503 strict_audience_validation: false,
504 jwks_max_response_bytes: default_jwks_max_bytes(),
505 }
506 }
507}
508
509impl OAuthConfig {
510 pub fn builder(
516 issuer: impl Into<String>,
517 audience: impl Into<String>,
518 jwks_uri: impl Into<String>,
519 ) -> OAuthConfigBuilder {
520 OAuthConfigBuilder {
521 inner: Self {
522 issuer: issuer.into(),
523 audience: audience.into(),
524 jwks_uri: jwks_uri.into(),
525 ..Self::default()
526 },
527 }
528 }
529
530 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
546 let allow_http = self.allow_http_oauth_urls;
547 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
548 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
549 return Err(crate::error::McpxError::Config(format!(
550 "oauth.issuer forbidden ({reason})"
551 )));
552 }
553 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
554 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
555 return Err(crate::error::McpxError::Config(format!(
556 "oauth.jwks_uri forbidden ({reason})"
557 )));
558 }
559 if let Some(proxy) = &self.proxy {
560 let url = check_oauth_url(
561 "oauth.proxy.authorize_url",
562 &proxy.authorize_url,
563 allow_http,
564 )?;
565 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
566 return Err(crate::error::McpxError::Config(format!(
567 "oauth.proxy.authorize_url forbidden ({reason})"
568 )));
569 }
570 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
571 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
572 return Err(crate::error::McpxError::Config(format!(
573 "oauth.proxy.token_url forbidden ({reason})"
574 )));
575 }
576 if let Some(url) = &proxy.introspection_url {
577 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
578 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
579 return Err(crate::error::McpxError::Config(format!(
580 "oauth.proxy.introspection_url forbidden ({reason})"
581 )));
582 }
583 }
584 if let Some(url) = &proxy.revocation_url {
585 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
586 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
587 return Err(crate::error::McpxError::Config(format!(
588 "oauth.proxy.revocation_url forbidden ({reason})"
589 )));
590 }
591 }
592 }
593 if let Some(tx) = &self.token_exchange {
594 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
595 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
596 return Err(crate::error::McpxError::Config(format!(
597 "oauth.token_exchange.token_url forbidden ({reason})"
598 )));
599 }
600 }
601 Ok(())
602 }
603}
604
605fn check_oauth_url(
612 field: &str,
613 raw: &str,
614 allow_http: bool,
615) -> Result<url::Url, crate::error::McpxError> {
616 let parsed = url::Url::parse(raw).map_err(|e| {
617 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
618 })?;
619 if !parsed.username().is_empty() || parsed.password().is_some() {
620 return Err(crate::error::McpxError::Config(format!(
621 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
622 )));
623 }
624 match parsed.scheme() {
625 "https" => Ok(parsed),
626 "http" if allow_http => Ok(parsed),
627 "http" => Err(crate::error::McpxError::Config(format!(
628 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
629 to override - strongly discouraged in production)"
630 ))),
631 other => Err(crate::error::McpxError::Config(format!(
632 "{field}: must use https scheme (got {other:?})"
633 ))),
634 }
635}
636
637#[derive(Debug, Clone)]
643#[must_use = "builders do nothing until `.build()` is called"]
644pub struct OAuthConfigBuilder {
645 inner: OAuthConfig,
646}
647
648impl OAuthConfigBuilder {
649 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
651 self.inner.scopes = scopes;
652 self
653 }
654
655 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
657 self.inner.scopes.push(ScopeMapping {
658 scope: scope.into(),
659 role: role.into(),
660 });
661 self
662 }
663
664 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
667 self.inner.role_claim = Some(claim.into());
668 self
669 }
670
671 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
673 self.inner.role_mappings = mappings;
674 self
675 }
676
677 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
680 self.inner.role_mappings.push(RoleMapping {
681 claim_value: claim_value.into(),
682 role: role.into(),
683 });
684 self
685 }
686
687 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
690 self.inner.jwks_cache_ttl = ttl.into();
691 self
692 }
693
694 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
697 self.inner.proxy = Some(proxy);
698 self
699 }
700
701 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
703 self.inner.token_exchange = Some(token_exchange);
704 self
705 }
706
707 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
712 self.inner.ca_cert_path = Some(path.into());
713 self
714 }
715
716 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
722 self.inner.allow_http_oauth_urls = allow;
723 self
724 }
725
726 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
729 self.inner.strict_audience_validation = strict;
730 self
731 }
732
733 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
735 self.inner.jwks_max_response_bytes = bytes;
736 self
737 }
738
739 #[must_use]
741 pub fn build(self) -> OAuthConfig {
742 self.inner
743 }
744}
745
746#[derive(Debug, Clone, Deserialize)]
748#[non_exhaustive]
749pub struct ScopeMapping {
750 pub scope: String,
752 pub role: String,
754}
755
756#[derive(Debug, Clone, Deserialize)]
760#[non_exhaustive]
761pub struct RoleMapping {
762 pub claim_value: String,
764 pub role: String,
766}
767
768#[derive(Debug, Clone, Deserialize)]
775#[non_exhaustive]
776pub struct TokenExchangeConfig {
777 pub token_url: String,
780 pub client_id: String,
782 pub client_secret: Option<secrecy::SecretString>,
785 pub client_cert: Option<ClientCertConfig>,
789 pub audience: String,
793}
794
795impl TokenExchangeConfig {
796 #[must_use]
798 pub fn new(
799 token_url: String,
800 client_id: String,
801 client_secret: Option<secrecy::SecretString>,
802 client_cert: Option<ClientCertConfig>,
803 audience: String,
804 ) -> Self {
805 Self {
806 token_url,
807 client_id,
808 client_secret,
809 client_cert,
810 audience,
811 }
812 }
813}
814
815#[derive(Debug, Clone, Deserialize)]
818#[non_exhaustive]
819pub struct ClientCertConfig {
820 pub cert_path: PathBuf,
822 pub key_path: PathBuf,
824}
825
826#[derive(Debug, Deserialize)]
828#[non_exhaustive]
829pub struct ExchangedToken {
830 pub access_token: String,
832 pub expires_in: Option<u64>,
834 pub issued_token_type: Option<String>,
837}
838
839#[derive(Debug, Clone, Deserialize, Default)]
846#[non_exhaustive]
847pub struct OAuthProxyConfig {
848 pub authorize_url: String,
851 pub token_url: String,
854 pub client_id: String,
856 pub client_secret: Option<secrecy::SecretString>,
858 #[serde(default)]
862 pub introspection_url: Option<String>,
863 #[serde(default)]
867 pub revocation_url: Option<String>,
868 #[serde(default)]
880 pub expose_admin_endpoints: bool,
881 #[serde(default)]
887 pub require_auth_on_admin_endpoints: bool,
888}
889
890impl OAuthProxyConfig {
891 pub fn builder(
899 authorize_url: impl Into<String>,
900 token_url: impl Into<String>,
901 client_id: impl Into<String>,
902 ) -> OAuthProxyConfigBuilder {
903 OAuthProxyConfigBuilder {
904 inner: Self {
905 authorize_url: authorize_url.into(),
906 token_url: token_url.into(),
907 client_id: client_id.into(),
908 ..Self::default()
909 },
910 }
911 }
912}
913
914#[derive(Debug, Clone)]
920#[must_use = "builders do nothing until `.build()` is called"]
921pub struct OAuthProxyConfigBuilder {
922 inner: OAuthProxyConfig,
923}
924
925impl OAuthProxyConfigBuilder {
926 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
928 self.inner.client_secret = Some(secret);
929 self
930 }
931
932 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
936 self.inner.introspection_url = Some(url.into());
937 self
938 }
939
940 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
944 self.inner.revocation_url = Some(url.into());
945 self
946 }
947
948 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
956 self.inner.expose_admin_endpoints = expose;
957 self
958 }
959
960 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
963 self.inner.require_auth_on_admin_endpoints = require;
964 self
965 }
966
967 #[must_use]
969 pub fn build(self) -> OAuthProxyConfig {
970 self.inner
971 }
972}
973
974type JwksKeyCache = (
982 HashMap<String, (Algorithm, DecodingKey)>,
983 Vec<(Algorithm, DecodingKey)>,
984);
985
986struct CachedKeys {
987 keys: HashMap<String, (Algorithm, DecodingKey)>,
989 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
991 fetched_at: Instant,
992 ttl: Duration,
993}
994
995impl CachedKeys {
996 fn is_expired(&self) -> bool {
997 self.fetched_at.elapsed() >= self.ttl
998 }
999}
1000
1001#[allow(
1010 missing_debug_implementations,
1011 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1012)]
1013#[non_exhaustive]
1014pub struct JwksCache {
1015 jwks_uri: String,
1016 ttl: Duration,
1017 max_jwks_keys: usize,
1018 max_response_bytes: u64,
1019 allow_http: bool,
1020 inner: RwLock<Option<CachedKeys>>,
1021 http: reqwest::Client,
1022 validation_template: Validation,
1023 expected_audience: String,
1027 strict_audience_validation: bool,
1028 scopes: Vec<ScopeMapping>,
1029 role_claim: Option<String>,
1030 role_mappings: Vec<RoleMapping>,
1031 last_refresh_attempt: RwLock<Option<Instant>>,
1034 refresh_lock: tokio::sync::Mutex<()>,
1036 #[cfg(any(test, feature = "test-helpers"))]
1037 test_allow_loopback_ssrf: bool,
1038}
1039
1040const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1042
1043const ACCEPTED_ALGS: &[Algorithm] = &[
1045 Algorithm::RS256,
1046 Algorithm::RS384,
1047 Algorithm::RS512,
1048 Algorithm::ES256,
1049 Algorithm::ES384,
1050 Algorithm::PS256,
1051 Algorithm::PS384,
1052 Algorithm::PS512,
1053 Algorithm::EdDSA,
1054];
1055
1056#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1058#[non_exhaustive]
1059pub enum JwtValidationFailure {
1060 Expired,
1062 Invalid,
1064}
1065
1066impl JwksCache {
1067 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1074 rustls::crypto::ring::default_provider()
1077 .install_default()
1078 .ok();
1079 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1080 .install_default()
1081 .ok();
1082
1083 let ttl =
1084 humantime::parse_duration(&config.jwks_cache_ttl).unwrap_or(Duration::from_mins(10));
1085
1086 let mut validation = Validation::new(Algorithm::RS256);
1087 validation.validate_aud = false;
1099 validation.set_issuer(&[&config.issuer]);
1100 validation.set_required_spec_claims(&["exp", "iss"]);
1101 validation.validate_exp = true;
1102 validation.validate_nbf = true;
1103
1104 let allow_http = config.allow_http_oauth_urls;
1105
1106 let mut http_builder = reqwest::Client::builder()
1107 .timeout(Duration::from_secs(10))
1108 .connect_timeout(Duration::from_secs(3))
1109 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1110 match evaluate_oauth_redirect(&attempt, allow_http) {
1120 Ok(()) => attempt.follow(),
1121 Err(reason) => {
1122 tracing::warn!(
1123 reason = %reason,
1124 target = %attempt.url(),
1125 "oauth redirect rejected"
1126 );
1127 attempt.error(reason)
1128 }
1129 }
1130 }));
1131
1132 if let Some(ref ca_path) = config.ca_cert_path {
1133 let pem = std::fs::read(ca_path)?;
1139 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1140 http_builder = http_builder.add_root_certificate(cert);
1141 }
1142
1143 let http = http_builder.build()?;
1144
1145 Ok(Self {
1146 jwks_uri: config.jwks_uri.clone(),
1147 ttl,
1148 max_jwks_keys: config.max_jwks_keys,
1149 max_response_bytes: config.jwks_max_response_bytes,
1150 allow_http,
1151 inner: RwLock::new(None),
1152 http,
1153 validation_template: validation,
1154 expected_audience: config.audience.clone(),
1155 strict_audience_validation: config.strict_audience_validation,
1156 scopes: config.scopes.clone(),
1157 role_claim: config.role_claim.clone(),
1158 role_mappings: config.role_mappings.clone(),
1159 last_refresh_attempt: RwLock::new(None),
1160 refresh_lock: tokio::sync::Mutex::new(()),
1161 #[cfg(any(test, feature = "test-helpers"))]
1162 test_allow_loopback_ssrf: false,
1163 })
1164 }
1165
1166 #[cfg(any(test, feature = "test-helpers"))]
1170 #[doc(hidden)]
1171 #[must_use]
1172 pub fn __test_allow_loopback_ssrf(mut self) -> Self {
1173 self.test_allow_loopback_ssrf = true;
1174 self
1175 }
1176
1177 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1179 self.validate_token_with_reason(token).await.ok()
1180 }
1181
1182 pub async fn validate_token_with_reason(
1189 &self,
1190 token: &str,
1191 ) -> Result<AuthIdentity, JwtValidationFailure> {
1192 let claims = self.decode_claims(token).await?;
1193
1194 self.check_audience(&claims)?;
1195 let role = self.resolve_role(&claims)?;
1196
1197 let sub = claims.sub;
1200 let name = claims
1201 .extra
1202 .get("preferred_username")
1203 .and_then(|v| v.as_str())
1204 .map(String::from)
1205 .or_else(|| sub.clone())
1206 .or(claims.azp)
1207 .or(claims.client_id)
1208 .unwrap_or_else(|| "oauth-client".into());
1209
1210 Ok(AuthIdentity {
1211 name,
1212 role,
1213 method: AuthMethod::OAuthJwt,
1214 raw_token: None,
1215 sub,
1216 })
1217 }
1218
1219 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1231 let (key, alg) = self.select_jwks_key(token).await?;
1232
1233 let mut validation = self.validation_template.clone();
1237 validation.algorithms = vec![alg];
1238
1239 let token_owned = token.to_owned();
1242 let join =
1243 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1244 .await;
1245
1246 let decode_result = match join {
1247 Ok(r) => r,
1248 Err(join_err) => {
1249 core::hint::cold_path();
1250 tracing::error!(
1251 error = %join_err,
1252 "JWT decode task panicked or was cancelled"
1253 );
1254 return Err(JwtValidationFailure::Invalid);
1255 }
1256 };
1257
1258 decode_result.map(|td| td.claims).map_err(|e| {
1259 core::hint::cold_path();
1260 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
1261 JwtValidationFailure::Expired
1262 } else {
1263 JwtValidationFailure::Invalid
1264 };
1265 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
1266 failure
1267 })
1268 }
1269
1270 #[allow(clippy::cognitive_complexity)]
1279 async fn select_jwks_key(
1280 &self,
1281 token: &str,
1282 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
1283 let Ok(header) = decode_header(token) else {
1284 core::hint::cold_path();
1285 tracing::debug!("JWT header decode failed");
1286 return Err(JwtValidationFailure::Invalid);
1287 };
1288 let kid = header.kid.as_deref();
1289 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
1290
1291 if !ACCEPTED_ALGS.contains(&header.alg) {
1292 core::hint::cold_path();
1293 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
1294 return Err(JwtValidationFailure::Invalid);
1295 }
1296
1297 let Some(key) = self.find_key(kid, header.alg).await else {
1298 core::hint::cold_path();
1299 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
1300 return Err(JwtValidationFailure::Invalid);
1301 };
1302
1303 Ok((key, header.alg))
1304 }
1305
1306 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
1314 let aud_ok = claims.aud.contains(&self.expected_audience)
1315 || (!self.strict_audience_validation
1316 && claims
1317 .azp
1318 .as_deref()
1319 .is_some_and(|azp| azp == self.expected_audience));
1320 if aud_ok {
1321 return Ok(());
1322 }
1323 core::hint::cold_path();
1324 tracing::debug!(
1325 aud = ?claims.aud.0,
1326 azp = ?claims.azp,
1327 expected = %self.expected_audience,
1328 strict = self.strict_audience_validation,
1329 "JWT rejected: audience mismatch"
1330 );
1331 Err(JwtValidationFailure::Invalid)
1332 }
1333
1334 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
1340 if let Some(ref claim_path) = self.role_claim {
1341 let values = resolve_claim_path(&claims.extra, claim_path);
1342 return self
1343 .role_mappings
1344 .iter()
1345 .find(|m| values.contains(&m.claim_value.as_str()))
1346 .map(|m| m.role.clone())
1347 .ok_or(JwtValidationFailure::Invalid);
1348 }
1349
1350 let token_scopes: Vec<&str> = claims
1351 .scope
1352 .as_deref()
1353 .unwrap_or("")
1354 .split_whitespace()
1355 .collect();
1356
1357 self.scopes
1358 .iter()
1359 .find(|m| token_scopes.contains(&m.scope.as_str()))
1360 .map(|m| m.role.clone())
1361 .ok_or(JwtValidationFailure::Invalid)
1362 }
1363
1364 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1367 {
1369 let guard = self.inner.read().await;
1370 if let Some(cached) = guard.as_ref()
1371 && !cached.is_expired()
1372 && let Some(key) = lookup_key(cached, kid, alg)
1373 {
1374 return Some(key);
1375 }
1376 }
1377
1378 self.refresh_with_cooldown().await;
1380
1381 let guard = self.inner.read().await;
1382 guard
1383 .as_ref()
1384 .and_then(|cached| lookup_key(cached, kid, alg))
1385 }
1386
1387 async fn refresh_with_cooldown(&self) {
1392 let _guard = self.refresh_lock.lock().await;
1394
1395 {
1397 let last = self.last_refresh_attempt.read().await;
1398 if let Some(ts) = *last
1399 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
1400 {
1401 tracing::debug!(
1402 elapsed_ms = ts.elapsed().as_millis(),
1403 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
1404 "JWKS refresh skipped (cooldown active)"
1405 );
1406 return;
1407 }
1408 }
1409
1410 {
1413 let mut last = self.last_refresh_attempt.write().await;
1414 *last = Some(Instant::now());
1415 }
1416
1417 let _ = self.refresh_inner().await;
1419 }
1420
1421 async fn refresh_inner(&self) -> Result<(), String> {
1426 let Some(jwks) = self.fetch_jwks().await else {
1427 return Ok(());
1428 };
1429 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
1430 Ok(cache) => cache,
1431 Err(msg) => {
1432 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
1433 return Err(msg);
1434 }
1435 };
1436
1437 tracing::debug!(
1438 named = keys.len(),
1439 unnamed = unnamed_keys.len(),
1440 "JWKS refreshed"
1441 );
1442
1443 let mut guard = self.inner.write().await;
1444 *guard = Some(CachedKeys {
1445 keys,
1446 unnamed_keys,
1447 fetched_at: Instant::now(),
1448 ttl: self.ttl,
1449 });
1450 Ok(())
1451 }
1452
1453 #[allow(
1455 clippy::cognitive_complexity,
1456 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
1457 )]
1458 async fn fetch_jwks(&self) -> Option<JwkSet> {
1459 #[cfg(any(test, feature = "test-helpers"))]
1460 let screening = if self.test_allow_loopback_ssrf {
1461 screen_oauth_target_with_test_override(&self.jwks_uri, self.allow_http, true).await
1462 } else {
1463 screen_oauth_target(&self.jwks_uri, self.allow_http).await
1464 };
1465 #[cfg(not(any(test, feature = "test-helpers")))]
1466 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http).await;
1467
1468 if let Err(error) = screening {
1469 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
1470 return None;
1471 }
1472
1473 let mut resp = match self.http.get(&self.jwks_uri).send().await {
1474 Ok(resp) => resp,
1475 Err(e) => {
1476 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
1477 return None;
1478 }
1479 };
1480
1481 let initial_capacity =
1482 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
1483 let mut body = Vec::with_capacity(initial_capacity);
1484 while let Some(chunk) = match resp.chunk().await {
1485 Ok(chunk) => chunk,
1486 Err(error) => {
1487 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
1488 return None;
1489 }
1490 } {
1491 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1492 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
1493 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
1494 tracing::warn!(
1495 uri = %self.jwks_uri,
1496 max_bytes = self.max_response_bytes,
1497 "JWKS response exceeded configured size cap"
1498 );
1499 return None;
1500 }
1501 body.extend_from_slice(&chunk);
1502 }
1503
1504 match serde_json::from_slice::<JwkSet>(&body) {
1505 Ok(jwks) => Some(jwks),
1506 Err(error) => {
1507 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
1508 None
1509 }
1510 }
1511 }
1512
1513 #[cfg(any(test, feature = "test-helpers"))]
1516 #[doc(hidden)]
1517 pub async fn __test_refresh_now(&self) -> Result<(), String> {
1518 let jwks = self
1519 .fetch_jwks()
1520 .await
1521 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
1522 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
1523 let mut guard = self.inner.write().await;
1524 *guard = Some(CachedKeys {
1525 keys,
1526 unnamed_keys,
1527 fetched_at: Instant::now(),
1528 ttl: self.ttl,
1529 });
1530 Ok(())
1531 }
1532
1533 #[cfg(any(test, feature = "test-helpers"))]
1536 #[doc(hidden)]
1537 pub async fn __test_has_kid(&self, kid: &str) -> bool {
1538 let guard = self.inner.read().await;
1539 guard
1540 .as_ref()
1541 .is_some_and(|cache| cache.keys.contains_key(kid))
1542 }
1543}
1544
1545fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
1547 if jwks.keys.len() > max_keys {
1548 return Err(format!(
1549 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
1550 jwks.keys.len(),
1551 max_keys
1552 ));
1553 }
1554 let mut keys = HashMap::new();
1555 let mut unnamed_keys = Vec::new();
1556 for jwk in &jwks.keys {
1557 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
1558 continue;
1559 };
1560 let Some(alg) = jwk_algorithm(jwk) else {
1561 continue;
1562 };
1563 if let Some(ref kid) = jwk.common.key_id {
1564 keys.insert(kid.clone(), (alg, decoding_key));
1565 } else {
1566 unnamed_keys.push((alg, decoding_key));
1567 }
1568 }
1569 Ok((keys, unnamed_keys))
1570}
1571
1572fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1574 if let Some(kid) = kid
1575 && let Some((cached_alg, key)) = cached.keys.get(kid)
1576 && *cached_alg == alg
1577 {
1578 return Some(key.clone());
1579 }
1580 cached
1582 .unnamed_keys
1583 .iter()
1584 .find(|(a, _)| *a == alg)
1585 .map(|(_, k)| k.clone())
1586}
1587
1588#[allow(clippy::wildcard_enum_match_arm)]
1590fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
1591 jwk.common.key_algorithm.and_then(|ka| match ka {
1592 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
1593 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
1594 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
1595 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
1596 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
1597 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
1598 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
1599 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
1600 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
1601 _ => None,
1602 })
1603}
1604
1605fn resolve_claim_path<'a>(
1619 extra: &'a HashMap<String, serde_json::Value>,
1620 path: &str,
1621) -> Vec<&'a str> {
1622 let mut segments = path.split('.');
1623 let Some(first) = segments.next() else {
1624 return Vec::new();
1625 };
1626
1627 let mut current: Option<&serde_json::Value> = extra.get(first);
1628
1629 for segment in segments {
1630 current = current.and_then(|v| v.get(segment));
1631 }
1632
1633 match current {
1634 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
1635 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
1636 _ => Vec::new(),
1637 }
1638}
1639
1640#[derive(Debug, Deserialize)]
1646struct Claims {
1647 sub: Option<String>,
1649 #[serde(default)]
1652 aud: OneOrMany,
1653 azp: Option<String>,
1655 client_id: Option<String>,
1657 scope: Option<String>,
1659 #[serde(flatten)]
1661 extra: HashMap<String, serde_json::Value>,
1662}
1663
1664#[derive(Debug, Default)]
1666struct OneOrMany(Vec<String>);
1667
1668impl OneOrMany {
1669 fn contains(&self, value: &str) -> bool {
1670 self.0.iter().any(|v| v == value)
1671 }
1672}
1673
1674impl<'de> Deserialize<'de> for OneOrMany {
1675 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
1676 use serde::de;
1677
1678 struct Visitor;
1679 impl<'de> de::Visitor<'de> for Visitor {
1680 type Value = OneOrMany;
1681 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1682 f.write_str("a string or array of strings")
1683 }
1684 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
1685 Ok(OneOrMany(vec![v.to_owned()]))
1686 }
1687 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
1688 let mut v = Vec::new();
1689 while let Some(s) = seq.next_element::<String>()? {
1690 v.push(s);
1691 }
1692 Ok(OneOrMany(v))
1693 }
1694 }
1695 deserializer.deserialize_any(Visitor)
1696 }
1697}
1698
1699#[must_use]
1706pub fn looks_like_jwt(token: &str) -> bool {
1707 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1708
1709 let mut parts = token.splitn(4, '.');
1710 let Some(header_b64) = parts.next() else {
1711 return false;
1712 };
1713 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
1715 return false;
1716 }
1717 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
1719 return false;
1720 };
1721 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
1723 return false;
1724 };
1725 header.get("alg").is_some()
1726}
1727
1728#[must_use]
1738pub fn protected_resource_metadata(
1739 resource_url: &str,
1740 server_url: &str,
1741 config: &OAuthConfig,
1742) -> serde_json::Value {
1743 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
1748 let auth_server = server_url;
1749 serde_json::json!({
1750 "resource": resource_url,
1751 "authorization_servers": [auth_server],
1752 "scopes_supported": scopes,
1753 "bearer_methods_supported": ["header"]
1754 })
1755}
1756
1757#[must_use]
1762pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
1763 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
1764 let mut meta = serde_json::json!({
1765 "issuer": &config.issuer,
1766 "authorization_endpoint": format!("{server_url}/authorize"),
1767 "token_endpoint": format!("{server_url}/token"),
1768 "registration_endpoint": format!("{server_url}/register"),
1769 "response_types_supported": ["code"],
1770 "grant_types_supported": ["authorization_code", "refresh_token"],
1771 "code_challenge_methods_supported": ["S256"],
1772 "scopes_supported": scopes,
1773 "token_endpoint_auth_methods_supported": ["none"],
1774 });
1775 if let Some(proxy) = &config.proxy
1776 && proxy.expose_admin_endpoints
1777 && let Some(obj) = meta.as_object_mut()
1778 {
1779 if proxy.introspection_url.is_some() {
1780 obj.insert(
1781 "introspection_endpoint".into(),
1782 serde_json::Value::String(format!("{server_url}/introspect")),
1783 );
1784 }
1785 if proxy.revocation_url.is_some() {
1786 obj.insert(
1787 "revocation_endpoint".into(),
1788 serde_json::Value::String(format!("{server_url}/revoke")),
1789 );
1790 }
1791 if proxy.require_auth_on_admin_endpoints {
1792 obj.insert(
1793 "introspection_endpoint_auth_methods_supported".into(),
1794 serde_json::json!(["bearer"]),
1795 );
1796 obj.insert(
1797 "revocation_endpoint_auth_methods_supported".into(),
1798 serde_json::json!(["bearer"]),
1799 );
1800 }
1801 }
1802 meta
1803}
1804
1805#[must_use]
1818pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
1819 use axum::{
1820 http::{StatusCode, header},
1821 response::IntoResponse,
1822 };
1823
1824 let upstream_query = replace_client_id(query, &proxy.client_id);
1826 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
1827
1828 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
1829}
1830
1831pub async fn handle_token(
1837 http: &OauthHttpClient,
1838 proxy: &OAuthProxyConfig,
1839 body: &str,
1840) -> axum::response::Response {
1841 use axum::{
1842 http::{StatusCode, header},
1843 response::IntoResponse,
1844 };
1845
1846 let mut upstream_body = replace_client_id(body, &proxy.client_id);
1848
1849 if let Some(ref secret) = proxy.client_secret {
1851 use std::fmt::Write;
1852
1853 use secrecy::ExposeSecret;
1854 let _ = write!(
1855 upstream_body,
1856 "&client_secret={}",
1857 urlencoding::encode(secret.expose_secret())
1858 );
1859 }
1860
1861 let result = http
1862 .send_screened(
1863 &proxy.token_url,
1864 http.inner
1865 .post(&proxy.token_url)
1866 .header("Content-Type", "application/x-www-form-urlencoded")
1867 .body(upstream_body),
1868 )
1869 .await;
1870
1871 match result {
1872 Ok(resp) => {
1873 let status =
1874 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
1875 let body_bytes = resp.bytes().await.unwrap_or_default();
1876 (
1877 status,
1878 [(header::CONTENT_TYPE, "application/json")],
1879 body_bytes,
1880 )
1881 .into_response()
1882 }
1883 Err(e) => {
1884 tracing::error!(error = %e, "OAuth token proxy request failed");
1885 (
1886 StatusCode::BAD_GATEWAY,
1887 [(header::CONTENT_TYPE, "application/json")],
1888 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
1889 )
1890 .into_response()
1891 }
1892 }
1893}
1894
1895#[must_use]
1902pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
1903 let mut resp = serde_json::json!({
1904 "client_id": proxy.client_id,
1905 "token_endpoint_auth_method": "none",
1906 });
1907 if let Some(uris) = body.get("redirect_uris")
1908 && let Some(obj) = resp.as_object_mut()
1909 {
1910 obj.insert("redirect_uris".into(), uris.clone());
1911 }
1912 if let Some(name) = body.get("client_name")
1913 && let Some(obj) = resp.as_object_mut()
1914 {
1915 obj.insert("client_name".into(), name.clone());
1916 }
1917 resp
1918}
1919
1920pub async fn handle_introspect(
1926 http: &OauthHttpClient,
1927 proxy: &OAuthProxyConfig,
1928 body: &str,
1929) -> axum::response::Response {
1930 let Some(ref url) = proxy.introspection_url else {
1931 return oauth_error_response(
1932 axum::http::StatusCode::NOT_FOUND,
1933 "not_supported",
1934 "introspection endpoint is not configured",
1935 );
1936 };
1937 proxy_oauth_admin_request(http, proxy, url, body).await
1938}
1939
1940pub async fn handle_revoke(
1947 http: &OauthHttpClient,
1948 proxy: &OAuthProxyConfig,
1949 body: &str,
1950) -> axum::response::Response {
1951 let Some(ref url) = proxy.revocation_url else {
1952 return oauth_error_response(
1953 axum::http::StatusCode::NOT_FOUND,
1954 "not_supported",
1955 "revocation endpoint is not configured",
1956 );
1957 };
1958 proxy_oauth_admin_request(http, proxy, url, body).await
1959}
1960
1961async fn proxy_oauth_admin_request(
1965 http: &OauthHttpClient,
1966 proxy: &OAuthProxyConfig,
1967 upstream_url: &str,
1968 body: &str,
1969) -> axum::response::Response {
1970 use axum::{
1971 http::{StatusCode, header},
1972 response::IntoResponse,
1973 };
1974
1975 let mut upstream_body = replace_client_id(body, &proxy.client_id);
1976 if let Some(ref secret) = proxy.client_secret {
1977 use std::fmt::Write;
1978
1979 use secrecy::ExposeSecret;
1980 let _ = write!(
1981 upstream_body,
1982 "&client_secret={}",
1983 urlencoding::encode(secret.expose_secret())
1984 );
1985 }
1986
1987 let result = http
1988 .send_screened(
1989 upstream_url,
1990 http.inner
1991 .post(upstream_url)
1992 .header("Content-Type", "application/x-www-form-urlencoded")
1993 .body(upstream_body),
1994 )
1995 .await;
1996
1997 match result {
1998 Ok(resp) => {
1999 let status =
2000 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2001 let content_type = resp
2002 .headers()
2003 .get(header::CONTENT_TYPE)
2004 .and_then(|v| v.to_str().ok())
2005 .unwrap_or("application/json")
2006 .to_owned();
2007 let body_bytes = resp.bytes().await.unwrap_or_default();
2008 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2009 }
2010 Err(e) => {
2011 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2012 oauth_error_response(
2013 StatusCode::BAD_GATEWAY,
2014 "server_error",
2015 "upstream endpoint unreachable",
2016 )
2017 }
2018 }
2019}
2020
2021fn oauth_error_response(
2022 status: axum::http::StatusCode,
2023 error: &str,
2024 description: &str,
2025) -> axum::response::Response {
2026 use axum::{http::header, response::IntoResponse};
2027 let body = serde_json::json!({
2028 "error": error,
2029 "error_description": description,
2030 });
2031 (
2032 status,
2033 [(header::CONTENT_TYPE, "application/json")],
2034 body.to_string(),
2035 )
2036 .into_response()
2037}
2038
2039#[derive(Debug, Deserialize)]
2045struct OAuthErrorResponse {
2046 error: String,
2047 error_description: Option<String>,
2048}
2049
2050fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2057 match raw {
2058 "invalid_request" => "invalid_request",
2059 "invalid_client" => "invalid_client",
2060 "invalid_grant" => "invalid_grant",
2061 "unauthorized_client" => "unauthorized_client",
2062 "unsupported_grant_type" => "unsupported_grant_type",
2063 "invalid_scope" => "invalid_scope",
2064 "temporarily_unavailable" => "temporarily_unavailable",
2065 "invalid_target" => "invalid_target",
2067 _ => "server_error",
2070 }
2071}
2072
2073pub async fn exchange_token(
2085 http: &OauthHttpClient,
2086 config: &TokenExchangeConfig,
2087 subject_token: &str,
2088) -> Result<ExchangedToken, crate::error::McpxError> {
2089 use secrecy::ExposeSecret;
2090
2091 let mut req = http
2092 .inner
2093 .post(&config.token_url)
2094 .header("Content-Type", "application/x-www-form-urlencoded")
2095 .header("Accept", "application/json");
2096
2097 if let Some(ref secret) = config.client_secret {
2099 use base64::Engine;
2100 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2101 "{}:{}",
2102 urlencoding::encode(&config.client_id),
2103 urlencoding::encode(secret.expose_secret()),
2104 ));
2105 req = req.header("Authorization", format!("Basic {credentials}"));
2106 }
2107 let form_body = build_exchange_form(config, subject_token);
2110
2111 let resp = http
2112 .send_screened(&config.token_url, req.body(form_body))
2113 .await
2114 .map_err(|e| {
2115 tracing::error!(error = %e, "token exchange request failed");
2116 crate::error::McpxError::Auth("server_error".into())
2118 })?;
2119
2120 let status = resp.status();
2121 let body_bytes = resp.bytes().await.map_err(|e| {
2122 tracing::error!(error = %e, "failed to read token exchange response");
2123 crate::error::McpxError::Auth("server_error".into())
2124 })?;
2125
2126 if !status.is_success() {
2127 core::hint::cold_path();
2128 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2131 let short_code = parsed
2132 .as_ref()
2133 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2134 if let Some(ref e) = parsed {
2135 tracing::warn!(
2136 status = %status,
2137 upstream_error = %e.error,
2138 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2139 client_code = %short_code,
2140 "token exchange rejected by authorization server",
2141 );
2142 } else {
2143 tracing::warn!(
2144 status = %status,
2145 client_code = %short_code,
2146 "token exchange rejected (unparseable upstream body)",
2147 );
2148 }
2149 return Err(crate::error::McpxError::Auth(short_code.into()));
2150 }
2151
2152 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2153 tracing::error!(error = %e, "failed to parse token exchange response");
2154 crate::error::McpxError::Auth("server_error".into())
2157 })?;
2158
2159 log_exchanged_token(&exchanged);
2160
2161 Ok(exchanged)
2162}
2163
2164fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2167 let body = format!(
2168 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2169 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2170 urlencoding::encode(subject_token),
2171 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2172 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2173 urlencoding::encode(&config.audience),
2174 );
2175 if config.client_secret.is_none() {
2176 format!(
2177 "{body}&client_id={}",
2178 urlencoding::encode(&config.client_id)
2179 )
2180 } else {
2181 body
2182 }
2183}
2184
2185fn log_exchanged_token(exchanged: &ExchangedToken) {
2188 use base64::Engine;
2189
2190 if !looks_like_jwt(&exchanged.access_token) {
2191 tracing::debug!(
2192 token_len = exchanged.access_token.len(),
2193 issued_token_type = ?exchanged.issued_token_type,
2194 expires_in = exchanged.expires_in,
2195 "exchanged token (opaque)",
2196 );
2197 return;
2198 }
2199 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
2200 return;
2201 };
2202 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
2203 return;
2204 };
2205 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
2206 return;
2207 };
2208 tracing::debug!(
2209 sub = ?claims.get("sub"),
2210 aud = ?claims.get("aud"),
2211 azp = ?claims.get("azp"),
2212 iss = ?claims.get("iss"),
2213 expires_in = exchanged.expires_in,
2214 "exchanged token claims (JWT)",
2215 );
2216}
2217
2218fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
2220 let encoded_id = urlencoding::encode(upstream_client_id);
2221 let mut parts: Vec<String> = params
2222 .split('&')
2223 .filter(|p| !p.starts_with("client_id="))
2224 .map(String::from)
2225 .collect();
2226 parts.push(format!("client_id={encoded_id}"));
2227 parts.join("&")
2228}
2229
2230#[cfg(test)]
2231mod tests {
2232 use std::sync::Arc;
2233
2234 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2235
2236 use super::*;
2237
2238 #[test]
2239 fn looks_like_jwt_valid() {
2240 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
2242 let payload = URL_SAFE_NO_PAD.encode(b"{}");
2243 let token = format!("{header}.{payload}.signature");
2244 assert!(looks_like_jwt(&token));
2245 }
2246
2247 #[test]
2248 fn looks_like_jwt_rejects_opaque_token() {
2249 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
2250 }
2251
2252 #[test]
2253 fn looks_like_jwt_rejects_two_segments() {
2254 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
2255 let token = format!("{header}.payload");
2256 assert!(!looks_like_jwt(&token));
2257 }
2258
2259 #[test]
2260 fn looks_like_jwt_rejects_four_segments() {
2261 assert!(!looks_like_jwt("a.b.c.d"));
2262 }
2263
2264 #[test]
2265 fn looks_like_jwt_rejects_no_alg() {
2266 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
2267 let payload = URL_SAFE_NO_PAD.encode(b"{}");
2268 let token = format!("{header}.{payload}.sig");
2269 assert!(!looks_like_jwt(&token));
2270 }
2271
2272 #[test]
2273 fn protected_resource_metadata_shape() {
2274 let config = OAuthConfig {
2275 issuer: "https://auth.example.com".into(),
2276 audience: "https://mcp.example.com/mcp".into(),
2277 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
2278 scopes: vec![
2279 ScopeMapping {
2280 scope: "mcp:read".into(),
2281 role: "viewer".into(),
2282 },
2283 ScopeMapping {
2284 scope: "mcp:admin".into(),
2285 role: "ops".into(),
2286 },
2287 ],
2288 role_claim: None,
2289 role_mappings: vec![],
2290 jwks_cache_ttl: "10m".into(),
2291 proxy: None,
2292 token_exchange: None,
2293 ca_cert_path: None,
2294 allow_http_oauth_urls: false,
2295 max_jwks_keys: default_max_jwks_keys(),
2296 strict_audience_validation: false,
2297 jwks_max_response_bytes: default_jwks_max_bytes(),
2298 };
2299 let meta = protected_resource_metadata(
2300 "https://mcp.example.com/mcp",
2301 "https://mcp.example.com",
2302 &config,
2303 );
2304 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
2305 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
2306 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
2307 assert_eq!(meta["bearer_methods_supported"][0], "header");
2308 }
2309
2310 fn validation_https_config() -> OAuthConfig {
2315 OAuthConfig::builder(
2316 "https://auth.example.com",
2317 "mcp",
2318 "https://auth.example.com/.well-known/jwks.json",
2319 )
2320 .build()
2321 }
2322
2323 #[test]
2324 fn validate_accepts_all_https_urls() {
2325 let cfg = validation_https_config();
2326 cfg.validate().expect("all-HTTPS config must validate");
2327 }
2328
2329 #[test]
2330 fn validate_rejects_http_jwks_uri() {
2331 let mut cfg = validation_https_config();
2332 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
2333 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
2334 let msg = err.to_string();
2335 assert!(
2336 msg.contains("oauth.jwks_uri") && msg.contains("https"),
2337 "error must reference offending field + scheme requirement; got {msg:?}"
2338 );
2339 }
2340
2341 #[test]
2342 fn validate_rejects_http_proxy_authorize_url() {
2343 let mut cfg = validation_https_config();
2344 cfg.proxy = Some(
2345 OAuthProxyConfig::builder(
2346 "http://idp.example.com/authorize", "https://idp.example.com/token",
2348 "client",
2349 )
2350 .build(),
2351 );
2352 let err = cfg
2353 .validate()
2354 .expect_err("http authorize_url must be rejected");
2355 assert!(
2356 err.to_string().contains("oauth.proxy.authorize_url"),
2357 "error must reference proxy.authorize_url; got {err}"
2358 );
2359 }
2360
2361 #[test]
2362 fn validate_rejects_http_proxy_token_url() {
2363 let mut cfg = validation_https_config();
2364 cfg.proxy = Some(
2365 OAuthProxyConfig::builder(
2366 "https://idp.example.com/authorize",
2367 "http://idp.example.com/token", "client",
2369 )
2370 .build(),
2371 );
2372 let err = cfg.validate().expect_err("http token_url must be rejected");
2373 assert!(
2374 err.to_string().contains("oauth.proxy.token_url"),
2375 "error must reference proxy.token_url; got {err}"
2376 );
2377 }
2378
2379 #[test]
2380 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
2381 let mut cfg = validation_https_config();
2382 cfg.proxy = Some(
2383 OAuthProxyConfig::builder(
2384 "https://idp.example.com/authorize",
2385 "https://idp.example.com/token",
2386 "client",
2387 )
2388 .introspection_url("http://idp.example.com/introspect")
2389 .build(),
2390 );
2391 let err = cfg
2392 .validate()
2393 .expect_err("http introspection_url must be rejected");
2394 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
2395
2396 let mut cfg = validation_https_config();
2397 cfg.proxy = Some(
2398 OAuthProxyConfig::builder(
2399 "https://idp.example.com/authorize",
2400 "https://idp.example.com/token",
2401 "client",
2402 )
2403 .revocation_url("http://idp.example.com/revoke")
2404 .build(),
2405 );
2406 let err = cfg
2407 .validate()
2408 .expect_err("http revocation_url must be rejected");
2409 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
2410 }
2411
2412 #[test]
2413 fn validate_rejects_http_token_exchange_url() {
2414 let mut cfg = validation_https_config();
2415 cfg.token_exchange = Some(TokenExchangeConfig::new(
2416 "http://idp.example.com/token".into(), "client".into(),
2418 None,
2419 None,
2420 "downstream".into(),
2421 ));
2422 let err = cfg
2423 .validate()
2424 .expect_err("http token_exchange.token_url must be rejected");
2425 assert!(
2426 err.to_string().contains("oauth.token_exchange.token_url"),
2427 "error must reference token_exchange.token_url; got {err}"
2428 );
2429 }
2430
2431 #[test]
2432 fn validate_rejects_unparseable_url() {
2433 let mut cfg = validation_https_config();
2434 cfg.jwks_uri = "not a url".into();
2435 let err = cfg
2436 .validate()
2437 .expect_err("unparseable URL must be rejected");
2438 assert!(err.to_string().contains("invalid URL"));
2439 }
2440
2441 #[test]
2442 fn validate_rejects_non_http_scheme() {
2443 let mut cfg = validation_https_config();
2444 cfg.jwks_uri = "file:///etc/passwd".into();
2445 let err = cfg.validate().expect_err("file:// scheme must be rejected");
2446 let msg = err.to_string();
2447 assert!(
2448 msg.contains("must use https scheme") && msg.contains("file"),
2449 "error must reject non-http(s) schemes; got {msg:?}"
2450 );
2451 }
2452
2453 #[test]
2454 fn validate_accepts_http_with_escape_hatch() {
2455 let mut cfg = OAuthConfig::builder(
2460 "http://auth.local",
2461 "mcp",
2462 "http://auth.local/.well-known/jwks.json",
2463 )
2464 .allow_http_oauth_urls(true)
2465 .build();
2466 cfg.proxy = Some(
2467 OAuthProxyConfig::builder(
2468 "http://idp.local/authorize",
2469 "http://idp.local/token",
2470 "client",
2471 )
2472 .introspection_url("http://idp.local/introspect")
2473 .revocation_url("http://idp.local/revoke")
2474 .build(),
2475 );
2476 cfg.token_exchange = Some(TokenExchangeConfig::new(
2477 "http://idp.local/token".into(),
2478 "client".into(),
2479 None,
2480 None,
2481 "downstream".into(),
2482 ));
2483 cfg.validate()
2484 .expect("escape hatch must permit http on all URL fields");
2485 }
2486
2487 #[test]
2488 fn validate_with_escape_hatch_still_rejects_unparseable() {
2489 let mut cfg = validation_https_config();
2492 cfg.allow_http_oauth_urls = true;
2493 cfg.jwks_uri = "::not-a-url::".into();
2494 cfg.validate()
2495 .expect_err("escape hatch must NOT bypass URL parsing");
2496 }
2497
2498 #[tokio::test]
2499 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
2500 rustls::crypto::ring::default_provider()
2515 .install_default()
2516 .ok();
2517
2518 let policy = reqwest::redirect::Policy::custom(|attempt| {
2519 if attempt.url().scheme() != "https" {
2520 attempt.error("redirect to non-HTTPS URL refused")
2521 } else if attempt.previous().len() >= 2 {
2522 attempt.error("too many redirects (max 2)")
2523 } else {
2524 attempt.follow()
2525 }
2526 });
2527 let client = reqwest::Client::builder()
2528 .timeout(Duration::from_secs(5))
2529 .connect_timeout(Duration::from_secs(3))
2530 .redirect(policy)
2531 .build()
2532 .expect("test client builds");
2533
2534 let mock = wiremock::MockServer::start().await;
2535 wiremock::Mock::given(wiremock::matchers::method("GET"))
2536 .and(wiremock::matchers::path("/jwks.json"))
2537 .respond_with(
2538 wiremock::ResponseTemplate::new(302)
2539 .insert_header("location", "http://example.invalid/jwks.json"),
2540 )
2541 .mount(&mock)
2542 .await;
2543
2544 let url = format!("{}/jwks.json", mock.uri());
2553 let err = client
2554 .get(&url)
2555 .send()
2556 .await
2557 .expect_err("redirect policy must reject scheme downgrade");
2558 let chain = format!("{err:#}");
2559 assert!(
2560 chain.contains("redirect to non-HTTPS URL refused")
2561 || chain.to_lowercase().contains("redirect"),
2562 "error must surface redirect-policy rejection; got {chain:?}"
2563 );
2564 }
2565
2566 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
2571
2572 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
2574 let mut rng = rsa::rand_core::OsRng;
2575 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
2576 let private_pem = private_key
2577 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
2578 .expect("PKCS8 PEM export")
2579 .to_string();
2580
2581 let public_key = private_key.to_public_key();
2582 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
2583 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
2584
2585 let jwks = serde_json::json!({
2586 "keys": [{
2587 "kty": "RSA",
2588 "use": "sig",
2589 "alg": "RS256",
2590 "kid": kid,
2591 "n": n,
2592 "e": e
2593 }]
2594 });
2595
2596 (private_pem, jwks)
2597 }
2598
2599 fn mint_token(
2601 private_pem: &str,
2602 kid: &str,
2603 issuer: &str,
2604 audience: &str,
2605 subject: &str,
2606 scope: &str,
2607 ) -> String {
2608 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
2609 .expect("encoding key from PEM");
2610 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2611 header.kid = Some(kid.into());
2612
2613 let now = jsonwebtoken::get_current_timestamp();
2614 let claims = serde_json::json!({
2615 "iss": issuer,
2616 "aud": audience,
2617 "sub": subject,
2618 "scope": scope,
2619 "exp": now + 3600,
2620 "iat": now,
2621 });
2622
2623 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
2624 }
2625
2626 fn test_config(jwks_uri: &str) -> OAuthConfig {
2627 OAuthConfig {
2628 issuer: "https://auth.test.local".into(),
2629 audience: "https://mcp.test.local/mcp".into(),
2630 jwks_uri: jwks_uri.into(),
2631 scopes: vec![
2632 ScopeMapping {
2633 scope: "mcp:read".into(),
2634 role: "viewer".into(),
2635 },
2636 ScopeMapping {
2637 scope: "mcp:admin".into(),
2638 role: "ops".into(),
2639 },
2640 ],
2641 role_claim: None,
2642 role_mappings: vec![],
2643 jwks_cache_ttl: "5m".into(),
2644 proxy: None,
2645 token_exchange: None,
2646 ca_cert_path: None,
2647 allow_http_oauth_urls: true,
2648 max_jwks_keys: default_max_jwks_keys(),
2649 strict_audience_validation: false,
2650 jwks_max_response_bytes: default_jwks_max_bytes(),
2651 }
2652 }
2653
2654 fn test_cache(config: &OAuthConfig) -> JwksCache {
2655 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
2656 }
2657
2658 #[tokio::test]
2659 async fn valid_jwt_returns_identity() {
2660 let kid = "test-key-1";
2661 let (pem, jwks) = generate_test_keypair(kid);
2662
2663 let mock_server = wiremock::MockServer::start().await;
2664 wiremock::Mock::given(wiremock::matchers::method("GET"))
2665 .and(wiremock::matchers::path("/jwks.json"))
2666 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2667 .mount(&mock_server)
2668 .await;
2669
2670 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2671 let config = test_config(&jwks_uri);
2672 let cache = test_cache(&config);
2673
2674 let token = mint_token(
2675 &pem,
2676 kid,
2677 "https://auth.test.local",
2678 "https://mcp.test.local/mcp",
2679 "ci-bot",
2680 "mcp:read mcp:other",
2681 );
2682
2683 let identity = cache.validate_token(&token).await;
2684 assert!(identity.is_some(), "valid JWT should authenticate");
2685 let id = identity.unwrap();
2686 assert_eq!(id.name, "ci-bot");
2687 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
2689 }
2690
2691 #[tokio::test]
2692 async fn wrong_issuer_rejected() {
2693 let kid = "test-key-2";
2694 let (pem, jwks) = generate_test_keypair(kid);
2695
2696 let mock_server = wiremock::MockServer::start().await;
2697 wiremock::Mock::given(wiremock::matchers::method("GET"))
2698 .and(wiremock::matchers::path("/jwks.json"))
2699 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2700 .mount(&mock_server)
2701 .await;
2702
2703 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2704 let config = test_config(&jwks_uri);
2705 let cache = test_cache(&config);
2706
2707 let token = mint_token(
2708 &pem,
2709 kid,
2710 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
2712 "attacker",
2713 "mcp:admin",
2714 );
2715
2716 assert!(cache.validate_token(&token).await.is_none());
2717 }
2718
2719 #[tokio::test]
2720 async fn wrong_audience_rejected() {
2721 let kid = "test-key-3";
2722 let (pem, jwks) = generate_test_keypair(kid);
2723
2724 let mock_server = wiremock::MockServer::start().await;
2725 wiremock::Mock::given(wiremock::matchers::method("GET"))
2726 .and(wiremock::matchers::path("/jwks.json"))
2727 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2728 .mount(&mock_server)
2729 .await;
2730
2731 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2732 let config = test_config(&jwks_uri);
2733 let cache = test_cache(&config);
2734
2735 let token = mint_token(
2736 &pem,
2737 kid,
2738 "https://auth.test.local",
2739 "https://wrong-audience.example.com", "attacker",
2741 "mcp:admin",
2742 );
2743
2744 assert!(cache.validate_token(&token).await.is_none());
2745 }
2746
2747 #[tokio::test]
2748 async fn expired_jwt_rejected() {
2749 let kid = "test-key-4";
2750 let (pem, jwks) = generate_test_keypair(kid);
2751
2752 let mock_server = wiremock::MockServer::start().await;
2753 wiremock::Mock::given(wiremock::matchers::method("GET"))
2754 .and(wiremock::matchers::path("/jwks.json"))
2755 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2756 .mount(&mock_server)
2757 .await;
2758
2759 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2760 let config = test_config(&jwks_uri);
2761 let cache = test_cache(&config);
2762
2763 let encoding_key =
2765 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
2766 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2767 header.kid = Some(kid.into());
2768 let now = jsonwebtoken::get_current_timestamp();
2769 let claims = serde_json::json!({
2770 "iss": "https://auth.test.local",
2771 "aud": "https://mcp.test.local/mcp",
2772 "sub": "expired-bot",
2773 "scope": "mcp:read",
2774 "exp": now - 120,
2775 "iat": now - 3720,
2776 });
2777 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
2778
2779 assert!(cache.validate_token(&token).await.is_none());
2780 }
2781
2782 #[tokio::test]
2783 async fn no_matching_scope_rejected() {
2784 let kid = "test-key-5";
2785 let (pem, jwks) = generate_test_keypair(kid);
2786
2787 let mock_server = wiremock::MockServer::start().await;
2788 wiremock::Mock::given(wiremock::matchers::method("GET"))
2789 .and(wiremock::matchers::path("/jwks.json"))
2790 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2791 .mount(&mock_server)
2792 .await;
2793
2794 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2795 let config = test_config(&jwks_uri);
2796 let cache = test_cache(&config);
2797
2798 let token = mint_token(
2799 &pem,
2800 kid,
2801 "https://auth.test.local",
2802 "https://mcp.test.local/mcp",
2803 "limited-bot",
2804 "some:other:scope", );
2806
2807 assert!(cache.validate_token(&token).await.is_none());
2808 }
2809
2810 #[tokio::test]
2811 async fn wrong_signing_key_rejected() {
2812 let kid = "test-key-6";
2813 let (_pem, jwks) = generate_test_keypair(kid);
2814
2815 let (attacker_pem, _) = generate_test_keypair(kid);
2817
2818 let mock_server = wiremock::MockServer::start().await;
2819 wiremock::Mock::given(wiremock::matchers::method("GET"))
2820 .and(wiremock::matchers::path("/jwks.json"))
2821 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2822 .mount(&mock_server)
2823 .await;
2824
2825 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2826 let config = test_config(&jwks_uri);
2827 let cache = test_cache(&config);
2828
2829 let token = mint_token(
2831 &attacker_pem,
2832 kid,
2833 "https://auth.test.local",
2834 "https://mcp.test.local/mcp",
2835 "attacker",
2836 "mcp:admin",
2837 );
2838
2839 assert!(cache.validate_token(&token).await.is_none());
2840 }
2841
2842 #[tokio::test]
2843 async fn admin_scope_maps_to_ops_role() {
2844 let kid = "test-key-7";
2845 let (pem, jwks) = generate_test_keypair(kid);
2846
2847 let mock_server = wiremock::MockServer::start().await;
2848 wiremock::Mock::given(wiremock::matchers::method("GET"))
2849 .and(wiremock::matchers::path("/jwks.json"))
2850 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2851 .mount(&mock_server)
2852 .await;
2853
2854 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2855 let config = test_config(&jwks_uri);
2856 let cache = test_cache(&config);
2857
2858 let token = mint_token(
2859 &pem,
2860 kid,
2861 "https://auth.test.local",
2862 "https://mcp.test.local/mcp",
2863 "admin-bot",
2864 "mcp:admin",
2865 );
2866
2867 let id = cache
2868 .validate_token(&token)
2869 .await
2870 .expect("should authenticate");
2871 assert_eq!(id.role, "ops");
2872 assert_eq!(id.name, "admin-bot");
2873 }
2874
2875 #[tokio::test]
2876 async fn jwks_server_down_returns_none() {
2877 let config = test_config("http://127.0.0.1:1/jwks.json");
2879 let cache = test_cache(&config);
2880
2881 let kid = "orphan-key";
2882 let (pem, _) = generate_test_keypair(kid);
2883 let token = mint_token(
2884 &pem,
2885 kid,
2886 "https://auth.test.local",
2887 "https://mcp.test.local/mcp",
2888 "bot",
2889 "mcp:read",
2890 );
2891
2892 assert!(cache.validate_token(&token).await.is_none());
2893 }
2894
2895 #[test]
2900 fn resolve_claim_path_flat_string() {
2901 let mut extra = HashMap::new();
2902 extra.insert(
2903 "scope".into(),
2904 serde_json::Value::String("mcp:read mcp:admin".into()),
2905 );
2906 let values = resolve_claim_path(&extra, "scope");
2907 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
2908 }
2909
2910 #[test]
2911 fn resolve_claim_path_flat_array() {
2912 let mut extra = HashMap::new();
2913 extra.insert(
2914 "roles".into(),
2915 serde_json::json!(["mcp-admin", "mcp-viewer"]),
2916 );
2917 let values = resolve_claim_path(&extra, "roles");
2918 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
2919 }
2920
2921 #[test]
2922 fn resolve_claim_path_nested_keycloak() {
2923 let mut extra = HashMap::new();
2924 extra.insert(
2925 "realm_access".into(),
2926 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
2927 );
2928 let values = resolve_claim_path(&extra, "realm_access.roles");
2929 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
2930 }
2931
2932 #[test]
2933 fn resolve_claim_path_missing_returns_empty() {
2934 let extra = HashMap::new();
2935 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
2936 }
2937
2938 #[test]
2939 fn resolve_claim_path_numeric_leaf_returns_empty() {
2940 let mut extra = HashMap::new();
2941 extra.insert("count".into(), serde_json::json!(42));
2942 assert!(resolve_claim_path(&extra, "count").is_empty());
2943 }
2944
2945 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
2951 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
2952 .expect("encoding key from PEM");
2953 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2954 header.kid = Some(kid.into());
2955 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
2956 }
2957
2958 fn test_config_with_role_claim(
2959 jwks_uri: &str,
2960 role_claim: &str,
2961 role_mappings: Vec<RoleMapping>,
2962 ) -> OAuthConfig {
2963 OAuthConfig {
2964 issuer: "https://auth.test.local".into(),
2965 audience: "https://mcp.test.local/mcp".into(),
2966 jwks_uri: jwks_uri.into(),
2967 scopes: vec![],
2968 role_claim: Some(role_claim.into()),
2969 role_mappings,
2970 jwks_cache_ttl: "5m".into(),
2971 proxy: None,
2972 token_exchange: None,
2973 ca_cert_path: None,
2974 allow_http_oauth_urls: true,
2975 max_jwks_keys: default_max_jwks_keys(),
2976 strict_audience_validation: false,
2977 jwks_max_response_bytes: default_jwks_max_bytes(),
2978 }
2979 }
2980
2981 #[tokio::test]
2982 async fn screen_oauth_target_rejects_literal_ip() {
2983 let err = screen_oauth_target("https://127.0.0.1/jwks.json", false)
2984 .await
2985 .expect_err("literal IPs must be rejected");
2986 let msg = err.to_string();
2987 assert!(msg.contains("literal IPv4 addresses are forbidden"));
2988 }
2989
2990 #[tokio::test]
2991 async fn screen_oauth_target_rejects_private_dns_resolution() {
2992 let err = screen_oauth_target("https://localhost/jwks.json", false)
2993 .await
2994 .expect_err("localhost resolution must be rejected");
2995 let msg = err.to_string();
2996 assert!(
2997 msg.contains("blocked IP") && msg.contains("loopback"),
2998 "got {msg:?}"
2999 );
3000 }
3001
3002 #[tokio::test]
3003 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
3004 let err = screen_oauth_target("http://127.0.0.1/jwks.json", true)
3005 .await
3006 .expect_err("literal IPs must still be rejected when http is allowed");
3007 let msg = err.to_string();
3008 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3009 }
3010
3011 #[tokio::test]
3012 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
3013 let err = screen_oauth_target("http://localhost/jwks.json", true)
3014 .await
3015 .expect_err("private DNS resolution must still be rejected when http is allowed");
3016 let msg = err.to_string();
3017 assert!(
3018 msg.contains("blocked IP") && msg.contains("loopback"),
3019 "got {msg:?}"
3020 );
3021 }
3022
3023 #[tokio::test]
3024 async fn screen_oauth_target_allows_public_hostname() {
3025 screen_oauth_target("https://example.com/.well-known/jwks.json", false)
3026 .await
3027 .expect("public hostname should pass screening");
3028 }
3029
3030 #[tokio::test]
3031 async fn audience_falls_back_to_azp_by_default() {
3032 let kid = "test-audience-azp-default";
3033 let (pem, jwks) = generate_test_keypair(kid);
3034
3035 let mock_server = wiremock::MockServer::start().await;
3036 wiremock::Mock::given(wiremock::matchers::method("GET"))
3037 .and(wiremock::matchers::path("/jwks.json"))
3038 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3039 .mount(&mock_server)
3040 .await;
3041
3042 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3043 let config = test_config(&jwks_uri);
3044 let cache = test_cache(&config);
3045
3046 let now = jsonwebtoken::get_current_timestamp();
3047 let token = mint_token_with_claims(
3048 &pem,
3049 kid,
3050 &serde_json::json!({
3051 "iss": "https://auth.test.local",
3052 "aud": "https://some-other-resource.example.com",
3053 "azp": "https://mcp.test.local/mcp",
3054 "sub": "compat-client",
3055 "scope": "mcp:read",
3056 "exp": now + 3600,
3057 "iat": now,
3058 }),
3059 );
3060
3061 let identity = cache
3062 .validate_token_with_reason(&token)
3063 .await
3064 .expect("azp fallback should remain enabled by default");
3065 assert_eq!(identity.role, "viewer");
3066 }
3067
3068 #[tokio::test]
3069 async fn strict_audience_validation_rejects_azp_only_match() {
3070 let kid = "test-audience-azp-strict";
3071 let (pem, jwks) = generate_test_keypair(kid);
3072
3073 let mock_server = wiremock::MockServer::start().await;
3074 wiremock::Mock::given(wiremock::matchers::method("GET"))
3075 .and(wiremock::matchers::path("/jwks.json"))
3076 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3077 .mount(&mock_server)
3078 .await;
3079
3080 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3081 let mut config = test_config(&jwks_uri);
3082 config.strict_audience_validation = true;
3083 let cache = test_cache(&config);
3084
3085 let now = jsonwebtoken::get_current_timestamp();
3086 let token = mint_token_with_claims(
3087 &pem,
3088 kid,
3089 &serde_json::json!({
3090 "iss": "https://auth.test.local",
3091 "aud": "https://some-other-resource.example.com",
3092 "azp": "https://mcp.test.local/mcp",
3093 "sub": "strict-client",
3094 "scope": "mcp:read",
3095 "exp": now + 3600,
3096 "iat": now,
3097 }),
3098 );
3099
3100 let failure = cache
3101 .validate_token_with_reason(&token)
3102 .await
3103 .expect_err("strict audience validation must ignore azp fallback");
3104 assert_eq!(failure, JwtValidationFailure::Invalid);
3105 }
3106
3107 #[derive(Clone, Default)]
3108 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
3109
3110 impl CapturedLogs {
3111 fn contents(&self) -> String {
3112 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
3113 String::from_utf8(bytes).unwrap_or_default()
3114 }
3115 }
3116
3117 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
3118
3119 impl std::io::Write for CapturedLogsWriter {
3120 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
3121 if let Ok(mut guard) = self.0.lock() {
3122 guard.extend_from_slice(buf);
3123 }
3124 Ok(buf.len())
3125 }
3126
3127 fn flush(&mut self) -> std::io::Result<()> {
3128 Ok(())
3129 }
3130 }
3131
3132 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
3133 type Writer = CapturedLogsWriter;
3134
3135 fn make_writer(&'a self) -> Self::Writer {
3136 CapturedLogsWriter(Arc::clone(&self.0))
3137 }
3138 }
3139
3140 #[tokio::test]
3141 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
3142 let kid = "oversized-jwks";
3143 let (_pem, jwks) = generate_test_keypair(kid);
3144 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
3145 oversized_body.push_str(&" ".repeat(4096));
3146
3147 let mock_server = wiremock::MockServer::start().await;
3148 wiremock::Mock::given(wiremock::matchers::method("GET"))
3149 .and(wiremock::matchers::path("/jwks.json"))
3150 .respond_with(
3151 wiremock::ResponseTemplate::new(200)
3152 .insert_header("content-type", "application/json")
3153 .set_body_string(oversized_body),
3154 )
3155 .mount(&mock_server)
3156 .await;
3157
3158 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3159 let mut config = test_config(&jwks_uri);
3160 config.jwks_max_response_bytes = 256;
3161 let cache = test_cache(&config);
3162
3163 let logs = CapturedLogs::default();
3164 let subscriber = tracing_subscriber::fmt()
3165 .with_writer(logs.clone())
3166 .with_ansi(false)
3167 .without_time()
3168 .finish();
3169 let _guard = tracing::subscriber::set_default(subscriber);
3170
3171 let result = cache.fetch_jwks().await;
3172 assert!(result.is_none(), "oversized JWKS must be dropped");
3173 assert!(
3174 logs.contents()
3175 .contains("JWKS response exceeded configured size cap"),
3176 "expected cap-exceeded warning in logs"
3177 );
3178 }
3179
3180 #[tokio::test]
3181 async fn role_claim_keycloak_nested_array() {
3182 let kid = "test-role-1";
3183 let (pem, jwks) = generate_test_keypair(kid);
3184
3185 let mock_server = wiremock::MockServer::start().await;
3186 wiremock::Mock::given(wiremock::matchers::method("GET"))
3187 .and(wiremock::matchers::path("/jwks.json"))
3188 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3189 .mount(&mock_server)
3190 .await;
3191
3192 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3193 let config = test_config_with_role_claim(
3194 &jwks_uri,
3195 "realm_access.roles",
3196 vec![
3197 RoleMapping {
3198 claim_value: "mcp-admin".into(),
3199 role: "ops".into(),
3200 },
3201 RoleMapping {
3202 claim_value: "mcp-viewer".into(),
3203 role: "viewer".into(),
3204 },
3205 ],
3206 );
3207 let cache = test_cache(&config);
3208
3209 let now = jsonwebtoken::get_current_timestamp();
3210 let token = mint_token_with_claims(
3211 &pem,
3212 kid,
3213 &serde_json::json!({
3214 "iss": "https://auth.test.local",
3215 "aud": "https://mcp.test.local/mcp",
3216 "sub": "keycloak-user",
3217 "exp": now + 3600,
3218 "iat": now,
3219 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
3220 }),
3221 );
3222
3223 let id = cache
3224 .validate_token(&token)
3225 .await
3226 .expect("should authenticate");
3227 assert_eq!(id.name, "keycloak-user");
3228 assert_eq!(id.role, "ops");
3229 }
3230
3231 #[tokio::test]
3232 async fn role_claim_flat_roles_array() {
3233 let kid = "test-role-2";
3234 let (pem, jwks) = generate_test_keypair(kid);
3235
3236 let mock_server = wiremock::MockServer::start().await;
3237 wiremock::Mock::given(wiremock::matchers::method("GET"))
3238 .and(wiremock::matchers::path("/jwks.json"))
3239 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3240 .mount(&mock_server)
3241 .await;
3242
3243 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3244 let config = test_config_with_role_claim(
3245 &jwks_uri,
3246 "roles",
3247 vec![
3248 RoleMapping {
3249 claim_value: "MCP.Admin".into(),
3250 role: "ops".into(),
3251 },
3252 RoleMapping {
3253 claim_value: "MCP.Reader".into(),
3254 role: "viewer".into(),
3255 },
3256 ],
3257 );
3258 let cache = test_cache(&config);
3259
3260 let now = jsonwebtoken::get_current_timestamp();
3261 let token = mint_token_with_claims(
3262 &pem,
3263 kid,
3264 &serde_json::json!({
3265 "iss": "https://auth.test.local",
3266 "aud": "https://mcp.test.local/mcp",
3267 "sub": "azure-ad-user",
3268 "exp": now + 3600,
3269 "iat": now,
3270 "roles": ["MCP.Reader", "OtherApp.Admin"]
3271 }),
3272 );
3273
3274 let id = cache
3275 .validate_token(&token)
3276 .await
3277 .expect("should authenticate");
3278 assert_eq!(id.name, "azure-ad-user");
3279 assert_eq!(id.role, "viewer");
3280 }
3281
3282 #[tokio::test]
3283 async fn role_claim_no_matching_value_rejected() {
3284 let kid = "test-role-3";
3285 let (pem, jwks) = generate_test_keypair(kid);
3286
3287 let mock_server = wiremock::MockServer::start().await;
3288 wiremock::Mock::given(wiremock::matchers::method("GET"))
3289 .and(wiremock::matchers::path("/jwks.json"))
3290 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3291 .mount(&mock_server)
3292 .await;
3293
3294 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3295 let config = test_config_with_role_claim(
3296 &jwks_uri,
3297 "roles",
3298 vec![RoleMapping {
3299 claim_value: "mcp-admin".into(),
3300 role: "ops".into(),
3301 }],
3302 );
3303 let cache = test_cache(&config);
3304
3305 let now = jsonwebtoken::get_current_timestamp();
3306 let token = mint_token_with_claims(
3307 &pem,
3308 kid,
3309 &serde_json::json!({
3310 "iss": "https://auth.test.local",
3311 "aud": "https://mcp.test.local/mcp",
3312 "sub": "limited-user",
3313 "exp": now + 3600,
3314 "iat": now,
3315 "roles": ["some-other-role"]
3316 }),
3317 );
3318
3319 assert!(cache.validate_token(&token).await.is_none());
3320 }
3321
3322 #[tokio::test]
3323 async fn role_claim_space_separated_string() {
3324 let kid = "test-role-4";
3325 let (pem, jwks) = generate_test_keypair(kid);
3326
3327 let mock_server = wiremock::MockServer::start().await;
3328 wiremock::Mock::given(wiremock::matchers::method("GET"))
3329 .and(wiremock::matchers::path("/jwks.json"))
3330 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3331 .mount(&mock_server)
3332 .await;
3333
3334 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3335 let config = test_config_with_role_claim(
3336 &jwks_uri,
3337 "custom_scope",
3338 vec![
3339 RoleMapping {
3340 claim_value: "write".into(),
3341 role: "ops".into(),
3342 },
3343 RoleMapping {
3344 claim_value: "read".into(),
3345 role: "viewer".into(),
3346 },
3347 ],
3348 );
3349 let cache = test_cache(&config);
3350
3351 let now = jsonwebtoken::get_current_timestamp();
3352 let token = mint_token_with_claims(
3353 &pem,
3354 kid,
3355 &serde_json::json!({
3356 "iss": "https://auth.test.local",
3357 "aud": "https://mcp.test.local/mcp",
3358 "sub": "custom-client",
3359 "exp": now + 3600,
3360 "iat": now,
3361 "custom_scope": "read audit"
3362 }),
3363 );
3364
3365 let id = cache
3366 .validate_token(&token)
3367 .await
3368 .expect("should authenticate");
3369 assert_eq!(id.name, "custom-client");
3370 assert_eq!(id.role, "viewer");
3371 }
3372
3373 #[tokio::test]
3374 async fn scope_backward_compat_without_role_claim() {
3375 let kid = "test-compat-1";
3377 let (pem, jwks) = generate_test_keypair(kid);
3378
3379 let mock_server = wiremock::MockServer::start().await;
3380 wiremock::Mock::given(wiremock::matchers::method("GET"))
3381 .and(wiremock::matchers::path("/jwks.json"))
3382 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3383 .mount(&mock_server)
3384 .await;
3385
3386 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3387 let config = test_config(&jwks_uri); let cache = test_cache(&config);
3389
3390 let token = mint_token(
3391 &pem,
3392 kid,
3393 "https://auth.test.local",
3394 "https://mcp.test.local/mcp",
3395 "legacy-bot",
3396 "mcp:admin other:scope",
3397 );
3398
3399 let id = cache
3400 .validate_token(&token)
3401 .await
3402 .expect("should authenticate");
3403 assert_eq!(id.name, "legacy-bot");
3404 assert_eq!(id.role, "ops"); }
3406
3407 #[tokio::test]
3412 async fn jwks_refresh_deduplication() {
3413 let kid = "test-dedup";
3416 let (pem, jwks) = generate_test_keypair(kid);
3417
3418 let mock_server = wiremock::MockServer::start().await;
3419 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
3420 .and(wiremock::matchers::path("/jwks.json"))
3421 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3422 .expect(1) .mount(&mock_server)
3424 .await;
3425
3426 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3427 let config = test_config(&jwks_uri);
3428 let cache = Arc::new(test_cache(&config));
3429
3430 let token = mint_token(
3432 &pem,
3433 kid,
3434 "https://auth.test.local",
3435 "https://mcp.test.local/mcp",
3436 "concurrent-bot",
3437 "mcp:read",
3438 );
3439
3440 let mut handles = Vec::new();
3441 for _ in 0..5 {
3442 let c = Arc::clone(&cache);
3443 let t = token.clone();
3444 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
3445 }
3446
3447 for h in handles {
3448 let result = h.await.unwrap();
3449 assert!(result.is_some(), "all concurrent requests should succeed");
3450 }
3451
3452 }
3454
3455 #[tokio::test]
3456 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
3457 let kid = "test-cooldown";
3460 let (_pem, jwks) = generate_test_keypair(kid);
3461
3462 let mock_server = wiremock::MockServer::start().await;
3463 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
3464 .and(wiremock::matchers::path("/jwks.json"))
3465 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3466 .expect(1) .mount(&mock_server)
3468 .await;
3469
3470 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3471 let config = test_config(&jwks_uri);
3472 let cache = test_cache(&config);
3473
3474 let fake_token1 =
3476 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
3477 let _ = cache.validate_token(fake_token1).await;
3478
3479 let fake_token2 =
3482 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
3483 let _ = cache.validate_token(fake_token2).await;
3484
3485 let fake_token3 =
3487 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
3488 let _ = cache.validate_token(fake_token3).await;
3489
3490 }
3492
3493 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
3496 OAuthProxyConfig {
3497 authorize_url: "https://example.invalid/auth".into(),
3498 token_url: token_url.into(),
3499 client_id: "mcp-client".into(),
3500 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
3501 introspection_url: None,
3502 revocation_url: None,
3503 expose_admin_endpoints: false,
3504 require_auth_on_admin_endpoints: false,
3505 }
3506 }
3507
3508 fn test_http_client() -> OauthHttpClient {
3511 rustls::crypto::ring::default_provider()
3512 .install_default()
3513 .ok();
3514 let config = OAuthConfig::builder(
3515 "https://auth.test.local",
3516 "https://mcp.test.local/mcp",
3517 "https://auth.test.local/.well-known/jwks.json",
3518 )
3519 .allow_http_oauth_urls(true)
3520 .build();
3521 OauthHttpClient::with_config(&config)
3522 .expect("build test http client")
3523 .__test_allow_loopback_ssrf()
3524 }
3525
3526 #[tokio::test]
3527 async fn introspect_proxies_and_injects_client_credentials() {
3528 use wiremock::matchers::{body_string_contains, method, path};
3529
3530 let mock_server = wiremock::MockServer::start().await;
3531 wiremock::Mock::given(method("POST"))
3532 .and(path("/introspect"))
3533 .and(body_string_contains("client_id=mcp-client"))
3534 .and(body_string_contains("client_secret=shh"))
3535 .and(body_string_contains("token=abc"))
3536 .respond_with(
3537 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
3538 "active": true,
3539 "scope": "read"
3540 })),
3541 )
3542 .expect(1)
3543 .mount(&mock_server)
3544 .await;
3545
3546 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
3547 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
3548
3549 let http = test_http_client();
3550 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3551 assert_eq!(resp.status(), 200);
3552 }
3553
3554 #[tokio::test]
3555 async fn introspect_returns_404_when_not_configured() {
3556 let proxy = proxy_cfg("https://example.invalid/token");
3557 let http = test_http_client();
3558 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3559 assert_eq!(resp.status(), 404);
3560 }
3561
3562 #[tokio::test]
3563 async fn revoke_proxies_and_returns_upstream_status() {
3564 use wiremock::matchers::{method, path};
3565
3566 let mock_server = wiremock::MockServer::start().await;
3567 wiremock::Mock::given(method("POST"))
3568 .and(path("/revoke"))
3569 .respond_with(wiremock::ResponseTemplate::new(200))
3570 .expect(1)
3571 .mount(&mock_server)
3572 .await;
3573
3574 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
3575 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
3576
3577 let http = test_http_client();
3578 let resp = handle_revoke(&http, &proxy, "token=abc").await;
3579 assert_eq!(resp.status(), 200);
3580 }
3581
3582 #[tokio::test]
3583 async fn revoke_returns_404_when_not_configured() {
3584 let proxy = proxy_cfg("https://example.invalid/token");
3585 let http = test_http_client();
3586 let resp = handle_revoke(&http, &proxy, "token=abc").await;
3587 assert_eq!(resp.status(), 404);
3588 }
3589
3590 #[test]
3591 fn metadata_advertises_endpoints_only_when_configured() {
3592 let mut cfg = test_config("https://auth.test.local/jwks.json");
3593 let m = authorization_server_metadata("https://mcp.local", &cfg);
3595 assert!(m.get("introspection_endpoint").is_none());
3596 assert!(m.get("revocation_endpoint").is_none());
3597
3598 let mut proxy = proxy_cfg("https://upstream.local/token");
3601 proxy.introspection_url = Some("https://upstream.local/introspect".into());
3602 proxy.revocation_url = Some("https://upstream.local/revoke".into());
3603 cfg.proxy = Some(proxy);
3604 let m = authorization_server_metadata("https://mcp.local", &cfg);
3605 assert!(
3606 m.get("introspection_endpoint").is_none(),
3607 "introspection must not be advertised when expose_admin_endpoints=false"
3608 );
3609 assert!(
3610 m.get("revocation_endpoint").is_none(),
3611 "revocation must not be advertised when expose_admin_endpoints=false"
3612 );
3613
3614 if let Some(p) = cfg.proxy.as_mut() {
3616 p.expose_admin_endpoints = true;
3617 p.revocation_url = None;
3618 }
3619 let m = authorization_server_metadata("https://mcp.local", &cfg);
3620 assert_eq!(
3621 m["introspection_endpoint"],
3622 serde_json::Value::String("https://mcp.local/introspect".into())
3623 );
3624 assert!(m.get("revocation_endpoint").is_none());
3625
3626 if let Some(p) = cfg.proxy.as_mut() {
3628 p.revocation_url = Some("https://upstream.local/revoke".into());
3629 }
3630 let m = authorization_server_metadata("https://mcp.local", &cfg);
3631 assert_eq!(
3632 m["revocation_endpoint"],
3633 serde_json::Value::String("https://mcp.local/revoke".into())
3634 );
3635 }
3636}