1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
100async fn screen_oauth_target_with_test_override(
101 url: &str,
102 allow_http: bool,
103 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
104 #[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
105) -> Result<(), crate::error::McpxError> {
106 let parsed = check_oauth_url("oauth target", url, allow_http)?;
107 #[cfg(any(test, feature = "test-helpers"))]
108 if test_allow_loopback_ssrf {
109 return Ok(());
110 }
111 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
112 return Err(crate::error::McpxError::Config(format!(
113 "OAuth target forbidden ({reason}): {url}"
114 )));
115 }
116
117 let host = parsed.host_str().ok_or_else(|| {
118 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
119 })?;
120 let port = parsed.port_or_known_default().ok_or_else(|| {
121 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
122 })?;
123
124 let addrs = lookup_host((host, port)).await.map_err(|error| {
125 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
126 })?;
127
128 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
129 let mut any_addr = false;
130 for addr in addrs {
131 any_addr = true;
132 let ip = addr.ip();
133 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
134 if reason == "cloud_metadata" {
137 return Err(crate::error::McpxError::Config(format!(
138 "OAuth target resolved to blocked IP ({reason}): {url}"
139 )));
140 }
141 if allowlist.is_empty() {
145 return Err(crate::error::McpxError::Config(format!(
146 "OAuth target resolved to blocked IP ({reason}): {url}"
147 )));
148 }
149 if host_allowed || allowlist.ip_allowed(ip) {
151 continue;
152 }
153 return Err(crate::error::McpxError::Config(format!(
154 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
155 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
156 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
157 URL: {url}"
158 )));
159 }
160 }
161 if !any_addr {
162 return Err(crate::error::McpxError::Config(format!(
163 "OAuth target DNS resolution returned no addresses: {url}"
164 )));
165 }
166
167 Ok(())
168}
169
170async fn screen_oauth_target(
171 url: &str,
172 allow_http: bool,
173 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
174) -> Result<(), crate::error::McpxError> {
175 #[cfg(any(test, feature = "test-helpers"))]
176 {
177 screen_oauth_target_with_test_override(url, allow_http, allowlist, false).await
178 }
179 #[cfg(not(any(test, feature = "test-helpers")))]
180 {
181 let parsed = check_oauth_url("oauth target", url, allow_http)?;
182 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
183 return Err(crate::error::McpxError::Config(format!(
184 "OAuth target forbidden ({reason}): {url}"
185 )));
186 }
187
188 let host = parsed.host_str().ok_or_else(|| {
189 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
190 })?;
191 let port = parsed.port_or_known_default().ok_or_else(|| {
192 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
193 })?;
194
195 let addrs = lookup_host((host, port)).await.map_err(|error| {
196 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
197 })?;
198
199 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
200 let mut any_addr = false;
201 for addr in addrs {
202 any_addr = true;
203 let ip = addr.ip();
204 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
205 if reason == "cloud_metadata" {
206 return Err(crate::error::McpxError::Config(format!(
207 "OAuth target resolved to blocked IP ({reason}): {url}"
208 )));
209 }
210 if allowlist.is_empty() {
211 return Err(crate::error::McpxError::Config(format!(
212 "OAuth target resolved to blocked IP ({reason}): {url}"
213 )));
214 }
215 if host_allowed || allowlist.ip_allowed(ip) {
216 continue;
217 }
218 return Err(crate::error::McpxError::Config(format!(
219 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
220 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
221 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
222 URL: {url}"
223 )));
224 }
225 }
226 if !any_addr {
227 return Err(crate::error::McpxError::Config(format!(
228 "OAuth target DNS resolution returned no addresses: {url}"
229 )));
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Clone)]
277pub struct OauthHttpClient {
278 inner: reqwest::Client,
279 allow_http: bool,
280 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
285 #[cfg(feature = "oauth-mtls-client")]
290 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
291 #[cfg(any(test, feature = "test-helpers"))]
297 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
298}
299
300#[cfg(feature = "oauth-mtls-client")]
304#[derive(Debug, Clone, Hash, Eq, PartialEq)]
305struct MtlsClientKey {
306 cert_path: PathBuf,
307 key_path: PathBuf,
308}
309
310impl OauthHttpClient {
311 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
329 Self::build(Some(config))
330 }
331
332 #[deprecated(
355 since = "1.2.1",
356 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
357 )]
358 pub fn new() -> Result<Self, crate::error::McpxError> {
359 Self::build(None)
360 }
361
362 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
365 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
366
367 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
372 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
373 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
374 })?),
375 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
376 };
377
378 let redirect_allowlist = Arc::clone(&allowlist);
381
382 #[cfg(any(test, feature = "test-helpers"))]
386 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
387 Arc::new(AtomicBool::new(false));
388 #[cfg(not(any(test, feature = "test-helpers")))]
389 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
390
391 let resolver: Arc<dyn reqwest::dns::Resolve> =
392 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
393 Arc::clone(&allowlist),
394 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
398 test_bypass.clone(),
399 ));
400
401 let mut builder = reqwest::Client::builder()
402 .no_proxy()
406 .dns_resolver(Arc::clone(&resolver))
407 .connect_timeout(Duration::from_secs(10))
408 .timeout(Duration::from_secs(30))
409 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
410 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
420 Ok(()) => attempt.follow(),
421 Err(reason) => {
422 tracing::warn!(
423 reason = %reason,
424 target = %attempt.url(),
425 "oauth redirect rejected"
426 );
427 attempt.error(reason)
428 }
429 }
430 }));
431
432 if let Some(cfg) = config
433 && let Some(ref ca_path) = cfg.ca_cert_path
434 {
435 let pem = std::fs::read(ca_path).map_err(|e| {
440 crate::error::McpxError::Startup(format!(
441 "oauth http client: read ca_cert_path {}: {e}",
442 ca_path.display()
443 ))
444 })?;
445 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
446 crate::error::McpxError::Startup(format!(
447 "oauth http client: parse ca_cert_path {}: {e}",
448 ca_path.display()
449 ))
450 })?;
451 builder = builder.add_root_certificate(cert);
452 }
453
454 let inner = builder.build().map_err(|e| {
455 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
456 })?;
457
458 #[cfg(feature = "oauth-mtls-client")]
459 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
460
461 Ok(Self {
462 inner,
463 allow_http,
464 allowlist,
465 #[cfg(feature = "oauth-mtls-client")]
466 mtls_clients,
467 #[cfg(any(test, feature = "test-helpers"))]
468 test_allow_loopback_ssrf: test_bypass,
469 })
470 }
471
472 async fn send_screened(
473 &self,
474 url: &str,
475 request: reqwest::RequestBuilder,
476 ) -> Result<reqwest::Response, crate::error::McpxError> {
477 #[cfg(any(test, feature = "test-helpers"))]
478 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
479 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
480 .await?;
481 } else {
482 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
483 }
484 #[cfg(not(any(test, feature = "test-helpers")))]
485 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
486 request.send().await.map_err(|error| {
487 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
488 })
489 }
490
491 #[cfg(any(test, feature = "test-helpers"))]
496 #[doc(hidden)]
497 #[must_use]
498 pub fn __test_allow_loopback_ssrf(self) -> Self {
499 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
502 self
503 }
504
505 #[doc(hidden)]
511 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
512 self.inner.get(url).send().await
513 }
514
515 #[cfg(any(test, feature = "test-helpers"))]
521 #[doc(hidden)]
522 #[must_use]
523 pub fn __test_inner_client(&self) -> &reqwest::Client {
524 &self.inner
525 }
526
527 #[cfg(feature = "oauth-mtls-client")]
534 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
535 if let Some(cc) = &cfg.client_cert {
536 let key = MtlsClientKey {
537 cert_path: cc.cert_path.clone(),
538 key_path: cc.key_path.clone(),
539 };
540 if let Some(client) = self.mtls_clients.get(&key) {
541 return client;
542 }
543 }
544 &self.inner
545 }
546
547 #[cfg(not(feature = "oauth-mtls-client"))]
548 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
549 &self.inner
550 }
551}
552
553impl std::fmt::Debug for OauthHttpClient {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
556 }
557}
558
559#[derive(Debug, Clone, Default, Deserialize)]
619#[non_exhaustive]
620pub struct OAuthSsrfAllowlist {
621 #[serde(default)]
626 pub hosts: Vec<String>,
627 #[serde(default)]
633 pub cidrs: Vec<String>,
634}
635
636fn compile_oauth_ssrf_allowlist(
643 raw: &OAuthSsrfAllowlist,
644) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
645 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
646 for (idx, entry) in raw.hosts.iter().enumerate() {
647 let trimmed = entry.trim();
648 if trimmed.is_empty() {
649 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
650 }
651 if trimmed.contains([':', '/', '@', '?', '#']) {
655 return Err(format!(
656 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
657 (no scheme, port, path, userinfo, query, or fragment)"
658 ));
659 }
660 match url::Host::parse(trimmed) {
661 Ok(url::Host::Domain(_)) => {}
662 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
663 return Err(format!(
664 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
665 here -- list them via oauth.ssrf_allowlist.cidrs instead"
666 ));
667 }
668 Err(e) => {
669 return Err(format!(
670 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
671 ));
672 }
673 }
674 hosts.push(trimmed.to_ascii_lowercase());
675 }
676 hosts.sort();
677 hosts.dedup();
678
679 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
680 for (idx, entry) in raw.cidrs.iter().enumerate() {
681 let parsed = crate::ssrf::CidrEntry::parse(entry)
682 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
683 cidrs.push(parsed);
684 }
685
686 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
687}
688
689#[derive(Debug, Clone, Deserialize)]
691#[non_exhaustive]
692pub struct OAuthConfig {
693 pub issuer: String,
695 pub audience: String,
697 pub jwks_uri: String,
699 #[serde(default)]
702 pub scopes: Vec<ScopeMapping>,
703 pub role_claim: Option<String>,
709 #[serde(default)]
712 pub role_mappings: Vec<RoleMapping>,
713 #[serde(default = "default_jwks_cache_ttl")]
716 pub jwks_cache_ttl: String,
717 pub proxy: Option<OAuthProxyConfig>,
721 pub token_exchange: Option<TokenExchangeConfig>,
726 #[serde(default)]
741 pub ca_cert_path: Option<PathBuf>,
742 #[serde(default)]
754 pub allow_http_oauth_urls: bool,
755 #[serde(default)]
764 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
765 #[serde(default = "default_max_jwks_keys")]
769 pub max_jwks_keys: usize,
770 #[serde(default)]
779 #[deprecated(
780 since = "1.7.0",
781 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
782 )]
783 pub strict_audience_validation: bool,
784 #[serde(default)]
792 pub audience_validation_mode: Option<AudienceValidationMode>,
793 #[serde(default = "default_jwks_max_bytes")]
797 pub jwks_max_response_bytes: u64,
798}
799
800fn default_jwks_cache_ttl() -> String {
801 "10m".into()
802}
803
804const fn default_max_jwks_keys() -> usize {
805 256
806}
807
808const fn default_jwks_max_bytes() -> u64 {
809 1024 * 1024
810}
811
812#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
829#[serde(rename_all = "snake_case")]
830#[non_exhaustive]
831pub enum AudienceValidationMode {
832 Permissive,
836 #[default]
840 Warn,
841 Strict,
845}
846
847impl Default for OAuthConfig {
848 fn default() -> Self {
849 Self {
850 issuer: String::new(),
851 audience: String::new(),
852 jwks_uri: String::new(),
853 scopes: Vec::new(),
854 role_claim: None,
855 role_mappings: Vec::new(),
856 jwks_cache_ttl: default_jwks_cache_ttl(),
857 proxy: None,
858 token_exchange: None,
859 ca_cert_path: None,
860 allow_http_oauth_urls: false,
861 max_jwks_keys: default_max_jwks_keys(),
862 #[allow(
863 deprecated,
864 reason = "default-construct deprecated field for backward compat"
865 )]
866 strict_audience_validation: false,
867 audience_validation_mode: None,
868 jwks_max_response_bytes: default_jwks_max_bytes(),
869 ssrf_allowlist: None,
870 }
871 }
872}
873
874impl OAuthConfig {
875 #[must_use]
881 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
882 if let Some(mode) = self.audience_validation_mode {
883 return mode;
884 }
885 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
886 if self.strict_audience_validation {
887 AudienceValidationMode::Strict
888 } else {
889 AudienceValidationMode::Warn
890 }
891 }
892
893 pub fn builder(
899 issuer: impl Into<String>,
900 audience: impl Into<String>,
901 jwks_uri: impl Into<String>,
902 ) -> OAuthConfigBuilder {
903 OAuthConfigBuilder {
904 inner: Self {
905 issuer: issuer.into(),
906 audience: audience.into(),
907 jwks_uri: jwks_uri.into(),
908 ..Self::default()
909 },
910 }
911 }
912
913 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
929 let allow_http = self.allow_http_oauth_urls;
930 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
931 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
932 return Err(crate::error::McpxError::Config(format!(
933 "oauth.issuer forbidden ({reason})"
934 )));
935 }
936 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
937 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
938 return Err(crate::error::McpxError::Config(format!(
939 "oauth.jwks_uri forbidden ({reason})"
940 )));
941 }
942 if let Some(proxy) = &self.proxy {
943 let url = check_oauth_url(
944 "oauth.proxy.authorize_url",
945 &proxy.authorize_url,
946 allow_http,
947 )?;
948 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
949 return Err(crate::error::McpxError::Config(format!(
950 "oauth.proxy.authorize_url forbidden ({reason})"
951 )));
952 }
953 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
954 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
955 return Err(crate::error::McpxError::Config(format!(
956 "oauth.proxy.token_url forbidden ({reason})"
957 )));
958 }
959 if let Some(url) = &proxy.introspection_url {
960 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
961 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
962 return Err(crate::error::McpxError::Config(format!(
963 "oauth.proxy.introspection_url forbidden ({reason})"
964 )));
965 }
966 }
967 if let Some(url) = &proxy.revocation_url {
968 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
969 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
970 return Err(crate::error::McpxError::Config(format!(
971 "oauth.proxy.revocation_url forbidden ({reason})"
972 )));
973 }
974 }
975 if proxy.expose_admin_endpoints
982 && !proxy.require_auth_on_admin_endpoints
983 && !proxy.allow_unauthenticated_admin_endpoints
984 {
985 return Err(crate::error::McpxError::Config(
986 "oauth.proxy: expose_admin_endpoints = true requires \
987 require_auth_on_admin_endpoints = true (recommended) \
988 or allow_unauthenticated_admin_endpoints = true \
989 (explicit opt-out, only safe behind an authenticated \
990 reverse proxy)"
991 .into(),
992 ));
993 }
994 }
995 if let Some(tx) = &self.token_exchange {
996 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
997 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
998 return Err(crate::error::McpxError::Config(format!(
999 "oauth.token_exchange.token_url forbidden ({reason})"
1000 )));
1001 }
1002 validate_token_exchange_client_auth(tx)?;
1005 }
1006 if let Some(raw) = &self.ssrf_allowlist {
1010 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1011 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
1012 })?;
1013 if !compiled.is_empty() {
1014 tracing::warn!(
1015 host_count = compiled.host_count(),
1016 cidr_count = compiled.cidr_count(),
1017 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
1018 are now reachable. Cloud-metadata addresses remain blocked. \
1019 See SECURITY.md \"Operator allowlist\"."
1020 );
1021 }
1022 }
1023 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
1026 crate::error::McpxError::Config(format!(
1027 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
1028 self.jwks_cache_ttl
1029 ))
1030 })?;
1031 Ok(())
1032 }
1033}
1034
1035fn validate_token_exchange_client_auth(
1041 tx: &TokenExchangeConfig,
1042) -> Result<(), crate::error::McpxError> {
1043 match (&tx.client_cert, tx.client_secret.is_some()) {
1044 (Some(_), true) => Err(crate::error::McpxError::Config(
1045 "oauth.token_exchange: client_cert and client_secret are mutually \
1046 exclusive (RFC 8705 ยง2). Set exactly one."
1047 .into(),
1048 )),
1049 (None, false) => Err(crate::error::McpxError::Config(
1050 "oauth.token_exchange: token exchange requires client authentication. \
1051 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1052 .into(),
1053 )),
1054 (Some(cc), false) => validate_client_cert_config(cc),
1055 (None, true) => Ok(()),
1056 }
1057}
1058
1059fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1072 #[cfg(not(feature = "oauth-mtls-client"))]
1073 {
1074 let _ = cc;
1075 Err(crate::error::McpxError::Config(
1076 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1077 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1078 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1079 the field"
1080 .into(),
1081 ))
1082 }
1083 #[cfg(feature = "oauth-mtls-client")]
1084 {
1085 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1086 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1087 crate::error::McpxError::Config(format!(
1088 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1089 cc.cert_path.display()
1090 ))
1091 })?;
1092 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1093 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1094 crate::error::McpxError::Config(format!(
1095 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1096 cc.key_path.display()
1097 ))
1098 })?;
1099 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1100 combined.extend_from_slice(&cert_bytes);
1101 if !cert_bytes.ends_with(b"\n") {
1102 combined.push(b'\n');
1103 }
1104 combined.extend_from_slice(&key_bytes);
1105 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1106 tracing::warn!(
1107 error = %e,
1108 cert_path = %cc.cert_path.display(),
1109 key_path = %cc.key_path.display(),
1110 "client cert PEM parse failed"
1111 );
1112 crate::error::McpxError::Config(format!(
1113 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1114 cc.cert_path.display(),
1115 cc.key_path.display()
1116 ))
1117 })?;
1118 Ok(())
1119 }
1120}
1121
1122#[cfg(feature = "oauth-mtls-client")]
1130fn build_mtls_clients(
1131 config: Option<&OAuthConfig>,
1132 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1133 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1134) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1135 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1136 let Some(cfg) = config else {
1137 return Ok(Arc::new(map));
1138 };
1139 let Some(tx) = &cfg.token_exchange else {
1140 return Ok(Arc::new(map));
1141 };
1142 let Some(cc) = &tx.client_cert else {
1143 return Ok(Arc::new(map));
1144 };
1145
1146 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1147 crate::error::McpxError::Startup(format!(
1148 "oauth http client mTLS: read cert_path {}: {e}",
1149 cc.cert_path.display()
1150 ))
1151 })?;
1152 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1153 crate::error::McpxError::Startup(format!(
1154 "oauth http client mTLS: read key_path {}: {e}",
1155 cc.key_path.display()
1156 ))
1157 })?;
1158 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1159 combined.extend_from_slice(&cert_bytes);
1160 if !cert_bytes.ends_with(b"\n") {
1161 combined.push(b'\n');
1162 }
1163 combined.extend_from_slice(&key_bytes);
1164 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1165 crate::error::McpxError::Startup(format!(
1166 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1167 cc.cert_path.display(),
1168 cc.key_path.display()
1169 ))
1170 })?;
1171
1172 let resolver: Arc<dyn reqwest::dns::Resolve> =
1173 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1174 Arc::clone(allowlist),
1175 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1180 test_bypass.clone(),
1181 ));
1182
1183 let mut builder = reqwest::Client::builder()
1184 .no_proxy()
1186 .dns_resolver(Arc::clone(&resolver))
1187 .connect_timeout(Duration::from_secs(10))
1188 .timeout(Duration::from_secs(30))
1189 .redirect(reqwest::redirect::Policy::none())
1190 .identity(identity);
1191
1192 if let Some(ref ca_path) = cfg.ca_cert_path {
1193 let pem = std::fs::read(ca_path).map_err(|e| {
1194 crate::error::McpxError::Startup(format!(
1195 "oauth http client mTLS: read ca_cert_path {}: {e}",
1196 ca_path.display()
1197 ))
1198 })?;
1199 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1200 crate::error::McpxError::Startup(format!(
1201 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1202 ca_path.display()
1203 ))
1204 })?;
1205 builder = builder.add_root_certificate(cert);
1206 }
1207
1208 let client = builder.build().map_err(|e| {
1209 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1210 })?;
1211 map.insert(
1212 MtlsClientKey {
1213 cert_path: cc.cert_path.clone(),
1214 key_path: cc.key_path.clone(),
1215 },
1216 client,
1217 );
1218 Ok(Arc::new(map))
1219}
1220
1221fn check_oauth_url(
1228 field: &str,
1229 raw: &str,
1230 allow_http: bool,
1231) -> Result<url::Url, crate::error::McpxError> {
1232 let parsed = url::Url::parse(raw).map_err(|e| {
1233 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1234 })?;
1235 if !parsed.username().is_empty() || parsed.password().is_some() {
1236 return Err(crate::error::McpxError::Config(format!(
1237 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1238 )));
1239 }
1240 match parsed.scheme() {
1241 "https" => Ok(parsed),
1242 "http" if allow_http => Ok(parsed),
1243 "http" => Err(crate::error::McpxError::Config(format!(
1244 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1245 to override - strongly discouraged in production)"
1246 ))),
1247 other => Err(crate::error::McpxError::Config(format!(
1248 "{field}: must use https scheme (got {other:?})"
1249 ))),
1250 }
1251}
1252
1253#[derive(Debug, Clone)]
1259#[must_use = "builders do nothing until `.build()` is called"]
1260pub struct OAuthConfigBuilder {
1261 inner: OAuthConfig,
1262}
1263
1264impl OAuthConfigBuilder {
1265 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1267 self.inner.scopes = scopes;
1268 self
1269 }
1270
1271 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1273 self.inner.scopes.push(ScopeMapping {
1274 scope: scope.into(),
1275 role: role.into(),
1276 });
1277 self
1278 }
1279
1280 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1283 self.inner.role_claim = Some(claim.into());
1284 self
1285 }
1286
1287 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1289 self.inner.role_mappings = mappings;
1290 self
1291 }
1292
1293 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1296 self.inner.role_mappings.push(RoleMapping {
1297 claim_value: claim_value.into(),
1298 role: role.into(),
1299 });
1300 self
1301 }
1302
1303 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1306 self.inner.jwks_cache_ttl = ttl.into();
1307 self
1308 }
1309
1310 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1313 self.inner.proxy = Some(proxy);
1314 self
1315 }
1316
1317 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1319 self.inner.token_exchange = Some(token_exchange);
1320 self
1321 }
1322
1323 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1328 self.inner.ca_cert_path = Some(path.into());
1329 self
1330 }
1331
1332 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1338 self.inner.allow_http_oauth_urls = allow;
1339 self
1340 }
1341
1342 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1351 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1352 #[allow(
1353 deprecated,
1354 reason = "intentional: deprecated builder forwards to deprecated field"
1355 )]
1356 {
1357 self.inner.strict_audience_validation = strict;
1358 }
1359 self.inner.audience_validation_mode = None;
1360 self
1361 }
1362
1363 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1371 self.inner.audience_validation_mode = Some(mode);
1372 self
1373 }
1374
1375 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1377 self.inner.jwks_max_response_bytes = bytes;
1378 self
1379 }
1380
1381 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1389 self.inner.ssrf_allowlist = Some(allowlist);
1390 self
1391 }
1392
1393 #[must_use]
1395 pub fn build(self) -> OAuthConfig {
1396 self.inner
1397 }
1398}
1399
1400#[derive(Debug, Clone, Deserialize)]
1402#[non_exhaustive]
1403pub struct ScopeMapping {
1404 pub scope: String,
1406 pub role: String,
1408}
1409
1410#[derive(Debug, Clone, Deserialize)]
1414#[non_exhaustive]
1415pub struct RoleMapping {
1416 pub claim_value: String,
1418 pub role: String,
1420}
1421
1422#[derive(Debug, Clone, Deserialize)]
1429#[non_exhaustive]
1430pub struct TokenExchangeConfig {
1431 pub token_url: String,
1434 pub client_id: String,
1436 pub client_secret: Option<secrecy::SecretString>,
1441 pub client_cert: Option<ClientCertConfig>,
1454 pub audience: String,
1458}
1459
1460impl TokenExchangeConfig {
1461 #[must_use]
1463 pub fn new(
1464 token_url: String,
1465 client_id: String,
1466 client_secret: Option<secrecy::SecretString>,
1467 client_cert: Option<ClientCertConfig>,
1468 audience: String,
1469 ) -> Self {
1470 Self {
1471 token_url,
1472 client_id,
1473 client_secret,
1474 client_cert,
1475 audience,
1476 }
1477 }
1478}
1479
1480#[derive(Debug, Clone, Deserialize)]
1484#[non_exhaustive]
1485pub struct ClientCertConfig {
1486 pub cert_path: PathBuf,
1489 pub key_path: PathBuf,
1493}
1494
1495impl ClientCertConfig {
1496 #[must_use]
1500 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1501 Self {
1502 cert_path,
1503 key_path,
1504 }
1505 }
1506}
1507
1508#[derive(Debug, Deserialize)]
1510#[non_exhaustive]
1511pub struct ExchangedToken {
1512 pub access_token: String,
1514 pub expires_in: Option<u64>,
1516 pub issued_token_type: Option<String>,
1519}
1520
1521#[derive(Debug, Clone, Deserialize, Default)]
1528#[non_exhaustive]
1529pub struct OAuthProxyConfig {
1530 pub authorize_url: String,
1533 pub token_url: String,
1536 pub client_id: String,
1538 pub client_secret: Option<secrecy::SecretString>,
1540 #[serde(default)]
1544 pub introspection_url: Option<String>,
1545 #[serde(default)]
1549 pub revocation_url: Option<String>,
1550 #[serde(default)]
1562 pub expose_admin_endpoints: bool,
1563 #[serde(default)]
1569 pub require_auth_on_admin_endpoints: bool,
1570 #[serde(default)]
1581 pub allow_unauthenticated_admin_endpoints: bool,
1582}
1583
1584impl OAuthProxyConfig {
1585 pub fn builder(
1593 authorize_url: impl Into<String>,
1594 token_url: impl Into<String>,
1595 client_id: impl Into<String>,
1596 ) -> OAuthProxyConfigBuilder {
1597 OAuthProxyConfigBuilder {
1598 inner: Self {
1599 authorize_url: authorize_url.into(),
1600 token_url: token_url.into(),
1601 client_id: client_id.into(),
1602 ..Self::default()
1603 },
1604 }
1605 }
1606}
1607
1608#[derive(Debug, Clone)]
1614#[must_use = "builders do nothing until `.build()` is called"]
1615pub struct OAuthProxyConfigBuilder {
1616 inner: OAuthProxyConfig,
1617}
1618
1619impl OAuthProxyConfigBuilder {
1620 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1622 self.inner.client_secret = Some(secret);
1623 self
1624 }
1625
1626 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1630 self.inner.introspection_url = Some(url.into());
1631 self
1632 }
1633
1634 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1638 self.inner.revocation_url = Some(url.into());
1639 self
1640 }
1641
1642 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1650 self.inner.expose_admin_endpoints = expose;
1651 self
1652 }
1653
1654 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1657 self.inner.require_auth_on_admin_endpoints = require;
1658 self
1659 }
1660
1661 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1665 self.inner.allow_unauthenticated_admin_endpoints = allow;
1666 self
1667 }
1668
1669 #[must_use]
1671 pub fn build(self) -> OAuthProxyConfig {
1672 self.inner
1673 }
1674}
1675
1676type JwksKeyCache = (
1684 HashMap<String, (Algorithm, DecodingKey)>,
1685 Vec<(Algorithm, DecodingKey)>,
1686);
1687
1688struct CachedKeys {
1689 keys: HashMap<String, (Algorithm, DecodingKey)>,
1691 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1693 fetched_at: Instant,
1694 ttl: Duration,
1695}
1696
1697impl CachedKeys {
1698 fn is_expired(&self) -> bool {
1699 self.fetched_at.elapsed() >= self.ttl
1700 }
1701}
1702
1703#[allow(
1712 missing_debug_implementations,
1713 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1714)]
1715#[non_exhaustive]
1716pub struct JwksCache {
1717 jwks_uri: String,
1718 ttl: Duration,
1719 max_jwks_keys: usize,
1720 max_response_bytes: u64,
1721 allow_http: bool,
1722 inner: RwLock<Option<CachedKeys>>,
1723 http: reqwest::Client,
1724 validation_template: Validation,
1725 expected_audience: String,
1728 audience_mode: AudienceValidationMode,
1729 azp_fallback_warned: AtomicBool,
1733 scopes: Vec<ScopeMapping>,
1734 role_claim: Option<String>,
1735 role_mappings: Vec<RoleMapping>,
1736 last_refresh_attempt: RwLock<Option<Instant>>,
1739 refresh_lock: tokio::sync::Mutex<()>,
1741 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1745 #[cfg(any(test, feature = "test-helpers"))]
1749 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1750}
1751
1752const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1754
1755const ACCEPTED_ALGS: &[Algorithm] = &[
1757 Algorithm::RS256,
1758 Algorithm::RS384,
1759 Algorithm::RS512,
1760 Algorithm::ES256,
1761 Algorithm::ES384,
1762 Algorithm::PS256,
1763 Algorithm::PS384,
1764 Algorithm::PS512,
1765 Algorithm::EdDSA,
1766];
1767
1768#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1770#[non_exhaustive]
1771pub enum JwtValidationFailure {
1772 Expired,
1774 Invalid,
1776}
1777
1778impl JwksCache {
1779 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1793 rustls::crypto::ring::default_provider()
1796 .install_default()
1797 .ok();
1798 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1799 .install_default()
1800 .ok();
1801
1802 #[allow(
1803 clippy::expect_used,
1804 reason = "jwks_cache_ttl was already parsed successfully by OAuthConfig::validate (call site precondition); re-parsing the same validated string here is infallible"
1805 )]
1806 let ttl = humantime::parse_duration(&config.jwks_cache_ttl)
1807 .expect("jwks_cache_ttl validated by OAuthConfig::validate");
1808
1809 let mut validation = Validation::new(Algorithm::RS256);
1810 validation.validate_aud = false;
1822 validation.set_issuer(&[&config.issuer]);
1823 validation.set_required_spec_claims(&["exp", "iss"]);
1824 validation.validate_exp = true;
1825 validation.validate_nbf = true;
1826
1827 let allow_http = config.allow_http_oauth_urls;
1828
1829 let allowlist = match config.ssrf_allowlist.as_ref() {
1832 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1833 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1834 "oauth.ssrf_allowlist: {e}"
1835 ))
1836 })?),
1837 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1838 };
1839 let redirect_allowlist = Arc::clone(&allowlist);
1840
1841 #[cfg(any(test, feature = "test-helpers"))]
1843 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1844 Arc::new(AtomicBool::new(false));
1845 #[cfg(not(any(test, feature = "test-helpers")))]
1846 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1847
1848 let resolver: Arc<dyn reqwest::dns::Resolve> =
1849 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1850 Arc::clone(&allowlist),
1851 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1852 test_bypass.clone(),
1853 ));
1854
1855 let mut http_builder = reqwest::Client::builder()
1856 .no_proxy()
1858 .dns_resolver(Arc::clone(&resolver))
1859 .timeout(Duration::from_secs(10))
1860 .connect_timeout(Duration::from_secs(3))
1861 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1862 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1872 Ok(()) => attempt.follow(),
1873 Err(reason) => {
1874 tracing::warn!(
1875 reason = %reason,
1876 target = %attempt.url(),
1877 "oauth redirect rejected"
1878 );
1879 attempt.error(reason)
1880 }
1881 }
1882 }));
1883
1884 if let Some(ref ca_path) = config.ca_cert_path {
1885 let pem = std::fs::read(ca_path)?;
1891 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1892 http_builder = http_builder.add_root_certificate(cert);
1893 }
1894
1895 let http = http_builder.build()?;
1896
1897 Ok(Self {
1898 jwks_uri: config.jwks_uri.clone(),
1899 ttl,
1900 max_jwks_keys: config.max_jwks_keys,
1901 max_response_bytes: config.jwks_max_response_bytes,
1902 allow_http,
1903 inner: RwLock::new(None),
1904 http,
1905 validation_template: validation,
1906 expected_audience: config.audience.clone(),
1907 audience_mode: config.effective_audience_validation_mode(),
1908 azp_fallback_warned: AtomicBool::new(false),
1909 scopes: config.scopes.clone(),
1910 role_claim: config.role_claim.clone(),
1911 role_mappings: config.role_mappings.clone(),
1912 last_refresh_attempt: RwLock::new(None),
1913 refresh_lock: tokio::sync::Mutex::new(()),
1914 allowlist,
1915 #[cfg(any(test, feature = "test-helpers"))]
1916 test_allow_loopback_ssrf: test_bypass,
1917 })
1918 }
1919
1920 #[cfg(any(test, feature = "test-helpers"))]
1924 #[doc(hidden)]
1925 #[must_use]
1926 pub fn __test_allow_loopback_ssrf(self) -> Self {
1927 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1930 self
1931 }
1932
1933 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1935 self.validate_token_with_reason(token).await.ok()
1936 }
1937
1938 pub async fn validate_token_with_reason(
1945 &self,
1946 token: &str,
1947 ) -> Result<AuthIdentity, JwtValidationFailure> {
1948 let claims = self.decode_claims(token).await?;
1949
1950 self.check_audience(&claims)?;
1951 let role = self.resolve_role(&claims)?;
1952
1953 let sub = claims.sub;
1956 let name = claims
1957 .extra
1958 .get("preferred_username")
1959 .and_then(|v| v.as_str())
1960 .map(String::from)
1961 .or_else(|| sub.clone())
1962 .or(claims.azp)
1963 .or(claims.client_id)
1964 .unwrap_or_else(|| "oauth-client".into());
1965
1966 Ok(AuthIdentity {
1967 name,
1968 role,
1969 method: AuthMethod::OAuthJwt,
1970 raw_token: None,
1971 sub,
1972 })
1973 }
1974
1975 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1987 let (key, alg) = self.select_jwks_key(token).await?;
1988
1989 let mut validation = self.validation_template.clone();
1993 validation.algorithms = vec![alg];
1994
1995 let token_owned = token.to_owned();
1998 let join =
1999 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
2000 .await;
2001
2002 let decode_result = match join {
2003 Ok(r) => r,
2004 Err(join_err) => {
2005 core::hint::cold_path();
2006 tracing::error!(
2007 error = %join_err,
2008 "JWT decode task panicked or was cancelled"
2009 );
2010 return Err(JwtValidationFailure::Invalid);
2011 }
2012 };
2013
2014 decode_result.map(|td| td.claims).map_err(|e| {
2015 core::hint::cold_path();
2016 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
2017 JwtValidationFailure::Expired
2018 } else {
2019 JwtValidationFailure::Invalid
2020 };
2021 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
2022 failure
2023 })
2024 }
2025
2026 #[allow(
2035 clippy::cognitive_complexity,
2036 reason = "each failure arm pairs `cold_path()` with a distinct `tracing::debug!` site for observability; collapsing into combinators would lose structured-field log sites without reducing real complexity"
2037 )]
2038 async fn select_jwks_key(
2039 &self,
2040 token: &str,
2041 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2042 let Ok(header) = decode_header(token) else {
2043 core::hint::cold_path();
2044 tracing::debug!("JWT header decode failed");
2045 return Err(JwtValidationFailure::Invalid);
2046 };
2047 let kid = header.kid.as_deref();
2048 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2049
2050 if !ACCEPTED_ALGS.contains(&header.alg) {
2051 core::hint::cold_path();
2052 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2053 return Err(JwtValidationFailure::Invalid);
2054 }
2055
2056 let Some(key) = self.find_key(kid, header.alg).await else {
2057 core::hint::cold_path();
2058 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2059 return Err(JwtValidationFailure::Invalid);
2060 };
2061
2062 Ok((key, header.alg))
2063 }
2064
2065 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2074 if claims.aud.contains(&self.expected_audience) {
2075 return Ok(());
2076 }
2077 let azp_match = claims
2078 .azp
2079 .as_deref()
2080 .is_some_and(|azp| azp == self.expected_audience);
2081 if azp_match {
2082 match self.audience_mode {
2083 AudienceValidationMode::Permissive => return Ok(()),
2084 AudienceValidationMode::Warn => {
2085 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2086 tracing::warn!(
2087 expected = %self.expected_audience,
2088 azp = ?claims.azp,
2089 "JWT accepted via deprecated `azp`-only audience fallback. \
2090 Configure your IdP to populate `aud`, or set \
2091 `audience_validation_mode = \"strict\"` once tokens carry `aud` correctly. \
2092 To silence this warning without changing acceptance, \
2093 set `audience_validation_mode = \"permissive\"`. \
2094 This warning logs once per process."
2095 );
2096 }
2097 return Ok(());
2098 }
2099 AudienceValidationMode::Strict => {}
2100 }
2101 }
2102 core::hint::cold_path();
2103 tracing::debug!(
2104 aud = ?claims.aud.0,
2105 azp = ?claims.azp,
2106 expected = %self.expected_audience,
2107 mode = ?self.audience_mode,
2108 "JWT rejected: audience mismatch"
2109 );
2110 Err(JwtValidationFailure::Invalid)
2111 }
2112
2113 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2119 if let Some(ref claim_path) = self.role_claim {
2120 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2121 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2122 values.extend(resolve_claim_path(&claims.extra, claim_path));
2123 return self
2124 .role_mappings
2125 .iter()
2126 .find(|m| values.contains(&m.claim_value.as_str()))
2127 .map(|m| m.role.clone())
2128 .ok_or(JwtValidationFailure::Invalid);
2129 }
2130
2131 let token_scopes: Vec<&str> = claims
2132 .scope
2133 .as_deref()
2134 .unwrap_or("")
2135 .split_whitespace()
2136 .collect();
2137
2138 self.scopes
2139 .iter()
2140 .find(|m| token_scopes.contains(&m.scope.as_str()))
2141 .map(|m| m.role.clone())
2142 .ok_or(JwtValidationFailure::Invalid)
2143 }
2144
2145 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2148 {
2150 let guard = self.inner.read().await;
2151 if let Some(cached) = guard.as_ref()
2152 && !cached.is_expired()
2153 && let Some(key) = lookup_key(cached, kid, alg)
2154 {
2155 return Some(key);
2156 }
2157 }
2158
2159 self.refresh_with_cooldown().await;
2161
2162 let guard = self.inner.read().await;
2163 guard
2164 .as_ref()
2165 .and_then(|cached| lookup_key(cached, kid, alg))
2166 }
2167
2168 async fn refresh_with_cooldown(&self) {
2173 let _guard = self.refresh_lock.lock().await;
2175
2176 {
2178 let last = self.last_refresh_attempt.read().await;
2179 if let Some(ts) = *last
2180 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2181 {
2182 tracing::debug!(
2183 elapsed_ms = ts.elapsed().as_millis(),
2184 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2185 "JWKS refresh skipped (cooldown active)"
2186 );
2187 return;
2188 }
2189 }
2190
2191 {
2194 let mut last = self.last_refresh_attempt.write().await;
2195 *last = Some(Instant::now());
2196 }
2197
2198 let _ = self.refresh_inner().await;
2200 }
2201
2202 async fn refresh_inner(&self) -> Result<(), String> {
2207 let Some(jwks) = self.fetch_jwks().await else {
2208 return Ok(());
2209 };
2210 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2211 Ok(cache) => cache,
2212 Err(msg) => {
2213 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2214 return Err(msg);
2215 }
2216 };
2217
2218 tracing::debug!(
2219 named = keys.len(),
2220 unnamed = unnamed_keys.len(),
2221 "JWKS refreshed"
2222 );
2223
2224 let mut guard = self.inner.write().await;
2225 *guard = Some(CachedKeys {
2226 keys,
2227 unnamed_keys,
2228 fetched_at: Instant::now(),
2229 ttl: self.ttl,
2230 });
2231 Ok(())
2232 }
2233
2234 #[allow(
2236 clippy::cognitive_complexity,
2237 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2238 )]
2239 async fn fetch_jwks(&self) -> Option<JwkSet> {
2240 #[cfg(any(test, feature = "test-helpers"))]
2241 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2242 screen_oauth_target_with_test_override(
2243 &self.jwks_uri,
2244 self.allow_http,
2245 &self.allowlist,
2246 true,
2247 )
2248 .await
2249 } else {
2250 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2251 };
2252 #[cfg(not(any(test, feature = "test-helpers")))]
2253 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2254
2255 if let Err(error) = screening {
2256 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2257 return None;
2258 }
2259
2260 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2261 Ok(resp) => resp,
2262 Err(e) => {
2263 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2264 return None;
2265 }
2266 };
2267
2268 let initial_capacity =
2269 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2270 let mut body = Vec::with_capacity(initial_capacity);
2271 while let Some(chunk) = match resp.chunk().await {
2272 Ok(chunk) => chunk,
2273 Err(error) => {
2274 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2275 return None;
2276 }
2277 } {
2278 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2279 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2280 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2281 tracing::warn!(
2282 uri = %self.jwks_uri,
2283 max_bytes = self.max_response_bytes,
2284 "JWKS response exceeded configured size cap"
2285 );
2286 return None;
2287 }
2288 body.extend_from_slice(&chunk);
2289 }
2290
2291 match serde_json::from_slice::<JwkSet>(&body) {
2292 Ok(jwks) => Some(jwks),
2293 Err(error) => {
2294 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2295 None
2296 }
2297 }
2298 }
2299
2300 #[cfg(any(test, feature = "test-helpers"))]
2303 #[doc(hidden)]
2304 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2305 let jwks = self
2306 .fetch_jwks()
2307 .await
2308 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2309 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2310 let mut guard = self.inner.write().await;
2311 *guard = Some(CachedKeys {
2312 keys,
2313 unnamed_keys,
2314 fetched_at: Instant::now(),
2315 ttl: self.ttl,
2316 });
2317 Ok(())
2318 }
2319
2320 #[cfg(any(test, feature = "test-helpers"))]
2323 #[doc(hidden)]
2324 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2325 let guard = self.inner.read().await;
2326 guard
2327 .as_ref()
2328 .is_some_and(|cache| cache.keys.contains_key(kid))
2329 }
2330}
2331
2332fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2334 if jwks.keys.len() > max_keys {
2335 return Err(format!(
2336 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2337 jwks.keys.len(),
2338 max_keys
2339 ));
2340 }
2341 let mut keys = HashMap::new();
2342 let mut unnamed_keys = Vec::new();
2343 for jwk in &jwks.keys {
2344 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2345 continue;
2346 };
2347 let Some(alg) = jwk_algorithm(jwk) else {
2348 continue;
2349 };
2350 if let Some(ref kid) = jwk.common.key_id {
2351 keys.insert(kid.clone(), (alg, decoding_key));
2352 } else {
2353 unnamed_keys.push((alg, decoding_key));
2354 }
2355 }
2356 Ok((keys, unnamed_keys))
2357}
2358
2359fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2361 if let Some(kid) = kid
2362 && let Some((cached_alg, key)) = cached.keys.get(kid)
2363 && *cached_alg == alg
2364 {
2365 return Some(key.clone());
2366 }
2367 cached
2369 .unnamed_keys
2370 .iter()
2371 .find(|(a, _)| *a == alg)
2372 .map(|(_, k)| k.clone())
2373}
2374
2375#[allow(clippy::wildcard_enum_match_arm)]
2377fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2378 jwk.common.key_algorithm.and_then(|ka| match ka {
2379 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2380 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2381 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2382 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2383 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2384 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2385 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2386 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2387 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2388 _ => None,
2389 })
2390}
2391
2392fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2413 match path {
2414 "sub" => claims.sub.iter().cloned().collect(),
2415 "azp" => claims.azp.iter().cloned().collect(),
2416 "client_id" => claims.client_id.iter().cloned().collect(),
2417 "aud" => claims.aud.0.clone(),
2418 "scope" => claims
2419 .scope
2420 .as_deref()
2421 .unwrap_or("")
2422 .split_whitespace()
2423 .map(str::to_owned)
2424 .collect(),
2425 _ => Vec::new(),
2426 }
2427}
2428
2429fn resolve_claim_path<'a>(
2439 extra: &'a HashMap<String, serde_json::Value>,
2440 path: &str,
2441) -> Vec<&'a str> {
2442 let mut segments = path.split('.');
2443 let Some(first) = segments.next() else {
2444 return Vec::new();
2445 };
2446
2447 let mut current: Option<&serde_json::Value> = extra.get(first);
2448
2449 for segment in segments {
2450 current = current.and_then(|v| v.get(segment));
2451 }
2452
2453 match current {
2454 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2455 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2456 _ => Vec::new(),
2457 }
2458}
2459
2460#[derive(Debug, Deserialize)]
2466struct Claims {
2467 sub: Option<String>,
2469 #[serde(default)]
2472 aud: OneOrMany,
2473 azp: Option<String>,
2475 client_id: Option<String>,
2477 scope: Option<String>,
2479 #[serde(flatten)]
2481 extra: HashMap<String, serde_json::Value>,
2482}
2483
2484#[derive(Debug, Default)]
2486struct OneOrMany(Vec<String>);
2487
2488impl OneOrMany {
2489 fn contains(&self, value: &str) -> bool {
2490 self.0.iter().any(|v| v == value)
2491 }
2492}
2493
2494impl<'de> Deserialize<'de> for OneOrMany {
2495 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2496 use serde::de;
2497
2498 struct Visitor;
2499 impl<'de> de::Visitor<'de> for Visitor {
2500 type Value = OneOrMany;
2501 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2502 f.write_str("a string or array of strings")
2503 }
2504 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2505 Ok(OneOrMany(vec![v.to_owned()]))
2506 }
2507 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2508 let mut v = Vec::new();
2509 while let Some(s) = seq.next_element::<String>()? {
2510 v.push(s);
2511 }
2512 Ok(OneOrMany(v))
2513 }
2514 }
2515 deserializer.deserialize_any(Visitor)
2516 }
2517}
2518
2519#[must_use]
2526pub fn looks_like_jwt(token: &str) -> bool {
2527 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2528
2529 let mut parts = token.splitn(4, '.');
2530 let Some(header_b64) = parts.next() else {
2531 return false;
2532 };
2533 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2535 return false;
2536 }
2537 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2539 return false;
2540 };
2541 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2543 return false;
2544 };
2545 header.get("alg").is_some()
2546}
2547
2548#[must_use]
2558pub fn protected_resource_metadata(
2559 resource_url: &str,
2560 server_url: &str,
2561 config: &OAuthConfig,
2562) -> serde_json::Value {
2563 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2568 let auth_server = server_url;
2569 serde_json::json!({
2570 "resource": resource_url,
2571 "authorization_servers": [auth_server],
2572 "scopes_supported": scopes,
2573 "bearer_methods_supported": ["header"]
2574 })
2575}
2576
2577#[must_use]
2582pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2583 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2584 let mut meta = serde_json::json!({
2585 "issuer": &config.issuer,
2586 "authorization_endpoint": format!("{server_url}/authorize"),
2587 "token_endpoint": format!("{server_url}/token"),
2588 "registration_endpoint": format!("{server_url}/register"),
2589 "response_types_supported": ["code"],
2590 "grant_types_supported": ["authorization_code", "refresh_token"],
2591 "code_challenge_methods_supported": ["S256"],
2592 "scopes_supported": scopes,
2593 "token_endpoint_auth_methods_supported": ["none"],
2594 });
2595 if let Some(proxy) = &config.proxy
2596 && proxy.expose_admin_endpoints
2597 && let Some(obj) = meta.as_object_mut()
2598 {
2599 if proxy.introspection_url.is_some() {
2600 obj.insert(
2601 "introspection_endpoint".into(),
2602 serde_json::Value::String(format!("{server_url}/introspect")),
2603 );
2604 }
2605 if proxy.revocation_url.is_some() {
2606 obj.insert(
2607 "revocation_endpoint".into(),
2608 serde_json::Value::String(format!("{server_url}/revoke")),
2609 );
2610 }
2611 if proxy.require_auth_on_admin_endpoints {
2612 obj.insert(
2613 "introspection_endpoint_auth_methods_supported".into(),
2614 serde_json::json!(["bearer"]),
2615 );
2616 obj.insert(
2617 "revocation_endpoint_auth_methods_supported".into(),
2618 serde_json::json!(["bearer"]),
2619 );
2620 }
2621 }
2622 meta
2623}
2624
2625#[must_use]
2638pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2639 use axum::{
2640 http::{StatusCode, header},
2641 response::IntoResponse,
2642 };
2643
2644 let upstream_query = replace_client_id(query, &proxy.client_id);
2646 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2647
2648 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2649}
2650
2651pub async fn handle_token(
2657 http: &OauthHttpClient,
2658 proxy: &OAuthProxyConfig,
2659 body: &str,
2660) -> axum::response::Response {
2661 use axum::{
2662 http::{StatusCode, header},
2663 response::IntoResponse,
2664 };
2665
2666 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2668
2669 if let Some(ref secret) = proxy.client_secret {
2671 use std::fmt::Write;
2672
2673 use secrecy::ExposeSecret;
2674 let _ = write!(
2675 upstream_body,
2676 "&client_secret={}",
2677 urlencoding::encode(secret.expose_secret())
2678 );
2679 }
2680
2681 let result = http
2682 .send_screened(
2683 &proxy.token_url,
2684 http.inner
2685 .post(&proxy.token_url)
2686 .header("Content-Type", "application/x-www-form-urlencoded")
2687 .body(upstream_body),
2688 )
2689 .await;
2690
2691 match result {
2692 Ok(resp) => {
2693 let status =
2694 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2695 let body_bytes = resp.bytes().await.unwrap_or_default();
2696 (
2697 status,
2698 [(header::CONTENT_TYPE, "application/json")],
2699 body_bytes,
2700 )
2701 .into_response()
2702 }
2703 Err(e) => {
2704 tracing::error!(error = %e, "OAuth token proxy request failed");
2705 (
2706 StatusCode::BAD_GATEWAY,
2707 [(header::CONTENT_TYPE, "application/json")],
2708 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2709 )
2710 .into_response()
2711 }
2712 }
2713}
2714
2715#[must_use]
2722pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2723 let mut resp = serde_json::json!({
2724 "client_id": proxy.client_id,
2725 "token_endpoint_auth_method": "none",
2726 });
2727 if let Some(uris) = body.get("redirect_uris")
2728 && let Some(obj) = resp.as_object_mut()
2729 {
2730 obj.insert("redirect_uris".into(), uris.clone());
2731 }
2732 if let Some(name) = body.get("client_name")
2733 && let Some(obj) = resp.as_object_mut()
2734 {
2735 obj.insert("client_name".into(), name.clone());
2736 }
2737 resp
2738}
2739
2740pub async fn handle_introspect(
2746 http: &OauthHttpClient,
2747 proxy: &OAuthProxyConfig,
2748 body: &str,
2749) -> axum::response::Response {
2750 let Some(ref url) = proxy.introspection_url else {
2751 return oauth_error_response(
2752 axum::http::StatusCode::NOT_FOUND,
2753 "not_supported",
2754 "introspection endpoint is not configured",
2755 );
2756 };
2757 proxy_oauth_admin_request(http, proxy, url, body).await
2758}
2759
2760pub async fn handle_revoke(
2767 http: &OauthHttpClient,
2768 proxy: &OAuthProxyConfig,
2769 body: &str,
2770) -> axum::response::Response {
2771 let Some(ref url) = proxy.revocation_url else {
2772 return oauth_error_response(
2773 axum::http::StatusCode::NOT_FOUND,
2774 "not_supported",
2775 "revocation endpoint is not configured",
2776 );
2777 };
2778 proxy_oauth_admin_request(http, proxy, url, body).await
2779}
2780
2781async fn proxy_oauth_admin_request(
2785 http: &OauthHttpClient,
2786 proxy: &OAuthProxyConfig,
2787 upstream_url: &str,
2788 body: &str,
2789) -> axum::response::Response {
2790 use axum::{
2791 http::{StatusCode, header},
2792 response::IntoResponse,
2793 };
2794
2795 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2796 if let Some(ref secret) = proxy.client_secret {
2797 use std::fmt::Write;
2798
2799 use secrecy::ExposeSecret;
2800 let _ = write!(
2801 upstream_body,
2802 "&client_secret={}",
2803 urlencoding::encode(secret.expose_secret())
2804 );
2805 }
2806
2807 let result = http
2808 .send_screened(
2809 upstream_url,
2810 http.inner
2811 .post(upstream_url)
2812 .header("Content-Type", "application/x-www-form-urlencoded")
2813 .body(upstream_body),
2814 )
2815 .await;
2816
2817 match result {
2818 Ok(resp) => {
2819 let status =
2820 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2821 let content_type = resp
2822 .headers()
2823 .get(header::CONTENT_TYPE)
2824 .and_then(|v| v.to_str().ok())
2825 .unwrap_or("application/json")
2826 .to_owned();
2827 let body_bytes = resp.bytes().await.unwrap_or_default();
2828 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2829 }
2830 Err(e) => {
2831 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2832 oauth_error_response(
2833 StatusCode::BAD_GATEWAY,
2834 "server_error",
2835 "upstream endpoint unreachable",
2836 )
2837 }
2838 }
2839}
2840
2841fn oauth_error_response(
2842 status: axum::http::StatusCode,
2843 error: &str,
2844 description: &str,
2845) -> axum::response::Response {
2846 use axum::{http::header, response::IntoResponse};
2847 let body = serde_json::json!({
2848 "error": error,
2849 "error_description": description,
2850 });
2851 (
2852 status,
2853 [(header::CONTENT_TYPE, "application/json")],
2854 body.to_string(),
2855 )
2856 .into_response()
2857}
2858
2859#[derive(Debug, Deserialize)]
2865struct OAuthErrorResponse {
2866 error: String,
2867 error_description: Option<String>,
2868}
2869
2870fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2877 match raw {
2878 "invalid_request" => "invalid_request",
2879 "invalid_client" => "invalid_client",
2880 "invalid_grant" => "invalid_grant",
2881 "unauthorized_client" => "unauthorized_client",
2882 "unsupported_grant_type" => "unsupported_grant_type",
2883 "invalid_scope" => "invalid_scope",
2884 "temporarily_unavailable" => "temporarily_unavailable",
2885 "invalid_target" => "invalid_target",
2887 _ => "server_error",
2890 }
2891}
2892
2893pub async fn exchange_token(
2905 http: &OauthHttpClient,
2906 config: &TokenExchangeConfig,
2907 subject_token: &str,
2908) -> Result<ExchangedToken, crate::error::McpxError> {
2909 use secrecy::ExposeSecret;
2910
2911 let client = http.client_for(config);
2912 let mut req = client
2913 .post(&config.token_url)
2914 .header("Content-Type", "application/x-www-form-urlencoded")
2915 .header("Accept", "application/json");
2916
2917 if config.client_cert.is_none()
2926 && let Some(ref secret) = config.client_secret
2927 {
2928 use base64::Engine;
2929 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2930 "{}:{}",
2931 urlencoding::encode(&config.client_id),
2932 urlencoding::encode(secret.expose_secret()),
2933 ));
2934 req = req.header("Authorization", format!("Basic {credentials}"));
2935 }
2936
2937 let form_body = build_exchange_form(config, subject_token);
2938
2939 let resp = http
2940 .send_screened(&config.token_url, req.body(form_body))
2941 .await
2942 .map_err(|e| {
2943 tracing::error!(error = %e, "token exchange request failed");
2944 crate::error::McpxError::Auth("server_error".into())
2946 })?;
2947
2948 let status = resp.status();
2949 let body_bytes = resp.bytes().await.map_err(|e| {
2950 tracing::error!(error = %e, "failed to read token exchange response");
2951 crate::error::McpxError::Auth("server_error".into())
2952 })?;
2953
2954 if !status.is_success() {
2955 core::hint::cold_path();
2956 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2959 let short_code = parsed
2960 .as_ref()
2961 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2962 if let Some(ref e) = parsed {
2963 tracing::warn!(
2964 status = %status,
2965 upstream_error = %e.error,
2966 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2967 client_code = %short_code,
2968 "token exchange rejected by authorization server",
2969 );
2970 } else {
2971 tracing::warn!(
2972 status = %status,
2973 client_code = %short_code,
2974 "token exchange rejected (unparseable upstream body)",
2975 );
2976 }
2977 return Err(crate::error::McpxError::Auth(short_code.into()));
2978 }
2979
2980 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2981 tracing::error!(error = %e, "failed to parse token exchange response");
2982 crate::error::McpxError::Auth("server_error".into())
2985 })?;
2986
2987 log_exchanged_token(&exchanged);
2988
2989 Ok(exchanged)
2990}
2991
2992fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2995 let body = format!(
2996 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2997 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2998 urlencoding::encode(subject_token),
2999 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3000 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3001 urlencoding::encode(&config.audience),
3002 );
3003 if config.client_secret.is_none() {
3004 format!(
3005 "{body}&client_id={}",
3006 urlencoding::encode(&config.client_id)
3007 )
3008 } else {
3009 body
3010 }
3011}
3012
3013fn log_exchanged_token(exchanged: &ExchangedToken) {
3016 use base64::Engine;
3017
3018 if !looks_like_jwt(&exchanged.access_token) {
3019 tracing::debug!(
3020 token_len = exchanged.access_token.len(),
3021 issued_token_type = ?exchanged.issued_token_type,
3022 expires_in = exchanged.expires_in,
3023 "exchanged token (opaque)",
3024 );
3025 return;
3026 }
3027 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
3028 return;
3029 };
3030 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
3031 return;
3032 };
3033 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
3034 return;
3035 };
3036 tracing::debug!(
3037 sub = ?claims.get("sub"),
3038 aud = ?claims.get("aud"),
3039 azp = ?claims.get("azp"),
3040 iss = ?claims.get("iss"),
3041 expires_in = exchanged.expires_in,
3042 "exchanged token claims (JWT)",
3043 );
3044}
3045
3046fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3048 let encoded_id = urlencoding::encode(upstream_client_id);
3049 let mut parts: Vec<String> = params
3050 .split('&')
3051 .filter(|p| !p.starts_with("client_id="))
3052 .map(String::from)
3053 .collect();
3054 parts.push(format!("client_id={encoded_id}"));
3055 parts.join("&")
3056}
3057
3058#[cfg(test)]
3059mod tests {
3060 use std::sync::Arc;
3061
3062 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3063
3064 use super::*;
3065
3066 #[test]
3067 fn looks_like_jwt_valid() {
3068 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3070 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3071 let token = format!("{header}.{payload}.signature");
3072 assert!(looks_like_jwt(&token));
3073 }
3074
3075 #[test]
3076 fn looks_like_jwt_rejects_opaque_token() {
3077 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3078 }
3079
3080 #[test]
3081 fn looks_like_jwt_rejects_two_segments() {
3082 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3083 let token = format!("{header}.payload");
3084 assert!(!looks_like_jwt(&token));
3085 }
3086
3087 #[test]
3088 fn looks_like_jwt_rejects_four_segments() {
3089 assert!(!looks_like_jwt("a.b.c.d"));
3090 }
3091
3092 #[test]
3093 fn looks_like_jwt_rejects_no_alg() {
3094 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3095 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3096 let token = format!("{header}.{payload}.sig");
3097 assert!(!looks_like_jwt(&token));
3098 }
3099
3100 #[test]
3101 fn protected_resource_metadata_shape() {
3102 let config = OAuthConfig {
3103 issuer: "https://auth.example.com".into(),
3104 audience: "https://mcp.example.com/mcp".into(),
3105 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3106 scopes: vec![
3107 ScopeMapping {
3108 scope: "mcp:read".into(),
3109 role: "viewer".into(),
3110 },
3111 ScopeMapping {
3112 scope: "mcp:admin".into(),
3113 role: "ops".into(),
3114 },
3115 ],
3116 role_claim: None,
3117 role_mappings: vec![],
3118 jwks_cache_ttl: "10m".into(),
3119 proxy: None,
3120 token_exchange: None,
3121 ca_cert_path: None,
3122 allow_http_oauth_urls: false,
3123 max_jwks_keys: default_max_jwks_keys(),
3124 #[allow(
3125 deprecated,
3126 reason = "test fixture: explicit value for the deprecated field"
3127 )]
3128 strict_audience_validation: false,
3129 audience_validation_mode: None,
3130 jwks_max_response_bytes: default_jwks_max_bytes(),
3131 ssrf_allowlist: None,
3132 };
3133 let meta = protected_resource_metadata(
3134 "https://mcp.example.com/mcp",
3135 "https://mcp.example.com",
3136 &config,
3137 );
3138 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3139 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3140 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3141 assert_eq!(meta["bearer_methods_supported"][0], "header");
3142 }
3143
3144 fn validation_https_config() -> OAuthConfig {
3149 OAuthConfig::builder(
3150 "https://auth.example.com",
3151 "mcp",
3152 "https://auth.example.com/.well-known/jwks.json",
3153 )
3154 .build()
3155 }
3156
3157 #[test]
3158 fn validate_accepts_all_https_urls() {
3159 let cfg = validation_https_config();
3160 cfg.validate().expect("all-HTTPS config must validate");
3161 }
3162
3163 #[test]
3164 fn validate_rejects_unparseable_jwks_cache_ttl() {
3165 let mut cfg = validation_https_config();
3166 cfg.jwks_cache_ttl = "not-a-duration".into();
3167 let err = cfg
3168 .validate()
3169 .expect_err("malformed jwks_cache_ttl must be rejected");
3170 let msg = err.to_string();
3171 assert!(
3172 msg.contains("jwks_cache_ttl"),
3173 "error must reference offending field; got {msg:?}"
3174 );
3175 }
3176
3177 #[test]
3178 fn validate_rejects_http_jwks_uri() {
3179 let mut cfg = validation_https_config();
3180 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3181 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3182 let msg = err.to_string();
3183 assert!(
3184 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3185 "error must reference offending field + scheme requirement; got {msg:?}"
3186 );
3187 }
3188
3189 #[test]
3190 fn validate_rejects_http_proxy_authorize_url() {
3191 let mut cfg = validation_https_config();
3192 cfg.proxy = Some(
3193 OAuthProxyConfig::builder(
3194 "http://idp.example.com/authorize", "https://idp.example.com/token",
3196 "client",
3197 )
3198 .build(),
3199 );
3200 let err = cfg
3201 .validate()
3202 .expect_err("http authorize_url must be rejected");
3203 assert!(
3204 err.to_string().contains("oauth.proxy.authorize_url"),
3205 "error must reference proxy.authorize_url; got {err}"
3206 );
3207 }
3208
3209 #[test]
3210 fn validate_rejects_http_proxy_token_url() {
3211 let mut cfg = validation_https_config();
3212 cfg.proxy = Some(
3213 OAuthProxyConfig::builder(
3214 "https://idp.example.com/authorize",
3215 "http://idp.example.com/token", "client",
3217 )
3218 .build(),
3219 );
3220 let err = cfg.validate().expect_err("http token_url must be rejected");
3221 assert!(
3222 err.to_string().contains("oauth.proxy.token_url"),
3223 "error must reference proxy.token_url; got {err}"
3224 );
3225 }
3226
3227 #[test]
3228 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3229 let mut cfg = validation_https_config();
3230 cfg.proxy = Some(
3231 OAuthProxyConfig::builder(
3232 "https://idp.example.com/authorize",
3233 "https://idp.example.com/token",
3234 "client",
3235 )
3236 .introspection_url("http://idp.example.com/introspect")
3237 .build(),
3238 );
3239 let err = cfg
3240 .validate()
3241 .expect_err("http introspection_url must be rejected");
3242 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3243
3244 let mut cfg = validation_https_config();
3245 cfg.proxy = Some(
3246 OAuthProxyConfig::builder(
3247 "https://idp.example.com/authorize",
3248 "https://idp.example.com/token",
3249 "client",
3250 )
3251 .revocation_url("http://idp.example.com/revoke")
3252 .build(),
3253 );
3254 let err = cfg
3255 .validate()
3256 .expect_err("http revocation_url must be rejected");
3257 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3258 }
3259
3260 #[test]
3263 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3264 let mut cfg = validation_https_config();
3265 cfg.proxy = Some(
3266 OAuthProxyConfig::builder(
3267 "https://idp.example.com/authorize",
3268 "https://idp.example.com/token",
3269 "client",
3270 )
3271 .introspection_url("https://idp.example.com/introspect")
3272 .expose_admin_endpoints(true)
3273 .build(),
3274 );
3275 let err = cfg
3276 .validate()
3277 .expect_err("expose_admin_endpoints without auth must fail");
3278 let msg = err.to_string();
3279 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3280 assert!(
3281 msg.contains("allow_unauthenticated_admin_endpoints"),
3282 "{msg}"
3283 );
3284 }
3285
3286 #[test]
3287 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3288 let mut cfg = validation_https_config();
3289 cfg.proxy = Some(
3290 OAuthProxyConfig::builder(
3291 "https://idp.example.com/authorize",
3292 "https://idp.example.com/token",
3293 "client",
3294 )
3295 .introspection_url("https://idp.example.com/introspect")
3296 .expose_admin_endpoints(true)
3297 .require_auth_on_admin_endpoints(true)
3298 .build(),
3299 );
3300 cfg.validate()
3301 .expect("authed admin endpoints must validate");
3302 }
3303
3304 #[test]
3305 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3306 let mut cfg = validation_https_config();
3307 cfg.proxy = Some(
3308 OAuthProxyConfig::builder(
3309 "https://idp.example.com/authorize",
3310 "https://idp.example.com/token",
3311 "client",
3312 )
3313 .introspection_url("https://idp.example.com/introspect")
3314 .expose_admin_endpoints(true)
3315 .allow_unauthenticated_admin_endpoints(true)
3316 .build(),
3317 );
3318 cfg.validate()
3319 .expect("explicit unauth opt-out must validate");
3320 }
3321
3322 #[test]
3323 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3324 let mut cfg = validation_https_config();
3327 cfg.proxy = Some(
3328 OAuthProxyConfig::builder(
3329 "https://idp.example.com/authorize",
3330 "https://idp.example.com/token",
3331 "client",
3332 )
3333 .introspection_url("https://idp.example.com/introspect")
3334 .build(),
3335 );
3336 cfg.validate()
3337 .expect("unexposed admin endpoints must validate");
3338 }
3339
3340 #[test]
3341 fn validate_rejects_http_token_exchange_url() {
3342 let mut cfg = validation_https_config();
3343 cfg.token_exchange = Some(TokenExchangeConfig::new(
3344 "http://idp.example.com/token".into(), "client".into(),
3346 None,
3347 None,
3348 "downstream".into(),
3349 ));
3350 let err = cfg
3351 .validate()
3352 .expect_err("http token_exchange.token_url must be rejected");
3353 assert!(
3354 err.to_string().contains("oauth.token_exchange.token_url"),
3355 "error must reference token_exchange.token_url; got {err}"
3356 );
3357 }
3358
3359 #[test]
3360 fn validate_rejects_unparseable_url() {
3361 let mut cfg = validation_https_config();
3362 cfg.jwks_uri = "not a url".into();
3363 let err = cfg
3364 .validate()
3365 .expect_err("unparseable URL must be rejected");
3366 assert!(err.to_string().contains("invalid URL"));
3367 }
3368
3369 #[test]
3370 fn validate_rejects_non_http_scheme() {
3371 let mut cfg = validation_https_config();
3372 cfg.jwks_uri = "file:///etc/passwd".into();
3373 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3374 let msg = err.to_string();
3375 assert!(
3376 msg.contains("must use https scheme") && msg.contains("file"),
3377 "error must reject non-http(s) schemes; got {msg:?}"
3378 );
3379 }
3380
3381 #[test]
3382 fn validate_accepts_http_with_escape_hatch() {
3383 let mut cfg = OAuthConfig::builder(
3388 "http://auth.local",
3389 "mcp",
3390 "http://auth.local/.well-known/jwks.json",
3391 )
3392 .allow_http_oauth_urls(true)
3393 .build();
3394 cfg.proxy = Some(
3395 OAuthProxyConfig::builder(
3396 "http://idp.local/authorize",
3397 "http://idp.local/token",
3398 "client",
3399 )
3400 .introspection_url("http://idp.local/introspect")
3401 .revocation_url("http://idp.local/revoke")
3402 .build(),
3403 );
3404 cfg.token_exchange = Some(TokenExchangeConfig::new(
3405 "http://idp.local/token".into(),
3406 "client".into(),
3407 Some(secrecy::SecretString::new("dev-secret".into())),
3408 None,
3409 "downstream".into(),
3410 ));
3411 cfg.validate()
3412 .expect("escape hatch must permit http on all URL fields");
3413 }
3414
3415 #[test]
3416 fn validate_with_escape_hatch_still_rejects_unparseable() {
3417 let mut cfg = validation_https_config();
3420 cfg.allow_http_oauth_urls = true;
3421 cfg.jwks_uri = "::not-a-url::".into();
3422 cfg.validate()
3423 .expect_err("escape hatch must NOT bypass URL parsing");
3424 }
3425
3426 #[tokio::test]
3427 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3428 rustls::crypto::ring::default_provider()
3443 .install_default()
3444 .ok();
3445
3446 let policy = reqwest::redirect::Policy::custom(|attempt| {
3447 if attempt.url().scheme() != "https" {
3448 attempt.error("redirect to non-HTTPS URL refused")
3449 } else if attempt.previous().len() >= 2 {
3450 attempt.error("too many redirects (max 2)")
3451 } else {
3452 attempt.follow()
3453 }
3454 });
3455 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3462 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3463 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3464 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3465 );
3466 let client = reqwest::Client::builder()
3467 .no_proxy()
3468 .dns_resolver(Arc::clone(&resolver))
3469 .timeout(Duration::from_secs(5))
3470 .connect_timeout(Duration::from_secs(3))
3471 .redirect(policy)
3472 .build()
3473 .expect("test client builds");
3474
3475 let mock = wiremock::MockServer::start().await;
3476 wiremock::Mock::given(wiremock::matchers::method("GET"))
3477 .and(wiremock::matchers::path("/jwks.json"))
3478 .respond_with(
3479 wiremock::ResponseTemplate::new(302)
3480 .insert_header("location", "http://example.invalid/jwks.json"),
3481 )
3482 .mount(&mock)
3483 .await;
3484
3485 let url = format!("{}/jwks.json", mock.uri());
3494 let err = client
3495 .get(&url)
3496 .send()
3497 .await
3498 .expect_err("redirect policy must reject scheme downgrade");
3499 let chain = format!("{err:#}");
3500 assert!(
3501 chain.contains("redirect to non-HTTPS URL refused")
3502 || chain.to_lowercase().contains("redirect"),
3503 "error must surface redirect-policy rejection; got {chain:?}"
3504 );
3505 }
3506
3507 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3512
3513 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3515 let mut rng = rsa::rand_core::OsRng;
3516 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3517 let private_pem = private_key
3518 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3519 .expect("PKCS8 PEM export")
3520 .to_string();
3521
3522 let public_key = private_key.to_public_key();
3523 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3524 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3525
3526 let jwks = serde_json::json!({
3527 "keys": [{
3528 "kty": "RSA",
3529 "use": "sig",
3530 "alg": "RS256",
3531 "kid": kid,
3532 "n": n,
3533 "e": e
3534 }]
3535 });
3536
3537 (private_pem, jwks)
3538 }
3539
3540 fn mint_token(
3542 private_pem: &str,
3543 kid: &str,
3544 issuer: &str,
3545 audience: &str,
3546 subject: &str,
3547 scope: &str,
3548 ) -> String {
3549 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3550 .expect("encoding key from PEM");
3551 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3552 header.kid = Some(kid.into());
3553
3554 let now = jsonwebtoken::get_current_timestamp();
3555 let claims = serde_json::json!({
3556 "iss": issuer,
3557 "aud": audience,
3558 "sub": subject,
3559 "scope": scope,
3560 "exp": now + 3600,
3561 "iat": now,
3562 });
3563
3564 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3565 }
3566
3567 fn test_config(jwks_uri: &str) -> OAuthConfig {
3568 OAuthConfig {
3569 issuer: "https://auth.test.local".into(),
3570 audience: "https://mcp.test.local/mcp".into(),
3571 jwks_uri: jwks_uri.into(),
3572 scopes: vec![
3573 ScopeMapping {
3574 scope: "mcp:read".into(),
3575 role: "viewer".into(),
3576 },
3577 ScopeMapping {
3578 scope: "mcp:admin".into(),
3579 role: "ops".into(),
3580 },
3581 ],
3582 role_claim: None,
3583 role_mappings: vec![],
3584 jwks_cache_ttl: "5m".into(),
3585 proxy: None,
3586 token_exchange: None,
3587 ca_cert_path: None,
3588 allow_http_oauth_urls: true,
3589 max_jwks_keys: default_max_jwks_keys(),
3590 #[allow(
3591 deprecated,
3592 reason = "test fixture: explicit value for the deprecated field"
3593 )]
3594 strict_audience_validation: false,
3595 audience_validation_mode: None,
3596 jwks_max_response_bytes: default_jwks_max_bytes(),
3597 ssrf_allowlist: None,
3598 }
3599 }
3600
3601 fn test_cache(config: &OAuthConfig) -> JwksCache {
3602 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3603 }
3604
3605 #[tokio::test]
3606 async fn valid_jwt_returns_identity() {
3607 let kid = "test-key-1";
3608 let (pem, jwks) = generate_test_keypair(kid);
3609
3610 let mock_server = wiremock::MockServer::start().await;
3611 wiremock::Mock::given(wiremock::matchers::method("GET"))
3612 .and(wiremock::matchers::path("/jwks.json"))
3613 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3614 .mount(&mock_server)
3615 .await;
3616
3617 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3618 let config = test_config(&jwks_uri);
3619 let cache = test_cache(&config);
3620
3621 let token = mint_token(
3622 &pem,
3623 kid,
3624 "https://auth.test.local",
3625 "https://mcp.test.local/mcp",
3626 "ci-bot",
3627 "mcp:read mcp:other",
3628 );
3629
3630 let identity = cache.validate_token(&token).await;
3631 assert!(identity.is_some(), "valid JWT should authenticate");
3632 let id = identity.unwrap();
3633 assert_eq!(id.name, "ci-bot");
3634 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3636 }
3637
3638 #[tokio::test]
3639 async fn wrong_issuer_rejected() {
3640 let kid = "test-key-2";
3641 let (pem, jwks) = generate_test_keypair(kid);
3642
3643 let mock_server = wiremock::MockServer::start().await;
3644 wiremock::Mock::given(wiremock::matchers::method("GET"))
3645 .and(wiremock::matchers::path("/jwks.json"))
3646 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3647 .mount(&mock_server)
3648 .await;
3649
3650 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3651 let config = test_config(&jwks_uri);
3652 let cache = test_cache(&config);
3653
3654 let token = mint_token(
3655 &pem,
3656 kid,
3657 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3659 "attacker",
3660 "mcp:admin",
3661 );
3662
3663 assert!(cache.validate_token(&token).await.is_none());
3664 }
3665
3666 #[tokio::test]
3667 async fn wrong_audience_rejected() {
3668 let kid = "test-key-3";
3669 let (pem, jwks) = generate_test_keypair(kid);
3670
3671 let mock_server = wiremock::MockServer::start().await;
3672 wiremock::Mock::given(wiremock::matchers::method("GET"))
3673 .and(wiremock::matchers::path("/jwks.json"))
3674 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3675 .mount(&mock_server)
3676 .await;
3677
3678 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3679 let config = test_config(&jwks_uri);
3680 let cache = test_cache(&config);
3681
3682 let token = mint_token(
3683 &pem,
3684 kid,
3685 "https://auth.test.local",
3686 "https://wrong-audience.example.com", "attacker",
3688 "mcp:admin",
3689 );
3690
3691 assert!(cache.validate_token(&token).await.is_none());
3692 }
3693
3694 #[tokio::test]
3695 async fn expired_jwt_rejected() {
3696 let kid = "test-key-4";
3697 let (pem, jwks) = generate_test_keypair(kid);
3698
3699 let mock_server = wiremock::MockServer::start().await;
3700 wiremock::Mock::given(wiremock::matchers::method("GET"))
3701 .and(wiremock::matchers::path("/jwks.json"))
3702 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3703 .mount(&mock_server)
3704 .await;
3705
3706 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3707 let config = test_config(&jwks_uri);
3708 let cache = test_cache(&config);
3709
3710 let encoding_key =
3712 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3713 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3714 header.kid = Some(kid.into());
3715 let now = jsonwebtoken::get_current_timestamp();
3716 let claims = serde_json::json!({
3717 "iss": "https://auth.test.local",
3718 "aud": "https://mcp.test.local/mcp",
3719 "sub": "expired-bot",
3720 "scope": "mcp:read",
3721 "exp": now - 120,
3722 "iat": now - 3720,
3723 });
3724 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3725
3726 assert!(cache.validate_token(&token).await.is_none());
3727 }
3728
3729 #[tokio::test]
3730 async fn no_matching_scope_rejected() {
3731 let kid = "test-key-5";
3732 let (pem, jwks) = generate_test_keypair(kid);
3733
3734 let mock_server = wiremock::MockServer::start().await;
3735 wiremock::Mock::given(wiremock::matchers::method("GET"))
3736 .and(wiremock::matchers::path("/jwks.json"))
3737 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3738 .mount(&mock_server)
3739 .await;
3740
3741 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3742 let config = test_config(&jwks_uri);
3743 let cache = test_cache(&config);
3744
3745 let token = mint_token(
3746 &pem,
3747 kid,
3748 "https://auth.test.local",
3749 "https://mcp.test.local/mcp",
3750 "limited-bot",
3751 "some:other:scope", );
3753
3754 assert!(cache.validate_token(&token).await.is_none());
3755 }
3756
3757 #[tokio::test]
3758 async fn wrong_signing_key_rejected() {
3759 let kid = "test-key-6";
3760 let (_pem, jwks) = generate_test_keypair(kid);
3761
3762 let (attacker_pem, _) = generate_test_keypair(kid);
3764
3765 let mock_server = wiremock::MockServer::start().await;
3766 wiremock::Mock::given(wiremock::matchers::method("GET"))
3767 .and(wiremock::matchers::path("/jwks.json"))
3768 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3769 .mount(&mock_server)
3770 .await;
3771
3772 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3773 let config = test_config(&jwks_uri);
3774 let cache = test_cache(&config);
3775
3776 let token = mint_token(
3778 &attacker_pem,
3779 kid,
3780 "https://auth.test.local",
3781 "https://mcp.test.local/mcp",
3782 "attacker",
3783 "mcp:admin",
3784 );
3785
3786 assert!(cache.validate_token(&token).await.is_none());
3787 }
3788
3789 #[tokio::test]
3790 async fn admin_scope_maps_to_ops_role() {
3791 let kid = "test-key-7";
3792 let (pem, jwks) = generate_test_keypair(kid);
3793
3794 let mock_server = wiremock::MockServer::start().await;
3795 wiremock::Mock::given(wiremock::matchers::method("GET"))
3796 .and(wiremock::matchers::path("/jwks.json"))
3797 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3798 .mount(&mock_server)
3799 .await;
3800
3801 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3802 let config = test_config(&jwks_uri);
3803 let cache = test_cache(&config);
3804
3805 let token = mint_token(
3806 &pem,
3807 kid,
3808 "https://auth.test.local",
3809 "https://mcp.test.local/mcp",
3810 "admin-bot",
3811 "mcp:admin",
3812 );
3813
3814 let id = cache
3815 .validate_token(&token)
3816 .await
3817 .expect("should authenticate");
3818 assert_eq!(id.role, "ops");
3819 assert_eq!(id.name, "admin-bot");
3820 }
3821
3822 #[tokio::test]
3823 async fn jwks_server_down_returns_none() {
3824 let config = test_config("http://127.0.0.1:1/jwks.json");
3826 let cache = test_cache(&config);
3827
3828 let kid = "orphan-key";
3829 let (pem, _) = generate_test_keypair(kid);
3830 let token = mint_token(
3831 &pem,
3832 kid,
3833 "https://auth.test.local",
3834 "https://mcp.test.local/mcp",
3835 "bot",
3836 "mcp:read",
3837 );
3838
3839 assert!(cache.validate_token(&token).await.is_none());
3840 }
3841
3842 #[test]
3847 fn resolve_claim_path_flat_string() {
3848 let mut extra = HashMap::new();
3849 extra.insert(
3850 "scope".into(),
3851 serde_json::Value::String("mcp:read mcp:admin".into()),
3852 );
3853 let values = resolve_claim_path(&extra, "scope");
3854 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3855 }
3856
3857 #[test]
3858 fn resolve_claim_path_flat_array() {
3859 let mut extra = HashMap::new();
3860 extra.insert(
3861 "roles".into(),
3862 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3863 );
3864 let values = resolve_claim_path(&extra, "roles");
3865 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3866 }
3867
3868 #[test]
3869 fn resolve_claim_path_nested_keycloak() {
3870 let mut extra = HashMap::new();
3871 extra.insert(
3872 "realm_access".into(),
3873 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3874 );
3875 let values = resolve_claim_path(&extra, "realm_access.roles");
3876 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3877 }
3878
3879 #[test]
3880 fn resolve_claim_path_missing_returns_empty() {
3881 let extra = HashMap::new();
3882 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3883 }
3884
3885 #[test]
3886 fn resolve_claim_path_numeric_leaf_returns_empty() {
3887 let mut extra = HashMap::new();
3888 extra.insert("count".into(), serde_json::json!(42));
3889 assert!(resolve_claim_path(&extra, "count").is_empty());
3890 }
3891
3892 fn make_claims(json: serde_json::Value) -> Claims {
3893 serde_json::from_value(json).expect("test claims must deserialize")
3894 }
3895
3896 #[test]
3897 fn first_class_scope_claim_splits_on_whitespace() {
3898 let claims = make_claims(serde_json::json!({
3899 "iss": "https://issuer.example.com",
3900 "exp": 9_999_999_999_u64,
3901 "scope": "read write admin",
3902 }));
3903 let values = first_class_claim_values(&claims, "scope");
3904 assert_eq!(values, vec!["read", "write", "admin"]);
3905 }
3906
3907 #[test]
3908 fn first_class_sub_claim_returns_single_value() {
3909 let claims = make_claims(serde_json::json!({
3910 "iss": "https://issuer.example.com",
3911 "exp": 9_999_999_999_u64,
3912 "sub": "service-account-orders",
3913 }));
3914 let values = first_class_claim_values(&claims, "sub");
3915 assert_eq!(values, vec!["service-account-orders"]);
3916 }
3917
3918 #[test]
3919 fn first_class_aud_claim_returns_every_audience() {
3920 let claims = make_claims(serde_json::json!({
3921 "iss": "https://issuer.example.com",
3922 "exp": 9_999_999_999_u64,
3923 "aud": ["api-a", "api-b"],
3924 }));
3925 let values = first_class_claim_values(&claims, "aud");
3926 assert_eq!(values, vec!["api-a", "api-b"]);
3927 }
3928
3929 #[test]
3930 fn first_class_unknown_path_returns_empty() {
3931 let claims = make_claims(serde_json::json!({
3932 "iss": "https://issuer.example.com",
3933 "exp": 9_999_999_999_u64,
3934 }));
3935 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
3936 }
3937
3938 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3944 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3945 .expect("encoding key from PEM");
3946 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3947 header.kid = Some(kid.into());
3948 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3949 }
3950
3951 fn test_config_with_role_claim(
3952 jwks_uri: &str,
3953 role_claim: &str,
3954 role_mappings: Vec<RoleMapping>,
3955 ) -> OAuthConfig {
3956 OAuthConfig {
3957 issuer: "https://auth.test.local".into(),
3958 audience: "https://mcp.test.local/mcp".into(),
3959 jwks_uri: jwks_uri.into(),
3960 scopes: vec![],
3961 role_claim: Some(role_claim.into()),
3962 role_mappings,
3963 jwks_cache_ttl: "5m".into(),
3964 proxy: None,
3965 token_exchange: None,
3966 ca_cert_path: None,
3967 allow_http_oauth_urls: true,
3968 max_jwks_keys: default_max_jwks_keys(),
3969 #[allow(
3970 deprecated,
3971 reason = "test fixture: explicit value for the deprecated field"
3972 )]
3973 strict_audience_validation: false,
3974 audience_validation_mode: None,
3975 jwks_max_response_bytes: default_jwks_max_bytes(),
3976 ssrf_allowlist: None,
3977 }
3978 }
3979
3980 #[tokio::test]
3981 async fn screen_oauth_target_rejects_literal_ip() {
3982 let err = screen_oauth_target(
3983 "https://127.0.0.1/jwks.json",
3984 false,
3985 &crate::ssrf::CompiledSsrfAllowlist::default(),
3986 )
3987 .await
3988 .expect_err("literal IPs must be rejected");
3989 let msg = err.to_string();
3990 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3991 }
3992
3993 #[tokio::test]
3994 async fn screen_oauth_target_rejects_private_dns_resolution() {
3995 let err = screen_oauth_target(
3996 "https://localhost/jwks.json",
3997 false,
3998 &crate::ssrf::CompiledSsrfAllowlist::default(),
3999 )
4000 .await
4001 .expect_err("localhost resolution must be rejected");
4002 let msg = err.to_string();
4003 assert!(
4004 msg.contains("blocked IP") && msg.contains("loopback"),
4005 "got {msg:?}"
4006 );
4007 }
4008
4009 #[tokio::test]
4010 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
4011 let err = screen_oauth_target(
4012 "http://127.0.0.1/jwks.json",
4013 true,
4014 &crate::ssrf::CompiledSsrfAllowlist::default(),
4015 )
4016 .await
4017 .expect_err("literal IPs must still be rejected when http is allowed");
4018 let msg = err.to_string();
4019 assert!(msg.contains("literal IPv4 addresses are forbidden"));
4020 }
4021
4022 #[tokio::test]
4023 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
4024 let err = screen_oauth_target(
4025 "http://localhost/jwks.json",
4026 true,
4027 &crate::ssrf::CompiledSsrfAllowlist::default(),
4028 )
4029 .await
4030 .expect_err("private DNS resolution must still be rejected when http is allowed");
4031 let msg = err.to_string();
4032 assert!(
4033 msg.contains("blocked IP") && msg.contains("loopback"),
4034 "got {msg:?}"
4035 );
4036 }
4037
4038 #[tokio::test]
4039 async fn screen_oauth_target_allows_public_hostname() {
4040 screen_oauth_target(
4041 "https://example.com/.well-known/jwks.json",
4042 false,
4043 &crate::ssrf::CompiledSsrfAllowlist::default(),
4044 )
4045 .await
4046 .expect("public hostname should pass screening");
4047 }
4048
4049 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4055 let raw = OAuthSsrfAllowlist {
4056 hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
4057 cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
4058 };
4059 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4060 }
4061
4062 #[test]
4063 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4064 let raw = OAuthSsrfAllowlist {
4065 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4066 cidrs: vec![],
4067 };
4068 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4069 assert_eq!(compiled.host_count(), 1);
4070 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4071 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4072 }
4073
4074 #[test]
4075 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4076 let raw = OAuthSsrfAllowlist {
4077 hosts: vec!["10.0.0.1".into()],
4078 cidrs: vec![],
4079 };
4080 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4081 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4082 }
4083
4084 #[test]
4085 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4086 let raw = OAuthSsrfAllowlist {
4087 hosts: vec!["rhbk.ops.example.com:8443".into()],
4088 cidrs: vec![],
4089 };
4090 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4091 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4092 }
4093
4094 #[test]
4095 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4096 let raw = OAuthSsrfAllowlist {
4097 hosts: vec![],
4098 cidrs: vec!["not-a-cidr".into()],
4099 };
4100 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4101 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4102 }
4103
4104 #[test]
4105 fn validate_rejects_misconfigured_allowlist() {
4106 let mut cfg = OAuthConfig::builder(
4107 "https://auth.example.com/",
4108 "mcp",
4109 "https://auth.example.com/jwks.json",
4110 )
4111 .build();
4112 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4113 hosts: vec!["10.0.0.1".into()],
4114 cidrs: vec![],
4115 });
4116 let err = cfg
4117 .validate()
4118 .expect_err("literal IP host must be rejected");
4119 assert!(
4120 err.to_string().contains("oauth.ssrf_allowlist"),
4121 "got {err}"
4122 );
4123 }
4124
4125 #[tokio::test]
4126 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4127 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4131 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4132 .await
4133 .expect_err("loopback must still be blocked when not in allowlist");
4134 let msg = err.to_string();
4135 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4136 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4137 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4138 }
4139
4140 #[tokio::test]
4141 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4142 let err = screen_oauth_target(
4145 "https://localhost/jwks.json",
4146 false,
4147 &crate::ssrf::CompiledSsrfAllowlist::default(),
4148 )
4149 .await
4150 .expect_err("loopback rejection");
4151 let msg = err.to_string();
4152 assert!(msg.contains("blocked IP"), "got {msg:?}");
4153 assert!(msg.contains("loopback"), "got {msg:?}");
4154 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4156 }
4157
4158 #[tokio::test]
4159 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4160 let allow = make_allowlist(&["localhost"], &[]);
4162 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4163 .await
4164 .expect("allowlisted host must pass");
4165 }
4166
4167 #[tokio::test]
4168 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4169 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4172 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4173 .await
4174 .expect("allowlisted CIDR must pass");
4175 }
4176
4177 #[tokio::test]
4178 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4179 let mut cfg = OAuthConfig::builder(
4180 "https://auth.example.com/",
4181 "mcp",
4182 "https://auth.example.com/jwks.json",
4183 )
4184 .build();
4185 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4186 hosts: vec![],
4187 cidrs: vec!["bad-cidr".into()],
4188 });
4189 let Err(err) = JwksCache::new(&cfg) else {
4190 panic!("invalid CIDR must fail JwksCache::new")
4191 };
4192 let msg = err.to_string();
4193 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4194 }
4195
4196 #[tokio::test]
4197 async fn audience_falls_back_to_azp_by_default() {
4198 let kid = "test-audience-azp-default";
4199 let (pem, jwks) = generate_test_keypair(kid);
4200
4201 let mock_server = wiremock::MockServer::start().await;
4202 wiremock::Mock::given(wiremock::matchers::method("GET"))
4203 .and(wiremock::matchers::path("/jwks.json"))
4204 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4205 .mount(&mock_server)
4206 .await;
4207
4208 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4209 let config = test_config(&jwks_uri);
4210 let cache = test_cache(&config);
4211
4212 let now = jsonwebtoken::get_current_timestamp();
4213 let token = mint_token_with_claims(
4214 &pem,
4215 kid,
4216 &serde_json::json!({
4217 "iss": "https://auth.test.local",
4218 "aud": "https://some-other-resource.example.com",
4219 "azp": "https://mcp.test.local/mcp",
4220 "sub": "compat-client",
4221 "scope": "mcp:read",
4222 "exp": now + 3600,
4223 "iat": now,
4224 }),
4225 );
4226
4227 let identity = cache
4228 .validate_token_with_reason(&token)
4229 .await
4230 .expect("azp fallback should remain enabled by default");
4231 assert_eq!(identity.role, "viewer");
4232 }
4233
4234 #[tokio::test]
4235 async fn strict_audience_validation_rejects_azp_only_match() {
4236 let kid = "test-audience-azp-strict";
4237 let (pem, jwks) = generate_test_keypair(kid);
4238
4239 let mock_server = wiremock::MockServer::start().await;
4240 wiremock::Mock::given(wiremock::matchers::method("GET"))
4241 .and(wiremock::matchers::path("/jwks.json"))
4242 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4243 .mount(&mock_server)
4244 .await;
4245
4246 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4247 let mut config = test_config(&jwks_uri);
4248 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4249 {
4250 config.strict_audience_validation = true;
4251 }
4252 let cache = test_cache(&config);
4253
4254 let now = jsonwebtoken::get_current_timestamp();
4255 let token = mint_token_with_claims(
4256 &pem,
4257 kid,
4258 &serde_json::json!({
4259 "iss": "https://auth.test.local",
4260 "aud": "https://some-other-resource.example.com",
4261 "azp": "https://mcp.test.local/mcp",
4262 "sub": "strict-client",
4263 "scope": "mcp:read",
4264 "exp": now + 3600,
4265 "iat": now,
4266 }),
4267 );
4268
4269 let failure = cache
4270 .validate_token_with_reason(&token)
4271 .await
4272 .expect_err("strict audience validation must ignore azp fallback");
4273 assert_eq!(failure, JwtValidationFailure::Invalid);
4274 }
4275
4276 #[tokio::test]
4277 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4278 let kid = "test-audience-warn-mode";
4279 let (pem, jwks) = generate_test_keypair(kid);
4280
4281 let mock_server = wiremock::MockServer::start().await;
4282 wiremock::Mock::given(wiremock::matchers::method("GET"))
4283 .and(wiremock::matchers::path("/jwks.json"))
4284 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4285 .mount(&mock_server)
4286 .await;
4287
4288 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4289 let mut config = test_config(&jwks_uri);
4290 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4291 let cache = test_cache(&config);
4292
4293 let now = jsonwebtoken::get_current_timestamp();
4294 let claims = serde_json::json!({
4295 "iss": "https://auth.test.local",
4296 "aud": "https://some-other-resource.example.com",
4297 "azp": "https://mcp.test.local/mcp",
4298 "sub": "warn-client",
4299 "scope": "mcp:read",
4300 "exp": now + 3600,
4301 "iat": now,
4302 });
4303 let token = mint_token_with_claims(&pem, kid, &claims);
4304
4305 let identity = cache
4306 .validate_token_with_reason(&token)
4307 .await
4308 .expect("warn mode must accept azp-only match");
4309 assert_eq!(identity.role, "viewer");
4310 assert!(
4311 cache.azp_fallback_warned.load(Ordering::Relaxed),
4312 "warn-once flag should be set after first azp-only match"
4313 );
4314
4315 let token2 = mint_token_with_claims(&pem, kid, &claims);
4316 cache
4317 .validate_token_with_reason(&token2)
4318 .await
4319 .expect("warn mode must continue accepting subsequent matches");
4320 assert!(
4321 cache.azp_fallback_warned.load(Ordering::Relaxed),
4322 "warn-once flag must remain set; the assertion guards against accidental clearing"
4323 );
4324 }
4325
4326 #[tokio::test]
4327 async fn permissive_mode_accepts_azp_only_match_silently() {
4328 let kid = "test-audience-permissive-mode";
4329 let (pem, jwks) = generate_test_keypair(kid);
4330
4331 let mock_server = wiremock::MockServer::start().await;
4332 wiremock::Mock::given(wiremock::matchers::method("GET"))
4333 .and(wiremock::matchers::path("/jwks.json"))
4334 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4335 .mount(&mock_server)
4336 .await;
4337
4338 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4339 let mut config = test_config(&jwks_uri);
4340 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4341 let cache = test_cache(&config);
4342
4343 let now = jsonwebtoken::get_current_timestamp();
4344 let token = mint_token_with_claims(
4345 &pem,
4346 kid,
4347 &serde_json::json!({
4348 "iss": "https://auth.test.local",
4349 "aud": "https://some-other-resource.example.com",
4350 "azp": "https://mcp.test.local/mcp",
4351 "sub": "permissive-client",
4352 "scope": "mcp:read",
4353 "exp": now + 3600,
4354 "iat": now,
4355 }),
4356 );
4357
4358 cache
4359 .validate_token_with_reason(&token)
4360 .await
4361 .expect("permissive mode must accept azp-only match");
4362 assert!(
4363 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4364 "permissive mode must not flip the warn-once flag"
4365 );
4366 }
4367
4368 #[test]
4369 fn audience_validation_mode_overrides_legacy_bool() {
4370 let mut config = OAuthConfig::default();
4371 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4372 {
4373 config.strict_audience_validation = false;
4374 }
4375 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4376 assert_eq!(
4377 config.effective_audience_validation_mode(),
4378 AudienceValidationMode::Strict,
4379 "explicit mode must override legacy false"
4380 );
4381
4382 let mut config = OAuthConfig::default();
4383 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4384 {
4385 config.strict_audience_validation = true;
4386 }
4387 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4388 assert_eq!(
4389 config.effective_audience_validation_mode(),
4390 AudienceValidationMode::Permissive,
4391 "explicit mode must override legacy true"
4392 );
4393 }
4394
4395 #[test]
4396 fn audience_validation_mode_default_is_warn_when_unset() {
4397 let config = OAuthConfig::default();
4398 assert_eq!(
4399 config.effective_audience_validation_mode(),
4400 AudienceValidationMode::Warn,
4401 "unset mode + unset bool must resolve to Warn (the new default)"
4402 );
4403 }
4404
4405 #[test]
4406 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4407 let mut config = OAuthConfig::default();
4408 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4409 {
4410 config.strict_audience_validation = true;
4411 }
4412 assert_eq!(
4413 config.effective_audience_validation_mode(),
4414 AudienceValidationMode::Strict,
4415 "legacy bool=true must resolve to Strict for backward compat"
4416 );
4417 }
4418
4419 #[derive(Clone, Default)]
4420 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4421
4422 impl CapturedLogs {
4423 fn contents(&self) -> String {
4424 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4425 String::from_utf8(bytes).unwrap_or_default()
4426 }
4427 }
4428
4429 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4430
4431 impl std::io::Write for CapturedLogsWriter {
4432 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4433 if let Ok(mut guard) = self.0.lock() {
4434 guard.extend_from_slice(buf);
4435 }
4436 Ok(buf.len())
4437 }
4438
4439 fn flush(&mut self) -> std::io::Result<()> {
4440 Ok(())
4441 }
4442 }
4443
4444 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4445 type Writer = CapturedLogsWriter;
4446
4447 fn make_writer(&'a self) -> Self::Writer {
4448 CapturedLogsWriter(Arc::clone(&self.0))
4449 }
4450 }
4451
4452 #[tokio::test]
4453 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4454 let kid = "oversized-jwks";
4455 let (_pem, jwks) = generate_test_keypair(kid);
4456 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4457 oversized_body.push_str(&" ".repeat(4096));
4458
4459 let mock_server = wiremock::MockServer::start().await;
4460 wiremock::Mock::given(wiremock::matchers::method("GET"))
4461 .and(wiremock::matchers::path("/jwks.json"))
4462 .respond_with(
4463 wiremock::ResponseTemplate::new(200)
4464 .insert_header("content-type", "application/json")
4465 .set_body_string(oversized_body),
4466 )
4467 .mount(&mock_server)
4468 .await;
4469
4470 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4471 let mut config = test_config(&jwks_uri);
4472 config.jwks_max_response_bytes = 256;
4473 let cache = test_cache(&config);
4474
4475 let logs = CapturedLogs::default();
4476 let subscriber = tracing_subscriber::fmt()
4477 .with_writer(logs.clone())
4478 .with_ansi(false)
4479 .without_time()
4480 .finish();
4481 let _guard = tracing::subscriber::set_default(subscriber);
4482
4483 let result = cache.fetch_jwks().await;
4484 assert!(result.is_none(), "oversized JWKS must be dropped");
4485 assert!(
4486 logs.contents()
4487 .contains("JWKS response exceeded configured size cap"),
4488 "expected cap-exceeded warning in logs"
4489 );
4490 }
4491
4492 #[tokio::test]
4493 async fn role_claim_keycloak_nested_array() {
4494 let kid = "test-role-1";
4495 let (pem, jwks) = generate_test_keypair(kid);
4496
4497 let mock_server = wiremock::MockServer::start().await;
4498 wiremock::Mock::given(wiremock::matchers::method("GET"))
4499 .and(wiremock::matchers::path("/jwks.json"))
4500 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4501 .mount(&mock_server)
4502 .await;
4503
4504 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4505 let config = test_config_with_role_claim(
4506 &jwks_uri,
4507 "realm_access.roles",
4508 vec![
4509 RoleMapping {
4510 claim_value: "mcp-admin".into(),
4511 role: "ops".into(),
4512 },
4513 RoleMapping {
4514 claim_value: "mcp-viewer".into(),
4515 role: "viewer".into(),
4516 },
4517 ],
4518 );
4519 let cache = test_cache(&config);
4520
4521 let now = jsonwebtoken::get_current_timestamp();
4522 let token = mint_token_with_claims(
4523 &pem,
4524 kid,
4525 &serde_json::json!({
4526 "iss": "https://auth.test.local",
4527 "aud": "https://mcp.test.local/mcp",
4528 "sub": "keycloak-user",
4529 "exp": now + 3600,
4530 "iat": now,
4531 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4532 }),
4533 );
4534
4535 let id = cache
4536 .validate_token(&token)
4537 .await
4538 .expect("should authenticate");
4539 assert_eq!(id.name, "keycloak-user");
4540 assert_eq!(id.role, "ops");
4541 }
4542
4543 #[tokio::test]
4544 async fn role_claim_flat_roles_array() {
4545 let kid = "test-role-2";
4546 let (pem, jwks) = generate_test_keypair(kid);
4547
4548 let mock_server = wiremock::MockServer::start().await;
4549 wiremock::Mock::given(wiremock::matchers::method("GET"))
4550 .and(wiremock::matchers::path("/jwks.json"))
4551 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4552 .mount(&mock_server)
4553 .await;
4554
4555 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4556 let config = test_config_with_role_claim(
4557 &jwks_uri,
4558 "roles",
4559 vec![
4560 RoleMapping {
4561 claim_value: "MCP.Admin".into(),
4562 role: "ops".into(),
4563 },
4564 RoleMapping {
4565 claim_value: "MCP.Reader".into(),
4566 role: "viewer".into(),
4567 },
4568 ],
4569 );
4570 let cache = test_cache(&config);
4571
4572 let now = jsonwebtoken::get_current_timestamp();
4573 let token = mint_token_with_claims(
4574 &pem,
4575 kid,
4576 &serde_json::json!({
4577 "iss": "https://auth.test.local",
4578 "aud": "https://mcp.test.local/mcp",
4579 "sub": "azure-ad-user",
4580 "exp": now + 3600,
4581 "iat": now,
4582 "roles": ["MCP.Reader", "OtherApp.Admin"]
4583 }),
4584 );
4585
4586 let id = cache
4587 .validate_token(&token)
4588 .await
4589 .expect("should authenticate");
4590 assert_eq!(id.name, "azure-ad-user");
4591 assert_eq!(id.role, "viewer");
4592 }
4593
4594 #[tokio::test]
4595 async fn role_claim_no_matching_value_rejected() {
4596 let kid = "test-role-3";
4597 let (pem, jwks) = generate_test_keypair(kid);
4598
4599 let mock_server = wiremock::MockServer::start().await;
4600 wiremock::Mock::given(wiremock::matchers::method("GET"))
4601 .and(wiremock::matchers::path("/jwks.json"))
4602 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4603 .mount(&mock_server)
4604 .await;
4605
4606 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4607 let config = test_config_with_role_claim(
4608 &jwks_uri,
4609 "roles",
4610 vec![RoleMapping {
4611 claim_value: "mcp-admin".into(),
4612 role: "ops".into(),
4613 }],
4614 );
4615 let cache = test_cache(&config);
4616
4617 let now = jsonwebtoken::get_current_timestamp();
4618 let token = mint_token_with_claims(
4619 &pem,
4620 kid,
4621 &serde_json::json!({
4622 "iss": "https://auth.test.local",
4623 "aud": "https://mcp.test.local/mcp",
4624 "sub": "limited-user",
4625 "exp": now + 3600,
4626 "iat": now,
4627 "roles": ["some-other-role"]
4628 }),
4629 );
4630
4631 assert!(cache.validate_token(&token).await.is_none());
4632 }
4633
4634 #[tokio::test]
4635 async fn role_claim_space_separated_string() {
4636 let kid = "test-role-4";
4637 let (pem, jwks) = generate_test_keypair(kid);
4638
4639 let mock_server = wiremock::MockServer::start().await;
4640 wiremock::Mock::given(wiremock::matchers::method("GET"))
4641 .and(wiremock::matchers::path("/jwks.json"))
4642 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4643 .mount(&mock_server)
4644 .await;
4645
4646 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4647 let config = test_config_with_role_claim(
4648 &jwks_uri,
4649 "custom_scope",
4650 vec![
4651 RoleMapping {
4652 claim_value: "write".into(),
4653 role: "ops".into(),
4654 },
4655 RoleMapping {
4656 claim_value: "read".into(),
4657 role: "viewer".into(),
4658 },
4659 ],
4660 );
4661 let cache = test_cache(&config);
4662
4663 let now = jsonwebtoken::get_current_timestamp();
4664 let token = mint_token_with_claims(
4665 &pem,
4666 kid,
4667 &serde_json::json!({
4668 "iss": "https://auth.test.local",
4669 "aud": "https://mcp.test.local/mcp",
4670 "sub": "custom-client",
4671 "exp": now + 3600,
4672 "iat": now,
4673 "custom_scope": "read audit"
4674 }),
4675 );
4676
4677 let id = cache
4678 .validate_token(&token)
4679 .await
4680 .expect("should authenticate");
4681 assert_eq!(id.name, "custom-client");
4682 assert_eq!(id.role, "viewer");
4683 }
4684
4685 #[tokio::test]
4686 async fn scope_backward_compat_without_role_claim() {
4687 let kid = "test-compat-1";
4689 let (pem, jwks) = generate_test_keypair(kid);
4690
4691 let mock_server = wiremock::MockServer::start().await;
4692 wiremock::Mock::given(wiremock::matchers::method("GET"))
4693 .and(wiremock::matchers::path("/jwks.json"))
4694 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4695 .mount(&mock_server)
4696 .await;
4697
4698 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4699 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4701
4702 let token = mint_token(
4703 &pem,
4704 kid,
4705 "https://auth.test.local",
4706 "https://mcp.test.local/mcp",
4707 "legacy-bot",
4708 "mcp:admin other:scope",
4709 );
4710
4711 let id = cache
4712 .validate_token(&token)
4713 .await
4714 .expect("should authenticate");
4715 assert_eq!(id.name, "legacy-bot");
4716 assert_eq!(id.role, "ops"); }
4718
4719 #[tokio::test]
4724 async fn jwks_refresh_deduplication() {
4725 let kid = "test-dedup";
4728 let (pem, jwks) = generate_test_keypair(kid);
4729
4730 let mock_server = wiremock::MockServer::start().await;
4731 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4732 .and(wiremock::matchers::path("/jwks.json"))
4733 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4734 .expect(1) .mount(&mock_server)
4736 .await;
4737
4738 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4739 let config = test_config(&jwks_uri);
4740 let cache = Arc::new(test_cache(&config));
4741
4742 let token = mint_token(
4744 &pem,
4745 kid,
4746 "https://auth.test.local",
4747 "https://mcp.test.local/mcp",
4748 "concurrent-bot",
4749 "mcp:read",
4750 );
4751
4752 let mut handles = Vec::new();
4753 for _ in 0..5 {
4754 let c = Arc::clone(&cache);
4755 let t = token.clone();
4756 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4757 }
4758
4759 for h in handles {
4760 let result = h.await.unwrap();
4761 assert!(result.is_some(), "all concurrent requests should succeed");
4762 }
4763
4764 }
4766
4767 #[tokio::test]
4768 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4769 let kid = "test-cooldown";
4772 let (_pem, jwks) = generate_test_keypair(kid);
4773
4774 let mock_server = wiremock::MockServer::start().await;
4775 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4776 .and(wiremock::matchers::path("/jwks.json"))
4777 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4778 .expect(1) .mount(&mock_server)
4780 .await;
4781
4782 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4783 let config = test_config(&jwks_uri);
4784 let cache = test_cache(&config);
4785
4786 let fake_token1 =
4788 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4789 let _ = cache.validate_token(fake_token1).await;
4790
4791 let fake_token2 =
4794 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4795 let _ = cache.validate_token(fake_token2).await;
4796
4797 let fake_token3 =
4799 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4800 let _ = cache.validate_token(fake_token3).await;
4801
4802 }
4804
4805 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
4808 OAuthProxyConfig {
4809 authorize_url: "https://example.invalid/auth".into(),
4810 token_url: token_url.into(),
4811 client_id: "mcp-client".into(),
4812 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
4813 introspection_url: None,
4814 revocation_url: None,
4815 expose_admin_endpoints: false,
4816 require_auth_on_admin_endpoints: false,
4817 allow_unauthenticated_admin_endpoints: false,
4818 }
4819 }
4820
4821 fn test_http_client() -> OauthHttpClient {
4824 rustls::crypto::ring::default_provider()
4825 .install_default()
4826 .ok();
4827 let config = OAuthConfig::builder(
4828 "https://auth.test.local",
4829 "https://mcp.test.local/mcp",
4830 "https://auth.test.local/.well-known/jwks.json",
4831 )
4832 .allow_http_oauth_urls(true)
4833 .build();
4834 OauthHttpClient::with_config(&config)
4835 .expect("build test http client")
4836 .__test_allow_loopback_ssrf()
4837 }
4838
4839 #[tokio::test]
4840 async fn introspect_proxies_and_injects_client_credentials() {
4841 use wiremock::matchers::{body_string_contains, method, path};
4842
4843 let mock_server = wiremock::MockServer::start().await;
4844 wiremock::Mock::given(method("POST"))
4845 .and(path("/introspect"))
4846 .and(body_string_contains("client_id=mcp-client"))
4847 .and(body_string_contains("client_secret=shh"))
4848 .and(body_string_contains("token=abc"))
4849 .respond_with(
4850 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
4851 "active": true,
4852 "scope": "read"
4853 })),
4854 )
4855 .expect(1)
4856 .mount(&mock_server)
4857 .await;
4858
4859 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4860 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
4861
4862 let http = test_http_client();
4863 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4864 assert_eq!(resp.status(), 200);
4865 }
4866
4867 #[tokio::test]
4868 async fn introspect_returns_404_when_not_configured() {
4869 let proxy = proxy_cfg("https://example.invalid/token");
4870 let http = test_http_client();
4871 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4872 assert_eq!(resp.status(), 404);
4873 }
4874
4875 #[tokio::test]
4876 async fn revoke_proxies_and_returns_upstream_status() {
4877 use wiremock::matchers::{method, path};
4878
4879 let mock_server = wiremock::MockServer::start().await;
4880 wiremock::Mock::given(method("POST"))
4881 .and(path("/revoke"))
4882 .respond_with(wiremock::ResponseTemplate::new(200))
4883 .expect(1)
4884 .mount(&mock_server)
4885 .await;
4886
4887 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4888 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4889
4890 let http = test_http_client();
4891 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4892 assert_eq!(resp.status(), 200);
4893 }
4894
4895 #[tokio::test]
4896 async fn revoke_returns_404_when_not_configured() {
4897 let proxy = proxy_cfg("https://example.invalid/token");
4898 let http = test_http_client();
4899 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4900 assert_eq!(resp.status(), 404);
4901 }
4902
4903 #[test]
4904 fn metadata_advertises_endpoints_only_when_configured() {
4905 let mut cfg = test_config("https://auth.test.local/jwks.json");
4906 let m = authorization_server_metadata("https://mcp.local", &cfg);
4908 assert!(m.get("introspection_endpoint").is_none());
4909 assert!(m.get("revocation_endpoint").is_none());
4910
4911 let mut proxy = proxy_cfg("https://upstream.local/token");
4914 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4915 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4916 cfg.proxy = Some(proxy);
4917 let m = authorization_server_metadata("https://mcp.local", &cfg);
4918 assert!(
4919 m.get("introspection_endpoint").is_none(),
4920 "introspection must not be advertised when expose_admin_endpoints=false"
4921 );
4922 assert!(
4923 m.get("revocation_endpoint").is_none(),
4924 "revocation must not be advertised when expose_admin_endpoints=false"
4925 );
4926
4927 if let Some(p) = cfg.proxy.as_mut() {
4929 p.expose_admin_endpoints = true;
4930 p.revocation_url = None;
4931 }
4932 let m = authorization_server_metadata("https://mcp.local", &cfg);
4933 assert_eq!(
4934 m["introspection_endpoint"],
4935 serde_json::Value::String("https://mcp.local/introspect".into())
4936 );
4937 assert!(m.get("revocation_endpoint").is_none());
4938
4939 if let Some(p) = cfg.proxy.as_mut() {
4941 p.revocation_url = Some("https://upstream.local/revoke".into());
4942 }
4943 let m = authorization_server_metadata("https://mcp.local", &cfg);
4944 assert_eq!(
4945 m["revocation_endpoint"],
4946 serde_json::Value::String("https://mcp.local/revoke".into())
4947 );
4948 }
4949
4950 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
4953 let mut cfg = validation_https_config();
4954 cfg.token_exchange = Some(tx);
4955 cfg
4956 }
4957
4958 fn tx_with(
4959 client_secret: Option<&str>,
4960 client_cert: Option<ClientCertConfig>,
4961 ) -> TokenExchangeConfig {
4962 TokenExchangeConfig::new(
4963 "https://idp.example.com/token".into(),
4964 "client".into(),
4965 client_secret.map(|s| secrecy::SecretString::new(s.into())),
4966 client_cert,
4967 "downstream".into(),
4968 )
4969 }
4970
4971 #[test]
4972 fn validate_rejects_token_exchange_without_client_auth() {
4973 let cfg = https_cfg_with_tx(tx_with(None, None));
4974 let err = cfg
4975 .validate()
4976 .expect_err("token_exchange without client auth must be rejected");
4977 let msg = err.to_string();
4978 assert!(
4979 msg.contains("requires client authentication"),
4980 "error must explain missing client auth; got {msg:?}"
4981 );
4982 }
4983
4984 #[test]
4985 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
4986 let cc = ClientCertConfig {
4987 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4988 key_path: PathBuf::from("/nonexistent/key.pem"),
4989 };
4990 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
4991 let err = cfg
4992 .validate()
4993 .expect_err("client_secret + client_cert must be rejected");
4994 let msg = err.to_string();
4995 assert!(
4996 msg.contains("mutually") && msg.contains("exclusive"),
4997 "error must explain mutual exclusion; got {msg:?}"
4998 );
4999 }
5000
5001 #[cfg(not(feature = "oauth-mtls-client"))]
5002 #[test]
5003 fn validate_rejects_client_cert_without_feature() {
5004 let cc = ClientCertConfig {
5005 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5006 key_path: PathBuf::from("/nonexistent/key.pem"),
5007 };
5008 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5009 let err = cfg
5010 .validate()
5011 .expect_err("client_cert without feature must be rejected");
5012 assert!(
5013 err.to_string().contains("oauth-mtls-client"),
5014 "error must reference the cargo feature; got {err}"
5015 );
5016 }
5017
5018 #[cfg(feature = "oauth-mtls-client")]
5019 #[test]
5020 fn validate_rejects_missing_client_cert_files() {
5021 let cc = ClientCertConfig {
5022 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5023 key_path: PathBuf::from("/nonexistent/key.pem"),
5024 };
5025 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5026 let err = cfg
5027 .validate()
5028 .expect_err("missing cert file must be rejected");
5029 assert!(
5030 err.to_string().contains("unreadable"),
5031 "error must call out unreadable file; got {err}"
5032 );
5033 }
5034
5035 #[cfg(feature = "oauth-mtls-client")]
5036 #[test]
5037 fn validate_rejects_malformed_client_cert_pem() {
5038 let dir = std::env::temp_dir();
5039 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5040 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5041 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5042 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5043 let cc = ClientCertConfig {
5044 cert_path: cert.clone(),
5045 key_path: key.clone(),
5046 };
5047 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5048 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5049 let _ = std::fs::remove_file(&cert);
5050 let _ = std::fs::remove_file(&key);
5051 assert!(
5052 err.to_string().contains("PEM parse failed"),
5053 "error must call out PEM parse failure; got {err}"
5054 );
5055 }
5056
5057 #[cfg(feature = "oauth-mtls-client")]
5058 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5059 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5060 let dir = std::env::temp_dir();
5061 let pid = std::process::id();
5062 let nonce: u64 = rand::random();
5063 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5064 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5065 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5066 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5067 (cert_path, key_path)
5068 }
5069
5070 #[cfg(feature = "oauth-mtls-client")]
5071 fn install_test_crypto_provider() {
5072 let _ = rustls::crypto::ring::default_provider().install_default();
5073 }
5074
5075 #[cfg(feature = "oauth-mtls-client")]
5076 #[test]
5077 fn validate_accepts_well_formed_client_cert() {
5078 install_test_crypto_provider();
5079 let (cert_path, key_path) = write_self_signed_pem();
5080 let cc = ClientCertConfig {
5081 cert_path: cert_path.clone(),
5082 key_path: key_path.clone(),
5083 };
5084 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5085 let res = cfg.validate();
5086 let _ = std::fs::remove_file(&cert_path);
5087 let _ = std::fs::remove_file(&key_path);
5088 res.expect("well-formed cert+key must validate");
5089 }
5090
5091 #[cfg(feature = "oauth-mtls-client")]
5092 #[test]
5093 fn client_for_returns_cached_mtls_client() {
5094 install_test_crypto_provider();
5095 let (cert_path, key_path) = write_self_signed_pem();
5096 let cc = ClientCertConfig {
5097 cert_path: cert_path.clone(),
5098 key_path: key_path.clone(),
5099 };
5100 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5101 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5102 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5103 let cert_client = http.client_for(tx_ref);
5104 let inner_client = http.client_for(&tx_with(Some("s"), None));
5105 let _ = std::fs::remove_file(&cert_path);
5106 let _ = std::fs::remove_file(&key_path);
5107 assert!(
5108 !std::ptr::eq(cert_client, inner_client),
5109 "client_for must return distinct clients for cert vs no-cert configs"
5110 );
5111 }
5112
5113 #[cfg(feature = "oauth-mtls-client")]
5114 #[test]
5115 fn client_for_falls_back_to_inner_when_cache_miss() {
5116 install_test_crypto_provider();
5117 let cfg = validation_https_config();
5118 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5119 let unrelated_cc = ClientCertConfig {
5120 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5121 key_path: PathBuf::from("/cache/miss/key.pem"),
5122 };
5123 let tx_unknown = tx_with(None, Some(unrelated_cc));
5124 let fallback = http.client_for(&tx_unknown);
5125 let inner = http.client_for(&tx_with(Some("s"), None));
5126 assert!(
5127 std::ptr::eq(fallback, inner),
5128 "cache miss must fall back to inner client"
5129 );
5130 }
5131}