1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86async fn screen_oauth_target_core(
106 url: &str,
107 allow_http: bool,
108 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
109 test_allow_loopback_ssrf: bool,
110) -> Result<(), crate::error::McpxError> {
111 let parsed = check_oauth_url("oauth target", url, allow_http)?;
112 if test_allow_loopback_ssrf {
113 return Ok(());
114 }
115 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
116 return Err(crate::error::McpxError::Config(format!(
117 "OAuth target forbidden ({reason}): {url}"
118 )));
119 }
120
121 let host = parsed.host_str().ok_or_else(|| {
122 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
123 })?;
124 let port = parsed.port_or_known_default().ok_or_else(|| {
125 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
126 })?;
127
128 let addrs = lookup_host((host, port)).await.map_err(|error| {
129 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
130 })?;
131
132 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
133 let mut any_addr = false;
134 for addr in addrs {
135 any_addr = true;
136 let ip = addr.ip();
137 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
138 if reason == "cloud_metadata" {
141 return Err(crate::error::McpxError::Config(format!(
142 "OAuth target resolved to blocked IP ({reason}): {url}"
143 )));
144 }
145 if allowlist.is_empty() {
149 return Err(crate::error::McpxError::Config(format!(
150 "OAuth target resolved to blocked IP ({reason}): {url}"
151 )));
152 }
153 if host_allowed || allowlist.ip_allowed(ip) {
155 continue;
156 }
157 return Err(crate::error::McpxError::Config(format!(
158 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
159 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
160 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
161 URL: {url}"
162 )));
163 }
164 }
165 if !any_addr {
166 return Err(crate::error::McpxError::Config(format!(
167 "OAuth target DNS resolution returned no addresses: {url}"
168 )));
169 }
170
171 Ok(())
172}
173
174async fn screen_oauth_target(
177 url: &str,
178 allow_http: bool,
179 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
180) -> Result<(), crate::error::McpxError> {
181 screen_oauth_target_core(url, allow_http, allowlist, false).await
182}
183
184#[cfg(any(test, feature = "test-helpers"))]
188async fn screen_oauth_target_with_test_override(
189 url: &str,
190 allow_http: bool,
191 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
192 test_allow_loopback_ssrf: bool,
193) -> Result<(), crate::error::McpxError> {
194 screen_oauth_target_core(url, allow_http, allowlist, test_allow_loopback_ssrf).await
195}
196
197#[derive(Clone)]
238pub struct OauthHttpClient {
239 inner: reqwest::Client,
240 allow_http: bool,
241 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
246 #[cfg(feature = "oauth-mtls-client")]
251 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
252 #[cfg(any(test, feature = "test-helpers"))]
258 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
259}
260
261#[cfg(feature = "oauth-mtls-client")]
265#[derive(Debug, Clone, Hash, Eq, PartialEq)]
266struct MtlsClientKey {
267 cert_path: PathBuf,
268 key_path: PathBuf,
269}
270
271impl OauthHttpClient {
272 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
290 Self::build(Some(config))
291 }
292
293 #[deprecated(
316 since = "1.2.1",
317 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
318 )]
319 pub fn new() -> Result<Self, crate::error::McpxError> {
320 Self::build(None)
321 }
322
323 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
326 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
327
328 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
333 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
334 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
335 })?),
336 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
337 };
338
339 let redirect_allowlist = Arc::clone(&allowlist);
342
343 #[cfg(any(test, feature = "test-helpers"))]
347 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
348 Arc::new(AtomicBool::new(false));
349 #[cfg(not(any(test, feature = "test-helpers")))]
350 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
351
352 let resolver: Arc<dyn reqwest::dns::Resolve> =
353 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
354 Arc::clone(&allowlist),
355 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
359 test_bypass.clone(),
360 ));
361
362 let mut builder = reqwest::Client::builder()
363 .no_proxy()
367 .dns_resolver(Arc::clone(&resolver))
368 .connect_timeout(Duration::from_secs(10))
369 .timeout(Duration::from_secs(30))
370 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
371 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
381 Ok(()) => attempt.follow(),
382 Err(reason) => {
383 tracing::warn!(
387 reason = %reason,
388 target = %crate::ssrf::sanitized_url_for_log(attempt.url()),
389 "oauth redirect rejected"
390 );
391 attempt.error(reason)
392 }
393 }
394 }));
395
396 if let Some(cfg) = config
397 && let Some(ref ca_path) = cfg.ca_cert_path
398 {
399 let pem = std::fs::read(ca_path).map_err(|e| {
404 crate::error::McpxError::Startup(format!(
405 "oauth http client: read ca_cert_path {}: {e}",
406 ca_path.display()
407 ))
408 })?;
409 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
410 crate::error::McpxError::Startup(format!(
411 "oauth http client: parse ca_cert_path {}: {e}",
412 ca_path.display()
413 ))
414 })?;
415 builder = builder.add_root_certificate(cert);
416 }
417
418 let inner = builder.build().map_err(|e| {
419 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
420 })?;
421
422 #[cfg(feature = "oauth-mtls-client")]
423 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
424
425 Ok(Self {
426 inner,
427 allow_http,
428 allowlist,
429 #[cfg(feature = "oauth-mtls-client")]
430 mtls_clients,
431 #[cfg(any(test, feature = "test-helpers"))]
432 test_allow_loopback_ssrf: test_bypass,
433 })
434 }
435
436 async fn send_screened(
437 &self,
438 url: &str,
439 request: reqwest::RequestBuilder,
440 ) -> Result<reqwest::Response, crate::error::McpxError> {
441 #[cfg(any(test, feature = "test-helpers"))]
442 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
443 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
444 .await?;
445 } else {
446 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
447 }
448 #[cfg(not(any(test, feature = "test-helpers")))]
449 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
450 request.send().await.map_err(|error| {
451 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
452 })
453 }
454
455 #[cfg(any(test, feature = "test-helpers"))]
460 #[doc(hidden)]
461 #[must_use]
462 pub fn __test_allow_loopback_ssrf(self) -> Self {
463 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
466 self
467 }
468
469 #[doc(hidden)]
475 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
476 self.inner.get(url).send().await
477 }
478
479 #[cfg(any(test, feature = "test-helpers"))]
485 #[doc(hidden)]
486 #[must_use]
487 pub fn __test_inner_client(&self) -> &reqwest::Client {
488 &self.inner
489 }
490
491 #[cfg(feature = "oauth-mtls-client")]
498 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
499 if let Some(cc) = &cfg.client_cert {
500 let key = MtlsClientKey {
501 cert_path: cc.cert_path.clone(),
502 key_path: cc.key_path.clone(),
503 };
504 if let Some(client) = self.mtls_clients.get(&key) {
505 return client;
506 }
507 }
508 &self.inner
509 }
510
511 #[cfg(not(feature = "oauth-mtls-client"))]
512 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
513 &self.inner
514 }
515}
516
517impl std::fmt::Debug for OauthHttpClient {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
520 }
521}
522
523#[derive(Debug, Clone, Default, Deserialize)]
583#[non_exhaustive]
584pub struct OAuthSsrfAllowlist {
585 #[serde(default)]
590 pub hosts: Vec<String>,
591 #[serde(default)]
597 pub cidrs: Vec<String>,
598}
599
600fn compile_oauth_ssrf_allowlist(
607 raw: &OAuthSsrfAllowlist,
608) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
609 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
610 for (idx, entry) in raw.hosts.iter().enumerate() {
611 let trimmed = entry.trim();
612 if trimmed.is_empty() {
613 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
614 }
615 if trimmed.contains([':', '/', '@', '?', '#']) {
619 return Err(format!(
620 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
621 (no scheme, port, path, userinfo, query, or fragment)"
622 ));
623 }
624 match url::Host::parse(trimmed) {
625 Ok(url::Host::Domain(_)) => {}
626 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
627 return Err(format!(
628 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
629 here -- list them via oauth.ssrf_allowlist.cidrs instead"
630 ));
631 }
632 Err(e) => {
633 return Err(format!(
634 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
635 ));
636 }
637 }
638 hosts.push(trimmed.to_ascii_lowercase());
639 }
640 hosts.sort();
641 hosts.dedup();
642
643 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
644 for (idx, entry) in raw.cidrs.iter().enumerate() {
645 let parsed = crate::ssrf::CidrEntry::parse(entry)
646 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
647 cidrs.push(parsed);
648 }
649
650 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
651}
652
653#[derive(Debug, Clone, Deserialize)]
655#[non_exhaustive]
656pub struct OAuthConfig {
657 pub issuer: String,
659 pub audience: String,
661 pub jwks_uri: String,
663 #[serde(default)]
666 pub scopes: Vec<ScopeMapping>,
667 pub role_claim: Option<String>,
673 #[serde(default)]
676 pub role_mappings: Vec<RoleMapping>,
677 #[serde(default = "default_jwks_cache_ttl")]
680 pub jwks_cache_ttl: String,
681 pub proxy: Option<OAuthProxyConfig>,
685 pub token_exchange: Option<TokenExchangeConfig>,
690 #[serde(default)]
705 pub ca_cert_path: Option<PathBuf>,
706 #[serde(default)]
718 pub allow_http_oauth_urls: bool,
719 #[serde(default)]
728 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
729 #[serde(default = "default_max_jwks_keys")]
733 pub max_jwks_keys: usize,
734 #[serde(default)]
743 #[deprecated(
744 since = "1.7.0",
745 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
746 )]
747 pub strict_audience_validation: bool,
748 #[serde(default)]
756 pub audience_validation_mode: Option<AudienceValidationMode>,
757 #[serde(default = "default_jwks_max_bytes")]
761 pub jwks_max_response_bytes: u64,
762}
763
764fn default_jwks_cache_ttl() -> String {
765 "10m".into()
766}
767
768const fn default_max_jwks_keys() -> usize {
769 256
770}
771
772const fn default_jwks_max_bytes() -> u64 {
773 1024 * 1024
774}
775
776#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
793#[serde(rename_all = "snake_case")]
794#[non_exhaustive]
795pub enum AudienceValidationMode {
796 Permissive,
800 #[default]
804 Warn,
805 Strict,
809}
810
811impl Default for OAuthConfig {
812 fn default() -> Self {
813 Self {
814 issuer: String::new(),
815 audience: String::new(),
816 jwks_uri: String::new(),
817 scopes: Vec::new(),
818 role_claim: None,
819 role_mappings: Vec::new(),
820 jwks_cache_ttl: default_jwks_cache_ttl(),
821 proxy: None,
822 token_exchange: None,
823 ca_cert_path: None,
824 allow_http_oauth_urls: false,
825 max_jwks_keys: default_max_jwks_keys(),
826 #[allow(
827 deprecated,
828 reason = "default-construct deprecated field for backward compat"
829 )]
830 strict_audience_validation: false,
831 audience_validation_mode: None,
832 jwks_max_response_bytes: default_jwks_max_bytes(),
833 ssrf_allowlist: None,
834 }
835 }
836}
837
838impl OAuthConfig {
839 #[must_use]
845 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
846 if let Some(mode) = self.audience_validation_mode {
847 return mode;
848 }
849 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
850 if self.strict_audience_validation {
851 AudienceValidationMode::Strict
852 } else {
853 AudienceValidationMode::Warn
854 }
855 }
856
857 pub fn builder(
863 issuer: impl Into<String>,
864 audience: impl Into<String>,
865 jwks_uri: impl Into<String>,
866 ) -> OAuthConfigBuilder {
867 OAuthConfigBuilder {
868 inner: Self {
869 issuer: issuer.into(),
870 audience: audience.into(),
871 jwks_uri: jwks_uri.into(),
872 ..Self::default()
873 },
874 }
875 }
876
877 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
893 let allow_http = self.allow_http_oauth_urls;
894 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
895 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
896 return Err(crate::error::McpxError::Config(format!(
897 "oauth.issuer forbidden ({reason})"
898 )));
899 }
900 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
901 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
902 return Err(crate::error::McpxError::Config(format!(
903 "oauth.jwks_uri forbidden ({reason})"
904 )));
905 }
906 if let Some(proxy) = &self.proxy {
907 let url = check_oauth_url(
908 "oauth.proxy.authorize_url",
909 &proxy.authorize_url,
910 allow_http,
911 )?;
912 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
913 return Err(crate::error::McpxError::Config(format!(
914 "oauth.proxy.authorize_url forbidden ({reason})"
915 )));
916 }
917 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
918 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
919 return Err(crate::error::McpxError::Config(format!(
920 "oauth.proxy.token_url forbidden ({reason})"
921 )));
922 }
923 if let Some(url) = &proxy.introspection_url {
924 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
925 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
926 return Err(crate::error::McpxError::Config(format!(
927 "oauth.proxy.introspection_url forbidden ({reason})"
928 )));
929 }
930 }
931 if let Some(url) = &proxy.revocation_url {
932 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
933 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
934 return Err(crate::error::McpxError::Config(format!(
935 "oauth.proxy.revocation_url forbidden ({reason})"
936 )));
937 }
938 }
939 if proxy.expose_admin_endpoints
946 && !proxy.require_auth_on_admin_endpoints
947 && !proxy.allow_unauthenticated_admin_endpoints
948 {
949 return Err(crate::error::McpxError::Config(
950 "oauth.proxy: expose_admin_endpoints = true requires \
951 require_auth_on_admin_endpoints = true (recommended) \
952 or allow_unauthenticated_admin_endpoints = true \
953 (explicit opt-out, only safe behind an authenticated \
954 reverse proxy)"
955 .into(),
956 ));
957 }
958 }
959 if let Some(tx) = &self.token_exchange {
960 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
961 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
962 return Err(crate::error::McpxError::Config(format!(
963 "oauth.token_exchange.token_url forbidden ({reason})"
964 )));
965 }
966 validate_token_exchange_client_auth(tx)?;
969 }
970 if let Some(raw) = &self.ssrf_allowlist {
974 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
975 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
976 })?;
977 if !compiled.is_empty() {
978 tracing::warn!(
979 host_count = compiled.host_count(),
980 cidr_count = compiled.cidr_count(),
981 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
982 are now reachable. Cloud-metadata addresses remain blocked. \
983 See SECURITY.md \"Operator allowlist\"."
984 );
985 }
986 }
987 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
990 crate::error::McpxError::Config(format!(
991 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
992 self.jwks_cache_ttl
993 ))
994 })?;
995 Ok(())
996 }
997}
998
999fn validate_token_exchange_client_auth(
1005 tx: &TokenExchangeConfig,
1006) -> Result<(), crate::error::McpxError> {
1007 match (&tx.client_cert, tx.client_secret.is_some()) {
1008 (Some(_), true) => Err(crate::error::McpxError::Config(
1009 "oauth.token_exchange: client_cert and client_secret are mutually \
1010 exclusive (RFC 8705 ยง2). Set exactly one."
1011 .into(),
1012 )),
1013 (None, false) => Err(crate::error::McpxError::Config(
1014 "oauth.token_exchange: token exchange requires client authentication. \
1015 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1016 .into(),
1017 )),
1018 (Some(cc), false) => validate_client_cert_config(cc),
1019 (None, true) => Ok(()),
1020 }
1021}
1022
1023fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1036 #[cfg(not(feature = "oauth-mtls-client"))]
1037 {
1038 let _ = cc;
1039 Err(crate::error::McpxError::Config(
1040 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1041 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1042 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1043 the field"
1044 .into(),
1045 ))
1046 }
1047 #[cfg(feature = "oauth-mtls-client")]
1048 {
1049 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1050 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1051 crate::error::McpxError::Config(format!(
1052 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1053 cc.cert_path.display()
1054 ))
1055 })?;
1056 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1057 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1058 crate::error::McpxError::Config(format!(
1059 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1060 cc.key_path.display()
1061 ))
1062 })?;
1063 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1064 combined.extend_from_slice(&cert_bytes);
1065 if !cert_bytes.ends_with(b"\n") {
1066 combined.push(b'\n');
1067 }
1068 combined.extend_from_slice(&key_bytes);
1069 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1070 tracing::warn!(
1071 error = %e,
1072 cert_path = %cc.cert_path.display(),
1073 key_path = %cc.key_path.display(),
1074 "client cert PEM parse failed"
1075 );
1076 crate::error::McpxError::Config(format!(
1077 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1078 cc.cert_path.display(),
1079 cc.key_path.display()
1080 ))
1081 })?;
1082 Ok(())
1083 }
1084}
1085
1086#[cfg(feature = "oauth-mtls-client")]
1094fn build_mtls_clients(
1095 config: Option<&OAuthConfig>,
1096 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1097 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1098) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1099 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1100 let Some(cfg) = config else {
1101 return Ok(Arc::new(map));
1102 };
1103 let Some(tx) = &cfg.token_exchange else {
1104 return Ok(Arc::new(map));
1105 };
1106 let Some(cc) = &tx.client_cert else {
1107 return Ok(Arc::new(map));
1108 };
1109
1110 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1111 crate::error::McpxError::Startup(format!(
1112 "oauth http client mTLS: read cert_path {}: {e}",
1113 cc.cert_path.display()
1114 ))
1115 })?;
1116 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1117 crate::error::McpxError::Startup(format!(
1118 "oauth http client mTLS: read key_path {}: {e}",
1119 cc.key_path.display()
1120 ))
1121 })?;
1122 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1123 combined.extend_from_slice(&cert_bytes);
1124 if !cert_bytes.ends_with(b"\n") {
1125 combined.push(b'\n');
1126 }
1127 combined.extend_from_slice(&key_bytes);
1128 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1129 crate::error::McpxError::Startup(format!(
1130 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1131 cc.cert_path.display(),
1132 cc.key_path.display()
1133 ))
1134 })?;
1135
1136 let resolver: Arc<dyn reqwest::dns::Resolve> =
1137 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1138 Arc::clone(allowlist),
1139 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1144 test_bypass.clone(),
1145 ));
1146
1147 let mut builder = reqwest::Client::builder()
1148 .no_proxy()
1150 .dns_resolver(Arc::clone(&resolver))
1151 .connect_timeout(Duration::from_secs(10))
1152 .timeout(Duration::from_secs(30))
1153 .redirect(reqwest::redirect::Policy::none())
1154 .identity(identity);
1155
1156 if let Some(ref ca_path) = cfg.ca_cert_path {
1157 let pem = std::fs::read(ca_path).map_err(|e| {
1158 crate::error::McpxError::Startup(format!(
1159 "oauth http client mTLS: read ca_cert_path {}: {e}",
1160 ca_path.display()
1161 ))
1162 })?;
1163 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1164 crate::error::McpxError::Startup(format!(
1165 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1166 ca_path.display()
1167 ))
1168 })?;
1169 builder = builder.add_root_certificate(cert);
1170 }
1171
1172 let client = builder.build().map_err(|e| {
1173 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1174 })?;
1175 map.insert(
1176 MtlsClientKey {
1177 cert_path: cc.cert_path.clone(),
1178 key_path: cc.key_path.clone(),
1179 },
1180 client,
1181 );
1182 Ok(Arc::new(map))
1183}
1184
1185fn check_oauth_url(
1192 field: &str,
1193 raw: &str,
1194 allow_http: bool,
1195) -> Result<url::Url, crate::error::McpxError> {
1196 let parsed = url::Url::parse(raw).map_err(|e| {
1197 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1198 })?;
1199 if !parsed.username().is_empty() || parsed.password().is_some() {
1200 return Err(crate::error::McpxError::Config(format!(
1201 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1202 )));
1203 }
1204 match parsed.scheme() {
1205 "https" => Ok(parsed),
1206 "http" if allow_http => Ok(parsed),
1207 "http" => Err(crate::error::McpxError::Config(format!(
1208 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1209 to override - strongly discouraged in production)"
1210 ))),
1211 other => Err(crate::error::McpxError::Config(format!(
1212 "{field}: must use https scheme (got {other:?})"
1213 ))),
1214 }
1215}
1216
1217#[derive(Debug, Clone)]
1223#[must_use = "builders do nothing until `.build()` is called"]
1224pub struct OAuthConfigBuilder {
1225 inner: OAuthConfig,
1226}
1227
1228impl OAuthConfigBuilder {
1229 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1231 self.inner.scopes = scopes;
1232 self
1233 }
1234
1235 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1237 self.inner.scopes.push(ScopeMapping {
1238 scope: scope.into(),
1239 role: role.into(),
1240 });
1241 self
1242 }
1243
1244 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1247 self.inner.role_claim = Some(claim.into());
1248 self
1249 }
1250
1251 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1253 self.inner.role_mappings = mappings;
1254 self
1255 }
1256
1257 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1260 self.inner.role_mappings.push(RoleMapping {
1261 claim_value: claim_value.into(),
1262 role: role.into(),
1263 });
1264 self
1265 }
1266
1267 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1270 self.inner.jwks_cache_ttl = ttl.into();
1271 self
1272 }
1273
1274 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1277 self.inner.proxy = Some(proxy);
1278 self
1279 }
1280
1281 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1283 self.inner.token_exchange = Some(token_exchange);
1284 self
1285 }
1286
1287 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1292 self.inner.ca_cert_path = Some(path.into());
1293 self
1294 }
1295
1296 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1302 self.inner.allow_http_oauth_urls = allow;
1303 self
1304 }
1305
1306 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1315 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1316 #[allow(
1317 deprecated,
1318 reason = "intentional: deprecated builder forwards to deprecated field"
1319 )]
1320 {
1321 self.inner.strict_audience_validation = strict;
1322 }
1323 self.inner.audience_validation_mode = None;
1324 self
1325 }
1326
1327 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1335 self.inner.audience_validation_mode = Some(mode);
1336 self
1337 }
1338
1339 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1341 self.inner.jwks_max_response_bytes = bytes;
1342 self
1343 }
1344
1345 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1353 self.inner.ssrf_allowlist = Some(allowlist);
1354 self
1355 }
1356
1357 #[must_use]
1359 pub fn build(self) -> OAuthConfig {
1360 self.inner
1361 }
1362}
1363
1364#[derive(Debug, Clone, Deserialize)]
1366#[non_exhaustive]
1367pub struct ScopeMapping {
1368 pub scope: String,
1370 pub role: String,
1372}
1373
1374#[derive(Debug, Clone, Deserialize)]
1378#[non_exhaustive]
1379pub struct RoleMapping {
1380 pub claim_value: String,
1382 pub role: String,
1384}
1385
1386#[derive(Debug, Clone, Deserialize)]
1393#[non_exhaustive]
1394pub struct TokenExchangeConfig {
1395 pub token_url: String,
1398 pub client_id: String,
1400 pub client_secret: Option<secrecy::SecretString>,
1405 pub client_cert: Option<ClientCertConfig>,
1418 pub audience: String,
1422}
1423
1424impl TokenExchangeConfig {
1425 #[must_use]
1427 pub fn new(
1428 token_url: String,
1429 client_id: String,
1430 client_secret: Option<secrecy::SecretString>,
1431 client_cert: Option<ClientCertConfig>,
1432 audience: String,
1433 ) -> Self {
1434 Self {
1435 token_url,
1436 client_id,
1437 client_secret,
1438 client_cert,
1439 audience,
1440 }
1441 }
1442}
1443
1444#[derive(Debug, Clone, Deserialize)]
1448#[non_exhaustive]
1449pub struct ClientCertConfig {
1450 pub cert_path: PathBuf,
1453 pub key_path: PathBuf,
1457}
1458
1459impl ClientCertConfig {
1460 #[must_use]
1464 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1465 Self {
1466 cert_path,
1467 key_path,
1468 }
1469 }
1470}
1471
1472#[derive(Debug, Deserialize)]
1474#[non_exhaustive]
1475pub struct ExchangedToken {
1476 pub access_token: String,
1478 pub expires_in: Option<u64>,
1480 pub issued_token_type: Option<String>,
1483}
1484
1485#[derive(Debug, Clone, Deserialize, Default)]
1492#[non_exhaustive]
1493pub struct OAuthProxyConfig {
1494 pub authorize_url: String,
1497 pub token_url: String,
1500 pub client_id: String,
1502 pub client_secret: Option<secrecy::SecretString>,
1504 #[serde(default)]
1508 pub introspection_url: Option<String>,
1509 #[serde(default)]
1513 pub revocation_url: Option<String>,
1514 #[serde(default)]
1526 pub expose_admin_endpoints: bool,
1527 #[serde(default)]
1533 pub require_auth_on_admin_endpoints: bool,
1534 #[serde(default)]
1545 pub allow_unauthenticated_admin_endpoints: bool,
1546}
1547
1548impl OAuthProxyConfig {
1549 pub fn builder(
1557 authorize_url: impl Into<String>,
1558 token_url: impl Into<String>,
1559 client_id: impl Into<String>,
1560 ) -> OAuthProxyConfigBuilder {
1561 OAuthProxyConfigBuilder {
1562 inner: Self {
1563 authorize_url: authorize_url.into(),
1564 token_url: token_url.into(),
1565 client_id: client_id.into(),
1566 ..Self::default()
1567 },
1568 }
1569 }
1570}
1571
1572#[derive(Debug, Clone)]
1578#[must_use = "builders do nothing until `.build()` is called"]
1579pub struct OAuthProxyConfigBuilder {
1580 inner: OAuthProxyConfig,
1581}
1582
1583impl OAuthProxyConfigBuilder {
1584 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1586 self.inner.client_secret = Some(secret);
1587 self
1588 }
1589
1590 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1594 self.inner.introspection_url = Some(url.into());
1595 self
1596 }
1597
1598 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1602 self.inner.revocation_url = Some(url.into());
1603 self
1604 }
1605
1606 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1614 self.inner.expose_admin_endpoints = expose;
1615 self
1616 }
1617
1618 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1621 self.inner.require_auth_on_admin_endpoints = require;
1622 self
1623 }
1624
1625 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1629 self.inner.allow_unauthenticated_admin_endpoints = allow;
1630 self
1631 }
1632
1633 #[must_use]
1635 pub fn build(self) -> OAuthProxyConfig {
1636 self.inner
1637 }
1638}
1639
1640type JwksKeyCache = (
1648 HashMap<String, (Algorithm, DecodingKey)>,
1649 Vec<(Algorithm, DecodingKey)>,
1650);
1651
1652struct CachedKeys {
1653 keys: HashMap<String, (Algorithm, DecodingKey)>,
1655 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1657 fetched_at: Instant,
1658 ttl: Duration,
1659}
1660
1661impl CachedKeys {
1662 fn is_expired(&self) -> bool {
1663 self.fetched_at.elapsed() >= self.ttl
1664 }
1665}
1666
1667#[allow(
1676 missing_debug_implementations,
1677 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1678)]
1679#[non_exhaustive]
1680pub struct JwksCache {
1681 jwks_uri: String,
1682 ttl: Duration,
1683 max_jwks_keys: usize,
1684 max_response_bytes: u64,
1685 allow_http: bool,
1686 inner: RwLock<Option<CachedKeys>>,
1687 http: reqwest::Client,
1688 validation_template: Validation,
1689 expected_audience: String,
1692 audience_mode: AudienceValidationMode,
1693 azp_fallback_warned: AtomicBool,
1697 scopes: Vec<ScopeMapping>,
1698 role_claim: Option<String>,
1699 role_mappings: Vec<RoleMapping>,
1700 last_refresh_attempt: RwLock<Option<Instant>>,
1703 refresh_lock: tokio::sync::Mutex<()>,
1705 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1709 #[cfg(any(test, feature = "test-helpers"))]
1713 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1714}
1715
1716const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1718
1719const ACCEPTED_ALGS: &[Algorithm] = &[
1721 Algorithm::RS256,
1722 Algorithm::RS384,
1723 Algorithm::RS512,
1724 Algorithm::ES256,
1725 Algorithm::ES384,
1726 Algorithm::PS256,
1727 Algorithm::PS384,
1728 Algorithm::PS512,
1729 Algorithm::EdDSA,
1730];
1731
1732#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1734#[non_exhaustive]
1735pub enum JwtValidationFailure {
1736 Expired,
1738 Invalid,
1740}
1741
1742impl JwksCache {
1743 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1755 rustls::crypto::ring::default_provider()
1758 .install_default()
1759 .ok();
1760 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1761 .install_default()
1762 .ok();
1763
1764 let ttl = humantime::parse_duration(&config.jwks_cache_ttl).map_err(|error| {
1765 format!(
1766 "invalid jwks_cache_ttl {:?}: {error}",
1767 config.jwks_cache_ttl
1768 )
1769 })?;
1770
1771 let mut validation = Validation::new(Algorithm::RS256);
1772 validation.validate_aud = false;
1784 validation.set_issuer(&[&config.issuer]);
1785 validation.set_required_spec_claims(&["exp", "iss"]);
1786 validation.validate_exp = true;
1787 validation.validate_nbf = true;
1788
1789 let allow_http = config.allow_http_oauth_urls;
1790
1791 let allowlist = match config.ssrf_allowlist.as_ref() {
1794 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1795 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1796 "oauth.ssrf_allowlist: {e}"
1797 ))
1798 })?),
1799 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1800 };
1801 let redirect_allowlist = Arc::clone(&allowlist);
1802
1803 #[cfg(any(test, feature = "test-helpers"))]
1805 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1806 Arc::new(AtomicBool::new(false));
1807 #[cfg(not(any(test, feature = "test-helpers")))]
1808 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1809
1810 let resolver: Arc<dyn reqwest::dns::Resolve> =
1811 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1812 Arc::clone(&allowlist),
1813 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1814 test_bypass.clone(),
1815 ));
1816
1817 let mut http_builder = reqwest::Client::builder()
1818 .no_proxy()
1820 .dns_resolver(Arc::clone(&resolver))
1821 .timeout(Duration::from_secs(10))
1822 .connect_timeout(Duration::from_secs(3))
1823 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1824 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1834 Ok(()) => attempt.follow(),
1835 Err(reason) => {
1836 tracing::warn!(
1840 reason = %reason,
1841 target = %crate::ssrf::sanitized_url_for_log(attempt.url()),
1842 "oauth redirect rejected"
1843 );
1844 attempt.error(reason)
1845 }
1846 }
1847 }));
1848
1849 if let Some(ref ca_path) = config.ca_cert_path {
1850 let pem = std::fs::read(ca_path)?;
1856 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1857 http_builder = http_builder.add_root_certificate(cert);
1858 }
1859
1860 let http = http_builder.build()?;
1861
1862 Ok(Self {
1863 jwks_uri: config.jwks_uri.clone(),
1864 ttl,
1865 max_jwks_keys: config.max_jwks_keys,
1866 max_response_bytes: config.jwks_max_response_bytes,
1867 allow_http,
1868 inner: RwLock::new(None),
1869 http,
1870 validation_template: validation,
1871 expected_audience: config.audience.clone(),
1872 audience_mode: config.effective_audience_validation_mode(),
1873 azp_fallback_warned: AtomicBool::new(false),
1874 scopes: config.scopes.clone(),
1875 role_claim: config.role_claim.clone(),
1876 role_mappings: config.role_mappings.clone(),
1877 last_refresh_attempt: RwLock::new(None),
1878 refresh_lock: tokio::sync::Mutex::new(()),
1879 allowlist,
1880 #[cfg(any(test, feature = "test-helpers"))]
1881 test_allow_loopback_ssrf: test_bypass,
1882 })
1883 }
1884
1885 #[cfg(any(test, feature = "test-helpers"))]
1889 #[doc(hidden)]
1890 #[must_use]
1891 pub fn __test_allow_loopback_ssrf(self) -> Self {
1892 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1895 self
1896 }
1897
1898 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1900 self.validate_token_with_reason(token).await.ok()
1901 }
1902
1903 pub async fn validate_token_with_reason(
1910 &self,
1911 token: &str,
1912 ) -> Result<AuthIdentity, JwtValidationFailure> {
1913 let claims = self.decode_claims(token).await?;
1914
1915 self.check_audience(&claims)?;
1916 let role = self.resolve_role(&claims)?;
1917
1918 let sub = claims.sub;
1921 let name = claims
1922 .extra
1923 .get("preferred_username")
1924 .and_then(|v| v.as_str())
1925 .map(String::from)
1926 .or_else(|| sub.clone())
1927 .or(claims.azp)
1928 .or(claims.client_id)
1929 .unwrap_or_else(|| "oauth-client".into());
1930
1931 Ok(AuthIdentity {
1932 name,
1933 role,
1934 method: AuthMethod::OAuthJwt,
1935 raw_token: None,
1936 sub,
1937 })
1938 }
1939
1940 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1952 let (key, alg) = self.select_jwks_key(token).await?;
1953
1954 let mut validation = self.validation_template.clone();
1958 validation.algorithms = vec![alg];
1959
1960 let token_owned = token.to_owned();
1963 let join =
1964 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1965 .await;
1966
1967 let decode_result = match join {
1968 Ok(r) => r,
1969 Err(join_err) => {
1970 core::hint::cold_path();
1971 tracing::error!(
1972 error = %join_err,
1973 "JWT decode task panicked or was cancelled"
1974 );
1975 return Err(JwtValidationFailure::Invalid);
1976 }
1977 };
1978
1979 decode_result.map(|td| td.claims).map_err(|e| {
1980 core::hint::cold_path();
1981 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
1982 JwtValidationFailure::Expired
1983 } else {
1984 JwtValidationFailure::Invalid
1985 };
1986 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
1987 failure
1988 })
1989 }
1990
1991 #[allow(
2000 clippy::cognitive_complexity,
2001 reason = "each failure arm pairs `cold_path()` with a distinct `tracing::debug!` site for observability; collapsing into combinators would lose structured-field log sites without reducing real complexity"
2002 )]
2003 async fn select_jwks_key(
2004 &self,
2005 token: &str,
2006 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2007 let Ok(header) = decode_header(token) else {
2008 core::hint::cold_path();
2009 tracing::debug!("JWT header decode failed");
2010 return Err(JwtValidationFailure::Invalid);
2011 };
2012 let kid = header.kid.as_deref();
2013 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2014
2015 if !ACCEPTED_ALGS.contains(&header.alg) {
2016 core::hint::cold_path();
2017 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2018 return Err(JwtValidationFailure::Invalid);
2019 }
2020
2021 let Some(key) = self.find_key(kid, header.alg).await else {
2022 core::hint::cold_path();
2023 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2024 return Err(JwtValidationFailure::Invalid);
2025 };
2026
2027 Ok((key, header.alg))
2028 }
2029
2030 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2039 if claims.aud.contains(&self.expected_audience) {
2040 return Ok(());
2041 }
2042 let azp_match = claims
2043 .azp
2044 .as_deref()
2045 .is_some_and(|azp| azp == self.expected_audience);
2046 if azp_match {
2047 match self.audience_mode {
2048 AudienceValidationMode::Permissive => return Ok(()),
2049 AudienceValidationMode::Warn => {
2050 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2051 tracing::warn!(
2052 expected = %self.expected_audience,
2053 azp = ?claims.azp,
2054 "JWT accepted via deprecated azp-only audience fallback. \
2055 Configure your IdP to populate aud, or set \
2056 audience_validation_mode = \"strict\" once tokens carry aud correctly. \
2057 To silence this warning without changing acceptance, \
2058 set audience_validation_mode = \"permissive\". \
2059 This warning logs once per process."
2060 );
2061 }
2062 return Ok(());
2063 }
2064 AudienceValidationMode::Strict => {}
2065 }
2066 }
2067 core::hint::cold_path();
2068 tracing::debug!(
2069 aud = ?claims.aud.0,
2070 azp = ?claims.azp,
2071 expected = %self.expected_audience,
2072 mode = ?self.audience_mode,
2073 "JWT rejected: audience mismatch"
2074 );
2075 Err(JwtValidationFailure::Invalid)
2076 }
2077
2078 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2084 if let Some(ref claim_path) = self.role_claim {
2085 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2086 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2087 values.extend(resolve_claim_path(&claims.extra, claim_path));
2088 return self
2089 .role_mappings
2090 .iter()
2091 .find(|m| values.contains(&m.claim_value.as_str()))
2092 .map(|m| m.role.clone())
2093 .ok_or(JwtValidationFailure::Invalid);
2094 }
2095
2096 let token_scopes: Vec<&str> = claims
2097 .scope
2098 .as_deref()
2099 .unwrap_or("")
2100 .split_whitespace()
2101 .collect();
2102
2103 self.scopes
2104 .iter()
2105 .find(|m| token_scopes.contains(&m.scope.as_str()))
2106 .map(|m| m.role.clone())
2107 .ok_or(JwtValidationFailure::Invalid)
2108 }
2109
2110 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2113 {
2115 let guard = self.inner.read().await;
2116 if let Some(cached) = guard.as_ref()
2117 && !cached.is_expired()
2118 && let Some(key) = lookup_key(cached, kid, alg)
2119 {
2120 return Some(key);
2121 }
2122 }
2123
2124 self.refresh_with_cooldown().await;
2126
2127 let guard = self.inner.read().await;
2128 guard
2129 .as_ref()
2130 .and_then(|cached| lookup_key(cached, kid, alg))
2131 }
2132
2133 async fn refresh_with_cooldown(&self) {
2138 let _guard = self.refresh_lock.lock().await;
2140
2141 {
2143 let last = self.last_refresh_attempt.read().await;
2144 if let Some(ts) = *last
2145 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2146 {
2147 tracing::debug!(
2148 elapsed_ms = ts.elapsed().as_millis(),
2149 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2150 "JWKS refresh skipped (cooldown active)"
2151 );
2152 return;
2153 }
2154 }
2155
2156 {
2159 let mut last = self.last_refresh_attempt.write().await;
2160 *last = Some(Instant::now());
2161 }
2162
2163 let _ = self.refresh_inner().await;
2165 }
2166
2167 async fn refresh_inner(&self) -> Result<(), String> {
2172 let Some(jwks) = self.fetch_jwks().await else {
2173 return Ok(());
2174 };
2175 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2176 Ok(cache) => cache,
2177 Err(msg) => {
2178 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2179 return Err(msg);
2180 }
2181 };
2182
2183 tracing::debug!(
2184 named = keys.len(),
2185 unnamed = unnamed_keys.len(),
2186 "JWKS refreshed"
2187 );
2188
2189 let mut guard = self.inner.write().await;
2190 *guard = Some(CachedKeys {
2191 keys,
2192 unnamed_keys,
2193 fetched_at: Instant::now(),
2194 ttl: self.ttl,
2195 });
2196 drop(guard);
2197 Ok(())
2198 }
2199
2200 #[allow(
2202 clippy::cognitive_complexity,
2203 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2204 )]
2205 async fn fetch_jwks(&self) -> Option<JwkSet> {
2206 #[cfg(any(test, feature = "test-helpers"))]
2207 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2208 screen_oauth_target_with_test_override(
2209 &self.jwks_uri,
2210 self.allow_http,
2211 &self.allowlist,
2212 true,
2213 )
2214 .await
2215 } else {
2216 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2217 };
2218 #[cfg(not(any(test, feature = "test-helpers")))]
2219 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2220
2221 if let Err(error) = screening {
2222 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2223 return None;
2224 }
2225
2226 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2227 Ok(resp) => resp,
2228 Err(e) => {
2229 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2230 return None;
2231 }
2232 };
2233
2234 let initial_capacity =
2235 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2236 let mut body = Vec::with_capacity(initial_capacity);
2237 while let Some(chunk) = match resp.chunk().await {
2238 Ok(chunk) => chunk,
2239 Err(error) => {
2240 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2241 return None;
2242 }
2243 } {
2244 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2245 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2246 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2247 tracing::warn!(
2248 uri = %self.jwks_uri,
2249 max_bytes = self.max_response_bytes,
2250 "JWKS response exceeded configured size cap"
2251 );
2252 return None;
2253 }
2254 body.extend_from_slice(&chunk);
2255 }
2256
2257 match serde_json::from_slice::<JwkSet>(&body) {
2258 Ok(jwks) => Some(jwks),
2259 Err(error) => {
2260 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2261 None
2262 }
2263 }
2264 }
2265
2266 #[cfg(any(test, feature = "test-helpers"))]
2269 #[doc(hidden)]
2270 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2271 let jwks = self
2272 .fetch_jwks()
2273 .await
2274 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2275 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2276 let mut guard = self.inner.write().await;
2277 *guard = Some(CachedKeys {
2278 keys,
2279 unnamed_keys,
2280 fetched_at: Instant::now(),
2281 ttl: self.ttl,
2282 });
2283 drop(guard);
2284 Ok(())
2285 }
2286
2287 #[cfg(any(test, feature = "test-helpers"))]
2290 #[doc(hidden)]
2291 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2292 let guard = self.inner.read().await;
2293 guard
2294 .as_ref()
2295 .is_some_and(|cache| cache.keys.contains_key(kid))
2296 }
2297}
2298
2299fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2301 if jwks.keys.len() > max_keys {
2302 return Err(format!(
2303 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2304 jwks.keys.len(),
2305 max_keys
2306 ));
2307 }
2308 let mut keys = HashMap::new();
2309 let mut unnamed_keys = Vec::new();
2310 for jwk in &jwks.keys {
2311 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2312 continue;
2313 };
2314 let Some(alg) = jwk_algorithm(jwk) else {
2315 continue;
2316 };
2317 if let Some(ref kid) = jwk.common.key_id {
2318 keys.insert(kid.clone(), (alg, decoding_key));
2319 } else {
2320 unnamed_keys.push((alg, decoding_key));
2321 }
2322 }
2323 Ok((keys, unnamed_keys))
2324}
2325
2326fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2328 if let Some(kid) = kid
2329 && let Some((cached_alg, key)) = cached.keys.get(kid)
2330 && *cached_alg == alg
2331 {
2332 return Some(key.clone());
2333 }
2334 cached
2336 .unnamed_keys
2337 .iter()
2338 .find(|(a, _)| *a == alg)
2339 .map(|(_, k)| k.clone())
2340}
2341
2342#[allow(
2344 clippy::wildcard_enum_match_arm,
2345 reason = "jsonwebtoken KeyAlgorithm is a large external enum; only the JWT-signing variants are mappable to `Algorithm`"
2346)]
2347fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2348 jwk.common.key_algorithm.and_then(|ka| match ka {
2349 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2350 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2351 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2352 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2353 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2354 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2355 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2356 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2357 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2358 _ => None,
2359 })
2360}
2361
2362fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2383 match path {
2384 "sub" => claims.sub.iter().cloned().collect(),
2385 "azp" => claims.azp.iter().cloned().collect(),
2386 "client_id" => claims.client_id.iter().cloned().collect(),
2387 "aud" => claims.aud.0.clone(),
2388 "scope" => claims
2389 .scope
2390 .as_deref()
2391 .unwrap_or("")
2392 .split_whitespace()
2393 .map(str::to_owned)
2394 .collect(),
2395 _ => Vec::new(),
2396 }
2397}
2398
2399fn resolve_claim_path<'a>(
2409 extra: &'a HashMap<String, serde_json::Value>,
2410 path: &str,
2411) -> Vec<&'a str> {
2412 let mut segments = path.split('.');
2413 let Some(first) = segments.next() else {
2414 return Vec::new();
2415 };
2416
2417 let mut current: Option<&serde_json::Value> = extra.get(first);
2418
2419 for segment in segments {
2420 current = current.and_then(|v| v.get(segment));
2421 }
2422
2423 match current {
2424 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2425 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2426 _ => Vec::new(),
2427 }
2428}
2429
2430#[derive(Debug, Deserialize)]
2436struct Claims {
2437 sub: Option<String>,
2439 #[serde(default)]
2442 aud: OneOrMany,
2443 azp: Option<String>,
2445 client_id: Option<String>,
2447 scope: Option<String>,
2449 #[serde(flatten)]
2451 extra: HashMap<String, serde_json::Value>,
2452}
2453
2454#[derive(Debug, Default)]
2456struct OneOrMany(Vec<String>);
2457
2458impl OneOrMany {
2459 fn contains(&self, value: &str) -> bool {
2460 self.0.iter().any(|v| v == value)
2461 }
2462}
2463
2464impl<'de> Deserialize<'de> for OneOrMany {
2465 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2466 use serde::de;
2467
2468 struct Visitor;
2469 impl<'de> de::Visitor<'de> for Visitor {
2470 type Value = OneOrMany;
2471 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2472 f.write_str("a string or array of strings")
2473 }
2474 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2475 Ok(OneOrMany(vec![v.to_owned()]))
2476 }
2477 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2478 let mut v = Vec::new();
2479 while let Some(s) = seq.next_element::<String>()? {
2480 v.push(s);
2481 }
2482 Ok(OneOrMany(v))
2483 }
2484 }
2485 deserializer.deserialize_any(Visitor)
2486 }
2487}
2488
2489#[must_use]
2496pub fn looks_like_jwt(token: &str) -> bool {
2497 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2498
2499 let mut parts = token.splitn(4, '.');
2500 let Some(header_b64) = parts.next() else {
2501 return false;
2502 };
2503 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2505 return false;
2506 }
2507 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2509 return false;
2510 };
2511 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2513 return false;
2514 };
2515 header.get("alg").is_some()
2516}
2517
2518#[must_use]
2528pub fn protected_resource_metadata(
2529 resource_url: &str,
2530 server_url: &str,
2531 config: &OAuthConfig,
2532) -> serde_json::Value {
2533 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2538 let auth_server = server_url;
2539 serde_json::json!({
2540 "resource": resource_url,
2541 "authorization_servers": [auth_server],
2542 "scopes_supported": scopes,
2543 "bearer_methods_supported": ["header"]
2544 })
2545}
2546
2547#[must_use]
2552pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2553 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2554 let mut meta = serde_json::json!({
2555 "issuer": &config.issuer,
2556 "authorization_endpoint": format!("{server_url}/authorize"),
2557 "token_endpoint": format!("{server_url}/token"),
2558 "registration_endpoint": format!("{server_url}/register"),
2559 "response_types_supported": ["code"],
2560 "grant_types_supported": ["authorization_code", "refresh_token"],
2561 "code_challenge_methods_supported": ["S256"],
2562 "scopes_supported": scopes,
2563 "token_endpoint_auth_methods_supported": ["none"],
2564 });
2565 if let Some(proxy) = &config.proxy
2566 && proxy.expose_admin_endpoints
2567 && let Some(obj) = meta.as_object_mut()
2568 {
2569 if proxy.introspection_url.is_some() {
2570 obj.insert(
2571 "introspection_endpoint".into(),
2572 serde_json::Value::String(format!("{server_url}/introspect")),
2573 );
2574 }
2575 if proxy.revocation_url.is_some() {
2576 obj.insert(
2577 "revocation_endpoint".into(),
2578 serde_json::Value::String(format!("{server_url}/revoke")),
2579 );
2580 }
2581 if proxy.require_auth_on_admin_endpoints {
2582 obj.insert(
2583 "introspection_endpoint_auth_methods_supported".into(),
2584 serde_json::json!(["bearer"]),
2585 );
2586 obj.insert(
2587 "revocation_endpoint_auth_methods_supported".into(),
2588 serde_json::json!(["bearer"]),
2589 );
2590 }
2591 }
2592 meta
2593}
2594
2595#[must_use]
2608pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2609 use axum::{
2610 http::{StatusCode, header},
2611 response::IntoResponse,
2612 };
2613
2614 let upstream_query = replace_client_id(query, &proxy.client_id);
2616 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2617
2618 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2619}
2620
2621pub async fn handle_token(
2627 http: &OauthHttpClient,
2628 proxy: &OAuthProxyConfig,
2629 body: &str,
2630) -> axum::response::Response {
2631 use axum::{
2632 http::{StatusCode, header},
2633 response::IntoResponse,
2634 };
2635
2636 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2638
2639 if let Some(ref secret) = proxy.client_secret {
2641 use std::fmt::Write;
2642
2643 use secrecy::ExposeSecret;
2644 let _ = write!(
2645 upstream_body,
2646 "&client_secret={}",
2647 urlencoding::encode(secret.expose_secret())
2648 );
2649 }
2650
2651 let result = http
2652 .send_screened(
2653 &proxy.token_url,
2654 http.inner
2655 .post(&proxy.token_url)
2656 .header("Content-Type", "application/x-www-form-urlencoded")
2657 .body(upstream_body),
2658 )
2659 .await;
2660
2661 match result {
2662 Ok(resp) => {
2663 let status =
2664 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2665 let body_bytes = resp.bytes().await.unwrap_or_default();
2666 (
2667 status,
2668 [(header::CONTENT_TYPE, "application/json")],
2669 body_bytes,
2670 )
2671 .into_response()
2672 }
2673 Err(e) => {
2674 tracing::error!(error = %e, "OAuth token proxy request failed");
2675 (
2676 StatusCode::BAD_GATEWAY,
2677 [(header::CONTENT_TYPE, "application/json")],
2678 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2679 )
2680 .into_response()
2681 }
2682 }
2683}
2684
2685#[must_use]
2692pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2693 let mut resp = serde_json::json!({
2694 "client_id": proxy.client_id,
2695 "token_endpoint_auth_method": "none",
2696 });
2697 if let Some(uris) = body.get("redirect_uris")
2698 && let Some(obj) = resp.as_object_mut()
2699 {
2700 obj.insert("redirect_uris".into(), uris.clone());
2701 }
2702 if let Some(name) = body.get("client_name")
2703 && let Some(obj) = resp.as_object_mut()
2704 {
2705 obj.insert("client_name".into(), name.clone());
2706 }
2707 resp
2708}
2709
2710pub async fn handle_introspect(
2716 http: &OauthHttpClient,
2717 proxy: &OAuthProxyConfig,
2718 body: &str,
2719) -> axum::response::Response {
2720 let Some(ref url) = proxy.introspection_url else {
2721 return oauth_error_response(
2722 axum::http::StatusCode::NOT_FOUND,
2723 "not_supported",
2724 "introspection endpoint is not configured",
2725 );
2726 };
2727 proxy_oauth_admin_request(http, proxy, url, body).await
2728}
2729
2730pub async fn handle_revoke(
2737 http: &OauthHttpClient,
2738 proxy: &OAuthProxyConfig,
2739 body: &str,
2740) -> axum::response::Response {
2741 let Some(ref url) = proxy.revocation_url else {
2742 return oauth_error_response(
2743 axum::http::StatusCode::NOT_FOUND,
2744 "not_supported",
2745 "revocation endpoint is not configured",
2746 );
2747 };
2748 proxy_oauth_admin_request(http, proxy, url, body).await
2749}
2750
2751async fn proxy_oauth_admin_request(
2755 http: &OauthHttpClient,
2756 proxy: &OAuthProxyConfig,
2757 upstream_url: &str,
2758 body: &str,
2759) -> axum::response::Response {
2760 use axum::{
2761 http::{StatusCode, header},
2762 response::IntoResponse,
2763 };
2764
2765 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2766 if let Some(ref secret) = proxy.client_secret {
2767 use std::fmt::Write;
2768
2769 use secrecy::ExposeSecret;
2770 let _ = write!(
2771 upstream_body,
2772 "&client_secret={}",
2773 urlencoding::encode(secret.expose_secret())
2774 );
2775 }
2776
2777 let result = http
2778 .send_screened(
2779 upstream_url,
2780 http.inner
2781 .post(upstream_url)
2782 .header("Content-Type", "application/x-www-form-urlencoded")
2783 .body(upstream_body),
2784 )
2785 .await;
2786
2787 match result {
2788 Ok(resp) => {
2789 let status =
2790 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2791 let content_type = resp
2792 .headers()
2793 .get(header::CONTENT_TYPE)
2794 .and_then(|v| v.to_str().ok())
2795 .unwrap_or("application/json")
2796 .to_owned();
2797 let body_bytes = resp.bytes().await.unwrap_or_default();
2798 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2799 }
2800 Err(e) => {
2801 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2802 oauth_error_response(
2803 StatusCode::BAD_GATEWAY,
2804 "server_error",
2805 "upstream endpoint unreachable",
2806 )
2807 }
2808 }
2809}
2810
2811fn oauth_error_response(
2812 status: axum::http::StatusCode,
2813 error: &str,
2814 description: &str,
2815) -> axum::response::Response {
2816 use axum::{http::header, response::IntoResponse};
2817 let body = serde_json::json!({
2818 "error": error,
2819 "error_description": description,
2820 });
2821 (
2822 status,
2823 [(header::CONTENT_TYPE, "application/json")],
2824 body.to_string(),
2825 )
2826 .into_response()
2827}
2828
2829#[derive(Debug, Deserialize)]
2835struct OAuthErrorResponse {
2836 error: String,
2837 error_description: Option<String>,
2838}
2839
2840fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2847 match raw {
2848 "invalid_request" => "invalid_request",
2849 "invalid_client" => "invalid_client",
2850 "invalid_grant" => "invalid_grant",
2851 "unauthorized_client" => "unauthorized_client",
2852 "unsupported_grant_type" => "unsupported_grant_type",
2853 "invalid_scope" => "invalid_scope",
2854 "temporarily_unavailable" => "temporarily_unavailable",
2855 "invalid_target" => "invalid_target",
2857 _ => "server_error",
2860 }
2861}
2862
2863pub async fn exchange_token(
2875 http: &OauthHttpClient,
2876 config: &TokenExchangeConfig,
2877 subject_token: &str,
2878) -> Result<ExchangedToken, crate::error::McpxError> {
2879 use secrecy::ExposeSecret;
2880
2881 let client = http.client_for(config);
2882 let mut req = client
2883 .post(&config.token_url)
2884 .header("Content-Type", "application/x-www-form-urlencoded")
2885 .header("Accept", "application/json");
2886
2887 if config.client_cert.is_none()
2896 && let Some(ref secret) = config.client_secret
2897 {
2898 use base64::Engine;
2899 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2900 "{}:{}",
2901 urlencoding::encode(&config.client_id),
2902 urlencoding::encode(secret.expose_secret()),
2903 ));
2904 req = req.header("Authorization", format!("Basic {credentials}"));
2905 }
2906
2907 let form_body = build_exchange_form(config, subject_token);
2908
2909 let resp = http
2910 .send_screened(&config.token_url, req.body(form_body))
2911 .await
2912 .map_err(|e| {
2913 tracing::error!(error = %e, "token exchange request failed");
2914 crate::error::McpxError::Auth("server_error".into())
2916 })?;
2917
2918 let status = resp.status();
2919 let body_bytes = resp.bytes().await.map_err(|e| {
2920 tracing::error!(error = %e, "failed to read token exchange response");
2921 crate::error::McpxError::Auth("server_error".into())
2922 })?;
2923
2924 if !status.is_success() {
2925 core::hint::cold_path();
2926 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2929 let short_code = parsed
2930 .as_ref()
2931 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2932 if let Some(ref e) = parsed {
2933 tracing::warn!(
2934 status = %status,
2935 upstream_error = %e.error,
2936 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2937 client_code = %short_code,
2938 "token exchange rejected by authorization server",
2939 );
2940 } else {
2941 tracing::warn!(
2942 status = %status,
2943 client_code = %short_code,
2944 "token exchange rejected (unparseable upstream body)",
2945 );
2946 }
2947 return Err(crate::error::McpxError::Auth(short_code.into()));
2948 }
2949
2950 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2951 tracing::error!(error = %e, "failed to parse token exchange response");
2952 crate::error::McpxError::Auth("server_error".into())
2955 })?;
2956
2957 log_exchanged_token(&exchanged);
2958
2959 Ok(exchanged)
2960}
2961
2962fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2965 let body = format!(
2966 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2967 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2968 urlencoding::encode(subject_token),
2969 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2970 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2971 urlencoding::encode(&config.audience),
2972 );
2973 if config.client_secret.is_none() {
2974 format!(
2975 "{body}&client_id={}",
2976 urlencoding::encode(&config.client_id)
2977 )
2978 } else {
2979 body
2980 }
2981}
2982
2983fn log_exchanged_token(exchanged: &ExchangedToken) {
2986 use base64::Engine;
2987
2988 if !looks_like_jwt(&exchanged.access_token) {
2989 tracing::debug!(
2990 token_len = exchanged.access_token.len(),
2991 issued_token_type = ?exchanged.issued_token_type,
2992 expires_in = exchanged.expires_in,
2993 "exchanged token (opaque)",
2994 );
2995 return;
2996 }
2997 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
2998 return;
2999 };
3000 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
3001 return;
3002 };
3003 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
3004 return;
3005 };
3006 tracing::debug!(
3007 sub = ?claims.get("sub"),
3008 aud = ?claims.get("aud"),
3009 azp = ?claims.get("azp"),
3010 iss = ?claims.get("iss"),
3011 expires_in = exchanged.expires_in,
3012 "exchanged token claims (JWT)",
3013 );
3014}
3015
3016fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3018 let encoded_id = urlencoding::encode(upstream_client_id);
3019 let mut parts: Vec<String> = params
3020 .split('&')
3021 .filter(|p| !p.starts_with("client_id="))
3022 .map(String::from)
3023 .collect();
3024 parts.push(format!("client_id={encoded_id}"));
3025 parts.join("&")
3026}
3027
3028#[cfg(test)]
3029mod tests {
3030 use std::sync::Arc;
3031
3032 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3033
3034 use super::*;
3035
3036 #[test]
3037 fn looks_like_jwt_valid() {
3038 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3040 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3041 let token = format!("{header}.{payload}.signature");
3042 assert!(looks_like_jwt(&token));
3043 }
3044
3045 #[test]
3046 fn looks_like_jwt_rejects_opaque_token() {
3047 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3048 }
3049
3050 #[test]
3051 fn looks_like_jwt_rejects_two_segments() {
3052 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3053 let token = format!("{header}.payload");
3054 assert!(!looks_like_jwt(&token));
3055 }
3056
3057 #[test]
3058 fn looks_like_jwt_rejects_four_segments() {
3059 assert!(!looks_like_jwt("a.b.c.d"));
3060 }
3061
3062 #[test]
3063 fn looks_like_jwt_rejects_no_alg() {
3064 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3065 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3066 let token = format!("{header}.{payload}.sig");
3067 assert!(!looks_like_jwt(&token));
3068 }
3069
3070 #[test]
3071 fn protected_resource_metadata_shape() {
3072 let config = OAuthConfig {
3073 issuer: "https://auth.example.com".into(),
3074 audience: "https://mcp.example.com/mcp".into(),
3075 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3076 scopes: vec![
3077 ScopeMapping {
3078 scope: "mcp:read".into(),
3079 role: "viewer".into(),
3080 },
3081 ScopeMapping {
3082 scope: "mcp:admin".into(),
3083 role: "ops".into(),
3084 },
3085 ],
3086 role_claim: None,
3087 role_mappings: vec![],
3088 jwks_cache_ttl: "10m".into(),
3089 proxy: None,
3090 token_exchange: None,
3091 ca_cert_path: None,
3092 allow_http_oauth_urls: false,
3093 max_jwks_keys: default_max_jwks_keys(),
3094 #[allow(
3095 deprecated,
3096 reason = "test fixture: explicit value for the deprecated field"
3097 )]
3098 strict_audience_validation: false,
3099 audience_validation_mode: None,
3100 jwks_max_response_bytes: default_jwks_max_bytes(),
3101 ssrf_allowlist: None,
3102 };
3103 let meta = protected_resource_metadata(
3104 "https://mcp.example.com/mcp",
3105 "https://mcp.example.com",
3106 &config,
3107 );
3108 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3109 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3110 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3111 assert_eq!(meta["bearer_methods_supported"][0], "header");
3112 }
3113
3114 fn validation_https_config() -> OAuthConfig {
3119 OAuthConfig::builder(
3120 "https://auth.example.com",
3121 "mcp",
3122 "https://auth.example.com/.well-known/jwks.json",
3123 )
3124 .build()
3125 }
3126
3127 #[test]
3128 fn validate_accepts_all_https_urls() {
3129 let cfg = validation_https_config();
3130 cfg.validate().expect("all-HTTPS config must validate");
3131 }
3132
3133 #[test]
3134 fn validate_rejects_unparseable_jwks_cache_ttl() {
3135 let mut cfg = validation_https_config();
3136 cfg.jwks_cache_ttl = "not-a-duration".into();
3137 let err = cfg
3138 .validate()
3139 .expect_err("malformed jwks_cache_ttl must be rejected");
3140 let msg = err.to_string();
3141 assert!(
3142 msg.contains("jwks_cache_ttl"),
3143 "error must reference offending field; got {msg:?}"
3144 );
3145 }
3146
3147 #[test]
3148 fn validate_rejects_http_jwks_uri() {
3149 let mut cfg = validation_https_config();
3150 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3151 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3152 let msg = err.to_string();
3153 assert!(
3154 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3155 "error must reference offending field + scheme requirement; got {msg:?}"
3156 );
3157 }
3158
3159 #[test]
3160 fn validate_rejects_http_proxy_authorize_url() {
3161 let mut cfg = validation_https_config();
3162 cfg.proxy = Some(
3163 OAuthProxyConfig::builder(
3164 "http://idp.example.com/authorize", "https://idp.example.com/token",
3166 "client",
3167 )
3168 .build(),
3169 );
3170 let err = cfg
3171 .validate()
3172 .expect_err("http authorize_url must be rejected");
3173 assert!(
3174 err.to_string().contains("oauth.proxy.authorize_url"),
3175 "error must reference proxy.authorize_url; got {err}"
3176 );
3177 }
3178
3179 #[test]
3180 fn validate_rejects_http_proxy_token_url() {
3181 let mut cfg = validation_https_config();
3182 cfg.proxy = Some(
3183 OAuthProxyConfig::builder(
3184 "https://idp.example.com/authorize",
3185 "http://idp.example.com/token", "client",
3187 )
3188 .build(),
3189 );
3190 let err = cfg.validate().expect_err("http token_url must be rejected");
3191 assert!(
3192 err.to_string().contains("oauth.proxy.token_url"),
3193 "error must reference proxy.token_url; got {err}"
3194 );
3195 }
3196
3197 #[test]
3198 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3199 let mut cfg = validation_https_config();
3200 cfg.proxy = Some(
3201 OAuthProxyConfig::builder(
3202 "https://idp.example.com/authorize",
3203 "https://idp.example.com/token",
3204 "client",
3205 )
3206 .introspection_url("http://idp.example.com/introspect")
3207 .build(),
3208 );
3209 let err = cfg
3210 .validate()
3211 .expect_err("http introspection_url must be rejected");
3212 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3213
3214 let mut cfg = validation_https_config();
3215 cfg.proxy = Some(
3216 OAuthProxyConfig::builder(
3217 "https://idp.example.com/authorize",
3218 "https://idp.example.com/token",
3219 "client",
3220 )
3221 .revocation_url("http://idp.example.com/revoke")
3222 .build(),
3223 );
3224 let err = cfg
3225 .validate()
3226 .expect_err("http revocation_url must be rejected");
3227 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3228 }
3229
3230 #[test]
3233 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3234 let mut cfg = validation_https_config();
3235 cfg.proxy = Some(
3236 OAuthProxyConfig::builder(
3237 "https://idp.example.com/authorize",
3238 "https://idp.example.com/token",
3239 "client",
3240 )
3241 .introspection_url("https://idp.example.com/introspect")
3242 .expose_admin_endpoints(true)
3243 .build(),
3244 );
3245 let err = cfg
3246 .validate()
3247 .expect_err("expose_admin_endpoints without auth must fail");
3248 let msg = err.to_string();
3249 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3250 assert!(
3251 msg.contains("allow_unauthenticated_admin_endpoints"),
3252 "{msg}"
3253 );
3254 }
3255
3256 #[test]
3257 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3258 let mut cfg = validation_https_config();
3259 cfg.proxy = Some(
3260 OAuthProxyConfig::builder(
3261 "https://idp.example.com/authorize",
3262 "https://idp.example.com/token",
3263 "client",
3264 )
3265 .introspection_url("https://idp.example.com/introspect")
3266 .expose_admin_endpoints(true)
3267 .require_auth_on_admin_endpoints(true)
3268 .build(),
3269 );
3270 cfg.validate()
3271 .expect("authed admin endpoints must validate");
3272 }
3273
3274 #[test]
3275 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3276 let mut cfg = validation_https_config();
3277 cfg.proxy = Some(
3278 OAuthProxyConfig::builder(
3279 "https://idp.example.com/authorize",
3280 "https://idp.example.com/token",
3281 "client",
3282 )
3283 .introspection_url("https://idp.example.com/introspect")
3284 .expose_admin_endpoints(true)
3285 .allow_unauthenticated_admin_endpoints(true)
3286 .build(),
3287 );
3288 cfg.validate()
3289 .expect("explicit unauth opt-out must validate");
3290 }
3291
3292 #[test]
3293 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3294 let mut cfg = validation_https_config();
3297 cfg.proxy = Some(
3298 OAuthProxyConfig::builder(
3299 "https://idp.example.com/authorize",
3300 "https://idp.example.com/token",
3301 "client",
3302 )
3303 .introspection_url("https://idp.example.com/introspect")
3304 .build(),
3305 );
3306 cfg.validate()
3307 .expect("unexposed admin endpoints must validate");
3308 }
3309
3310 #[test]
3311 fn validate_rejects_http_token_exchange_url() {
3312 let mut cfg = validation_https_config();
3313 cfg.token_exchange = Some(TokenExchangeConfig::new(
3314 "http://idp.example.com/token".into(), "client".into(),
3316 None,
3317 None,
3318 "downstream".into(),
3319 ));
3320 let err = cfg
3321 .validate()
3322 .expect_err("http token_exchange.token_url must be rejected");
3323 assert!(
3324 err.to_string().contains("oauth.token_exchange.token_url"),
3325 "error must reference token_exchange.token_url; got {err}"
3326 );
3327 }
3328
3329 #[test]
3330 fn validate_rejects_unparseable_url() {
3331 let mut cfg = validation_https_config();
3332 cfg.jwks_uri = "not a url".into();
3333 let err = cfg
3334 .validate()
3335 .expect_err("unparseable URL must be rejected");
3336 assert!(err.to_string().contains("invalid URL"));
3337 }
3338
3339 #[test]
3340 fn validate_rejects_non_http_scheme() {
3341 let mut cfg = validation_https_config();
3342 cfg.jwks_uri = "file:///etc/passwd".into();
3343 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3344 let msg = err.to_string();
3345 assert!(
3346 msg.contains("must use https scheme") && msg.contains("file"),
3347 "error must reject non-http(s) schemes; got {msg:?}"
3348 );
3349 }
3350
3351 #[test]
3352 fn validate_accepts_http_with_escape_hatch() {
3353 let mut cfg = OAuthConfig::builder(
3358 "http://auth.local",
3359 "mcp",
3360 "http://auth.local/.well-known/jwks.json",
3361 )
3362 .allow_http_oauth_urls(true)
3363 .build();
3364 cfg.proxy = Some(
3365 OAuthProxyConfig::builder(
3366 "http://idp.local/authorize",
3367 "http://idp.local/token",
3368 "client",
3369 )
3370 .introspection_url("http://idp.local/introspect")
3371 .revocation_url("http://idp.local/revoke")
3372 .build(),
3373 );
3374 cfg.token_exchange = Some(TokenExchangeConfig::new(
3375 "http://idp.local/token".into(),
3376 "client".into(),
3377 Some(secrecy::SecretString::new("dev-secret".into())),
3378 None,
3379 "downstream".into(),
3380 ));
3381 cfg.validate()
3382 .expect("escape hatch must permit http on all URL fields");
3383 }
3384
3385 #[test]
3386 fn validate_with_escape_hatch_still_rejects_unparseable() {
3387 let mut cfg = validation_https_config();
3390 cfg.allow_http_oauth_urls = true;
3391 cfg.jwks_uri = "::not-a-url::".into();
3392 cfg.validate()
3393 .expect_err("escape hatch must NOT bypass URL parsing");
3394 }
3395
3396 #[tokio::test]
3397 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3398 rustls::crypto::ring::default_provider()
3413 .install_default()
3414 .ok();
3415
3416 let policy = reqwest::redirect::Policy::custom(|attempt| {
3417 if attempt.url().scheme() != "https" {
3418 attempt.error("redirect to non-HTTPS URL refused")
3419 } else if attempt.previous().len() >= 2 {
3420 attempt.error("too many redirects (max 2)")
3421 } else {
3422 attempt.follow()
3423 }
3424 });
3425 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3432 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3433 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3434 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3435 );
3436 let client = reqwest::Client::builder()
3437 .no_proxy()
3438 .dns_resolver(Arc::clone(&resolver))
3439 .timeout(Duration::from_secs(5))
3440 .connect_timeout(Duration::from_secs(3))
3441 .redirect(policy)
3442 .build()
3443 .expect("test client builds");
3444
3445 let mock = wiremock::MockServer::start().await;
3446 wiremock::Mock::given(wiremock::matchers::method("GET"))
3447 .and(wiremock::matchers::path("/jwks.json"))
3448 .respond_with(
3449 wiremock::ResponseTemplate::new(302)
3450 .insert_header("location", "http://example.invalid/jwks.json"),
3451 )
3452 .mount(&mock)
3453 .await;
3454
3455 let url = format!("{}/jwks.json", mock.uri());
3464 let err = client
3465 .get(&url)
3466 .send()
3467 .await
3468 .expect_err("redirect policy must reject scheme downgrade");
3469 let chain = format!("{err:#}");
3470 assert!(
3471 chain.contains("redirect to non-HTTPS URL refused")
3472 || chain.to_lowercase().contains("redirect"),
3473 "error must surface redirect-policy rejection; got {chain:?}"
3474 );
3475 }
3476
3477 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3482
3483 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3485 let mut rng = rsa::rand_core::OsRng;
3486 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3487 let private_pem = private_key
3488 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3489 .expect("PKCS8 PEM export")
3490 .to_string();
3491
3492 let public_key = private_key.to_public_key();
3493 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3494 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3495
3496 let jwks = serde_json::json!({
3497 "keys": [{
3498 "kty": "RSA",
3499 "use": "sig",
3500 "alg": "RS256",
3501 "kid": kid,
3502 "n": n,
3503 "e": e
3504 }]
3505 });
3506
3507 (private_pem, jwks)
3508 }
3509
3510 fn mint_token(
3512 private_pem: &str,
3513 kid: &str,
3514 issuer: &str,
3515 audience: &str,
3516 subject: &str,
3517 scope: &str,
3518 ) -> String {
3519 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3520 .expect("encoding key from PEM");
3521 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3522 header.kid = Some(kid.into());
3523
3524 let now = jsonwebtoken::get_current_timestamp();
3525 let claims = serde_json::json!({
3526 "iss": issuer,
3527 "aud": audience,
3528 "sub": subject,
3529 "scope": scope,
3530 "exp": now + 3600,
3531 "iat": now,
3532 });
3533
3534 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3535 }
3536
3537 fn test_config(jwks_uri: &str) -> OAuthConfig {
3538 OAuthConfig {
3539 issuer: "https://auth.test.local".into(),
3540 audience: "https://mcp.test.local/mcp".into(),
3541 jwks_uri: jwks_uri.into(),
3542 scopes: vec![
3543 ScopeMapping {
3544 scope: "mcp:read".into(),
3545 role: "viewer".into(),
3546 },
3547 ScopeMapping {
3548 scope: "mcp:admin".into(),
3549 role: "ops".into(),
3550 },
3551 ],
3552 role_claim: None,
3553 role_mappings: vec![],
3554 jwks_cache_ttl: "5m".into(),
3555 proxy: None,
3556 token_exchange: None,
3557 ca_cert_path: None,
3558 allow_http_oauth_urls: true,
3559 max_jwks_keys: default_max_jwks_keys(),
3560 #[allow(
3561 deprecated,
3562 reason = "test fixture: explicit value for the deprecated field"
3563 )]
3564 strict_audience_validation: false,
3565 audience_validation_mode: None,
3566 jwks_max_response_bytes: default_jwks_max_bytes(),
3567 ssrf_allowlist: None,
3568 }
3569 }
3570
3571 fn test_cache(config: &OAuthConfig) -> JwksCache {
3572 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3573 }
3574
3575 #[tokio::test]
3576 async fn valid_jwt_returns_identity() {
3577 let kid = "test-key-1";
3578 let (pem, jwks) = generate_test_keypair(kid);
3579
3580 let mock_server = wiremock::MockServer::start().await;
3581 wiremock::Mock::given(wiremock::matchers::method("GET"))
3582 .and(wiremock::matchers::path("/jwks.json"))
3583 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3584 .mount(&mock_server)
3585 .await;
3586
3587 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3588 let config = test_config(&jwks_uri);
3589 let cache = test_cache(&config);
3590
3591 let token = mint_token(
3592 &pem,
3593 kid,
3594 "https://auth.test.local",
3595 "https://mcp.test.local/mcp",
3596 "ci-bot",
3597 "mcp:read mcp:other",
3598 );
3599
3600 let identity = cache.validate_token(&token).await;
3601 assert!(identity.is_some(), "valid JWT should authenticate");
3602 let id = identity.unwrap();
3603 assert_eq!(id.name, "ci-bot");
3604 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3606 }
3607
3608 #[tokio::test]
3609 async fn wrong_issuer_rejected() {
3610 let kid = "test-key-2";
3611 let (pem, jwks) = generate_test_keypair(kid);
3612
3613 let mock_server = wiremock::MockServer::start().await;
3614 wiremock::Mock::given(wiremock::matchers::method("GET"))
3615 .and(wiremock::matchers::path("/jwks.json"))
3616 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3617 .mount(&mock_server)
3618 .await;
3619
3620 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3621 let config = test_config(&jwks_uri);
3622 let cache = test_cache(&config);
3623
3624 let token = mint_token(
3625 &pem,
3626 kid,
3627 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3629 "attacker",
3630 "mcp:admin",
3631 );
3632
3633 assert!(cache.validate_token(&token).await.is_none());
3634 }
3635
3636 #[tokio::test]
3637 async fn wrong_audience_rejected() {
3638 let kid = "test-key-3";
3639 let (pem, jwks) = generate_test_keypair(kid);
3640
3641 let mock_server = wiremock::MockServer::start().await;
3642 wiremock::Mock::given(wiremock::matchers::method("GET"))
3643 .and(wiremock::matchers::path("/jwks.json"))
3644 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3645 .mount(&mock_server)
3646 .await;
3647
3648 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3649 let config = test_config(&jwks_uri);
3650 let cache = test_cache(&config);
3651
3652 let token = mint_token(
3653 &pem,
3654 kid,
3655 "https://auth.test.local",
3656 "https://wrong-audience.example.com", "attacker",
3658 "mcp:admin",
3659 );
3660
3661 assert!(cache.validate_token(&token).await.is_none());
3662 }
3663
3664 #[tokio::test]
3665 async fn expired_jwt_rejected() {
3666 let kid = "test-key-4";
3667 let (pem, jwks) = generate_test_keypair(kid);
3668
3669 let mock_server = wiremock::MockServer::start().await;
3670 wiremock::Mock::given(wiremock::matchers::method("GET"))
3671 .and(wiremock::matchers::path("/jwks.json"))
3672 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3673 .mount(&mock_server)
3674 .await;
3675
3676 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3677 let config = test_config(&jwks_uri);
3678 let cache = test_cache(&config);
3679
3680 let encoding_key =
3682 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3683 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3684 header.kid = Some(kid.into());
3685 let now = jsonwebtoken::get_current_timestamp();
3686 let claims = serde_json::json!({
3687 "iss": "https://auth.test.local",
3688 "aud": "https://mcp.test.local/mcp",
3689 "sub": "expired-bot",
3690 "scope": "mcp:read",
3691 "exp": now - 120,
3692 "iat": now - 3720,
3693 });
3694 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3695
3696 assert!(cache.validate_token(&token).await.is_none());
3697 }
3698
3699 #[tokio::test]
3700 async fn no_matching_scope_rejected() {
3701 let kid = "test-key-5";
3702 let (pem, jwks) = generate_test_keypair(kid);
3703
3704 let mock_server = wiremock::MockServer::start().await;
3705 wiremock::Mock::given(wiremock::matchers::method("GET"))
3706 .and(wiremock::matchers::path("/jwks.json"))
3707 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3708 .mount(&mock_server)
3709 .await;
3710
3711 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3712 let config = test_config(&jwks_uri);
3713 let cache = test_cache(&config);
3714
3715 let token = mint_token(
3716 &pem,
3717 kid,
3718 "https://auth.test.local",
3719 "https://mcp.test.local/mcp",
3720 "limited-bot",
3721 "some:other:scope", );
3723
3724 assert!(cache.validate_token(&token).await.is_none());
3725 }
3726
3727 #[tokio::test]
3728 async fn wrong_signing_key_rejected() {
3729 let kid = "test-key-6";
3730 let (_pem, jwks) = generate_test_keypair(kid);
3731
3732 let (attacker_pem, _) = generate_test_keypair(kid);
3734
3735 let mock_server = wiremock::MockServer::start().await;
3736 wiremock::Mock::given(wiremock::matchers::method("GET"))
3737 .and(wiremock::matchers::path("/jwks.json"))
3738 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3739 .mount(&mock_server)
3740 .await;
3741
3742 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3743 let config = test_config(&jwks_uri);
3744 let cache = test_cache(&config);
3745
3746 let token = mint_token(
3748 &attacker_pem,
3749 kid,
3750 "https://auth.test.local",
3751 "https://mcp.test.local/mcp",
3752 "attacker",
3753 "mcp:admin",
3754 );
3755
3756 assert!(cache.validate_token(&token).await.is_none());
3757 }
3758
3759 #[tokio::test]
3760 async fn admin_scope_maps_to_ops_role() {
3761 let kid = "test-key-7";
3762 let (pem, jwks) = generate_test_keypair(kid);
3763
3764 let mock_server = wiremock::MockServer::start().await;
3765 wiremock::Mock::given(wiremock::matchers::method("GET"))
3766 .and(wiremock::matchers::path("/jwks.json"))
3767 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3768 .mount(&mock_server)
3769 .await;
3770
3771 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3772 let config = test_config(&jwks_uri);
3773 let cache = test_cache(&config);
3774
3775 let token = mint_token(
3776 &pem,
3777 kid,
3778 "https://auth.test.local",
3779 "https://mcp.test.local/mcp",
3780 "admin-bot",
3781 "mcp:admin",
3782 );
3783
3784 let id = cache
3785 .validate_token(&token)
3786 .await
3787 .expect("should authenticate");
3788 assert_eq!(id.role, "ops");
3789 assert_eq!(id.name, "admin-bot");
3790 }
3791
3792 #[tokio::test]
3793 async fn jwks_server_down_returns_none() {
3794 let config = test_config("http://127.0.0.1:1/jwks.json");
3796 let cache = test_cache(&config);
3797
3798 let kid = "orphan-key";
3799 let (pem, _) = generate_test_keypair(kid);
3800 let token = mint_token(
3801 &pem,
3802 kid,
3803 "https://auth.test.local",
3804 "https://mcp.test.local/mcp",
3805 "bot",
3806 "mcp:read",
3807 );
3808
3809 assert!(cache.validate_token(&token).await.is_none());
3810 }
3811
3812 #[test]
3817 fn resolve_claim_path_flat_string() {
3818 let mut extra = HashMap::new();
3819 extra.insert(
3820 "scope".into(),
3821 serde_json::Value::String("mcp:read mcp:admin".into()),
3822 );
3823 let values = resolve_claim_path(&extra, "scope");
3824 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3825 }
3826
3827 #[test]
3828 fn resolve_claim_path_flat_array() {
3829 let mut extra = HashMap::new();
3830 extra.insert(
3831 "roles".into(),
3832 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3833 );
3834 let values = resolve_claim_path(&extra, "roles");
3835 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3836 }
3837
3838 #[test]
3839 fn resolve_claim_path_nested_keycloak() {
3840 let mut extra = HashMap::new();
3841 extra.insert(
3842 "realm_access".into(),
3843 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3844 );
3845 let values = resolve_claim_path(&extra, "realm_access.roles");
3846 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3847 }
3848
3849 #[test]
3850 fn resolve_claim_path_missing_returns_empty() {
3851 let extra = HashMap::new();
3852 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3853 }
3854
3855 #[test]
3856 fn resolve_claim_path_numeric_leaf_returns_empty() {
3857 let mut extra = HashMap::new();
3858 extra.insert("count".into(), serde_json::json!(42));
3859 assert!(resolve_claim_path(&extra, "count").is_empty());
3860 }
3861
3862 fn make_claims(json: serde_json::Value) -> Claims {
3863 serde_json::from_value(json).expect("test claims must deserialize")
3864 }
3865
3866 #[test]
3867 fn first_class_scope_claim_splits_on_whitespace() {
3868 let claims = make_claims(serde_json::json!({
3869 "iss": "https://issuer.example.com",
3870 "exp": 9_999_999_999_u64,
3871 "scope": "read write admin",
3872 }));
3873 let values = first_class_claim_values(&claims, "scope");
3874 assert_eq!(values, vec!["read", "write", "admin"]);
3875 }
3876
3877 #[test]
3878 fn first_class_sub_claim_returns_single_value() {
3879 let claims = make_claims(serde_json::json!({
3880 "iss": "https://issuer.example.com",
3881 "exp": 9_999_999_999_u64,
3882 "sub": "service-account-orders",
3883 }));
3884 let values = first_class_claim_values(&claims, "sub");
3885 assert_eq!(values, vec!["service-account-orders"]);
3886 }
3887
3888 #[test]
3889 fn first_class_aud_claim_returns_every_audience() {
3890 let claims = make_claims(serde_json::json!({
3891 "iss": "https://issuer.example.com",
3892 "exp": 9_999_999_999_u64,
3893 "aud": ["api-a", "api-b"],
3894 }));
3895 let values = first_class_claim_values(&claims, "aud");
3896 assert_eq!(values, vec!["api-a", "api-b"]);
3897 }
3898
3899 #[test]
3900 fn first_class_unknown_path_returns_empty() {
3901 let claims = make_claims(serde_json::json!({
3902 "iss": "https://issuer.example.com",
3903 "exp": 9_999_999_999_u64,
3904 }));
3905 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
3906 }
3907
3908 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3914 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3915 .expect("encoding key from PEM");
3916 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3917 header.kid = Some(kid.into());
3918 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3919 }
3920
3921 fn test_config_with_role_claim(
3922 jwks_uri: &str,
3923 role_claim: &str,
3924 role_mappings: Vec<RoleMapping>,
3925 ) -> OAuthConfig {
3926 OAuthConfig {
3927 issuer: "https://auth.test.local".into(),
3928 audience: "https://mcp.test.local/mcp".into(),
3929 jwks_uri: jwks_uri.into(),
3930 scopes: vec![],
3931 role_claim: Some(role_claim.into()),
3932 role_mappings,
3933 jwks_cache_ttl: "5m".into(),
3934 proxy: None,
3935 token_exchange: None,
3936 ca_cert_path: None,
3937 allow_http_oauth_urls: true,
3938 max_jwks_keys: default_max_jwks_keys(),
3939 #[allow(
3940 deprecated,
3941 reason = "test fixture: explicit value for the deprecated field"
3942 )]
3943 strict_audience_validation: false,
3944 audience_validation_mode: None,
3945 jwks_max_response_bytes: default_jwks_max_bytes(),
3946 ssrf_allowlist: None,
3947 }
3948 }
3949
3950 #[tokio::test]
3951 async fn screen_oauth_target_rejects_literal_ip() {
3952 let err = screen_oauth_target(
3953 "https://127.0.0.1/jwks.json",
3954 false,
3955 &crate::ssrf::CompiledSsrfAllowlist::default(),
3956 )
3957 .await
3958 .expect_err("literal IPs must be rejected");
3959 let msg = err.to_string();
3960 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3961 }
3962
3963 #[tokio::test]
3964 async fn screen_oauth_target_rejects_private_dns_resolution() {
3965 let err = screen_oauth_target(
3966 "https://localhost/jwks.json",
3967 false,
3968 &crate::ssrf::CompiledSsrfAllowlist::default(),
3969 )
3970 .await
3971 .expect_err("localhost resolution must be rejected");
3972 let msg = err.to_string();
3973 assert!(
3974 msg.contains("blocked IP") && msg.contains("loopback"),
3975 "got {msg:?}"
3976 );
3977 }
3978
3979 #[tokio::test]
3980 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
3981 let err = screen_oauth_target(
3982 "http://127.0.0.1/jwks.json",
3983 true,
3984 &crate::ssrf::CompiledSsrfAllowlist::default(),
3985 )
3986 .await
3987 .expect_err("literal IPs must still be rejected when http is allowed");
3988 let msg = err.to_string();
3989 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3990 }
3991
3992 #[tokio::test]
3993 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
3994 let err = screen_oauth_target(
3995 "http://localhost/jwks.json",
3996 true,
3997 &crate::ssrf::CompiledSsrfAllowlist::default(),
3998 )
3999 .await
4000 .expect_err("private DNS resolution must still be rejected when http is allowed");
4001 let msg = err.to_string();
4002 assert!(
4003 msg.contains("blocked IP") && msg.contains("loopback"),
4004 "got {msg:?}"
4005 );
4006 }
4007
4008 #[tokio::test]
4009 async fn screen_oauth_target_allows_public_hostname() {
4010 screen_oauth_target(
4011 "https://example.com/.well-known/jwks.json",
4012 false,
4013 &crate::ssrf::CompiledSsrfAllowlist::default(),
4014 )
4015 .await
4016 .expect("public hostname should pass screening");
4017 }
4018
4019 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4025 let raw = OAuthSsrfAllowlist {
4026 hosts: hosts.iter().map(|s| (*s).to_owned()).collect(),
4027 cidrs: cidrs.iter().map(|s| (*s).to_owned()).collect(),
4028 };
4029 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4030 }
4031
4032 #[test]
4033 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4034 let raw = OAuthSsrfAllowlist {
4035 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4036 cidrs: vec![],
4037 };
4038 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4039 assert_eq!(compiled.host_count(), 1);
4040 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4041 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4042 }
4043
4044 #[test]
4045 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4046 let raw = OAuthSsrfAllowlist {
4047 hosts: vec!["10.0.0.1".into()],
4048 cidrs: vec![],
4049 };
4050 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4051 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4052 }
4053
4054 #[test]
4055 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4056 let raw = OAuthSsrfAllowlist {
4057 hosts: vec!["rhbk.ops.example.com:8443".into()],
4058 cidrs: vec![],
4059 };
4060 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4061 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4062 }
4063
4064 #[test]
4065 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4066 let raw = OAuthSsrfAllowlist {
4067 hosts: vec![],
4068 cidrs: vec!["not-a-cidr".into()],
4069 };
4070 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4071 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4072 }
4073
4074 #[test]
4075 fn validate_rejects_misconfigured_allowlist() {
4076 let mut cfg = OAuthConfig::builder(
4077 "https://auth.example.com/",
4078 "mcp",
4079 "https://auth.example.com/jwks.json",
4080 )
4081 .build();
4082 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4083 hosts: vec!["10.0.0.1".into()],
4084 cidrs: vec![],
4085 });
4086 let err = cfg
4087 .validate()
4088 .expect_err("literal IP host must be rejected");
4089 assert!(
4090 err.to_string().contains("oauth.ssrf_allowlist"),
4091 "got {err}"
4092 );
4093 }
4094
4095 #[tokio::test]
4096 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4097 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4101 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4102 .await
4103 .expect_err("loopback must still be blocked when not in allowlist");
4104 let msg = err.to_string();
4105 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4106 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4107 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4108 }
4109
4110 #[tokio::test]
4111 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4112 let err = screen_oauth_target(
4115 "https://localhost/jwks.json",
4116 false,
4117 &crate::ssrf::CompiledSsrfAllowlist::default(),
4118 )
4119 .await
4120 .expect_err("loopback rejection");
4121 let msg = err.to_string();
4122 assert!(msg.contains("blocked IP"), "got {msg:?}");
4123 assert!(msg.contains("loopback"), "got {msg:?}");
4124 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4126 }
4127
4128 #[tokio::test]
4129 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4130 let allow = make_allowlist(&["localhost"], &[]);
4132 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4133 .await
4134 .expect("allowlisted host must pass");
4135 }
4136
4137 #[tokio::test]
4138 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4139 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4142 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4143 .await
4144 .expect("allowlisted CIDR must pass");
4145 }
4146
4147 #[tokio::test]
4148 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4149 let mut cfg = OAuthConfig::builder(
4150 "https://auth.example.com/",
4151 "mcp",
4152 "https://auth.example.com/jwks.json",
4153 )
4154 .build();
4155 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4156 hosts: vec![],
4157 cidrs: vec!["bad-cidr".into()],
4158 });
4159 let Err(err) = JwksCache::new(&cfg) else {
4160 panic!("invalid CIDR must fail JwksCache::new")
4161 };
4162 let msg = err.to_string();
4163 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4164 }
4165
4166 #[tokio::test]
4167 async fn jwks_cache_new_invalid_ttl_is_err() {
4168 let cfg = OAuthConfig::builder(
4171 "https://auth.example.com/",
4172 "mcp",
4173 "https://auth.example.com/jwks.json",
4174 )
4175 .jwks_cache_ttl("not-a-duration")
4176 .build();
4177 let Err(err) = JwksCache::new(&cfg) else {
4178 panic!("invalid jwks_cache_ttl must fail JwksCache::new")
4179 };
4180 let msg = err.to_string();
4181 assert!(msg.contains("jwks_cache_ttl"), "got {msg:?}");
4182 }
4183
4184 #[tokio::test]
4185 async fn audience_falls_back_to_azp_by_default() {
4186 let kid = "test-audience-azp-default";
4187 let (pem, jwks) = generate_test_keypair(kid);
4188
4189 let mock_server = wiremock::MockServer::start().await;
4190 wiremock::Mock::given(wiremock::matchers::method("GET"))
4191 .and(wiremock::matchers::path("/jwks.json"))
4192 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4193 .mount(&mock_server)
4194 .await;
4195
4196 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4197 let config = test_config(&jwks_uri);
4198 let cache = test_cache(&config);
4199
4200 let now = jsonwebtoken::get_current_timestamp();
4201 let token = mint_token_with_claims(
4202 &pem,
4203 kid,
4204 &serde_json::json!({
4205 "iss": "https://auth.test.local",
4206 "aud": "https://some-other-resource.example.com",
4207 "azp": "https://mcp.test.local/mcp",
4208 "sub": "compat-client",
4209 "scope": "mcp:read",
4210 "exp": now + 3600,
4211 "iat": now,
4212 }),
4213 );
4214
4215 let identity = cache
4216 .validate_token_with_reason(&token)
4217 .await
4218 .expect("azp fallback should remain enabled by default");
4219 assert_eq!(identity.role, "viewer");
4220 }
4221
4222 #[tokio::test]
4223 async fn strict_audience_validation_rejects_azp_only_match() {
4224 let kid = "test-audience-azp-strict";
4225 let (pem, jwks) = generate_test_keypair(kid);
4226
4227 let mock_server = wiremock::MockServer::start().await;
4228 wiremock::Mock::given(wiremock::matchers::method("GET"))
4229 .and(wiremock::matchers::path("/jwks.json"))
4230 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4231 .mount(&mock_server)
4232 .await;
4233
4234 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4235 let mut config = test_config(&jwks_uri);
4236 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4237 {
4238 config.strict_audience_validation = true;
4239 }
4240 let cache = test_cache(&config);
4241
4242 let now = jsonwebtoken::get_current_timestamp();
4243 let token = mint_token_with_claims(
4244 &pem,
4245 kid,
4246 &serde_json::json!({
4247 "iss": "https://auth.test.local",
4248 "aud": "https://some-other-resource.example.com",
4249 "azp": "https://mcp.test.local/mcp",
4250 "sub": "strict-client",
4251 "scope": "mcp:read",
4252 "exp": now + 3600,
4253 "iat": now,
4254 }),
4255 );
4256
4257 let failure = cache
4258 .validate_token_with_reason(&token)
4259 .await
4260 .expect_err("strict audience validation must ignore azp fallback");
4261 assert_eq!(failure, JwtValidationFailure::Invalid);
4262 }
4263
4264 #[tokio::test]
4265 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4266 let kid = "test-audience-warn-mode";
4267 let (pem, jwks) = generate_test_keypair(kid);
4268
4269 let mock_server = wiremock::MockServer::start().await;
4270 wiremock::Mock::given(wiremock::matchers::method("GET"))
4271 .and(wiremock::matchers::path("/jwks.json"))
4272 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4273 .mount(&mock_server)
4274 .await;
4275
4276 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4277 let mut config = test_config(&jwks_uri);
4278 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4279 let cache = test_cache(&config);
4280
4281 let now = jsonwebtoken::get_current_timestamp();
4282 let claims = serde_json::json!({
4283 "iss": "https://auth.test.local",
4284 "aud": "https://some-other-resource.example.com",
4285 "azp": "https://mcp.test.local/mcp",
4286 "sub": "warn-client",
4287 "scope": "mcp:read",
4288 "exp": now + 3600,
4289 "iat": now,
4290 });
4291 let token = mint_token_with_claims(&pem, kid, &claims);
4292
4293 let identity = cache
4294 .validate_token_with_reason(&token)
4295 .await
4296 .expect("warn mode must accept azp-only match");
4297 assert_eq!(identity.role, "viewer");
4298 assert!(
4299 cache.azp_fallback_warned.load(Ordering::Relaxed),
4300 "warn-once flag should be set after first azp-only match"
4301 );
4302
4303 let token2 = mint_token_with_claims(&pem, kid, &claims);
4304 cache
4305 .validate_token_with_reason(&token2)
4306 .await
4307 .expect("warn mode must continue accepting subsequent matches");
4308 assert!(
4309 cache.azp_fallback_warned.load(Ordering::Relaxed),
4310 "warn-once flag must remain set; the assertion guards against accidental clearing"
4311 );
4312 }
4313
4314 #[tokio::test]
4315 async fn permissive_mode_accepts_azp_only_match_silently() {
4316 let kid = "test-audience-permissive-mode";
4317 let (pem, jwks) = generate_test_keypair(kid);
4318
4319 let mock_server = wiremock::MockServer::start().await;
4320 wiremock::Mock::given(wiremock::matchers::method("GET"))
4321 .and(wiremock::matchers::path("/jwks.json"))
4322 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4323 .mount(&mock_server)
4324 .await;
4325
4326 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4327 let mut config = test_config(&jwks_uri);
4328 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4329 let cache = test_cache(&config);
4330
4331 let now = jsonwebtoken::get_current_timestamp();
4332 let token = mint_token_with_claims(
4333 &pem,
4334 kid,
4335 &serde_json::json!({
4336 "iss": "https://auth.test.local",
4337 "aud": "https://some-other-resource.example.com",
4338 "azp": "https://mcp.test.local/mcp",
4339 "sub": "permissive-client",
4340 "scope": "mcp:read",
4341 "exp": now + 3600,
4342 "iat": now,
4343 }),
4344 );
4345
4346 cache
4347 .validate_token_with_reason(&token)
4348 .await
4349 .expect("permissive mode must accept azp-only match");
4350 assert!(
4351 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4352 "permissive mode must not flip the warn-once flag"
4353 );
4354 }
4355
4356 #[test]
4357 fn audience_validation_mode_overrides_legacy_bool() {
4358 let mut config = OAuthConfig::default();
4359 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4360 {
4361 config.strict_audience_validation = false;
4362 }
4363 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4364 assert_eq!(
4365 config.effective_audience_validation_mode(),
4366 AudienceValidationMode::Strict,
4367 "explicit mode must override legacy false"
4368 );
4369
4370 let mut config = OAuthConfig::default();
4371 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4372 {
4373 config.strict_audience_validation = true;
4374 }
4375 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4376 assert_eq!(
4377 config.effective_audience_validation_mode(),
4378 AudienceValidationMode::Permissive,
4379 "explicit mode must override legacy true"
4380 );
4381 }
4382
4383 #[test]
4384 fn audience_validation_mode_default_is_warn_when_unset() {
4385 let config = OAuthConfig::default();
4386 assert_eq!(
4387 config.effective_audience_validation_mode(),
4388 AudienceValidationMode::Warn,
4389 "unset mode + unset bool must resolve to Warn (the new default)"
4390 );
4391 }
4392
4393 #[test]
4394 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4395 let mut config = OAuthConfig::default();
4396 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4397 {
4398 config.strict_audience_validation = true;
4399 }
4400 assert_eq!(
4401 config.effective_audience_validation_mode(),
4402 AudienceValidationMode::Strict,
4403 "legacy bool=true must resolve to Strict for backward compat"
4404 );
4405 }
4406
4407 #[derive(Clone, Default)]
4408 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4409
4410 impl CapturedLogs {
4411 fn contents(&self) -> String {
4412 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4413 String::from_utf8(bytes).unwrap_or_default()
4414 }
4415 }
4416
4417 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4418
4419 impl std::io::Write for CapturedLogsWriter {
4420 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4421 if let Ok(mut guard) = self.0.lock() {
4422 guard.extend_from_slice(buf);
4423 }
4424 Ok(buf.len())
4425 }
4426
4427 fn flush(&mut self) -> std::io::Result<()> {
4428 Ok(())
4429 }
4430 }
4431
4432 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4433 type Writer = CapturedLogsWriter;
4434
4435 fn make_writer(&'a self) -> Self::Writer {
4436 CapturedLogsWriter(Arc::clone(&self.0))
4437 }
4438 }
4439
4440 #[tokio::test]
4441 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4442 let kid = "oversized-jwks";
4443 let (_pem, jwks) = generate_test_keypair(kid);
4444 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4445 oversized_body.push_str(&" ".repeat(4096));
4446
4447 let mock_server = wiremock::MockServer::start().await;
4448 wiremock::Mock::given(wiremock::matchers::method("GET"))
4449 .and(wiremock::matchers::path("/jwks.json"))
4450 .respond_with(
4451 wiremock::ResponseTemplate::new(200)
4452 .insert_header("content-type", "application/json")
4453 .set_body_string(oversized_body),
4454 )
4455 .mount(&mock_server)
4456 .await;
4457
4458 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4459 let mut config = test_config(&jwks_uri);
4460 config.jwks_max_response_bytes = 256;
4461 let cache = test_cache(&config);
4462
4463 let logs = CapturedLogs::default();
4464 let subscriber = tracing_subscriber::fmt()
4465 .with_writer(logs.clone())
4466 .with_ansi(false)
4467 .without_time()
4468 .finish();
4469 let _guard = tracing::subscriber::set_default(subscriber);
4470
4471 let result = cache.fetch_jwks().await;
4472 assert!(result.is_none(), "oversized JWKS must be dropped");
4473 assert!(
4474 logs.contents()
4475 .contains("JWKS response exceeded configured size cap"),
4476 "expected cap-exceeded warning in logs"
4477 );
4478 }
4479
4480 #[tokio::test]
4484 async fn redirect_rejection_log_does_not_echo_credentials() {
4485 let mock_server = wiremock::MockServer::start().await;
4486 wiremock::Mock::given(wiremock::matchers::method("GET"))
4487 .and(wiremock::matchers::path("/jwks.json"))
4488 .respond_with(
4489 wiremock::ResponseTemplate::new(302)
4490 .insert_header("location", "https://u:p@redirect-target.example/next"),
4491 )
4492 .mount(&mock_server)
4493 .await;
4494
4495 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4496 let config = test_config(&jwks_uri);
4497 let cache = test_cache(&config);
4498
4499 let logs = CapturedLogs::default();
4500 let subscriber = tracing_subscriber::fmt()
4501 .with_writer(logs.clone())
4502 .with_ansi(false)
4503 .without_time()
4504 .finish();
4505 let _guard = tracing::subscriber::set_default(subscriber);
4506
4507 let result = cache.fetch_jwks().await;
4508 assert!(result.is_none(), "rejected redirect must fail the fetch");
4509 let contents = logs.contents();
4510 assert!(
4511 contents.contains("oauth redirect rejected"),
4512 "expected redirect-rejection warning in logs: {contents}"
4513 );
4514 assert!(
4515 !contents.contains("u:p"),
4516 "rejection log must not echo userinfo credentials: {contents}"
4517 );
4518 }
4519
4520 #[tokio::test]
4521 async fn role_claim_keycloak_nested_array() {
4522 let kid = "test-role-1";
4523 let (pem, jwks) = generate_test_keypair(kid);
4524
4525 let mock_server = wiremock::MockServer::start().await;
4526 wiremock::Mock::given(wiremock::matchers::method("GET"))
4527 .and(wiremock::matchers::path("/jwks.json"))
4528 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4529 .mount(&mock_server)
4530 .await;
4531
4532 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4533 let config = test_config_with_role_claim(
4534 &jwks_uri,
4535 "realm_access.roles",
4536 vec![
4537 RoleMapping {
4538 claim_value: "mcp-admin".into(),
4539 role: "ops".into(),
4540 },
4541 RoleMapping {
4542 claim_value: "mcp-viewer".into(),
4543 role: "viewer".into(),
4544 },
4545 ],
4546 );
4547 let cache = test_cache(&config);
4548
4549 let now = jsonwebtoken::get_current_timestamp();
4550 let token = mint_token_with_claims(
4551 &pem,
4552 kid,
4553 &serde_json::json!({
4554 "iss": "https://auth.test.local",
4555 "aud": "https://mcp.test.local/mcp",
4556 "sub": "keycloak-user",
4557 "exp": now + 3600,
4558 "iat": now,
4559 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4560 }),
4561 );
4562
4563 let id = cache
4564 .validate_token(&token)
4565 .await
4566 .expect("should authenticate");
4567 assert_eq!(id.name, "keycloak-user");
4568 assert_eq!(id.role, "ops");
4569 }
4570
4571 #[tokio::test]
4572 async fn role_claim_flat_roles_array() {
4573 let kid = "test-role-2";
4574 let (pem, jwks) = generate_test_keypair(kid);
4575
4576 let mock_server = wiremock::MockServer::start().await;
4577 wiremock::Mock::given(wiremock::matchers::method("GET"))
4578 .and(wiremock::matchers::path("/jwks.json"))
4579 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4580 .mount(&mock_server)
4581 .await;
4582
4583 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4584 let config = test_config_with_role_claim(
4585 &jwks_uri,
4586 "roles",
4587 vec![
4588 RoleMapping {
4589 claim_value: "MCP.Admin".into(),
4590 role: "ops".into(),
4591 },
4592 RoleMapping {
4593 claim_value: "MCP.Reader".into(),
4594 role: "viewer".into(),
4595 },
4596 ],
4597 );
4598 let cache = test_cache(&config);
4599
4600 let now = jsonwebtoken::get_current_timestamp();
4601 let token = mint_token_with_claims(
4602 &pem,
4603 kid,
4604 &serde_json::json!({
4605 "iss": "https://auth.test.local",
4606 "aud": "https://mcp.test.local/mcp",
4607 "sub": "azure-ad-user",
4608 "exp": now + 3600,
4609 "iat": now,
4610 "roles": ["MCP.Reader", "OtherApp.Admin"]
4611 }),
4612 );
4613
4614 let id = cache
4615 .validate_token(&token)
4616 .await
4617 .expect("should authenticate");
4618 assert_eq!(id.name, "azure-ad-user");
4619 assert_eq!(id.role, "viewer");
4620 }
4621
4622 #[tokio::test]
4623 async fn role_claim_no_matching_value_rejected() {
4624 let kid = "test-role-3";
4625 let (pem, jwks) = generate_test_keypair(kid);
4626
4627 let mock_server = wiremock::MockServer::start().await;
4628 wiremock::Mock::given(wiremock::matchers::method("GET"))
4629 .and(wiremock::matchers::path("/jwks.json"))
4630 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4631 .mount(&mock_server)
4632 .await;
4633
4634 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4635 let config = test_config_with_role_claim(
4636 &jwks_uri,
4637 "roles",
4638 vec![RoleMapping {
4639 claim_value: "mcp-admin".into(),
4640 role: "ops".into(),
4641 }],
4642 );
4643 let cache = test_cache(&config);
4644
4645 let now = jsonwebtoken::get_current_timestamp();
4646 let token = mint_token_with_claims(
4647 &pem,
4648 kid,
4649 &serde_json::json!({
4650 "iss": "https://auth.test.local",
4651 "aud": "https://mcp.test.local/mcp",
4652 "sub": "limited-user",
4653 "exp": now + 3600,
4654 "iat": now,
4655 "roles": ["some-other-role"]
4656 }),
4657 );
4658
4659 assert!(cache.validate_token(&token).await.is_none());
4660 }
4661
4662 #[tokio::test]
4663 async fn role_claim_space_separated_string() {
4664 let kid = "test-role-4";
4665 let (pem, jwks) = generate_test_keypair(kid);
4666
4667 let mock_server = wiremock::MockServer::start().await;
4668 wiremock::Mock::given(wiremock::matchers::method("GET"))
4669 .and(wiremock::matchers::path("/jwks.json"))
4670 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4671 .mount(&mock_server)
4672 .await;
4673
4674 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4675 let config = test_config_with_role_claim(
4676 &jwks_uri,
4677 "custom_scope",
4678 vec![
4679 RoleMapping {
4680 claim_value: "write".into(),
4681 role: "ops".into(),
4682 },
4683 RoleMapping {
4684 claim_value: "read".into(),
4685 role: "viewer".into(),
4686 },
4687 ],
4688 );
4689 let cache = test_cache(&config);
4690
4691 let now = jsonwebtoken::get_current_timestamp();
4692 let token = mint_token_with_claims(
4693 &pem,
4694 kid,
4695 &serde_json::json!({
4696 "iss": "https://auth.test.local",
4697 "aud": "https://mcp.test.local/mcp",
4698 "sub": "custom-client",
4699 "exp": now + 3600,
4700 "iat": now,
4701 "custom_scope": "read audit"
4702 }),
4703 );
4704
4705 let id = cache
4706 .validate_token(&token)
4707 .await
4708 .expect("should authenticate");
4709 assert_eq!(id.name, "custom-client");
4710 assert_eq!(id.role, "viewer");
4711 }
4712
4713 #[tokio::test]
4714 async fn scope_backward_compat_without_role_claim() {
4715 let kid = "test-compat-1";
4717 let (pem, jwks) = generate_test_keypair(kid);
4718
4719 let mock_server = wiremock::MockServer::start().await;
4720 wiremock::Mock::given(wiremock::matchers::method("GET"))
4721 .and(wiremock::matchers::path("/jwks.json"))
4722 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4723 .mount(&mock_server)
4724 .await;
4725
4726 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4727 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4729
4730 let token = mint_token(
4731 &pem,
4732 kid,
4733 "https://auth.test.local",
4734 "https://mcp.test.local/mcp",
4735 "legacy-bot",
4736 "mcp:admin other:scope",
4737 );
4738
4739 let id = cache
4740 .validate_token(&token)
4741 .await
4742 .expect("should authenticate");
4743 assert_eq!(id.name, "legacy-bot");
4744 assert_eq!(id.role, "ops"); }
4746
4747 #[tokio::test]
4752 async fn jwks_refresh_deduplication() {
4753 let kid = "test-dedup";
4756 let (pem, jwks) = generate_test_keypair(kid);
4757
4758 let mock_server = wiremock::MockServer::start().await;
4759 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4760 .and(wiremock::matchers::path("/jwks.json"))
4761 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4762 .expect(1) .mount(&mock_server)
4764 .await;
4765
4766 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4767 let config = test_config(&jwks_uri);
4768 let cache = Arc::new(test_cache(&config));
4769
4770 let token = mint_token(
4772 &pem,
4773 kid,
4774 "https://auth.test.local",
4775 "https://mcp.test.local/mcp",
4776 "concurrent-bot",
4777 "mcp:read",
4778 );
4779
4780 let mut handles = Vec::new();
4781 for _ in 0..5 {
4782 let c = Arc::clone(&cache);
4783 let t = token.clone();
4784 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4785 }
4786
4787 for h in handles {
4788 let result = h.await.unwrap();
4789 assert!(result.is_some(), "all concurrent requests should succeed");
4790 }
4791
4792 }
4794
4795 #[tokio::test]
4796 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4797 let kid = "test-cooldown";
4800 let (_pem, jwks) = generate_test_keypair(kid);
4801
4802 let mock_server = wiremock::MockServer::start().await;
4803 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4804 .and(wiremock::matchers::path("/jwks.json"))
4805 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4806 .expect(1) .mount(&mock_server)
4808 .await;
4809
4810 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4811 let config = test_config(&jwks_uri);
4812 let cache = test_cache(&config);
4813
4814 let fake_token1 =
4816 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4817 let _ = cache.validate_token(fake_token1).await;
4818
4819 let fake_token2 =
4822 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4823 let _ = cache.validate_token(fake_token2).await;
4824
4825 let fake_token3 =
4827 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4828 let _ = cache.validate_token(fake_token3).await;
4829
4830 }
4832
4833 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
4836 OAuthProxyConfig {
4837 authorize_url: "https://example.invalid/auth".into(),
4838 token_url: token_url.into(),
4839 client_id: "mcp-client".into(),
4840 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
4841 introspection_url: None,
4842 revocation_url: None,
4843 expose_admin_endpoints: false,
4844 require_auth_on_admin_endpoints: false,
4845 allow_unauthenticated_admin_endpoints: false,
4846 }
4847 }
4848
4849 fn test_http_client() -> OauthHttpClient {
4852 rustls::crypto::ring::default_provider()
4853 .install_default()
4854 .ok();
4855 let config = OAuthConfig::builder(
4856 "https://auth.test.local",
4857 "https://mcp.test.local/mcp",
4858 "https://auth.test.local/.well-known/jwks.json",
4859 )
4860 .allow_http_oauth_urls(true)
4861 .build();
4862 OauthHttpClient::with_config(&config)
4863 .expect("build test http client")
4864 .__test_allow_loopback_ssrf()
4865 }
4866
4867 #[tokio::test]
4868 async fn introspect_proxies_and_injects_client_credentials() {
4869 use wiremock::matchers::{body_string_contains, method, path};
4870
4871 let mock_server = wiremock::MockServer::start().await;
4872 wiremock::Mock::given(method("POST"))
4873 .and(path("/introspect"))
4874 .and(body_string_contains("client_id=mcp-client"))
4875 .and(body_string_contains("client_secret=shh"))
4876 .and(body_string_contains("token=abc"))
4877 .respond_with(
4878 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
4879 "active": true,
4880 "scope": "read"
4881 })),
4882 )
4883 .expect(1)
4884 .mount(&mock_server)
4885 .await;
4886
4887 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4888 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
4889
4890 let http = test_http_client();
4891 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4892 assert_eq!(resp.status(), 200);
4893 }
4894
4895 #[tokio::test]
4896 async fn introspect_returns_404_when_not_configured() {
4897 let proxy = proxy_cfg("https://example.invalid/token");
4898 let http = test_http_client();
4899 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4900 assert_eq!(resp.status(), 404);
4901 }
4902
4903 #[tokio::test]
4904 async fn revoke_proxies_and_returns_upstream_status() {
4905 use wiremock::matchers::{method, path};
4906
4907 let mock_server = wiremock::MockServer::start().await;
4908 wiremock::Mock::given(method("POST"))
4909 .and(path("/revoke"))
4910 .respond_with(wiremock::ResponseTemplate::new(200))
4911 .expect(1)
4912 .mount(&mock_server)
4913 .await;
4914
4915 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4916 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4917
4918 let http = test_http_client();
4919 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4920 assert_eq!(resp.status(), 200);
4921 }
4922
4923 #[tokio::test]
4924 async fn revoke_returns_404_when_not_configured() {
4925 let proxy = proxy_cfg("https://example.invalid/token");
4926 let http = test_http_client();
4927 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4928 assert_eq!(resp.status(), 404);
4929 }
4930
4931 #[test]
4932 fn metadata_advertises_endpoints_only_when_configured() {
4933 let mut cfg = test_config("https://auth.test.local/jwks.json");
4934 let m = authorization_server_metadata("https://mcp.local", &cfg);
4936 assert!(m.get("introspection_endpoint").is_none());
4937 assert!(m.get("revocation_endpoint").is_none());
4938
4939 let mut proxy = proxy_cfg("https://upstream.local/token");
4942 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4943 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4944 cfg.proxy = Some(proxy);
4945 let m = authorization_server_metadata("https://mcp.local", &cfg);
4946 assert!(
4947 m.get("introspection_endpoint").is_none(),
4948 "introspection must not be advertised when expose_admin_endpoints=false"
4949 );
4950 assert!(
4951 m.get("revocation_endpoint").is_none(),
4952 "revocation must not be advertised when expose_admin_endpoints=false"
4953 );
4954
4955 if let Some(p) = cfg.proxy.as_mut() {
4957 p.expose_admin_endpoints = true;
4958 p.revocation_url = None;
4959 }
4960 let m = authorization_server_metadata("https://mcp.local", &cfg);
4961 assert_eq!(
4962 m["introspection_endpoint"],
4963 serde_json::Value::String("https://mcp.local/introspect".into())
4964 );
4965 assert!(m.get("revocation_endpoint").is_none());
4966
4967 if let Some(p) = cfg.proxy.as_mut() {
4969 p.revocation_url = Some("https://upstream.local/revoke".into());
4970 }
4971 let m = authorization_server_metadata("https://mcp.local", &cfg);
4972 assert_eq!(
4973 m["revocation_endpoint"],
4974 serde_json::Value::String("https://mcp.local/revoke".into())
4975 );
4976 }
4977
4978 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
4981 let mut cfg = validation_https_config();
4982 cfg.token_exchange = Some(tx);
4983 cfg
4984 }
4985
4986 fn tx_with(
4987 client_secret: Option<&str>,
4988 client_cert: Option<ClientCertConfig>,
4989 ) -> TokenExchangeConfig {
4990 TokenExchangeConfig::new(
4991 "https://idp.example.com/token".into(),
4992 "client".into(),
4993 client_secret.map(|s| secrecy::SecretString::new(s.into())),
4994 client_cert,
4995 "downstream".into(),
4996 )
4997 }
4998
4999 #[test]
5000 fn validate_rejects_token_exchange_without_client_auth() {
5001 let cfg = https_cfg_with_tx(tx_with(None, None));
5002 let err = cfg
5003 .validate()
5004 .expect_err("token_exchange without client auth must be rejected");
5005 let msg = err.to_string();
5006 assert!(
5007 msg.contains("requires client authentication"),
5008 "error must explain missing client auth; got {msg:?}"
5009 );
5010 }
5011
5012 #[test]
5013 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
5014 let cc = ClientCertConfig {
5015 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5016 key_path: PathBuf::from("/nonexistent/key.pem"),
5017 };
5018 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
5019 let err = cfg
5020 .validate()
5021 .expect_err("client_secret + client_cert must be rejected");
5022 let msg = err.to_string();
5023 assert!(
5024 msg.contains("mutually") && msg.contains("exclusive"),
5025 "error must explain mutual exclusion; got {msg:?}"
5026 );
5027 }
5028
5029 #[cfg(not(feature = "oauth-mtls-client"))]
5030 #[test]
5031 fn validate_rejects_client_cert_without_feature() {
5032 let cc = ClientCertConfig {
5033 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5034 key_path: PathBuf::from("/nonexistent/key.pem"),
5035 };
5036 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5037 let err = cfg
5038 .validate()
5039 .expect_err("client_cert without feature must be rejected");
5040 assert!(
5041 err.to_string().contains("oauth-mtls-client"),
5042 "error must reference the cargo feature; got {err}"
5043 );
5044 }
5045
5046 #[cfg(feature = "oauth-mtls-client")]
5047 #[test]
5048 fn validate_rejects_missing_client_cert_files() {
5049 let cc = ClientCertConfig {
5050 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5051 key_path: PathBuf::from("/nonexistent/key.pem"),
5052 };
5053 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5054 let err = cfg
5055 .validate()
5056 .expect_err("missing cert file must be rejected");
5057 assert!(
5058 err.to_string().contains("unreadable"),
5059 "error must call out unreadable file; got {err}"
5060 );
5061 }
5062
5063 #[cfg(feature = "oauth-mtls-client")]
5064 #[test]
5065 fn validate_rejects_malformed_client_cert_pem() {
5066 let dir = std::env::temp_dir();
5067 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5068 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5069 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5070 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5071 let cc = ClientCertConfig {
5072 cert_path: cert.clone(),
5073 key_path: key.clone(),
5074 };
5075 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5076 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5077 let _ = std::fs::remove_file(&cert);
5078 let _ = std::fs::remove_file(&key);
5079 assert!(
5080 err.to_string().contains("PEM parse failed"),
5081 "error must call out PEM parse failure; got {err}"
5082 );
5083 }
5084
5085 #[cfg(feature = "oauth-mtls-client")]
5086 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5087 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5088 let dir = std::env::temp_dir();
5089 let pid = std::process::id();
5090 let nonce: u64 = rand::random();
5091 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5092 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5093 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5094 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5095 (cert_path, key_path)
5096 }
5097
5098 #[cfg(feature = "oauth-mtls-client")]
5099 fn install_test_crypto_provider() {
5100 let _ = rustls::crypto::ring::default_provider().install_default();
5101 }
5102
5103 #[cfg(feature = "oauth-mtls-client")]
5104 #[test]
5105 fn validate_accepts_well_formed_client_cert() {
5106 install_test_crypto_provider();
5107 let (cert_path, key_path) = write_self_signed_pem();
5108 let cc = ClientCertConfig {
5109 cert_path: cert_path.clone(),
5110 key_path: key_path.clone(),
5111 };
5112 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5113 let res = cfg.validate();
5114 let _ = std::fs::remove_file(&cert_path);
5115 let _ = std::fs::remove_file(&key_path);
5116 res.expect("well-formed cert+key must validate");
5117 }
5118
5119 #[cfg(feature = "oauth-mtls-client")]
5120 #[test]
5121 fn client_for_returns_cached_mtls_client() {
5122 install_test_crypto_provider();
5123 let (cert_path, key_path) = write_self_signed_pem();
5124 let cc = ClientCertConfig {
5125 cert_path: cert_path.clone(),
5126 key_path: key_path.clone(),
5127 };
5128 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5129 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5130 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5131 let cert_client = http.client_for(tx_ref);
5132 let inner_client = http.client_for(&tx_with(Some("s"), None));
5133 let _ = std::fs::remove_file(&cert_path);
5134 let _ = std::fs::remove_file(&key_path);
5135 assert!(
5136 !std::ptr::eq(cert_client, inner_client),
5137 "client_for must return distinct clients for cert vs no-cert configs"
5138 );
5139 }
5140
5141 #[cfg(feature = "oauth-mtls-client")]
5142 #[test]
5143 fn client_for_falls_back_to_inner_when_cache_miss() {
5144 install_test_crypto_provider();
5145 let cfg = validation_https_config();
5146 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5147 let unrelated_cc = ClientCertConfig {
5148 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5149 key_path: PathBuf::from("/cache/miss/key.pem"),
5150 };
5151 let tx_unknown = tx_with(None, Some(unrelated_cc));
5152 let fallback = http.client_for(&tx_unknown);
5153 let inner = http.client_for(&tx_with(Some("s"), None));
5154 assert!(
5155 std::ptr::eq(fallback, inner),
5156 "cache miss must fall back to inner client"
5157 );
5158 }
5159}