1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86async fn screen_oauth_target_core(
106 url: &str,
107 allow_http: bool,
108 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
109 test_allow_loopback_ssrf: bool,
110) -> Result<(), crate::error::McpxError> {
111 let parsed = check_oauth_url("oauth target", url, allow_http)?;
112 if test_allow_loopback_ssrf {
113 return Ok(());
114 }
115 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
116 return Err(crate::error::McpxError::Config(format!(
117 "OAuth target forbidden ({reason}): {url}"
118 )));
119 }
120
121 let host = parsed.host_str().ok_or_else(|| {
122 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
123 })?;
124 let port = parsed.port_or_known_default().ok_or_else(|| {
125 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
126 })?;
127
128 let addrs = lookup_host((host, port)).await.map_err(|error| {
129 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
130 })?;
131
132 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
133 let mut any_addr = false;
134 for addr in addrs {
135 any_addr = true;
136 let ip = addr.ip();
137 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
138 if reason == "cloud_metadata" {
141 return Err(crate::error::McpxError::Config(format!(
142 "OAuth target resolved to blocked IP ({reason}): {url}"
143 )));
144 }
145 if allowlist.is_empty() {
149 return Err(crate::error::McpxError::Config(format!(
150 "OAuth target resolved to blocked IP ({reason}): {url}"
151 )));
152 }
153 if host_allowed || allowlist.ip_allowed(ip) {
155 continue;
156 }
157 return Err(crate::error::McpxError::Config(format!(
158 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
159 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
160 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
161 URL: {url}"
162 )));
163 }
164 }
165 if !any_addr {
166 return Err(crate::error::McpxError::Config(format!(
167 "OAuth target DNS resolution returned no addresses: {url}"
168 )));
169 }
170
171 Ok(())
172}
173
174async fn screen_oauth_target(
177 url: &str,
178 allow_http: bool,
179 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
180) -> Result<(), crate::error::McpxError> {
181 screen_oauth_target_core(url, allow_http, allowlist, false).await
182}
183
184#[cfg(any(test, feature = "test-helpers"))]
188async fn screen_oauth_target_with_test_override(
189 url: &str,
190 allow_http: bool,
191 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
192 test_allow_loopback_ssrf: bool,
193) -> Result<(), crate::error::McpxError> {
194 screen_oauth_target_core(url, allow_http, allowlist, test_allow_loopback_ssrf).await
195}
196
197#[derive(Clone)]
238pub struct OauthHttpClient {
239 inner: reqwest::Client,
240 allow_http: bool,
241 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
246 #[cfg(feature = "oauth-mtls-client")]
251 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
252 #[cfg(any(test, feature = "test-helpers"))]
258 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
259}
260
261#[cfg(feature = "oauth-mtls-client")]
265#[derive(Debug, Clone, Hash, Eq, PartialEq)]
266struct MtlsClientKey {
267 cert_path: PathBuf,
268 key_path: PathBuf,
269}
270
271impl OauthHttpClient {
272 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
290 Self::build(Some(config))
291 }
292
293 #[deprecated(
316 since = "1.2.1",
317 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
318 )]
319 pub fn new() -> Result<Self, crate::error::McpxError> {
320 Self::build(None)
321 }
322
323 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
326 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
327
328 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
333 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
334 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
335 })?),
336 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
337 };
338
339 let redirect_allowlist = Arc::clone(&allowlist);
342
343 #[cfg(any(test, feature = "test-helpers"))]
347 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
348 Arc::new(AtomicBool::new(false));
349 #[cfg(not(any(test, feature = "test-helpers")))]
350 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
351
352 let resolver: Arc<dyn reqwest::dns::Resolve> =
353 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
354 Arc::clone(&allowlist),
355 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
359 test_bypass.clone(),
360 ));
361
362 let mut builder = reqwest::Client::builder()
363 .no_proxy()
367 .dns_resolver(Arc::clone(&resolver))
368 .connect_timeout(Duration::from_secs(10))
369 .timeout(Duration::from_secs(30))
370 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
371 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
381 Ok(()) => attempt.follow(),
382 Err(reason) => {
383 tracing::warn!(
384 reason = %reason,
385 target = %attempt.url(),
386 "oauth redirect rejected"
387 );
388 attempt.error(reason)
389 }
390 }
391 }));
392
393 if let Some(cfg) = config
394 && let Some(ref ca_path) = cfg.ca_cert_path
395 {
396 let pem = std::fs::read(ca_path).map_err(|e| {
401 crate::error::McpxError::Startup(format!(
402 "oauth http client: read ca_cert_path {}: {e}",
403 ca_path.display()
404 ))
405 })?;
406 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
407 crate::error::McpxError::Startup(format!(
408 "oauth http client: parse ca_cert_path {}: {e}",
409 ca_path.display()
410 ))
411 })?;
412 builder = builder.add_root_certificate(cert);
413 }
414
415 let inner = builder.build().map_err(|e| {
416 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
417 })?;
418
419 #[cfg(feature = "oauth-mtls-client")]
420 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
421
422 Ok(Self {
423 inner,
424 allow_http,
425 allowlist,
426 #[cfg(feature = "oauth-mtls-client")]
427 mtls_clients,
428 #[cfg(any(test, feature = "test-helpers"))]
429 test_allow_loopback_ssrf: test_bypass,
430 })
431 }
432
433 async fn send_screened(
434 &self,
435 url: &str,
436 request: reqwest::RequestBuilder,
437 ) -> Result<reqwest::Response, crate::error::McpxError> {
438 #[cfg(any(test, feature = "test-helpers"))]
439 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
440 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
441 .await?;
442 } else {
443 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
444 }
445 #[cfg(not(any(test, feature = "test-helpers")))]
446 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
447 request.send().await.map_err(|error| {
448 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
449 })
450 }
451
452 #[cfg(any(test, feature = "test-helpers"))]
457 #[doc(hidden)]
458 #[must_use]
459 pub fn __test_allow_loopback_ssrf(self) -> Self {
460 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
463 self
464 }
465
466 #[doc(hidden)]
472 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
473 self.inner.get(url).send().await
474 }
475
476 #[cfg(any(test, feature = "test-helpers"))]
482 #[doc(hidden)]
483 #[must_use]
484 pub fn __test_inner_client(&self) -> &reqwest::Client {
485 &self.inner
486 }
487
488 #[cfg(feature = "oauth-mtls-client")]
495 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
496 if let Some(cc) = &cfg.client_cert {
497 let key = MtlsClientKey {
498 cert_path: cc.cert_path.clone(),
499 key_path: cc.key_path.clone(),
500 };
501 if let Some(client) = self.mtls_clients.get(&key) {
502 return client;
503 }
504 }
505 &self.inner
506 }
507
508 #[cfg(not(feature = "oauth-mtls-client"))]
509 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
510 &self.inner
511 }
512}
513
514impl std::fmt::Debug for OauthHttpClient {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
517 }
518}
519
520#[derive(Debug, Clone, Default, Deserialize)]
580#[non_exhaustive]
581pub struct OAuthSsrfAllowlist {
582 #[serde(default)]
587 pub hosts: Vec<String>,
588 #[serde(default)]
594 pub cidrs: Vec<String>,
595}
596
597fn compile_oauth_ssrf_allowlist(
604 raw: &OAuthSsrfAllowlist,
605) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
606 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
607 for (idx, entry) in raw.hosts.iter().enumerate() {
608 let trimmed = entry.trim();
609 if trimmed.is_empty() {
610 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
611 }
612 if trimmed.contains([':', '/', '@', '?', '#']) {
616 return Err(format!(
617 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
618 (no scheme, port, path, userinfo, query, or fragment)"
619 ));
620 }
621 match url::Host::parse(trimmed) {
622 Ok(url::Host::Domain(_)) => {}
623 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
624 return Err(format!(
625 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
626 here -- list them via oauth.ssrf_allowlist.cidrs instead"
627 ));
628 }
629 Err(e) => {
630 return Err(format!(
631 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
632 ));
633 }
634 }
635 hosts.push(trimmed.to_ascii_lowercase());
636 }
637 hosts.sort();
638 hosts.dedup();
639
640 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
641 for (idx, entry) in raw.cidrs.iter().enumerate() {
642 let parsed = crate::ssrf::CidrEntry::parse(entry)
643 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
644 cidrs.push(parsed);
645 }
646
647 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
648}
649
650#[derive(Debug, Clone, Deserialize)]
652#[non_exhaustive]
653pub struct OAuthConfig {
654 pub issuer: String,
656 pub audience: String,
658 pub jwks_uri: String,
660 #[serde(default)]
663 pub scopes: Vec<ScopeMapping>,
664 pub role_claim: Option<String>,
670 #[serde(default)]
673 pub role_mappings: Vec<RoleMapping>,
674 #[serde(default = "default_jwks_cache_ttl")]
677 pub jwks_cache_ttl: String,
678 pub proxy: Option<OAuthProxyConfig>,
682 pub token_exchange: Option<TokenExchangeConfig>,
687 #[serde(default)]
702 pub ca_cert_path: Option<PathBuf>,
703 #[serde(default)]
715 pub allow_http_oauth_urls: bool,
716 #[serde(default)]
725 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
726 #[serde(default = "default_max_jwks_keys")]
730 pub max_jwks_keys: usize,
731 #[serde(default)]
740 #[deprecated(
741 since = "1.7.0",
742 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
743 )]
744 pub strict_audience_validation: bool,
745 #[serde(default)]
753 pub audience_validation_mode: Option<AudienceValidationMode>,
754 #[serde(default = "default_jwks_max_bytes")]
758 pub jwks_max_response_bytes: u64,
759}
760
761fn default_jwks_cache_ttl() -> String {
762 "10m".into()
763}
764
765const fn default_max_jwks_keys() -> usize {
766 256
767}
768
769const fn default_jwks_max_bytes() -> u64 {
770 1024 * 1024
771}
772
773#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
790#[serde(rename_all = "snake_case")]
791#[non_exhaustive]
792pub enum AudienceValidationMode {
793 Permissive,
797 #[default]
801 Warn,
802 Strict,
806}
807
808impl Default for OAuthConfig {
809 fn default() -> Self {
810 Self {
811 issuer: String::new(),
812 audience: String::new(),
813 jwks_uri: String::new(),
814 scopes: Vec::new(),
815 role_claim: None,
816 role_mappings: Vec::new(),
817 jwks_cache_ttl: default_jwks_cache_ttl(),
818 proxy: None,
819 token_exchange: None,
820 ca_cert_path: None,
821 allow_http_oauth_urls: false,
822 max_jwks_keys: default_max_jwks_keys(),
823 #[allow(
824 deprecated,
825 reason = "default-construct deprecated field for backward compat"
826 )]
827 strict_audience_validation: false,
828 audience_validation_mode: None,
829 jwks_max_response_bytes: default_jwks_max_bytes(),
830 ssrf_allowlist: None,
831 }
832 }
833}
834
835impl OAuthConfig {
836 #[must_use]
842 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
843 if let Some(mode) = self.audience_validation_mode {
844 return mode;
845 }
846 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
847 if self.strict_audience_validation {
848 AudienceValidationMode::Strict
849 } else {
850 AudienceValidationMode::Warn
851 }
852 }
853
854 pub fn builder(
860 issuer: impl Into<String>,
861 audience: impl Into<String>,
862 jwks_uri: impl Into<String>,
863 ) -> OAuthConfigBuilder {
864 OAuthConfigBuilder {
865 inner: Self {
866 issuer: issuer.into(),
867 audience: audience.into(),
868 jwks_uri: jwks_uri.into(),
869 ..Self::default()
870 },
871 }
872 }
873
874 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
890 let allow_http = self.allow_http_oauth_urls;
891 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
892 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
893 return Err(crate::error::McpxError::Config(format!(
894 "oauth.issuer forbidden ({reason})"
895 )));
896 }
897 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
898 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
899 return Err(crate::error::McpxError::Config(format!(
900 "oauth.jwks_uri forbidden ({reason})"
901 )));
902 }
903 if let Some(proxy) = &self.proxy {
904 let url = check_oauth_url(
905 "oauth.proxy.authorize_url",
906 &proxy.authorize_url,
907 allow_http,
908 )?;
909 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
910 return Err(crate::error::McpxError::Config(format!(
911 "oauth.proxy.authorize_url forbidden ({reason})"
912 )));
913 }
914 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
915 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
916 return Err(crate::error::McpxError::Config(format!(
917 "oauth.proxy.token_url forbidden ({reason})"
918 )));
919 }
920 if let Some(url) = &proxy.introspection_url {
921 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
922 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
923 return Err(crate::error::McpxError::Config(format!(
924 "oauth.proxy.introspection_url forbidden ({reason})"
925 )));
926 }
927 }
928 if let Some(url) = &proxy.revocation_url {
929 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
930 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
931 return Err(crate::error::McpxError::Config(format!(
932 "oauth.proxy.revocation_url forbidden ({reason})"
933 )));
934 }
935 }
936 if proxy.expose_admin_endpoints
943 && !proxy.require_auth_on_admin_endpoints
944 && !proxy.allow_unauthenticated_admin_endpoints
945 {
946 return Err(crate::error::McpxError::Config(
947 "oauth.proxy: expose_admin_endpoints = true requires \
948 require_auth_on_admin_endpoints = true (recommended) \
949 or allow_unauthenticated_admin_endpoints = true \
950 (explicit opt-out, only safe behind an authenticated \
951 reverse proxy)"
952 .into(),
953 ));
954 }
955 }
956 if let Some(tx) = &self.token_exchange {
957 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
958 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
959 return Err(crate::error::McpxError::Config(format!(
960 "oauth.token_exchange.token_url forbidden ({reason})"
961 )));
962 }
963 validate_token_exchange_client_auth(tx)?;
966 }
967 if let Some(raw) = &self.ssrf_allowlist {
971 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
972 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
973 })?;
974 if !compiled.is_empty() {
975 tracing::warn!(
976 host_count = compiled.host_count(),
977 cidr_count = compiled.cidr_count(),
978 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
979 are now reachable. Cloud-metadata addresses remain blocked. \
980 See SECURITY.md \"Operator allowlist\"."
981 );
982 }
983 }
984 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
987 crate::error::McpxError::Config(format!(
988 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
989 self.jwks_cache_ttl
990 ))
991 })?;
992 Ok(())
993 }
994}
995
996fn validate_token_exchange_client_auth(
1002 tx: &TokenExchangeConfig,
1003) -> Result<(), crate::error::McpxError> {
1004 match (&tx.client_cert, tx.client_secret.is_some()) {
1005 (Some(_), true) => Err(crate::error::McpxError::Config(
1006 "oauth.token_exchange: client_cert and client_secret are mutually \
1007 exclusive (RFC 8705 ยง2). Set exactly one."
1008 .into(),
1009 )),
1010 (None, false) => Err(crate::error::McpxError::Config(
1011 "oauth.token_exchange: token exchange requires client authentication. \
1012 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1013 .into(),
1014 )),
1015 (Some(cc), false) => validate_client_cert_config(cc),
1016 (None, true) => Ok(()),
1017 }
1018}
1019
1020fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1033 #[cfg(not(feature = "oauth-mtls-client"))]
1034 {
1035 let _ = cc;
1036 Err(crate::error::McpxError::Config(
1037 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1038 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1039 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1040 the field"
1041 .into(),
1042 ))
1043 }
1044 #[cfg(feature = "oauth-mtls-client")]
1045 {
1046 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1047 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1048 crate::error::McpxError::Config(format!(
1049 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1050 cc.cert_path.display()
1051 ))
1052 })?;
1053 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1054 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1055 crate::error::McpxError::Config(format!(
1056 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1057 cc.key_path.display()
1058 ))
1059 })?;
1060 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1061 combined.extend_from_slice(&cert_bytes);
1062 if !cert_bytes.ends_with(b"\n") {
1063 combined.push(b'\n');
1064 }
1065 combined.extend_from_slice(&key_bytes);
1066 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1067 tracing::warn!(
1068 error = %e,
1069 cert_path = %cc.cert_path.display(),
1070 key_path = %cc.key_path.display(),
1071 "client cert PEM parse failed"
1072 );
1073 crate::error::McpxError::Config(format!(
1074 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1075 cc.cert_path.display(),
1076 cc.key_path.display()
1077 ))
1078 })?;
1079 Ok(())
1080 }
1081}
1082
1083#[cfg(feature = "oauth-mtls-client")]
1091fn build_mtls_clients(
1092 config: Option<&OAuthConfig>,
1093 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1094 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1095) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1096 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1097 let Some(cfg) = config else {
1098 return Ok(Arc::new(map));
1099 };
1100 let Some(tx) = &cfg.token_exchange else {
1101 return Ok(Arc::new(map));
1102 };
1103 let Some(cc) = &tx.client_cert else {
1104 return Ok(Arc::new(map));
1105 };
1106
1107 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1108 crate::error::McpxError::Startup(format!(
1109 "oauth http client mTLS: read cert_path {}: {e}",
1110 cc.cert_path.display()
1111 ))
1112 })?;
1113 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1114 crate::error::McpxError::Startup(format!(
1115 "oauth http client mTLS: read key_path {}: {e}",
1116 cc.key_path.display()
1117 ))
1118 })?;
1119 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1120 combined.extend_from_slice(&cert_bytes);
1121 if !cert_bytes.ends_with(b"\n") {
1122 combined.push(b'\n');
1123 }
1124 combined.extend_from_slice(&key_bytes);
1125 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1126 crate::error::McpxError::Startup(format!(
1127 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1128 cc.cert_path.display(),
1129 cc.key_path.display()
1130 ))
1131 })?;
1132
1133 let resolver: Arc<dyn reqwest::dns::Resolve> =
1134 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1135 Arc::clone(allowlist),
1136 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1141 test_bypass.clone(),
1142 ));
1143
1144 let mut builder = reqwest::Client::builder()
1145 .no_proxy()
1147 .dns_resolver(Arc::clone(&resolver))
1148 .connect_timeout(Duration::from_secs(10))
1149 .timeout(Duration::from_secs(30))
1150 .redirect(reqwest::redirect::Policy::none())
1151 .identity(identity);
1152
1153 if let Some(ref ca_path) = cfg.ca_cert_path {
1154 let pem = std::fs::read(ca_path).map_err(|e| {
1155 crate::error::McpxError::Startup(format!(
1156 "oauth http client mTLS: read ca_cert_path {}: {e}",
1157 ca_path.display()
1158 ))
1159 })?;
1160 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1161 crate::error::McpxError::Startup(format!(
1162 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1163 ca_path.display()
1164 ))
1165 })?;
1166 builder = builder.add_root_certificate(cert);
1167 }
1168
1169 let client = builder.build().map_err(|e| {
1170 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1171 })?;
1172 map.insert(
1173 MtlsClientKey {
1174 cert_path: cc.cert_path.clone(),
1175 key_path: cc.key_path.clone(),
1176 },
1177 client,
1178 );
1179 Ok(Arc::new(map))
1180}
1181
1182fn check_oauth_url(
1189 field: &str,
1190 raw: &str,
1191 allow_http: bool,
1192) -> Result<url::Url, crate::error::McpxError> {
1193 let parsed = url::Url::parse(raw).map_err(|e| {
1194 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1195 })?;
1196 if !parsed.username().is_empty() || parsed.password().is_some() {
1197 return Err(crate::error::McpxError::Config(format!(
1198 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1199 )));
1200 }
1201 match parsed.scheme() {
1202 "https" => Ok(parsed),
1203 "http" if allow_http => Ok(parsed),
1204 "http" => Err(crate::error::McpxError::Config(format!(
1205 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1206 to override - strongly discouraged in production)"
1207 ))),
1208 other => Err(crate::error::McpxError::Config(format!(
1209 "{field}: must use https scheme (got {other:?})"
1210 ))),
1211 }
1212}
1213
1214#[derive(Debug, Clone)]
1220#[must_use = "builders do nothing until `.build()` is called"]
1221pub struct OAuthConfigBuilder {
1222 inner: OAuthConfig,
1223}
1224
1225impl OAuthConfigBuilder {
1226 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1228 self.inner.scopes = scopes;
1229 self
1230 }
1231
1232 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1234 self.inner.scopes.push(ScopeMapping {
1235 scope: scope.into(),
1236 role: role.into(),
1237 });
1238 self
1239 }
1240
1241 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1244 self.inner.role_claim = Some(claim.into());
1245 self
1246 }
1247
1248 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1250 self.inner.role_mappings = mappings;
1251 self
1252 }
1253
1254 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1257 self.inner.role_mappings.push(RoleMapping {
1258 claim_value: claim_value.into(),
1259 role: role.into(),
1260 });
1261 self
1262 }
1263
1264 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1267 self.inner.jwks_cache_ttl = ttl.into();
1268 self
1269 }
1270
1271 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1274 self.inner.proxy = Some(proxy);
1275 self
1276 }
1277
1278 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1280 self.inner.token_exchange = Some(token_exchange);
1281 self
1282 }
1283
1284 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1289 self.inner.ca_cert_path = Some(path.into());
1290 self
1291 }
1292
1293 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1299 self.inner.allow_http_oauth_urls = allow;
1300 self
1301 }
1302
1303 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1312 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1313 #[allow(
1314 deprecated,
1315 reason = "intentional: deprecated builder forwards to deprecated field"
1316 )]
1317 {
1318 self.inner.strict_audience_validation = strict;
1319 }
1320 self.inner.audience_validation_mode = None;
1321 self
1322 }
1323
1324 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1332 self.inner.audience_validation_mode = Some(mode);
1333 self
1334 }
1335
1336 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1338 self.inner.jwks_max_response_bytes = bytes;
1339 self
1340 }
1341
1342 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1350 self.inner.ssrf_allowlist = Some(allowlist);
1351 self
1352 }
1353
1354 #[must_use]
1356 pub fn build(self) -> OAuthConfig {
1357 self.inner
1358 }
1359}
1360
1361#[derive(Debug, Clone, Deserialize)]
1363#[non_exhaustive]
1364pub struct ScopeMapping {
1365 pub scope: String,
1367 pub role: String,
1369}
1370
1371#[derive(Debug, Clone, Deserialize)]
1375#[non_exhaustive]
1376pub struct RoleMapping {
1377 pub claim_value: String,
1379 pub role: String,
1381}
1382
1383#[derive(Debug, Clone, Deserialize)]
1390#[non_exhaustive]
1391pub struct TokenExchangeConfig {
1392 pub token_url: String,
1395 pub client_id: String,
1397 pub client_secret: Option<secrecy::SecretString>,
1402 pub client_cert: Option<ClientCertConfig>,
1415 pub audience: String,
1419}
1420
1421impl TokenExchangeConfig {
1422 #[must_use]
1424 pub fn new(
1425 token_url: String,
1426 client_id: String,
1427 client_secret: Option<secrecy::SecretString>,
1428 client_cert: Option<ClientCertConfig>,
1429 audience: String,
1430 ) -> Self {
1431 Self {
1432 token_url,
1433 client_id,
1434 client_secret,
1435 client_cert,
1436 audience,
1437 }
1438 }
1439}
1440
1441#[derive(Debug, Clone, Deserialize)]
1445#[non_exhaustive]
1446pub struct ClientCertConfig {
1447 pub cert_path: PathBuf,
1450 pub key_path: PathBuf,
1454}
1455
1456impl ClientCertConfig {
1457 #[must_use]
1461 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1462 Self {
1463 cert_path,
1464 key_path,
1465 }
1466 }
1467}
1468
1469#[derive(Debug, Deserialize)]
1471#[non_exhaustive]
1472pub struct ExchangedToken {
1473 pub access_token: String,
1475 pub expires_in: Option<u64>,
1477 pub issued_token_type: Option<String>,
1480}
1481
1482#[derive(Debug, Clone, Deserialize, Default)]
1489#[non_exhaustive]
1490pub struct OAuthProxyConfig {
1491 pub authorize_url: String,
1494 pub token_url: String,
1497 pub client_id: String,
1499 pub client_secret: Option<secrecy::SecretString>,
1501 #[serde(default)]
1505 pub introspection_url: Option<String>,
1506 #[serde(default)]
1510 pub revocation_url: Option<String>,
1511 #[serde(default)]
1523 pub expose_admin_endpoints: bool,
1524 #[serde(default)]
1530 pub require_auth_on_admin_endpoints: bool,
1531 #[serde(default)]
1542 pub allow_unauthenticated_admin_endpoints: bool,
1543}
1544
1545impl OAuthProxyConfig {
1546 pub fn builder(
1554 authorize_url: impl Into<String>,
1555 token_url: impl Into<String>,
1556 client_id: impl Into<String>,
1557 ) -> OAuthProxyConfigBuilder {
1558 OAuthProxyConfigBuilder {
1559 inner: Self {
1560 authorize_url: authorize_url.into(),
1561 token_url: token_url.into(),
1562 client_id: client_id.into(),
1563 ..Self::default()
1564 },
1565 }
1566 }
1567}
1568
1569#[derive(Debug, Clone)]
1575#[must_use = "builders do nothing until `.build()` is called"]
1576pub struct OAuthProxyConfigBuilder {
1577 inner: OAuthProxyConfig,
1578}
1579
1580impl OAuthProxyConfigBuilder {
1581 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1583 self.inner.client_secret = Some(secret);
1584 self
1585 }
1586
1587 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1591 self.inner.introspection_url = Some(url.into());
1592 self
1593 }
1594
1595 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1599 self.inner.revocation_url = Some(url.into());
1600 self
1601 }
1602
1603 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1611 self.inner.expose_admin_endpoints = expose;
1612 self
1613 }
1614
1615 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1618 self.inner.require_auth_on_admin_endpoints = require;
1619 self
1620 }
1621
1622 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1626 self.inner.allow_unauthenticated_admin_endpoints = allow;
1627 self
1628 }
1629
1630 #[must_use]
1632 pub fn build(self) -> OAuthProxyConfig {
1633 self.inner
1634 }
1635}
1636
1637type JwksKeyCache = (
1645 HashMap<String, (Algorithm, DecodingKey)>,
1646 Vec<(Algorithm, DecodingKey)>,
1647);
1648
1649struct CachedKeys {
1650 keys: HashMap<String, (Algorithm, DecodingKey)>,
1652 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1654 fetched_at: Instant,
1655 ttl: Duration,
1656}
1657
1658impl CachedKeys {
1659 fn is_expired(&self) -> bool {
1660 self.fetched_at.elapsed() >= self.ttl
1661 }
1662}
1663
1664#[allow(
1673 missing_debug_implementations,
1674 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1675)]
1676#[non_exhaustive]
1677pub struct JwksCache {
1678 jwks_uri: String,
1679 ttl: Duration,
1680 max_jwks_keys: usize,
1681 max_response_bytes: u64,
1682 allow_http: bool,
1683 inner: RwLock<Option<CachedKeys>>,
1684 http: reqwest::Client,
1685 validation_template: Validation,
1686 expected_audience: String,
1689 audience_mode: AudienceValidationMode,
1690 azp_fallback_warned: AtomicBool,
1694 scopes: Vec<ScopeMapping>,
1695 role_claim: Option<String>,
1696 role_mappings: Vec<RoleMapping>,
1697 last_refresh_attempt: RwLock<Option<Instant>>,
1700 refresh_lock: tokio::sync::Mutex<()>,
1702 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1706 #[cfg(any(test, feature = "test-helpers"))]
1710 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1711}
1712
1713const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1715
1716const ACCEPTED_ALGS: &[Algorithm] = &[
1718 Algorithm::RS256,
1719 Algorithm::RS384,
1720 Algorithm::RS512,
1721 Algorithm::ES256,
1722 Algorithm::ES384,
1723 Algorithm::PS256,
1724 Algorithm::PS384,
1725 Algorithm::PS512,
1726 Algorithm::EdDSA,
1727];
1728
1729#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1731#[non_exhaustive]
1732pub enum JwtValidationFailure {
1733 Expired,
1735 Invalid,
1737}
1738
1739impl JwksCache {
1740 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1752 rustls::crypto::ring::default_provider()
1755 .install_default()
1756 .ok();
1757 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1758 .install_default()
1759 .ok();
1760
1761 let ttl = humantime::parse_duration(&config.jwks_cache_ttl).map_err(|error| {
1762 format!(
1763 "invalid jwks_cache_ttl {:?}: {error}",
1764 config.jwks_cache_ttl
1765 )
1766 })?;
1767
1768 let mut validation = Validation::new(Algorithm::RS256);
1769 validation.validate_aud = false;
1781 validation.set_issuer(&[&config.issuer]);
1782 validation.set_required_spec_claims(&["exp", "iss"]);
1783 validation.validate_exp = true;
1784 validation.validate_nbf = true;
1785
1786 let allow_http = config.allow_http_oauth_urls;
1787
1788 let allowlist = match config.ssrf_allowlist.as_ref() {
1791 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1792 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1793 "oauth.ssrf_allowlist: {e}"
1794 ))
1795 })?),
1796 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1797 };
1798 let redirect_allowlist = Arc::clone(&allowlist);
1799
1800 #[cfg(any(test, feature = "test-helpers"))]
1802 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1803 Arc::new(AtomicBool::new(false));
1804 #[cfg(not(any(test, feature = "test-helpers")))]
1805 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1806
1807 let resolver: Arc<dyn reqwest::dns::Resolve> =
1808 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1809 Arc::clone(&allowlist),
1810 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1811 test_bypass.clone(),
1812 ));
1813
1814 let mut http_builder = reqwest::Client::builder()
1815 .no_proxy()
1817 .dns_resolver(Arc::clone(&resolver))
1818 .timeout(Duration::from_secs(10))
1819 .connect_timeout(Duration::from_secs(3))
1820 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1821 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1831 Ok(()) => attempt.follow(),
1832 Err(reason) => {
1833 tracing::warn!(
1834 reason = %reason,
1835 target = %attempt.url(),
1836 "oauth redirect rejected"
1837 );
1838 attempt.error(reason)
1839 }
1840 }
1841 }));
1842
1843 if let Some(ref ca_path) = config.ca_cert_path {
1844 let pem = std::fs::read(ca_path)?;
1850 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1851 http_builder = http_builder.add_root_certificate(cert);
1852 }
1853
1854 let http = http_builder.build()?;
1855
1856 Ok(Self {
1857 jwks_uri: config.jwks_uri.clone(),
1858 ttl,
1859 max_jwks_keys: config.max_jwks_keys,
1860 max_response_bytes: config.jwks_max_response_bytes,
1861 allow_http,
1862 inner: RwLock::new(None),
1863 http,
1864 validation_template: validation,
1865 expected_audience: config.audience.clone(),
1866 audience_mode: config.effective_audience_validation_mode(),
1867 azp_fallback_warned: AtomicBool::new(false),
1868 scopes: config.scopes.clone(),
1869 role_claim: config.role_claim.clone(),
1870 role_mappings: config.role_mappings.clone(),
1871 last_refresh_attempt: RwLock::new(None),
1872 refresh_lock: tokio::sync::Mutex::new(()),
1873 allowlist,
1874 #[cfg(any(test, feature = "test-helpers"))]
1875 test_allow_loopback_ssrf: test_bypass,
1876 })
1877 }
1878
1879 #[cfg(any(test, feature = "test-helpers"))]
1883 #[doc(hidden)]
1884 #[must_use]
1885 pub fn __test_allow_loopback_ssrf(self) -> Self {
1886 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1889 self
1890 }
1891
1892 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1894 self.validate_token_with_reason(token).await.ok()
1895 }
1896
1897 pub async fn validate_token_with_reason(
1904 &self,
1905 token: &str,
1906 ) -> Result<AuthIdentity, JwtValidationFailure> {
1907 let claims = self.decode_claims(token).await?;
1908
1909 self.check_audience(&claims)?;
1910 let role = self.resolve_role(&claims)?;
1911
1912 let sub = claims.sub;
1915 let name = claims
1916 .extra
1917 .get("preferred_username")
1918 .and_then(|v| v.as_str())
1919 .map(String::from)
1920 .or_else(|| sub.clone())
1921 .or(claims.azp)
1922 .or(claims.client_id)
1923 .unwrap_or_else(|| "oauth-client".into());
1924
1925 Ok(AuthIdentity {
1926 name,
1927 role,
1928 method: AuthMethod::OAuthJwt,
1929 raw_token: None,
1930 sub,
1931 })
1932 }
1933
1934 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1946 let (key, alg) = self.select_jwks_key(token).await?;
1947
1948 let mut validation = self.validation_template.clone();
1952 validation.algorithms = vec![alg];
1953
1954 let token_owned = token.to_owned();
1957 let join =
1958 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1959 .await;
1960
1961 let decode_result = match join {
1962 Ok(r) => r,
1963 Err(join_err) => {
1964 core::hint::cold_path();
1965 tracing::error!(
1966 error = %join_err,
1967 "JWT decode task panicked or was cancelled"
1968 );
1969 return Err(JwtValidationFailure::Invalid);
1970 }
1971 };
1972
1973 decode_result.map(|td| td.claims).map_err(|e| {
1974 core::hint::cold_path();
1975 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
1976 JwtValidationFailure::Expired
1977 } else {
1978 JwtValidationFailure::Invalid
1979 };
1980 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
1981 failure
1982 })
1983 }
1984
1985 #[allow(
1994 clippy::cognitive_complexity,
1995 reason = "each failure arm pairs `cold_path()` with a distinct `tracing::debug!` site for observability; collapsing into combinators would lose structured-field log sites without reducing real complexity"
1996 )]
1997 async fn select_jwks_key(
1998 &self,
1999 token: &str,
2000 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2001 let Ok(header) = decode_header(token) else {
2002 core::hint::cold_path();
2003 tracing::debug!("JWT header decode failed");
2004 return Err(JwtValidationFailure::Invalid);
2005 };
2006 let kid = header.kid.as_deref();
2007 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2008
2009 if !ACCEPTED_ALGS.contains(&header.alg) {
2010 core::hint::cold_path();
2011 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2012 return Err(JwtValidationFailure::Invalid);
2013 }
2014
2015 let Some(key) = self.find_key(kid, header.alg).await else {
2016 core::hint::cold_path();
2017 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2018 return Err(JwtValidationFailure::Invalid);
2019 };
2020
2021 Ok((key, header.alg))
2022 }
2023
2024 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2033 if claims.aud.contains(&self.expected_audience) {
2034 return Ok(());
2035 }
2036 let azp_match = claims
2037 .azp
2038 .as_deref()
2039 .is_some_and(|azp| azp == self.expected_audience);
2040 if azp_match {
2041 match self.audience_mode {
2042 AudienceValidationMode::Permissive => return Ok(()),
2043 AudienceValidationMode::Warn => {
2044 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2045 tracing::warn!(
2046 expected = %self.expected_audience,
2047 azp = ?claims.azp,
2048 "JWT accepted via deprecated azp-only audience fallback. \
2049 Configure your IdP to populate aud, or set \
2050 audience_validation_mode = \"strict\" once tokens carry aud correctly. \
2051 To silence this warning without changing acceptance, \
2052 set audience_validation_mode = \"permissive\". \
2053 This warning logs once per process."
2054 );
2055 }
2056 return Ok(());
2057 }
2058 AudienceValidationMode::Strict => {}
2059 }
2060 }
2061 core::hint::cold_path();
2062 tracing::debug!(
2063 aud = ?claims.aud.0,
2064 azp = ?claims.azp,
2065 expected = %self.expected_audience,
2066 mode = ?self.audience_mode,
2067 "JWT rejected: audience mismatch"
2068 );
2069 Err(JwtValidationFailure::Invalid)
2070 }
2071
2072 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2078 if let Some(ref claim_path) = self.role_claim {
2079 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2080 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2081 values.extend(resolve_claim_path(&claims.extra, claim_path));
2082 return self
2083 .role_mappings
2084 .iter()
2085 .find(|m| values.contains(&m.claim_value.as_str()))
2086 .map(|m| m.role.clone())
2087 .ok_or(JwtValidationFailure::Invalid);
2088 }
2089
2090 let token_scopes: Vec<&str> = claims
2091 .scope
2092 .as_deref()
2093 .unwrap_or("")
2094 .split_whitespace()
2095 .collect();
2096
2097 self.scopes
2098 .iter()
2099 .find(|m| token_scopes.contains(&m.scope.as_str()))
2100 .map(|m| m.role.clone())
2101 .ok_or(JwtValidationFailure::Invalid)
2102 }
2103
2104 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2107 {
2109 let guard = self.inner.read().await;
2110 if let Some(cached) = guard.as_ref()
2111 && !cached.is_expired()
2112 && let Some(key) = lookup_key(cached, kid, alg)
2113 {
2114 return Some(key);
2115 }
2116 }
2117
2118 self.refresh_with_cooldown().await;
2120
2121 let guard = self.inner.read().await;
2122 guard
2123 .as_ref()
2124 .and_then(|cached| lookup_key(cached, kid, alg))
2125 }
2126
2127 async fn refresh_with_cooldown(&self) {
2132 let _guard = self.refresh_lock.lock().await;
2134
2135 {
2137 let last = self.last_refresh_attempt.read().await;
2138 if let Some(ts) = *last
2139 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2140 {
2141 tracing::debug!(
2142 elapsed_ms = ts.elapsed().as_millis(),
2143 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2144 "JWKS refresh skipped (cooldown active)"
2145 );
2146 return;
2147 }
2148 }
2149
2150 {
2153 let mut last = self.last_refresh_attempt.write().await;
2154 *last = Some(Instant::now());
2155 }
2156
2157 let _ = self.refresh_inner().await;
2159 }
2160
2161 async fn refresh_inner(&self) -> Result<(), String> {
2166 let Some(jwks) = self.fetch_jwks().await else {
2167 return Ok(());
2168 };
2169 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2170 Ok(cache) => cache,
2171 Err(msg) => {
2172 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2173 return Err(msg);
2174 }
2175 };
2176
2177 tracing::debug!(
2178 named = keys.len(),
2179 unnamed = unnamed_keys.len(),
2180 "JWKS refreshed"
2181 );
2182
2183 let mut guard = self.inner.write().await;
2184 *guard = Some(CachedKeys {
2185 keys,
2186 unnamed_keys,
2187 fetched_at: Instant::now(),
2188 ttl: self.ttl,
2189 });
2190 Ok(())
2191 }
2192
2193 #[allow(
2195 clippy::cognitive_complexity,
2196 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2197 )]
2198 async fn fetch_jwks(&self) -> Option<JwkSet> {
2199 #[cfg(any(test, feature = "test-helpers"))]
2200 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2201 screen_oauth_target_with_test_override(
2202 &self.jwks_uri,
2203 self.allow_http,
2204 &self.allowlist,
2205 true,
2206 )
2207 .await
2208 } else {
2209 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2210 };
2211 #[cfg(not(any(test, feature = "test-helpers")))]
2212 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2213
2214 if let Err(error) = screening {
2215 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2216 return None;
2217 }
2218
2219 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2220 Ok(resp) => resp,
2221 Err(e) => {
2222 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2223 return None;
2224 }
2225 };
2226
2227 let initial_capacity =
2228 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2229 let mut body = Vec::with_capacity(initial_capacity);
2230 while let Some(chunk) = match resp.chunk().await {
2231 Ok(chunk) => chunk,
2232 Err(error) => {
2233 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2234 return None;
2235 }
2236 } {
2237 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2238 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2239 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2240 tracing::warn!(
2241 uri = %self.jwks_uri,
2242 max_bytes = self.max_response_bytes,
2243 "JWKS response exceeded configured size cap"
2244 );
2245 return None;
2246 }
2247 body.extend_from_slice(&chunk);
2248 }
2249
2250 match serde_json::from_slice::<JwkSet>(&body) {
2251 Ok(jwks) => Some(jwks),
2252 Err(error) => {
2253 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2254 None
2255 }
2256 }
2257 }
2258
2259 #[cfg(any(test, feature = "test-helpers"))]
2262 #[doc(hidden)]
2263 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2264 let jwks = self
2265 .fetch_jwks()
2266 .await
2267 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2268 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2269 let mut guard = self.inner.write().await;
2270 *guard = Some(CachedKeys {
2271 keys,
2272 unnamed_keys,
2273 fetched_at: Instant::now(),
2274 ttl: self.ttl,
2275 });
2276 Ok(())
2277 }
2278
2279 #[cfg(any(test, feature = "test-helpers"))]
2282 #[doc(hidden)]
2283 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2284 let guard = self.inner.read().await;
2285 guard
2286 .as_ref()
2287 .is_some_and(|cache| cache.keys.contains_key(kid))
2288 }
2289}
2290
2291fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2293 if jwks.keys.len() > max_keys {
2294 return Err(format!(
2295 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2296 jwks.keys.len(),
2297 max_keys
2298 ));
2299 }
2300 let mut keys = HashMap::new();
2301 let mut unnamed_keys = Vec::new();
2302 for jwk in &jwks.keys {
2303 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2304 continue;
2305 };
2306 let Some(alg) = jwk_algorithm(jwk) else {
2307 continue;
2308 };
2309 if let Some(ref kid) = jwk.common.key_id {
2310 keys.insert(kid.clone(), (alg, decoding_key));
2311 } else {
2312 unnamed_keys.push((alg, decoding_key));
2313 }
2314 }
2315 Ok((keys, unnamed_keys))
2316}
2317
2318fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2320 if let Some(kid) = kid
2321 && let Some((cached_alg, key)) = cached.keys.get(kid)
2322 && *cached_alg == alg
2323 {
2324 return Some(key.clone());
2325 }
2326 cached
2328 .unnamed_keys
2329 .iter()
2330 .find(|(a, _)| *a == alg)
2331 .map(|(_, k)| k.clone())
2332}
2333
2334#[allow(
2336 clippy::wildcard_enum_match_arm,
2337 reason = "jsonwebtoken KeyAlgorithm is a large external enum; only the JWT-signing variants are mappable to `Algorithm`"
2338)]
2339fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2340 jwk.common.key_algorithm.and_then(|ka| match ka {
2341 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2342 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2343 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2344 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2345 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2346 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2347 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2348 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2349 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2350 _ => None,
2351 })
2352}
2353
2354fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2375 match path {
2376 "sub" => claims.sub.iter().cloned().collect(),
2377 "azp" => claims.azp.iter().cloned().collect(),
2378 "client_id" => claims.client_id.iter().cloned().collect(),
2379 "aud" => claims.aud.0.clone(),
2380 "scope" => claims
2381 .scope
2382 .as_deref()
2383 .unwrap_or("")
2384 .split_whitespace()
2385 .map(str::to_owned)
2386 .collect(),
2387 _ => Vec::new(),
2388 }
2389}
2390
2391fn resolve_claim_path<'a>(
2401 extra: &'a HashMap<String, serde_json::Value>,
2402 path: &str,
2403) -> Vec<&'a str> {
2404 let mut segments = path.split('.');
2405 let Some(first) = segments.next() else {
2406 return Vec::new();
2407 };
2408
2409 let mut current: Option<&serde_json::Value> = extra.get(first);
2410
2411 for segment in segments {
2412 current = current.and_then(|v| v.get(segment));
2413 }
2414
2415 match current {
2416 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2417 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2418 _ => Vec::new(),
2419 }
2420}
2421
2422#[derive(Debug, Deserialize)]
2428struct Claims {
2429 sub: Option<String>,
2431 #[serde(default)]
2434 aud: OneOrMany,
2435 azp: Option<String>,
2437 client_id: Option<String>,
2439 scope: Option<String>,
2441 #[serde(flatten)]
2443 extra: HashMap<String, serde_json::Value>,
2444}
2445
2446#[derive(Debug, Default)]
2448struct OneOrMany(Vec<String>);
2449
2450impl OneOrMany {
2451 fn contains(&self, value: &str) -> bool {
2452 self.0.iter().any(|v| v == value)
2453 }
2454}
2455
2456impl<'de> Deserialize<'de> for OneOrMany {
2457 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2458 use serde::de;
2459
2460 struct Visitor;
2461 impl<'de> de::Visitor<'de> for Visitor {
2462 type Value = OneOrMany;
2463 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2464 f.write_str("a string or array of strings")
2465 }
2466 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2467 Ok(OneOrMany(vec![v.to_owned()]))
2468 }
2469 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2470 let mut v = Vec::new();
2471 while let Some(s) = seq.next_element::<String>()? {
2472 v.push(s);
2473 }
2474 Ok(OneOrMany(v))
2475 }
2476 }
2477 deserializer.deserialize_any(Visitor)
2478 }
2479}
2480
2481#[must_use]
2488pub fn looks_like_jwt(token: &str) -> bool {
2489 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2490
2491 let mut parts = token.splitn(4, '.');
2492 let Some(header_b64) = parts.next() else {
2493 return false;
2494 };
2495 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2497 return false;
2498 }
2499 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2501 return false;
2502 };
2503 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2505 return false;
2506 };
2507 header.get("alg").is_some()
2508}
2509
2510#[must_use]
2520pub fn protected_resource_metadata(
2521 resource_url: &str,
2522 server_url: &str,
2523 config: &OAuthConfig,
2524) -> serde_json::Value {
2525 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2530 let auth_server = server_url;
2531 serde_json::json!({
2532 "resource": resource_url,
2533 "authorization_servers": [auth_server],
2534 "scopes_supported": scopes,
2535 "bearer_methods_supported": ["header"]
2536 })
2537}
2538
2539#[must_use]
2544pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2545 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2546 let mut meta = serde_json::json!({
2547 "issuer": &config.issuer,
2548 "authorization_endpoint": format!("{server_url}/authorize"),
2549 "token_endpoint": format!("{server_url}/token"),
2550 "registration_endpoint": format!("{server_url}/register"),
2551 "response_types_supported": ["code"],
2552 "grant_types_supported": ["authorization_code", "refresh_token"],
2553 "code_challenge_methods_supported": ["S256"],
2554 "scopes_supported": scopes,
2555 "token_endpoint_auth_methods_supported": ["none"],
2556 });
2557 if let Some(proxy) = &config.proxy
2558 && proxy.expose_admin_endpoints
2559 && let Some(obj) = meta.as_object_mut()
2560 {
2561 if proxy.introspection_url.is_some() {
2562 obj.insert(
2563 "introspection_endpoint".into(),
2564 serde_json::Value::String(format!("{server_url}/introspect")),
2565 );
2566 }
2567 if proxy.revocation_url.is_some() {
2568 obj.insert(
2569 "revocation_endpoint".into(),
2570 serde_json::Value::String(format!("{server_url}/revoke")),
2571 );
2572 }
2573 if proxy.require_auth_on_admin_endpoints {
2574 obj.insert(
2575 "introspection_endpoint_auth_methods_supported".into(),
2576 serde_json::json!(["bearer"]),
2577 );
2578 obj.insert(
2579 "revocation_endpoint_auth_methods_supported".into(),
2580 serde_json::json!(["bearer"]),
2581 );
2582 }
2583 }
2584 meta
2585}
2586
2587#[must_use]
2600pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2601 use axum::{
2602 http::{StatusCode, header},
2603 response::IntoResponse,
2604 };
2605
2606 let upstream_query = replace_client_id(query, &proxy.client_id);
2608 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2609
2610 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2611}
2612
2613pub async fn handle_token(
2619 http: &OauthHttpClient,
2620 proxy: &OAuthProxyConfig,
2621 body: &str,
2622) -> axum::response::Response {
2623 use axum::{
2624 http::{StatusCode, header},
2625 response::IntoResponse,
2626 };
2627
2628 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2630
2631 if let Some(ref secret) = proxy.client_secret {
2633 use std::fmt::Write;
2634
2635 use secrecy::ExposeSecret;
2636 let _ = write!(
2637 upstream_body,
2638 "&client_secret={}",
2639 urlencoding::encode(secret.expose_secret())
2640 );
2641 }
2642
2643 let result = http
2644 .send_screened(
2645 &proxy.token_url,
2646 http.inner
2647 .post(&proxy.token_url)
2648 .header("Content-Type", "application/x-www-form-urlencoded")
2649 .body(upstream_body),
2650 )
2651 .await;
2652
2653 match result {
2654 Ok(resp) => {
2655 let status =
2656 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2657 let body_bytes = resp.bytes().await.unwrap_or_default();
2658 (
2659 status,
2660 [(header::CONTENT_TYPE, "application/json")],
2661 body_bytes,
2662 )
2663 .into_response()
2664 }
2665 Err(e) => {
2666 tracing::error!(error = %e, "OAuth token proxy request failed");
2667 (
2668 StatusCode::BAD_GATEWAY,
2669 [(header::CONTENT_TYPE, "application/json")],
2670 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2671 )
2672 .into_response()
2673 }
2674 }
2675}
2676
2677#[must_use]
2684pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2685 let mut resp = serde_json::json!({
2686 "client_id": proxy.client_id,
2687 "token_endpoint_auth_method": "none",
2688 });
2689 if let Some(uris) = body.get("redirect_uris")
2690 && let Some(obj) = resp.as_object_mut()
2691 {
2692 obj.insert("redirect_uris".into(), uris.clone());
2693 }
2694 if let Some(name) = body.get("client_name")
2695 && let Some(obj) = resp.as_object_mut()
2696 {
2697 obj.insert("client_name".into(), name.clone());
2698 }
2699 resp
2700}
2701
2702pub async fn handle_introspect(
2708 http: &OauthHttpClient,
2709 proxy: &OAuthProxyConfig,
2710 body: &str,
2711) -> axum::response::Response {
2712 let Some(ref url) = proxy.introspection_url else {
2713 return oauth_error_response(
2714 axum::http::StatusCode::NOT_FOUND,
2715 "not_supported",
2716 "introspection endpoint is not configured",
2717 );
2718 };
2719 proxy_oauth_admin_request(http, proxy, url, body).await
2720}
2721
2722pub async fn handle_revoke(
2729 http: &OauthHttpClient,
2730 proxy: &OAuthProxyConfig,
2731 body: &str,
2732) -> axum::response::Response {
2733 let Some(ref url) = proxy.revocation_url else {
2734 return oauth_error_response(
2735 axum::http::StatusCode::NOT_FOUND,
2736 "not_supported",
2737 "revocation endpoint is not configured",
2738 );
2739 };
2740 proxy_oauth_admin_request(http, proxy, url, body).await
2741}
2742
2743async fn proxy_oauth_admin_request(
2747 http: &OauthHttpClient,
2748 proxy: &OAuthProxyConfig,
2749 upstream_url: &str,
2750 body: &str,
2751) -> axum::response::Response {
2752 use axum::{
2753 http::{StatusCode, header},
2754 response::IntoResponse,
2755 };
2756
2757 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2758 if let Some(ref secret) = proxy.client_secret {
2759 use std::fmt::Write;
2760
2761 use secrecy::ExposeSecret;
2762 let _ = write!(
2763 upstream_body,
2764 "&client_secret={}",
2765 urlencoding::encode(secret.expose_secret())
2766 );
2767 }
2768
2769 let result = http
2770 .send_screened(
2771 upstream_url,
2772 http.inner
2773 .post(upstream_url)
2774 .header("Content-Type", "application/x-www-form-urlencoded")
2775 .body(upstream_body),
2776 )
2777 .await;
2778
2779 match result {
2780 Ok(resp) => {
2781 let status =
2782 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2783 let content_type = resp
2784 .headers()
2785 .get(header::CONTENT_TYPE)
2786 .and_then(|v| v.to_str().ok())
2787 .unwrap_or("application/json")
2788 .to_owned();
2789 let body_bytes = resp.bytes().await.unwrap_or_default();
2790 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2791 }
2792 Err(e) => {
2793 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2794 oauth_error_response(
2795 StatusCode::BAD_GATEWAY,
2796 "server_error",
2797 "upstream endpoint unreachable",
2798 )
2799 }
2800 }
2801}
2802
2803fn oauth_error_response(
2804 status: axum::http::StatusCode,
2805 error: &str,
2806 description: &str,
2807) -> axum::response::Response {
2808 use axum::{http::header, response::IntoResponse};
2809 let body = serde_json::json!({
2810 "error": error,
2811 "error_description": description,
2812 });
2813 (
2814 status,
2815 [(header::CONTENT_TYPE, "application/json")],
2816 body.to_string(),
2817 )
2818 .into_response()
2819}
2820
2821#[derive(Debug, Deserialize)]
2827struct OAuthErrorResponse {
2828 error: String,
2829 error_description: Option<String>,
2830}
2831
2832fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2839 match raw {
2840 "invalid_request" => "invalid_request",
2841 "invalid_client" => "invalid_client",
2842 "invalid_grant" => "invalid_grant",
2843 "unauthorized_client" => "unauthorized_client",
2844 "unsupported_grant_type" => "unsupported_grant_type",
2845 "invalid_scope" => "invalid_scope",
2846 "temporarily_unavailable" => "temporarily_unavailable",
2847 "invalid_target" => "invalid_target",
2849 _ => "server_error",
2852 }
2853}
2854
2855pub async fn exchange_token(
2867 http: &OauthHttpClient,
2868 config: &TokenExchangeConfig,
2869 subject_token: &str,
2870) -> Result<ExchangedToken, crate::error::McpxError> {
2871 use secrecy::ExposeSecret;
2872
2873 let client = http.client_for(config);
2874 let mut req = client
2875 .post(&config.token_url)
2876 .header("Content-Type", "application/x-www-form-urlencoded")
2877 .header("Accept", "application/json");
2878
2879 if config.client_cert.is_none()
2888 && let Some(ref secret) = config.client_secret
2889 {
2890 use base64::Engine;
2891 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2892 "{}:{}",
2893 urlencoding::encode(&config.client_id),
2894 urlencoding::encode(secret.expose_secret()),
2895 ));
2896 req = req.header("Authorization", format!("Basic {credentials}"));
2897 }
2898
2899 let form_body = build_exchange_form(config, subject_token);
2900
2901 let resp = http
2902 .send_screened(&config.token_url, req.body(form_body))
2903 .await
2904 .map_err(|e| {
2905 tracing::error!(error = %e, "token exchange request failed");
2906 crate::error::McpxError::Auth("server_error".into())
2908 })?;
2909
2910 let status = resp.status();
2911 let body_bytes = resp.bytes().await.map_err(|e| {
2912 tracing::error!(error = %e, "failed to read token exchange response");
2913 crate::error::McpxError::Auth("server_error".into())
2914 })?;
2915
2916 if !status.is_success() {
2917 core::hint::cold_path();
2918 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2921 let short_code = parsed
2922 .as_ref()
2923 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2924 if let Some(ref e) = parsed {
2925 tracing::warn!(
2926 status = %status,
2927 upstream_error = %e.error,
2928 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2929 client_code = %short_code,
2930 "token exchange rejected by authorization server",
2931 );
2932 } else {
2933 tracing::warn!(
2934 status = %status,
2935 client_code = %short_code,
2936 "token exchange rejected (unparseable upstream body)",
2937 );
2938 }
2939 return Err(crate::error::McpxError::Auth(short_code.into()));
2940 }
2941
2942 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2943 tracing::error!(error = %e, "failed to parse token exchange response");
2944 crate::error::McpxError::Auth("server_error".into())
2947 })?;
2948
2949 log_exchanged_token(&exchanged);
2950
2951 Ok(exchanged)
2952}
2953
2954fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2957 let body = format!(
2958 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2959 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2960 urlencoding::encode(subject_token),
2961 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2962 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2963 urlencoding::encode(&config.audience),
2964 );
2965 if config.client_secret.is_none() {
2966 format!(
2967 "{body}&client_id={}",
2968 urlencoding::encode(&config.client_id)
2969 )
2970 } else {
2971 body
2972 }
2973}
2974
2975fn log_exchanged_token(exchanged: &ExchangedToken) {
2978 use base64::Engine;
2979
2980 if !looks_like_jwt(&exchanged.access_token) {
2981 tracing::debug!(
2982 token_len = exchanged.access_token.len(),
2983 issued_token_type = ?exchanged.issued_token_type,
2984 expires_in = exchanged.expires_in,
2985 "exchanged token (opaque)",
2986 );
2987 return;
2988 }
2989 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
2990 return;
2991 };
2992 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
2993 return;
2994 };
2995 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
2996 return;
2997 };
2998 tracing::debug!(
2999 sub = ?claims.get("sub"),
3000 aud = ?claims.get("aud"),
3001 azp = ?claims.get("azp"),
3002 iss = ?claims.get("iss"),
3003 expires_in = exchanged.expires_in,
3004 "exchanged token claims (JWT)",
3005 );
3006}
3007
3008fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3010 let encoded_id = urlencoding::encode(upstream_client_id);
3011 let mut parts: Vec<String> = params
3012 .split('&')
3013 .filter(|p| !p.starts_with("client_id="))
3014 .map(String::from)
3015 .collect();
3016 parts.push(format!("client_id={encoded_id}"));
3017 parts.join("&")
3018}
3019
3020#[cfg(test)]
3021mod tests {
3022 use std::sync::Arc;
3023
3024 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3025
3026 use super::*;
3027
3028 #[test]
3029 fn looks_like_jwt_valid() {
3030 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3032 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3033 let token = format!("{header}.{payload}.signature");
3034 assert!(looks_like_jwt(&token));
3035 }
3036
3037 #[test]
3038 fn looks_like_jwt_rejects_opaque_token() {
3039 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3040 }
3041
3042 #[test]
3043 fn looks_like_jwt_rejects_two_segments() {
3044 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3045 let token = format!("{header}.payload");
3046 assert!(!looks_like_jwt(&token));
3047 }
3048
3049 #[test]
3050 fn looks_like_jwt_rejects_four_segments() {
3051 assert!(!looks_like_jwt("a.b.c.d"));
3052 }
3053
3054 #[test]
3055 fn looks_like_jwt_rejects_no_alg() {
3056 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3057 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3058 let token = format!("{header}.{payload}.sig");
3059 assert!(!looks_like_jwt(&token));
3060 }
3061
3062 #[test]
3063 fn protected_resource_metadata_shape() {
3064 let config = OAuthConfig {
3065 issuer: "https://auth.example.com".into(),
3066 audience: "https://mcp.example.com/mcp".into(),
3067 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3068 scopes: vec![
3069 ScopeMapping {
3070 scope: "mcp:read".into(),
3071 role: "viewer".into(),
3072 },
3073 ScopeMapping {
3074 scope: "mcp:admin".into(),
3075 role: "ops".into(),
3076 },
3077 ],
3078 role_claim: None,
3079 role_mappings: vec![],
3080 jwks_cache_ttl: "10m".into(),
3081 proxy: None,
3082 token_exchange: None,
3083 ca_cert_path: None,
3084 allow_http_oauth_urls: false,
3085 max_jwks_keys: default_max_jwks_keys(),
3086 #[allow(
3087 deprecated,
3088 reason = "test fixture: explicit value for the deprecated field"
3089 )]
3090 strict_audience_validation: false,
3091 audience_validation_mode: None,
3092 jwks_max_response_bytes: default_jwks_max_bytes(),
3093 ssrf_allowlist: None,
3094 };
3095 let meta = protected_resource_metadata(
3096 "https://mcp.example.com/mcp",
3097 "https://mcp.example.com",
3098 &config,
3099 );
3100 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3101 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3102 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3103 assert_eq!(meta["bearer_methods_supported"][0], "header");
3104 }
3105
3106 fn validation_https_config() -> OAuthConfig {
3111 OAuthConfig::builder(
3112 "https://auth.example.com",
3113 "mcp",
3114 "https://auth.example.com/.well-known/jwks.json",
3115 )
3116 .build()
3117 }
3118
3119 #[test]
3120 fn validate_accepts_all_https_urls() {
3121 let cfg = validation_https_config();
3122 cfg.validate().expect("all-HTTPS config must validate");
3123 }
3124
3125 #[test]
3126 fn validate_rejects_unparseable_jwks_cache_ttl() {
3127 let mut cfg = validation_https_config();
3128 cfg.jwks_cache_ttl = "not-a-duration".into();
3129 let err = cfg
3130 .validate()
3131 .expect_err("malformed jwks_cache_ttl must be rejected");
3132 let msg = err.to_string();
3133 assert!(
3134 msg.contains("jwks_cache_ttl"),
3135 "error must reference offending field; got {msg:?}"
3136 );
3137 }
3138
3139 #[test]
3140 fn validate_rejects_http_jwks_uri() {
3141 let mut cfg = validation_https_config();
3142 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3143 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3144 let msg = err.to_string();
3145 assert!(
3146 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3147 "error must reference offending field + scheme requirement; got {msg:?}"
3148 );
3149 }
3150
3151 #[test]
3152 fn validate_rejects_http_proxy_authorize_url() {
3153 let mut cfg = validation_https_config();
3154 cfg.proxy = Some(
3155 OAuthProxyConfig::builder(
3156 "http://idp.example.com/authorize", "https://idp.example.com/token",
3158 "client",
3159 )
3160 .build(),
3161 );
3162 let err = cfg
3163 .validate()
3164 .expect_err("http authorize_url must be rejected");
3165 assert!(
3166 err.to_string().contains("oauth.proxy.authorize_url"),
3167 "error must reference proxy.authorize_url; got {err}"
3168 );
3169 }
3170
3171 #[test]
3172 fn validate_rejects_http_proxy_token_url() {
3173 let mut cfg = validation_https_config();
3174 cfg.proxy = Some(
3175 OAuthProxyConfig::builder(
3176 "https://idp.example.com/authorize",
3177 "http://idp.example.com/token", "client",
3179 )
3180 .build(),
3181 );
3182 let err = cfg.validate().expect_err("http token_url must be rejected");
3183 assert!(
3184 err.to_string().contains("oauth.proxy.token_url"),
3185 "error must reference proxy.token_url; got {err}"
3186 );
3187 }
3188
3189 #[test]
3190 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3191 let mut cfg = validation_https_config();
3192 cfg.proxy = Some(
3193 OAuthProxyConfig::builder(
3194 "https://idp.example.com/authorize",
3195 "https://idp.example.com/token",
3196 "client",
3197 )
3198 .introspection_url("http://idp.example.com/introspect")
3199 .build(),
3200 );
3201 let err = cfg
3202 .validate()
3203 .expect_err("http introspection_url must be rejected");
3204 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3205
3206 let mut cfg = validation_https_config();
3207 cfg.proxy = Some(
3208 OAuthProxyConfig::builder(
3209 "https://idp.example.com/authorize",
3210 "https://idp.example.com/token",
3211 "client",
3212 )
3213 .revocation_url("http://idp.example.com/revoke")
3214 .build(),
3215 );
3216 let err = cfg
3217 .validate()
3218 .expect_err("http revocation_url must be rejected");
3219 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3220 }
3221
3222 #[test]
3225 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3226 let mut cfg = validation_https_config();
3227 cfg.proxy = Some(
3228 OAuthProxyConfig::builder(
3229 "https://idp.example.com/authorize",
3230 "https://idp.example.com/token",
3231 "client",
3232 )
3233 .introspection_url("https://idp.example.com/introspect")
3234 .expose_admin_endpoints(true)
3235 .build(),
3236 );
3237 let err = cfg
3238 .validate()
3239 .expect_err("expose_admin_endpoints without auth must fail");
3240 let msg = err.to_string();
3241 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3242 assert!(
3243 msg.contains("allow_unauthenticated_admin_endpoints"),
3244 "{msg}"
3245 );
3246 }
3247
3248 #[test]
3249 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3250 let mut cfg = validation_https_config();
3251 cfg.proxy = Some(
3252 OAuthProxyConfig::builder(
3253 "https://idp.example.com/authorize",
3254 "https://idp.example.com/token",
3255 "client",
3256 )
3257 .introspection_url("https://idp.example.com/introspect")
3258 .expose_admin_endpoints(true)
3259 .require_auth_on_admin_endpoints(true)
3260 .build(),
3261 );
3262 cfg.validate()
3263 .expect("authed admin endpoints must validate");
3264 }
3265
3266 #[test]
3267 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3268 let mut cfg = validation_https_config();
3269 cfg.proxy = Some(
3270 OAuthProxyConfig::builder(
3271 "https://idp.example.com/authorize",
3272 "https://idp.example.com/token",
3273 "client",
3274 )
3275 .introspection_url("https://idp.example.com/introspect")
3276 .expose_admin_endpoints(true)
3277 .allow_unauthenticated_admin_endpoints(true)
3278 .build(),
3279 );
3280 cfg.validate()
3281 .expect("explicit unauth opt-out must validate");
3282 }
3283
3284 #[test]
3285 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3286 let mut cfg = validation_https_config();
3289 cfg.proxy = Some(
3290 OAuthProxyConfig::builder(
3291 "https://idp.example.com/authorize",
3292 "https://idp.example.com/token",
3293 "client",
3294 )
3295 .introspection_url("https://idp.example.com/introspect")
3296 .build(),
3297 );
3298 cfg.validate()
3299 .expect("unexposed admin endpoints must validate");
3300 }
3301
3302 #[test]
3303 fn validate_rejects_http_token_exchange_url() {
3304 let mut cfg = validation_https_config();
3305 cfg.token_exchange = Some(TokenExchangeConfig::new(
3306 "http://idp.example.com/token".into(), "client".into(),
3308 None,
3309 None,
3310 "downstream".into(),
3311 ));
3312 let err = cfg
3313 .validate()
3314 .expect_err("http token_exchange.token_url must be rejected");
3315 assert!(
3316 err.to_string().contains("oauth.token_exchange.token_url"),
3317 "error must reference token_exchange.token_url; got {err}"
3318 );
3319 }
3320
3321 #[test]
3322 fn validate_rejects_unparseable_url() {
3323 let mut cfg = validation_https_config();
3324 cfg.jwks_uri = "not a url".into();
3325 let err = cfg
3326 .validate()
3327 .expect_err("unparseable URL must be rejected");
3328 assert!(err.to_string().contains("invalid URL"));
3329 }
3330
3331 #[test]
3332 fn validate_rejects_non_http_scheme() {
3333 let mut cfg = validation_https_config();
3334 cfg.jwks_uri = "file:///etc/passwd".into();
3335 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3336 let msg = err.to_string();
3337 assert!(
3338 msg.contains("must use https scheme") && msg.contains("file"),
3339 "error must reject non-http(s) schemes; got {msg:?}"
3340 );
3341 }
3342
3343 #[test]
3344 fn validate_accepts_http_with_escape_hatch() {
3345 let mut cfg = OAuthConfig::builder(
3350 "http://auth.local",
3351 "mcp",
3352 "http://auth.local/.well-known/jwks.json",
3353 )
3354 .allow_http_oauth_urls(true)
3355 .build();
3356 cfg.proxy = Some(
3357 OAuthProxyConfig::builder(
3358 "http://idp.local/authorize",
3359 "http://idp.local/token",
3360 "client",
3361 )
3362 .introspection_url("http://idp.local/introspect")
3363 .revocation_url("http://idp.local/revoke")
3364 .build(),
3365 );
3366 cfg.token_exchange = Some(TokenExchangeConfig::new(
3367 "http://idp.local/token".into(),
3368 "client".into(),
3369 Some(secrecy::SecretString::new("dev-secret".into())),
3370 None,
3371 "downstream".into(),
3372 ));
3373 cfg.validate()
3374 .expect("escape hatch must permit http on all URL fields");
3375 }
3376
3377 #[test]
3378 fn validate_with_escape_hatch_still_rejects_unparseable() {
3379 let mut cfg = validation_https_config();
3382 cfg.allow_http_oauth_urls = true;
3383 cfg.jwks_uri = "::not-a-url::".into();
3384 cfg.validate()
3385 .expect_err("escape hatch must NOT bypass URL parsing");
3386 }
3387
3388 #[tokio::test]
3389 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3390 rustls::crypto::ring::default_provider()
3405 .install_default()
3406 .ok();
3407
3408 let policy = reqwest::redirect::Policy::custom(|attempt| {
3409 if attempt.url().scheme() != "https" {
3410 attempt.error("redirect to non-HTTPS URL refused")
3411 } else if attempt.previous().len() >= 2 {
3412 attempt.error("too many redirects (max 2)")
3413 } else {
3414 attempt.follow()
3415 }
3416 });
3417 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3424 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3425 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3426 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3427 );
3428 let client = reqwest::Client::builder()
3429 .no_proxy()
3430 .dns_resolver(Arc::clone(&resolver))
3431 .timeout(Duration::from_secs(5))
3432 .connect_timeout(Duration::from_secs(3))
3433 .redirect(policy)
3434 .build()
3435 .expect("test client builds");
3436
3437 let mock = wiremock::MockServer::start().await;
3438 wiremock::Mock::given(wiremock::matchers::method("GET"))
3439 .and(wiremock::matchers::path("/jwks.json"))
3440 .respond_with(
3441 wiremock::ResponseTemplate::new(302)
3442 .insert_header("location", "http://example.invalid/jwks.json"),
3443 )
3444 .mount(&mock)
3445 .await;
3446
3447 let url = format!("{}/jwks.json", mock.uri());
3456 let err = client
3457 .get(&url)
3458 .send()
3459 .await
3460 .expect_err("redirect policy must reject scheme downgrade");
3461 let chain = format!("{err:#}");
3462 assert!(
3463 chain.contains("redirect to non-HTTPS URL refused")
3464 || chain.to_lowercase().contains("redirect"),
3465 "error must surface redirect-policy rejection; got {chain:?}"
3466 );
3467 }
3468
3469 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3474
3475 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3477 let mut rng = rsa::rand_core::OsRng;
3478 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3479 let private_pem = private_key
3480 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3481 .expect("PKCS8 PEM export")
3482 .to_string();
3483
3484 let public_key = private_key.to_public_key();
3485 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3486 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3487
3488 let jwks = serde_json::json!({
3489 "keys": [{
3490 "kty": "RSA",
3491 "use": "sig",
3492 "alg": "RS256",
3493 "kid": kid,
3494 "n": n,
3495 "e": e
3496 }]
3497 });
3498
3499 (private_pem, jwks)
3500 }
3501
3502 fn mint_token(
3504 private_pem: &str,
3505 kid: &str,
3506 issuer: &str,
3507 audience: &str,
3508 subject: &str,
3509 scope: &str,
3510 ) -> String {
3511 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3512 .expect("encoding key from PEM");
3513 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3514 header.kid = Some(kid.into());
3515
3516 let now = jsonwebtoken::get_current_timestamp();
3517 let claims = serde_json::json!({
3518 "iss": issuer,
3519 "aud": audience,
3520 "sub": subject,
3521 "scope": scope,
3522 "exp": now + 3600,
3523 "iat": now,
3524 });
3525
3526 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3527 }
3528
3529 fn test_config(jwks_uri: &str) -> OAuthConfig {
3530 OAuthConfig {
3531 issuer: "https://auth.test.local".into(),
3532 audience: "https://mcp.test.local/mcp".into(),
3533 jwks_uri: jwks_uri.into(),
3534 scopes: vec![
3535 ScopeMapping {
3536 scope: "mcp:read".into(),
3537 role: "viewer".into(),
3538 },
3539 ScopeMapping {
3540 scope: "mcp:admin".into(),
3541 role: "ops".into(),
3542 },
3543 ],
3544 role_claim: None,
3545 role_mappings: vec![],
3546 jwks_cache_ttl: "5m".into(),
3547 proxy: None,
3548 token_exchange: None,
3549 ca_cert_path: None,
3550 allow_http_oauth_urls: true,
3551 max_jwks_keys: default_max_jwks_keys(),
3552 #[allow(
3553 deprecated,
3554 reason = "test fixture: explicit value for the deprecated field"
3555 )]
3556 strict_audience_validation: false,
3557 audience_validation_mode: None,
3558 jwks_max_response_bytes: default_jwks_max_bytes(),
3559 ssrf_allowlist: None,
3560 }
3561 }
3562
3563 fn test_cache(config: &OAuthConfig) -> JwksCache {
3564 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3565 }
3566
3567 #[tokio::test]
3568 async fn valid_jwt_returns_identity() {
3569 let kid = "test-key-1";
3570 let (pem, jwks) = generate_test_keypair(kid);
3571
3572 let mock_server = wiremock::MockServer::start().await;
3573 wiremock::Mock::given(wiremock::matchers::method("GET"))
3574 .and(wiremock::matchers::path("/jwks.json"))
3575 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3576 .mount(&mock_server)
3577 .await;
3578
3579 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3580 let config = test_config(&jwks_uri);
3581 let cache = test_cache(&config);
3582
3583 let token = mint_token(
3584 &pem,
3585 kid,
3586 "https://auth.test.local",
3587 "https://mcp.test.local/mcp",
3588 "ci-bot",
3589 "mcp:read mcp:other",
3590 );
3591
3592 let identity = cache.validate_token(&token).await;
3593 assert!(identity.is_some(), "valid JWT should authenticate");
3594 let id = identity.unwrap();
3595 assert_eq!(id.name, "ci-bot");
3596 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3598 }
3599
3600 #[tokio::test]
3601 async fn wrong_issuer_rejected() {
3602 let kid = "test-key-2";
3603 let (pem, jwks) = generate_test_keypair(kid);
3604
3605 let mock_server = wiremock::MockServer::start().await;
3606 wiremock::Mock::given(wiremock::matchers::method("GET"))
3607 .and(wiremock::matchers::path("/jwks.json"))
3608 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3609 .mount(&mock_server)
3610 .await;
3611
3612 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3613 let config = test_config(&jwks_uri);
3614 let cache = test_cache(&config);
3615
3616 let token = mint_token(
3617 &pem,
3618 kid,
3619 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3621 "attacker",
3622 "mcp:admin",
3623 );
3624
3625 assert!(cache.validate_token(&token).await.is_none());
3626 }
3627
3628 #[tokio::test]
3629 async fn wrong_audience_rejected() {
3630 let kid = "test-key-3";
3631 let (pem, jwks) = generate_test_keypair(kid);
3632
3633 let mock_server = wiremock::MockServer::start().await;
3634 wiremock::Mock::given(wiremock::matchers::method("GET"))
3635 .and(wiremock::matchers::path("/jwks.json"))
3636 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3637 .mount(&mock_server)
3638 .await;
3639
3640 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3641 let config = test_config(&jwks_uri);
3642 let cache = test_cache(&config);
3643
3644 let token = mint_token(
3645 &pem,
3646 kid,
3647 "https://auth.test.local",
3648 "https://wrong-audience.example.com", "attacker",
3650 "mcp:admin",
3651 );
3652
3653 assert!(cache.validate_token(&token).await.is_none());
3654 }
3655
3656 #[tokio::test]
3657 async fn expired_jwt_rejected() {
3658 let kid = "test-key-4";
3659 let (pem, jwks) = generate_test_keypair(kid);
3660
3661 let mock_server = wiremock::MockServer::start().await;
3662 wiremock::Mock::given(wiremock::matchers::method("GET"))
3663 .and(wiremock::matchers::path("/jwks.json"))
3664 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3665 .mount(&mock_server)
3666 .await;
3667
3668 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3669 let config = test_config(&jwks_uri);
3670 let cache = test_cache(&config);
3671
3672 let encoding_key =
3674 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3675 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3676 header.kid = Some(kid.into());
3677 let now = jsonwebtoken::get_current_timestamp();
3678 let claims = serde_json::json!({
3679 "iss": "https://auth.test.local",
3680 "aud": "https://mcp.test.local/mcp",
3681 "sub": "expired-bot",
3682 "scope": "mcp:read",
3683 "exp": now - 120,
3684 "iat": now - 3720,
3685 });
3686 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3687
3688 assert!(cache.validate_token(&token).await.is_none());
3689 }
3690
3691 #[tokio::test]
3692 async fn no_matching_scope_rejected() {
3693 let kid = "test-key-5";
3694 let (pem, jwks) = generate_test_keypair(kid);
3695
3696 let mock_server = wiremock::MockServer::start().await;
3697 wiremock::Mock::given(wiremock::matchers::method("GET"))
3698 .and(wiremock::matchers::path("/jwks.json"))
3699 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3700 .mount(&mock_server)
3701 .await;
3702
3703 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3704 let config = test_config(&jwks_uri);
3705 let cache = test_cache(&config);
3706
3707 let token = mint_token(
3708 &pem,
3709 kid,
3710 "https://auth.test.local",
3711 "https://mcp.test.local/mcp",
3712 "limited-bot",
3713 "some:other:scope", );
3715
3716 assert!(cache.validate_token(&token).await.is_none());
3717 }
3718
3719 #[tokio::test]
3720 async fn wrong_signing_key_rejected() {
3721 let kid = "test-key-6";
3722 let (_pem, jwks) = generate_test_keypair(kid);
3723
3724 let (attacker_pem, _) = generate_test_keypair(kid);
3726
3727 let mock_server = wiremock::MockServer::start().await;
3728 wiremock::Mock::given(wiremock::matchers::method("GET"))
3729 .and(wiremock::matchers::path("/jwks.json"))
3730 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3731 .mount(&mock_server)
3732 .await;
3733
3734 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3735 let config = test_config(&jwks_uri);
3736 let cache = test_cache(&config);
3737
3738 let token = mint_token(
3740 &attacker_pem,
3741 kid,
3742 "https://auth.test.local",
3743 "https://mcp.test.local/mcp",
3744 "attacker",
3745 "mcp:admin",
3746 );
3747
3748 assert!(cache.validate_token(&token).await.is_none());
3749 }
3750
3751 #[tokio::test]
3752 async fn admin_scope_maps_to_ops_role() {
3753 let kid = "test-key-7";
3754 let (pem, jwks) = generate_test_keypair(kid);
3755
3756 let mock_server = wiremock::MockServer::start().await;
3757 wiremock::Mock::given(wiremock::matchers::method("GET"))
3758 .and(wiremock::matchers::path("/jwks.json"))
3759 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3760 .mount(&mock_server)
3761 .await;
3762
3763 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3764 let config = test_config(&jwks_uri);
3765 let cache = test_cache(&config);
3766
3767 let token = mint_token(
3768 &pem,
3769 kid,
3770 "https://auth.test.local",
3771 "https://mcp.test.local/mcp",
3772 "admin-bot",
3773 "mcp:admin",
3774 );
3775
3776 let id = cache
3777 .validate_token(&token)
3778 .await
3779 .expect("should authenticate");
3780 assert_eq!(id.role, "ops");
3781 assert_eq!(id.name, "admin-bot");
3782 }
3783
3784 #[tokio::test]
3785 async fn jwks_server_down_returns_none() {
3786 let config = test_config("http://127.0.0.1:1/jwks.json");
3788 let cache = test_cache(&config);
3789
3790 let kid = "orphan-key";
3791 let (pem, _) = generate_test_keypair(kid);
3792 let token = mint_token(
3793 &pem,
3794 kid,
3795 "https://auth.test.local",
3796 "https://mcp.test.local/mcp",
3797 "bot",
3798 "mcp:read",
3799 );
3800
3801 assert!(cache.validate_token(&token).await.is_none());
3802 }
3803
3804 #[test]
3809 fn resolve_claim_path_flat_string() {
3810 let mut extra = HashMap::new();
3811 extra.insert(
3812 "scope".into(),
3813 serde_json::Value::String("mcp:read mcp:admin".into()),
3814 );
3815 let values = resolve_claim_path(&extra, "scope");
3816 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3817 }
3818
3819 #[test]
3820 fn resolve_claim_path_flat_array() {
3821 let mut extra = HashMap::new();
3822 extra.insert(
3823 "roles".into(),
3824 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3825 );
3826 let values = resolve_claim_path(&extra, "roles");
3827 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3828 }
3829
3830 #[test]
3831 fn resolve_claim_path_nested_keycloak() {
3832 let mut extra = HashMap::new();
3833 extra.insert(
3834 "realm_access".into(),
3835 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3836 );
3837 let values = resolve_claim_path(&extra, "realm_access.roles");
3838 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3839 }
3840
3841 #[test]
3842 fn resolve_claim_path_missing_returns_empty() {
3843 let extra = HashMap::new();
3844 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3845 }
3846
3847 #[test]
3848 fn resolve_claim_path_numeric_leaf_returns_empty() {
3849 let mut extra = HashMap::new();
3850 extra.insert("count".into(), serde_json::json!(42));
3851 assert!(resolve_claim_path(&extra, "count").is_empty());
3852 }
3853
3854 fn make_claims(json: serde_json::Value) -> Claims {
3855 serde_json::from_value(json).expect("test claims must deserialize")
3856 }
3857
3858 #[test]
3859 fn first_class_scope_claim_splits_on_whitespace() {
3860 let claims = make_claims(serde_json::json!({
3861 "iss": "https://issuer.example.com",
3862 "exp": 9_999_999_999_u64,
3863 "scope": "read write admin",
3864 }));
3865 let values = first_class_claim_values(&claims, "scope");
3866 assert_eq!(values, vec!["read", "write", "admin"]);
3867 }
3868
3869 #[test]
3870 fn first_class_sub_claim_returns_single_value() {
3871 let claims = make_claims(serde_json::json!({
3872 "iss": "https://issuer.example.com",
3873 "exp": 9_999_999_999_u64,
3874 "sub": "service-account-orders",
3875 }));
3876 let values = first_class_claim_values(&claims, "sub");
3877 assert_eq!(values, vec!["service-account-orders"]);
3878 }
3879
3880 #[test]
3881 fn first_class_aud_claim_returns_every_audience() {
3882 let claims = make_claims(serde_json::json!({
3883 "iss": "https://issuer.example.com",
3884 "exp": 9_999_999_999_u64,
3885 "aud": ["api-a", "api-b"],
3886 }));
3887 let values = first_class_claim_values(&claims, "aud");
3888 assert_eq!(values, vec!["api-a", "api-b"]);
3889 }
3890
3891 #[test]
3892 fn first_class_unknown_path_returns_empty() {
3893 let claims = make_claims(serde_json::json!({
3894 "iss": "https://issuer.example.com",
3895 "exp": 9_999_999_999_u64,
3896 }));
3897 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
3898 }
3899
3900 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3906 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3907 .expect("encoding key from PEM");
3908 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3909 header.kid = Some(kid.into());
3910 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3911 }
3912
3913 fn test_config_with_role_claim(
3914 jwks_uri: &str,
3915 role_claim: &str,
3916 role_mappings: Vec<RoleMapping>,
3917 ) -> OAuthConfig {
3918 OAuthConfig {
3919 issuer: "https://auth.test.local".into(),
3920 audience: "https://mcp.test.local/mcp".into(),
3921 jwks_uri: jwks_uri.into(),
3922 scopes: vec![],
3923 role_claim: Some(role_claim.into()),
3924 role_mappings,
3925 jwks_cache_ttl: "5m".into(),
3926 proxy: None,
3927 token_exchange: None,
3928 ca_cert_path: None,
3929 allow_http_oauth_urls: true,
3930 max_jwks_keys: default_max_jwks_keys(),
3931 #[allow(
3932 deprecated,
3933 reason = "test fixture: explicit value for the deprecated field"
3934 )]
3935 strict_audience_validation: false,
3936 audience_validation_mode: None,
3937 jwks_max_response_bytes: default_jwks_max_bytes(),
3938 ssrf_allowlist: None,
3939 }
3940 }
3941
3942 #[tokio::test]
3943 async fn screen_oauth_target_rejects_literal_ip() {
3944 let err = screen_oauth_target(
3945 "https://127.0.0.1/jwks.json",
3946 false,
3947 &crate::ssrf::CompiledSsrfAllowlist::default(),
3948 )
3949 .await
3950 .expect_err("literal IPs must be rejected");
3951 let msg = err.to_string();
3952 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3953 }
3954
3955 #[tokio::test]
3956 async fn screen_oauth_target_rejects_private_dns_resolution() {
3957 let err = screen_oauth_target(
3958 "https://localhost/jwks.json",
3959 false,
3960 &crate::ssrf::CompiledSsrfAllowlist::default(),
3961 )
3962 .await
3963 .expect_err("localhost resolution must be rejected");
3964 let msg = err.to_string();
3965 assert!(
3966 msg.contains("blocked IP") && msg.contains("loopback"),
3967 "got {msg:?}"
3968 );
3969 }
3970
3971 #[tokio::test]
3972 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
3973 let err = screen_oauth_target(
3974 "http://127.0.0.1/jwks.json",
3975 true,
3976 &crate::ssrf::CompiledSsrfAllowlist::default(),
3977 )
3978 .await
3979 .expect_err("literal IPs must still be rejected when http is allowed");
3980 let msg = err.to_string();
3981 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3982 }
3983
3984 #[tokio::test]
3985 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
3986 let err = screen_oauth_target(
3987 "http://localhost/jwks.json",
3988 true,
3989 &crate::ssrf::CompiledSsrfAllowlist::default(),
3990 )
3991 .await
3992 .expect_err("private DNS resolution must still be rejected when http is allowed");
3993 let msg = err.to_string();
3994 assert!(
3995 msg.contains("blocked IP") && msg.contains("loopback"),
3996 "got {msg:?}"
3997 );
3998 }
3999
4000 #[tokio::test]
4001 async fn screen_oauth_target_allows_public_hostname() {
4002 screen_oauth_target(
4003 "https://example.com/.well-known/jwks.json",
4004 false,
4005 &crate::ssrf::CompiledSsrfAllowlist::default(),
4006 )
4007 .await
4008 .expect("public hostname should pass screening");
4009 }
4010
4011 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4017 let raw = OAuthSsrfAllowlist {
4018 hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
4019 cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
4020 };
4021 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4022 }
4023
4024 #[test]
4025 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4026 let raw = OAuthSsrfAllowlist {
4027 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4028 cidrs: vec![],
4029 };
4030 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4031 assert_eq!(compiled.host_count(), 1);
4032 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4033 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4034 }
4035
4036 #[test]
4037 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4038 let raw = OAuthSsrfAllowlist {
4039 hosts: vec!["10.0.0.1".into()],
4040 cidrs: vec![],
4041 };
4042 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4043 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4044 }
4045
4046 #[test]
4047 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4048 let raw = OAuthSsrfAllowlist {
4049 hosts: vec!["rhbk.ops.example.com:8443".into()],
4050 cidrs: vec![],
4051 };
4052 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4053 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4054 }
4055
4056 #[test]
4057 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4058 let raw = OAuthSsrfAllowlist {
4059 hosts: vec![],
4060 cidrs: vec!["not-a-cidr".into()],
4061 };
4062 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4063 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4064 }
4065
4066 #[test]
4067 fn validate_rejects_misconfigured_allowlist() {
4068 let mut cfg = OAuthConfig::builder(
4069 "https://auth.example.com/",
4070 "mcp",
4071 "https://auth.example.com/jwks.json",
4072 )
4073 .build();
4074 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4075 hosts: vec!["10.0.0.1".into()],
4076 cidrs: vec![],
4077 });
4078 let err = cfg
4079 .validate()
4080 .expect_err("literal IP host must be rejected");
4081 assert!(
4082 err.to_string().contains("oauth.ssrf_allowlist"),
4083 "got {err}"
4084 );
4085 }
4086
4087 #[tokio::test]
4088 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4089 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4093 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4094 .await
4095 .expect_err("loopback must still be blocked when not in allowlist");
4096 let msg = err.to_string();
4097 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4098 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4099 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4100 }
4101
4102 #[tokio::test]
4103 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4104 let err = screen_oauth_target(
4107 "https://localhost/jwks.json",
4108 false,
4109 &crate::ssrf::CompiledSsrfAllowlist::default(),
4110 )
4111 .await
4112 .expect_err("loopback rejection");
4113 let msg = err.to_string();
4114 assert!(msg.contains("blocked IP"), "got {msg:?}");
4115 assert!(msg.contains("loopback"), "got {msg:?}");
4116 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4118 }
4119
4120 #[tokio::test]
4121 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4122 let allow = make_allowlist(&["localhost"], &[]);
4124 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4125 .await
4126 .expect("allowlisted host must pass");
4127 }
4128
4129 #[tokio::test]
4130 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4131 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4134 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4135 .await
4136 .expect("allowlisted CIDR must pass");
4137 }
4138
4139 #[tokio::test]
4140 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4141 let mut cfg = OAuthConfig::builder(
4142 "https://auth.example.com/",
4143 "mcp",
4144 "https://auth.example.com/jwks.json",
4145 )
4146 .build();
4147 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4148 hosts: vec![],
4149 cidrs: vec!["bad-cidr".into()],
4150 });
4151 let Err(err) = JwksCache::new(&cfg) else {
4152 panic!("invalid CIDR must fail JwksCache::new")
4153 };
4154 let msg = err.to_string();
4155 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4156 }
4157
4158 #[tokio::test]
4159 async fn jwks_cache_new_invalid_ttl_is_err() {
4160 let cfg = OAuthConfig::builder(
4163 "https://auth.example.com/",
4164 "mcp",
4165 "https://auth.example.com/jwks.json",
4166 )
4167 .jwks_cache_ttl("not-a-duration")
4168 .build();
4169 let Err(err) = JwksCache::new(&cfg) else {
4170 panic!("invalid jwks_cache_ttl must fail JwksCache::new")
4171 };
4172 let msg = err.to_string();
4173 assert!(msg.contains("jwks_cache_ttl"), "got {msg:?}");
4174 }
4175
4176 #[tokio::test]
4177 async fn audience_falls_back_to_azp_by_default() {
4178 let kid = "test-audience-azp-default";
4179 let (pem, jwks) = generate_test_keypair(kid);
4180
4181 let mock_server = wiremock::MockServer::start().await;
4182 wiremock::Mock::given(wiremock::matchers::method("GET"))
4183 .and(wiremock::matchers::path("/jwks.json"))
4184 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4185 .mount(&mock_server)
4186 .await;
4187
4188 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4189 let config = test_config(&jwks_uri);
4190 let cache = test_cache(&config);
4191
4192 let now = jsonwebtoken::get_current_timestamp();
4193 let token = mint_token_with_claims(
4194 &pem,
4195 kid,
4196 &serde_json::json!({
4197 "iss": "https://auth.test.local",
4198 "aud": "https://some-other-resource.example.com",
4199 "azp": "https://mcp.test.local/mcp",
4200 "sub": "compat-client",
4201 "scope": "mcp:read",
4202 "exp": now + 3600,
4203 "iat": now,
4204 }),
4205 );
4206
4207 let identity = cache
4208 .validate_token_with_reason(&token)
4209 .await
4210 .expect("azp fallback should remain enabled by default");
4211 assert_eq!(identity.role, "viewer");
4212 }
4213
4214 #[tokio::test]
4215 async fn strict_audience_validation_rejects_azp_only_match() {
4216 let kid = "test-audience-azp-strict";
4217 let (pem, jwks) = generate_test_keypair(kid);
4218
4219 let mock_server = wiremock::MockServer::start().await;
4220 wiremock::Mock::given(wiremock::matchers::method("GET"))
4221 .and(wiremock::matchers::path("/jwks.json"))
4222 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4223 .mount(&mock_server)
4224 .await;
4225
4226 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4227 let mut config = test_config(&jwks_uri);
4228 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4229 {
4230 config.strict_audience_validation = true;
4231 }
4232 let cache = test_cache(&config);
4233
4234 let now = jsonwebtoken::get_current_timestamp();
4235 let token = mint_token_with_claims(
4236 &pem,
4237 kid,
4238 &serde_json::json!({
4239 "iss": "https://auth.test.local",
4240 "aud": "https://some-other-resource.example.com",
4241 "azp": "https://mcp.test.local/mcp",
4242 "sub": "strict-client",
4243 "scope": "mcp:read",
4244 "exp": now + 3600,
4245 "iat": now,
4246 }),
4247 );
4248
4249 let failure = cache
4250 .validate_token_with_reason(&token)
4251 .await
4252 .expect_err("strict audience validation must ignore azp fallback");
4253 assert_eq!(failure, JwtValidationFailure::Invalid);
4254 }
4255
4256 #[tokio::test]
4257 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4258 let kid = "test-audience-warn-mode";
4259 let (pem, jwks) = generate_test_keypair(kid);
4260
4261 let mock_server = wiremock::MockServer::start().await;
4262 wiremock::Mock::given(wiremock::matchers::method("GET"))
4263 .and(wiremock::matchers::path("/jwks.json"))
4264 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4265 .mount(&mock_server)
4266 .await;
4267
4268 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4269 let mut config = test_config(&jwks_uri);
4270 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4271 let cache = test_cache(&config);
4272
4273 let now = jsonwebtoken::get_current_timestamp();
4274 let claims = serde_json::json!({
4275 "iss": "https://auth.test.local",
4276 "aud": "https://some-other-resource.example.com",
4277 "azp": "https://mcp.test.local/mcp",
4278 "sub": "warn-client",
4279 "scope": "mcp:read",
4280 "exp": now + 3600,
4281 "iat": now,
4282 });
4283 let token = mint_token_with_claims(&pem, kid, &claims);
4284
4285 let identity = cache
4286 .validate_token_with_reason(&token)
4287 .await
4288 .expect("warn mode must accept azp-only match");
4289 assert_eq!(identity.role, "viewer");
4290 assert!(
4291 cache.azp_fallback_warned.load(Ordering::Relaxed),
4292 "warn-once flag should be set after first azp-only match"
4293 );
4294
4295 let token2 = mint_token_with_claims(&pem, kid, &claims);
4296 cache
4297 .validate_token_with_reason(&token2)
4298 .await
4299 .expect("warn mode must continue accepting subsequent matches");
4300 assert!(
4301 cache.azp_fallback_warned.load(Ordering::Relaxed),
4302 "warn-once flag must remain set; the assertion guards against accidental clearing"
4303 );
4304 }
4305
4306 #[tokio::test]
4307 async fn permissive_mode_accepts_azp_only_match_silently() {
4308 let kid = "test-audience-permissive-mode";
4309 let (pem, jwks) = generate_test_keypair(kid);
4310
4311 let mock_server = wiremock::MockServer::start().await;
4312 wiremock::Mock::given(wiremock::matchers::method("GET"))
4313 .and(wiremock::matchers::path("/jwks.json"))
4314 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4315 .mount(&mock_server)
4316 .await;
4317
4318 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4319 let mut config = test_config(&jwks_uri);
4320 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4321 let cache = test_cache(&config);
4322
4323 let now = jsonwebtoken::get_current_timestamp();
4324 let token = mint_token_with_claims(
4325 &pem,
4326 kid,
4327 &serde_json::json!({
4328 "iss": "https://auth.test.local",
4329 "aud": "https://some-other-resource.example.com",
4330 "azp": "https://mcp.test.local/mcp",
4331 "sub": "permissive-client",
4332 "scope": "mcp:read",
4333 "exp": now + 3600,
4334 "iat": now,
4335 }),
4336 );
4337
4338 cache
4339 .validate_token_with_reason(&token)
4340 .await
4341 .expect("permissive mode must accept azp-only match");
4342 assert!(
4343 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4344 "permissive mode must not flip the warn-once flag"
4345 );
4346 }
4347
4348 #[test]
4349 fn audience_validation_mode_overrides_legacy_bool() {
4350 let mut config = OAuthConfig::default();
4351 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4352 {
4353 config.strict_audience_validation = false;
4354 }
4355 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4356 assert_eq!(
4357 config.effective_audience_validation_mode(),
4358 AudienceValidationMode::Strict,
4359 "explicit mode must override legacy false"
4360 );
4361
4362 let mut config = OAuthConfig::default();
4363 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4364 {
4365 config.strict_audience_validation = true;
4366 }
4367 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4368 assert_eq!(
4369 config.effective_audience_validation_mode(),
4370 AudienceValidationMode::Permissive,
4371 "explicit mode must override legacy true"
4372 );
4373 }
4374
4375 #[test]
4376 fn audience_validation_mode_default_is_warn_when_unset() {
4377 let config = OAuthConfig::default();
4378 assert_eq!(
4379 config.effective_audience_validation_mode(),
4380 AudienceValidationMode::Warn,
4381 "unset mode + unset bool must resolve to Warn (the new default)"
4382 );
4383 }
4384
4385 #[test]
4386 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4387 let mut config = OAuthConfig::default();
4388 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4389 {
4390 config.strict_audience_validation = true;
4391 }
4392 assert_eq!(
4393 config.effective_audience_validation_mode(),
4394 AudienceValidationMode::Strict,
4395 "legacy bool=true must resolve to Strict for backward compat"
4396 );
4397 }
4398
4399 #[derive(Clone, Default)]
4400 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4401
4402 impl CapturedLogs {
4403 fn contents(&self) -> String {
4404 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4405 String::from_utf8(bytes).unwrap_or_default()
4406 }
4407 }
4408
4409 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4410
4411 impl std::io::Write for CapturedLogsWriter {
4412 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4413 if let Ok(mut guard) = self.0.lock() {
4414 guard.extend_from_slice(buf);
4415 }
4416 Ok(buf.len())
4417 }
4418
4419 fn flush(&mut self) -> std::io::Result<()> {
4420 Ok(())
4421 }
4422 }
4423
4424 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4425 type Writer = CapturedLogsWriter;
4426
4427 fn make_writer(&'a self) -> Self::Writer {
4428 CapturedLogsWriter(Arc::clone(&self.0))
4429 }
4430 }
4431
4432 #[tokio::test]
4433 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4434 let kid = "oversized-jwks";
4435 let (_pem, jwks) = generate_test_keypair(kid);
4436 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4437 oversized_body.push_str(&" ".repeat(4096));
4438
4439 let mock_server = wiremock::MockServer::start().await;
4440 wiremock::Mock::given(wiremock::matchers::method("GET"))
4441 .and(wiremock::matchers::path("/jwks.json"))
4442 .respond_with(
4443 wiremock::ResponseTemplate::new(200)
4444 .insert_header("content-type", "application/json")
4445 .set_body_string(oversized_body),
4446 )
4447 .mount(&mock_server)
4448 .await;
4449
4450 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4451 let mut config = test_config(&jwks_uri);
4452 config.jwks_max_response_bytes = 256;
4453 let cache = test_cache(&config);
4454
4455 let logs = CapturedLogs::default();
4456 let subscriber = tracing_subscriber::fmt()
4457 .with_writer(logs.clone())
4458 .with_ansi(false)
4459 .without_time()
4460 .finish();
4461 let _guard = tracing::subscriber::set_default(subscriber);
4462
4463 let result = cache.fetch_jwks().await;
4464 assert!(result.is_none(), "oversized JWKS must be dropped");
4465 assert!(
4466 logs.contents()
4467 .contains("JWKS response exceeded configured size cap"),
4468 "expected cap-exceeded warning in logs"
4469 );
4470 }
4471
4472 #[tokio::test]
4473 async fn role_claim_keycloak_nested_array() {
4474 let kid = "test-role-1";
4475 let (pem, jwks) = generate_test_keypair(kid);
4476
4477 let mock_server = wiremock::MockServer::start().await;
4478 wiremock::Mock::given(wiremock::matchers::method("GET"))
4479 .and(wiremock::matchers::path("/jwks.json"))
4480 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4481 .mount(&mock_server)
4482 .await;
4483
4484 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4485 let config = test_config_with_role_claim(
4486 &jwks_uri,
4487 "realm_access.roles",
4488 vec![
4489 RoleMapping {
4490 claim_value: "mcp-admin".into(),
4491 role: "ops".into(),
4492 },
4493 RoleMapping {
4494 claim_value: "mcp-viewer".into(),
4495 role: "viewer".into(),
4496 },
4497 ],
4498 );
4499 let cache = test_cache(&config);
4500
4501 let now = jsonwebtoken::get_current_timestamp();
4502 let token = mint_token_with_claims(
4503 &pem,
4504 kid,
4505 &serde_json::json!({
4506 "iss": "https://auth.test.local",
4507 "aud": "https://mcp.test.local/mcp",
4508 "sub": "keycloak-user",
4509 "exp": now + 3600,
4510 "iat": now,
4511 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4512 }),
4513 );
4514
4515 let id = cache
4516 .validate_token(&token)
4517 .await
4518 .expect("should authenticate");
4519 assert_eq!(id.name, "keycloak-user");
4520 assert_eq!(id.role, "ops");
4521 }
4522
4523 #[tokio::test]
4524 async fn role_claim_flat_roles_array() {
4525 let kid = "test-role-2";
4526 let (pem, jwks) = generate_test_keypair(kid);
4527
4528 let mock_server = wiremock::MockServer::start().await;
4529 wiremock::Mock::given(wiremock::matchers::method("GET"))
4530 .and(wiremock::matchers::path("/jwks.json"))
4531 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4532 .mount(&mock_server)
4533 .await;
4534
4535 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4536 let config = test_config_with_role_claim(
4537 &jwks_uri,
4538 "roles",
4539 vec![
4540 RoleMapping {
4541 claim_value: "MCP.Admin".into(),
4542 role: "ops".into(),
4543 },
4544 RoleMapping {
4545 claim_value: "MCP.Reader".into(),
4546 role: "viewer".into(),
4547 },
4548 ],
4549 );
4550 let cache = test_cache(&config);
4551
4552 let now = jsonwebtoken::get_current_timestamp();
4553 let token = mint_token_with_claims(
4554 &pem,
4555 kid,
4556 &serde_json::json!({
4557 "iss": "https://auth.test.local",
4558 "aud": "https://mcp.test.local/mcp",
4559 "sub": "azure-ad-user",
4560 "exp": now + 3600,
4561 "iat": now,
4562 "roles": ["MCP.Reader", "OtherApp.Admin"]
4563 }),
4564 );
4565
4566 let id = cache
4567 .validate_token(&token)
4568 .await
4569 .expect("should authenticate");
4570 assert_eq!(id.name, "azure-ad-user");
4571 assert_eq!(id.role, "viewer");
4572 }
4573
4574 #[tokio::test]
4575 async fn role_claim_no_matching_value_rejected() {
4576 let kid = "test-role-3";
4577 let (pem, jwks) = generate_test_keypair(kid);
4578
4579 let mock_server = wiremock::MockServer::start().await;
4580 wiremock::Mock::given(wiremock::matchers::method("GET"))
4581 .and(wiremock::matchers::path("/jwks.json"))
4582 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4583 .mount(&mock_server)
4584 .await;
4585
4586 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4587 let config = test_config_with_role_claim(
4588 &jwks_uri,
4589 "roles",
4590 vec![RoleMapping {
4591 claim_value: "mcp-admin".into(),
4592 role: "ops".into(),
4593 }],
4594 );
4595 let cache = test_cache(&config);
4596
4597 let now = jsonwebtoken::get_current_timestamp();
4598 let token = mint_token_with_claims(
4599 &pem,
4600 kid,
4601 &serde_json::json!({
4602 "iss": "https://auth.test.local",
4603 "aud": "https://mcp.test.local/mcp",
4604 "sub": "limited-user",
4605 "exp": now + 3600,
4606 "iat": now,
4607 "roles": ["some-other-role"]
4608 }),
4609 );
4610
4611 assert!(cache.validate_token(&token).await.is_none());
4612 }
4613
4614 #[tokio::test]
4615 async fn role_claim_space_separated_string() {
4616 let kid = "test-role-4";
4617 let (pem, jwks) = generate_test_keypair(kid);
4618
4619 let mock_server = wiremock::MockServer::start().await;
4620 wiremock::Mock::given(wiremock::matchers::method("GET"))
4621 .and(wiremock::matchers::path("/jwks.json"))
4622 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4623 .mount(&mock_server)
4624 .await;
4625
4626 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4627 let config = test_config_with_role_claim(
4628 &jwks_uri,
4629 "custom_scope",
4630 vec![
4631 RoleMapping {
4632 claim_value: "write".into(),
4633 role: "ops".into(),
4634 },
4635 RoleMapping {
4636 claim_value: "read".into(),
4637 role: "viewer".into(),
4638 },
4639 ],
4640 );
4641 let cache = test_cache(&config);
4642
4643 let now = jsonwebtoken::get_current_timestamp();
4644 let token = mint_token_with_claims(
4645 &pem,
4646 kid,
4647 &serde_json::json!({
4648 "iss": "https://auth.test.local",
4649 "aud": "https://mcp.test.local/mcp",
4650 "sub": "custom-client",
4651 "exp": now + 3600,
4652 "iat": now,
4653 "custom_scope": "read audit"
4654 }),
4655 );
4656
4657 let id = cache
4658 .validate_token(&token)
4659 .await
4660 .expect("should authenticate");
4661 assert_eq!(id.name, "custom-client");
4662 assert_eq!(id.role, "viewer");
4663 }
4664
4665 #[tokio::test]
4666 async fn scope_backward_compat_without_role_claim() {
4667 let kid = "test-compat-1";
4669 let (pem, jwks) = generate_test_keypair(kid);
4670
4671 let mock_server = wiremock::MockServer::start().await;
4672 wiremock::Mock::given(wiremock::matchers::method("GET"))
4673 .and(wiremock::matchers::path("/jwks.json"))
4674 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4675 .mount(&mock_server)
4676 .await;
4677
4678 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4679 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4681
4682 let token = mint_token(
4683 &pem,
4684 kid,
4685 "https://auth.test.local",
4686 "https://mcp.test.local/mcp",
4687 "legacy-bot",
4688 "mcp:admin other:scope",
4689 );
4690
4691 let id = cache
4692 .validate_token(&token)
4693 .await
4694 .expect("should authenticate");
4695 assert_eq!(id.name, "legacy-bot");
4696 assert_eq!(id.role, "ops"); }
4698
4699 #[tokio::test]
4704 async fn jwks_refresh_deduplication() {
4705 let kid = "test-dedup";
4708 let (pem, jwks) = generate_test_keypair(kid);
4709
4710 let mock_server = wiremock::MockServer::start().await;
4711 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4712 .and(wiremock::matchers::path("/jwks.json"))
4713 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4714 .expect(1) .mount(&mock_server)
4716 .await;
4717
4718 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4719 let config = test_config(&jwks_uri);
4720 let cache = Arc::new(test_cache(&config));
4721
4722 let token = mint_token(
4724 &pem,
4725 kid,
4726 "https://auth.test.local",
4727 "https://mcp.test.local/mcp",
4728 "concurrent-bot",
4729 "mcp:read",
4730 );
4731
4732 let mut handles = Vec::new();
4733 for _ in 0..5 {
4734 let c = Arc::clone(&cache);
4735 let t = token.clone();
4736 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4737 }
4738
4739 for h in handles {
4740 let result = h.await.unwrap();
4741 assert!(result.is_some(), "all concurrent requests should succeed");
4742 }
4743
4744 }
4746
4747 #[tokio::test]
4748 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4749 let kid = "test-cooldown";
4752 let (_pem, jwks) = generate_test_keypair(kid);
4753
4754 let mock_server = wiremock::MockServer::start().await;
4755 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4756 .and(wiremock::matchers::path("/jwks.json"))
4757 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4758 .expect(1) .mount(&mock_server)
4760 .await;
4761
4762 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4763 let config = test_config(&jwks_uri);
4764 let cache = test_cache(&config);
4765
4766 let fake_token1 =
4768 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4769 let _ = cache.validate_token(fake_token1).await;
4770
4771 let fake_token2 =
4774 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4775 let _ = cache.validate_token(fake_token2).await;
4776
4777 let fake_token3 =
4779 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4780 let _ = cache.validate_token(fake_token3).await;
4781
4782 }
4784
4785 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
4788 OAuthProxyConfig {
4789 authorize_url: "https://example.invalid/auth".into(),
4790 token_url: token_url.into(),
4791 client_id: "mcp-client".into(),
4792 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
4793 introspection_url: None,
4794 revocation_url: None,
4795 expose_admin_endpoints: false,
4796 require_auth_on_admin_endpoints: false,
4797 allow_unauthenticated_admin_endpoints: false,
4798 }
4799 }
4800
4801 fn test_http_client() -> OauthHttpClient {
4804 rustls::crypto::ring::default_provider()
4805 .install_default()
4806 .ok();
4807 let config = OAuthConfig::builder(
4808 "https://auth.test.local",
4809 "https://mcp.test.local/mcp",
4810 "https://auth.test.local/.well-known/jwks.json",
4811 )
4812 .allow_http_oauth_urls(true)
4813 .build();
4814 OauthHttpClient::with_config(&config)
4815 .expect("build test http client")
4816 .__test_allow_loopback_ssrf()
4817 }
4818
4819 #[tokio::test]
4820 async fn introspect_proxies_and_injects_client_credentials() {
4821 use wiremock::matchers::{body_string_contains, method, path};
4822
4823 let mock_server = wiremock::MockServer::start().await;
4824 wiremock::Mock::given(method("POST"))
4825 .and(path("/introspect"))
4826 .and(body_string_contains("client_id=mcp-client"))
4827 .and(body_string_contains("client_secret=shh"))
4828 .and(body_string_contains("token=abc"))
4829 .respond_with(
4830 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
4831 "active": true,
4832 "scope": "read"
4833 })),
4834 )
4835 .expect(1)
4836 .mount(&mock_server)
4837 .await;
4838
4839 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4840 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
4841
4842 let http = test_http_client();
4843 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4844 assert_eq!(resp.status(), 200);
4845 }
4846
4847 #[tokio::test]
4848 async fn introspect_returns_404_when_not_configured() {
4849 let proxy = proxy_cfg("https://example.invalid/token");
4850 let http = test_http_client();
4851 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4852 assert_eq!(resp.status(), 404);
4853 }
4854
4855 #[tokio::test]
4856 async fn revoke_proxies_and_returns_upstream_status() {
4857 use wiremock::matchers::{method, path};
4858
4859 let mock_server = wiremock::MockServer::start().await;
4860 wiremock::Mock::given(method("POST"))
4861 .and(path("/revoke"))
4862 .respond_with(wiremock::ResponseTemplate::new(200))
4863 .expect(1)
4864 .mount(&mock_server)
4865 .await;
4866
4867 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4868 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4869
4870 let http = test_http_client();
4871 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4872 assert_eq!(resp.status(), 200);
4873 }
4874
4875 #[tokio::test]
4876 async fn revoke_returns_404_when_not_configured() {
4877 let proxy = proxy_cfg("https://example.invalid/token");
4878 let http = test_http_client();
4879 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4880 assert_eq!(resp.status(), 404);
4881 }
4882
4883 #[test]
4884 fn metadata_advertises_endpoints_only_when_configured() {
4885 let mut cfg = test_config("https://auth.test.local/jwks.json");
4886 let m = authorization_server_metadata("https://mcp.local", &cfg);
4888 assert!(m.get("introspection_endpoint").is_none());
4889 assert!(m.get("revocation_endpoint").is_none());
4890
4891 let mut proxy = proxy_cfg("https://upstream.local/token");
4894 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4895 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4896 cfg.proxy = Some(proxy);
4897 let m = authorization_server_metadata("https://mcp.local", &cfg);
4898 assert!(
4899 m.get("introspection_endpoint").is_none(),
4900 "introspection must not be advertised when expose_admin_endpoints=false"
4901 );
4902 assert!(
4903 m.get("revocation_endpoint").is_none(),
4904 "revocation must not be advertised when expose_admin_endpoints=false"
4905 );
4906
4907 if let Some(p) = cfg.proxy.as_mut() {
4909 p.expose_admin_endpoints = true;
4910 p.revocation_url = None;
4911 }
4912 let m = authorization_server_metadata("https://mcp.local", &cfg);
4913 assert_eq!(
4914 m["introspection_endpoint"],
4915 serde_json::Value::String("https://mcp.local/introspect".into())
4916 );
4917 assert!(m.get("revocation_endpoint").is_none());
4918
4919 if let Some(p) = cfg.proxy.as_mut() {
4921 p.revocation_url = Some("https://upstream.local/revoke".into());
4922 }
4923 let m = authorization_server_metadata("https://mcp.local", &cfg);
4924 assert_eq!(
4925 m["revocation_endpoint"],
4926 serde_json::Value::String("https://mcp.local/revoke".into())
4927 );
4928 }
4929
4930 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
4933 let mut cfg = validation_https_config();
4934 cfg.token_exchange = Some(tx);
4935 cfg
4936 }
4937
4938 fn tx_with(
4939 client_secret: Option<&str>,
4940 client_cert: Option<ClientCertConfig>,
4941 ) -> TokenExchangeConfig {
4942 TokenExchangeConfig::new(
4943 "https://idp.example.com/token".into(),
4944 "client".into(),
4945 client_secret.map(|s| secrecy::SecretString::new(s.into())),
4946 client_cert,
4947 "downstream".into(),
4948 )
4949 }
4950
4951 #[test]
4952 fn validate_rejects_token_exchange_without_client_auth() {
4953 let cfg = https_cfg_with_tx(tx_with(None, None));
4954 let err = cfg
4955 .validate()
4956 .expect_err("token_exchange without client auth must be rejected");
4957 let msg = err.to_string();
4958 assert!(
4959 msg.contains("requires client authentication"),
4960 "error must explain missing client auth; got {msg:?}"
4961 );
4962 }
4963
4964 #[test]
4965 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
4966 let cc = ClientCertConfig {
4967 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4968 key_path: PathBuf::from("/nonexistent/key.pem"),
4969 };
4970 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
4971 let err = cfg
4972 .validate()
4973 .expect_err("client_secret + client_cert must be rejected");
4974 let msg = err.to_string();
4975 assert!(
4976 msg.contains("mutually") && msg.contains("exclusive"),
4977 "error must explain mutual exclusion; got {msg:?}"
4978 );
4979 }
4980
4981 #[cfg(not(feature = "oauth-mtls-client"))]
4982 #[test]
4983 fn validate_rejects_client_cert_without_feature() {
4984 let cc = ClientCertConfig {
4985 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4986 key_path: PathBuf::from("/nonexistent/key.pem"),
4987 };
4988 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
4989 let err = cfg
4990 .validate()
4991 .expect_err("client_cert without feature must be rejected");
4992 assert!(
4993 err.to_string().contains("oauth-mtls-client"),
4994 "error must reference the cargo feature; got {err}"
4995 );
4996 }
4997
4998 #[cfg(feature = "oauth-mtls-client")]
4999 #[test]
5000 fn validate_rejects_missing_client_cert_files() {
5001 let cc = ClientCertConfig {
5002 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5003 key_path: PathBuf::from("/nonexistent/key.pem"),
5004 };
5005 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5006 let err = cfg
5007 .validate()
5008 .expect_err("missing cert file must be rejected");
5009 assert!(
5010 err.to_string().contains("unreadable"),
5011 "error must call out unreadable file; got {err}"
5012 );
5013 }
5014
5015 #[cfg(feature = "oauth-mtls-client")]
5016 #[test]
5017 fn validate_rejects_malformed_client_cert_pem() {
5018 let dir = std::env::temp_dir();
5019 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5020 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5021 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5022 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5023 let cc = ClientCertConfig {
5024 cert_path: cert.clone(),
5025 key_path: key.clone(),
5026 };
5027 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5028 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5029 let _ = std::fs::remove_file(&cert);
5030 let _ = std::fs::remove_file(&key);
5031 assert!(
5032 err.to_string().contains("PEM parse failed"),
5033 "error must call out PEM parse failure; got {err}"
5034 );
5035 }
5036
5037 #[cfg(feature = "oauth-mtls-client")]
5038 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5039 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5040 let dir = std::env::temp_dir();
5041 let pid = std::process::id();
5042 let nonce: u64 = rand::random();
5043 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5044 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5045 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5046 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5047 (cert_path, key_path)
5048 }
5049
5050 #[cfg(feature = "oauth-mtls-client")]
5051 fn install_test_crypto_provider() {
5052 let _ = rustls::crypto::ring::default_provider().install_default();
5053 }
5054
5055 #[cfg(feature = "oauth-mtls-client")]
5056 #[test]
5057 fn validate_accepts_well_formed_client_cert() {
5058 install_test_crypto_provider();
5059 let (cert_path, key_path) = write_self_signed_pem();
5060 let cc = ClientCertConfig {
5061 cert_path: cert_path.clone(),
5062 key_path: key_path.clone(),
5063 };
5064 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5065 let res = cfg.validate();
5066 let _ = std::fs::remove_file(&cert_path);
5067 let _ = std::fs::remove_file(&key_path);
5068 res.expect("well-formed cert+key must validate");
5069 }
5070
5071 #[cfg(feature = "oauth-mtls-client")]
5072 #[test]
5073 fn client_for_returns_cached_mtls_client() {
5074 install_test_crypto_provider();
5075 let (cert_path, key_path) = write_self_signed_pem();
5076 let cc = ClientCertConfig {
5077 cert_path: cert_path.clone(),
5078 key_path: key_path.clone(),
5079 };
5080 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5081 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5082 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5083 let cert_client = http.client_for(tx_ref);
5084 let inner_client = http.client_for(&tx_with(Some("s"), None));
5085 let _ = std::fs::remove_file(&cert_path);
5086 let _ = std::fs::remove_file(&key_path);
5087 assert!(
5088 !std::ptr::eq(cert_client, inner_client),
5089 "client_for must return distinct clients for cert vs no-cert configs"
5090 );
5091 }
5092
5093 #[cfg(feature = "oauth-mtls-client")]
5094 #[test]
5095 fn client_for_falls_back_to_inner_when_cache_miss() {
5096 install_test_crypto_provider();
5097 let cfg = validation_https_config();
5098 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5099 let unrelated_cc = ClientCertConfig {
5100 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5101 key_path: PathBuf::from("/cache/miss/key.pem"),
5102 };
5103 let tx_unknown = tx_with(None, Some(unrelated_cc));
5104 let fallback = http.client_for(&tx_unknown);
5105 let inner = http.client_for(&tx_with(Some("s"), None));
5106 assert!(
5107 std::ptr::eq(fallback, inner),
5108 "cache miss must fall back to inner client"
5109 );
5110 }
5111}