1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86async fn screen_oauth_target_core(
106 url: &str,
107 allow_http: bool,
108 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
109 test_allow_loopback_ssrf: bool,
110) -> Result<(), crate::error::McpxError> {
111 let parsed = check_oauth_url("oauth target", url, allow_http)?;
112 if test_allow_loopback_ssrf {
113 return Ok(());
114 }
115 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
116 return Err(crate::error::McpxError::Config(format!(
117 "OAuth target forbidden ({reason}): {url}"
118 )));
119 }
120
121 let host = parsed.host_str().ok_or_else(|| {
122 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
123 })?;
124 let port = parsed.port_or_known_default().ok_or_else(|| {
125 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
126 })?;
127
128 let addrs = lookup_host((host, port)).await.map_err(|error| {
129 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
130 })?;
131
132 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
133 let mut any_addr = false;
134 for addr in addrs {
135 any_addr = true;
136 let ip = addr.ip();
137 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
138 if reason == "cloud_metadata" {
141 return Err(crate::error::McpxError::Config(format!(
142 "OAuth target resolved to blocked IP ({reason}): {url}"
143 )));
144 }
145 if allowlist.is_empty() {
149 return Err(crate::error::McpxError::Config(format!(
150 "OAuth target resolved to blocked IP ({reason}): {url}"
151 )));
152 }
153 if host_allowed || allowlist.ip_allowed(ip) {
155 continue;
156 }
157 return Err(crate::error::McpxError::Config(format!(
158 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
159 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
160 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
161 URL: {url}"
162 )));
163 }
164 }
165 if !any_addr {
166 return Err(crate::error::McpxError::Config(format!(
167 "OAuth target DNS resolution returned no addresses: {url}"
168 )));
169 }
170
171 Ok(())
172}
173
174async fn screen_oauth_target(
177 url: &str,
178 allow_http: bool,
179 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
180) -> Result<(), crate::error::McpxError> {
181 screen_oauth_target_core(url, allow_http, allowlist, false).await
182}
183
184#[cfg(any(test, feature = "test-helpers"))]
188async fn screen_oauth_target_with_test_override(
189 url: &str,
190 allow_http: bool,
191 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
192 test_allow_loopback_ssrf: bool,
193) -> Result<(), crate::error::McpxError> {
194 screen_oauth_target_core(url, allow_http, allowlist, test_allow_loopback_ssrf).await
195}
196
197#[derive(Clone)]
238pub struct OauthHttpClient {
239 inner: reqwest::Client,
240 allow_http: bool,
241 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
246 #[cfg(feature = "oauth-mtls-client")]
251 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
252 #[cfg(any(test, feature = "test-helpers"))]
258 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
259}
260
261#[cfg(feature = "oauth-mtls-client")]
265#[derive(Debug, Clone, Hash, Eq, PartialEq)]
266struct MtlsClientKey {
267 cert_path: PathBuf,
268 key_path: PathBuf,
269}
270
271impl OauthHttpClient {
272 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
290 Self::build(Some(config))
291 }
292
293 #[deprecated(
316 since = "1.2.1",
317 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
318 )]
319 pub fn new() -> Result<Self, crate::error::McpxError> {
320 Self::build(None)
321 }
322
323 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
326 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
327
328 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
333 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
334 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
335 })?),
336 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
337 };
338
339 let redirect_allowlist = Arc::clone(&allowlist);
342
343 #[cfg(any(test, feature = "test-helpers"))]
347 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
348 Arc::new(AtomicBool::new(false));
349 #[cfg(not(any(test, feature = "test-helpers")))]
350 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
351
352 let resolver: Arc<dyn reqwest::dns::Resolve> =
353 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
354 Arc::clone(&allowlist),
355 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
359 test_bypass.clone(),
360 ));
361
362 let mut builder = reqwest::Client::builder()
363 .no_proxy()
367 .dns_resolver(Arc::clone(&resolver))
368 .connect_timeout(Duration::from_secs(10))
369 .timeout(Duration::from_secs(30))
370 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
371 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
381 Ok(()) => attempt.follow(),
382 Err(reason) => {
383 tracing::warn!(
387 reason = %reason,
388 target = %crate::ssrf::sanitized_url_for_log(attempt.url()),
389 "oauth redirect rejected"
390 );
391 attempt.error(reason)
392 }
393 }
394 }));
395
396 if let Some(cfg) = config
397 && let Some(ref ca_path) = cfg.ca_cert_path
398 {
399 let pem = std::fs::read(ca_path).map_err(|e| {
404 crate::error::McpxError::Startup(format!(
405 "oauth http client: read ca_cert_path {}: {e}",
406 ca_path.display()
407 ))
408 })?;
409 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
410 crate::error::McpxError::Startup(format!(
411 "oauth http client: parse ca_cert_path {}: {e}",
412 ca_path.display()
413 ))
414 })?;
415 builder = builder.add_root_certificate(cert);
416 }
417
418 let inner = builder.build().map_err(|e| {
419 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
420 })?;
421
422 #[cfg(feature = "oauth-mtls-client")]
423 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
424
425 Ok(Self {
426 inner,
427 allow_http,
428 allowlist,
429 #[cfg(feature = "oauth-mtls-client")]
430 mtls_clients,
431 #[cfg(any(test, feature = "test-helpers"))]
432 test_allow_loopback_ssrf: test_bypass,
433 })
434 }
435
436 async fn send_screened(
437 &self,
438 url: &str,
439 request: reqwest::RequestBuilder,
440 ) -> Result<reqwest::Response, crate::error::McpxError> {
441 #[cfg(any(test, feature = "test-helpers"))]
442 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
443 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
444 .await?;
445 } else {
446 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
447 }
448 #[cfg(not(any(test, feature = "test-helpers")))]
449 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
450 request.send().await.map_err(|error| {
451 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
452 })
453 }
454
455 #[cfg(any(test, feature = "test-helpers"))]
460 #[doc(hidden)]
461 #[must_use]
462 pub fn __test_allow_loopback_ssrf(self) -> Self {
463 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
466 self
467 }
468
469 #[doc(hidden)]
475 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
476 self.inner.get(url).send().await
477 }
478
479 #[cfg(any(test, feature = "test-helpers"))]
485 #[doc(hidden)]
486 #[must_use]
487 pub fn __test_inner_client(&self) -> &reqwest::Client {
488 &self.inner
489 }
490
491 #[cfg(feature = "oauth-mtls-client")]
498 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
499 if let Some(cc) = &cfg.client_cert {
500 let key = MtlsClientKey {
501 cert_path: cc.cert_path.clone(),
502 key_path: cc.key_path.clone(),
503 };
504 if let Some(client) = self.mtls_clients.get(&key) {
505 return client;
506 }
507 }
508 &self.inner
509 }
510
511 #[cfg(not(feature = "oauth-mtls-client"))]
512 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
513 &self.inner
514 }
515}
516
517impl std::fmt::Debug for OauthHttpClient {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
520 }
521}
522
523#[derive(Debug, Clone, Default, Deserialize)]
583#[non_exhaustive]
584pub struct OAuthSsrfAllowlist {
585 #[serde(default)]
590 pub hosts: Vec<String>,
591 #[serde(default)]
597 pub cidrs: Vec<String>,
598}
599
600fn compile_oauth_ssrf_allowlist(
607 raw: &OAuthSsrfAllowlist,
608) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
609 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
610 for (idx, entry) in raw.hosts.iter().enumerate() {
611 let trimmed = entry.trim();
612 if trimmed.is_empty() {
613 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
614 }
615 if trimmed.contains([':', '/', '@', '?', '#']) {
619 return Err(format!(
620 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
621 (no scheme, port, path, userinfo, query, or fragment)"
622 ));
623 }
624 match url::Host::parse(trimmed) {
625 Ok(url::Host::Domain(_)) => {}
626 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
627 return Err(format!(
628 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
629 here -- list them via oauth.ssrf_allowlist.cidrs instead"
630 ));
631 }
632 Err(e) => {
633 return Err(format!(
634 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
635 ));
636 }
637 }
638 hosts.push(trimmed.to_ascii_lowercase());
639 }
640 hosts.sort();
641 hosts.dedup();
642
643 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
644 for (idx, entry) in raw.cidrs.iter().enumerate() {
645 let parsed = crate::ssrf::CidrEntry::parse(entry)
646 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
647 cidrs.push(parsed);
648 }
649
650 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
651}
652
653#[derive(Debug, Clone, Deserialize)]
655#[non_exhaustive]
656pub struct OAuthConfig {
657 pub issuer: String,
659 pub audience: String,
661 pub jwks_uri: String,
663 #[serde(default)]
666 pub scopes: Vec<ScopeMapping>,
667 pub role_claim: Option<String>,
673 #[serde(default)]
676 pub role_mappings: Vec<RoleMapping>,
677 #[serde(default = "default_jwks_cache_ttl")]
680 pub jwks_cache_ttl: String,
681 pub proxy: Option<OAuthProxyConfig>,
685 pub token_exchange: Option<TokenExchangeConfig>,
690 #[serde(default)]
705 pub ca_cert_path: Option<PathBuf>,
706 #[serde(default)]
718 pub allow_http_oauth_urls: bool,
719 #[serde(default)]
728 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
729 #[serde(default = "default_max_jwks_keys")]
733 pub max_jwks_keys: usize,
734 #[serde(default)]
743 #[deprecated(
744 since = "1.7.0",
745 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
746 )]
747 pub strict_audience_validation: bool,
748 #[serde(default)]
756 pub audience_validation_mode: Option<AudienceValidationMode>,
757 #[serde(default = "default_jwks_max_bytes")]
761 pub jwks_max_response_bytes: u64,
762}
763
764fn default_jwks_cache_ttl() -> String {
765 "10m".into()
766}
767
768const fn default_max_jwks_keys() -> usize {
769 256
770}
771
772const fn default_jwks_max_bytes() -> u64 {
773 1024 * 1024
774}
775
776#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
793#[serde(rename_all = "snake_case")]
794#[non_exhaustive]
795pub enum AudienceValidationMode {
796 Permissive,
800 #[default]
804 Warn,
805 Strict,
809}
810
811impl AudienceValidationMode {
812 #[must_use]
817 pub(crate) const fn as_str(self) -> &'static str {
818 match self {
819 Self::Permissive => "permissive",
820 Self::Warn => "warn",
821 Self::Strict => "strict",
822 }
823 }
824}
825
826impl Default for OAuthConfig {
827 fn default() -> Self {
828 Self {
829 issuer: String::new(),
830 audience: String::new(),
831 jwks_uri: String::new(),
832 scopes: Vec::new(),
833 role_claim: None,
834 role_mappings: Vec::new(),
835 jwks_cache_ttl: default_jwks_cache_ttl(),
836 proxy: None,
837 token_exchange: None,
838 ca_cert_path: None,
839 allow_http_oauth_urls: false,
840 max_jwks_keys: default_max_jwks_keys(),
841 #[allow(
842 deprecated,
843 reason = "default-construct deprecated field for backward compat"
844 )]
845 strict_audience_validation: false,
846 audience_validation_mode: None,
847 jwks_max_response_bytes: default_jwks_max_bytes(),
848 ssrf_allowlist: None,
849 }
850 }
851}
852
853impl OAuthConfig {
854 #[must_use]
860 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
861 if let Some(mode) = self.audience_validation_mode {
862 return mode;
863 }
864 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
865 if self.strict_audience_validation {
866 AudienceValidationMode::Strict
867 } else {
868 AudienceValidationMode::Warn
869 }
870 }
871
872 pub fn builder(
878 issuer: impl Into<String>,
879 audience: impl Into<String>,
880 jwks_uri: impl Into<String>,
881 ) -> OAuthConfigBuilder {
882 OAuthConfigBuilder {
883 inner: Self {
884 issuer: issuer.into(),
885 audience: audience.into(),
886 jwks_uri: jwks_uri.into(),
887 ..Self::default()
888 },
889 }
890 }
891
892 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
908 let allow_http = self.allow_http_oauth_urls;
909 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
910 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
911 return Err(crate::error::McpxError::Config(format!(
912 "oauth.issuer forbidden ({reason})"
913 )));
914 }
915 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
916 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
917 return Err(crate::error::McpxError::Config(format!(
918 "oauth.jwks_uri forbidden ({reason})"
919 )));
920 }
921 if let Some(proxy) = &self.proxy {
922 let url = check_oauth_url(
923 "oauth.proxy.authorize_url",
924 &proxy.authorize_url,
925 allow_http,
926 )?;
927 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
928 return Err(crate::error::McpxError::Config(format!(
929 "oauth.proxy.authorize_url forbidden ({reason})"
930 )));
931 }
932 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
933 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
934 return Err(crate::error::McpxError::Config(format!(
935 "oauth.proxy.token_url forbidden ({reason})"
936 )));
937 }
938 if let Some(url) = &proxy.introspection_url {
939 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
940 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
941 return Err(crate::error::McpxError::Config(format!(
942 "oauth.proxy.introspection_url forbidden ({reason})"
943 )));
944 }
945 }
946 if let Some(url) = &proxy.revocation_url {
947 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
948 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
949 return Err(crate::error::McpxError::Config(format!(
950 "oauth.proxy.revocation_url forbidden ({reason})"
951 )));
952 }
953 }
954 if proxy.expose_admin_endpoints
961 && !proxy.require_auth_on_admin_endpoints
962 && !proxy.allow_unauthenticated_admin_endpoints
963 {
964 return Err(crate::error::McpxError::Config(
965 "oauth.proxy: expose_admin_endpoints = true requires \
966 require_auth_on_admin_endpoints = true (recommended) \
967 or allow_unauthenticated_admin_endpoints = true \
968 (explicit opt-out, only safe behind an authenticated \
969 reverse proxy)"
970 .into(),
971 ));
972 }
973 }
974 if let Some(tx) = &self.token_exchange {
975 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
976 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
977 return Err(crate::error::McpxError::Config(format!(
978 "oauth.token_exchange.token_url forbidden ({reason})"
979 )));
980 }
981 validate_token_exchange_client_auth(tx)?;
984 }
985 if let Some(raw) = &self.ssrf_allowlist {
989 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
990 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
991 })?;
992 if !compiled.is_empty() {
993 tracing::warn!(
994 host_count = compiled.host_count(),
995 cidr_count = compiled.cidr_count(),
996 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
997 are now reachable. Cloud-metadata addresses remain blocked. \
998 See SECURITY.md \"Operator allowlist\"."
999 );
1000 }
1001 }
1002 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
1005 crate::error::McpxError::Config(format!(
1006 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
1007 self.jwks_cache_ttl
1008 ))
1009 })?;
1010 Ok(())
1011 }
1012}
1013
1014fn validate_token_exchange_client_auth(
1020 tx: &TokenExchangeConfig,
1021) -> Result<(), crate::error::McpxError> {
1022 match (&tx.client_cert, tx.client_secret.is_some()) {
1023 (Some(_), true) => Err(crate::error::McpxError::Config(
1024 "oauth.token_exchange: client_cert and client_secret are mutually \
1025 exclusive (RFC 8705 ยง2). Set exactly one."
1026 .into(),
1027 )),
1028 (None, false) => Err(crate::error::McpxError::Config(
1029 "oauth.token_exchange: token exchange requires client authentication. \
1030 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1031 .into(),
1032 )),
1033 (Some(cc), false) => validate_client_cert_config(cc),
1034 (None, true) => Ok(()),
1035 }
1036}
1037
1038fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1051 #[cfg(not(feature = "oauth-mtls-client"))]
1052 {
1053 let _ = cc;
1054 Err(crate::error::McpxError::Config(
1055 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1056 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1057 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1058 the field"
1059 .into(),
1060 ))
1061 }
1062 #[cfg(feature = "oauth-mtls-client")]
1063 {
1064 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1065 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1066 crate::error::McpxError::Config(format!(
1067 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1068 cc.cert_path.display()
1069 ))
1070 })?;
1071 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1072 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1073 crate::error::McpxError::Config(format!(
1074 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1075 cc.key_path.display()
1076 ))
1077 })?;
1078 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1079 combined.extend_from_slice(&cert_bytes);
1080 if !cert_bytes.ends_with(b"\n") {
1081 combined.push(b'\n');
1082 }
1083 combined.extend_from_slice(&key_bytes);
1084 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1085 tracing::warn!(
1086 error = %e,
1087 cert_path = %cc.cert_path.display(),
1088 key_path = %cc.key_path.display(),
1089 "client cert PEM parse failed"
1090 );
1091 crate::error::McpxError::Config(format!(
1092 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1093 cc.cert_path.display(),
1094 cc.key_path.display()
1095 ))
1096 })?;
1097 Ok(())
1098 }
1099}
1100
1101#[cfg(feature = "oauth-mtls-client")]
1109fn build_mtls_clients(
1110 config: Option<&OAuthConfig>,
1111 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1112 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1113) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1114 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1115 let Some(cfg) = config else {
1116 return Ok(Arc::new(map));
1117 };
1118 let Some(tx) = &cfg.token_exchange else {
1119 return Ok(Arc::new(map));
1120 };
1121 let Some(cc) = &tx.client_cert else {
1122 return Ok(Arc::new(map));
1123 };
1124
1125 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1126 crate::error::McpxError::Startup(format!(
1127 "oauth http client mTLS: read cert_path {}: {e}",
1128 cc.cert_path.display()
1129 ))
1130 })?;
1131 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1132 crate::error::McpxError::Startup(format!(
1133 "oauth http client mTLS: read key_path {}: {e}",
1134 cc.key_path.display()
1135 ))
1136 })?;
1137 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1138 combined.extend_from_slice(&cert_bytes);
1139 if !cert_bytes.ends_with(b"\n") {
1140 combined.push(b'\n');
1141 }
1142 combined.extend_from_slice(&key_bytes);
1143 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1144 crate::error::McpxError::Startup(format!(
1145 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1146 cc.cert_path.display(),
1147 cc.key_path.display()
1148 ))
1149 })?;
1150
1151 let resolver: Arc<dyn reqwest::dns::Resolve> =
1152 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1153 Arc::clone(allowlist),
1154 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1159 test_bypass.clone(),
1160 ));
1161
1162 let mut builder = reqwest::Client::builder()
1163 .no_proxy()
1165 .dns_resolver(Arc::clone(&resolver))
1166 .connect_timeout(Duration::from_secs(10))
1167 .timeout(Duration::from_secs(30))
1168 .redirect(reqwest::redirect::Policy::none())
1169 .identity(identity);
1170
1171 if let Some(ref ca_path) = cfg.ca_cert_path {
1172 let pem = std::fs::read(ca_path).map_err(|e| {
1173 crate::error::McpxError::Startup(format!(
1174 "oauth http client mTLS: read ca_cert_path {}: {e}",
1175 ca_path.display()
1176 ))
1177 })?;
1178 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1179 crate::error::McpxError::Startup(format!(
1180 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1181 ca_path.display()
1182 ))
1183 })?;
1184 builder = builder.add_root_certificate(cert);
1185 }
1186
1187 let client = builder.build().map_err(|e| {
1188 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1189 })?;
1190 map.insert(
1191 MtlsClientKey {
1192 cert_path: cc.cert_path.clone(),
1193 key_path: cc.key_path.clone(),
1194 },
1195 client,
1196 );
1197 Ok(Arc::new(map))
1198}
1199
1200fn check_oauth_url(
1207 field: &str,
1208 raw: &str,
1209 allow_http: bool,
1210) -> Result<url::Url, crate::error::McpxError> {
1211 let parsed = url::Url::parse(raw).map_err(|e| {
1212 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1213 })?;
1214 if !parsed.username().is_empty() || parsed.password().is_some() {
1215 return Err(crate::error::McpxError::Config(format!(
1216 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1217 )));
1218 }
1219 match parsed.scheme() {
1220 "https" => Ok(parsed),
1221 "http" if allow_http => Ok(parsed),
1222 "http" => Err(crate::error::McpxError::Config(format!(
1223 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1224 to override - strongly discouraged in production)"
1225 ))),
1226 other => Err(crate::error::McpxError::Config(format!(
1227 "{field}: must use https scheme (got {other:?})"
1228 ))),
1229 }
1230}
1231
1232#[derive(Debug, Clone)]
1238#[must_use = "builders do nothing until `.build()` is called"]
1239pub struct OAuthConfigBuilder {
1240 inner: OAuthConfig,
1241}
1242
1243impl OAuthConfigBuilder {
1244 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1246 self.inner.scopes = scopes;
1247 self
1248 }
1249
1250 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1252 self.inner.scopes.push(ScopeMapping {
1253 scope: scope.into(),
1254 role: role.into(),
1255 });
1256 self
1257 }
1258
1259 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1262 self.inner.role_claim = Some(claim.into());
1263 self
1264 }
1265
1266 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1268 self.inner.role_mappings = mappings;
1269 self
1270 }
1271
1272 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1275 self.inner.role_mappings.push(RoleMapping {
1276 claim_value: claim_value.into(),
1277 role: role.into(),
1278 });
1279 self
1280 }
1281
1282 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1285 self.inner.jwks_cache_ttl = ttl.into();
1286 self
1287 }
1288
1289 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1292 self.inner.proxy = Some(proxy);
1293 self
1294 }
1295
1296 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1298 self.inner.token_exchange = Some(token_exchange);
1299 self
1300 }
1301
1302 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1307 self.inner.ca_cert_path = Some(path.into());
1308 self
1309 }
1310
1311 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1317 self.inner.allow_http_oauth_urls = allow;
1318 self
1319 }
1320
1321 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1330 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1331 #[allow(
1332 deprecated,
1333 reason = "intentional: deprecated builder forwards to deprecated field"
1334 )]
1335 {
1336 self.inner.strict_audience_validation = strict;
1337 }
1338 self.inner.audience_validation_mode = None;
1339 self
1340 }
1341
1342 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1350 self.inner.audience_validation_mode = Some(mode);
1351 self
1352 }
1353
1354 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1356 self.inner.jwks_max_response_bytes = bytes;
1357 self
1358 }
1359
1360 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1368 self.inner.ssrf_allowlist = Some(allowlist);
1369 self
1370 }
1371
1372 #[must_use]
1374 pub fn build(self) -> OAuthConfig {
1375 self.inner
1376 }
1377}
1378
1379#[derive(Debug, Clone, Deserialize)]
1381#[non_exhaustive]
1382pub struct ScopeMapping {
1383 pub scope: String,
1385 pub role: String,
1387}
1388
1389#[derive(Debug, Clone, Deserialize)]
1393#[non_exhaustive]
1394pub struct RoleMapping {
1395 pub claim_value: String,
1397 pub role: String,
1399}
1400
1401#[derive(Debug, Clone, Deserialize)]
1408#[non_exhaustive]
1409pub struct TokenExchangeConfig {
1410 pub token_url: String,
1413 pub client_id: String,
1415 pub client_secret: Option<secrecy::SecretString>,
1420 pub client_cert: Option<ClientCertConfig>,
1433 pub audience: String,
1437}
1438
1439impl TokenExchangeConfig {
1440 #[must_use]
1442 pub fn new(
1443 token_url: String,
1444 client_id: String,
1445 client_secret: Option<secrecy::SecretString>,
1446 client_cert: Option<ClientCertConfig>,
1447 audience: String,
1448 ) -> Self {
1449 Self {
1450 token_url,
1451 client_id,
1452 client_secret,
1453 client_cert,
1454 audience,
1455 }
1456 }
1457}
1458
1459#[derive(Debug, Clone, Deserialize)]
1463#[non_exhaustive]
1464pub struct ClientCertConfig {
1465 pub cert_path: PathBuf,
1468 pub key_path: PathBuf,
1472}
1473
1474impl ClientCertConfig {
1475 #[must_use]
1479 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1480 Self {
1481 cert_path,
1482 key_path,
1483 }
1484 }
1485}
1486
1487#[derive(Debug, Deserialize)]
1489#[non_exhaustive]
1490pub struct ExchangedToken {
1491 pub access_token: String,
1493 pub expires_in: Option<u64>,
1495 pub issued_token_type: Option<String>,
1498}
1499
1500#[derive(Debug, Clone, Deserialize, Default)]
1507#[non_exhaustive]
1508pub struct OAuthProxyConfig {
1509 pub authorize_url: String,
1512 pub token_url: String,
1515 pub client_id: String,
1517 pub client_secret: Option<secrecy::SecretString>,
1519 #[serde(default)]
1523 pub introspection_url: Option<String>,
1524 #[serde(default)]
1528 pub revocation_url: Option<String>,
1529 #[serde(default)]
1541 pub expose_admin_endpoints: bool,
1542 #[serde(default)]
1548 pub require_auth_on_admin_endpoints: bool,
1549 #[serde(default)]
1560 pub allow_unauthenticated_admin_endpoints: bool,
1561}
1562
1563impl OAuthProxyConfig {
1564 pub fn builder(
1572 authorize_url: impl Into<String>,
1573 token_url: impl Into<String>,
1574 client_id: impl Into<String>,
1575 ) -> OAuthProxyConfigBuilder {
1576 OAuthProxyConfigBuilder {
1577 inner: Self {
1578 authorize_url: authorize_url.into(),
1579 token_url: token_url.into(),
1580 client_id: client_id.into(),
1581 ..Self::default()
1582 },
1583 }
1584 }
1585}
1586
1587#[derive(Debug, Clone)]
1593#[must_use = "builders do nothing until `.build()` is called"]
1594pub struct OAuthProxyConfigBuilder {
1595 inner: OAuthProxyConfig,
1596}
1597
1598impl OAuthProxyConfigBuilder {
1599 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1601 self.inner.client_secret = Some(secret);
1602 self
1603 }
1604
1605 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1609 self.inner.introspection_url = Some(url.into());
1610 self
1611 }
1612
1613 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1617 self.inner.revocation_url = Some(url.into());
1618 self
1619 }
1620
1621 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1629 self.inner.expose_admin_endpoints = expose;
1630 self
1631 }
1632
1633 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1636 self.inner.require_auth_on_admin_endpoints = require;
1637 self
1638 }
1639
1640 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1644 self.inner.allow_unauthenticated_admin_endpoints = allow;
1645 self
1646 }
1647
1648 #[must_use]
1650 pub fn build(self) -> OAuthProxyConfig {
1651 self.inner
1652 }
1653}
1654
1655type JwksKeyCache = (
1663 HashMap<String, (Algorithm, DecodingKey)>,
1664 Vec<(Algorithm, DecodingKey)>,
1665);
1666
1667struct CachedKeys {
1668 keys: HashMap<String, (Algorithm, DecodingKey)>,
1670 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1672 fetched_at: Instant,
1673 ttl: Duration,
1674}
1675
1676impl CachedKeys {
1677 fn is_expired(&self) -> bool {
1678 self.fetched_at.elapsed() >= self.ttl
1679 }
1680}
1681
1682#[allow(
1691 missing_debug_implementations,
1692 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1693)]
1694#[non_exhaustive]
1695pub struct JwksCache {
1696 jwks_uri: String,
1697 ttl: Duration,
1698 max_jwks_keys: usize,
1699 max_response_bytes: u64,
1700 allow_http: bool,
1701 inner: RwLock<Option<CachedKeys>>,
1702 http: reqwest::Client,
1703 validation_template: Validation,
1704 expected_audience: String,
1707 audience_mode: AudienceValidationMode,
1708 azp_fallback_warned: AtomicBool,
1712 scopes: Vec<ScopeMapping>,
1713 role_claim: Option<String>,
1714 role_mappings: Vec<RoleMapping>,
1715 last_refresh_attempt: RwLock<Option<Instant>>,
1718 refresh_lock: tokio::sync::Mutex<()>,
1720 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1724 #[cfg(any(test, feature = "test-helpers"))]
1728 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1729}
1730
1731const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1733
1734const OAUTH_PROXY_MAX_RESPONSE_BYTES: u64 = 1024 * 1024;
1744
1745const ACCEPTED_ALGS: &[Algorithm] = &[
1747 Algorithm::RS256,
1748 Algorithm::RS384,
1749 Algorithm::RS512,
1750 Algorithm::ES256,
1751 Algorithm::ES384,
1752 Algorithm::PS256,
1753 Algorithm::PS384,
1754 Algorithm::PS512,
1755 Algorithm::EdDSA,
1756];
1757
1758#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1760#[non_exhaustive]
1761pub enum JwtValidationFailure {
1762 Expired,
1764 Invalid,
1766}
1767
1768impl JwksCache {
1769 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1781 rustls::crypto::ring::default_provider()
1784 .install_default()
1785 .ok();
1786 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1787 .install_default()
1788 .ok();
1789
1790 let ttl = humantime::parse_duration(&config.jwks_cache_ttl).map_err(|error| {
1791 format!(
1792 "invalid jwks_cache_ttl {:?}: {error}",
1793 config.jwks_cache_ttl
1794 )
1795 })?;
1796
1797 let mut validation = Validation::new(Algorithm::RS256);
1798 validation.validate_aud = false;
1810 validation.set_issuer(&[&config.issuer]);
1811 validation.set_required_spec_claims(&["exp", "iss"]);
1812 validation.validate_exp = true;
1813 validation.validate_nbf = true;
1814
1815 let allow_http = config.allow_http_oauth_urls;
1816
1817 let allowlist = match config.ssrf_allowlist.as_ref() {
1820 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1821 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1822 "oauth.ssrf_allowlist: {e}"
1823 ))
1824 })?),
1825 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1826 };
1827 let redirect_allowlist = Arc::clone(&allowlist);
1828
1829 #[cfg(any(test, feature = "test-helpers"))]
1831 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1832 Arc::new(AtomicBool::new(false));
1833 #[cfg(not(any(test, feature = "test-helpers")))]
1834 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1835
1836 let resolver: Arc<dyn reqwest::dns::Resolve> =
1837 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1838 Arc::clone(&allowlist),
1839 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1840 test_bypass.clone(),
1841 ));
1842
1843 let mut http_builder = reqwest::Client::builder()
1844 .no_proxy()
1846 .dns_resolver(Arc::clone(&resolver))
1847 .timeout(Duration::from_secs(10))
1848 .connect_timeout(Duration::from_secs(3))
1849 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1850 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1860 Ok(()) => attempt.follow(),
1861 Err(reason) => {
1862 tracing::warn!(
1866 reason = %reason,
1867 target = %crate::ssrf::sanitized_url_for_log(attempt.url()),
1868 "oauth redirect rejected"
1869 );
1870 attempt.error(reason)
1871 }
1872 }
1873 }));
1874
1875 if let Some(ref ca_path) = config.ca_cert_path {
1876 let pem = std::fs::read(ca_path)?;
1882 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1883 http_builder = http_builder.add_root_certificate(cert);
1884 }
1885
1886 let http = http_builder.build()?;
1887
1888 Ok(Self {
1889 jwks_uri: config.jwks_uri.clone(),
1890 ttl,
1891 max_jwks_keys: config.max_jwks_keys,
1892 max_response_bytes: config.jwks_max_response_bytes,
1893 allow_http,
1894 inner: RwLock::new(None),
1895 http,
1896 validation_template: validation,
1897 expected_audience: config.audience.clone(),
1898 audience_mode: config.effective_audience_validation_mode(),
1899 azp_fallback_warned: AtomicBool::new(false),
1900 scopes: config.scopes.clone(),
1901 role_claim: config.role_claim.clone(),
1902 role_mappings: config.role_mappings.clone(),
1903 last_refresh_attempt: RwLock::new(None),
1904 refresh_lock: tokio::sync::Mutex::new(()),
1905 allowlist,
1906 #[cfg(any(test, feature = "test-helpers"))]
1907 test_allow_loopback_ssrf: test_bypass,
1908 })
1909 }
1910
1911 #[cfg(any(test, feature = "test-helpers"))]
1915 #[doc(hidden)]
1916 #[must_use]
1917 pub fn __test_allow_loopback_ssrf(self) -> Self {
1918 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1921 self
1922 }
1923
1924 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1926 self.validate_token_with_reason(token).await.ok()
1927 }
1928
1929 pub async fn validate_token_with_reason(
1939 &self,
1940 token: &str,
1941 ) -> Result<AuthIdentity, JwtValidationFailure> {
1942 let claims = self.decode_claims(token).await?;
1943
1944 self.check_audience(&claims)?;
1945 let role = self.resolve_role(&claims)?;
1946
1947 let sub = claims.sub;
1950 let name = claims
1951 .extra
1952 .get("preferred_username")
1953 .and_then(|v| v.as_str())
1954 .map(String::from)
1955 .or_else(|| sub.clone())
1956 .or(claims.azp)
1957 .or(claims.client_id)
1958 .unwrap_or_else(|| "oauth-client".into());
1959
1960 Ok(AuthIdentity {
1961 name,
1962 role,
1963 method: AuthMethod::OAuthJwt,
1964 raw_token: None,
1965 sub,
1966 })
1967 }
1968
1969 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1985 let (key, alg) = self.select_jwks_key(token).await?;
1986
1987 let mut validation = self.validation_template.clone();
1991 validation.algorithms = vec![alg];
1992
1993 let token_owned = token.to_owned();
1996 let join =
1997 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1998 .await;
1999
2000 let decode_result = match join {
2001 Ok(r) => r,
2002 Err(join_err) => {
2003 core::hint::cold_path();
2004 tracing::error!(
2005 error = %join_err,
2006 "JWT decode task panicked or was cancelled"
2007 );
2008 return Err(JwtValidationFailure::Invalid);
2009 }
2010 };
2011
2012 decode_result.map(|td| td.claims).map_err(|e| {
2013 core::hint::cold_path();
2014 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
2015 JwtValidationFailure::Expired
2016 } else {
2017 JwtValidationFailure::Invalid
2018 };
2019 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
2020 failure
2021 })
2022 }
2023
2024 #[allow(
2033 clippy::cognitive_complexity,
2034 reason = "each failure arm pairs `cold_path()` with a distinct `tracing::debug!` site for observability; collapsing into combinators would lose structured-field log sites without reducing real complexity"
2035 )]
2036 async fn select_jwks_key(
2037 &self,
2038 token: &str,
2039 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2040 let Ok(header) = decode_header(token) else {
2041 core::hint::cold_path();
2042 tracing::debug!("JWT header decode failed");
2043 return Err(JwtValidationFailure::Invalid);
2044 };
2045 let kid = header.kid.as_deref();
2046 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2047
2048 if !ACCEPTED_ALGS.contains(&header.alg) {
2049 core::hint::cold_path();
2050 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2051 return Err(JwtValidationFailure::Invalid);
2052 }
2053
2054 let Some(key) = self.find_key(kid, header.alg).await else {
2055 core::hint::cold_path();
2056 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2057 return Err(JwtValidationFailure::Invalid);
2058 };
2059
2060 Ok((key, header.alg))
2061 }
2062
2063 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2072 if claims.aud.contains(&self.expected_audience) {
2073 return Ok(());
2074 }
2075 let azp_match = claims
2076 .azp
2077 .as_deref()
2078 .is_some_and(|azp| azp == self.expected_audience);
2079 if azp_match {
2080 match self.audience_mode {
2081 AudienceValidationMode::Permissive => return Ok(()),
2082 AudienceValidationMode::Warn => {
2083 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2084 tracing::warn!(
2085 expected = %self.expected_audience,
2086 azp = claims.azp.as_deref().unwrap_or("-"),
2087 "JWT accepted via deprecated azp-only audience fallback. \
2088 Configure your IdP to populate aud, or set \
2089 audience_validation_mode = \"strict\" once tokens carry aud correctly. \
2090 To silence this warning without changing acceptance, \
2091 set audience_validation_mode = \"permissive\". \
2092 This warning logs once per process."
2093 );
2094 }
2095 return Ok(());
2096 }
2097 AudienceValidationMode::Strict => {}
2098 }
2099 }
2100 core::hint::cold_path();
2101 tracing::debug!(
2102 aud = %claims.aud.log_display(),
2103 azp = claims.azp.as_deref().unwrap_or("-"),
2104 expected = %self.expected_audience,
2105 mode = self.audience_mode.as_str(),
2106 "JWT rejected: audience mismatch"
2107 );
2108 Err(JwtValidationFailure::Invalid)
2109 }
2110
2111 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2117 if let Some(ref claim_path) = self.role_claim {
2118 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2119 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2120 values.extend(resolve_claim_path(&claims.extra, claim_path));
2121 return self
2122 .role_mappings
2123 .iter()
2124 .find(|m| values.contains(&m.claim_value.as_str()))
2125 .map(|m| m.role.clone())
2126 .ok_or(JwtValidationFailure::Invalid);
2127 }
2128
2129 let token_scopes: Vec<&str> = claims
2130 .scope
2131 .as_deref()
2132 .unwrap_or("")
2133 .split_whitespace()
2134 .collect();
2135
2136 self.scopes
2137 .iter()
2138 .find(|m| token_scopes.contains(&m.scope.as_str()))
2139 .map(|m| m.role.clone())
2140 .ok_or(JwtValidationFailure::Invalid)
2141 }
2142
2143 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2149 {
2151 let guard = self.inner.read().await;
2152 if let Some(cached) = guard.as_ref()
2153 && !cached.is_expired()
2154 && let Some(key) = lookup_key(cached, kid, alg)
2155 {
2156 return Some(key);
2157 }
2158 }
2159
2160 self.refresh_with_cooldown().await;
2162
2163 let guard = self.inner.read().await;
2164 guard
2165 .as_ref()
2166 .and_then(|cached| lookup_key(cached, kid, alg))
2167 }
2168
2169 async fn refresh_with_cooldown(&self) {
2189 let _guard = self.refresh_lock.lock().await;
2191
2192 {
2194 let last = self.last_refresh_attempt.read().await;
2195 if let Some(ts) = *last
2196 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2197 {
2198 tracing::debug!(
2199 elapsed_ms = ts.elapsed().as_millis(),
2200 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2201 "JWKS refresh skipped (cooldown active)"
2202 );
2203 return;
2204 }
2205 }
2206
2207 {
2210 let mut last = self.last_refresh_attempt.write().await;
2211 *last = Some(Instant::now());
2212 }
2213
2214 let _ = self.refresh_inner().await;
2216 }
2217
2218 async fn refresh_inner(&self) -> Result<(), String> {
2227 let Some(jwks) = self.fetch_jwks().await else {
2228 return Ok(());
2229 };
2230 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2231 Ok(cache) => cache,
2232 Err(msg) => {
2233 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2234 return Err(msg);
2235 }
2236 };
2237
2238 tracing::debug!(
2239 named = keys.len(),
2240 unnamed = unnamed_keys.len(),
2241 "JWKS refreshed"
2242 );
2243
2244 let mut guard = self.inner.write().await;
2245 *guard = Some(CachedKeys {
2246 keys,
2247 unnamed_keys,
2248 fetched_at: Instant::now(),
2249 ttl: self.ttl,
2250 });
2251 drop(guard);
2252 Ok(())
2253 }
2254
2255 #[allow(
2257 clippy::cognitive_complexity,
2258 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2259 )]
2260 async fn fetch_jwks(&self) -> Option<JwkSet> {
2261 #[cfg(any(test, feature = "test-helpers"))]
2262 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2263 screen_oauth_target_with_test_override(
2264 &self.jwks_uri,
2265 self.allow_http,
2266 &self.allowlist,
2267 true,
2268 )
2269 .await
2270 } else {
2271 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2272 };
2273 #[cfg(not(any(test, feature = "test-helpers")))]
2274 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2275
2276 if let Err(error) = screening {
2277 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2278 return None;
2279 }
2280
2281 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2282 Ok(resp) => resp,
2283 Err(e) => {
2284 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2285 return None;
2286 }
2287 };
2288
2289 let initial_capacity =
2290 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2291 let mut body = Vec::with_capacity(initial_capacity);
2292 while let Some(chunk) = match resp.chunk().await {
2293 Ok(chunk) => chunk,
2294 Err(error) => {
2295 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2296 return None;
2297 }
2298 } {
2299 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2300 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2301 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2302 tracing::warn!(
2303 uri = %self.jwks_uri,
2304 max_bytes = self.max_response_bytes,
2305 "JWKS response exceeded configured size cap"
2306 );
2307 return None;
2308 }
2309 body.extend_from_slice(&chunk);
2310 }
2311
2312 match serde_json::from_slice::<JwkSet>(&body) {
2313 Ok(jwks) => Some(jwks),
2314 Err(error) => {
2315 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2316 None
2317 }
2318 }
2319 }
2320
2321 #[cfg(any(test, feature = "test-helpers"))]
2324 #[doc(hidden)]
2325 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2326 let jwks = self
2327 .fetch_jwks()
2328 .await
2329 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2330 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2331 let mut guard = self.inner.write().await;
2332 *guard = Some(CachedKeys {
2333 keys,
2334 unnamed_keys,
2335 fetched_at: Instant::now(),
2336 ttl: self.ttl,
2337 });
2338 drop(guard);
2339 Ok(())
2340 }
2341
2342 #[cfg(any(test, feature = "test-helpers"))]
2345 #[doc(hidden)]
2346 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2347 let guard = self.inner.read().await;
2348 guard
2349 .as_ref()
2350 .is_some_and(|cache| cache.keys.contains_key(kid))
2351 }
2352}
2353
2354fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2356 if jwks.keys.len() > max_keys {
2357 return Err(format!(
2358 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2359 jwks.keys.len(),
2360 max_keys
2361 ));
2362 }
2363 let mut keys = HashMap::new();
2364 let mut unnamed_keys = Vec::new();
2365 for jwk in &jwks.keys {
2366 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2367 continue;
2368 };
2369 let Some(alg) = jwk_algorithm(jwk) else {
2370 continue;
2371 };
2372 if let Some(ref kid) = jwk.common.key_id {
2373 keys.insert(kid.clone(), (alg, decoding_key));
2374 } else {
2375 unnamed_keys.push((alg, decoding_key));
2376 }
2377 }
2378 Ok((keys, unnamed_keys))
2379}
2380
2381fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2383 if let Some(kid) = kid
2384 && let Some((cached_alg, key)) = cached.keys.get(kid)
2385 && *cached_alg == alg
2386 {
2387 return Some(key.clone());
2388 }
2389 cached
2391 .unnamed_keys
2392 .iter()
2393 .find(|(a, _)| *a == alg)
2394 .map(|(_, k)| k.clone())
2395}
2396
2397#[allow(
2399 clippy::wildcard_enum_match_arm,
2400 reason = "jsonwebtoken KeyAlgorithm is a large external enum; only the JWT-signing variants are mappable to `Algorithm`"
2401)]
2402fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2403 jwk.common.key_algorithm.and_then(|ka| match ka {
2404 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2405 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2406 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2407 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2408 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2409 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2410 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2411 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2412 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2413 _ => None,
2414 })
2415}
2416
2417fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2438 match path {
2439 "sub" => claims.sub.iter().cloned().collect(),
2440 "azp" => claims.azp.iter().cloned().collect(),
2441 "client_id" => claims.client_id.iter().cloned().collect(),
2442 "aud" => claims.aud.0.clone(),
2443 "scope" => claims
2444 .scope
2445 .as_deref()
2446 .unwrap_or("")
2447 .split_whitespace()
2448 .map(str::to_owned)
2449 .collect(),
2450 _ => Vec::new(),
2451 }
2452}
2453
2454fn resolve_claim_path<'a>(
2464 extra: &'a HashMap<String, serde_json::Value>,
2465 path: &str,
2466) -> Vec<&'a str> {
2467 let mut segments = path.split('.');
2468 let Some(first) = segments.next() else {
2469 return Vec::new();
2470 };
2471
2472 let mut current: Option<&serde_json::Value> = extra.get(first);
2473
2474 for segment in segments {
2475 current = current.and_then(|v| v.get(segment));
2476 }
2477
2478 match current {
2479 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2480 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2481 _ => Vec::new(),
2482 }
2483}
2484
2485#[derive(Debug, Deserialize)]
2491struct Claims {
2492 sub: Option<String>,
2494 #[serde(default)]
2497 aud: OneOrMany,
2498 azp: Option<String>,
2500 client_id: Option<String>,
2502 scope: Option<String>,
2504 #[serde(flatten)]
2506 extra: HashMap<String, serde_json::Value>,
2507}
2508
2509#[derive(Debug, Default)]
2511struct OneOrMany(Vec<String>);
2512
2513impl OneOrMany {
2514 fn contains(&self, value: &str) -> bool {
2515 self.0.iter().any(|v| v == value)
2516 }
2517
2518 fn log_display(&self) -> String {
2522 if self.0.is_empty() {
2523 "-".to_owned()
2524 } else {
2525 self.0.join(", ")
2526 }
2527 }
2528}
2529
2530fn fmt_json_aud(value: Option<&serde_json::Value>) -> String {
2540 match value {
2541 Some(serde_json::Value::String(s)) => s.clone(),
2542 Some(serde_json::Value::Array(items)) => {
2543 let joined = items
2544 .iter()
2545 .filter_map(serde_json::Value::as_str)
2546 .collect::<Vec<_>>()
2547 .join(", ");
2548 if joined.is_empty() {
2549 "-".to_owned()
2550 } else {
2551 joined
2552 }
2553 }
2554 Some(
2555 serde_json::Value::Null
2556 | serde_json::Value::Bool(_)
2557 | serde_json::Value::Number(_)
2558 | serde_json::Value::Object(_),
2559 )
2560 | None => "-".to_owned(),
2561 }
2562}
2563
2564fn fmt_json_str(value: Option<&serde_json::Value>) -> &str {
2568 value.and_then(serde_json::Value::as_str).unwrap_or("-")
2569}
2570
2571impl<'de> Deserialize<'de> for OneOrMany {
2572 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2573 use serde::de;
2574
2575 struct Visitor;
2576 impl<'de> de::Visitor<'de> for Visitor {
2577 type Value = OneOrMany;
2578 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2579 f.write_str("a string or array of strings")
2580 }
2581 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2582 Ok(OneOrMany(vec![v.to_owned()]))
2583 }
2584 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2585 let mut v = Vec::new();
2586 while let Some(s) = seq.next_element::<String>()? {
2587 v.push(s);
2588 }
2589 Ok(OneOrMany(v))
2590 }
2591 }
2592 deserializer.deserialize_any(Visitor)
2593 }
2594}
2595
2596#[must_use]
2603pub fn looks_like_jwt(token: &str) -> bool {
2604 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2605
2606 let mut parts = token.splitn(4, '.');
2607 let Some(header_b64) = parts.next() else {
2608 return false;
2609 };
2610 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2612 return false;
2613 }
2614 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2616 return false;
2617 };
2618 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2620 return false;
2621 };
2622 header.get("alg").is_some()
2623}
2624
2625#[must_use]
2635pub fn protected_resource_metadata(
2636 resource_url: &str,
2637 server_url: &str,
2638 config: &OAuthConfig,
2639) -> serde_json::Value {
2640 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2645 let auth_server = server_url;
2646 serde_json::json!({
2647 "resource": resource_url,
2648 "authorization_servers": [auth_server],
2649 "scopes_supported": scopes,
2650 "bearer_methods_supported": ["header"]
2651 })
2652}
2653
2654#[must_use]
2659pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2660 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2661 let mut meta = serde_json::json!({
2662 "issuer": &config.issuer,
2663 "authorization_endpoint": format!("{server_url}/authorize"),
2664 "token_endpoint": format!("{server_url}/token"),
2665 "registration_endpoint": format!("{server_url}/register"),
2666 "response_types_supported": ["code"],
2667 "grant_types_supported": ["authorization_code", "refresh_token"],
2668 "code_challenge_methods_supported": ["S256"],
2669 "scopes_supported": scopes,
2670 "token_endpoint_auth_methods_supported": ["none"],
2671 });
2672 if let Some(proxy) = &config.proxy
2673 && proxy.expose_admin_endpoints
2674 && let Some(obj) = meta.as_object_mut()
2675 {
2676 if proxy.introspection_url.is_some() {
2677 obj.insert(
2678 "introspection_endpoint".into(),
2679 serde_json::Value::String(format!("{server_url}/introspect")),
2680 );
2681 }
2682 if proxy.revocation_url.is_some() {
2683 obj.insert(
2684 "revocation_endpoint".into(),
2685 serde_json::Value::String(format!("{server_url}/revoke")),
2686 );
2687 }
2688 if proxy.require_auth_on_admin_endpoints {
2689 obj.insert(
2690 "introspection_endpoint_auth_methods_supported".into(),
2691 serde_json::json!(["bearer"]),
2692 );
2693 obj.insert(
2694 "revocation_endpoint_auth_methods_supported".into(),
2695 serde_json::json!(["bearer"]),
2696 );
2697 }
2698 }
2699 meta
2700}
2701
2702#[must_use]
2715pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2716 use axum::{
2717 http::{StatusCode, header},
2718 response::IntoResponse,
2719 };
2720
2721 let upstream_query = replace_client_id(query, &proxy.client_id);
2723 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2724
2725 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2726}
2727
2728pub async fn handle_token(
2734 http: &OauthHttpClient,
2735 proxy: &OAuthProxyConfig,
2736 body: &str,
2737) -> axum::response::Response {
2738 use axum::{
2739 http::{StatusCode, header},
2740 response::IntoResponse,
2741 };
2742
2743 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2745
2746 if let Some(ref secret) = proxy.client_secret {
2748 use std::fmt::Write;
2749
2750 use secrecy::ExposeSecret;
2751 let _ = write!(
2752 upstream_body,
2753 "&client_secret={}",
2754 urlencoding::encode(secret.expose_secret())
2755 );
2756 }
2757
2758 let result = http
2759 .send_screened(
2760 &proxy.token_url,
2761 http.inner
2762 .post(&proxy.token_url)
2763 .header("Content-Type", "application/x-www-form-urlencoded")
2764 .body(upstream_body),
2765 )
2766 .await;
2767
2768 match result {
2769 Ok(resp) => {
2770 let status =
2771 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2772 let Ok(body_bytes) =
2773 read_response_capped(resp, OAUTH_PROXY_MAX_RESPONSE_BYTES, "oauth/token").await
2774 else {
2775 return oauth_error_response(
2776 StatusCode::BAD_GATEWAY,
2777 "server_error",
2778 "upstream response too large or unreadable",
2779 );
2780 };
2781 (
2782 status,
2783 [(header::CONTENT_TYPE, "application/json")],
2784 body_bytes,
2785 )
2786 .into_response()
2787 }
2788 Err(e) => {
2789 tracing::error!(error = %e, "OAuth token proxy request failed");
2790 (
2791 StatusCode::BAD_GATEWAY,
2792 [(header::CONTENT_TYPE, "application/json")],
2793 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2794 )
2795 .into_response()
2796 }
2797 }
2798}
2799
2800#[must_use]
2807pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2808 let mut resp = serde_json::json!({
2809 "client_id": proxy.client_id,
2810 "token_endpoint_auth_method": "none",
2811 });
2812 if let Some(uris) = body.get("redirect_uris")
2813 && let Some(obj) = resp.as_object_mut()
2814 {
2815 obj.insert("redirect_uris".into(), uris.clone());
2816 }
2817 if let Some(name) = body.get("client_name")
2818 && let Some(obj) = resp.as_object_mut()
2819 {
2820 obj.insert("client_name".into(), name.clone());
2821 }
2822 resp
2823}
2824
2825pub async fn handle_introspect(
2831 http: &OauthHttpClient,
2832 proxy: &OAuthProxyConfig,
2833 body: &str,
2834) -> axum::response::Response {
2835 let Some(ref url) = proxy.introspection_url else {
2836 return oauth_error_response(
2837 axum::http::StatusCode::NOT_FOUND,
2838 "not_supported",
2839 "introspection endpoint is not configured",
2840 );
2841 };
2842 proxy_oauth_admin_request(http, proxy, url, body).await
2843}
2844
2845pub async fn handle_revoke(
2852 http: &OauthHttpClient,
2853 proxy: &OAuthProxyConfig,
2854 body: &str,
2855) -> axum::response::Response {
2856 let Some(ref url) = proxy.revocation_url else {
2857 return oauth_error_response(
2858 axum::http::StatusCode::NOT_FOUND,
2859 "not_supported",
2860 "revocation endpoint is not configured",
2861 );
2862 };
2863 proxy_oauth_admin_request(http, proxy, url, body).await
2864}
2865
2866async fn proxy_oauth_admin_request(
2870 http: &OauthHttpClient,
2871 proxy: &OAuthProxyConfig,
2872 upstream_url: &str,
2873 body: &str,
2874) -> axum::response::Response {
2875 use axum::{
2876 http::{StatusCode, header},
2877 response::IntoResponse,
2878 };
2879
2880 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2881 if let Some(ref secret) = proxy.client_secret {
2882 use std::fmt::Write;
2883
2884 use secrecy::ExposeSecret;
2885 let _ = write!(
2886 upstream_body,
2887 "&client_secret={}",
2888 urlencoding::encode(secret.expose_secret())
2889 );
2890 }
2891
2892 let result = http
2893 .send_screened(
2894 upstream_url,
2895 http.inner
2896 .post(upstream_url)
2897 .header("Content-Type", "application/x-www-form-urlencoded")
2898 .body(upstream_body),
2899 )
2900 .await;
2901
2902 match result {
2903 Ok(resp) => {
2904 let status =
2905 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2906 let content_type = resp
2907 .headers()
2908 .get(header::CONTENT_TYPE)
2909 .and_then(|v| v.to_str().ok())
2910 .unwrap_or("application/json")
2911 .to_owned();
2912 let Ok(body_bytes) =
2913 read_response_capped(resp, OAUTH_PROXY_MAX_RESPONSE_BYTES, "oauth/admin").await
2914 else {
2915 return oauth_error_response(
2916 StatusCode::BAD_GATEWAY,
2917 "server_error",
2918 "upstream response too large or unreadable",
2919 );
2920 };
2921 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2922 }
2923 Err(e) => {
2924 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2925 oauth_error_response(
2926 StatusCode::BAD_GATEWAY,
2927 "server_error",
2928 "upstream endpoint unreachable",
2929 )
2930 }
2931 }
2932}
2933
2934async fn read_response_capped(
2944 mut resp: reqwest::Response,
2945 max_bytes: u64,
2946 context: &str,
2947) -> Result<Vec<u8>, ()> {
2948 let initial_capacity = usize::try_from(max_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2949 let mut body = Vec::with_capacity(initial_capacity);
2950 loop {
2951 match resp.chunk().await {
2952 Ok(Some(chunk)) => {
2953 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2954 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2955 if body_len.saturating_add(chunk_len) > max_bytes {
2956 tracing::warn!(
2957 context = context,
2958 max_bytes = max_bytes,
2959 "upstream OAuth response exceeded size cap; failing closed"
2960 );
2961 return Err(());
2962 }
2963 body.extend_from_slice(&chunk);
2964 }
2965 Ok(None) => return Ok(body),
2966 Err(error) => {
2967 tracing::warn!(context = context, error = %error, "failed to read upstream OAuth response");
2968 return Err(());
2969 }
2970 }
2971 }
2972}
2973
2974fn oauth_error_response(
2975 status: axum::http::StatusCode,
2976 error: &str,
2977 description: &str,
2978) -> axum::response::Response {
2979 use axum::{http::header, response::IntoResponse};
2980 let body = serde_json::json!({
2981 "error": error,
2982 "error_description": description,
2983 });
2984 (
2985 status,
2986 [(header::CONTENT_TYPE, "application/json")],
2987 body.to_string(),
2988 )
2989 .into_response()
2990}
2991
2992#[derive(Debug, Deserialize)]
2998struct OAuthErrorResponse {
2999 error: String,
3000 error_description: Option<String>,
3001}
3002
3003fn sanitize_oauth_error_code(raw: &str) -> &'static str {
3010 match raw {
3011 "invalid_request" => "invalid_request",
3012 "invalid_client" => "invalid_client",
3013 "invalid_grant" => "invalid_grant",
3014 "unauthorized_client" => "unauthorized_client",
3015 "unsupported_grant_type" => "unsupported_grant_type",
3016 "invalid_scope" => "invalid_scope",
3017 "temporarily_unavailable" => "temporarily_unavailable",
3018 "invalid_target" => "invalid_target",
3020 _ => "server_error",
3023 }
3024}
3025
3026pub async fn exchange_token(
3038 http: &OauthHttpClient,
3039 config: &TokenExchangeConfig,
3040 subject_token: &str,
3041) -> Result<ExchangedToken, crate::error::McpxError> {
3042 use secrecy::ExposeSecret;
3043
3044 let client = http.client_for(config);
3045 let mut req = client
3046 .post(&config.token_url)
3047 .header("Content-Type", "application/x-www-form-urlencoded")
3048 .header("Accept", "application/json");
3049
3050 if config.client_cert.is_none()
3059 && let Some(ref secret) = config.client_secret
3060 {
3061 use base64::Engine;
3062 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
3063 "{}:{}",
3064 urlencoding::encode(&config.client_id),
3065 urlencoding::encode(secret.expose_secret()),
3066 ));
3067 req = req.header("Authorization", format!("Basic {credentials}"));
3068 }
3069
3070 let form_body = build_exchange_form(config, subject_token);
3071
3072 let resp = http
3073 .send_screened(&config.token_url, req.body(form_body))
3074 .await
3075 .map_err(|e| {
3076 tracing::error!(error = %e, "token exchange request failed");
3077 crate::error::McpxError::Auth("server_error".into())
3079 })?;
3080
3081 let status = resp.status();
3082 let body_bytes =
3083 read_response_capped(resp, OAUTH_PROXY_MAX_RESPONSE_BYTES, "oauth/token-exchange")
3084 .await
3085 .map_err(|()| {
3086 crate::error::McpxError::Auth("server_error".into())
3088 })?;
3089
3090 if !status.is_success() {
3091 core::hint::cold_path();
3092 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
3095 let short_code = parsed
3096 .as_ref()
3097 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
3098 if let Some(ref e) = parsed {
3099 tracing::warn!(
3100 status = %status,
3101 upstream_error = %e.error,
3102 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
3103 client_code = %short_code,
3104 "token exchange rejected by authorization server",
3105 );
3106 } else {
3107 tracing::warn!(
3108 status = %status,
3109 client_code = %short_code,
3110 "token exchange rejected (unparseable upstream body)",
3111 );
3112 }
3113 return Err(crate::error::McpxError::Auth(short_code.into()));
3114 }
3115
3116 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
3117 tracing::error!(error = %e, "failed to parse token exchange response");
3118 crate::error::McpxError::Auth("server_error".into())
3121 })?;
3122
3123 log_exchanged_token(&exchanged);
3124
3125 Ok(exchanged)
3126}
3127
3128fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
3131 let body = format!(
3132 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
3133 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
3134 urlencoding::encode(subject_token),
3135 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3136 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3137 urlencoding::encode(&config.audience),
3138 );
3139 if config.client_secret.is_none() {
3140 format!(
3141 "{body}&client_id={}",
3142 urlencoding::encode(&config.client_id)
3143 )
3144 } else {
3145 body
3146 }
3147}
3148
3149fn log_exchanged_token(exchanged: &ExchangedToken) {
3152 use base64::Engine;
3153
3154 if !looks_like_jwt(&exchanged.access_token) {
3155 tracing::debug!(
3156 token_len = exchanged.access_token.len(),
3157 issued_token_type = exchanged.issued_token_type.as_deref().unwrap_or("-"),
3158 expires_in = exchanged.expires_in,
3159 "exchanged token (opaque)",
3160 );
3161 return;
3162 }
3163 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
3164 return;
3165 };
3166 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
3167 return;
3168 };
3169 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
3170 return;
3171 };
3172 tracing::debug!(
3173 sub = fmt_json_str(claims.get("sub")),
3174 aud = %fmt_json_aud(claims.get("aud")),
3175 azp = fmt_json_str(claims.get("azp")),
3176 iss = fmt_json_str(claims.get("iss")),
3177 expires_in = exchanged.expires_in,
3178 "exchanged token claims (JWT)",
3179 );
3180}
3181
3182fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3184 let encoded_id = urlencoding::encode(upstream_client_id);
3185 let mut parts: Vec<String> = params
3186 .split('&')
3187 .filter(|p| !p.starts_with("client_id="))
3188 .map(String::from)
3189 .collect();
3190 parts.push(format!("client_id={encoded_id}"));
3191 parts.join("&")
3192}
3193
3194#[cfg(test)]
3195mod tests {
3196 use std::sync::Arc;
3197
3198 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3199
3200 use super::*;
3201
3202 #[test]
3203 fn looks_like_jwt_valid() {
3204 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3206 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3207 let token = format!("{header}.{payload}.signature");
3208 assert!(looks_like_jwt(&token));
3209 }
3210
3211 #[test]
3212 fn looks_like_jwt_rejects_opaque_token() {
3213 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3214 }
3215
3216 #[test]
3217 fn looks_like_jwt_rejects_two_segments() {
3218 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3219 let token = format!("{header}.payload");
3220 assert!(!looks_like_jwt(&token));
3221 }
3222
3223 #[test]
3224 fn looks_like_jwt_rejects_four_segments() {
3225 assert!(!looks_like_jwt("a.b.c.d"));
3226 }
3227
3228 #[test]
3229 fn looks_like_jwt_rejects_no_alg() {
3230 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3231 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3232 let token = format!("{header}.{payload}.sig");
3233 assert!(!looks_like_jwt(&token));
3234 }
3235
3236 #[test]
3237 fn protected_resource_metadata_shape() {
3238 let config = OAuthConfig {
3239 issuer: "https://auth.example.com".into(),
3240 audience: "https://mcp.example.com/mcp".into(),
3241 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3242 scopes: vec![
3243 ScopeMapping {
3244 scope: "mcp:read".into(),
3245 role: "viewer".into(),
3246 },
3247 ScopeMapping {
3248 scope: "mcp:admin".into(),
3249 role: "ops".into(),
3250 },
3251 ],
3252 role_claim: None,
3253 role_mappings: vec![],
3254 jwks_cache_ttl: "10m".into(),
3255 proxy: None,
3256 token_exchange: None,
3257 ca_cert_path: None,
3258 allow_http_oauth_urls: false,
3259 max_jwks_keys: default_max_jwks_keys(),
3260 #[allow(
3261 deprecated,
3262 reason = "test fixture: explicit value for the deprecated field"
3263 )]
3264 strict_audience_validation: false,
3265 audience_validation_mode: None,
3266 jwks_max_response_bytes: default_jwks_max_bytes(),
3267 ssrf_allowlist: None,
3268 };
3269 let meta = protected_resource_metadata(
3270 "https://mcp.example.com/mcp",
3271 "https://mcp.example.com",
3272 &config,
3273 );
3274 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3275 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3276 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3277 assert_eq!(meta["bearer_methods_supported"][0], "header");
3278 }
3279
3280 fn validation_https_config() -> OAuthConfig {
3285 OAuthConfig::builder(
3286 "https://auth.example.com",
3287 "mcp",
3288 "https://auth.example.com/.well-known/jwks.json",
3289 )
3290 .build()
3291 }
3292
3293 #[test]
3294 fn validate_accepts_all_https_urls() {
3295 let cfg = validation_https_config();
3296 cfg.validate().expect("all-HTTPS config must validate");
3297 }
3298
3299 #[test]
3300 fn validate_rejects_unparseable_jwks_cache_ttl() {
3301 let mut cfg = validation_https_config();
3302 cfg.jwks_cache_ttl = "not-a-duration".into();
3303 let err = cfg
3304 .validate()
3305 .expect_err("malformed jwks_cache_ttl must be rejected");
3306 let msg = err.to_string();
3307 assert!(
3308 msg.contains("jwks_cache_ttl"),
3309 "error must reference offending field; got {msg:?}"
3310 );
3311 }
3312
3313 #[test]
3314 fn validate_rejects_http_jwks_uri() {
3315 let mut cfg = validation_https_config();
3316 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3317 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3318 let msg = err.to_string();
3319 assert!(
3320 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3321 "error must reference offending field + scheme requirement; got {msg:?}"
3322 );
3323 }
3324
3325 #[test]
3326 fn validate_rejects_http_proxy_authorize_url() {
3327 let mut cfg = validation_https_config();
3328 cfg.proxy = Some(
3329 OAuthProxyConfig::builder(
3330 "http://idp.example.com/authorize", "https://idp.example.com/token",
3332 "client",
3333 )
3334 .build(),
3335 );
3336 let err = cfg
3337 .validate()
3338 .expect_err("http authorize_url must be rejected");
3339 assert!(
3340 err.to_string().contains("oauth.proxy.authorize_url"),
3341 "error must reference proxy.authorize_url; got {err}"
3342 );
3343 }
3344
3345 #[test]
3346 fn validate_rejects_http_proxy_token_url() {
3347 let mut cfg = validation_https_config();
3348 cfg.proxy = Some(
3349 OAuthProxyConfig::builder(
3350 "https://idp.example.com/authorize",
3351 "http://idp.example.com/token", "client",
3353 )
3354 .build(),
3355 );
3356 let err = cfg.validate().expect_err("http token_url must be rejected");
3357 assert!(
3358 err.to_string().contains("oauth.proxy.token_url"),
3359 "error must reference proxy.token_url; got {err}"
3360 );
3361 }
3362
3363 #[test]
3364 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3365 let mut cfg = validation_https_config();
3366 cfg.proxy = Some(
3367 OAuthProxyConfig::builder(
3368 "https://idp.example.com/authorize",
3369 "https://idp.example.com/token",
3370 "client",
3371 )
3372 .introspection_url("http://idp.example.com/introspect")
3373 .build(),
3374 );
3375 let err = cfg
3376 .validate()
3377 .expect_err("http introspection_url must be rejected");
3378 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3379
3380 let mut cfg = validation_https_config();
3381 cfg.proxy = Some(
3382 OAuthProxyConfig::builder(
3383 "https://idp.example.com/authorize",
3384 "https://idp.example.com/token",
3385 "client",
3386 )
3387 .revocation_url("http://idp.example.com/revoke")
3388 .build(),
3389 );
3390 let err = cfg
3391 .validate()
3392 .expect_err("http revocation_url must be rejected");
3393 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3394 }
3395
3396 #[test]
3399 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3400 let mut cfg = validation_https_config();
3401 cfg.proxy = Some(
3402 OAuthProxyConfig::builder(
3403 "https://idp.example.com/authorize",
3404 "https://idp.example.com/token",
3405 "client",
3406 )
3407 .introspection_url("https://idp.example.com/introspect")
3408 .expose_admin_endpoints(true)
3409 .build(),
3410 );
3411 let err = cfg
3412 .validate()
3413 .expect_err("expose_admin_endpoints without auth must fail");
3414 let msg = err.to_string();
3415 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3416 assert!(
3417 msg.contains("allow_unauthenticated_admin_endpoints"),
3418 "{msg}"
3419 );
3420 }
3421
3422 #[test]
3423 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3424 let mut cfg = validation_https_config();
3425 cfg.proxy = Some(
3426 OAuthProxyConfig::builder(
3427 "https://idp.example.com/authorize",
3428 "https://idp.example.com/token",
3429 "client",
3430 )
3431 .introspection_url("https://idp.example.com/introspect")
3432 .expose_admin_endpoints(true)
3433 .require_auth_on_admin_endpoints(true)
3434 .build(),
3435 );
3436 cfg.validate()
3437 .expect("authed admin endpoints must validate");
3438 }
3439
3440 #[test]
3441 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3442 let mut cfg = validation_https_config();
3443 cfg.proxy = Some(
3444 OAuthProxyConfig::builder(
3445 "https://idp.example.com/authorize",
3446 "https://idp.example.com/token",
3447 "client",
3448 )
3449 .introspection_url("https://idp.example.com/introspect")
3450 .expose_admin_endpoints(true)
3451 .allow_unauthenticated_admin_endpoints(true)
3452 .build(),
3453 );
3454 cfg.validate()
3455 .expect("explicit unauth opt-out must validate");
3456 }
3457
3458 #[test]
3459 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3460 let mut cfg = validation_https_config();
3463 cfg.proxy = Some(
3464 OAuthProxyConfig::builder(
3465 "https://idp.example.com/authorize",
3466 "https://idp.example.com/token",
3467 "client",
3468 )
3469 .introspection_url("https://idp.example.com/introspect")
3470 .build(),
3471 );
3472 cfg.validate()
3473 .expect("unexposed admin endpoints must validate");
3474 }
3475
3476 #[test]
3477 fn validate_rejects_http_token_exchange_url() {
3478 let mut cfg = validation_https_config();
3479 cfg.token_exchange = Some(TokenExchangeConfig::new(
3480 "http://idp.example.com/token".into(), "client".into(),
3482 None,
3483 None,
3484 "downstream".into(),
3485 ));
3486 let err = cfg
3487 .validate()
3488 .expect_err("http token_exchange.token_url must be rejected");
3489 assert!(
3490 err.to_string().contains("oauth.token_exchange.token_url"),
3491 "error must reference token_exchange.token_url; got {err}"
3492 );
3493 }
3494
3495 #[test]
3496 fn validate_rejects_unparseable_url() {
3497 let mut cfg = validation_https_config();
3498 cfg.jwks_uri = "not a url".into();
3499 let err = cfg
3500 .validate()
3501 .expect_err("unparseable URL must be rejected");
3502 assert!(err.to_string().contains("invalid URL"));
3503 }
3504
3505 #[test]
3506 fn validate_rejects_non_http_scheme() {
3507 let mut cfg = validation_https_config();
3508 cfg.jwks_uri = "file:///etc/passwd".into();
3509 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3510 let msg = err.to_string();
3511 assert!(
3512 msg.contains("must use https scheme") && msg.contains("file"),
3513 "error must reject non-http(s) schemes; got {msg:?}"
3514 );
3515 }
3516
3517 #[test]
3518 fn validate_accepts_http_with_escape_hatch() {
3519 let mut cfg = OAuthConfig::builder(
3524 "http://auth.local",
3525 "mcp",
3526 "http://auth.local/.well-known/jwks.json",
3527 )
3528 .allow_http_oauth_urls(true)
3529 .build();
3530 cfg.proxy = Some(
3531 OAuthProxyConfig::builder(
3532 "http://idp.local/authorize",
3533 "http://idp.local/token",
3534 "client",
3535 )
3536 .introspection_url("http://idp.local/introspect")
3537 .revocation_url("http://idp.local/revoke")
3538 .build(),
3539 );
3540 cfg.token_exchange = Some(TokenExchangeConfig::new(
3541 "http://idp.local/token".into(),
3542 "client".into(),
3543 Some(secrecy::SecretString::new("dev-secret".into())),
3544 None,
3545 "downstream".into(),
3546 ));
3547 cfg.validate()
3548 .expect("escape hatch must permit http on all URL fields");
3549 }
3550
3551 #[test]
3552 fn validate_with_escape_hatch_still_rejects_unparseable() {
3553 let mut cfg = validation_https_config();
3556 cfg.allow_http_oauth_urls = true;
3557 cfg.jwks_uri = "::not-a-url::".into();
3558 cfg.validate()
3559 .expect_err("escape hatch must NOT bypass URL parsing");
3560 }
3561
3562 #[tokio::test]
3563 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3564 rustls::crypto::ring::default_provider()
3579 .install_default()
3580 .ok();
3581
3582 let policy = reqwest::redirect::Policy::custom(|attempt| {
3583 if attempt.url().scheme() != "https" {
3584 attempt.error("redirect to non-HTTPS URL refused")
3585 } else if attempt.previous().len() >= 2 {
3586 attempt.error("too many redirects (max 2)")
3587 } else {
3588 attempt.follow()
3589 }
3590 });
3591 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3598 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3599 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3600 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3601 );
3602 let client = reqwest::Client::builder()
3603 .no_proxy()
3604 .dns_resolver(Arc::clone(&resolver))
3605 .timeout(Duration::from_secs(5))
3606 .connect_timeout(Duration::from_secs(3))
3607 .redirect(policy)
3608 .build()
3609 .expect("test client builds");
3610
3611 let mock = wiremock::MockServer::start().await;
3612 wiremock::Mock::given(wiremock::matchers::method("GET"))
3613 .and(wiremock::matchers::path("/jwks.json"))
3614 .respond_with(
3615 wiremock::ResponseTemplate::new(302)
3616 .insert_header("location", "http://example.invalid/jwks.json"),
3617 )
3618 .mount(&mock)
3619 .await;
3620
3621 let url = format!("{}/jwks.json", mock.uri());
3630 let err = client
3631 .get(&url)
3632 .send()
3633 .await
3634 .expect_err("redirect policy must reject scheme downgrade");
3635 let chain = format!("{err:#}");
3636 assert!(
3637 chain.contains("redirect to non-HTTPS URL refused")
3638 || chain.to_lowercase().contains("redirect"),
3639 "error must surface redirect-policy rejection; got {chain:?}"
3640 );
3641 }
3642
3643 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3648
3649 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3651 let mut rng = rsa::rand_core::OsRng;
3652 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3653 let private_pem = private_key
3654 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3655 .expect("PKCS8 PEM export")
3656 .to_string();
3657
3658 let public_key = private_key.to_public_key();
3659 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3660 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3661
3662 let jwks = serde_json::json!({
3663 "keys": [{
3664 "kty": "RSA",
3665 "use": "sig",
3666 "alg": "RS256",
3667 "kid": kid,
3668 "n": n,
3669 "e": e
3670 }]
3671 });
3672
3673 (private_pem, jwks)
3674 }
3675
3676 fn mint_token(
3678 private_pem: &str,
3679 kid: &str,
3680 issuer: &str,
3681 audience: &str,
3682 subject: &str,
3683 scope: &str,
3684 ) -> String {
3685 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3686 .expect("encoding key from PEM");
3687 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3688 header.kid = Some(kid.into());
3689
3690 let now = jsonwebtoken::get_current_timestamp();
3691 let claims = serde_json::json!({
3692 "iss": issuer,
3693 "aud": audience,
3694 "sub": subject,
3695 "scope": scope,
3696 "exp": now + 3600,
3697 "iat": now,
3698 });
3699
3700 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3701 }
3702
3703 fn test_config(jwks_uri: &str) -> OAuthConfig {
3704 OAuthConfig {
3705 issuer: "https://auth.test.local".into(),
3706 audience: "https://mcp.test.local/mcp".into(),
3707 jwks_uri: jwks_uri.into(),
3708 scopes: vec![
3709 ScopeMapping {
3710 scope: "mcp:read".into(),
3711 role: "viewer".into(),
3712 },
3713 ScopeMapping {
3714 scope: "mcp:admin".into(),
3715 role: "ops".into(),
3716 },
3717 ],
3718 role_claim: None,
3719 role_mappings: vec![],
3720 jwks_cache_ttl: "5m".into(),
3721 proxy: None,
3722 token_exchange: None,
3723 ca_cert_path: None,
3724 allow_http_oauth_urls: true,
3725 max_jwks_keys: default_max_jwks_keys(),
3726 #[allow(
3727 deprecated,
3728 reason = "test fixture: explicit value for the deprecated field"
3729 )]
3730 strict_audience_validation: false,
3731 audience_validation_mode: None,
3732 jwks_max_response_bytes: default_jwks_max_bytes(),
3733 ssrf_allowlist: None,
3734 }
3735 }
3736
3737 fn test_cache(config: &OAuthConfig) -> JwksCache {
3738 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3739 }
3740
3741 #[tokio::test]
3742 async fn valid_jwt_returns_identity() {
3743 let kid = "test-key-1";
3744 let (pem, jwks) = generate_test_keypair(kid);
3745
3746 let mock_server = wiremock::MockServer::start().await;
3747 wiremock::Mock::given(wiremock::matchers::method("GET"))
3748 .and(wiremock::matchers::path("/jwks.json"))
3749 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3750 .mount(&mock_server)
3751 .await;
3752
3753 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3754 let config = test_config(&jwks_uri);
3755 let cache = test_cache(&config);
3756
3757 let token = mint_token(
3758 &pem,
3759 kid,
3760 "https://auth.test.local",
3761 "https://mcp.test.local/mcp",
3762 "ci-bot",
3763 "mcp:read mcp:other",
3764 );
3765
3766 let identity = cache.validate_token(&token).await;
3767 assert!(identity.is_some(), "valid JWT should authenticate");
3768 let id = identity.unwrap();
3769 assert_eq!(id.name, "ci-bot");
3770 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3772 }
3773
3774 #[tokio::test]
3775 async fn wrong_issuer_rejected() {
3776 let kid = "test-key-2";
3777 let (pem, jwks) = generate_test_keypair(kid);
3778
3779 let mock_server = wiremock::MockServer::start().await;
3780 wiremock::Mock::given(wiremock::matchers::method("GET"))
3781 .and(wiremock::matchers::path("/jwks.json"))
3782 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3783 .mount(&mock_server)
3784 .await;
3785
3786 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3787 let config = test_config(&jwks_uri);
3788 let cache = test_cache(&config);
3789
3790 let token = mint_token(
3791 &pem,
3792 kid,
3793 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3795 "attacker",
3796 "mcp:admin",
3797 );
3798
3799 assert!(cache.validate_token(&token).await.is_none());
3800 }
3801
3802 #[tokio::test]
3803 async fn wrong_audience_rejected() {
3804 let kid = "test-key-3";
3805 let (pem, jwks) = generate_test_keypair(kid);
3806
3807 let mock_server = wiremock::MockServer::start().await;
3808 wiremock::Mock::given(wiremock::matchers::method("GET"))
3809 .and(wiremock::matchers::path("/jwks.json"))
3810 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3811 .mount(&mock_server)
3812 .await;
3813
3814 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3815 let config = test_config(&jwks_uri);
3816 let cache = test_cache(&config);
3817
3818 let token = mint_token(
3819 &pem,
3820 kid,
3821 "https://auth.test.local",
3822 "https://wrong-audience.example.com", "attacker",
3824 "mcp:admin",
3825 );
3826
3827 assert!(cache.validate_token(&token).await.is_none());
3828 }
3829
3830 #[tokio::test]
3831 async fn expired_jwt_rejected() {
3832 let kid = "test-key-4";
3833 let (pem, jwks) = generate_test_keypair(kid);
3834
3835 let mock_server = wiremock::MockServer::start().await;
3836 wiremock::Mock::given(wiremock::matchers::method("GET"))
3837 .and(wiremock::matchers::path("/jwks.json"))
3838 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3839 .mount(&mock_server)
3840 .await;
3841
3842 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3843 let config = test_config(&jwks_uri);
3844 let cache = test_cache(&config);
3845
3846 let encoding_key =
3848 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3849 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3850 header.kid = Some(kid.into());
3851 let now = jsonwebtoken::get_current_timestamp();
3852 let claims = serde_json::json!({
3853 "iss": "https://auth.test.local",
3854 "aud": "https://mcp.test.local/mcp",
3855 "sub": "expired-bot",
3856 "scope": "mcp:read",
3857 "exp": now - 120,
3858 "iat": now - 3720,
3859 });
3860 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3861
3862 assert!(cache.validate_token(&token).await.is_none());
3863 }
3864
3865 #[tokio::test]
3866 async fn no_matching_scope_rejected() {
3867 let kid = "test-key-5";
3868 let (pem, jwks) = generate_test_keypair(kid);
3869
3870 let mock_server = wiremock::MockServer::start().await;
3871 wiremock::Mock::given(wiremock::matchers::method("GET"))
3872 .and(wiremock::matchers::path("/jwks.json"))
3873 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3874 .mount(&mock_server)
3875 .await;
3876
3877 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3878 let config = test_config(&jwks_uri);
3879 let cache = test_cache(&config);
3880
3881 let token = mint_token(
3882 &pem,
3883 kid,
3884 "https://auth.test.local",
3885 "https://mcp.test.local/mcp",
3886 "limited-bot",
3887 "some:other:scope", );
3889
3890 assert!(cache.validate_token(&token).await.is_none());
3891 }
3892
3893 #[tokio::test]
3894 async fn wrong_signing_key_rejected() {
3895 let kid = "test-key-6";
3896 let (_pem, jwks) = generate_test_keypair(kid);
3897
3898 let (attacker_pem, _) = generate_test_keypair(kid);
3900
3901 let mock_server = wiremock::MockServer::start().await;
3902 wiremock::Mock::given(wiremock::matchers::method("GET"))
3903 .and(wiremock::matchers::path("/jwks.json"))
3904 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3905 .mount(&mock_server)
3906 .await;
3907
3908 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3909 let config = test_config(&jwks_uri);
3910 let cache = test_cache(&config);
3911
3912 let token = mint_token(
3914 &attacker_pem,
3915 kid,
3916 "https://auth.test.local",
3917 "https://mcp.test.local/mcp",
3918 "attacker",
3919 "mcp:admin",
3920 );
3921
3922 assert!(cache.validate_token(&token).await.is_none());
3923 }
3924
3925 #[tokio::test]
3926 async fn admin_scope_maps_to_ops_role() {
3927 let kid = "test-key-7";
3928 let (pem, jwks) = generate_test_keypair(kid);
3929
3930 let mock_server = wiremock::MockServer::start().await;
3931 wiremock::Mock::given(wiremock::matchers::method("GET"))
3932 .and(wiremock::matchers::path("/jwks.json"))
3933 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3934 .mount(&mock_server)
3935 .await;
3936
3937 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3938 let config = test_config(&jwks_uri);
3939 let cache = test_cache(&config);
3940
3941 let token = mint_token(
3942 &pem,
3943 kid,
3944 "https://auth.test.local",
3945 "https://mcp.test.local/mcp",
3946 "admin-bot",
3947 "mcp:admin",
3948 );
3949
3950 let id = cache
3951 .validate_token(&token)
3952 .await
3953 .expect("should authenticate");
3954 assert_eq!(id.role, "ops");
3955 assert_eq!(id.name, "admin-bot");
3956 }
3957
3958 #[tokio::test]
3959 async fn jwks_server_down_returns_none() {
3960 let config = test_config("http://127.0.0.1:1/jwks.json");
3962 let cache = test_cache(&config);
3963
3964 let kid = "orphan-key";
3965 let (pem, _) = generate_test_keypair(kid);
3966 let token = mint_token(
3967 &pem,
3968 kid,
3969 "https://auth.test.local",
3970 "https://mcp.test.local/mcp",
3971 "bot",
3972 "mcp:read",
3973 );
3974
3975 assert!(cache.validate_token(&token).await.is_none());
3976 }
3977
3978 #[test]
3983 fn resolve_claim_path_flat_string() {
3984 let mut extra = HashMap::new();
3985 extra.insert(
3986 "scope".into(),
3987 serde_json::Value::String("mcp:read mcp:admin".into()),
3988 );
3989 let values = resolve_claim_path(&extra, "scope");
3990 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3991 }
3992
3993 #[test]
3994 fn resolve_claim_path_flat_array() {
3995 let mut extra = HashMap::new();
3996 extra.insert(
3997 "roles".into(),
3998 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3999 );
4000 let values = resolve_claim_path(&extra, "roles");
4001 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
4002 }
4003
4004 #[test]
4005 fn resolve_claim_path_nested_keycloak() {
4006 let mut extra = HashMap::new();
4007 extra.insert(
4008 "realm_access".into(),
4009 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
4010 );
4011 let values = resolve_claim_path(&extra, "realm_access.roles");
4012 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
4013 }
4014
4015 #[test]
4016 fn resolve_claim_path_missing_returns_empty() {
4017 let extra = HashMap::new();
4018 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
4019 }
4020
4021 #[test]
4022 fn resolve_claim_path_numeric_leaf_returns_empty() {
4023 let mut extra = HashMap::new();
4024 extra.insert("count".into(), serde_json::json!(42));
4025 assert!(resolve_claim_path(&extra, "count").is_empty());
4026 }
4027
4028 fn make_claims(json: serde_json::Value) -> Claims {
4029 serde_json::from_value(json).expect("test claims must deserialize")
4030 }
4031
4032 #[test]
4033 fn first_class_scope_claim_splits_on_whitespace() {
4034 let claims = make_claims(serde_json::json!({
4035 "iss": "https://issuer.example.com",
4036 "exp": 9_999_999_999_u64,
4037 "scope": "read write admin",
4038 }));
4039 let values = first_class_claim_values(&claims, "scope");
4040 assert_eq!(values, vec!["read", "write", "admin"]);
4041 }
4042
4043 #[test]
4044 fn first_class_sub_claim_returns_single_value() {
4045 let claims = make_claims(serde_json::json!({
4046 "iss": "https://issuer.example.com",
4047 "exp": 9_999_999_999_u64,
4048 "sub": "service-account-orders",
4049 }));
4050 let values = first_class_claim_values(&claims, "sub");
4051 assert_eq!(values, vec!["service-account-orders"]);
4052 }
4053
4054 #[test]
4055 fn first_class_aud_claim_returns_every_audience() {
4056 let claims = make_claims(serde_json::json!({
4057 "iss": "https://issuer.example.com",
4058 "exp": 9_999_999_999_u64,
4059 "aud": ["api-a", "api-b"],
4060 }));
4061 let values = first_class_claim_values(&claims, "aud");
4062 assert_eq!(values, vec!["api-a", "api-b"]);
4063 }
4064
4065 #[test]
4066 fn first_class_unknown_path_returns_empty() {
4067 let claims = make_claims(serde_json::json!({
4068 "iss": "https://issuer.example.com",
4069 "exp": 9_999_999_999_u64,
4070 }));
4071 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
4072 }
4073
4074 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
4080 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
4081 .expect("encoding key from PEM");
4082 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
4083 header.kid = Some(kid.into());
4084 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
4085 }
4086
4087 fn test_config_with_role_claim(
4088 jwks_uri: &str,
4089 role_claim: &str,
4090 role_mappings: Vec<RoleMapping>,
4091 ) -> OAuthConfig {
4092 OAuthConfig {
4093 issuer: "https://auth.test.local".into(),
4094 audience: "https://mcp.test.local/mcp".into(),
4095 jwks_uri: jwks_uri.into(),
4096 scopes: vec![],
4097 role_claim: Some(role_claim.into()),
4098 role_mappings,
4099 jwks_cache_ttl: "5m".into(),
4100 proxy: None,
4101 token_exchange: None,
4102 ca_cert_path: None,
4103 allow_http_oauth_urls: true,
4104 max_jwks_keys: default_max_jwks_keys(),
4105 #[allow(
4106 deprecated,
4107 reason = "test fixture: explicit value for the deprecated field"
4108 )]
4109 strict_audience_validation: false,
4110 audience_validation_mode: None,
4111 jwks_max_response_bytes: default_jwks_max_bytes(),
4112 ssrf_allowlist: None,
4113 }
4114 }
4115
4116 #[tokio::test]
4117 async fn screen_oauth_target_rejects_literal_ip() {
4118 let err = screen_oauth_target(
4119 "https://127.0.0.1/jwks.json",
4120 false,
4121 &crate::ssrf::CompiledSsrfAllowlist::default(),
4122 )
4123 .await
4124 .expect_err("literal IPs must be rejected");
4125 let msg = err.to_string();
4126 assert!(msg.contains("literal IPv4 addresses are forbidden"));
4127 }
4128
4129 #[tokio::test]
4130 async fn screen_oauth_target_rejects_private_dns_resolution() {
4131 let err = screen_oauth_target(
4132 "https://localhost/jwks.json",
4133 false,
4134 &crate::ssrf::CompiledSsrfAllowlist::default(),
4135 )
4136 .await
4137 .expect_err("localhost resolution must be rejected");
4138 let msg = err.to_string();
4139 assert!(
4140 msg.contains("blocked IP") && msg.contains("loopback"),
4141 "got {msg:?}"
4142 );
4143 }
4144
4145 #[tokio::test]
4146 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
4147 let err = screen_oauth_target(
4148 "http://127.0.0.1/jwks.json",
4149 true,
4150 &crate::ssrf::CompiledSsrfAllowlist::default(),
4151 )
4152 .await
4153 .expect_err("literal IPs must still be rejected when http is allowed");
4154 let msg = err.to_string();
4155 assert!(msg.contains("literal IPv4 addresses are forbidden"));
4156 }
4157
4158 #[tokio::test]
4159 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
4160 let err = screen_oauth_target(
4161 "http://localhost/jwks.json",
4162 true,
4163 &crate::ssrf::CompiledSsrfAllowlist::default(),
4164 )
4165 .await
4166 .expect_err("private DNS resolution must still be rejected when http is allowed");
4167 let msg = err.to_string();
4168 assert!(
4169 msg.contains("blocked IP") && msg.contains("loopback"),
4170 "got {msg:?}"
4171 );
4172 }
4173
4174 #[tokio::test]
4175 async fn screen_oauth_target_allows_public_hostname() {
4176 screen_oauth_target(
4177 "https://example.com/.well-known/jwks.json",
4178 false,
4179 &crate::ssrf::CompiledSsrfAllowlist::default(),
4180 )
4181 .await
4182 .expect("public hostname should pass screening");
4183 }
4184
4185 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4191 let raw = OAuthSsrfAllowlist {
4192 hosts: hosts.iter().map(|s| (*s).to_owned()).collect(),
4193 cidrs: cidrs.iter().map(|s| (*s).to_owned()).collect(),
4194 };
4195 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4196 }
4197
4198 #[test]
4199 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4200 let raw = OAuthSsrfAllowlist {
4201 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4202 cidrs: vec![],
4203 };
4204 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4205 assert_eq!(compiled.host_count(), 1);
4206 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4207 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4208 }
4209
4210 #[test]
4211 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4212 let raw = OAuthSsrfAllowlist {
4213 hosts: vec!["10.0.0.1".into()],
4214 cidrs: vec![],
4215 };
4216 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4217 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4218 }
4219
4220 #[test]
4221 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4222 let raw = OAuthSsrfAllowlist {
4223 hosts: vec!["rhbk.ops.example.com:8443".into()],
4224 cidrs: vec![],
4225 };
4226 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4227 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4228 }
4229
4230 #[test]
4231 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4232 let raw = OAuthSsrfAllowlist {
4233 hosts: vec![],
4234 cidrs: vec!["not-a-cidr".into()],
4235 };
4236 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4237 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4238 }
4239
4240 #[test]
4241 fn validate_rejects_misconfigured_allowlist() {
4242 let mut cfg = OAuthConfig::builder(
4243 "https://auth.example.com/",
4244 "mcp",
4245 "https://auth.example.com/jwks.json",
4246 )
4247 .build();
4248 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4249 hosts: vec!["10.0.0.1".into()],
4250 cidrs: vec![],
4251 });
4252 let err = cfg
4253 .validate()
4254 .expect_err("literal IP host must be rejected");
4255 assert!(
4256 err.to_string().contains("oauth.ssrf_allowlist"),
4257 "got {err}"
4258 );
4259 }
4260
4261 #[tokio::test]
4262 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4263 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4267 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4268 .await
4269 .expect_err("loopback must still be blocked when not in allowlist");
4270 let msg = err.to_string();
4271 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4272 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4273 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4274 }
4275
4276 #[tokio::test]
4277 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4278 let err = screen_oauth_target(
4281 "https://localhost/jwks.json",
4282 false,
4283 &crate::ssrf::CompiledSsrfAllowlist::default(),
4284 )
4285 .await
4286 .expect_err("loopback rejection");
4287 let msg = err.to_string();
4288 assert!(msg.contains("blocked IP"), "got {msg:?}");
4289 assert!(msg.contains("loopback"), "got {msg:?}");
4290 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4292 }
4293
4294 #[tokio::test]
4295 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4296 let allow = make_allowlist(&["localhost"], &[]);
4298 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4299 .await
4300 .expect("allowlisted host must pass");
4301 }
4302
4303 #[tokio::test]
4304 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4305 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4308 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4309 .await
4310 .expect("allowlisted CIDR must pass");
4311 }
4312
4313 #[tokio::test]
4314 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4315 let mut cfg = OAuthConfig::builder(
4316 "https://auth.example.com/",
4317 "mcp",
4318 "https://auth.example.com/jwks.json",
4319 )
4320 .build();
4321 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4322 hosts: vec![],
4323 cidrs: vec!["bad-cidr".into()],
4324 });
4325 let Err(err) = JwksCache::new(&cfg) else {
4326 panic!("invalid CIDR must fail JwksCache::new")
4327 };
4328 let msg = err.to_string();
4329 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4330 }
4331
4332 #[tokio::test]
4333 async fn jwks_cache_new_invalid_ttl_is_err() {
4334 let cfg = OAuthConfig::builder(
4337 "https://auth.example.com/",
4338 "mcp",
4339 "https://auth.example.com/jwks.json",
4340 )
4341 .jwks_cache_ttl("not-a-duration")
4342 .build();
4343 let Err(err) = JwksCache::new(&cfg) else {
4344 panic!("invalid jwks_cache_ttl must fail JwksCache::new")
4345 };
4346 let msg = err.to_string();
4347 assert!(msg.contains("jwks_cache_ttl"), "got {msg:?}");
4348 }
4349
4350 #[tokio::test]
4351 async fn audience_falls_back_to_azp_by_default() {
4352 let kid = "test-audience-azp-default";
4353 let (pem, jwks) = generate_test_keypair(kid);
4354
4355 let mock_server = wiremock::MockServer::start().await;
4356 wiremock::Mock::given(wiremock::matchers::method("GET"))
4357 .and(wiremock::matchers::path("/jwks.json"))
4358 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4359 .mount(&mock_server)
4360 .await;
4361
4362 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4363 let config = test_config(&jwks_uri);
4364 let cache = test_cache(&config);
4365
4366 let now = jsonwebtoken::get_current_timestamp();
4367 let token = mint_token_with_claims(
4368 &pem,
4369 kid,
4370 &serde_json::json!({
4371 "iss": "https://auth.test.local",
4372 "aud": "https://some-other-resource.example.com",
4373 "azp": "https://mcp.test.local/mcp",
4374 "sub": "compat-client",
4375 "scope": "mcp:read",
4376 "exp": now + 3600,
4377 "iat": now,
4378 }),
4379 );
4380
4381 let identity = cache
4382 .validate_token_with_reason(&token)
4383 .await
4384 .expect("azp fallback should remain enabled by default");
4385 assert_eq!(identity.role, "viewer");
4386 }
4387
4388 #[tokio::test]
4389 async fn strict_audience_validation_rejects_azp_only_match() {
4390 let kid = "test-audience-azp-strict";
4391 let (pem, jwks) = generate_test_keypair(kid);
4392
4393 let mock_server = wiremock::MockServer::start().await;
4394 wiremock::Mock::given(wiremock::matchers::method("GET"))
4395 .and(wiremock::matchers::path("/jwks.json"))
4396 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4397 .mount(&mock_server)
4398 .await;
4399
4400 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4401 let mut config = test_config(&jwks_uri);
4402 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4403 {
4404 config.strict_audience_validation = true;
4405 }
4406 let cache = test_cache(&config);
4407
4408 let now = jsonwebtoken::get_current_timestamp();
4409 let token = mint_token_with_claims(
4410 &pem,
4411 kid,
4412 &serde_json::json!({
4413 "iss": "https://auth.test.local",
4414 "aud": "https://some-other-resource.example.com",
4415 "azp": "https://mcp.test.local/mcp",
4416 "sub": "strict-client",
4417 "scope": "mcp:read",
4418 "exp": now + 3600,
4419 "iat": now,
4420 }),
4421 );
4422
4423 let failure = cache
4424 .validate_token_with_reason(&token)
4425 .await
4426 .expect_err("strict audience validation must ignore azp fallback");
4427 assert_eq!(failure, JwtValidationFailure::Invalid);
4428 }
4429
4430 #[tokio::test]
4431 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4432 let kid = "test-audience-warn-mode";
4433 let (pem, jwks) = generate_test_keypair(kid);
4434
4435 let mock_server = wiremock::MockServer::start().await;
4436 wiremock::Mock::given(wiremock::matchers::method("GET"))
4437 .and(wiremock::matchers::path("/jwks.json"))
4438 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4439 .mount(&mock_server)
4440 .await;
4441
4442 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4443 let mut config = test_config(&jwks_uri);
4444 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4445 let cache = test_cache(&config);
4446
4447 let now = jsonwebtoken::get_current_timestamp();
4448 let claims = serde_json::json!({
4449 "iss": "https://auth.test.local",
4450 "aud": "https://some-other-resource.example.com",
4451 "azp": "https://mcp.test.local/mcp",
4452 "sub": "warn-client",
4453 "scope": "mcp:read",
4454 "exp": now + 3600,
4455 "iat": now,
4456 });
4457 let token = mint_token_with_claims(&pem, kid, &claims);
4458
4459 let identity = cache
4460 .validate_token_with_reason(&token)
4461 .await
4462 .expect("warn mode must accept azp-only match");
4463 assert_eq!(identity.role, "viewer");
4464 assert!(
4465 cache.azp_fallback_warned.load(Ordering::Relaxed),
4466 "warn-once flag should be set after first azp-only match"
4467 );
4468
4469 let token2 = mint_token_with_claims(&pem, kid, &claims);
4470 cache
4471 .validate_token_with_reason(&token2)
4472 .await
4473 .expect("warn mode must continue accepting subsequent matches");
4474 assert!(
4475 cache.azp_fallback_warned.load(Ordering::Relaxed),
4476 "warn-once flag must remain set; the assertion guards against accidental clearing"
4477 );
4478 }
4479
4480 #[tokio::test]
4481 async fn permissive_mode_accepts_azp_only_match_silently() {
4482 let kid = "test-audience-permissive-mode";
4483 let (pem, jwks) = generate_test_keypair(kid);
4484
4485 let mock_server = wiremock::MockServer::start().await;
4486 wiremock::Mock::given(wiremock::matchers::method("GET"))
4487 .and(wiremock::matchers::path("/jwks.json"))
4488 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4489 .mount(&mock_server)
4490 .await;
4491
4492 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4493 let mut config = test_config(&jwks_uri);
4494 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4495 let cache = test_cache(&config);
4496
4497 let now = jsonwebtoken::get_current_timestamp();
4498 let token = mint_token_with_claims(
4499 &pem,
4500 kid,
4501 &serde_json::json!({
4502 "iss": "https://auth.test.local",
4503 "aud": "https://some-other-resource.example.com",
4504 "azp": "https://mcp.test.local/mcp",
4505 "sub": "permissive-client",
4506 "scope": "mcp:read",
4507 "exp": now + 3600,
4508 "iat": now,
4509 }),
4510 );
4511
4512 cache
4513 .validate_token_with_reason(&token)
4514 .await
4515 .expect("permissive mode must accept azp-only match");
4516 assert!(
4517 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4518 "permissive mode must not flip the warn-once flag"
4519 );
4520 }
4521
4522 #[test]
4523 fn audience_validation_mode_overrides_legacy_bool() {
4524 let mut config = OAuthConfig::default();
4525 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4526 {
4527 config.strict_audience_validation = false;
4528 }
4529 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4530 assert_eq!(
4531 config.effective_audience_validation_mode(),
4532 AudienceValidationMode::Strict,
4533 "explicit mode must override legacy false"
4534 );
4535
4536 let mut config = OAuthConfig::default();
4537 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4538 {
4539 config.strict_audience_validation = true;
4540 }
4541 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4542 assert_eq!(
4543 config.effective_audience_validation_mode(),
4544 AudienceValidationMode::Permissive,
4545 "explicit mode must override legacy true"
4546 );
4547 }
4548
4549 #[test]
4550 fn audience_validation_mode_default_is_warn_when_unset() {
4551 let config = OAuthConfig::default();
4552 assert_eq!(
4553 config.effective_audience_validation_mode(),
4554 AudienceValidationMode::Warn,
4555 "unset mode + unset bool must resolve to Warn (the new default)"
4556 );
4557 }
4558
4559 #[test]
4560 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4561 let mut config = OAuthConfig::default();
4562 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4563 {
4564 config.strict_audience_validation = true;
4565 }
4566 assert_eq!(
4567 config.effective_audience_validation_mode(),
4568 AudienceValidationMode::Strict,
4569 "legacy bool=true must resolve to Strict for backward compat"
4570 );
4571 }
4572
4573 #[derive(Clone, Default)]
4574 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4575
4576 impl CapturedLogs {
4577 fn contents(&self) -> String {
4578 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4579 String::from_utf8(bytes).unwrap_or_default()
4580 }
4581 }
4582
4583 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4584
4585 impl std::io::Write for CapturedLogsWriter {
4586 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4587 if let Ok(mut guard) = self.0.lock() {
4588 guard.extend_from_slice(buf);
4589 }
4590 Ok(buf.len())
4591 }
4592
4593 fn flush(&mut self) -> std::io::Result<()> {
4594 Ok(())
4595 }
4596 }
4597
4598 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4599 type Writer = CapturedLogsWriter;
4600
4601 fn make_writer(&'a self) -> Self::Writer {
4602 CapturedLogsWriter(Arc::clone(&self.0))
4603 }
4604 }
4605
4606 #[tokio::test]
4607 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4608 let kid = "oversized-jwks";
4609 let (_pem, jwks) = generate_test_keypair(kid);
4610 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4611 oversized_body.push_str(&" ".repeat(4096));
4612
4613 let mock_server = wiremock::MockServer::start().await;
4614 wiremock::Mock::given(wiremock::matchers::method("GET"))
4615 .and(wiremock::matchers::path("/jwks.json"))
4616 .respond_with(
4617 wiremock::ResponseTemplate::new(200)
4618 .insert_header("content-type", "application/json")
4619 .set_body_string(oversized_body),
4620 )
4621 .mount(&mock_server)
4622 .await;
4623
4624 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4625 let mut config = test_config(&jwks_uri);
4626 config.jwks_max_response_bytes = 256;
4627 let cache = test_cache(&config);
4628
4629 let logs = CapturedLogs::default();
4630 let subscriber = tracing_subscriber::fmt()
4631 .with_writer(logs.clone())
4632 .with_ansi(false)
4633 .without_time()
4634 .finish();
4635 let _guard = tracing::subscriber::set_default(subscriber);
4636
4637 let result = cache.fetch_jwks().await;
4638 assert!(result.is_none(), "oversized JWKS must be dropped");
4639 assert!(
4640 logs.contents()
4641 .contains("JWKS response exceeded configured size cap"),
4642 "expected cap-exceeded warning in logs"
4643 );
4644 }
4645
4646 #[tokio::test]
4650 async fn redirect_rejection_log_does_not_echo_credentials() {
4651 let mock_server = wiremock::MockServer::start().await;
4652 wiremock::Mock::given(wiremock::matchers::method("GET"))
4653 .and(wiremock::matchers::path("/jwks.json"))
4654 .respond_with(
4655 wiremock::ResponseTemplate::new(302)
4656 .insert_header("location", "https://u:p@redirect-target.example/next"),
4657 )
4658 .mount(&mock_server)
4659 .await;
4660
4661 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4662 let config = test_config(&jwks_uri);
4663 let cache = test_cache(&config);
4664
4665 let logs = CapturedLogs::default();
4666 let subscriber = tracing_subscriber::fmt()
4667 .with_writer(logs.clone())
4668 .with_ansi(false)
4669 .without_time()
4670 .finish();
4671 let _guard = tracing::subscriber::set_default(subscriber);
4672
4673 let result = cache.fetch_jwks().await;
4674 assert!(result.is_none(), "rejected redirect must fail the fetch");
4675 let contents = logs.contents();
4676 assert!(
4677 contents.contains("oauth redirect rejected"),
4678 "expected redirect-rejection warning in logs: {contents}"
4679 );
4680 assert!(
4681 !contents.contains("u:p"),
4682 "rejection log must not echo userinfo credentials: {contents}"
4683 );
4684 }
4685
4686 #[tokio::test]
4687 async fn role_claim_keycloak_nested_array() {
4688 let kid = "test-role-1";
4689 let (pem, jwks) = generate_test_keypair(kid);
4690
4691 let mock_server = wiremock::MockServer::start().await;
4692 wiremock::Mock::given(wiremock::matchers::method("GET"))
4693 .and(wiremock::matchers::path("/jwks.json"))
4694 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4695 .mount(&mock_server)
4696 .await;
4697
4698 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4699 let config = test_config_with_role_claim(
4700 &jwks_uri,
4701 "realm_access.roles",
4702 vec![
4703 RoleMapping {
4704 claim_value: "mcp-admin".into(),
4705 role: "ops".into(),
4706 },
4707 RoleMapping {
4708 claim_value: "mcp-viewer".into(),
4709 role: "viewer".into(),
4710 },
4711 ],
4712 );
4713 let cache = test_cache(&config);
4714
4715 let now = jsonwebtoken::get_current_timestamp();
4716 let token = mint_token_with_claims(
4717 &pem,
4718 kid,
4719 &serde_json::json!({
4720 "iss": "https://auth.test.local",
4721 "aud": "https://mcp.test.local/mcp",
4722 "sub": "keycloak-user",
4723 "exp": now + 3600,
4724 "iat": now,
4725 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4726 }),
4727 );
4728
4729 let id = cache
4730 .validate_token(&token)
4731 .await
4732 .expect("should authenticate");
4733 assert_eq!(id.name, "keycloak-user");
4734 assert_eq!(id.role, "ops");
4735 }
4736
4737 #[tokio::test]
4738 async fn role_claim_flat_roles_array() {
4739 let kid = "test-role-2";
4740 let (pem, jwks) = generate_test_keypair(kid);
4741
4742 let mock_server = wiremock::MockServer::start().await;
4743 wiremock::Mock::given(wiremock::matchers::method("GET"))
4744 .and(wiremock::matchers::path("/jwks.json"))
4745 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4746 .mount(&mock_server)
4747 .await;
4748
4749 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4750 let config = test_config_with_role_claim(
4751 &jwks_uri,
4752 "roles",
4753 vec![
4754 RoleMapping {
4755 claim_value: "MCP.Admin".into(),
4756 role: "ops".into(),
4757 },
4758 RoleMapping {
4759 claim_value: "MCP.Reader".into(),
4760 role: "viewer".into(),
4761 },
4762 ],
4763 );
4764 let cache = test_cache(&config);
4765
4766 let now = jsonwebtoken::get_current_timestamp();
4767 let token = mint_token_with_claims(
4768 &pem,
4769 kid,
4770 &serde_json::json!({
4771 "iss": "https://auth.test.local",
4772 "aud": "https://mcp.test.local/mcp",
4773 "sub": "azure-ad-user",
4774 "exp": now + 3600,
4775 "iat": now,
4776 "roles": ["MCP.Reader", "OtherApp.Admin"]
4777 }),
4778 );
4779
4780 let id = cache
4781 .validate_token(&token)
4782 .await
4783 .expect("should authenticate");
4784 assert_eq!(id.name, "azure-ad-user");
4785 assert_eq!(id.role, "viewer");
4786 }
4787
4788 #[tokio::test]
4789 async fn role_claim_no_matching_value_rejected() {
4790 let kid = "test-role-3";
4791 let (pem, jwks) = generate_test_keypair(kid);
4792
4793 let mock_server = wiremock::MockServer::start().await;
4794 wiremock::Mock::given(wiremock::matchers::method("GET"))
4795 .and(wiremock::matchers::path("/jwks.json"))
4796 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4797 .mount(&mock_server)
4798 .await;
4799
4800 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4801 let config = test_config_with_role_claim(
4802 &jwks_uri,
4803 "roles",
4804 vec![RoleMapping {
4805 claim_value: "mcp-admin".into(),
4806 role: "ops".into(),
4807 }],
4808 );
4809 let cache = test_cache(&config);
4810
4811 let now = jsonwebtoken::get_current_timestamp();
4812 let token = mint_token_with_claims(
4813 &pem,
4814 kid,
4815 &serde_json::json!({
4816 "iss": "https://auth.test.local",
4817 "aud": "https://mcp.test.local/mcp",
4818 "sub": "limited-user",
4819 "exp": now + 3600,
4820 "iat": now,
4821 "roles": ["some-other-role"]
4822 }),
4823 );
4824
4825 assert!(cache.validate_token(&token).await.is_none());
4826 }
4827
4828 #[tokio::test]
4829 async fn role_claim_space_separated_string() {
4830 let kid = "test-role-4";
4831 let (pem, jwks) = generate_test_keypair(kid);
4832
4833 let mock_server = wiremock::MockServer::start().await;
4834 wiremock::Mock::given(wiremock::matchers::method("GET"))
4835 .and(wiremock::matchers::path("/jwks.json"))
4836 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4837 .mount(&mock_server)
4838 .await;
4839
4840 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4841 let config = test_config_with_role_claim(
4842 &jwks_uri,
4843 "custom_scope",
4844 vec![
4845 RoleMapping {
4846 claim_value: "write".into(),
4847 role: "ops".into(),
4848 },
4849 RoleMapping {
4850 claim_value: "read".into(),
4851 role: "viewer".into(),
4852 },
4853 ],
4854 );
4855 let cache = test_cache(&config);
4856
4857 let now = jsonwebtoken::get_current_timestamp();
4858 let token = mint_token_with_claims(
4859 &pem,
4860 kid,
4861 &serde_json::json!({
4862 "iss": "https://auth.test.local",
4863 "aud": "https://mcp.test.local/mcp",
4864 "sub": "custom-client",
4865 "exp": now + 3600,
4866 "iat": now,
4867 "custom_scope": "read audit"
4868 }),
4869 );
4870
4871 let id = cache
4872 .validate_token(&token)
4873 .await
4874 .expect("should authenticate");
4875 assert_eq!(id.name, "custom-client");
4876 assert_eq!(id.role, "viewer");
4877 }
4878
4879 #[tokio::test]
4880 async fn scope_backward_compat_without_role_claim() {
4881 let kid = "test-compat-1";
4883 let (pem, jwks) = generate_test_keypair(kid);
4884
4885 let mock_server = wiremock::MockServer::start().await;
4886 wiremock::Mock::given(wiremock::matchers::method("GET"))
4887 .and(wiremock::matchers::path("/jwks.json"))
4888 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4889 .mount(&mock_server)
4890 .await;
4891
4892 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4893 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4895
4896 let token = mint_token(
4897 &pem,
4898 kid,
4899 "https://auth.test.local",
4900 "https://mcp.test.local/mcp",
4901 "legacy-bot",
4902 "mcp:admin other:scope",
4903 );
4904
4905 let id = cache
4906 .validate_token(&token)
4907 .await
4908 .expect("should authenticate");
4909 assert_eq!(id.name, "legacy-bot");
4910 assert_eq!(id.role, "ops"); }
4912
4913 #[tokio::test]
4918 async fn jwks_refresh_deduplication() {
4919 let kid = "test-dedup";
4922 let (pem, jwks) = generate_test_keypair(kid);
4923
4924 let mock_server = wiremock::MockServer::start().await;
4925 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4926 .and(wiremock::matchers::path("/jwks.json"))
4927 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4928 .expect(1) .mount(&mock_server)
4930 .await;
4931
4932 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4933 let config = test_config(&jwks_uri);
4934 let cache = Arc::new(test_cache(&config));
4935
4936 let token = mint_token(
4938 &pem,
4939 kid,
4940 "https://auth.test.local",
4941 "https://mcp.test.local/mcp",
4942 "concurrent-bot",
4943 "mcp:read",
4944 );
4945
4946 let mut handles = Vec::new();
4947 for _ in 0..5 {
4948 let c = Arc::clone(&cache);
4949 let t = token.clone();
4950 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4951 }
4952
4953 for h in handles {
4954 let result = h.await.unwrap();
4955 assert!(result.is_some(), "all concurrent requests should succeed");
4956 }
4957
4958 }
4960
4961 #[tokio::test]
4962 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4963 let kid = "test-cooldown";
4966 let (_pem, jwks) = generate_test_keypair(kid);
4967
4968 let mock_server = wiremock::MockServer::start().await;
4969 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4970 .and(wiremock::matchers::path("/jwks.json"))
4971 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4972 .expect(1) .mount(&mock_server)
4974 .await;
4975
4976 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4977 let config = test_config(&jwks_uri);
4978 let cache = test_cache(&config);
4979
4980 let fake_token1 =
4982 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4983 let _ = cache.validate_token(fake_token1).await;
4984
4985 let fake_token2 =
4988 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4989 let _ = cache.validate_token(fake_token2).await;
4990
4991 let fake_token3 =
4993 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4994 let _ = cache.validate_token(fake_token3).await;
4995
4996 }
4998
4999 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
5002 OAuthProxyConfig {
5003 authorize_url: "https://example.invalid/auth".into(),
5004 token_url: token_url.into(),
5005 client_id: "mcp-client".into(),
5006 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
5007 introspection_url: None,
5008 revocation_url: None,
5009 expose_admin_endpoints: false,
5010 require_auth_on_admin_endpoints: false,
5011 allow_unauthenticated_admin_endpoints: false,
5012 }
5013 }
5014
5015 fn test_http_client() -> OauthHttpClient {
5018 rustls::crypto::ring::default_provider()
5019 .install_default()
5020 .ok();
5021 let config = OAuthConfig::builder(
5022 "https://auth.test.local",
5023 "https://mcp.test.local/mcp",
5024 "https://auth.test.local/.well-known/jwks.json",
5025 )
5026 .allow_http_oauth_urls(true)
5027 .build();
5028 OauthHttpClient::with_config(&config)
5029 .expect("build test http client")
5030 .__test_allow_loopback_ssrf()
5031 }
5032
5033 #[tokio::test]
5034 async fn introspect_proxies_and_injects_client_credentials() {
5035 use wiremock::matchers::{body_string_contains, method, path};
5036
5037 let mock_server = wiremock::MockServer::start().await;
5038 wiremock::Mock::given(method("POST"))
5039 .and(path("/introspect"))
5040 .and(body_string_contains("client_id=mcp-client"))
5041 .and(body_string_contains("client_secret=shh"))
5042 .and(body_string_contains("token=abc"))
5043 .respond_with(
5044 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
5045 "active": true,
5046 "scope": "read"
5047 })),
5048 )
5049 .expect(1)
5050 .mount(&mock_server)
5051 .await;
5052
5053 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
5054 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
5055
5056 let http = test_http_client();
5057 let resp = handle_introspect(&http, &proxy, "token=abc").await;
5058 assert_eq!(resp.status(), 200);
5059 }
5060
5061 #[tokio::test]
5062 async fn token_proxy_fails_closed_on_oversized_upstream_response() {
5063 use http_body_util::BodyExt as _;
5064 use wiremock::matchers::{method, path};
5065
5066 let oversized = "x"
5068 .repeat(usize::try_from(OAUTH_PROXY_MAX_RESPONSE_BYTES).unwrap_or(usize::MAX) + 4096);
5069 let mock_server = wiremock::MockServer::start().await;
5070 wiremock::Mock::given(method("POST"))
5071 .and(path("/token"))
5072 .respond_with(wiremock::ResponseTemplate::new(200).set_body_string(oversized.clone()))
5073 .expect(1)
5074 .mount(&mock_server)
5075 .await;
5076
5077 let proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
5078 let http = test_http_client();
5079 let resp = handle_token(&http, &proxy, "grant_type=authorization_code&code=abc").await;
5080
5081 assert_eq!(
5083 resp.status(),
5084 502,
5085 "oversized upstream response must fail closed as 502"
5086 );
5087 let body = resp
5088 .into_body()
5089 .collect()
5090 .await
5091 .expect("collect body")
5092 .to_bytes();
5093 assert!(
5094 body.len() < 1024,
5095 "must return the small generic error body, not the oversized upstream body (got {} bytes)",
5096 body.len()
5097 );
5098 assert!(
5099 !body.windows(8).any(|w| w == b"xxxxxxxx"),
5100 "the oversized upstream payload must not be forwarded to the client"
5101 );
5102 }
5103
5104 #[tokio::test]
5105 async fn token_proxy_passes_through_normal_response() {
5106 use http_body_util::BodyExt as _;
5107 use wiremock::matchers::{method, path};
5108
5109 let mock_server = wiremock::MockServer::start().await;
5110 wiremock::Mock::given(method("POST"))
5111 .and(path("/token"))
5112 .respond_with(
5113 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
5114 "access_token": "at-123",
5115 "token_type": "Bearer"
5116 })),
5117 )
5118 .expect(1)
5119 .mount(&mock_server)
5120 .await;
5121
5122 let proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
5123 let http = test_http_client();
5124 let resp = handle_token(&http, &proxy, "grant_type=authorization_code&code=abc").await;
5125
5126 assert_eq!(
5127 resp.status(),
5128 200,
5129 "a normal-sized response must pass through"
5130 );
5131 let body = resp
5132 .into_body()
5133 .collect()
5134 .await
5135 .expect("collect body")
5136 .to_bytes();
5137 let json: serde_json::Value =
5138 serde_json::from_slice(&body).expect("upstream JSON preserved");
5139 assert_eq!(json["access_token"], "at-123");
5140 }
5141
5142 #[tokio::test]
5143 async fn introspect_returns_404_when_not_configured() {
5144 let proxy = proxy_cfg("https://example.invalid/token");
5145 let http = test_http_client();
5146 let resp = handle_introspect(&http, &proxy, "token=abc").await;
5147 assert_eq!(resp.status(), 404);
5148 }
5149
5150 #[tokio::test]
5151 async fn revoke_proxies_and_returns_upstream_status() {
5152 use wiremock::matchers::{method, path};
5153
5154 let mock_server = wiremock::MockServer::start().await;
5155 wiremock::Mock::given(method("POST"))
5156 .and(path("/revoke"))
5157 .respond_with(wiremock::ResponseTemplate::new(200))
5158 .expect(1)
5159 .mount(&mock_server)
5160 .await;
5161
5162 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
5163 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
5164
5165 let http = test_http_client();
5166 let resp = handle_revoke(&http, &proxy, "token=abc").await;
5167 assert_eq!(resp.status(), 200);
5168 }
5169
5170 #[tokio::test]
5171 async fn revoke_returns_404_when_not_configured() {
5172 let proxy = proxy_cfg("https://example.invalid/token");
5173 let http = test_http_client();
5174 let resp = handle_revoke(&http, &proxy, "token=abc").await;
5175 assert_eq!(resp.status(), 404);
5176 }
5177
5178 #[test]
5179 fn metadata_advertises_endpoints_only_when_configured() {
5180 let mut cfg = test_config("https://auth.test.local/jwks.json");
5181 let m = authorization_server_metadata("https://mcp.local", &cfg);
5183 assert!(m.get("introspection_endpoint").is_none());
5184 assert!(m.get("revocation_endpoint").is_none());
5185
5186 let mut proxy = proxy_cfg("https://upstream.local/token");
5189 proxy.introspection_url = Some("https://upstream.local/introspect".into());
5190 proxy.revocation_url = Some("https://upstream.local/revoke".into());
5191 cfg.proxy = Some(proxy);
5192 let m = authorization_server_metadata("https://mcp.local", &cfg);
5193 assert!(
5194 m.get("introspection_endpoint").is_none(),
5195 "introspection must not be advertised when expose_admin_endpoints=false"
5196 );
5197 assert!(
5198 m.get("revocation_endpoint").is_none(),
5199 "revocation must not be advertised when expose_admin_endpoints=false"
5200 );
5201
5202 if let Some(p) = cfg.proxy.as_mut() {
5204 p.expose_admin_endpoints = true;
5205 p.revocation_url = None;
5206 }
5207 let m = authorization_server_metadata("https://mcp.local", &cfg);
5208 assert_eq!(
5209 m["introspection_endpoint"],
5210 serde_json::Value::String("https://mcp.local/introspect".into())
5211 );
5212 assert!(m.get("revocation_endpoint").is_none());
5213
5214 if let Some(p) = cfg.proxy.as_mut() {
5216 p.revocation_url = Some("https://upstream.local/revoke".into());
5217 }
5218 let m = authorization_server_metadata("https://mcp.local", &cfg);
5219 assert_eq!(
5220 m["revocation_endpoint"],
5221 serde_json::Value::String("https://mcp.local/revoke".into())
5222 );
5223 }
5224
5225 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
5228 let mut cfg = validation_https_config();
5229 cfg.token_exchange = Some(tx);
5230 cfg
5231 }
5232
5233 fn tx_with(
5234 client_secret: Option<&str>,
5235 client_cert: Option<ClientCertConfig>,
5236 ) -> TokenExchangeConfig {
5237 TokenExchangeConfig::new(
5238 "https://idp.example.com/token".into(),
5239 "client".into(),
5240 client_secret.map(|s| secrecy::SecretString::new(s.into())),
5241 client_cert,
5242 "downstream".into(),
5243 )
5244 }
5245
5246 #[test]
5247 fn validate_rejects_token_exchange_without_client_auth() {
5248 let cfg = https_cfg_with_tx(tx_with(None, None));
5249 let err = cfg
5250 .validate()
5251 .expect_err("token_exchange without client auth must be rejected");
5252 let msg = err.to_string();
5253 assert!(
5254 msg.contains("requires client authentication"),
5255 "error must explain missing client auth; got {msg:?}"
5256 );
5257 }
5258
5259 #[test]
5260 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
5261 let cc = ClientCertConfig {
5262 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5263 key_path: PathBuf::from("/nonexistent/key.pem"),
5264 };
5265 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
5266 let err = cfg
5267 .validate()
5268 .expect_err("client_secret + client_cert must be rejected");
5269 let msg = err.to_string();
5270 assert!(
5271 msg.contains("mutually") && msg.contains("exclusive"),
5272 "error must explain mutual exclusion; got {msg:?}"
5273 );
5274 }
5275
5276 #[cfg(not(feature = "oauth-mtls-client"))]
5277 #[test]
5278 fn validate_rejects_client_cert_without_feature() {
5279 let cc = ClientCertConfig {
5280 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5281 key_path: PathBuf::from("/nonexistent/key.pem"),
5282 };
5283 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5284 let err = cfg
5285 .validate()
5286 .expect_err("client_cert without feature must be rejected");
5287 assert!(
5288 err.to_string().contains("oauth-mtls-client"),
5289 "error must reference the cargo feature; got {err}"
5290 );
5291 }
5292
5293 #[cfg(feature = "oauth-mtls-client")]
5294 #[test]
5295 fn validate_rejects_missing_client_cert_files() {
5296 let cc = ClientCertConfig {
5297 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5298 key_path: PathBuf::from("/nonexistent/key.pem"),
5299 };
5300 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5301 let err = cfg
5302 .validate()
5303 .expect_err("missing cert file must be rejected");
5304 assert!(
5305 err.to_string().contains("unreadable"),
5306 "error must call out unreadable file; got {err}"
5307 );
5308 }
5309
5310 #[cfg(feature = "oauth-mtls-client")]
5311 #[test]
5312 fn validate_rejects_malformed_client_cert_pem() {
5313 let dir = std::env::temp_dir();
5314 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5315 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5316 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5317 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5318 let cc = ClientCertConfig {
5319 cert_path: cert.clone(),
5320 key_path: key.clone(),
5321 };
5322 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5323 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5324 let _ = std::fs::remove_file(&cert);
5325 let _ = std::fs::remove_file(&key);
5326 assert!(
5327 err.to_string().contains("PEM parse failed"),
5328 "error must call out PEM parse failure; got {err}"
5329 );
5330 }
5331
5332 #[cfg(feature = "oauth-mtls-client")]
5333 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5334 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5335 let dir = std::env::temp_dir();
5336 let pid = std::process::id();
5337 let nonce: u64 = rand::random();
5338 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5339 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5340 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5341 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5342 (cert_path, key_path)
5343 }
5344
5345 #[cfg(feature = "oauth-mtls-client")]
5346 fn install_test_crypto_provider() {
5347 let _ = rustls::crypto::ring::default_provider().install_default();
5348 }
5349
5350 #[cfg(feature = "oauth-mtls-client")]
5351 #[test]
5352 fn validate_accepts_well_formed_client_cert() {
5353 install_test_crypto_provider();
5354 let (cert_path, key_path) = write_self_signed_pem();
5355 let cc = ClientCertConfig {
5356 cert_path: cert_path.clone(),
5357 key_path: key_path.clone(),
5358 };
5359 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5360 let res = cfg.validate();
5361 let _ = std::fs::remove_file(&cert_path);
5362 let _ = std::fs::remove_file(&key_path);
5363 res.expect("well-formed cert+key must validate");
5364 }
5365
5366 #[cfg(feature = "oauth-mtls-client")]
5367 #[test]
5368 fn client_for_returns_cached_mtls_client() {
5369 install_test_crypto_provider();
5370 let (cert_path, key_path) = write_self_signed_pem();
5371 let cc = ClientCertConfig {
5372 cert_path: cert_path.clone(),
5373 key_path: key_path.clone(),
5374 };
5375 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5376 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5377 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5378 let cert_client = http.client_for(tx_ref);
5379 let inner_client = http.client_for(&tx_with(Some("s"), None));
5380 let _ = std::fs::remove_file(&cert_path);
5381 let _ = std::fs::remove_file(&key_path);
5382 assert!(
5383 !std::ptr::eq(cert_client, inner_client),
5384 "client_for must return distinct clients for cert vs no-cert configs"
5385 );
5386 }
5387
5388 #[cfg(feature = "oauth-mtls-client")]
5389 #[test]
5390 fn client_for_falls_back_to_inner_when_cache_miss() {
5391 install_test_crypto_provider();
5392 let cfg = validation_https_config();
5393 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5394 let unrelated_cc = ClientCertConfig {
5395 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5396 key_path: PathBuf::from("/cache/miss/key.pem"),
5397 };
5398 let tx_unknown = tx_with(None, Some(unrelated_cc));
5399 let fallback = http.client_for(&tx_unknown);
5400 let inner = http.client_for(&tx_with(Some("s"), None));
5401 assert!(
5402 std::ptr::eq(fallback, inner),
5403 "cache miss must fall back to inner client"
5404 );
5405 }
5406}