1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
100async fn screen_oauth_target_with_test_override(
101 url: &str,
102 allow_http: bool,
103 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
104 #[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
105) -> Result<(), crate::error::McpxError> {
106 let parsed = check_oauth_url("oauth target", url, allow_http)?;
107 #[cfg(any(test, feature = "test-helpers"))]
108 if test_allow_loopback_ssrf {
109 return Ok(());
110 }
111 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
112 return Err(crate::error::McpxError::Config(format!(
113 "OAuth target forbidden ({reason}): {url}"
114 )));
115 }
116
117 let host = parsed.host_str().ok_or_else(|| {
118 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
119 })?;
120 let port = parsed.port_or_known_default().ok_or_else(|| {
121 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
122 })?;
123
124 let addrs = lookup_host((host, port)).await.map_err(|error| {
125 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
126 })?;
127
128 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
129 let mut any_addr = false;
130 for addr in addrs {
131 any_addr = true;
132 let ip = addr.ip();
133 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
134 if reason == "cloud_metadata" {
137 return Err(crate::error::McpxError::Config(format!(
138 "OAuth target resolved to blocked IP ({reason}): {url}"
139 )));
140 }
141 if allowlist.is_empty() {
145 return Err(crate::error::McpxError::Config(format!(
146 "OAuth target resolved to blocked IP ({reason}): {url}"
147 )));
148 }
149 if host_allowed || allowlist.ip_allowed(ip) {
151 continue;
152 }
153 return Err(crate::error::McpxError::Config(format!(
154 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
155 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
156 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
157 URL: {url}"
158 )));
159 }
160 }
161 if !any_addr {
162 return Err(crate::error::McpxError::Config(format!(
163 "OAuth target DNS resolution returned no addresses: {url}"
164 )));
165 }
166
167 Ok(())
168}
169
170async fn screen_oauth_target(
171 url: &str,
172 allow_http: bool,
173 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
174) -> Result<(), crate::error::McpxError> {
175 #[cfg(any(test, feature = "test-helpers"))]
176 {
177 screen_oauth_target_with_test_override(url, allow_http, allowlist, false).await
178 }
179 #[cfg(not(any(test, feature = "test-helpers")))]
180 {
181 let parsed = check_oauth_url("oauth target", url, allow_http)?;
182 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
183 return Err(crate::error::McpxError::Config(format!(
184 "OAuth target forbidden ({reason}): {url}"
185 )));
186 }
187
188 let host = parsed.host_str().ok_or_else(|| {
189 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
190 })?;
191 let port = parsed.port_or_known_default().ok_or_else(|| {
192 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
193 })?;
194
195 let addrs = lookup_host((host, port)).await.map_err(|error| {
196 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
197 })?;
198
199 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
200 let mut any_addr = false;
201 for addr in addrs {
202 any_addr = true;
203 let ip = addr.ip();
204 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
205 if reason == "cloud_metadata" {
206 return Err(crate::error::McpxError::Config(format!(
207 "OAuth target resolved to blocked IP ({reason}): {url}"
208 )));
209 }
210 if allowlist.is_empty() {
211 return Err(crate::error::McpxError::Config(format!(
212 "OAuth target resolved to blocked IP ({reason}): {url}"
213 )));
214 }
215 if host_allowed || allowlist.ip_allowed(ip) {
216 continue;
217 }
218 return Err(crate::error::McpxError::Config(format!(
219 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
220 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
221 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
222 URL: {url}"
223 )));
224 }
225 }
226 if !any_addr {
227 return Err(crate::error::McpxError::Config(format!(
228 "OAuth target DNS resolution returned no addresses: {url}"
229 )));
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Clone)]
277pub struct OauthHttpClient {
278 inner: reqwest::Client,
279 allow_http: bool,
280 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
285 #[cfg(feature = "oauth-mtls-client")]
290 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
291 #[cfg(any(test, feature = "test-helpers"))]
297 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
298}
299
300#[cfg(feature = "oauth-mtls-client")]
304#[derive(Debug, Clone, Hash, Eq, PartialEq)]
305struct MtlsClientKey {
306 cert_path: PathBuf,
307 key_path: PathBuf,
308}
309
310impl OauthHttpClient {
311 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
329 Self::build(Some(config))
330 }
331
332 #[deprecated(
355 since = "1.2.1",
356 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
357 )]
358 pub fn new() -> Result<Self, crate::error::McpxError> {
359 Self::build(None)
360 }
361
362 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
365 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
366
367 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
372 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
373 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
374 })?),
375 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
376 };
377
378 let redirect_allowlist = Arc::clone(&allowlist);
381
382 #[cfg(any(test, feature = "test-helpers"))]
386 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
387 Arc::new(AtomicBool::new(false));
388 #[cfg(not(any(test, feature = "test-helpers")))]
389 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
390
391 let resolver: Arc<dyn reqwest::dns::Resolve> =
392 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
393 Arc::clone(&allowlist),
394 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
398 test_bypass.clone(),
399 ));
400
401 let mut builder = reqwest::Client::builder()
402 .no_proxy()
406 .dns_resolver(Arc::clone(&resolver))
407 .connect_timeout(Duration::from_secs(10))
408 .timeout(Duration::from_secs(30))
409 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
410 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
420 Ok(()) => attempt.follow(),
421 Err(reason) => {
422 tracing::warn!(
423 reason = %reason,
424 target = %attempt.url(),
425 "oauth redirect rejected"
426 );
427 attempt.error(reason)
428 }
429 }
430 }));
431
432 if let Some(cfg) = config
433 && let Some(ref ca_path) = cfg.ca_cert_path
434 {
435 let pem = std::fs::read(ca_path).map_err(|e| {
440 crate::error::McpxError::Startup(format!(
441 "oauth http client: read ca_cert_path {}: {e}",
442 ca_path.display()
443 ))
444 })?;
445 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
446 crate::error::McpxError::Startup(format!(
447 "oauth http client: parse ca_cert_path {}: {e}",
448 ca_path.display()
449 ))
450 })?;
451 builder = builder.add_root_certificate(cert);
452 }
453
454 let inner = builder.build().map_err(|e| {
455 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
456 })?;
457
458 #[cfg(feature = "oauth-mtls-client")]
459 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
460
461 Ok(Self {
462 inner,
463 allow_http,
464 allowlist,
465 #[cfg(feature = "oauth-mtls-client")]
466 mtls_clients,
467 #[cfg(any(test, feature = "test-helpers"))]
468 test_allow_loopback_ssrf: test_bypass,
469 })
470 }
471
472 async fn send_screened(
473 &self,
474 url: &str,
475 request: reqwest::RequestBuilder,
476 ) -> Result<reqwest::Response, crate::error::McpxError> {
477 #[cfg(any(test, feature = "test-helpers"))]
478 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
479 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
480 .await?;
481 } else {
482 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
483 }
484 #[cfg(not(any(test, feature = "test-helpers")))]
485 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
486 request.send().await.map_err(|error| {
487 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
488 })
489 }
490
491 #[cfg(any(test, feature = "test-helpers"))]
496 #[doc(hidden)]
497 #[must_use]
498 pub fn __test_allow_loopback_ssrf(self) -> Self {
499 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
502 self
503 }
504
505 #[doc(hidden)]
511 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
512 self.inner.get(url).send().await
513 }
514
515 #[cfg(any(test, feature = "test-helpers"))]
521 #[doc(hidden)]
522 #[must_use]
523 pub fn __test_inner_client(&self) -> &reqwest::Client {
524 &self.inner
525 }
526
527 #[cfg(feature = "oauth-mtls-client")]
534 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
535 if let Some(cc) = &cfg.client_cert {
536 let key = MtlsClientKey {
537 cert_path: cc.cert_path.clone(),
538 key_path: cc.key_path.clone(),
539 };
540 if let Some(client) = self.mtls_clients.get(&key) {
541 return client;
542 }
543 }
544 &self.inner
545 }
546
547 #[cfg(not(feature = "oauth-mtls-client"))]
548 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
549 &self.inner
550 }
551}
552
553impl std::fmt::Debug for OauthHttpClient {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
556 }
557}
558
559#[derive(Debug, Clone, Default, Deserialize)]
619#[non_exhaustive]
620pub struct OAuthSsrfAllowlist {
621 #[serde(default)]
626 pub hosts: Vec<String>,
627 #[serde(default)]
633 pub cidrs: Vec<String>,
634}
635
636fn compile_oauth_ssrf_allowlist(
643 raw: &OAuthSsrfAllowlist,
644) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
645 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
646 for (idx, entry) in raw.hosts.iter().enumerate() {
647 let trimmed = entry.trim();
648 if trimmed.is_empty() {
649 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
650 }
651 if trimmed.contains([':', '/', '@', '?', '#']) {
655 return Err(format!(
656 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
657 (no scheme, port, path, userinfo, query, or fragment)"
658 ));
659 }
660 match url::Host::parse(trimmed) {
661 Ok(url::Host::Domain(_)) => {}
662 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
663 return Err(format!(
664 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
665 here -- list them via oauth.ssrf_allowlist.cidrs instead"
666 ));
667 }
668 Err(e) => {
669 return Err(format!(
670 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
671 ));
672 }
673 }
674 hosts.push(trimmed.to_ascii_lowercase());
675 }
676 hosts.sort();
677 hosts.dedup();
678
679 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
680 for (idx, entry) in raw.cidrs.iter().enumerate() {
681 let parsed = crate::ssrf::CidrEntry::parse(entry)
682 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
683 cidrs.push(parsed);
684 }
685
686 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
687}
688
689#[derive(Debug, Clone, Deserialize)]
691#[non_exhaustive]
692pub struct OAuthConfig {
693 pub issuer: String,
695 pub audience: String,
697 pub jwks_uri: String,
699 #[serde(default)]
702 pub scopes: Vec<ScopeMapping>,
703 pub role_claim: Option<String>,
709 #[serde(default)]
712 pub role_mappings: Vec<RoleMapping>,
713 #[serde(default = "default_jwks_cache_ttl")]
716 pub jwks_cache_ttl: String,
717 pub proxy: Option<OAuthProxyConfig>,
721 pub token_exchange: Option<TokenExchangeConfig>,
726 #[serde(default)]
741 pub ca_cert_path: Option<PathBuf>,
742 #[serde(default)]
754 pub allow_http_oauth_urls: bool,
755 #[serde(default)]
764 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
765 #[serde(default = "default_max_jwks_keys")]
769 pub max_jwks_keys: usize,
770 #[serde(default)]
779 #[deprecated(
780 since = "1.7.0",
781 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
782 )]
783 pub strict_audience_validation: bool,
784 #[serde(default)]
792 pub audience_validation_mode: Option<AudienceValidationMode>,
793 #[serde(default = "default_jwks_max_bytes")]
797 pub jwks_max_response_bytes: u64,
798}
799
800fn default_jwks_cache_ttl() -> String {
801 "10m".into()
802}
803
804const fn default_max_jwks_keys() -> usize {
805 256
806}
807
808const fn default_jwks_max_bytes() -> u64 {
809 1024 * 1024
810}
811
812#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
829#[serde(rename_all = "snake_case")]
830#[non_exhaustive]
831pub enum AudienceValidationMode {
832 Permissive,
836 #[default]
840 Warn,
841 Strict,
845}
846
847impl Default for OAuthConfig {
848 fn default() -> Self {
849 Self {
850 issuer: String::new(),
851 audience: String::new(),
852 jwks_uri: String::new(),
853 scopes: Vec::new(),
854 role_claim: None,
855 role_mappings: Vec::new(),
856 jwks_cache_ttl: default_jwks_cache_ttl(),
857 proxy: None,
858 token_exchange: None,
859 ca_cert_path: None,
860 allow_http_oauth_urls: false,
861 max_jwks_keys: default_max_jwks_keys(),
862 #[allow(
863 deprecated,
864 reason = "default-construct deprecated field for backward compat"
865 )]
866 strict_audience_validation: false,
867 audience_validation_mode: None,
868 jwks_max_response_bytes: default_jwks_max_bytes(),
869 ssrf_allowlist: None,
870 }
871 }
872}
873
874impl OAuthConfig {
875 #[must_use]
881 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
882 if let Some(mode) = self.audience_validation_mode {
883 return mode;
884 }
885 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
886 if self.strict_audience_validation {
887 AudienceValidationMode::Strict
888 } else {
889 AudienceValidationMode::Warn
890 }
891 }
892
893 pub fn builder(
899 issuer: impl Into<String>,
900 audience: impl Into<String>,
901 jwks_uri: impl Into<String>,
902 ) -> OAuthConfigBuilder {
903 OAuthConfigBuilder {
904 inner: Self {
905 issuer: issuer.into(),
906 audience: audience.into(),
907 jwks_uri: jwks_uri.into(),
908 ..Self::default()
909 },
910 }
911 }
912
913 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
929 let allow_http = self.allow_http_oauth_urls;
930 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
931 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
932 return Err(crate::error::McpxError::Config(format!(
933 "oauth.issuer forbidden ({reason})"
934 )));
935 }
936 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
937 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
938 return Err(crate::error::McpxError::Config(format!(
939 "oauth.jwks_uri forbidden ({reason})"
940 )));
941 }
942 if let Some(proxy) = &self.proxy {
943 let url = check_oauth_url(
944 "oauth.proxy.authorize_url",
945 &proxy.authorize_url,
946 allow_http,
947 )?;
948 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
949 return Err(crate::error::McpxError::Config(format!(
950 "oauth.proxy.authorize_url forbidden ({reason})"
951 )));
952 }
953 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
954 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
955 return Err(crate::error::McpxError::Config(format!(
956 "oauth.proxy.token_url forbidden ({reason})"
957 )));
958 }
959 if let Some(url) = &proxy.introspection_url {
960 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
961 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
962 return Err(crate::error::McpxError::Config(format!(
963 "oauth.proxy.introspection_url forbidden ({reason})"
964 )));
965 }
966 }
967 if let Some(url) = &proxy.revocation_url {
968 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
969 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
970 return Err(crate::error::McpxError::Config(format!(
971 "oauth.proxy.revocation_url forbidden ({reason})"
972 )));
973 }
974 }
975 if proxy.expose_admin_endpoints
982 && !proxy.require_auth_on_admin_endpoints
983 && !proxy.allow_unauthenticated_admin_endpoints
984 {
985 return Err(crate::error::McpxError::Config(
986 "oauth.proxy: expose_admin_endpoints = true requires \
987 require_auth_on_admin_endpoints = true (recommended) \
988 or allow_unauthenticated_admin_endpoints = true \
989 (explicit opt-out, only safe behind an authenticated \
990 reverse proxy)"
991 .into(),
992 ));
993 }
994 }
995 if let Some(tx) = &self.token_exchange {
996 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
997 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
998 return Err(crate::error::McpxError::Config(format!(
999 "oauth.token_exchange.token_url forbidden ({reason})"
1000 )));
1001 }
1002 validate_token_exchange_client_auth(tx)?;
1005 }
1006 if let Some(raw) = &self.ssrf_allowlist {
1010 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1011 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
1012 })?;
1013 if !compiled.is_empty() {
1014 tracing::warn!(
1015 host_count = compiled.host_count(),
1016 cidr_count = compiled.cidr_count(),
1017 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
1018 are now reachable. Cloud-metadata addresses remain blocked. \
1019 See SECURITY.md \"Operator allowlist\"."
1020 );
1021 }
1022 }
1023 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
1026 crate::error::McpxError::Config(format!(
1027 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
1028 self.jwks_cache_ttl
1029 ))
1030 })?;
1031 Ok(())
1032 }
1033}
1034
1035fn validate_token_exchange_client_auth(
1041 tx: &TokenExchangeConfig,
1042) -> Result<(), crate::error::McpxError> {
1043 match (&tx.client_cert, tx.client_secret.is_some()) {
1044 (Some(_), true) => Err(crate::error::McpxError::Config(
1045 "oauth.token_exchange: client_cert and client_secret are mutually \
1046 exclusive (RFC 8705 ยง2). Set exactly one."
1047 .into(),
1048 )),
1049 (None, false) => Err(crate::error::McpxError::Config(
1050 "oauth.token_exchange: token exchange requires client authentication. \
1051 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1052 .into(),
1053 )),
1054 (Some(cc), false) => validate_client_cert_config(cc),
1055 (None, true) => Ok(()),
1056 }
1057}
1058
1059fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1072 #[cfg(not(feature = "oauth-mtls-client"))]
1073 {
1074 let _ = cc;
1075 Err(crate::error::McpxError::Config(
1076 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1077 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1078 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1079 the field"
1080 .into(),
1081 ))
1082 }
1083 #[cfg(feature = "oauth-mtls-client")]
1084 {
1085 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1086 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1087 crate::error::McpxError::Config(format!(
1088 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1089 cc.cert_path.display()
1090 ))
1091 })?;
1092 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1093 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1094 crate::error::McpxError::Config(format!(
1095 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1096 cc.key_path.display()
1097 ))
1098 })?;
1099 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1100 combined.extend_from_slice(&cert_bytes);
1101 if !cert_bytes.ends_with(b"\n") {
1102 combined.push(b'\n');
1103 }
1104 combined.extend_from_slice(&key_bytes);
1105 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1106 tracing::warn!(
1107 error = %e,
1108 cert_path = %cc.cert_path.display(),
1109 key_path = %cc.key_path.display(),
1110 "client cert PEM parse failed"
1111 );
1112 crate::error::McpxError::Config(format!(
1113 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1114 cc.cert_path.display(),
1115 cc.key_path.display()
1116 ))
1117 })?;
1118 Ok(())
1119 }
1120}
1121
1122#[cfg(feature = "oauth-mtls-client")]
1130fn build_mtls_clients(
1131 config: Option<&OAuthConfig>,
1132 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1133 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1134) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1135 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1136 let Some(cfg) = config else {
1137 return Ok(Arc::new(map));
1138 };
1139 let Some(tx) = &cfg.token_exchange else {
1140 return Ok(Arc::new(map));
1141 };
1142 let Some(cc) = &tx.client_cert else {
1143 return Ok(Arc::new(map));
1144 };
1145
1146 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1147 crate::error::McpxError::Startup(format!(
1148 "oauth http client mTLS: read cert_path {}: {e}",
1149 cc.cert_path.display()
1150 ))
1151 })?;
1152 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1153 crate::error::McpxError::Startup(format!(
1154 "oauth http client mTLS: read key_path {}: {e}",
1155 cc.key_path.display()
1156 ))
1157 })?;
1158 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1159 combined.extend_from_slice(&cert_bytes);
1160 if !cert_bytes.ends_with(b"\n") {
1161 combined.push(b'\n');
1162 }
1163 combined.extend_from_slice(&key_bytes);
1164 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1165 crate::error::McpxError::Startup(format!(
1166 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1167 cc.cert_path.display(),
1168 cc.key_path.display()
1169 ))
1170 })?;
1171
1172 let resolver: Arc<dyn reqwest::dns::Resolve> =
1173 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1174 Arc::clone(allowlist),
1175 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1180 test_bypass.clone(),
1181 ));
1182
1183 let mut builder = reqwest::Client::builder()
1184 .no_proxy()
1186 .dns_resolver(Arc::clone(&resolver))
1187 .connect_timeout(Duration::from_secs(10))
1188 .timeout(Duration::from_secs(30))
1189 .redirect(reqwest::redirect::Policy::none())
1190 .identity(identity);
1191
1192 if let Some(ref ca_path) = cfg.ca_cert_path {
1193 let pem = std::fs::read(ca_path).map_err(|e| {
1194 crate::error::McpxError::Startup(format!(
1195 "oauth http client mTLS: read ca_cert_path {}: {e}",
1196 ca_path.display()
1197 ))
1198 })?;
1199 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1200 crate::error::McpxError::Startup(format!(
1201 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1202 ca_path.display()
1203 ))
1204 })?;
1205 builder = builder.add_root_certificate(cert);
1206 }
1207
1208 let client = builder.build().map_err(|e| {
1209 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1210 })?;
1211 map.insert(
1212 MtlsClientKey {
1213 cert_path: cc.cert_path.clone(),
1214 key_path: cc.key_path.clone(),
1215 },
1216 client,
1217 );
1218 Ok(Arc::new(map))
1219}
1220
1221fn check_oauth_url(
1228 field: &str,
1229 raw: &str,
1230 allow_http: bool,
1231) -> Result<url::Url, crate::error::McpxError> {
1232 let parsed = url::Url::parse(raw).map_err(|e| {
1233 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1234 })?;
1235 if !parsed.username().is_empty() || parsed.password().is_some() {
1236 return Err(crate::error::McpxError::Config(format!(
1237 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1238 )));
1239 }
1240 match parsed.scheme() {
1241 "https" => Ok(parsed),
1242 "http" if allow_http => Ok(parsed),
1243 "http" => Err(crate::error::McpxError::Config(format!(
1244 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1245 to override - strongly discouraged in production)"
1246 ))),
1247 other => Err(crate::error::McpxError::Config(format!(
1248 "{field}: must use https scheme (got {other:?})"
1249 ))),
1250 }
1251}
1252
1253#[derive(Debug, Clone)]
1259#[must_use = "builders do nothing until `.build()` is called"]
1260pub struct OAuthConfigBuilder {
1261 inner: OAuthConfig,
1262}
1263
1264impl OAuthConfigBuilder {
1265 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1267 self.inner.scopes = scopes;
1268 self
1269 }
1270
1271 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1273 self.inner.scopes.push(ScopeMapping {
1274 scope: scope.into(),
1275 role: role.into(),
1276 });
1277 self
1278 }
1279
1280 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1283 self.inner.role_claim = Some(claim.into());
1284 self
1285 }
1286
1287 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1289 self.inner.role_mappings = mappings;
1290 self
1291 }
1292
1293 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1296 self.inner.role_mappings.push(RoleMapping {
1297 claim_value: claim_value.into(),
1298 role: role.into(),
1299 });
1300 self
1301 }
1302
1303 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1306 self.inner.jwks_cache_ttl = ttl.into();
1307 self
1308 }
1309
1310 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1313 self.inner.proxy = Some(proxy);
1314 self
1315 }
1316
1317 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1319 self.inner.token_exchange = Some(token_exchange);
1320 self
1321 }
1322
1323 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1328 self.inner.ca_cert_path = Some(path.into());
1329 self
1330 }
1331
1332 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1338 self.inner.allow_http_oauth_urls = allow;
1339 self
1340 }
1341
1342 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1351 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1352 #[allow(
1353 deprecated,
1354 reason = "intentional: deprecated builder forwards to deprecated field"
1355 )]
1356 {
1357 self.inner.strict_audience_validation = strict;
1358 }
1359 self.inner.audience_validation_mode = None;
1360 self
1361 }
1362
1363 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1371 self.inner.audience_validation_mode = Some(mode);
1372 self
1373 }
1374
1375 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1377 self.inner.jwks_max_response_bytes = bytes;
1378 self
1379 }
1380
1381 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1389 self.inner.ssrf_allowlist = Some(allowlist);
1390 self
1391 }
1392
1393 #[must_use]
1395 pub fn build(self) -> OAuthConfig {
1396 self.inner
1397 }
1398}
1399
1400#[derive(Debug, Clone, Deserialize)]
1402#[non_exhaustive]
1403pub struct ScopeMapping {
1404 pub scope: String,
1406 pub role: String,
1408}
1409
1410#[derive(Debug, Clone, Deserialize)]
1414#[non_exhaustive]
1415pub struct RoleMapping {
1416 pub claim_value: String,
1418 pub role: String,
1420}
1421
1422#[derive(Debug, Clone, Deserialize)]
1429#[non_exhaustive]
1430pub struct TokenExchangeConfig {
1431 pub token_url: String,
1434 pub client_id: String,
1436 pub client_secret: Option<secrecy::SecretString>,
1441 pub client_cert: Option<ClientCertConfig>,
1454 pub audience: String,
1458}
1459
1460impl TokenExchangeConfig {
1461 #[must_use]
1463 pub fn new(
1464 token_url: String,
1465 client_id: String,
1466 client_secret: Option<secrecy::SecretString>,
1467 client_cert: Option<ClientCertConfig>,
1468 audience: String,
1469 ) -> Self {
1470 Self {
1471 token_url,
1472 client_id,
1473 client_secret,
1474 client_cert,
1475 audience,
1476 }
1477 }
1478}
1479
1480#[derive(Debug, Clone, Deserialize)]
1484#[non_exhaustive]
1485pub struct ClientCertConfig {
1486 pub cert_path: PathBuf,
1489 pub key_path: PathBuf,
1493}
1494
1495impl ClientCertConfig {
1496 #[must_use]
1500 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1501 Self {
1502 cert_path,
1503 key_path,
1504 }
1505 }
1506}
1507
1508#[derive(Debug, Deserialize)]
1510#[non_exhaustive]
1511pub struct ExchangedToken {
1512 pub access_token: String,
1514 pub expires_in: Option<u64>,
1516 pub issued_token_type: Option<String>,
1519}
1520
1521#[derive(Debug, Clone, Deserialize, Default)]
1528#[non_exhaustive]
1529pub struct OAuthProxyConfig {
1530 pub authorize_url: String,
1533 pub token_url: String,
1536 pub client_id: String,
1538 pub client_secret: Option<secrecy::SecretString>,
1540 #[serde(default)]
1544 pub introspection_url: Option<String>,
1545 #[serde(default)]
1549 pub revocation_url: Option<String>,
1550 #[serde(default)]
1562 pub expose_admin_endpoints: bool,
1563 #[serde(default)]
1569 pub require_auth_on_admin_endpoints: bool,
1570 #[serde(default)]
1581 pub allow_unauthenticated_admin_endpoints: bool,
1582}
1583
1584impl OAuthProxyConfig {
1585 pub fn builder(
1593 authorize_url: impl Into<String>,
1594 token_url: impl Into<String>,
1595 client_id: impl Into<String>,
1596 ) -> OAuthProxyConfigBuilder {
1597 OAuthProxyConfigBuilder {
1598 inner: Self {
1599 authorize_url: authorize_url.into(),
1600 token_url: token_url.into(),
1601 client_id: client_id.into(),
1602 ..Self::default()
1603 },
1604 }
1605 }
1606}
1607
1608#[derive(Debug, Clone)]
1614#[must_use = "builders do nothing until `.build()` is called"]
1615pub struct OAuthProxyConfigBuilder {
1616 inner: OAuthProxyConfig,
1617}
1618
1619impl OAuthProxyConfigBuilder {
1620 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1622 self.inner.client_secret = Some(secret);
1623 self
1624 }
1625
1626 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1630 self.inner.introspection_url = Some(url.into());
1631 self
1632 }
1633
1634 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1638 self.inner.revocation_url = Some(url.into());
1639 self
1640 }
1641
1642 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1650 self.inner.expose_admin_endpoints = expose;
1651 self
1652 }
1653
1654 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1657 self.inner.require_auth_on_admin_endpoints = require;
1658 self
1659 }
1660
1661 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1665 self.inner.allow_unauthenticated_admin_endpoints = allow;
1666 self
1667 }
1668
1669 #[must_use]
1671 pub fn build(self) -> OAuthProxyConfig {
1672 self.inner
1673 }
1674}
1675
1676type JwksKeyCache = (
1684 HashMap<String, (Algorithm, DecodingKey)>,
1685 Vec<(Algorithm, DecodingKey)>,
1686);
1687
1688struct CachedKeys {
1689 keys: HashMap<String, (Algorithm, DecodingKey)>,
1691 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1693 fetched_at: Instant,
1694 ttl: Duration,
1695}
1696
1697impl CachedKeys {
1698 fn is_expired(&self) -> bool {
1699 self.fetched_at.elapsed() >= self.ttl
1700 }
1701}
1702
1703#[allow(
1712 missing_debug_implementations,
1713 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1714)]
1715#[non_exhaustive]
1716pub struct JwksCache {
1717 jwks_uri: String,
1718 ttl: Duration,
1719 max_jwks_keys: usize,
1720 max_response_bytes: u64,
1721 allow_http: bool,
1722 inner: RwLock<Option<CachedKeys>>,
1723 http: reqwest::Client,
1724 validation_template: Validation,
1725 expected_audience: String,
1728 audience_mode: AudienceValidationMode,
1729 azp_fallback_warned: AtomicBool,
1733 scopes: Vec<ScopeMapping>,
1734 role_claim: Option<String>,
1735 role_mappings: Vec<RoleMapping>,
1736 last_refresh_attempt: RwLock<Option<Instant>>,
1739 refresh_lock: tokio::sync::Mutex<()>,
1741 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1745 #[cfg(any(test, feature = "test-helpers"))]
1749 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1750}
1751
1752const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1754
1755const ACCEPTED_ALGS: &[Algorithm] = &[
1757 Algorithm::RS256,
1758 Algorithm::RS384,
1759 Algorithm::RS512,
1760 Algorithm::ES256,
1761 Algorithm::ES384,
1762 Algorithm::PS256,
1763 Algorithm::PS384,
1764 Algorithm::PS512,
1765 Algorithm::EdDSA,
1766];
1767
1768#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1770#[non_exhaustive]
1771pub enum JwtValidationFailure {
1772 Expired,
1774 Invalid,
1776}
1777
1778impl JwksCache {
1779 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1793 rustls::crypto::ring::default_provider()
1796 .install_default()
1797 .ok();
1798 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1799 .install_default()
1800 .ok();
1801
1802 let ttl = humantime::parse_duration(&config.jwks_cache_ttl)
1803 .expect("jwks_cache_ttl validated by OAuthConfig::validate");
1804
1805 let mut validation = Validation::new(Algorithm::RS256);
1806 validation.validate_aud = false;
1818 validation.set_issuer(&[&config.issuer]);
1819 validation.set_required_spec_claims(&["exp", "iss"]);
1820 validation.validate_exp = true;
1821 validation.validate_nbf = true;
1822
1823 let allow_http = config.allow_http_oauth_urls;
1824
1825 let allowlist = match config.ssrf_allowlist.as_ref() {
1828 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1829 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1830 "oauth.ssrf_allowlist: {e}"
1831 ))
1832 })?),
1833 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1834 };
1835 let redirect_allowlist = Arc::clone(&allowlist);
1836
1837 #[cfg(any(test, feature = "test-helpers"))]
1839 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1840 Arc::new(AtomicBool::new(false));
1841 #[cfg(not(any(test, feature = "test-helpers")))]
1842 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1843
1844 let resolver: Arc<dyn reqwest::dns::Resolve> =
1845 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1846 Arc::clone(&allowlist),
1847 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1848 test_bypass.clone(),
1849 ));
1850
1851 let mut http_builder = reqwest::Client::builder()
1852 .no_proxy()
1854 .dns_resolver(Arc::clone(&resolver))
1855 .timeout(Duration::from_secs(10))
1856 .connect_timeout(Duration::from_secs(3))
1857 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1858 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1868 Ok(()) => attempt.follow(),
1869 Err(reason) => {
1870 tracing::warn!(
1871 reason = %reason,
1872 target = %attempt.url(),
1873 "oauth redirect rejected"
1874 );
1875 attempt.error(reason)
1876 }
1877 }
1878 }));
1879
1880 if let Some(ref ca_path) = config.ca_cert_path {
1881 let pem = std::fs::read(ca_path)?;
1887 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1888 http_builder = http_builder.add_root_certificate(cert);
1889 }
1890
1891 let http = http_builder.build()?;
1892
1893 Ok(Self {
1894 jwks_uri: config.jwks_uri.clone(),
1895 ttl,
1896 max_jwks_keys: config.max_jwks_keys,
1897 max_response_bytes: config.jwks_max_response_bytes,
1898 allow_http,
1899 inner: RwLock::new(None),
1900 http,
1901 validation_template: validation,
1902 expected_audience: config.audience.clone(),
1903 audience_mode: config.effective_audience_validation_mode(),
1904 azp_fallback_warned: AtomicBool::new(false),
1905 scopes: config.scopes.clone(),
1906 role_claim: config.role_claim.clone(),
1907 role_mappings: config.role_mappings.clone(),
1908 last_refresh_attempt: RwLock::new(None),
1909 refresh_lock: tokio::sync::Mutex::new(()),
1910 allowlist,
1911 #[cfg(any(test, feature = "test-helpers"))]
1912 test_allow_loopback_ssrf: test_bypass,
1913 })
1914 }
1915
1916 #[cfg(any(test, feature = "test-helpers"))]
1920 #[doc(hidden)]
1921 #[must_use]
1922 pub fn __test_allow_loopback_ssrf(self) -> Self {
1923 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1926 self
1927 }
1928
1929 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1931 self.validate_token_with_reason(token).await.ok()
1932 }
1933
1934 pub async fn validate_token_with_reason(
1941 &self,
1942 token: &str,
1943 ) -> Result<AuthIdentity, JwtValidationFailure> {
1944 let claims = self.decode_claims(token).await?;
1945
1946 self.check_audience(&claims)?;
1947 let role = self.resolve_role(&claims)?;
1948
1949 let sub = claims.sub;
1952 let name = claims
1953 .extra
1954 .get("preferred_username")
1955 .and_then(|v| v.as_str())
1956 .map(String::from)
1957 .or_else(|| sub.clone())
1958 .or(claims.azp)
1959 .or(claims.client_id)
1960 .unwrap_or_else(|| "oauth-client".into());
1961
1962 Ok(AuthIdentity {
1963 name,
1964 role,
1965 method: AuthMethod::OAuthJwt,
1966 raw_token: None,
1967 sub,
1968 })
1969 }
1970
1971 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1983 let (key, alg) = self.select_jwks_key(token).await?;
1984
1985 let mut validation = self.validation_template.clone();
1989 validation.algorithms = vec![alg];
1990
1991 let token_owned = token.to_owned();
1994 let join =
1995 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1996 .await;
1997
1998 let decode_result = match join {
1999 Ok(r) => r,
2000 Err(join_err) => {
2001 core::hint::cold_path();
2002 tracing::error!(
2003 error = %join_err,
2004 "JWT decode task panicked or was cancelled"
2005 );
2006 return Err(JwtValidationFailure::Invalid);
2007 }
2008 };
2009
2010 decode_result.map(|td| td.claims).map_err(|e| {
2011 core::hint::cold_path();
2012 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
2013 JwtValidationFailure::Expired
2014 } else {
2015 JwtValidationFailure::Invalid
2016 };
2017 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
2018 failure
2019 })
2020 }
2021
2022 #[allow(clippy::cognitive_complexity)]
2031 async fn select_jwks_key(
2032 &self,
2033 token: &str,
2034 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2035 let Ok(header) = decode_header(token) else {
2036 core::hint::cold_path();
2037 tracing::debug!("JWT header decode failed");
2038 return Err(JwtValidationFailure::Invalid);
2039 };
2040 let kid = header.kid.as_deref();
2041 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2042
2043 if !ACCEPTED_ALGS.contains(&header.alg) {
2044 core::hint::cold_path();
2045 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2046 return Err(JwtValidationFailure::Invalid);
2047 }
2048
2049 let Some(key) = self.find_key(kid, header.alg).await else {
2050 core::hint::cold_path();
2051 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2052 return Err(JwtValidationFailure::Invalid);
2053 };
2054
2055 Ok((key, header.alg))
2056 }
2057
2058 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2067 if claims.aud.contains(&self.expected_audience) {
2068 return Ok(());
2069 }
2070 let azp_match = claims
2071 .azp
2072 .as_deref()
2073 .is_some_and(|azp| azp == self.expected_audience);
2074 if azp_match {
2075 match self.audience_mode {
2076 AudienceValidationMode::Permissive => return Ok(()),
2077 AudienceValidationMode::Warn => {
2078 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2079 tracing::warn!(
2080 expected = %self.expected_audience,
2081 azp = ?claims.azp,
2082 "JWT accepted via deprecated `azp`-only audience fallback. \
2083 Configure your IdP to populate `aud`, or set \
2084 `audience_validation_mode = \"strict\"` once tokens carry `aud` correctly. \
2085 To silence this warning without changing acceptance, \
2086 set `audience_validation_mode = \"permissive\"`. \
2087 This warning logs once per process."
2088 );
2089 }
2090 return Ok(());
2091 }
2092 AudienceValidationMode::Strict => {}
2093 }
2094 }
2095 core::hint::cold_path();
2096 tracing::debug!(
2097 aud = ?claims.aud.0,
2098 azp = ?claims.azp,
2099 expected = %self.expected_audience,
2100 mode = ?self.audience_mode,
2101 "JWT rejected: audience mismatch"
2102 );
2103 Err(JwtValidationFailure::Invalid)
2104 }
2105
2106 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2112 if let Some(ref claim_path) = self.role_claim {
2113 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2114 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2115 values.extend(resolve_claim_path(&claims.extra, claim_path));
2116 return self
2117 .role_mappings
2118 .iter()
2119 .find(|m| values.contains(&m.claim_value.as_str()))
2120 .map(|m| m.role.clone())
2121 .ok_or(JwtValidationFailure::Invalid);
2122 }
2123
2124 let token_scopes: Vec<&str> = claims
2125 .scope
2126 .as_deref()
2127 .unwrap_or("")
2128 .split_whitespace()
2129 .collect();
2130
2131 self.scopes
2132 .iter()
2133 .find(|m| token_scopes.contains(&m.scope.as_str()))
2134 .map(|m| m.role.clone())
2135 .ok_or(JwtValidationFailure::Invalid)
2136 }
2137
2138 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2141 {
2143 let guard = self.inner.read().await;
2144 if let Some(cached) = guard.as_ref()
2145 && !cached.is_expired()
2146 && let Some(key) = lookup_key(cached, kid, alg)
2147 {
2148 return Some(key);
2149 }
2150 }
2151
2152 self.refresh_with_cooldown().await;
2154
2155 let guard = self.inner.read().await;
2156 guard
2157 .as_ref()
2158 .and_then(|cached| lookup_key(cached, kid, alg))
2159 }
2160
2161 async fn refresh_with_cooldown(&self) {
2166 let _guard = self.refresh_lock.lock().await;
2168
2169 {
2171 let last = self.last_refresh_attempt.read().await;
2172 if let Some(ts) = *last
2173 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2174 {
2175 tracing::debug!(
2176 elapsed_ms = ts.elapsed().as_millis(),
2177 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2178 "JWKS refresh skipped (cooldown active)"
2179 );
2180 return;
2181 }
2182 }
2183
2184 {
2187 let mut last = self.last_refresh_attempt.write().await;
2188 *last = Some(Instant::now());
2189 }
2190
2191 let _ = self.refresh_inner().await;
2193 }
2194
2195 async fn refresh_inner(&self) -> Result<(), String> {
2200 let Some(jwks) = self.fetch_jwks().await else {
2201 return Ok(());
2202 };
2203 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2204 Ok(cache) => cache,
2205 Err(msg) => {
2206 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2207 return Err(msg);
2208 }
2209 };
2210
2211 tracing::debug!(
2212 named = keys.len(),
2213 unnamed = unnamed_keys.len(),
2214 "JWKS refreshed"
2215 );
2216
2217 let mut guard = self.inner.write().await;
2218 *guard = Some(CachedKeys {
2219 keys,
2220 unnamed_keys,
2221 fetched_at: Instant::now(),
2222 ttl: self.ttl,
2223 });
2224 Ok(())
2225 }
2226
2227 #[allow(
2229 clippy::cognitive_complexity,
2230 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2231 )]
2232 async fn fetch_jwks(&self) -> Option<JwkSet> {
2233 #[cfg(any(test, feature = "test-helpers"))]
2234 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2235 screen_oauth_target_with_test_override(
2236 &self.jwks_uri,
2237 self.allow_http,
2238 &self.allowlist,
2239 true,
2240 )
2241 .await
2242 } else {
2243 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2244 };
2245 #[cfg(not(any(test, feature = "test-helpers")))]
2246 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2247
2248 if let Err(error) = screening {
2249 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2250 return None;
2251 }
2252
2253 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2254 Ok(resp) => resp,
2255 Err(e) => {
2256 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2257 return None;
2258 }
2259 };
2260
2261 let initial_capacity =
2262 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2263 let mut body = Vec::with_capacity(initial_capacity);
2264 while let Some(chunk) = match resp.chunk().await {
2265 Ok(chunk) => chunk,
2266 Err(error) => {
2267 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2268 return None;
2269 }
2270 } {
2271 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2272 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2273 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2274 tracing::warn!(
2275 uri = %self.jwks_uri,
2276 max_bytes = self.max_response_bytes,
2277 "JWKS response exceeded configured size cap"
2278 );
2279 return None;
2280 }
2281 body.extend_from_slice(&chunk);
2282 }
2283
2284 match serde_json::from_slice::<JwkSet>(&body) {
2285 Ok(jwks) => Some(jwks),
2286 Err(error) => {
2287 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2288 None
2289 }
2290 }
2291 }
2292
2293 #[cfg(any(test, feature = "test-helpers"))]
2296 #[doc(hidden)]
2297 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2298 let jwks = self
2299 .fetch_jwks()
2300 .await
2301 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2302 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2303 let mut guard = self.inner.write().await;
2304 *guard = Some(CachedKeys {
2305 keys,
2306 unnamed_keys,
2307 fetched_at: Instant::now(),
2308 ttl: self.ttl,
2309 });
2310 Ok(())
2311 }
2312
2313 #[cfg(any(test, feature = "test-helpers"))]
2316 #[doc(hidden)]
2317 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2318 let guard = self.inner.read().await;
2319 guard
2320 .as_ref()
2321 .is_some_and(|cache| cache.keys.contains_key(kid))
2322 }
2323}
2324
2325fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2327 if jwks.keys.len() > max_keys {
2328 return Err(format!(
2329 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2330 jwks.keys.len(),
2331 max_keys
2332 ));
2333 }
2334 let mut keys = HashMap::new();
2335 let mut unnamed_keys = Vec::new();
2336 for jwk in &jwks.keys {
2337 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2338 continue;
2339 };
2340 let Some(alg) = jwk_algorithm(jwk) else {
2341 continue;
2342 };
2343 if let Some(ref kid) = jwk.common.key_id {
2344 keys.insert(kid.clone(), (alg, decoding_key));
2345 } else {
2346 unnamed_keys.push((alg, decoding_key));
2347 }
2348 }
2349 Ok((keys, unnamed_keys))
2350}
2351
2352fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2354 if let Some(kid) = kid
2355 && let Some((cached_alg, key)) = cached.keys.get(kid)
2356 && *cached_alg == alg
2357 {
2358 return Some(key.clone());
2359 }
2360 cached
2362 .unnamed_keys
2363 .iter()
2364 .find(|(a, _)| *a == alg)
2365 .map(|(_, k)| k.clone())
2366}
2367
2368#[allow(clippy::wildcard_enum_match_arm)]
2370fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2371 jwk.common.key_algorithm.and_then(|ka| match ka {
2372 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2373 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2374 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2375 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2376 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2377 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2378 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2379 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2380 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2381 _ => None,
2382 })
2383}
2384
2385fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2406 match path {
2407 "sub" => claims.sub.iter().cloned().collect(),
2408 "azp" => claims.azp.iter().cloned().collect(),
2409 "client_id" => claims.client_id.iter().cloned().collect(),
2410 "aud" => claims.aud.0.clone(),
2411 "scope" => claims
2412 .scope
2413 .as_deref()
2414 .unwrap_or("")
2415 .split_whitespace()
2416 .map(str::to_owned)
2417 .collect(),
2418 _ => Vec::new(),
2419 }
2420}
2421
2422fn resolve_claim_path<'a>(
2432 extra: &'a HashMap<String, serde_json::Value>,
2433 path: &str,
2434) -> Vec<&'a str> {
2435 let mut segments = path.split('.');
2436 let Some(first) = segments.next() else {
2437 return Vec::new();
2438 };
2439
2440 let mut current: Option<&serde_json::Value> = extra.get(first);
2441
2442 for segment in segments {
2443 current = current.and_then(|v| v.get(segment));
2444 }
2445
2446 match current {
2447 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2448 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2449 _ => Vec::new(),
2450 }
2451}
2452
2453#[derive(Debug, Deserialize)]
2459struct Claims {
2460 sub: Option<String>,
2462 #[serde(default)]
2465 aud: OneOrMany,
2466 azp: Option<String>,
2468 client_id: Option<String>,
2470 scope: Option<String>,
2472 #[serde(flatten)]
2474 extra: HashMap<String, serde_json::Value>,
2475}
2476
2477#[derive(Debug, Default)]
2479struct OneOrMany(Vec<String>);
2480
2481impl OneOrMany {
2482 fn contains(&self, value: &str) -> bool {
2483 self.0.iter().any(|v| v == value)
2484 }
2485}
2486
2487impl<'de> Deserialize<'de> for OneOrMany {
2488 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2489 use serde::de;
2490
2491 struct Visitor;
2492 impl<'de> de::Visitor<'de> for Visitor {
2493 type Value = OneOrMany;
2494 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2495 f.write_str("a string or array of strings")
2496 }
2497 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2498 Ok(OneOrMany(vec![v.to_owned()]))
2499 }
2500 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2501 let mut v = Vec::new();
2502 while let Some(s) = seq.next_element::<String>()? {
2503 v.push(s);
2504 }
2505 Ok(OneOrMany(v))
2506 }
2507 }
2508 deserializer.deserialize_any(Visitor)
2509 }
2510}
2511
2512#[must_use]
2519pub fn looks_like_jwt(token: &str) -> bool {
2520 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2521
2522 let mut parts = token.splitn(4, '.');
2523 let Some(header_b64) = parts.next() else {
2524 return false;
2525 };
2526 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2528 return false;
2529 }
2530 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2532 return false;
2533 };
2534 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2536 return false;
2537 };
2538 header.get("alg").is_some()
2539}
2540
2541#[must_use]
2551pub fn protected_resource_metadata(
2552 resource_url: &str,
2553 server_url: &str,
2554 config: &OAuthConfig,
2555) -> serde_json::Value {
2556 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2561 let auth_server = server_url;
2562 serde_json::json!({
2563 "resource": resource_url,
2564 "authorization_servers": [auth_server],
2565 "scopes_supported": scopes,
2566 "bearer_methods_supported": ["header"]
2567 })
2568}
2569
2570#[must_use]
2575pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2576 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2577 let mut meta = serde_json::json!({
2578 "issuer": &config.issuer,
2579 "authorization_endpoint": format!("{server_url}/authorize"),
2580 "token_endpoint": format!("{server_url}/token"),
2581 "registration_endpoint": format!("{server_url}/register"),
2582 "response_types_supported": ["code"],
2583 "grant_types_supported": ["authorization_code", "refresh_token"],
2584 "code_challenge_methods_supported": ["S256"],
2585 "scopes_supported": scopes,
2586 "token_endpoint_auth_methods_supported": ["none"],
2587 });
2588 if let Some(proxy) = &config.proxy
2589 && proxy.expose_admin_endpoints
2590 && let Some(obj) = meta.as_object_mut()
2591 {
2592 if proxy.introspection_url.is_some() {
2593 obj.insert(
2594 "introspection_endpoint".into(),
2595 serde_json::Value::String(format!("{server_url}/introspect")),
2596 );
2597 }
2598 if proxy.revocation_url.is_some() {
2599 obj.insert(
2600 "revocation_endpoint".into(),
2601 serde_json::Value::String(format!("{server_url}/revoke")),
2602 );
2603 }
2604 if proxy.require_auth_on_admin_endpoints {
2605 obj.insert(
2606 "introspection_endpoint_auth_methods_supported".into(),
2607 serde_json::json!(["bearer"]),
2608 );
2609 obj.insert(
2610 "revocation_endpoint_auth_methods_supported".into(),
2611 serde_json::json!(["bearer"]),
2612 );
2613 }
2614 }
2615 meta
2616}
2617
2618#[must_use]
2631pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2632 use axum::{
2633 http::{StatusCode, header},
2634 response::IntoResponse,
2635 };
2636
2637 let upstream_query = replace_client_id(query, &proxy.client_id);
2639 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2640
2641 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2642}
2643
2644pub async fn handle_token(
2650 http: &OauthHttpClient,
2651 proxy: &OAuthProxyConfig,
2652 body: &str,
2653) -> axum::response::Response {
2654 use axum::{
2655 http::{StatusCode, header},
2656 response::IntoResponse,
2657 };
2658
2659 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2661
2662 if let Some(ref secret) = proxy.client_secret {
2664 use std::fmt::Write;
2665
2666 use secrecy::ExposeSecret;
2667 let _ = write!(
2668 upstream_body,
2669 "&client_secret={}",
2670 urlencoding::encode(secret.expose_secret())
2671 );
2672 }
2673
2674 let result = http
2675 .send_screened(
2676 &proxy.token_url,
2677 http.inner
2678 .post(&proxy.token_url)
2679 .header("Content-Type", "application/x-www-form-urlencoded")
2680 .body(upstream_body),
2681 )
2682 .await;
2683
2684 match result {
2685 Ok(resp) => {
2686 let status =
2687 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2688 let body_bytes = resp.bytes().await.unwrap_or_default();
2689 (
2690 status,
2691 [(header::CONTENT_TYPE, "application/json")],
2692 body_bytes,
2693 )
2694 .into_response()
2695 }
2696 Err(e) => {
2697 tracing::error!(error = %e, "OAuth token proxy request failed");
2698 (
2699 StatusCode::BAD_GATEWAY,
2700 [(header::CONTENT_TYPE, "application/json")],
2701 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2702 )
2703 .into_response()
2704 }
2705 }
2706}
2707
2708#[must_use]
2715pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2716 let mut resp = serde_json::json!({
2717 "client_id": proxy.client_id,
2718 "token_endpoint_auth_method": "none",
2719 });
2720 if let Some(uris) = body.get("redirect_uris")
2721 && let Some(obj) = resp.as_object_mut()
2722 {
2723 obj.insert("redirect_uris".into(), uris.clone());
2724 }
2725 if let Some(name) = body.get("client_name")
2726 && let Some(obj) = resp.as_object_mut()
2727 {
2728 obj.insert("client_name".into(), name.clone());
2729 }
2730 resp
2731}
2732
2733pub async fn handle_introspect(
2739 http: &OauthHttpClient,
2740 proxy: &OAuthProxyConfig,
2741 body: &str,
2742) -> axum::response::Response {
2743 let Some(ref url) = proxy.introspection_url else {
2744 return oauth_error_response(
2745 axum::http::StatusCode::NOT_FOUND,
2746 "not_supported",
2747 "introspection endpoint is not configured",
2748 );
2749 };
2750 proxy_oauth_admin_request(http, proxy, url, body).await
2751}
2752
2753pub async fn handle_revoke(
2760 http: &OauthHttpClient,
2761 proxy: &OAuthProxyConfig,
2762 body: &str,
2763) -> axum::response::Response {
2764 let Some(ref url) = proxy.revocation_url else {
2765 return oauth_error_response(
2766 axum::http::StatusCode::NOT_FOUND,
2767 "not_supported",
2768 "revocation endpoint is not configured",
2769 );
2770 };
2771 proxy_oauth_admin_request(http, proxy, url, body).await
2772}
2773
2774async fn proxy_oauth_admin_request(
2778 http: &OauthHttpClient,
2779 proxy: &OAuthProxyConfig,
2780 upstream_url: &str,
2781 body: &str,
2782) -> axum::response::Response {
2783 use axum::{
2784 http::{StatusCode, header},
2785 response::IntoResponse,
2786 };
2787
2788 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2789 if let Some(ref secret) = proxy.client_secret {
2790 use std::fmt::Write;
2791
2792 use secrecy::ExposeSecret;
2793 let _ = write!(
2794 upstream_body,
2795 "&client_secret={}",
2796 urlencoding::encode(secret.expose_secret())
2797 );
2798 }
2799
2800 let result = http
2801 .send_screened(
2802 upstream_url,
2803 http.inner
2804 .post(upstream_url)
2805 .header("Content-Type", "application/x-www-form-urlencoded")
2806 .body(upstream_body),
2807 )
2808 .await;
2809
2810 match result {
2811 Ok(resp) => {
2812 let status =
2813 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2814 let content_type = resp
2815 .headers()
2816 .get(header::CONTENT_TYPE)
2817 .and_then(|v| v.to_str().ok())
2818 .unwrap_or("application/json")
2819 .to_owned();
2820 let body_bytes = resp.bytes().await.unwrap_or_default();
2821 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2822 }
2823 Err(e) => {
2824 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2825 oauth_error_response(
2826 StatusCode::BAD_GATEWAY,
2827 "server_error",
2828 "upstream endpoint unreachable",
2829 )
2830 }
2831 }
2832}
2833
2834fn oauth_error_response(
2835 status: axum::http::StatusCode,
2836 error: &str,
2837 description: &str,
2838) -> axum::response::Response {
2839 use axum::{http::header, response::IntoResponse};
2840 let body = serde_json::json!({
2841 "error": error,
2842 "error_description": description,
2843 });
2844 (
2845 status,
2846 [(header::CONTENT_TYPE, "application/json")],
2847 body.to_string(),
2848 )
2849 .into_response()
2850}
2851
2852#[derive(Debug, Deserialize)]
2858struct OAuthErrorResponse {
2859 error: String,
2860 error_description: Option<String>,
2861}
2862
2863fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2870 match raw {
2871 "invalid_request" => "invalid_request",
2872 "invalid_client" => "invalid_client",
2873 "invalid_grant" => "invalid_grant",
2874 "unauthorized_client" => "unauthorized_client",
2875 "unsupported_grant_type" => "unsupported_grant_type",
2876 "invalid_scope" => "invalid_scope",
2877 "temporarily_unavailable" => "temporarily_unavailable",
2878 "invalid_target" => "invalid_target",
2880 _ => "server_error",
2883 }
2884}
2885
2886pub async fn exchange_token(
2898 http: &OauthHttpClient,
2899 config: &TokenExchangeConfig,
2900 subject_token: &str,
2901) -> Result<ExchangedToken, crate::error::McpxError> {
2902 use secrecy::ExposeSecret;
2903
2904 let client = http.client_for(config);
2905 let mut req = client
2906 .post(&config.token_url)
2907 .header("Content-Type", "application/x-www-form-urlencoded")
2908 .header("Accept", "application/json");
2909
2910 if config.client_cert.is_none()
2919 && let Some(ref secret) = config.client_secret
2920 {
2921 use base64::Engine;
2922 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2923 "{}:{}",
2924 urlencoding::encode(&config.client_id),
2925 urlencoding::encode(secret.expose_secret()),
2926 ));
2927 req = req.header("Authorization", format!("Basic {credentials}"));
2928 }
2929
2930 let form_body = build_exchange_form(config, subject_token);
2931
2932 let resp = http
2933 .send_screened(&config.token_url, req.body(form_body))
2934 .await
2935 .map_err(|e| {
2936 tracing::error!(error = %e, "token exchange request failed");
2937 crate::error::McpxError::Auth("server_error".into())
2939 })?;
2940
2941 let status = resp.status();
2942 let body_bytes = resp.bytes().await.map_err(|e| {
2943 tracing::error!(error = %e, "failed to read token exchange response");
2944 crate::error::McpxError::Auth("server_error".into())
2945 })?;
2946
2947 if !status.is_success() {
2948 core::hint::cold_path();
2949 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2952 let short_code = parsed
2953 .as_ref()
2954 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2955 if let Some(ref e) = parsed {
2956 tracing::warn!(
2957 status = %status,
2958 upstream_error = %e.error,
2959 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2960 client_code = %short_code,
2961 "token exchange rejected by authorization server",
2962 );
2963 } else {
2964 tracing::warn!(
2965 status = %status,
2966 client_code = %short_code,
2967 "token exchange rejected (unparseable upstream body)",
2968 );
2969 }
2970 return Err(crate::error::McpxError::Auth(short_code.into()));
2971 }
2972
2973 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2974 tracing::error!(error = %e, "failed to parse token exchange response");
2975 crate::error::McpxError::Auth("server_error".into())
2978 })?;
2979
2980 log_exchanged_token(&exchanged);
2981
2982 Ok(exchanged)
2983}
2984
2985fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2988 let body = format!(
2989 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2990 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2991 urlencoding::encode(subject_token),
2992 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2993 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2994 urlencoding::encode(&config.audience),
2995 );
2996 if config.client_secret.is_none() {
2997 format!(
2998 "{body}&client_id={}",
2999 urlencoding::encode(&config.client_id)
3000 )
3001 } else {
3002 body
3003 }
3004}
3005
3006fn log_exchanged_token(exchanged: &ExchangedToken) {
3009 use base64::Engine;
3010
3011 if !looks_like_jwt(&exchanged.access_token) {
3012 tracing::debug!(
3013 token_len = exchanged.access_token.len(),
3014 issued_token_type = ?exchanged.issued_token_type,
3015 expires_in = exchanged.expires_in,
3016 "exchanged token (opaque)",
3017 );
3018 return;
3019 }
3020 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
3021 return;
3022 };
3023 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
3024 return;
3025 };
3026 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
3027 return;
3028 };
3029 tracing::debug!(
3030 sub = ?claims.get("sub"),
3031 aud = ?claims.get("aud"),
3032 azp = ?claims.get("azp"),
3033 iss = ?claims.get("iss"),
3034 expires_in = exchanged.expires_in,
3035 "exchanged token claims (JWT)",
3036 );
3037}
3038
3039fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3041 let encoded_id = urlencoding::encode(upstream_client_id);
3042 let mut parts: Vec<String> = params
3043 .split('&')
3044 .filter(|p| !p.starts_with("client_id="))
3045 .map(String::from)
3046 .collect();
3047 parts.push(format!("client_id={encoded_id}"));
3048 parts.join("&")
3049}
3050
3051#[cfg(test)]
3052mod tests {
3053 use std::sync::Arc;
3054
3055 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3056
3057 use super::*;
3058
3059 #[test]
3060 fn looks_like_jwt_valid() {
3061 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3063 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3064 let token = format!("{header}.{payload}.signature");
3065 assert!(looks_like_jwt(&token));
3066 }
3067
3068 #[test]
3069 fn looks_like_jwt_rejects_opaque_token() {
3070 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3071 }
3072
3073 #[test]
3074 fn looks_like_jwt_rejects_two_segments() {
3075 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3076 let token = format!("{header}.payload");
3077 assert!(!looks_like_jwt(&token));
3078 }
3079
3080 #[test]
3081 fn looks_like_jwt_rejects_four_segments() {
3082 assert!(!looks_like_jwt("a.b.c.d"));
3083 }
3084
3085 #[test]
3086 fn looks_like_jwt_rejects_no_alg() {
3087 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3088 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3089 let token = format!("{header}.{payload}.sig");
3090 assert!(!looks_like_jwt(&token));
3091 }
3092
3093 #[test]
3094 fn protected_resource_metadata_shape() {
3095 let config = OAuthConfig {
3096 issuer: "https://auth.example.com".into(),
3097 audience: "https://mcp.example.com/mcp".into(),
3098 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3099 scopes: vec![
3100 ScopeMapping {
3101 scope: "mcp:read".into(),
3102 role: "viewer".into(),
3103 },
3104 ScopeMapping {
3105 scope: "mcp:admin".into(),
3106 role: "ops".into(),
3107 },
3108 ],
3109 role_claim: None,
3110 role_mappings: vec![],
3111 jwks_cache_ttl: "10m".into(),
3112 proxy: None,
3113 token_exchange: None,
3114 ca_cert_path: None,
3115 allow_http_oauth_urls: false,
3116 max_jwks_keys: default_max_jwks_keys(),
3117 #[allow(
3118 deprecated,
3119 reason = "test fixture: explicit value for the deprecated field"
3120 )]
3121 strict_audience_validation: false,
3122 audience_validation_mode: None,
3123 jwks_max_response_bytes: default_jwks_max_bytes(),
3124 ssrf_allowlist: None,
3125 };
3126 let meta = protected_resource_metadata(
3127 "https://mcp.example.com/mcp",
3128 "https://mcp.example.com",
3129 &config,
3130 );
3131 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3132 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3133 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3134 assert_eq!(meta["bearer_methods_supported"][0], "header");
3135 }
3136
3137 fn validation_https_config() -> OAuthConfig {
3142 OAuthConfig::builder(
3143 "https://auth.example.com",
3144 "mcp",
3145 "https://auth.example.com/.well-known/jwks.json",
3146 )
3147 .build()
3148 }
3149
3150 #[test]
3151 fn validate_accepts_all_https_urls() {
3152 let cfg = validation_https_config();
3153 cfg.validate().expect("all-HTTPS config must validate");
3154 }
3155
3156 #[test]
3157 fn validate_rejects_unparseable_jwks_cache_ttl() {
3158 let mut cfg = validation_https_config();
3159 cfg.jwks_cache_ttl = "not-a-duration".into();
3160 let err = cfg
3161 .validate()
3162 .expect_err("malformed jwks_cache_ttl must be rejected");
3163 let msg = err.to_string();
3164 assert!(
3165 msg.contains("jwks_cache_ttl"),
3166 "error must reference offending field; got {msg:?}"
3167 );
3168 }
3169
3170 #[test]
3171 fn validate_rejects_http_jwks_uri() {
3172 let mut cfg = validation_https_config();
3173 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3174 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3175 let msg = err.to_string();
3176 assert!(
3177 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3178 "error must reference offending field + scheme requirement; got {msg:?}"
3179 );
3180 }
3181
3182 #[test]
3183 fn validate_rejects_http_proxy_authorize_url() {
3184 let mut cfg = validation_https_config();
3185 cfg.proxy = Some(
3186 OAuthProxyConfig::builder(
3187 "http://idp.example.com/authorize", "https://idp.example.com/token",
3189 "client",
3190 )
3191 .build(),
3192 );
3193 let err = cfg
3194 .validate()
3195 .expect_err("http authorize_url must be rejected");
3196 assert!(
3197 err.to_string().contains("oauth.proxy.authorize_url"),
3198 "error must reference proxy.authorize_url; got {err}"
3199 );
3200 }
3201
3202 #[test]
3203 fn validate_rejects_http_proxy_token_url() {
3204 let mut cfg = validation_https_config();
3205 cfg.proxy = Some(
3206 OAuthProxyConfig::builder(
3207 "https://idp.example.com/authorize",
3208 "http://idp.example.com/token", "client",
3210 )
3211 .build(),
3212 );
3213 let err = cfg.validate().expect_err("http token_url must be rejected");
3214 assert!(
3215 err.to_string().contains("oauth.proxy.token_url"),
3216 "error must reference proxy.token_url; got {err}"
3217 );
3218 }
3219
3220 #[test]
3221 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3222 let mut cfg = validation_https_config();
3223 cfg.proxy = Some(
3224 OAuthProxyConfig::builder(
3225 "https://idp.example.com/authorize",
3226 "https://idp.example.com/token",
3227 "client",
3228 )
3229 .introspection_url("http://idp.example.com/introspect")
3230 .build(),
3231 );
3232 let err = cfg
3233 .validate()
3234 .expect_err("http introspection_url must be rejected");
3235 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3236
3237 let mut cfg = validation_https_config();
3238 cfg.proxy = Some(
3239 OAuthProxyConfig::builder(
3240 "https://idp.example.com/authorize",
3241 "https://idp.example.com/token",
3242 "client",
3243 )
3244 .revocation_url("http://idp.example.com/revoke")
3245 .build(),
3246 );
3247 let err = cfg
3248 .validate()
3249 .expect_err("http revocation_url must be rejected");
3250 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3251 }
3252
3253 #[test]
3256 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3257 let mut cfg = validation_https_config();
3258 cfg.proxy = Some(
3259 OAuthProxyConfig::builder(
3260 "https://idp.example.com/authorize",
3261 "https://idp.example.com/token",
3262 "client",
3263 )
3264 .introspection_url("https://idp.example.com/introspect")
3265 .expose_admin_endpoints(true)
3266 .build(),
3267 );
3268 let err = cfg
3269 .validate()
3270 .expect_err("expose_admin_endpoints without auth must fail");
3271 let msg = err.to_string();
3272 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3273 assert!(
3274 msg.contains("allow_unauthenticated_admin_endpoints"),
3275 "{msg}"
3276 );
3277 }
3278
3279 #[test]
3280 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3281 let mut cfg = validation_https_config();
3282 cfg.proxy = Some(
3283 OAuthProxyConfig::builder(
3284 "https://idp.example.com/authorize",
3285 "https://idp.example.com/token",
3286 "client",
3287 )
3288 .introspection_url("https://idp.example.com/introspect")
3289 .expose_admin_endpoints(true)
3290 .require_auth_on_admin_endpoints(true)
3291 .build(),
3292 );
3293 cfg.validate()
3294 .expect("authed admin endpoints must validate");
3295 }
3296
3297 #[test]
3298 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3299 let mut cfg = validation_https_config();
3300 cfg.proxy = Some(
3301 OAuthProxyConfig::builder(
3302 "https://idp.example.com/authorize",
3303 "https://idp.example.com/token",
3304 "client",
3305 )
3306 .introspection_url("https://idp.example.com/introspect")
3307 .expose_admin_endpoints(true)
3308 .allow_unauthenticated_admin_endpoints(true)
3309 .build(),
3310 );
3311 cfg.validate()
3312 .expect("explicit unauth opt-out must validate");
3313 }
3314
3315 #[test]
3316 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3317 let mut cfg = validation_https_config();
3320 cfg.proxy = Some(
3321 OAuthProxyConfig::builder(
3322 "https://idp.example.com/authorize",
3323 "https://idp.example.com/token",
3324 "client",
3325 )
3326 .introspection_url("https://idp.example.com/introspect")
3327 .build(),
3328 );
3329 cfg.validate()
3330 .expect("unexposed admin endpoints must validate");
3331 }
3332
3333 #[test]
3334 fn validate_rejects_http_token_exchange_url() {
3335 let mut cfg = validation_https_config();
3336 cfg.token_exchange = Some(TokenExchangeConfig::new(
3337 "http://idp.example.com/token".into(), "client".into(),
3339 None,
3340 None,
3341 "downstream".into(),
3342 ));
3343 let err = cfg
3344 .validate()
3345 .expect_err("http token_exchange.token_url must be rejected");
3346 assert!(
3347 err.to_string().contains("oauth.token_exchange.token_url"),
3348 "error must reference token_exchange.token_url; got {err}"
3349 );
3350 }
3351
3352 #[test]
3353 fn validate_rejects_unparseable_url() {
3354 let mut cfg = validation_https_config();
3355 cfg.jwks_uri = "not a url".into();
3356 let err = cfg
3357 .validate()
3358 .expect_err("unparseable URL must be rejected");
3359 assert!(err.to_string().contains("invalid URL"));
3360 }
3361
3362 #[test]
3363 fn validate_rejects_non_http_scheme() {
3364 let mut cfg = validation_https_config();
3365 cfg.jwks_uri = "file:///etc/passwd".into();
3366 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3367 let msg = err.to_string();
3368 assert!(
3369 msg.contains("must use https scheme") && msg.contains("file"),
3370 "error must reject non-http(s) schemes; got {msg:?}"
3371 );
3372 }
3373
3374 #[test]
3375 fn validate_accepts_http_with_escape_hatch() {
3376 let mut cfg = OAuthConfig::builder(
3381 "http://auth.local",
3382 "mcp",
3383 "http://auth.local/.well-known/jwks.json",
3384 )
3385 .allow_http_oauth_urls(true)
3386 .build();
3387 cfg.proxy = Some(
3388 OAuthProxyConfig::builder(
3389 "http://idp.local/authorize",
3390 "http://idp.local/token",
3391 "client",
3392 )
3393 .introspection_url("http://idp.local/introspect")
3394 .revocation_url("http://idp.local/revoke")
3395 .build(),
3396 );
3397 cfg.token_exchange = Some(TokenExchangeConfig::new(
3398 "http://idp.local/token".into(),
3399 "client".into(),
3400 Some(secrecy::SecretString::new("dev-secret".into())),
3401 None,
3402 "downstream".into(),
3403 ));
3404 cfg.validate()
3405 .expect("escape hatch must permit http on all URL fields");
3406 }
3407
3408 #[test]
3409 fn validate_with_escape_hatch_still_rejects_unparseable() {
3410 let mut cfg = validation_https_config();
3413 cfg.allow_http_oauth_urls = true;
3414 cfg.jwks_uri = "::not-a-url::".into();
3415 cfg.validate()
3416 .expect_err("escape hatch must NOT bypass URL parsing");
3417 }
3418
3419 #[tokio::test]
3420 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3421 rustls::crypto::ring::default_provider()
3436 .install_default()
3437 .ok();
3438
3439 let policy = reqwest::redirect::Policy::custom(|attempt| {
3440 if attempt.url().scheme() != "https" {
3441 attempt.error("redirect to non-HTTPS URL refused")
3442 } else if attempt.previous().len() >= 2 {
3443 attempt.error("too many redirects (max 2)")
3444 } else {
3445 attempt.follow()
3446 }
3447 });
3448 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3455 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3456 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3457 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3458 );
3459 let client = reqwest::Client::builder()
3460 .no_proxy()
3461 .dns_resolver(Arc::clone(&resolver))
3462 .timeout(Duration::from_secs(5))
3463 .connect_timeout(Duration::from_secs(3))
3464 .redirect(policy)
3465 .build()
3466 .expect("test client builds");
3467
3468 let mock = wiremock::MockServer::start().await;
3469 wiremock::Mock::given(wiremock::matchers::method("GET"))
3470 .and(wiremock::matchers::path("/jwks.json"))
3471 .respond_with(
3472 wiremock::ResponseTemplate::new(302)
3473 .insert_header("location", "http://example.invalid/jwks.json"),
3474 )
3475 .mount(&mock)
3476 .await;
3477
3478 let url = format!("{}/jwks.json", mock.uri());
3487 let err = client
3488 .get(&url)
3489 .send()
3490 .await
3491 .expect_err("redirect policy must reject scheme downgrade");
3492 let chain = format!("{err:#}");
3493 assert!(
3494 chain.contains("redirect to non-HTTPS URL refused")
3495 || chain.to_lowercase().contains("redirect"),
3496 "error must surface redirect-policy rejection; got {chain:?}"
3497 );
3498 }
3499
3500 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3505
3506 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3508 let mut rng = rsa::rand_core::OsRng;
3509 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3510 let private_pem = private_key
3511 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3512 .expect("PKCS8 PEM export")
3513 .to_string();
3514
3515 let public_key = private_key.to_public_key();
3516 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3517 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3518
3519 let jwks = serde_json::json!({
3520 "keys": [{
3521 "kty": "RSA",
3522 "use": "sig",
3523 "alg": "RS256",
3524 "kid": kid,
3525 "n": n,
3526 "e": e
3527 }]
3528 });
3529
3530 (private_pem, jwks)
3531 }
3532
3533 fn mint_token(
3535 private_pem: &str,
3536 kid: &str,
3537 issuer: &str,
3538 audience: &str,
3539 subject: &str,
3540 scope: &str,
3541 ) -> String {
3542 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3543 .expect("encoding key from PEM");
3544 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3545 header.kid = Some(kid.into());
3546
3547 let now = jsonwebtoken::get_current_timestamp();
3548 let claims = serde_json::json!({
3549 "iss": issuer,
3550 "aud": audience,
3551 "sub": subject,
3552 "scope": scope,
3553 "exp": now + 3600,
3554 "iat": now,
3555 });
3556
3557 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3558 }
3559
3560 fn test_config(jwks_uri: &str) -> OAuthConfig {
3561 OAuthConfig {
3562 issuer: "https://auth.test.local".into(),
3563 audience: "https://mcp.test.local/mcp".into(),
3564 jwks_uri: jwks_uri.into(),
3565 scopes: vec![
3566 ScopeMapping {
3567 scope: "mcp:read".into(),
3568 role: "viewer".into(),
3569 },
3570 ScopeMapping {
3571 scope: "mcp:admin".into(),
3572 role: "ops".into(),
3573 },
3574 ],
3575 role_claim: None,
3576 role_mappings: vec![],
3577 jwks_cache_ttl: "5m".into(),
3578 proxy: None,
3579 token_exchange: None,
3580 ca_cert_path: None,
3581 allow_http_oauth_urls: true,
3582 max_jwks_keys: default_max_jwks_keys(),
3583 #[allow(
3584 deprecated,
3585 reason = "test fixture: explicit value for the deprecated field"
3586 )]
3587 strict_audience_validation: false,
3588 audience_validation_mode: None,
3589 jwks_max_response_bytes: default_jwks_max_bytes(),
3590 ssrf_allowlist: None,
3591 }
3592 }
3593
3594 fn test_cache(config: &OAuthConfig) -> JwksCache {
3595 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3596 }
3597
3598 #[tokio::test]
3599 async fn valid_jwt_returns_identity() {
3600 let kid = "test-key-1";
3601 let (pem, jwks) = generate_test_keypair(kid);
3602
3603 let mock_server = wiremock::MockServer::start().await;
3604 wiremock::Mock::given(wiremock::matchers::method("GET"))
3605 .and(wiremock::matchers::path("/jwks.json"))
3606 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3607 .mount(&mock_server)
3608 .await;
3609
3610 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3611 let config = test_config(&jwks_uri);
3612 let cache = test_cache(&config);
3613
3614 let token = mint_token(
3615 &pem,
3616 kid,
3617 "https://auth.test.local",
3618 "https://mcp.test.local/mcp",
3619 "ci-bot",
3620 "mcp:read mcp:other",
3621 );
3622
3623 let identity = cache.validate_token(&token).await;
3624 assert!(identity.is_some(), "valid JWT should authenticate");
3625 let id = identity.unwrap();
3626 assert_eq!(id.name, "ci-bot");
3627 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3629 }
3630
3631 #[tokio::test]
3632 async fn wrong_issuer_rejected() {
3633 let kid = "test-key-2";
3634 let (pem, jwks) = generate_test_keypair(kid);
3635
3636 let mock_server = wiremock::MockServer::start().await;
3637 wiremock::Mock::given(wiremock::matchers::method("GET"))
3638 .and(wiremock::matchers::path("/jwks.json"))
3639 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3640 .mount(&mock_server)
3641 .await;
3642
3643 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3644 let config = test_config(&jwks_uri);
3645 let cache = test_cache(&config);
3646
3647 let token = mint_token(
3648 &pem,
3649 kid,
3650 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3652 "attacker",
3653 "mcp:admin",
3654 );
3655
3656 assert!(cache.validate_token(&token).await.is_none());
3657 }
3658
3659 #[tokio::test]
3660 async fn wrong_audience_rejected() {
3661 let kid = "test-key-3";
3662 let (pem, jwks) = generate_test_keypair(kid);
3663
3664 let mock_server = wiremock::MockServer::start().await;
3665 wiremock::Mock::given(wiremock::matchers::method("GET"))
3666 .and(wiremock::matchers::path("/jwks.json"))
3667 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3668 .mount(&mock_server)
3669 .await;
3670
3671 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3672 let config = test_config(&jwks_uri);
3673 let cache = test_cache(&config);
3674
3675 let token = mint_token(
3676 &pem,
3677 kid,
3678 "https://auth.test.local",
3679 "https://wrong-audience.example.com", "attacker",
3681 "mcp:admin",
3682 );
3683
3684 assert!(cache.validate_token(&token).await.is_none());
3685 }
3686
3687 #[tokio::test]
3688 async fn expired_jwt_rejected() {
3689 let kid = "test-key-4";
3690 let (pem, jwks) = generate_test_keypair(kid);
3691
3692 let mock_server = wiremock::MockServer::start().await;
3693 wiremock::Mock::given(wiremock::matchers::method("GET"))
3694 .and(wiremock::matchers::path("/jwks.json"))
3695 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3696 .mount(&mock_server)
3697 .await;
3698
3699 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3700 let config = test_config(&jwks_uri);
3701 let cache = test_cache(&config);
3702
3703 let encoding_key =
3705 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3706 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3707 header.kid = Some(kid.into());
3708 let now = jsonwebtoken::get_current_timestamp();
3709 let claims = serde_json::json!({
3710 "iss": "https://auth.test.local",
3711 "aud": "https://mcp.test.local/mcp",
3712 "sub": "expired-bot",
3713 "scope": "mcp:read",
3714 "exp": now - 120,
3715 "iat": now - 3720,
3716 });
3717 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3718
3719 assert!(cache.validate_token(&token).await.is_none());
3720 }
3721
3722 #[tokio::test]
3723 async fn no_matching_scope_rejected() {
3724 let kid = "test-key-5";
3725 let (pem, jwks) = generate_test_keypair(kid);
3726
3727 let mock_server = wiremock::MockServer::start().await;
3728 wiremock::Mock::given(wiremock::matchers::method("GET"))
3729 .and(wiremock::matchers::path("/jwks.json"))
3730 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3731 .mount(&mock_server)
3732 .await;
3733
3734 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3735 let config = test_config(&jwks_uri);
3736 let cache = test_cache(&config);
3737
3738 let token = mint_token(
3739 &pem,
3740 kid,
3741 "https://auth.test.local",
3742 "https://mcp.test.local/mcp",
3743 "limited-bot",
3744 "some:other:scope", );
3746
3747 assert!(cache.validate_token(&token).await.is_none());
3748 }
3749
3750 #[tokio::test]
3751 async fn wrong_signing_key_rejected() {
3752 let kid = "test-key-6";
3753 let (_pem, jwks) = generate_test_keypair(kid);
3754
3755 let (attacker_pem, _) = generate_test_keypair(kid);
3757
3758 let mock_server = wiremock::MockServer::start().await;
3759 wiremock::Mock::given(wiremock::matchers::method("GET"))
3760 .and(wiremock::matchers::path("/jwks.json"))
3761 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3762 .mount(&mock_server)
3763 .await;
3764
3765 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3766 let config = test_config(&jwks_uri);
3767 let cache = test_cache(&config);
3768
3769 let token = mint_token(
3771 &attacker_pem,
3772 kid,
3773 "https://auth.test.local",
3774 "https://mcp.test.local/mcp",
3775 "attacker",
3776 "mcp:admin",
3777 );
3778
3779 assert!(cache.validate_token(&token).await.is_none());
3780 }
3781
3782 #[tokio::test]
3783 async fn admin_scope_maps_to_ops_role() {
3784 let kid = "test-key-7";
3785 let (pem, jwks) = generate_test_keypair(kid);
3786
3787 let mock_server = wiremock::MockServer::start().await;
3788 wiremock::Mock::given(wiremock::matchers::method("GET"))
3789 .and(wiremock::matchers::path("/jwks.json"))
3790 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3791 .mount(&mock_server)
3792 .await;
3793
3794 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3795 let config = test_config(&jwks_uri);
3796 let cache = test_cache(&config);
3797
3798 let token = mint_token(
3799 &pem,
3800 kid,
3801 "https://auth.test.local",
3802 "https://mcp.test.local/mcp",
3803 "admin-bot",
3804 "mcp:admin",
3805 );
3806
3807 let id = cache
3808 .validate_token(&token)
3809 .await
3810 .expect("should authenticate");
3811 assert_eq!(id.role, "ops");
3812 assert_eq!(id.name, "admin-bot");
3813 }
3814
3815 #[tokio::test]
3816 async fn jwks_server_down_returns_none() {
3817 let config = test_config("http://127.0.0.1:1/jwks.json");
3819 let cache = test_cache(&config);
3820
3821 let kid = "orphan-key";
3822 let (pem, _) = generate_test_keypair(kid);
3823 let token = mint_token(
3824 &pem,
3825 kid,
3826 "https://auth.test.local",
3827 "https://mcp.test.local/mcp",
3828 "bot",
3829 "mcp:read",
3830 );
3831
3832 assert!(cache.validate_token(&token).await.is_none());
3833 }
3834
3835 #[test]
3840 fn resolve_claim_path_flat_string() {
3841 let mut extra = HashMap::new();
3842 extra.insert(
3843 "scope".into(),
3844 serde_json::Value::String("mcp:read mcp:admin".into()),
3845 );
3846 let values = resolve_claim_path(&extra, "scope");
3847 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3848 }
3849
3850 #[test]
3851 fn resolve_claim_path_flat_array() {
3852 let mut extra = HashMap::new();
3853 extra.insert(
3854 "roles".into(),
3855 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3856 );
3857 let values = resolve_claim_path(&extra, "roles");
3858 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3859 }
3860
3861 #[test]
3862 fn resolve_claim_path_nested_keycloak() {
3863 let mut extra = HashMap::new();
3864 extra.insert(
3865 "realm_access".into(),
3866 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3867 );
3868 let values = resolve_claim_path(&extra, "realm_access.roles");
3869 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3870 }
3871
3872 #[test]
3873 fn resolve_claim_path_missing_returns_empty() {
3874 let extra = HashMap::new();
3875 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3876 }
3877
3878 #[test]
3879 fn resolve_claim_path_numeric_leaf_returns_empty() {
3880 let mut extra = HashMap::new();
3881 extra.insert("count".into(), serde_json::json!(42));
3882 assert!(resolve_claim_path(&extra, "count").is_empty());
3883 }
3884
3885 fn make_claims(json: serde_json::Value) -> Claims {
3886 serde_json::from_value(json).expect("test claims must deserialize")
3887 }
3888
3889 #[test]
3890 fn first_class_scope_claim_splits_on_whitespace() {
3891 let claims = make_claims(serde_json::json!({
3892 "iss": "https://issuer.example.com",
3893 "exp": 9_999_999_999_u64,
3894 "scope": "read write admin",
3895 }));
3896 let values = first_class_claim_values(&claims, "scope");
3897 assert_eq!(values, vec!["read", "write", "admin"]);
3898 }
3899
3900 #[test]
3901 fn first_class_sub_claim_returns_single_value() {
3902 let claims = make_claims(serde_json::json!({
3903 "iss": "https://issuer.example.com",
3904 "exp": 9_999_999_999_u64,
3905 "sub": "service-account-orders",
3906 }));
3907 let values = first_class_claim_values(&claims, "sub");
3908 assert_eq!(values, vec!["service-account-orders"]);
3909 }
3910
3911 #[test]
3912 fn first_class_aud_claim_returns_every_audience() {
3913 let claims = make_claims(serde_json::json!({
3914 "iss": "https://issuer.example.com",
3915 "exp": 9_999_999_999_u64,
3916 "aud": ["api-a", "api-b"],
3917 }));
3918 let values = first_class_claim_values(&claims, "aud");
3919 assert_eq!(values, vec!["api-a", "api-b"]);
3920 }
3921
3922 #[test]
3923 fn first_class_unknown_path_returns_empty() {
3924 let claims = make_claims(serde_json::json!({
3925 "iss": "https://issuer.example.com",
3926 "exp": 9_999_999_999_u64,
3927 }));
3928 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
3929 }
3930
3931 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3937 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3938 .expect("encoding key from PEM");
3939 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3940 header.kid = Some(kid.into());
3941 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3942 }
3943
3944 fn test_config_with_role_claim(
3945 jwks_uri: &str,
3946 role_claim: &str,
3947 role_mappings: Vec<RoleMapping>,
3948 ) -> OAuthConfig {
3949 OAuthConfig {
3950 issuer: "https://auth.test.local".into(),
3951 audience: "https://mcp.test.local/mcp".into(),
3952 jwks_uri: jwks_uri.into(),
3953 scopes: vec![],
3954 role_claim: Some(role_claim.into()),
3955 role_mappings,
3956 jwks_cache_ttl: "5m".into(),
3957 proxy: None,
3958 token_exchange: None,
3959 ca_cert_path: None,
3960 allow_http_oauth_urls: true,
3961 max_jwks_keys: default_max_jwks_keys(),
3962 #[allow(
3963 deprecated,
3964 reason = "test fixture: explicit value for the deprecated field"
3965 )]
3966 strict_audience_validation: false,
3967 audience_validation_mode: None,
3968 jwks_max_response_bytes: default_jwks_max_bytes(),
3969 ssrf_allowlist: None,
3970 }
3971 }
3972
3973 #[tokio::test]
3974 async fn screen_oauth_target_rejects_literal_ip() {
3975 let err = screen_oauth_target(
3976 "https://127.0.0.1/jwks.json",
3977 false,
3978 &crate::ssrf::CompiledSsrfAllowlist::default(),
3979 )
3980 .await
3981 .expect_err("literal IPs must be rejected");
3982 let msg = err.to_string();
3983 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3984 }
3985
3986 #[tokio::test]
3987 async fn screen_oauth_target_rejects_private_dns_resolution() {
3988 let err = screen_oauth_target(
3989 "https://localhost/jwks.json",
3990 false,
3991 &crate::ssrf::CompiledSsrfAllowlist::default(),
3992 )
3993 .await
3994 .expect_err("localhost resolution must be rejected");
3995 let msg = err.to_string();
3996 assert!(
3997 msg.contains("blocked IP") && msg.contains("loopback"),
3998 "got {msg:?}"
3999 );
4000 }
4001
4002 #[tokio::test]
4003 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
4004 let err = screen_oauth_target(
4005 "http://127.0.0.1/jwks.json",
4006 true,
4007 &crate::ssrf::CompiledSsrfAllowlist::default(),
4008 )
4009 .await
4010 .expect_err("literal IPs must still be rejected when http is allowed");
4011 let msg = err.to_string();
4012 assert!(msg.contains("literal IPv4 addresses are forbidden"));
4013 }
4014
4015 #[tokio::test]
4016 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
4017 let err = screen_oauth_target(
4018 "http://localhost/jwks.json",
4019 true,
4020 &crate::ssrf::CompiledSsrfAllowlist::default(),
4021 )
4022 .await
4023 .expect_err("private DNS resolution must still be rejected when http is allowed");
4024 let msg = err.to_string();
4025 assert!(
4026 msg.contains("blocked IP") && msg.contains("loopback"),
4027 "got {msg:?}"
4028 );
4029 }
4030
4031 #[tokio::test]
4032 async fn screen_oauth_target_allows_public_hostname() {
4033 screen_oauth_target(
4034 "https://example.com/.well-known/jwks.json",
4035 false,
4036 &crate::ssrf::CompiledSsrfAllowlist::default(),
4037 )
4038 .await
4039 .expect("public hostname should pass screening");
4040 }
4041
4042 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4048 let raw = OAuthSsrfAllowlist {
4049 hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
4050 cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
4051 };
4052 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4053 }
4054
4055 #[test]
4056 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4057 let raw = OAuthSsrfAllowlist {
4058 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4059 cidrs: vec![],
4060 };
4061 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4062 assert_eq!(compiled.host_count(), 1);
4063 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4064 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4065 }
4066
4067 #[test]
4068 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4069 let raw = OAuthSsrfAllowlist {
4070 hosts: vec!["10.0.0.1".into()],
4071 cidrs: vec![],
4072 };
4073 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4074 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4075 }
4076
4077 #[test]
4078 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4079 let raw = OAuthSsrfAllowlist {
4080 hosts: vec!["rhbk.ops.example.com:8443".into()],
4081 cidrs: vec![],
4082 };
4083 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4084 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4085 }
4086
4087 #[test]
4088 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4089 let raw = OAuthSsrfAllowlist {
4090 hosts: vec![],
4091 cidrs: vec!["not-a-cidr".into()],
4092 };
4093 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4094 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4095 }
4096
4097 #[test]
4098 fn validate_rejects_misconfigured_allowlist() {
4099 let mut cfg = OAuthConfig::builder(
4100 "https://auth.example.com/",
4101 "mcp",
4102 "https://auth.example.com/jwks.json",
4103 )
4104 .build();
4105 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4106 hosts: vec!["10.0.0.1".into()],
4107 cidrs: vec![],
4108 });
4109 let err = cfg
4110 .validate()
4111 .expect_err("literal IP host must be rejected");
4112 assert!(
4113 err.to_string().contains("oauth.ssrf_allowlist"),
4114 "got {err}"
4115 );
4116 }
4117
4118 #[tokio::test]
4119 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4120 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4124 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4125 .await
4126 .expect_err("loopback must still be blocked when not in allowlist");
4127 let msg = err.to_string();
4128 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4129 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4130 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4131 }
4132
4133 #[tokio::test]
4134 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4135 let err = screen_oauth_target(
4138 "https://localhost/jwks.json",
4139 false,
4140 &crate::ssrf::CompiledSsrfAllowlist::default(),
4141 )
4142 .await
4143 .expect_err("loopback rejection");
4144 let msg = err.to_string();
4145 assert!(msg.contains("blocked IP"), "got {msg:?}");
4146 assert!(msg.contains("loopback"), "got {msg:?}");
4147 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4149 }
4150
4151 #[tokio::test]
4152 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4153 let allow = make_allowlist(&["localhost"], &[]);
4155 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4156 .await
4157 .expect("allowlisted host must pass");
4158 }
4159
4160 #[tokio::test]
4161 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4162 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4165 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4166 .await
4167 .expect("allowlisted CIDR must pass");
4168 }
4169
4170 #[tokio::test]
4171 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4172 let mut cfg = OAuthConfig::builder(
4173 "https://auth.example.com/",
4174 "mcp",
4175 "https://auth.example.com/jwks.json",
4176 )
4177 .build();
4178 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4179 hosts: vec![],
4180 cidrs: vec!["bad-cidr".into()],
4181 });
4182 let Err(err) = JwksCache::new(&cfg) else {
4183 panic!("invalid CIDR must fail JwksCache::new")
4184 };
4185 let msg = err.to_string();
4186 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4187 }
4188
4189 #[tokio::test]
4190 async fn audience_falls_back_to_azp_by_default() {
4191 let kid = "test-audience-azp-default";
4192 let (pem, jwks) = generate_test_keypair(kid);
4193
4194 let mock_server = wiremock::MockServer::start().await;
4195 wiremock::Mock::given(wiremock::matchers::method("GET"))
4196 .and(wiremock::matchers::path("/jwks.json"))
4197 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4198 .mount(&mock_server)
4199 .await;
4200
4201 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4202 let config = test_config(&jwks_uri);
4203 let cache = test_cache(&config);
4204
4205 let now = jsonwebtoken::get_current_timestamp();
4206 let token = mint_token_with_claims(
4207 &pem,
4208 kid,
4209 &serde_json::json!({
4210 "iss": "https://auth.test.local",
4211 "aud": "https://some-other-resource.example.com",
4212 "azp": "https://mcp.test.local/mcp",
4213 "sub": "compat-client",
4214 "scope": "mcp:read",
4215 "exp": now + 3600,
4216 "iat": now,
4217 }),
4218 );
4219
4220 let identity = cache
4221 .validate_token_with_reason(&token)
4222 .await
4223 .expect("azp fallback should remain enabled by default");
4224 assert_eq!(identity.role, "viewer");
4225 }
4226
4227 #[tokio::test]
4228 async fn strict_audience_validation_rejects_azp_only_match() {
4229 let kid = "test-audience-azp-strict";
4230 let (pem, jwks) = generate_test_keypair(kid);
4231
4232 let mock_server = wiremock::MockServer::start().await;
4233 wiremock::Mock::given(wiremock::matchers::method("GET"))
4234 .and(wiremock::matchers::path("/jwks.json"))
4235 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4236 .mount(&mock_server)
4237 .await;
4238
4239 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4240 let mut config = test_config(&jwks_uri);
4241 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4242 {
4243 config.strict_audience_validation = true;
4244 }
4245 let cache = test_cache(&config);
4246
4247 let now = jsonwebtoken::get_current_timestamp();
4248 let token = mint_token_with_claims(
4249 &pem,
4250 kid,
4251 &serde_json::json!({
4252 "iss": "https://auth.test.local",
4253 "aud": "https://some-other-resource.example.com",
4254 "azp": "https://mcp.test.local/mcp",
4255 "sub": "strict-client",
4256 "scope": "mcp:read",
4257 "exp": now + 3600,
4258 "iat": now,
4259 }),
4260 );
4261
4262 let failure = cache
4263 .validate_token_with_reason(&token)
4264 .await
4265 .expect_err("strict audience validation must ignore azp fallback");
4266 assert_eq!(failure, JwtValidationFailure::Invalid);
4267 }
4268
4269 #[tokio::test]
4270 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4271 let kid = "test-audience-warn-mode";
4272 let (pem, jwks) = generate_test_keypair(kid);
4273
4274 let mock_server = wiremock::MockServer::start().await;
4275 wiremock::Mock::given(wiremock::matchers::method("GET"))
4276 .and(wiremock::matchers::path("/jwks.json"))
4277 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4278 .mount(&mock_server)
4279 .await;
4280
4281 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4282 let mut config = test_config(&jwks_uri);
4283 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4284 let cache = test_cache(&config);
4285
4286 let now = jsonwebtoken::get_current_timestamp();
4287 let claims = serde_json::json!({
4288 "iss": "https://auth.test.local",
4289 "aud": "https://some-other-resource.example.com",
4290 "azp": "https://mcp.test.local/mcp",
4291 "sub": "warn-client",
4292 "scope": "mcp:read",
4293 "exp": now + 3600,
4294 "iat": now,
4295 });
4296 let token = mint_token_with_claims(&pem, kid, &claims);
4297
4298 let identity = cache
4299 .validate_token_with_reason(&token)
4300 .await
4301 .expect("warn mode must accept azp-only match");
4302 assert_eq!(identity.role, "viewer");
4303 assert!(
4304 cache.azp_fallback_warned.load(Ordering::Relaxed),
4305 "warn-once flag should be set after first azp-only match"
4306 );
4307
4308 let token2 = mint_token_with_claims(&pem, kid, &claims);
4309 cache
4310 .validate_token_with_reason(&token2)
4311 .await
4312 .expect("warn mode must continue accepting subsequent matches");
4313 assert!(
4314 cache.azp_fallback_warned.load(Ordering::Relaxed),
4315 "warn-once flag must remain set; the assertion guards against accidental clearing"
4316 );
4317 }
4318
4319 #[tokio::test]
4320 async fn permissive_mode_accepts_azp_only_match_silently() {
4321 let kid = "test-audience-permissive-mode";
4322 let (pem, jwks) = generate_test_keypair(kid);
4323
4324 let mock_server = wiremock::MockServer::start().await;
4325 wiremock::Mock::given(wiremock::matchers::method("GET"))
4326 .and(wiremock::matchers::path("/jwks.json"))
4327 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4328 .mount(&mock_server)
4329 .await;
4330
4331 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4332 let mut config = test_config(&jwks_uri);
4333 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4334 let cache = test_cache(&config);
4335
4336 let now = jsonwebtoken::get_current_timestamp();
4337 let token = mint_token_with_claims(
4338 &pem,
4339 kid,
4340 &serde_json::json!({
4341 "iss": "https://auth.test.local",
4342 "aud": "https://some-other-resource.example.com",
4343 "azp": "https://mcp.test.local/mcp",
4344 "sub": "permissive-client",
4345 "scope": "mcp:read",
4346 "exp": now + 3600,
4347 "iat": now,
4348 }),
4349 );
4350
4351 cache
4352 .validate_token_with_reason(&token)
4353 .await
4354 .expect("permissive mode must accept azp-only match");
4355 assert!(
4356 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4357 "permissive mode must not flip the warn-once flag"
4358 );
4359 }
4360
4361 #[test]
4362 fn audience_validation_mode_overrides_legacy_bool() {
4363 let mut config = OAuthConfig::default();
4364 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4365 {
4366 config.strict_audience_validation = false;
4367 }
4368 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4369 assert_eq!(
4370 config.effective_audience_validation_mode(),
4371 AudienceValidationMode::Strict,
4372 "explicit mode must override legacy false"
4373 );
4374
4375 let mut config = OAuthConfig::default();
4376 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4377 {
4378 config.strict_audience_validation = true;
4379 }
4380 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4381 assert_eq!(
4382 config.effective_audience_validation_mode(),
4383 AudienceValidationMode::Permissive,
4384 "explicit mode must override legacy true"
4385 );
4386 }
4387
4388 #[test]
4389 fn audience_validation_mode_default_is_warn_when_unset() {
4390 let config = OAuthConfig::default();
4391 assert_eq!(
4392 config.effective_audience_validation_mode(),
4393 AudienceValidationMode::Warn,
4394 "unset mode + unset bool must resolve to Warn (the new default)"
4395 );
4396 }
4397
4398 #[test]
4399 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4400 let mut config = OAuthConfig::default();
4401 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4402 {
4403 config.strict_audience_validation = true;
4404 }
4405 assert_eq!(
4406 config.effective_audience_validation_mode(),
4407 AudienceValidationMode::Strict,
4408 "legacy bool=true must resolve to Strict for backward compat"
4409 );
4410 }
4411
4412 #[derive(Clone, Default)]
4413 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4414
4415 impl CapturedLogs {
4416 fn contents(&self) -> String {
4417 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4418 String::from_utf8(bytes).unwrap_or_default()
4419 }
4420 }
4421
4422 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4423
4424 impl std::io::Write for CapturedLogsWriter {
4425 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4426 if let Ok(mut guard) = self.0.lock() {
4427 guard.extend_from_slice(buf);
4428 }
4429 Ok(buf.len())
4430 }
4431
4432 fn flush(&mut self) -> std::io::Result<()> {
4433 Ok(())
4434 }
4435 }
4436
4437 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4438 type Writer = CapturedLogsWriter;
4439
4440 fn make_writer(&'a self) -> Self::Writer {
4441 CapturedLogsWriter(Arc::clone(&self.0))
4442 }
4443 }
4444
4445 #[tokio::test]
4446 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4447 let kid = "oversized-jwks";
4448 let (_pem, jwks) = generate_test_keypair(kid);
4449 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4450 oversized_body.push_str(&" ".repeat(4096));
4451
4452 let mock_server = wiremock::MockServer::start().await;
4453 wiremock::Mock::given(wiremock::matchers::method("GET"))
4454 .and(wiremock::matchers::path("/jwks.json"))
4455 .respond_with(
4456 wiremock::ResponseTemplate::new(200)
4457 .insert_header("content-type", "application/json")
4458 .set_body_string(oversized_body),
4459 )
4460 .mount(&mock_server)
4461 .await;
4462
4463 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4464 let mut config = test_config(&jwks_uri);
4465 config.jwks_max_response_bytes = 256;
4466 let cache = test_cache(&config);
4467
4468 let logs = CapturedLogs::default();
4469 let subscriber = tracing_subscriber::fmt()
4470 .with_writer(logs.clone())
4471 .with_ansi(false)
4472 .without_time()
4473 .finish();
4474 let _guard = tracing::subscriber::set_default(subscriber);
4475
4476 let result = cache.fetch_jwks().await;
4477 assert!(result.is_none(), "oversized JWKS must be dropped");
4478 assert!(
4479 logs.contents()
4480 .contains("JWKS response exceeded configured size cap"),
4481 "expected cap-exceeded warning in logs"
4482 );
4483 }
4484
4485 #[tokio::test]
4486 async fn role_claim_keycloak_nested_array() {
4487 let kid = "test-role-1";
4488 let (pem, jwks) = generate_test_keypair(kid);
4489
4490 let mock_server = wiremock::MockServer::start().await;
4491 wiremock::Mock::given(wiremock::matchers::method("GET"))
4492 .and(wiremock::matchers::path("/jwks.json"))
4493 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4494 .mount(&mock_server)
4495 .await;
4496
4497 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4498 let config = test_config_with_role_claim(
4499 &jwks_uri,
4500 "realm_access.roles",
4501 vec![
4502 RoleMapping {
4503 claim_value: "mcp-admin".into(),
4504 role: "ops".into(),
4505 },
4506 RoleMapping {
4507 claim_value: "mcp-viewer".into(),
4508 role: "viewer".into(),
4509 },
4510 ],
4511 );
4512 let cache = test_cache(&config);
4513
4514 let now = jsonwebtoken::get_current_timestamp();
4515 let token = mint_token_with_claims(
4516 &pem,
4517 kid,
4518 &serde_json::json!({
4519 "iss": "https://auth.test.local",
4520 "aud": "https://mcp.test.local/mcp",
4521 "sub": "keycloak-user",
4522 "exp": now + 3600,
4523 "iat": now,
4524 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4525 }),
4526 );
4527
4528 let id = cache
4529 .validate_token(&token)
4530 .await
4531 .expect("should authenticate");
4532 assert_eq!(id.name, "keycloak-user");
4533 assert_eq!(id.role, "ops");
4534 }
4535
4536 #[tokio::test]
4537 async fn role_claim_flat_roles_array() {
4538 let kid = "test-role-2";
4539 let (pem, jwks) = generate_test_keypair(kid);
4540
4541 let mock_server = wiremock::MockServer::start().await;
4542 wiremock::Mock::given(wiremock::matchers::method("GET"))
4543 .and(wiremock::matchers::path("/jwks.json"))
4544 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4545 .mount(&mock_server)
4546 .await;
4547
4548 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4549 let config = test_config_with_role_claim(
4550 &jwks_uri,
4551 "roles",
4552 vec![
4553 RoleMapping {
4554 claim_value: "MCP.Admin".into(),
4555 role: "ops".into(),
4556 },
4557 RoleMapping {
4558 claim_value: "MCP.Reader".into(),
4559 role: "viewer".into(),
4560 },
4561 ],
4562 );
4563 let cache = test_cache(&config);
4564
4565 let now = jsonwebtoken::get_current_timestamp();
4566 let token = mint_token_with_claims(
4567 &pem,
4568 kid,
4569 &serde_json::json!({
4570 "iss": "https://auth.test.local",
4571 "aud": "https://mcp.test.local/mcp",
4572 "sub": "azure-ad-user",
4573 "exp": now + 3600,
4574 "iat": now,
4575 "roles": ["MCP.Reader", "OtherApp.Admin"]
4576 }),
4577 );
4578
4579 let id = cache
4580 .validate_token(&token)
4581 .await
4582 .expect("should authenticate");
4583 assert_eq!(id.name, "azure-ad-user");
4584 assert_eq!(id.role, "viewer");
4585 }
4586
4587 #[tokio::test]
4588 async fn role_claim_no_matching_value_rejected() {
4589 let kid = "test-role-3";
4590 let (pem, jwks) = generate_test_keypair(kid);
4591
4592 let mock_server = wiremock::MockServer::start().await;
4593 wiremock::Mock::given(wiremock::matchers::method("GET"))
4594 .and(wiremock::matchers::path("/jwks.json"))
4595 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4596 .mount(&mock_server)
4597 .await;
4598
4599 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4600 let config = test_config_with_role_claim(
4601 &jwks_uri,
4602 "roles",
4603 vec![RoleMapping {
4604 claim_value: "mcp-admin".into(),
4605 role: "ops".into(),
4606 }],
4607 );
4608 let cache = test_cache(&config);
4609
4610 let now = jsonwebtoken::get_current_timestamp();
4611 let token = mint_token_with_claims(
4612 &pem,
4613 kid,
4614 &serde_json::json!({
4615 "iss": "https://auth.test.local",
4616 "aud": "https://mcp.test.local/mcp",
4617 "sub": "limited-user",
4618 "exp": now + 3600,
4619 "iat": now,
4620 "roles": ["some-other-role"]
4621 }),
4622 );
4623
4624 assert!(cache.validate_token(&token).await.is_none());
4625 }
4626
4627 #[tokio::test]
4628 async fn role_claim_space_separated_string() {
4629 let kid = "test-role-4";
4630 let (pem, jwks) = generate_test_keypair(kid);
4631
4632 let mock_server = wiremock::MockServer::start().await;
4633 wiremock::Mock::given(wiremock::matchers::method("GET"))
4634 .and(wiremock::matchers::path("/jwks.json"))
4635 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4636 .mount(&mock_server)
4637 .await;
4638
4639 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4640 let config = test_config_with_role_claim(
4641 &jwks_uri,
4642 "custom_scope",
4643 vec![
4644 RoleMapping {
4645 claim_value: "write".into(),
4646 role: "ops".into(),
4647 },
4648 RoleMapping {
4649 claim_value: "read".into(),
4650 role: "viewer".into(),
4651 },
4652 ],
4653 );
4654 let cache = test_cache(&config);
4655
4656 let now = jsonwebtoken::get_current_timestamp();
4657 let token = mint_token_with_claims(
4658 &pem,
4659 kid,
4660 &serde_json::json!({
4661 "iss": "https://auth.test.local",
4662 "aud": "https://mcp.test.local/mcp",
4663 "sub": "custom-client",
4664 "exp": now + 3600,
4665 "iat": now,
4666 "custom_scope": "read audit"
4667 }),
4668 );
4669
4670 let id = cache
4671 .validate_token(&token)
4672 .await
4673 .expect("should authenticate");
4674 assert_eq!(id.name, "custom-client");
4675 assert_eq!(id.role, "viewer");
4676 }
4677
4678 #[tokio::test]
4679 async fn scope_backward_compat_without_role_claim() {
4680 let kid = "test-compat-1";
4682 let (pem, jwks) = generate_test_keypair(kid);
4683
4684 let mock_server = wiremock::MockServer::start().await;
4685 wiremock::Mock::given(wiremock::matchers::method("GET"))
4686 .and(wiremock::matchers::path("/jwks.json"))
4687 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4688 .mount(&mock_server)
4689 .await;
4690
4691 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4692 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4694
4695 let token = mint_token(
4696 &pem,
4697 kid,
4698 "https://auth.test.local",
4699 "https://mcp.test.local/mcp",
4700 "legacy-bot",
4701 "mcp:admin other:scope",
4702 );
4703
4704 let id = cache
4705 .validate_token(&token)
4706 .await
4707 .expect("should authenticate");
4708 assert_eq!(id.name, "legacy-bot");
4709 assert_eq!(id.role, "ops"); }
4711
4712 #[tokio::test]
4717 async fn jwks_refresh_deduplication() {
4718 let kid = "test-dedup";
4721 let (pem, jwks) = generate_test_keypair(kid);
4722
4723 let mock_server = wiremock::MockServer::start().await;
4724 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4725 .and(wiremock::matchers::path("/jwks.json"))
4726 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4727 .expect(1) .mount(&mock_server)
4729 .await;
4730
4731 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4732 let config = test_config(&jwks_uri);
4733 let cache = Arc::new(test_cache(&config));
4734
4735 let token = mint_token(
4737 &pem,
4738 kid,
4739 "https://auth.test.local",
4740 "https://mcp.test.local/mcp",
4741 "concurrent-bot",
4742 "mcp:read",
4743 );
4744
4745 let mut handles = Vec::new();
4746 for _ in 0..5 {
4747 let c = Arc::clone(&cache);
4748 let t = token.clone();
4749 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4750 }
4751
4752 for h in handles {
4753 let result = h.await.unwrap();
4754 assert!(result.is_some(), "all concurrent requests should succeed");
4755 }
4756
4757 }
4759
4760 #[tokio::test]
4761 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4762 let kid = "test-cooldown";
4765 let (_pem, jwks) = generate_test_keypair(kid);
4766
4767 let mock_server = wiremock::MockServer::start().await;
4768 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4769 .and(wiremock::matchers::path("/jwks.json"))
4770 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4771 .expect(1) .mount(&mock_server)
4773 .await;
4774
4775 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4776 let config = test_config(&jwks_uri);
4777 let cache = test_cache(&config);
4778
4779 let fake_token1 =
4781 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4782 let _ = cache.validate_token(fake_token1).await;
4783
4784 let fake_token2 =
4787 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4788 let _ = cache.validate_token(fake_token2).await;
4789
4790 let fake_token3 =
4792 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4793 let _ = cache.validate_token(fake_token3).await;
4794
4795 }
4797
4798 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
4801 OAuthProxyConfig {
4802 authorize_url: "https://example.invalid/auth".into(),
4803 token_url: token_url.into(),
4804 client_id: "mcp-client".into(),
4805 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
4806 introspection_url: None,
4807 revocation_url: None,
4808 expose_admin_endpoints: false,
4809 require_auth_on_admin_endpoints: false,
4810 allow_unauthenticated_admin_endpoints: false,
4811 }
4812 }
4813
4814 fn test_http_client() -> OauthHttpClient {
4817 rustls::crypto::ring::default_provider()
4818 .install_default()
4819 .ok();
4820 let config = OAuthConfig::builder(
4821 "https://auth.test.local",
4822 "https://mcp.test.local/mcp",
4823 "https://auth.test.local/.well-known/jwks.json",
4824 )
4825 .allow_http_oauth_urls(true)
4826 .build();
4827 OauthHttpClient::with_config(&config)
4828 .expect("build test http client")
4829 .__test_allow_loopback_ssrf()
4830 }
4831
4832 #[tokio::test]
4833 async fn introspect_proxies_and_injects_client_credentials() {
4834 use wiremock::matchers::{body_string_contains, method, path};
4835
4836 let mock_server = wiremock::MockServer::start().await;
4837 wiremock::Mock::given(method("POST"))
4838 .and(path("/introspect"))
4839 .and(body_string_contains("client_id=mcp-client"))
4840 .and(body_string_contains("client_secret=shh"))
4841 .and(body_string_contains("token=abc"))
4842 .respond_with(
4843 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
4844 "active": true,
4845 "scope": "read"
4846 })),
4847 )
4848 .expect(1)
4849 .mount(&mock_server)
4850 .await;
4851
4852 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4853 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
4854
4855 let http = test_http_client();
4856 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4857 assert_eq!(resp.status(), 200);
4858 }
4859
4860 #[tokio::test]
4861 async fn introspect_returns_404_when_not_configured() {
4862 let proxy = proxy_cfg("https://example.invalid/token");
4863 let http = test_http_client();
4864 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4865 assert_eq!(resp.status(), 404);
4866 }
4867
4868 #[tokio::test]
4869 async fn revoke_proxies_and_returns_upstream_status() {
4870 use wiremock::matchers::{method, path};
4871
4872 let mock_server = wiremock::MockServer::start().await;
4873 wiremock::Mock::given(method("POST"))
4874 .and(path("/revoke"))
4875 .respond_with(wiremock::ResponseTemplate::new(200))
4876 .expect(1)
4877 .mount(&mock_server)
4878 .await;
4879
4880 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4881 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4882
4883 let http = test_http_client();
4884 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4885 assert_eq!(resp.status(), 200);
4886 }
4887
4888 #[tokio::test]
4889 async fn revoke_returns_404_when_not_configured() {
4890 let proxy = proxy_cfg("https://example.invalid/token");
4891 let http = test_http_client();
4892 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4893 assert_eq!(resp.status(), 404);
4894 }
4895
4896 #[test]
4897 fn metadata_advertises_endpoints_only_when_configured() {
4898 let mut cfg = test_config("https://auth.test.local/jwks.json");
4899 let m = authorization_server_metadata("https://mcp.local", &cfg);
4901 assert!(m.get("introspection_endpoint").is_none());
4902 assert!(m.get("revocation_endpoint").is_none());
4903
4904 let mut proxy = proxy_cfg("https://upstream.local/token");
4907 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4908 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4909 cfg.proxy = Some(proxy);
4910 let m = authorization_server_metadata("https://mcp.local", &cfg);
4911 assert!(
4912 m.get("introspection_endpoint").is_none(),
4913 "introspection must not be advertised when expose_admin_endpoints=false"
4914 );
4915 assert!(
4916 m.get("revocation_endpoint").is_none(),
4917 "revocation must not be advertised when expose_admin_endpoints=false"
4918 );
4919
4920 if let Some(p) = cfg.proxy.as_mut() {
4922 p.expose_admin_endpoints = true;
4923 p.revocation_url = None;
4924 }
4925 let m = authorization_server_metadata("https://mcp.local", &cfg);
4926 assert_eq!(
4927 m["introspection_endpoint"],
4928 serde_json::Value::String("https://mcp.local/introspect".into())
4929 );
4930 assert!(m.get("revocation_endpoint").is_none());
4931
4932 if let Some(p) = cfg.proxy.as_mut() {
4934 p.revocation_url = Some("https://upstream.local/revoke".into());
4935 }
4936 let m = authorization_server_metadata("https://mcp.local", &cfg);
4937 assert_eq!(
4938 m["revocation_endpoint"],
4939 serde_json::Value::String("https://mcp.local/revoke".into())
4940 );
4941 }
4942
4943 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
4946 let mut cfg = validation_https_config();
4947 cfg.token_exchange = Some(tx);
4948 cfg
4949 }
4950
4951 fn tx_with(
4952 client_secret: Option<&str>,
4953 client_cert: Option<ClientCertConfig>,
4954 ) -> TokenExchangeConfig {
4955 TokenExchangeConfig::new(
4956 "https://idp.example.com/token".into(),
4957 "client".into(),
4958 client_secret.map(|s| secrecy::SecretString::new(s.into())),
4959 client_cert,
4960 "downstream".into(),
4961 )
4962 }
4963
4964 #[test]
4965 fn validate_rejects_token_exchange_without_client_auth() {
4966 let cfg = https_cfg_with_tx(tx_with(None, None));
4967 let err = cfg
4968 .validate()
4969 .expect_err("token_exchange without client auth must be rejected");
4970 let msg = err.to_string();
4971 assert!(
4972 msg.contains("requires client authentication"),
4973 "error must explain missing client auth; got {msg:?}"
4974 );
4975 }
4976
4977 #[test]
4978 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
4979 let cc = ClientCertConfig {
4980 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4981 key_path: PathBuf::from("/nonexistent/key.pem"),
4982 };
4983 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
4984 let err = cfg
4985 .validate()
4986 .expect_err("client_secret + client_cert must be rejected");
4987 let msg = err.to_string();
4988 assert!(
4989 msg.contains("mutually") && msg.contains("exclusive"),
4990 "error must explain mutual exclusion; got {msg:?}"
4991 );
4992 }
4993
4994 #[cfg(not(feature = "oauth-mtls-client"))]
4995 #[test]
4996 fn validate_rejects_client_cert_without_feature() {
4997 let cc = ClientCertConfig {
4998 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4999 key_path: PathBuf::from("/nonexistent/key.pem"),
5000 };
5001 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5002 let err = cfg
5003 .validate()
5004 .expect_err("client_cert without feature must be rejected");
5005 assert!(
5006 err.to_string().contains("oauth-mtls-client"),
5007 "error must reference the cargo feature; got {err}"
5008 );
5009 }
5010
5011 #[cfg(feature = "oauth-mtls-client")]
5012 #[test]
5013 fn validate_rejects_missing_client_cert_files() {
5014 let cc = ClientCertConfig {
5015 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5016 key_path: PathBuf::from("/nonexistent/key.pem"),
5017 };
5018 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5019 let err = cfg
5020 .validate()
5021 .expect_err("missing cert file must be rejected");
5022 assert!(
5023 err.to_string().contains("unreadable"),
5024 "error must call out unreadable file; got {err}"
5025 );
5026 }
5027
5028 #[cfg(feature = "oauth-mtls-client")]
5029 #[test]
5030 fn validate_rejects_malformed_client_cert_pem() {
5031 let dir = std::env::temp_dir();
5032 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5033 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5034 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5035 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5036 let cc = ClientCertConfig {
5037 cert_path: cert.clone(),
5038 key_path: key.clone(),
5039 };
5040 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5041 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5042 let _ = std::fs::remove_file(&cert);
5043 let _ = std::fs::remove_file(&key);
5044 assert!(
5045 err.to_string().contains("PEM parse failed"),
5046 "error must call out PEM parse failure; got {err}"
5047 );
5048 }
5049
5050 #[cfg(feature = "oauth-mtls-client")]
5051 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5052 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5053 let dir = std::env::temp_dir();
5054 let pid = std::process::id();
5055 let nonce: u64 = rand::random();
5056 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5057 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5058 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5059 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5060 (cert_path, key_path)
5061 }
5062
5063 #[cfg(feature = "oauth-mtls-client")]
5064 fn install_test_crypto_provider() {
5065 let _ = rustls::crypto::ring::default_provider().install_default();
5066 }
5067
5068 #[cfg(feature = "oauth-mtls-client")]
5069 #[test]
5070 fn validate_accepts_well_formed_client_cert() {
5071 install_test_crypto_provider();
5072 let (cert_path, key_path) = write_self_signed_pem();
5073 let cc = ClientCertConfig {
5074 cert_path: cert_path.clone(),
5075 key_path: key_path.clone(),
5076 };
5077 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5078 let res = cfg.validate();
5079 let _ = std::fs::remove_file(&cert_path);
5080 let _ = std::fs::remove_file(&key_path);
5081 res.expect("well-formed cert+key must validate");
5082 }
5083
5084 #[cfg(feature = "oauth-mtls-client")]
5085 #[test]
5086 fn client_for_returns_cached_mtls_client() {
5087 install_test_crypto_provider();
5088 let (cert_path, key_path) = write_self_signed_pem();
5089 let cc = ClientCertConfig {
5090 cert_path: cert_path.clone(),
5091 key_path: key_path.clone(),
5092 };
5093 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5094 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5095 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5096 let cert_client = http.client_for(tx_ref);
5097 let inner_client = http.client_for(&tx_with(Some("s"), None));
5098 let _ = std::fs::remove_file(&cert_path);
5099 let _ = std::fs::remove_file(&key_path);
5100 assert!(
5101 !std::ptr::eq(cert_client, inner_client),
5102 "client_for must return distinct clients for cert vs no-cert configs"
5103 );
5104 }
5105
5106 #[cfg(feature = "oauth-mtls-client")]
5107 #[test]
5108 fn client_for_falls_back_to_inner_when_cache_miss() {
5109 install_test_crypto_provider();
5110 let cfg = validation_https_config();
5111 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5112 let unrelated_cc = ClientCertConfig {
5113 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5114 key_path: PathBuf::from("/cache/miss/key.pem"),
5115 };
5116 let tx_unknown = tx_with(None, Some(unrelated_cc));
5117 let fallback = http.client_for(&tx_unknown);
5118 let inner = http.client_for(&tx_with(Some("s"), None));
5119 assert!(
5120 std::ptr::eq(fallback, inner),
5121 "cache miss must fall back to inner client"
5122 );
5123 }
5124}