1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::{Duration, Instant},
24};
25
26use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
27use serde::Deserialize;
28use tokio::{net::lookup_host, sync::RwLock};
29
30use crate::auth::{AuthIdentity, AuthMethod};
31
32fn evaluate_oauth_redirect(
58 attempt: &reqwest::redirect::Attempt<'_>,
59 allow_http: bool,
60 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
61) -> Result<(), String> {
62 let prev_https = attempt
63 .previous()
64 .last()
65 .is_some_and(|prev| prev.scheme() == "https");
66 let target_url = attempt.url();
67 let dest_scheme = target_url.scheme();
68 if dest_scheme != "https" {
69 if prev_https {
70 return Err("redirect downgrades https -> http".to_owned());
71 }
72 if !allow_http || dest_scheme != "http" {
73 return Err("redirect to non-HTTP(S) URL refused".to_owned());
74 }
75 }
76 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
77 {
78 return Err(format!("redirect target forbidden: {reason}"));
79 }
80 if attempt.previous().len() >= 2 {
81 return Err("too many redirects (max 2)".to_owned());
82 }
83 Ok(())
84}
85
86#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
100async fn screen_oauth_target_with_test_override(
101 url: &str,
102 allow_http: bool,
103 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
104 #[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
105) -> Result<(), crate::error::McpxError> {
106 let parsed = check_oauth_url("oauth target", url, allow_http)?;
107 #[cfg(any(test, feature = "test-helpers"))]
108 if test_allow_loopback_ssrf {
109 return Ok(());
110 }
111 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
112 return Err(crate::error::McpxError::Config(format!(
113 "OAuth target forbidden ({reason}): {url}"
114 )));
115 }
116
117 let host = parsed.host_str().ok_or_else(|| {
118 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
119 })?;
120 let port = parsed.port_or_known_default().ok_or_else(|| {
121 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
122 })?;
123
124 let addrs = lookup_host((host, port)).await.map_err(|error| {
125 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
126 })?;
127
128 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
129 let mut any_addr = false;
130 for addr in addrs {
131 any_addr = true;
132 let ip = addr.ip();
133 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
134 if reason == "cloud_metadata" {
137 return Err(crate::error::McpxError::Config(format!(
138 "OAuth target resolved to blocked IP ({reason}): {url}"
139 )));
140 }
141 if allowlist.is_empty() {
145 return Err(crate::error::McpxError::Config(format!(
146 "OAuth target resolved to blocked IP ({reason}): {url}"
147 )));
148 }
149 if host_allowed || allowlist.ip_allowed(ip) {
151 continue;
152 }
153 return Err(crate::error::McpxError::Config(format!(
154 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
155 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
156 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
157 URL: {url}"
158 )));
159 }
160 }
161 if !any_addr {
162 return Err(crate::error::McpxError::Config(format!(
163 "OAuth target DNS resolution returned no addresses: {url}"
164 )));
165 }
166
167 Ok(())
168}
169
170async fn screen_oauth_target(
171 url: &str,
172 allow_http: bool,
173 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
174) -> Result<(), crate::error::McpxError> {
175 #[cfg(any(test, feature = "test-helpers"))]
176 {
177 screen_oauth_target_with_test_override(url, allow_http, allowlist, false).await
178 }
179 #[cfg(not(any(test, feature = "test-helpers")))]
180 {
181 let parsed = check_oauth_url("oauth target", url, allow_http)?;
182 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
183 return Err(crate::error::McpxError::Config(format!(
184 "OAuth target forbidden ({reason}): {url}"
185 )));
186 }
187
188 let host = parsed.host_str().ok_or_else(|| {
189 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
190 })?;
191 let port = parsed.port_or_known_default().ok_or_else(|| {
192 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
193 })?;
194
195 let addrs = lookup_host((host, port)).await.map_err(|error| {
196 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
197 })?;
198
199 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
200 let mut any_addr = false;
201 for addr in addrs {
202 any_addr = true;
203 let ip = addr.ip();
204 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
205 if reason == "cloud_metadata" {
206 return Err(crate::error::McpxError::Config(format!(
207 "OAuth target resolved to blocked IP ({reason}): {url}"
208 )));
209 }
210 if allowlist.is_empty() {
211 return Err(crate::error::McpxError::Config(format!(
212 "OAuth target resolved to blocked IP ({reason}): {url}"
213 )));
214 }
215 if host_allowed || allowlist.ip_allowed(ip) {
216 continue;
217 }
218 return Err(crate::error::McpxError::Config(format!(
219 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
220 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
221 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
222 URL: {url}"
223 )));
224 }
225 }
226 if !any_addr {
227 return Err(crate::error::McpxError::Config(format!(
228 "OAuth target DNS resolution returned no addresses: {url}"
229 )));
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Clone)]
277pub struct OauthHttpClient {
278 inner: reqwest::Client,
279 allow_http: bool,
280 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
285 #[cfg(feature = "oauth-mtls-client")]
290 mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
291 #[cfg(any(test, feature = "test-helpers"))]
297 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
298}
299
300#[cfg(feature = "oauth-mtls-client")]
304#[derive(Debug, Clone, Hash, Eq, PartialEq)]
305struct MtlsClientKey {
306 cert_path: PathBuf,
307 key_path: PathBuf,
308}
309
310impl OauthHttpClient {
311 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
329 Self::build(Some(config))
330 }
331
332 #[deprecated(
355 since = "1.2.1",
356 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
357 )]
358 pub fn new() -> Result<Self, crate::error::McpxError> {
359 Self::build(None)
360 }
361
362 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
365 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
366
367 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
372 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
373 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
374 })?),
375 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
376 };
377
378 let redirect_allowlist = Arc::clone(&allowlist);
381
382 #[cfg(any(test, feature = "test-helpers"))]
386 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
387 Arc::new(AtomicBool::new(false));
388 #[cfg(not(any(test, feature = "test-helpers")))]
389 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
390
391 let resolver: Arc<dyn reqwest::dns::Resolve> =
392 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
393 Arc::clone(&allowlist),
394 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
398 test_bypass.clone(),
399 ));
400
401 let mut builder = reqwest::Client::builder()
402 .no_proxy()
406 .dns_resolver(Arc::clone(&resolver))
407 .connect_timeout(Duration::from_secs(10))
408 .timeout(Duration::from_secs(30))
409 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
410 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
420 Ok(()) => attempt.follow(),
421 Err(reason) => {
422 tracing::warn!(
423 reason = %reason,
424 target = %attempt.url(),
425 "oauth redirect rejected"
426 );
427 attempt.error(reason)
428 }
429 }
430 }));
431
432 if let Some(cfg) = config
433 && let Some(ref ca_path) = cfg.ca_cert_path
434 {
435 let pem = std::fs::read(ca_path).map_err(|e| {
440 crate::error::McpxError::Startup(format!(
441 "oauth http client: read ca_cert_path {}: {e}",
442 ca_path.display()
443 ))
444 })?;
445 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
446 crate::error::McpxError::Startup(format!(
447 "oauth http client: parse ca_cert_path {}: {e}",
448 ca_path.display()
449 ))
450 })?;
451 builder = builder.add_root_certificate(cert);
452 }
453
454 let inner = builder.build().map_err(|e| {
455 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
456 })?;
457
458 #[cfg(feature = "oauth-mtls-client")]
459 let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
460
461 Ok(Self {
462 inner,
463 allow_http,
464 allowlist,
465 #[cfg(feature = "oauth-mtls-client")]
466 mtls_clients,
467 #[cfg(any(test, feature = "test-helpers"))]
468 test_allow_loopback_ssrf: test_bypass,
469 })
470 }
471
472 async fn send_screened(
473 &self,
474 url: &str,
475 request: reqwest::RequestBuilder,
476 ) -> Result<reqwest::Response, crate::error::McpxError> {
477 #[cfg(any(test, feature = "test-helpers"))]
478 if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
479 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
480 .await?;
481 } else {
482 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
483 }
484 #[cfg(not(any(test, feature = "test-helpers")))]
485 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
486 request.send().await.map_err(|error| {
487 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
488 })
489 }
490
491 #[cfg(any(test, feature = "test-helpers"))]
496 #[doc(hidden)]
497 #[must_use]
498 pub fn __test_allow_loopback_ssrf(self) -> Self {
499 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
502 self
503 }
504
505 #[doc(hidden)]
511 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
512 self.inner.get(url).send().await
513 }
514
515 #[cfg(any(test, feature = "test-helpers"))]
521 #[doc(hidden)]
522 #[must_use]
523 pub fn __test_inner_client(&self) -> &reqwest::Client {
524 &self.inner
525 }
526
527 #[cfg(feature = "oauth-mtls-client")]
534 fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
535 if let Some(cc) = &cfg.client_cert {
536 let key = MtlsClientKey {
537 cert_path: cc.cert_path.clone(),
538 key_path: cc.key_path.clone(),
539 };
540 if let Some(client) = self.mtls_clients.get(&key) {
541 return client;
542 }
543 }
544 &self.inner
545 }
546
547 #[cfg(not(feature = "oauth-mtls-client"))]
548 fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
549 &self.inner
550 }
551}
552
553impl std::fmt::Debug for OauthHttpClient {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
556 }
557}
558
559#[derive(Debug, Clone, Default, Deserialize)]
619#[non_exhaustive]
620pub struct OAuthSsrfAllowlist {
621 #[serde(default)]
626 pub hosts: Vec<String>,
627 #[serde(default)]
633 pub cidrs: Vec<String>,
634}
635
636fn compile_oauth_ssrf_allowlist(
643 raw: &OAuthSsrfAllowlist,
644) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
645 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
646 for (idx, entry) in raw.hosts.iter().enumerate() {
647 let trimmed = entry.trim();
648 if trimmed.is_empty() {
649 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
650 }
651 if trimmed.contains([':', '/', '@', '?', '#']) {
655 return Err(format!(
656 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
657 (no scheme, port, path, userinfo, query, or fragment)"
658 ));
659 }
660 match url::Host::parse(trimmed) {
661 Ok(url::Host::Domain(_)) => {}
662 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
663 return Err(format!(
664 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
665 here -- list them via oauth.ssrf_allowlist.cidrs instead"
666 ));
667 }
668 Err(e) => {
669 return Err(format!(
670 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
671 ));
672 }
673 }
674 hosts.push(trimmed.to_ascii_lowercase());
675 }
676 hosts.sort();
677 hosts.dedup();
678
679 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
680 for (idx, entry) in raw.cidrs.iter().enumerate() {
681 let parsed = crate::ssrf::CidrEntry::parse(entry)
682 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
683 cidrs.push(parsed);
684 }
685
686 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
687}
688
689#[derive(Debug, Clone, Deserialize)]
691#[non_exhaustive]
692pub struct OAuthConfig {
693 pub issuer: String,
695 pub audience: String,
697 pub jwks_uri: String,
699 #[serde(default)]
702 pub scopes: Vec<ScopeMapping>,
703 pub role_claim: Option<String>,
709 #[serde(default)]
712 pub role_mappings: Vec<RoleMapping>,
713 #[serde(default = "default_jwks_cache_ttl")]
716 pub jwks_cache_ttl: String,
717 pub proxy: Option<OAuthProxyConfig>,
721 pub token_exchange: Option<TokenExchangeConfig>,
726 #[serde(default)]
741 pub ca_cert_path: Option<PathBuf>,
742 #[serde(default)]
754 pub allow_http_oauth_urls: bool,
755 #[serde(default)]
764 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
765 #[serde(default = "default_max_jwks_keys")]
769 pub max_jwks_keys: usize,
770 #[serde(default)]
779 #[deprecated(
780 since = "1.7.0",
781 note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
782 )]
783 pub strict_audience_validation: bool,
784 #[serde(default)]
792 pub audience_validation_mode: Option<AudienceValidationMode>,
793 #[serde(default = "default_jwks_max_bytes")]
797 pub jwks_max_response_bytes: u64,
798}
799
800fn default_jwks_cache_ttl() -> String {
801 "10m".into()
802}
803
804const fn default_max_jwks_keys() -> usize {
805 256
806}
807
808const fn default_jwks_max_bytes() -> u64 {
809 1024 * 1024
810}
811
812#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
829#[serde(rename_all = "snake_case")]
830#[non_exhaustive]
831pub enum AudienceValidationMode {
832 Permissive,
836 #[default]
840 Warn,
841 Strict,
845}
846
847impl Default for OAuthConfig {
848 fn default() -> Self {
849 Self {
850 issuer: String::new(),
851 audience: String::new(),
852 jwks_uri: String::new(),
853 scopes: Vec::new(),
854 role_claim: None,
855 role_mappings: Vec::new(),
856 jwks_cache_ttl: default_jwks_cache_ttl(),
857 proxy: None,
858 token_exchange: None,
859 ca_cert_path: None,
860 allow_http_oauth_urls: false,
861 max_jwks_keys: default_max_jwks_keys(),
862 #[allow(
863 deprecated,
864 reason = "default-construct deprecated field for backward compat"
865 )]
866 strict_audience_validation: false,
867 audience_validation_mode: None,
868 jwks_max_response_bytes: default_jwks_max_bytes(),
869 ssrf_allowlist: None,
870 }
871 }
872}
873
874impl OAuthConfig {
875 #[must_use]
881 pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
882 if let Some(mode) = self.audience_validation_mode {
883 return mode;
884 }
885 #[allow(deprecated, reason = "intentional: legacy flag resolution path")]
886 if self.strict_audience_validation {
887 AudienceValidationMode::Strict
888 } else {
889 AudienceValidationMode::Warn
890 }
891 }
892
893 pub fn builder(
899 issuer: impl Into<String>,
900 audience: impl Into<String>,
901 jwks_uri: impl Into<String>,
902 ) -> OAuthConfigBuilder {
903 OAuthConfigBuilder {
904 inner: Self {
905 issuer: issuer.into(),
906 audience: audience.into(),
907 jwks_uri: jwks_uri.into(),
908 ..Self::default()
909 },
910 }
911 }
912
913 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
929 let allow_http = self.allow_http_oauth_urls;
930 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
931 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
932 return Err(crate::error::McpxError::Config(format!(
933 "oauth.issuer forbidden ({reason})"
934 )));
935 }
936 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
937 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
938 return Err(crate::error::McpxError::Config(format!(
939 "oauth.jwks_uri forbidden ({reason})"
940 )));
941 }
942 if let Some(proxy) = &self.proxy {
943 let url = check_oauth_url(
944 "oauth.proxy.authorize_url",
945 &proxy.authorize_url,
946 allow_http,
947 )?;
948 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
949 return Err(crate::error::McpxError::Config(format!(
950 "oauth.proxy.authorize_url forbidden ({reason})"
951 )));
952 }
953 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
954 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
955 return Err(crate::error::McpxError::Config(format!(
956 "oauth.proxy.token_url forbidden ({reason})"
957 )));
958 }
959 if let Some(url) = &proxy.introspection_url {
960 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
961 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
962 return Err(crate::error::McpxError::Config(format!(
963 "oauth.proxy.introspection_url forbidden ({reason})"
964 )));
965 }
966 }
967 if let Some(url) = &proxy.revocation_url {
968 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
969 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
970 return Err(crate::error::McpxError::Config(format!(
971 "oauth.proxy.revocation_url forbidden ({reason})"
972 )));
973 }
974 }
975 if proxy.expose_admin_endpoints
982 && !proxy.require_auth_on_admin_endpoints
983 && !proxy.allow_unauthenticated_admin_endpoints
984 {
985 return Err(crate::error::McpxError::Config(
986 "oauth.proxy: expose_admin_endpoints = true requires \
987 require_auth_on_admin_endpoints = true (recommended) \
988 or allow_unauthenticated_admin_endpoints = true \
989 (explicit opt-out, only safe behind an authenticated \
990 reverse proxy)"
991 .into(),
992 ));
993 }
994 }
995 if let Some(tx) = &self.token_exchange {
996 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
997 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
998 return Err(crate::error::McpxError::Config(format!(
999 "oauth.token_exchange.token_url forbidden ({reason})"
1000 )));
1001 }
1002 validate_token_exchange_client_auth(tx)?;
1005 }
1006 if let Some(raw) = &self.ssrf_allowlist {
1010 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1011 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
1012 })?;
1013 if !compiled.is_empty() {
1014 tracing::warn!(
1015 host_count = compiled.host_count(),
1016 cidr_count = compiled.cidr_count(),
1017 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
1018 are now reachable. Cloud-metadata addresses remain blocked. \
1019 See SECURITY.md \"Operator allowlist\"."
1020 );
1021 }
1022 }
1023 humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
1026 crate::error::McpxError::Config(format!(
1027 "oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
1028 self.jwks_cache_ttl
1029 ))
1030 })?;
1031 Ok(())
1032 }
1033}
1034
1035fn validate_token_exchange_client_auth(
1041 tx: &TokenExchangeConfig,
1042) -> Result<(), crate::error::McpxError> {
1043 match (&tx.client_cert, tx.client_secret.is_some()) {
1044 (Some(_), true) => Err(crate::error::McpxError::Config(
1045 "oauth.token_exchange: client_cert and client_secret are mutually \
1046 exclusive (RFC 8705 ยง2). Set exactly one."
1047 .into(),
1048 )),
1049 (None, false) => Err(crate::error::McpxError::Config(
1050 "oauth.token_exchange: token exchange requires client authentication. \
1051 Set either client_secret (RFC 6749 ยง2.3.1) or client_cert (RFC 8705 ยง2)."
1052 .into(),
1053 )),
1054 (Some(cc), false) => validate_client_cert_config(cc),
1055 (None, true) => Ok(()),
1056 }
1057}
1058
1059fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
1072 #[cfg(not(feature = "oauth-mtls-client"))]
1073 {
1074 let _ = cc;
1075 Err(crate::error::McpxError::Config(
1076 "oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
1077 rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
1078 application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
1079 the field"
1080 .into(),
1081 ))
1082 }
1083 #[cfg(feature = "oauth-mtls-client")]
1084 {
1085 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1086 tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
1087 crate::error::McpxError::Config(format!(
1088 "oauth.token_exchange.client_cert.cert_path unreadable: {}",
1089 cc.cert_path.display()
1090 ))
1091 })?;
1092 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1093 tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
1094 crate::error::McpxError::Config(format!(
1095 "oauth.token_exchange.client_cert.key_path unreadable: {}",
1096 cc.key_path.display()
1097 ))
1098 })?;
1099 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1100 combined.extend_from_slice(&cert_bytes);
1101 if !cert_bytes.ends_with(b"\n") {
1102 combined.push(b'\n');
1103 }
1104 combined.extend_from_slice(&key_bytes);
1105 let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1106 tracing::warn!(
1107 error = %e,
1108 cert_path = %cc.cert_path.display(),
1109 key_path = %cc.key_path.display(),
1110 "client cert PEM parse failed"
1111 );
1112 crate::error::McpxError::Config(format!(
1113 "oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
1114 cc.cert_path.display(),
1115 cc.key_path.display()
1116 ))
1117 })?;
1118 Ok(())
1119 }
1120}
1121
1122#[cfg(feature = "oauth-mtls-client")]
1130fn build_mtls_clients(
1131 config: Option<&OAuthConfig>,
1132 allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
1133 test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
1134) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
1135 let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
1136 let Some(cfg) = config else {
1137 return Ok(Arc::new(map));
1138 };
1139 let Some(tx) = &cfg.token_exchange else {
1140 return Ok(Arc::new(map));
1141 };
1142 let Some(cc) = &tx.client_cert else {
1143 return Ok(Arc::new(map));
1144 };
1145
1146 let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
1147 crate::error::McpxError::Startup(format!(
1148 "oauth http client mTLS: read cert_path {}: {e}",
1149 cc.cert_path.display()
1150 ))
1151 })?;
1152 let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
1153 crate::error::McpxError::Startup(format!(
1154 "oauth http client mTLS: read key_path {}: {e}",
1155 cc.key_path.display()
1156 ))
1157 })?;
1158 let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
1159 combined.extend_from_slice(&cert_bytes);
1160 if !cert_bytes.ends_with(b"\n") {
1161 combined.push(b'\n');
1162 }
1163 combined.extend_from_slice(&key_bytes);
1164 let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
1165 crate::error::McpxError::Startup(format!(
1166 "oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
1167 cc.cert_path.display(),
1168 cc.key_path.display()
1169 ))
1170 })?;
1171
1172 let resolver: Arc<dyn reqwest::dns::Resolve> =
1173 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1174 Arc::clone(allowlist),
1175 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1180 test_bypass.clone(),
1181 ));
1182
1183 let mut builder = reqwest::Client::builder()
1184 .no_proxy()
1186 .dns_resolver(Arc::clone(&resolver))
1187 .connect_timeout(Duration::from_secs(10))
1188 .timeout(Duration::from_secs(30))
1189 .redirect(reqwest::redirect::Policy::none())
1190 .identity(identity);
1191
1192 if let Some(ref ca_path) = cfg.ca_cert_path {
1193 let pem = std::fs::read(ca_path).map_err(|e| {
1194 crate::error::McpxError::Startup(format!(
1195 "oauth http client mTLS: read ca_cert_path {}: {e}",
1196 ca_path.display()
1197 ))
1198 })?;
1199 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
1200 crate::error::McpxError::Startup(format!(
1201 "oauth http client mTLS: parse ca_cert_path {}: {e}",
1202 ca_path.display()
1203 ))
1204 })?;
1205 builder = builder.add_root_certificate(cert);
1206 }
1207
1208 let client = builder.build().map_err(|e| {
1209 crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
1210 })?;
1211 map.insert(
1212 MtlsClientKey {
1213 cert_path: cc.cert_path.clone(),
1214 key_path: cc.key_path.clone(),
1215 },
1216 client,
1217 );
1218 Ok(Arc::new(map))
1219}
1220
1221fn check_oauth_url(
1228 field: &str,
1229 raw: &str,
1230 allow_http: bool,
1231) -> Result<url::Url, crate::error::McpxError> {
1232 let parsed = url::Url::parse(raw).map_err(|e| {
1233 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
1234 })?;
1235 if !parsed.username().is_empty() || parsed.password().is_some() {
1236 return Err(crate::error::McpxError::Config(format!(
1237 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
1238 )));
1239 }
1240 match parsed.scheme() {
1241 "https" => Ok(parsed),
1242 "http" if allow_http => Ok(parsed),
1243 "http" => Err(crate::error::McpxError::Config(format!(
1244 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
1245 to override - strongly discouraged in production)"
1246 ))),
1247 other => Err(crate::error::McpxError::Config(format!(
1248 "{field}: must use https scheme (got {other:?})"
1249 ))),
1250 }
1251}
1252
1253#[derive(Debug, Clone)]
1259#[must_use = "builders do nothing until `.build()` is called"]
1260pub struct OAuthConfigBuilder {
1261 inner: OAuthConfig,
1262}
1263
1264impl OAuthConfigBuilder {
1265 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
1267 self.inner.scopes = scopes;
1268 self
1269 }
1270
1271 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
1273 self.inner.scopes.push(ScopeMapping {
1274 scope: scope.into(),
1275 role: role.into(),
1276 });
1277 self
1278 }
1279
1280 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
1283 self.inner.role_claim = Some(claim.into());
1284 self
1285 }
1286
1287 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
1289 self.inner.role_mappings = mappings;
1290 self
1291 }
1292
1293 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
1296 self.inner.role_mappings.push(RoleMapping {
1297 claim_value: claim_value.into(),
1298 role: role.into(),
1299 });
1300 self
1301 }
1302
1303 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
1306 self.inner.jwks_cache_ttl = ttl.into();
1307 self
1308 }
1309
1310 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
1313 self.inner.proxy = Some(proxy);
1314 self
1315 }
1316
1317 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
1319 self.inner.token_exchange = Some(token_exchange);
1320 self
1321 }
1322
1323 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
1328 self.inner.ca_cert_path = Some(path.into());
1329 self
1330 }
1331
1332 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
1338 self.inner.allow_http_oauth_urls = allow;
1339 self
1340 }
1341
1342 #[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
1351 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
1352 #[allow(
1353 deprecated,
1354 reason = "intentional: deprecated builder forwards to deprecated field"
1355 )]
1356 {
1357 self.inner.strict_audience_validation = strict;
1358 }
1359 self.inner.audience_validation_mode = None;
1360 self
1361 }
1362
1363 pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
1371 self.inner.audience_validation_mode = Some(mode);
1372 self
1373 }
1374
1375 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
1377 self.inner.jwks_max_response_bytes = bytes;
1378 self
1379 }
1380
1381 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
1389 self.inner.ssrf_allowlist = Some(allowlist);
1390 self
1391 }
1392
1393 #[must_use]
1395 pub fn build(self) -> OAuthConfig {
1396 self.inner
1397 }
1398}
1399
1400#[derive(Debug, Clone, Deserialize)]
1402#[non_exhaustive]
1403pub struct ScopeMapping {
1404 pub scope: String,
1406 pub role: String,
1408}
1409
1410#[derive(Debug, Clone, Deserialize)]
1414#[non_exhaustive]
1415pub struct RoleMapping {
1416 pub claim_value: String,
1418 pub role: String,
1420}
1421
1422#[derive(Debug, Clone, Deserialize)]
1429#[non_exhaustive]
1430pub struct TokenExchangeConfig {
1431 pub token_url: String,
1434 pub client_id: String,
1436 pub client_secret: Option<secrecy::SecretString>,
1441 pub client_cert: Option<ClientCertConfig>,
1454 pub audience: String,
1458}
1459
1460impl TokenExchangeConfig {
1461 #[must_use]
1463 pub fn new(
1464 token_url: String,
1465 client_id: String,
1466 client_secret: Option<secrecy::SecretString>,
1467 client_cert: Option<ClientCertConfig>,
1468 audience: String,
1469 ) -> Self {
1470 Self {
1471 token_url,
1472 client_id,
1473 client_secret,
1474 client_cert,
1475 audience,
1476 }
1477 }
1478}
1479
1480#[derive(Debug, Clone, Deserialize)]
1484#[non_exhaustive]
1485pub struct ClientCertConfig {
1486 pub cert_path: PathBuf,
1489 pub key_path: PathBuf,
1493}
1494
1495impl ClientCertConfig {
1496 #[must_use]
1500 pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
1501 Self {
1502 cert_path,
1503 key_path,
1504 }
1505 }
1506}
1507
1508#[derive(Debug, Deserialize)]
1510#[non_exhaustive]
1511pub struct ExchangedToken {
1512 pub access_token: String,
1514 pub expires_in: Option<u64>,
1516 pub issued_token_type: Option<String>,
1519}
1520
1521#[derive(Debug, Clone, Deserialize, Default)]
1528#[non_exhaustive]
1529pub struct OAuthProxyConfig {
1530 pub authorize_url: String,
1533 pub token_url: String,
1536 pub client_id: String,
1538 pub client_secret: Option<secrecy::SecretString>,
1540 #[serde(default)]
1544 pub introspection_url: Option<String>,
1545 #[serde(default)]
1549 pub revocation_url: Option<String>,
1550 #[serde(default)]
1562 pub expose_admin_endpoints: bool,
1563 #[serde(default)]
1569 pub require_auth_on_admin_endpoints: bool,
1570 #[serde(default)]
1581 pub allow_unauthenticated_admin_endpoints: bool,
1582}
1583
1584impl OAuthProxyConfig {
1585 pub fn builder(
1593 authorize_url: impl Into<String>,
1594 token_url: impl Into<String>,
1595 client_id: impl Into<String>,
1596 ) -> OAuthProxyConfigBuilder {
1597 OAuthProxyConfigBuilder {
1598 inner: Self {
1599 authorize_url: authorize_url.into(),
1600 token_url: token_url.into(),
1601 client_id: client_id.into(),
1602 ..Self::default()
1603 },
1604 }
1605 }
1606}
1607
1608#[derive(Debug, Clone)]
1614#[must_use = "builders do nothing until `.build()` is called"]
1615pub struct OAuthProxyConfigBuilder {
1616 inner: OAuthProxyConfig,
1617}
1618
1619impl OAuthProxyConfigBuilder {
1620 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1622 self.inner.client_secret = Some(secret);
1623 self
1624 }
1625
1626 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1630 self.inner.introspection_url = Some(url.into());
1631 self
1632 }
1633
1634 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1638 self.inner.revocation_url = Some(url.into());
1639 self
1640 }
1641
1642 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1650 self.inner.expose_admin_endpoints = expose;
1651 self
1652 }
1653
1654 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1657 self.inner.require_auth_on_admin_endpoints = require;
1658 self
1659 }
1660
1661 pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
1665 self.inner.allow_unauthenticated_admin_endpoints = allow;
1666 self
1667 }
1668
1669 #[must_use]
1671 pub fn build(self) -> OAuthProxyConfig {
1672 self.inner
1673 }
1674}
1675
1676type JwksKeyCache = (
1684 HashMap<String, (Algorithm, DecodingKey)>,
1685 Vec<(Algorithm, DecodingKey)>,
1686);
1687
1688struct CachedKeys {
1689 keys: HashMap<String, (Algorithm, DecodingKey)>,
1691 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1693 fetched_at: Instant,
1694 ttl: Duration,
1695}
1696
1697impl CachedKeys {
1698 fn is_expired(&self) -> bool {
1699 self.fetched_at.elapsed() >= self.ttl
1700 }
1701}
1702
1703#[allow(
1712 missing_debug_implementations,
1713 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1714)]
1715#[non_exhaustive]
1716pub struct JwksCache {
1717 jwks_uri: String,
1718 ttl: Duration,
1719 max_jwks_keys: usize,
1720 max_response_bytes: u64,
1721 allow_http: bool,
1722 inner: RwLock<Option<CachedKeys>>,
1723 http: reqwest::Client,
1724 validation_template: Validation,
1725 expected_audience: String,
1728 audience_mode: AudienceValidationMode,
1729 azp_fallback_warned: AtomicBool,
1733 scopes: Vec<ScopeMapping>,
1734 role_claim: Option<String>,
1735 role_mappings: Vec<RoleMapping>,
1736 last_refresh_attempt: RwLock<Option<Instant>>,
1739 refresh_lock: tokio::sync::Mutex<()>,
1741 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1745 #[cfg(any(test, feature = "test-helpers"))]
1749 test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
1750}
1751
1752const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1754
1755const ACCEPTED_ALGS: &[Algorithm] = &[
1757 Algorithm::RS256,
1758 Algorithm::RS384,
1759 Algorithm::RS512,
1760 Algorithm::ES256,
1761 Algorithm::ES384,
1762 Algorithm::PS256,
1763 Algorithm::PS384,
1764 Algorithm::PS512,
1765 Algorithm::EdDSA,
1766];
1767
1768#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1770#[non_exhaustive]
1771pub enum JwtValidationFailure {
1772 Expired,
1774 Invalid,
1776}
1777
1778impl JwksCache {
1779 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1793 rustls::crypto::ring::default_provider()
1796 .install_default()
1797 .ok();
1798 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1799 .install_default()
1800 .ok();
1801
1802 #[allow(
1803 clippy::expect_used,
1804 reason = "jwks_cache_ttl was already parsed successfully by OAuthConfig::validate (call site precondition); re-parsing the same validated string here is infallible"
1805 )]
1806 let ttl = humantime::parse_duration(&config.jwks_cache_ttl)
1807 .expect("jwks_cache_ttl validated by OAuthConfig::validate");
1808
1809 let mut validation = Validation::new(Algorithm::RS256);
1810 validation.validate_aud = false;
1822 validation.set_issuer(&[&config.issuer]);
1823 validation.set_required_spec_claims(&["exp", "iss"]);
1824 validation.validate_exp = true;
1825 validation.validate_nbf = true;
1826
1827 let allow_http = config.allow_http_oauth_urls;
1828
1829 let allowlist = match config.ssrf_allowlist.as_ref() {
1832 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1833 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1834 "oauth.ssrf_allowlist: {e}"
1835 ))
1836 })?),
1837 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1838 };
1839 let redirect_allowlist = Arc::clone(&allowlist);
1840
1841 #[cfg(any(test, feature = "test-helpers"))]
1843 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
1844 Arc::new(AtomicBool::new(false));
1845 #[cfg(not(any(test, feature = "test-helpers")))]
1846 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
1847
1848 let resolver: Arc<dyn reqwest::dns::Resolve> =
1849 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
1850 Arc::clone(&allowlist),
1851 #[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
1852 test_bypass.clone(),
1853 ));
1854
1855 let mut http_builder = reqwest::Client::builder()
1856 .no_proxy()
1858 .dns_resolver(Arc::clone(&resolver))
1859 .timeout(Duration::from_secs(10))
1860 .connect_timeout(Duration::from_secs(3))
1861 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1862 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1872 Ok(()) => attempt.follow(),
1873 Err(reason) => {
1874 tracing::warn!(
1875 reason = %reason,
1876 target = %attempt.url(),
1877 "oauth redirect rejected"
1878 );
1879 attempt.error(reason)
1880 }
1881 }
1882 }));
1883
1884 if let Some(ref ca_path) = config.ca_cert_path {
1885 let pem = std::fs::read(ca_path)?;
1891 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1892 http_builder = http_builder.add_root_certificate(cert);
1893 }
1894
1895 let http = http_builder.build()?;
1896
1897 Ok(Self {
1898 jwks_uri: config.jwks_uri.clone(),
1899 ttl,
1900 max_jwks_keys: config.max_jwks_keys,
1901 max_response_bytes: config.jwks_max_response_bytes,
1902 allow_http,
1903 inner: RwLock::new(None),
1904 http,
1905 validation_template: validation,
1906 expected_audience: config.audience.clone(),
1907 audience_mode: config.effective_audience_validation_mode(),
1908 azp_fallback_warned: AtomicBool::new(false),
1909 scopes: config.scopes.clone(),
1910 role_claim: config.role_claim.clone(),
1911 role_mappings: config.role_mappings.clone(),
1912 last_refresh_attempt: RwLock::new(None),
1913 refresh_lock: tokio::sync::Mutex::new(()),
1914 allowlist,
1915 #[cfg(any(test, feature = "test-helpers"))]
1916 test_allow_loopback_ssrf: test_bypass,
1917 })
1918 }
1919
1920 #[cfg(any(test, feature = "test-helpers"))]
1924 #[doc(hidden)]
1925 #[must_use]
1926 pub fn __test_allow_loopback_ssrf(self) -> Self {
1927 self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
1930 self
1931 }
1932
1933 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1935 self.validate_token_with_reason(token).await.ok()
1936 }
1937
1938 pub async fn validate_token_with_reason(
1945 &self,
1946 token: &str,
1947 ) -> Result<AuthIdentity, JwtValidationFailure> {
1948 let claims = self.decode_claims(token).await?;
1949
1950 self.check_audience(&claims)?;
1951 let role = self.resolve_role(&claims)?;
1952
1953 let sub = claims.sub;
1956 let name = claims
1957 .extra
1958 .get("preferred_username")
1959 .and_then(|v| v.as_str())
1960 .map(String::from)
1961 .or_else(|| sub.clone())
1962 .or(claims.azp)
1963 .or(claims.client_id)
1964 .unwrap_or_else(|| "oauth-client".into());
1965
1966 Ok(AuthIdentity {
1967 name,
1968 role,
1969 method: AuthMethod::OAuthJwt,
1970 raw_token: None,
1971 sub,
1972 })
1973 }
1974
1975 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1987 let (key, alg) = self.select_jwks_key(token).await?;
1988
1989 let mut validation = self.validation_template.clone();
1993 validation.algorithms = vec![alg];
1994
1995 let token_owned = token.to_owned();
1998 let join =
1999 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
2000 .await;
2001
2002 let decode_result = match join {
2003 Ok(r) => r,
2004 Err(join_err) => {
2005 core::hint::cold_path();
2006 tracing::error!(
2007 error = %join_err,
2008 "JWT decode task panicked or was cancelled"
2009 );
2010 return Err(JwtValidationFailure::Invalid);
2011 }
2012 };
2013
2014 decode_result.map(|td| td.claims).map_err(|e| {
2015 core::hint::cold_path();
2016 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
2017 JwtValidationFailure::Expired
2018 } else {
2019 JwtValidationFailure::Invalid
2020 };
2021 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
2022 failure
2023 })
2024 }
2025
2026 #[allow(
2035 clippy::cognitive_complexity,
2036 reason = "each failure arm pairs `cold_path()` with a distinct `tracing::debug!` site for observability; collapsing into combinators would lose structured-field log sites without reducing real complexity"
2037 )]
2038 async fn select_jwks_key(
2039 &self,
2040 token: &str,
2041 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
2042 let Ok(header) = decode_header(token) else {
2043 core::hint::cold_path();
2044 tracing::debug!("JWT header decode failed");
2045 return Err(JwtValidationFailure::Invalid);
2046 };
2047 let kid = header.kid.as_deref();
2048 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
2049
2050 if !ACCEPTED_ALGS.contains(&header.alg) {
2051 core::hint::cold_path();
2052 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
2053 return Err(JwtValidationFailure::Invalid);
2054 }
2055
2056 let Some(key) = self.find_key(kid, header.alg).await else {
2057 core::hint::cold_path();
2058 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
2059 return Err(JwtValidationFailure::Invalid);
2060 };
2061
2062 Ok((key, header.alg))
2063 }
2064
2065 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
2074 if claims.aud.contains(&self.expected_audience) {
2075 return Ok(());
2076 }
2077 let azp_match = claims
2078 .azp
2079 .as_deref()
2080 .is_some_and(|azp| azp == self.expected_audience);
2081 if azp_match {
2082 match self.audience_mode {
2083 AudienceValidationMode::Permissive => return Ok(()),
2084 AudienceValidationMode::Warn => {
2085 if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
2086 tracing::warn!(
2087 expected = %self.expected_audience,
2088 azp = ?claims.azp,
2089 "JWT accepted via deprecated `azp`-only audience fallback. \
2090 Configure your IdP to populate `aud`, or set \
2091 `audience_validation_mode = \"strict\"` once tokens carry `aud` correctly. \
2092 To silence this warning without changing acceptance, \
2093 set `audience_validation_mode = \"permissive\"`. \
2094 This warning logs once per process."
2095 );
2096 }
2097 return Ok(());
2098 }
2099 AudienceValidationMode::Strict => {}
2100 }
2101 }
2102 core::hint::cold_path();
2103 tracing::debug!(
2104 aud = ?claims.aud.0,
2105 azp = ?claims.azp,
2106 expected = %self.expected_audience,
2107 mode = ?self.audience_mode,
2108 "JWT rejected: audience mismatch"
2109 );
2110 Err(JwtValidationFailure::Invalid)
2111 }
2112
2113 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
2119 if let Some(ref claim_path) = self.role_claim {
2120 let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
2121 let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
2122 values.extend(resolve_claim_path(&claims.extra, claim_path));
2123 return self
2124 .role_mappings
2125 .iter()
2126 .find(|m| values.contains(&m.claim_value.as_str()))
2127 .map(|m| m.role.clone())
2128 .ok_or(JwtValidationFailure::Invalid);
2129 }
2130
2131 let token_scopes: Vec<&str> = claims
2132 .scope
2133 .as_deref()
2134 .unwrap_or("")
2135 .split_whitespace()
2136 .collect();
2137
2138 self.scopes
2139 .iter()
2140 .find(|m| token_scopes.contains(&m.scope.as_str()))
2141 .map(|m| m.role.clone())
2142 .ok_or(JwtValidationFailure::Invalid)
2143 }
2144
2145 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2148 {
2150 let guard = self.inner.read().await;
2151 if let Some(cached) = guard.as_ref()
2152 && !cached.is_expired()
2153 && let Some(key) = lookup_key(cached, kid, alg)
2154 {
2155 return Some(key);
2156 }
2157 }
2158
2159 self.refresh_with_cooldown().await;
2161
2162 let guard = self.inner.read().await;
2163 guard
2164 .as_ref()
2165 .and_then(|cached| lookup_key(cached, kid, alg))
2166 }
2167
2168 async fn refresh_with_cooldown(&self) {
2173 let _guard = self.refresh_lock.lock().await;
2175
2176 {
2178 let last = self.last_refresh_attempt.read().await;
2179 if let Some(ts) = *last
2180 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
2181 {
2182 tracing::debug!(
2183 elapsed_ms = ts.elapsed().as_millis(),
2184 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
2185 "JWKS refresh skipped (cooldown active)"
2186 );
2187 return;
2188 }
2189 }
2190
2191 {
2194 let mut last = self.last_refresh_attempt.write().await;
2195 *last = Some(Instant::now());
2196 }
2197
2198 let _ = self.refresh_inner().await;
2200 }
2201
2202 async fn refresh_inner(&self) -> Result<(), String> {
2207 let Some(jwks) = self.fetch_jwks().await else {
2208 return Ok(());
2209 };
2210 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
2211 Ok(cache) => cache,
2212 Err(msg) => {
2213 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
2214 return Err(msg);
2215 }
2216 };
2217
2218 tracing::debug!(
2219 named = keys.len(),
2220 unnamed = unnamed_keys.len(),
2221 "JWKS refreshed"
2222 );
2223
2224 let mut guard = self.inner.write().await;
2225 *guard = Some(CachedKeys {
2226 keys,
2227 unnamed_keys,
2228 fetched_at: Instant::now(),
2229 ttl: self.ttl,
2230 });
2231 Ok(())
2232 }
2233
2234 #[allow(
2236 clippy::cognitive_complexity,
2237 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
2238 )]
2239 async fn fetch_jwks(&self) -> Option<JwkSet> {
2240 #[cfg(any(test, feature = "test-helpers"))]
2241 let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
2242 screen_oauth_target_with_test_override(
2243 &self.jwks_uri,
2244 self.allow_http,
2245 &self.allowlist,
2246 true,
2247 )
2248 .await
2249 } else {
2250 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
2251 };
2252 #[cfg(not(any(test, feature = "test-helpers")))]
2253 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
2254
2255 if let Err(error) = screening {
2256 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
2257 return None;
2258 }
2259
2260 let mut resp = match self.http.get(&self.jwks_uri).send().await {
2261 Ok(resp) => resp,
2262 Err(e) => {
2263 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
2264 return None;
2265 }
2266 };
2267
2268 let initial_capacity =
2269 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
2270 let mut body = Vec::with_capacity(initial_capacity);
2271 while let Some(chunk) = match resp.chunk().await {
2272 Ok(chunk) => chunk,
2273 Err(error) => {
2274 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
2275 return None;
2276 }
2277 } {
2278 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
2279 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
2280 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
2281 tracing::warn!(
2282 uri = %self.jwks_uri,
2283 max_bytes = self.max_response_bytes,
2284 "JWKS response exceeded configured size cap"
2285 );
2286 return None;
2287 }
2288 body.extend_from_slice(&chunk);
2289 }
2290
2291 match serde_json::from_slice::<JwkSet>(&body) {
2292 Ok(jwks) => Some(jwks),
2293 Err(error) => {
2294 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
2295 None
2296 }
2297 }
2298 }
2299
2300 #[cfg(any(test, feature = "test-helpers"))]
2303 #[doc(hidden)]
2304 pub async fn __test_refresh_now(&self) -> Result<(), String> {
2305 let jwks = self
2306 .fetch_jwks()
2307 .await
2308 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
2309 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
2310 let mut guard = self.inner.write().await;
2311 *guard = Some(CachedKeys {
2312 keys,
2313 unnamed_keys,
2314 fetched_at: Instant::now(),
2315 ttl: self.ttl,
2316 });
2317 Ok(())
2318 }
2319
2320 #[cfg(any(test, feature = "test-helpers"))]
2323 #[doc(hidden)]
2324 pub async fn __test_has_kid(&self, kid: &str) -> bool {
2325 let guard = self.inner.read().await;
2326 guard
2327 .as_ref()
2328 .is_some_and(|cache| cache.keys.contains_key(kid))
2329 }
2330}
2331
2332fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
2334 if jwks.keys.len() > max_keys {
2335 return Err(format!(
2336 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
2337 jwks.keys.len(),
2338 max_keys
2339 ));
2340 }
2341 let mut keys = HashMap::new();
2342 let mut unnamed_keys = Vec::new();
2343 for jwk in &jwks.keys {
2344 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
2345 continue;
2346 };
2347 let Some(alg) = jwk_algorithm(jwk) else {
2348 continue;
2349 };
2350 if let Some(ref kid) = jwk.common.key_id {
2351 keys.insert(kid.clone(), (alg, decoding_key));
2352 } else {
2353 unnamed_keys.push((alg, decoding_key));
2354 }
2355 }
2356 Ok((keys, unnamed_keys))
2357}
2358
2359fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
2361 if let Some(kid) = kid
2362 && let Some((cached_alg, key)) = cached.keys.get(kid)
2363 && *cached_alg == alg
2364 {
2365 return Some(key.clone());
2366 }
2367 cached
2369 .unnamed_keys
2370 .iter()
2371 .find(|(a, _)| *a == alg)
2372 .map(|(_, k)| k.clone())
2373}
2374
2375#[allow(
2377 clippy::wildcard_enum_match_arm,
2378 reason = "jsonwebtoken KeyAlgorithm is a large external enum; only the JWT-signing variants are mappable to `Algorithm`"
2379)]
2380fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
2381 jwk.common.key_algorithm.and_then(|ka| match ka {
2382 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
2383 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
2384 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
2385 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
2386 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
2387 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
2388 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
2389 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
2390 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
2391 _ => None,
2392 })
2393}
2394
2395fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
2416 match path {
2417 "sub" => claims.sub.iter().cloned().collect(),
2418 "azp" => claims.azp.iter().cloned().collect(),
2419 "client_id" => claims.client_id.iter().cloned().collect(),
2420 "aud" => claims.aud.0.clone(),
2421 "scope" => claims
2422 .scope
2423 .as_deref()
2424 .unwrap_or("")
2425 .split_whitespace()
2426 .map(str::to_owned)
2427 .collect(),
2428 _ => Vec::new(),
2429 }
2430}
2431
2432fn resolve_claim_path<'a>(
2442 extra: &'a HashMap<String, serde_json::Value>,
2443 path: &str,
2444) -> Vec<&'a str> {
2445 let mut segments = path.split('.');
2446 let Some(first) = segments.next() else {
2447 return Vec::new();
2448 };
2449
2450 let mut current: Option<&serde_json::Value> = extra.get(first);
2451
2452 for segment in segments {
2453 current = current.and_then(|v| v.get(segment));
2454 }
2455
2456 match current {
2457 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
2458 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
2459 _ => Vec::new(),
2460 }
2461}
2462
2463#[derive(Debug, Deserialize)]
2469struct Claims {
2470 sub: Option<String>,
2472 #[serde(default)]
2475 aud: OneOrMany,
2476 azp: Option<String>,
2478 client_id: Option<String>,
2480 scope: Option<String>,
2482 #[serde(flatten)]
2484 extra: HashMap<String, serde_json::Value>,
2485}
2486
2487#[derive(Debug, Default)]
2489struct OneOrMany(Vec<String>);
2490
2491impl OneOrMany {
2492 fn contains(&self, value: &str) -> bool {
2493 self.0.iter().any(|v| v == value)
2494 }
2495}
2496
2497impl<'de> Deserialize<'de> for OneOrMany {
2498 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
2499 use serde::de;
2500
2501 struct Visitor;
2502 impl<'de> de::Visitor<'de> for Visitor {
2503 type Value = OneOrMany;
2504 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2505 f.write_str("a string or array of strings")
2506 }
2507 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
2508 Ok(OneOrMany(vec![v.to_owned()]))
2509 }
2510 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
2511 let mut v = Vec::new();
2512 while let Some(s) = seq.next_element::<String>()? {
2513 v.push(s);
2514 }
2515 Ok(OneOrMany(v))
2516 }
2517 }
2518 deserializer.deserialize_any(Visitor)
2519 }
2520}
2521
2522#[must_use]
2529pub fn looks_like_jwt(token: &str) -> bool {
2530 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2531
2532 let mut parts = token.splitn(4, '.');
2533 let Some(header_b64) = parts.next() else {
2534 return false;
2535 };
2536 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
2538 return false;
2539 }
2540 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
2542 return false;
2543 };
2544 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
2546 return false;
2547 };
2548 header.get("alg").is_some()
2549}
2550
2551#[must_use]
2561pub fn protected_resource_metadata(
2562 resource_url: &str,
2563 server_url: &str,
2564 config: &OAuthConfig,
2565) -> serde_json::Value {
2566 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2571 let auth_server = server_url;
2572 serde_json::json!({
2573 "resource": resource_url,
2574 "authorization_servers": [auth_server],
2575 "scopes_supported": scopes,
2576 "bearer_methods_supported": ["header"]
2577 })
2578}
2579
2580#[must_use]
2585pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2586 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2587 let mut meta = serde_json::json!({
2588 "issuer": &config.issuer,
2589 "authorization_endpoint": format!("{server_url}/authorize"),
2590 "token_endpoint": format!("{server_url}/token"),
2591 "registration_endpoint": format!("{server_url}/register"),
2592 "response_types_supported": ["code"],
2593 "grant_types_supported": ["authorization_code", "refresh_token"],
2594 "code_challenge_methods_supported": ["S256"],
2595 "scopes_supported": scopes,
2596 "token_endpoint_auth_methods_supported": ["none"],
2597 });
2598 if let Some(proxy) = &config.proxy
2599 && proxy.expose_admin_endpoints
2600 && let Some(obj) = meta.as_object_mut()
2601 {
2602 if proxy.introspection_url.is_some() {
2603 obj.insert(
2604 "introspection_endpoint".into(),
2605 serde_json::Value::String(format!("{server_url}/introspect")),
2606 );
2607 }
2608 if proxy.revocation_url.is_some() {
2609 obj.insert(
2610 "revocation_endpoint".into(),
2611 serde_json::Value::String(format!("{server_url}/revoke")),
2612 );
2613 }
2614 if proxy.require_auth_on_admin_endpoints {
2615 obj.insert(
2616 "introspection_endpoint_auth_methods_supported".into(),
2617 serde_json::json!(["bearer"]),
2618 );
2619 obj.insert(
2620 "revocation_endpoint_auth_methods_supported".into(),
2621 serde_json::json!(["bearer"]),
2622 );
2623 }
2624 }
2625 meta
2626}
2627
2628#[must_use]
2641pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2642 use axum::{
2643 http::{StatusCode, header},
2644 response::IntoResponse,
2645 };
2646
2647 let upstream_query = replace_client_id(query, &proxy.client_id);
2649 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2650
2651 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2652}
2653
2654pub async fn handle_token(
2660 http: &OauthHttpClient,
2661 proxy: &OAuthProxyConfig,
2662 body: &str,
2663) -> axum::response::Response {
2664 use axum::{
2665 http::{StatusCode, header},
2666 response::IntoResponse,
2667 };
2668
2669 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2671
2672 if let Some(ref secret) = proxy.client_secret {
2674 use std::fmt::Write;
2675
2676 use secrecy::ExposeSecret;
2677 let _ = write!(
2678 upstream_body,
2679 "&client_secret={}",
2680 urlencoding::encode(secret.expose_secret())
2681 );
2682 }
2683
2684 let result = http
2685 .send_screened(
2686 &proxy.token_url,
2687 http.inner
2688 .post(&proxy.token_url)
2689 .header("Content-Type", "application/x-www-form-urlencoded")
2690 .body(upstream_body),
2691 )
2692 .await;
2693
2694 match result {
2695 Ok(resp) => {
2696 let status =
2697 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2698 let body_bytes = resp.bytes().await.unwrap_or_default();
2699 (
2700 status,
2701 [(header::CONTENT_TYPE, "application/json")],
2702 body_bytes,
2703 )
2704 .into_response()
2705 }
2706 Err(e) => {
2707 tracing::error!(error = %e, "OAuth token proxy request failed");
2708 (
2709 StatusCode::BAD_GATEWAY,
2710 [(header::CONTENT_TYPE, "application/json")],
2711 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2712 )
2713 .into_response()
2714 }
2715 }
2716}
2717
2718#[must_use]
2725pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2726 let mut resp = serde_json::json!({
2727 "client_id": proxy.client_id,
2728 "token_endpoint_auth_method": "none",
2729 });
2730 if let Some(uris) = body.get("redirect_uris")
2731 && let Some(obj) = resp.as_object_mut()
2732 {
2733 obj.insert("redirect_uris".into(), uris.clone());
2734 }
2735 if let Some(name) = body.get("client_name")
2736 && let Some(obj) = resp.as_object_mut()
2737 {
2738 obj.insert("client_name".into(), name.clone());
2739 }
2740 resp
2741}
2742
2743pub async fn handle_introspect(
2749 http: &OauthHttpClient,
2750 proxy: &OAuthProxyConfig,
2751 body: &str,
2752) -> axum::response::Response {
2753 let Some(ref url) = proxy.introspection_url else {
2754 return oauth_error_response(
2755 axum::http::StatusCode::NOT_FOUND,
2756 "not_supported",
2757 "introspection endpoint is not configured",
2758 );
2759 };
2760 proxy_oauth_admin_request(http, proxy, url, body).await
2761}
2762
2763pub async fn handle_revoke(
2770 http: &OauthHttpClient,
2771 proxy: &OAuthProxyConfig,
2772 body: &str,
2773) -> axum::response::Response {
2774 let Some(ref url) = proxy.revocation_url else {
2775 return oauth_error_response(
2776 axum::http::StatusCode::NOT_FOUND,
2777 "not_supported",
2778 "revocation endpoint is not configured",
2779 );
2780 };
2781 proxy_oauth_admin_request(http, proxy, url, body).await
2782}
2783
2784async fn proxy_oauth_admin_request(
2788 http: &OauthHttpClient,
2789 proxy: &OAuthProxyConfig,
2790 upstream_url: &str,
2791 body: &str,
2792) -> axum::response::Response {
2793 use axum::{
2794 http::{StatusCode, header},
2795 response::IntoResponse,
2796 };
2797
2798 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2799 if let Some(ref secret) = proxy.client_secret {
2800 use std::fmt::Write;
2801
2802 use secrecy::ExposeSecret;
2803 let _ = write!(
2804 upstream_body,
2805 "&client_secret={}",
2806 urlencoding::encode(secret.expose_secret())
2807 );
2808 }
2809
2810 let result = http
2811 .send_screened(
2812 upstream_url,
2813 http.inner
2814 .post(upstream_url)
2815 .header("Content-Type", "application/x-www-form-urlencoded")
2816 .body(upstream_body),
2817 )
2818 .await;
2819
2820 match result {
2821 Ok(resp) => {
2822 let status =
2823 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2824 let content_type = resp
2825 .headers()
2826 .get(header::CONTENT_TYPE)
2827 .and_then(|v| v.to_str().ok())
2828 .unwrap_or("application/json")
2829 .to_owned();
2830 let body_bytes = resp.bytes().await.unwrap_or_default();
2831 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2832 }
2833 Err(e) => {
2834 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2835 oauth_error_response(
2836 StatusCode::BAD_GATEWAY,
2837 "server_error",
2838 "upstream endpoint unreachable",
2839 )
2840 }
2841 }
2842}
2843
2844fn oauth_error_response(
2845 status: axum::http::StatusCode,
2846 error: &str,
2847 description: &str,
2848) -> axum::response::Response {
2849 use axum::{http::header, response::IntoResponse};
2850 let body = serde_json::json!({
2851 "error": error,
2852 "error_description": description,
2853 });
2854 (
2855 status,
2856 [(header::CONTENT_TYPE, "application/json")],
2857 body.to_string(),
2858 )
2859 .into_response()
2860}
2861
2862#[derive(Debug, Deserialize)]
2868struct OAuthErrorResponse {
2869 error: String,
2870 error_description: Option<String>,
2871}
2872
2873fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2880 match raw {
2881 "invalid_request" => "invalid_request",
2882 "invalid_client" => "invalid_client",
2883 "invalid_grant" => "invalid_grant",
2884 "unauthorized_client" => "unauthorized_client",
2885 "unsupported_grant_type" => "unsupported_grant_type",
2886 "invalid_scope" => "invalid_scope",
2887 "temporarily_unavailable" => "temporarily_unavailable",
2888 "invalid_target" => "invalid_target",
2890 _ => "server_error",
2893 }
2894}
2895
2896pub async fn exchange_token(
2908 http: &OauthHttpClient,
2909 config: &TokenExchangeConfig,
2910 subject_token: &str,
2911) -> Result<ExchangedToken, crate::error::McpxError> {
2912 use secrecy::ExposeSecret;
2913
2914 let client = http.client_for(config);
2915 let mut req = client
2916 .post(&config.token_url)
2917 .header("Content-Type", "application/x-www-form-urlencoded")
2918 .header("Accept", "application/json");
2919
2920 if config.client_cert.is_none()
2929 && let Some(ref secret) = config.client_secret
2930 {
2931 use base64::Engine;
2932 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2933 "{}:{}",
2934 urlencoding::encode(&config.client_id),
2935 urlencoding::encode(secret.expose_secret()),
2936 ));
2937 req = req.header("Authorization", format!("Basic {credentials}"));
2938 }
2939
2940 let form_body = build_exchange_form(config, subject_token);
2941
2942 let resp = http
2943 .send_screened(&config.token_url, req.body(form_body))
2944 .await
2945 .map_err(|e| {
2946 tracing::error!(error = %e, "token exchange request failed");
2947 crate::error::McpxError::Auth("server_error".into())
2949 })?;
2950
2951 let status = resp.status();
2952 let body_bytes = resp.bytes().await.map_err(|e| {
2953 tracing::error!(error = %e, "failed to read token exchange response");
2954 crate::error::McpxError::Auth("server_error".into())
2955 })?;
2956
2957 if !status.is_success() {
2958 core::hint::cold_path();
2959 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2962 let short_code = parsed
2963 .as_ref()
2964 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2965 if let Some(ref e) = parsed {
2966 tracing::warn!(
2967 status = %status,
2968 upstream_error = %e.error,
2969 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2970 client_code = %short_code,
2971 "token exchange rejected by authorization server",
2972 );
2973 } else {
2974 tracing::warn!(
2975 status = %status,
2976 client_code = %short_code,
2977 "token exchange rejected (unparseable upstream body)",
2978 );
2979 }
2980 return Err(crate::error::McpxError::Auth(short_code.into()));
2981 }
2982
2983 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2984 tracing::error!(error = %e, "failed to parse token exchange response");
2985 crate::error::McpxError::Auth("server_error".into())
2988 })?;
2989
2990 log_exchanged_token(&exchanged);
2991
2992 Ok(exchanged)
2993}
2994
2995fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2998 let body = format!(
2999 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
3000 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
3001 urlencoding::encode(subject_token),
3002 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3003 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
3004 urlencoding::encode(&config.audience),
3005 );
3006 if config.client_secret.is_none() {
3007 format!(
3008 "{body}&client_id={}",
3009 urlencoding::encode(&config.client_id)
3010 )
3011 } else {
3012 body
3013 }
3014}
3015
3016fn log_exchanged_token(exchanged: &ExchangedToken) {
3019 use base64::Engine;
3020
3021 if !looks_like_jwt(&exchanged.access_token) {
3022 tracing::debug!(
3023 token_len = exchanged.access_token.len(),
3024 issued_token_type = ?exchanged.issued_token_type,
3025 expires_in = exchanged.expires_in,
3026 "exchanged token (opaque)",
3027 );
3028 return;
3029 }
3030 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
3031 return;
3032 };
3033 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
3034 return;
3035 };
3036 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
3037 return;
3038 };
3039 tracing::debug!(
3040 sub = ?claims.get("sub"),
3041 aud = ?claims.get("aud"),
3042 azp = ?claims.get("azp"),
3043 iss = ?claims.get("iss"),
3044 expires_in = exchanged.expires_in,
3045 "exchanged token claims (JWT)",
3046 );
3047}
3048
3049fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
3051 let encoded_id = urlencoding::encode(upstream_client_id);
3052 let mut parts: Vec<String> = params
3053 .split('&')
3054 .filter(|p| !p.starts_with("client_id="))
3055 .map(String::from)
3056 .collect();
3057 parts.push(format!("client_id={encoded_id}"));
3058 parts.join("&")
3059}
3060
3061#[cfg(test)]
3062mod tests {
3063 use std::sync::Arc;
3064
3065 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3066
3067 use super::*;
3068
3069 #[test]
3070 fn looks_like_jwt_valid() {
3071 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
3073 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3074 let token = format!("{header}.{payload}.signature");
3075 assert!(looks_like_jwt(&token));
3076 }
3077
3078 #[test]
3079 fn looks_like_jwt_rejects_opaque_token() {
3080 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
3081 }
3082
3083 #[test]
3084 fn looks_like_jwt_rejects_two_segments() {
3085 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
3086 let token = format!("{header}.payload");
3087 assert!(!looks_like_jwt(&token));
3088 }
3089
3090 #[test]
3091 fn looks_like_jwt_rejects_four_segments() {
3092 assert!(!looks_like_jwt("a.b.c.d"));
3093 }
3094
3095 #[test]
3096 fn looks_like_jwt_rejects_no_alg() {
3097 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
3098 let payload = URL_SAFE_NO_PAD.encode(b"{}");
3099 let token = format!("{header}.{payload}.sig");
3100 assert!(!looks_like_jwt(&token));
3101 }
3102
3103 #[test]
3104 fn protected_resource_metadata_shape() {
3105 let config = OAuthConfig {
3106 issuer: "https://auth.example.com".into(),
3107 audience: "https://mcp.example.com/mcp".into(),
3108 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
3109 scopes: vec![
3110 ScopeMapping {
3111 scope: "mcp:read".into(),
3112 role: "viewer".into(),
3113 },
3114 ScopeMapping {
3115 scope: "mcp:admin".into(),
3116 role: "ops".into(),
3117 },
3118 ],
3119 role_claim: None,
3120 role_mappings: vec![],
3121 jwks_cache_ttl: "10m".into(),
3122 proxy: None,
3123 token_exchange: None,
3124 ca_cert_path: None,
3125 allow_http_oauth_urls: false,
3126 max_jwks_keys: default_max_jwks_keys(),
3127 #[allow(
3128 deprecated,
3129 reason = "test fixture: explicit value for the deprecated field"
3130 )]
3131 strict_audience_validation: false,
3132 audience_validation_mode: None,
3133 jwks_max_response_bytes: default_jwks_max_bytes(),
3134 ssrf_allowlist: None,
3135 };
3136 let meta = protected_resource_metadata(
3137 "https://mcp.example.com/mcp",
3138 "https://mcp.example.com",
3139 &config,
3140 );
3141 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
3142 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
3143 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
3144 assert_eq!(meta["bearer_methods_supported"][0], "header");
3145 }
3146
3147 fn validation_https_config() -> OAuthConfig {
3152 OAuthConfig::builder(
3153 "https://auth.example.com",
3154 "mcp",
3155 "https://auth.example.com/.well-known/jwks.json",
3156 )
3157 .build()
3158 }
3159
3160 #[test]
3161 fn validate_accepts_all_https_urls() {
3162 let cfg = validation_https_config();
3163 cfg.validate().expect("all-HTTPS config must validate");
3164 }
3165
3166 #[test]
3167 fn validate_rejects_unparseable_jwks_cache_ttl() {
3168 let mut cfg = validation_https_config();
3169 cfg.jwks_cache_ttl = "not-a-duration".into();
3170 let err = cfg
3171 .validate()
3172 .expect_err("malformed jwks_cache_ttl must be rejected");
3173 let msg = err.to_string();
3174 assert!(
3175 msg.contains("jwks_cache_ttl"),
3176 "error must reference offending field; got {msg:?}"
3177 );
3178 }
3179
3180 #[test]
3181 fn validate_rejects_http_jwks_uri() {
3182 let mut cfg = validation_https_config();
3183 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
3184 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
3185 let msg = err.to_string();
3186 assert!(
3187 msg.contains("oauth.jwks_uri") && msg.contains("https"),
3188 "error must reference offending field + scheme requirement; got {msg:?}"
3189 );
3190 }
3191
3192 #[test]
3193 fn validate_rejects_http_proxy_authorize_url() {
3194 let mut cfg = validation_https_config();
3195 cfg.proxy = Some(
3196 OAuthProxyConfig::builder(
3197 "http://idp.example.com/authorize", "https://idp.example.com/token",
3199 "client",
3200 )
3201 .build(),
3202 );
3203 let err = cfg
3204 .validate()
3205 .expect_err("http authorize_url must be rejected");
3206 assert!(
3207 err.to_string().contains("oauth.proxy.authorize_url"),
3208 "error must reference proxy.authorize_url; got {err}"
3209 );
3210 }
3211
3212 #[test]
3213 fn validate_rejects_http_proxy_token_url() {
3214 let mut cfg = validation_https_config();
3215 cfg.proxy = Some(
3216 OAuthProxyConfig::builder(
3217 "https://idp.example.com/authorize",
3218 "http://idp.example.com/token", "client",
3220 )
3221 .build(),
3222 );
3223 let err = cfg.validate().expect_err("http token_url must be rejected");
3224 assert!(
3225 err.to_string().contains("oauth.proxy.token_url"),
3226 "error must reference proxy.token_url; got {err}"
3227 );
3228 }
3229
3230 #[test]
3231 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
3232 let mut cfg = validation_https_config();
3233 cfg.proxy = Some(
3234 OAuthProxyConfig::builder(
3235 "https://idp.example.com/authorize",
3236 "https://idp.example.com/token",
3237 "client",
3238 )
3239 .introspection_url("http://idp.example.com/introspect")
3240 .build(),
3241 );
3242 let err = cfg
3243 .validate()
3244 .expect_err("http introspection_url must be rejected");
3245 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
3246
3247 let mut cfg = validation_https_config();
3248 cfg.proxy = Some(
3249 OAuthProxyConfig::builder(
3250 "https://idp.example.com/authorize",
3251 "https://idp.example.com/token",
3252 "client",
3253 )
3254 .revocation_url("http://idp.example.com/revoke")
3255 .build(),
3256 );
3257 let err = cfg
3258 .validate()
3259 .expect_err("http revocation_url must be rejected");
3260 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
3261 }
3262
3263 #[test]
3266 fn validate_rejects_exposed_admin_endpoints_without_auth() {
3267 let mut cfg = validation_https_config();
3268 cfg.proxy = Some(
3269 OAuthProxyConfig::builder(
3270 "https://idp.example.com/authorize",
3271 "https://idp.example.com/token",
3272 "client",
3273 )
3274 .introspection_url("https://idp.example.com/introspect")
3275 .expose_admin_endpoints(true)
3276 .build(),
3277 );
3278 let err = cfg
3279 .validate()
3280 .expect_err("expose_admin_endpoints without auth must fail");
3281 let msg = err.to_string();
3282 assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
3283 assert!(
3284 msg.contains("allow_unauthenticated_admin_endpoints"),
3285 "{msg}"
3286 );
3287 }
3288
3289 #[test]
3290 fn validate_accepts_exposed_admin_endpoints_with_auth() {
3291 let mut cfg = validation_https_config();
3292 cfg.proxy = Some(
3293 OAuthProxyConfig::builder(
3294 "https://idp.example.com/authorize",
3295 "https://idp.example.com/token",
3296 "client",
3297 )
3298 .introspection_url("https://idp.example.com/introspect")
3299 .expose_admin_endpoints(true)
3300 .require_auth_on_admin_endpoints(true)
3301 .build(),
3302 );
3303 cfg.validate()
3304 .expect("authed admin endpoints must validate");
3305 }
3306
3307 #[test]
3308 fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
3309 let mut cfg = validation_https_config();
3310 cfg.proxy = Some(
3311 OAuthProxyConfig::builder(
3312 "https://idp.example.com/authorize",
3313 "https://idp.example.com/token",
3314 "client",
3315 )
3316 .introspection_url("https://idp.example.com/introspect")
3317 .expose_admin_endpoints(true)
3318 .allow_unauthenticated_admin_endpoints(true)
3319 .build(),
3320 );
3321 cfg.validate()
3322 .expect("explicit unauth opt-out must validate");
3323 }
3324
3325 #[test]
3326 fn validate_accepts_unexposed_admin_endpoints_without_auth() {
3327 let mut cfg = validation_https_config();
3330 cfg.proxy = Some(
3331 OAuthProxyConfig::builder(
3332 "https://idp.example.com/authorize",
3333 "https://idp.example.com/token",
3334 "client",
3335 )
3336 .introspection_url("https://idp.example.com/introspect")
3337 .build(),
3338 );
3339 cfg.validate()
3340 .expect("unexposed admin endpoints must validate");
3341 }
3342
3343 #[test]
3344 fn validate_rejects_http_token_exchange_url() {
3345 let mut cfg = validation_https_config();
3346 cfg.token_exchange = Some(TokenExchangeConfig::new(
3347 "http://idp.example.com/token".into(), "client".into(),
3349 None,
3350 None,
3351 "downstream".into(),
3352 ));
3353 let err = cfg
3354 .validate()
3355 .expect_err("http token_exchange.token_url must be rejected");
3356 assert!(
3357 err.to_string().contains("oauth.token_exchange.token_url"),
3358 "error must reference token_exchange.token_url; got {err}"
3359 );
3360 }
3361
3362 #[test]
3363 fn validate_rejects_unparseable_url() {
3364 let mut cfg = validation_https_config();
3365 cfg.jwks_uri = "not a url".into();
3366 let err = cfg
3367 .validate()
3368 .expect_err("unparseable URL must be rejected");
3369 assert!(err.to_string().contains("invalid URL"));
3370 }
3371
3372 #[test]
3373 fn validate_rejects_non_http_scheme() {
3374 let mut cfg = validation_https_config();
3375 cfg.jwks_uri = "file:///etc/passwd".into();
3376 let err = cfg.validate().expect_err("file:// scheme must be rejected");
3377 let msg = err.to_string();
3378 assert!(
3379 msg.contains("must use https scheme") && msg.contains("file"),
3380 "error must reject non-http(s) schemes; got {msg:?}"
3381 );
3382 }
3383
3384 #[test]
3385 fn validate_accepts_http_with_escape_hatch() {
3386 let mut cfg = OAuthConfig::builder(
3391 "http://auth.local",
3392 "mcp",
3393 "http://auth.local/.well-known/jwks.json",
3394 )
3395 .allow_http_oauth_urls(true)
3396 .build();
3397 cfg.proxy = Some(
3398 OAuthProxyConfig::builder(
3399 "http://idp.local/authorize",
3400 "http://idp.local/token",
3401 "client",
3402 )
3403 .introspection_url("http://idp.local/introspect")
3404 .revocation_url("http://idp.local/revoke")
3405 .build(),
3406 );
3407 cfg.token_exchange = Some(TokenExchangeConfig::new(
3408 "http://idp.local/token".into(),
3409 "client".into(),
3410 Some(secrecy::SecretString::new("dev-secret".into())),
3411 None,
3412 "downstream".into(),
3413 ));
3414 cfg.validate()
3415 .expect("escape hatch must permit http on all URL fields");
3416 }
3417
3418 #[test]
3419 fn validate_with_escape_hatch_still_rejects_unparseable() {
3420 let mut cfg = validation_https_config();
3423 cfg.allow_http_oauth_urls = true;
3424 cfg.jwks_uri = "::not-a-url::".into();
3425 cfg.validate()
3426 .expect_err("escape hatch must NOT bypass URL parsing");
3427 }
3428
3429 #[tokio::test]
3430 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
3431 rustls::crypto::ring::default_provider()
3446 .install_default()
3447 .ok();
3448
3449 let policy = reqwest::redirect::Policy::custom(|attempt| {
3450 if attempt.url().scheme() != "https" {
3451 attempt.error("redirect to non-HTTPS URL refused")
3452 } else if attempt.previous().len() >= 2 {
3453 attempt.error("too many redirects (max 2)")
3454 } else {
3455 attempt.follow()
3456 }
3457 });
3458 let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
3465 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
3466 let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
3467 crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
3468 );
3469 let client = reqwest::Client::builder()
3470 .no_proxy()
3471 .dns_resolver(Arc::clone(&resolver))
3472 .timeout(Duration::from_secs(5))
3473 .connect_timeout(Duration::from_secs(3))
3474 .redirect(policy)
3475 .build()
3476 .expect("test client builds");
3477
3478 let mock = wiremock::MockServer::start().await;
3479 wiremock::Mock::given(wiremock::matchers::method("GET"))
3480 .and(wiremock::matchers::path("/jwks.json"))
3481 .respond_with(
3482 wiremock::ResponseTemplate::new(302)
3483 .insert_header("location", "http://example.invalid/jwks.json"),
3484 )
3485 .mount(&mock)
3486 .await;
3487
3488 let url = format!("{}/jwks.json", mock.uri());
3497 let err = client
3498 .get(&url)
3499 .send()
3500 .await
3501 .expect_err("redirect policy must reject scheme downgrade");
3502 let chain = format!("{err:#}");
3503 assert!(
3504 chain.contains("redirect to non-HTTPS URL refused")
3505 || chain.to_lowercase().contains("redirect"),
3506 "error must surface redirect-policy rejection; got {chain:?}"
3507 );
3508 }
3509
3510 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
3515
3516 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
3518 let mut rng = rsa::rand_core::OsRng;
3519 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
3520 let private_pem = private_key
3521 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
3522 .expect("PKCS8 PEM export")
3523 .to_string();
3524
3525 let public_key = private_key.to_public_key();
3526 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
3527 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
3528
3529 let jwks = serde_json::json!({
3530 "keys": [{
3531 "kty": "RSA",
3532 "use": "sig",
3533 "alg": "RS256",
3534 "kid": kid,
3535 "n": n,
3536 "e": e
3537 }]
3538 });
3539
3540 (private_pem, jwks)
3541 }
3542
3543 fn mint_token(
3545 private_pem: &str,
3546 kid: &str,
3547 issuer: &str,
3548 audience: &str,
3549 subject: &str,
3550 scope: &str,
3551 ) -> String {
3552 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3553 .expect("encoding key from PEM");
3554 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3555 header.kid = Some(kid.into());
3556
3557 let now = jsonwebtoken::get_current_timestamp();
3558 let claims = serde_json::json!({
3559 "iss": issuer,
3560 "aud": audience,
3561 "sub": subject,
3562 "scope": scope,
3563 "exp": now + 3600,
3564 "iat": now,
3565 });
3566
3567 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3568 }
3569
3570 fn test_config(jwks_uri: &str) -> OAuthConfig {
3571 OAuthConfig {
3572 issuer: "https://auth.test.local".into(),
3573 audience: "https://mcp.test.local/mcp".into(),
3574 jwks_uri: jwks_uri.into(),
3575 scopes: vec![
3576 ScopeMapping {
3577 scope: "mcp:read".into(),
3578 role: "viewer".into(),
3579 },
3580 ScopeMapping {
3581 scope: "mcp:admin".into(),
3582 role: "ops".into(),
3583 },
3584 ],
3585 role_claim: None,
3586 role_mappings: vec![],
3587 jwks_cache_ttl: "5m".into(),
3588 proxy: None,
3589 token_exchange: None,
3590 ca_cert_path: None,
3591 allow_http_oauth_urls: true,
3592 max_jwks_keys: default_max_jwks_keys(),
3593 #[allow(
3594 deprecated,
3595 reason = "test fixture: explicit value for the deprecated field"
3596 )]
3597 strict_audience_validation: false,
3598 audience_validation_mode: None,
3599 jwks_max_response_bytes: default_jwks_max_bytes(),
3600 ssrf_allowlist: None,
3601 }
3602 }
3603
3604 fn test_cache(config: &OAuthConfig) -> JwksCache {
3605 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
3606 }
3607
3608 #[tokio::test]
3609 async fn valid_jwt_returns_identity() {
3610 let kid = "test-key-1";
3611 let (pem, jwks) = generate_test_keypair(kid);
3612
3613 let mock_server = wiremock::MockServer::start().await;
3614 wiremock::Mock::given(wiremock::matchers::method("GET"))
3615 .and(wiremock::matchers::path("/jwks.json"))
3616 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3617 .mount(&mock_server)
3618 .await;
3619
3620 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3621 let config = test_config(&jwks_uri);
3622 let cache = test_cache(&config);
3623
3624 let token = mint_token(
3625 &pem,
3626 kid,
3627 "https://auth.test.local",
3628 "https://mcp.test.local/mcp",
3629 "ci-bot",
3630 "mcp:read mcp:other",
3631 );
3632
3633 let identity = cache.validate_token(&token).await;
3634 assert!(identity.is_some(), "valid JWT should authenticate");
3635 let id = identity.unwrap();
3636 assert_eq!(id.name, "ci-bot");
3637 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
3639 }
3640
3641 #[tokio::test]
3642 async fn wrong_issuer_rejected() {
3643 let kid = "test-key-2";
3644 let (pem, jwks) = generate_test_keypair(kid);
3645
3646 let mock_server = wiremock::MockServer::start().await;
3647 wiremock::Mock::given(wiremock::matchers::method("GET"))
3648 .and(wiremock::matchers::path("/jwks.json"))
3649 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3650 .mount(&mock_server)
3651 .await;
3652
3653 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3654 let config = test_config(&jwks_uri);
3655 let cache = test_cache(&config);
3656
3657 let token = mint_token(
3658 &pem,
3659 kid,
3660 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
3662 "attacker",
3663 "mcp:admin",
3664 );
3665
3666 assert!(cache.validate_token(&token).await.is_none());
3667 }
3668
3669 #[tokio::test]
3670 async fn wrong_audience_rejected() {
3671 let kid = "test-key-3";
3672 let (pem, jwks) = generate_test_keypair(kid);
3673
3674 let mock_server = wiremock::MockServer::start().await;
3675 wiremock::Mock::given(wiremock::matchers::method("GET"))
3676 .and(wiremock::matchers::path("/jwks.json"))
3677 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3678 .mount(&mock_server)
3679 .await;
3680
3681 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3682 let config = test_config(&jwks_uri);
3683 let cache = test_cache(&config);
3684
3685 let token = mint_token(
3686 &pem,
3687 kid,
3688 "https://auth.test.local",
3689 "https://wrong-audience.example.com", "attacker",
3691 "mcp:admin",
3692 );
3693
3694 assert!(cache.validate_token(&token).await.is_none());
3695 }
3696
3697 #[tokio::test]
3698 async fn expired_jwt_rejected() {
3699 let kid = "test-key-4";
3700 let (pem, jwks) = generate_test_keypair(kid);
3701
3702 let mock_server = wiremock::MockServer::start().await;
3703 wiremock::Mock::given(wiremock::matchers::method("GET"))
3704 .and(wiremock::matchers::path("/jwks.json"))
3705 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3706 .mount(&mock_server)
3707 .await;
3708
3709 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3710 let config = test_config(&jwks_uri);
3711 let cache = test_cache(&config);
3712
3713 let encoding_key =
3715 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3716 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3717 header.kid = Some(kid.into());
3718 let now = jsonwebtoken::get_current_timestamp();
3719 let claims = serde_json::json!({
3720 "iss": "https://auth.test.local",
3721 "aud": "https://mcp.test.local/mcp",
3722 "sub": "expired-bot",
3723 "scope": "mcp:read",
3724 "exp": now - 120,
3725 "iat": now - 3720,
3726 });
3727 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3728
3729 assert!(cache.validate_token(&token).await.is_none());
3730 }
3731
3732 #[tokio::test]
3733 async fn no_matching_scope_rejected() {
3734 let kid = "test-key-5";
3735 let (pem, jwks) = generate_test_keypair(kid);
3736
3737 let mock_server = wiremock::MockServer::start().await;
3738 wiremock::Mock::given(wiremock::matchers::method("GET"))
3739 .and(wiremock::matchers::path("/jwks.json"))
3740 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3741 .mount(&mock_server)
3742 .await;
3743
3744 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3745 let config = test_config(&jwks_uri);
3746 let cache = test_cache(&config);
3747
3748 let token = mint_token(
3749 &pem,
3750 kid,
3751 "https://auth.test.local",
3752 "https://mcp.test.local/mcp",
3753 "limited-bot",
3754 "some:other:scope", );
3756
3757 assert!(cache.validate_token(&token).await.is_none());
3758 }
3759
3760 #[tokio::test]
3761 async fn wrong_signing_key_rejected() {
3762 let kid = "test-key-6";
3763 let (_pem, jwks) = generate_test_keypair(kid);
3764
3765 let (attacker_pem, _) = generate_test_keypair(kid);
3767
3768 let mock_server = wiremock::MockServer::start().await;
3769 wiremock::Mock::given(wiremock::matchers::method("GET"))
3770 .and(wiremock::matchers::path("/jwks.json"))
3771 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3772 .mount(&mock_server)
3773 .await;
3774
3775 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3776 let config = test_config(&jwks_uri);
3777 let cache = test_cache(&config);
3778
3779 let token = mint_token(
3781 &attacker_pem,
3782 kid,
3783 "https://auth.test.local",
3784 "https://mcp.test.local/mcp",
3785 "attacker",
3786 "mcp:admin",
3787 );
3788
3789 assert!(cache.validate_token(&token).await.is_none());
3790 }
3791
3792 #[tokio::test]
3793 async fn admin_scope_maps_to_ops_role() {
3794 let kid = "test-key-7";
3795 let (pem, jwks) = generate_test_keypair(kid);
3796
3797 let mock_server = wiremock::MockServer::start().await;
3798 wiremock::Mock::given(wiremock::matchers::method("GET"))
3799 .and(wiremock::matchers::path("/jwks.json"))
3800 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3801 .mount(&mock_server)
3802 .await;
3803
3804 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3805 let config = test_config(&jwks_uri);
3806 let cache = test_cache(&config);
3807
3808 let token = mint_token(
3809 &pem,
3810 kid,
3811 "https://auth.test.local",
3812 "https://mcp.test.local/mcp",
3813 "admin-bot",
3814 "mcp:admin",
3815 );
3816
3817 let id = cache
3818 .validate_token(&token)
3819 .await
3820 .expect("should authenticate");
3821 assert_eq!(id.role, "ops");
3822 assert_eq!(id.name, "admin-bot");
3823 }
3824
3825 #[tokio::test]
3826 async fn jwks_server_down_returns_none() {
3827 let config = test_config("http://127.0.0.1:1/jwks.json");
3829 let cache = test_cache(&config);
3830
3831 let kid = "orphan-key";
3832 let (pem, _) = generate_test_keypair(kid);
3833 let token = mint_token(
3834 &pem,
3835 kid,
3836 "https://auth.test.local",
3837 "https://mcp.test.local/mcp",
3838 "bot",
3839 "mcp:read",
3840 );
3841
3842 assert!(cache.validate_token(&token).await.is_none());
3843 }
3844
3845 #[test]
3850 fn resolve_claim_path_flat_string() {
3851 let mut extra = HashMap::new();
3852 extra.insert(
3853 "scope".into(),
3854 serde_json::Value::String("mcp:read mcp:admin".into()),
3855 );
3856 let values = resolve_claim_path(&extra, "scope");
3857 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3858 }
3859
3860 #[test]
3861 fn resolve_claim_path_flat_array() {
3862 let mut extra = HashMap::new();
3863 extra.insert(
3864 "roles".into(),
3865 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3866 );
3867 let values = resolve_claim_path(&extra, "roles");
3868 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3869 }
3870
3871 #[test]
3872 fn resolve_claim_path_nested_keycloak() {
3873 let mut extra = HashMap::new();
3874 extra.insert(
3875 "realm_access".into(),
3876 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3877 );
3878 let values = resolve_claim_path(&extra, "realm_access.roles");
3879 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3880 }
3881
3882 #[test]
3883 fn resolve_claim_path_missing_returns_empty() {
3884 let extra = HashMap::new();
3885 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3886 }
3887
3888 #[test]
3889 fn resolve_claim_path_numeric_leaf_returns_empty() {
3890 let mut extra = HashMap::new();
3891 extra.insert("count".into(), serde_json::json!(42));
3892 assert!(resolve_claim_path(&extra, "count").is_empty());
3893 }
3894
3895 fn make_claims(json: serde_json::Value) -> Claims {
3896 serde_json::from_value(json).expect("test claims must deserialize")
3897 }
3898
3899 #[test]
3900 fn first_class_scope_claim_splits_on_whitespace() {
3901 let claims = make_claims(serde_json::json!({
3902 "iss": "https://issuer.example.com",
3903 "exp": 9_999_999_999_u64,
3904 "scope": "read write admin",
3905 }));
3906 let values = first_class_claim_values(&claims, "scope");
3907 assert_eq!(values, vec!["read", "write", "admin"]);
3908 }
3909
3910 #[test]
3911 fn first_class_sub_claim_returns_single_value() {
3912 let claims = make_claims(serde_json::json!({
3913 "iss": "https://issuer.example.com",
3914 "exp": 9_999_999_999_u64,
3915 "sub": "service-account-orders",
3916 }));
3917 let values = first_class_claim_values(&claims, "sub");
3918 assert_eq!(values, vec!["service-account-orders"]);
3919 }
3920
3921 #[test]
3922 fn first_class_aud_claim_returns_every_audience() {
3923 let claims = make_claims(serde_json::json!({
3924 "iss": "https://issuer.example.com",
3925 "exp": 9_999_999_999_u64,
3926 "aud": ["api-a", "api-b"],
3927 }));
3928 let values = first_class_claim_values(&claims, "aud");
3929 assert_eq!(values, vec!["api-a", "api-b"]);
3930 }
3931
3932 #[test]
3933 fn first_class_unknown_path_returns_empty() {
3934 let claims = make_claims(serde_json::json!({
3935 "iss": "https://issuer.example.com",
3936 "exp": 9_999_999_999_u64,
3937 }));
3938 assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
3939 }
3940
3941 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3947 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3948 .expect("encoding key from PEM");
3949 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3950 header.kid = Some(kid.into());
3951 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3952 }
3953
3954 fn test_config_with_role_claim(
3955 jwks_uri: &str,
3956 role_claim: &str,
3957 role_mappings: Vec<RoleMapping>,
3958 ) -> OAuthConfig {
3959 OAuthConfig {
3960 issuer: "https://auth.test.local".into(),
3961 audience: "https://mcp.test.local/mcp".into(),
3962 jwks_uri: jwks_uri.into(),
3963 scopes: vec![],
3964 role_claim: Some(role_claim.into()),
3965 role_mappings,
3966 jwks_cache_ttl: "5m".into(),
3967 proxy: None,
3968 token_exchange: None,
3969 ca_cert_path: None,
3970 allow_http_oauth_urls: true,
3971 max_jwks_keys: default_max_jwks_keys(),
3972 #[allow(
3973 deprecated,
3974 reason = "test fixture: explicit value for the deprecated field"
3975 )]
3976 strict_audience_validation: false,
3977 audience_validation_mode: None,
3978 jwks_max_response_bytes: default_jwks_max_bytes(),
3979 ssrf_allowlist: None,
3980 }
3981 }
3982
3983 #[tokio::test]
3984 async fn screen_oauth_target_rejects_literal_ip() {
3985 let err = screen_oauth_target(
3986 "https://127.0.0.1/jwks.json",
3987 false,
3988 &crate::ssrf::CompiledSsrfAllowlist::default(),
3989 )
3990 .await
3991 .expect_err("literal IPs must be rejected");
3992 let msg = err.to_string();
3993 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3994 }
3995
3996 #[tokio::test]
3997 async fn screen_oauth_target_rejects_private_dns_resolution() {
3998 let err = screen_oauth_target(
3999 "https://localhost/jwks.json",
4000 false,
4001 &crate::ssrf::CompiledSsrfAllowlist::default(),
4002 )
4003 .await
4004 .expect_err("localhost resolution must be rejected");
4005 let msg = err.to_string();
4006 assert!(
4007 msg.contains("blocked IP") && msg.contains("loopback"),
4008 "got {msg:?}"
4009 );
4010 }
4011
4012 #[tokio::test]
4013 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
4014 let err = screen_oauth_target(
4015 "http://127.0.0.1/jwks.json",
4016 true,
4017 &crate::ssrf::CompiledSsrfAllowlist::default(),
4018 )
4019 .await
4020 .expect_err("literal IPs must still be rejected when http is allowed");
4021 let msg = err.to_string();
4022 assert!(msg.contains("literal IPv4 addresses are forbidden"));
4023 }
4024
4025 #[tokio::test]
4026 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
4027 let err = screen_oauth_target(
4028 "http://localhost/jwks.json",
4029 true,
4030 &crate::ssrf::CompiledSsrfAllowlist::default(),
4031 )
4032 .await
4033 .expect_err("private DNS resolution must still be rejected when http is allowed");
4034 let msg = err.to_string();
4035 assert!(
4036 msg.contains("blocked IP") && msg.contains("loopback"),
4037 "got {msg:?}"
4038 );
4039 }
4040
4041 #[tokio::test]
4042 async fn screen_oauth_target_allows_public_hostname() {
4043 screen_oauth_target(
4044 "https://example.com/.well-known/jwks.json",
4045 false,
4046 &crate::ssrf::CompiledSsrfAllowlist::default(),
4047 )
4048 .await
4049 .expect("public hostname should pass screening");
4050 }
4051
4052 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
4058 let raw = OAuthSsrfAllowlist {
4059 hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
4060 cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
4061 };
4062 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
4063 }
4064
4065 #[test]
4066 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
4067 let raw = OAuthSsrfAllowlist {
4068 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
4069 cidrs: vec![],
4070 };
4071 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
4072 assert_eq!(compiled.host_count(), 1);
4073 assert!(compiled.host_allowed("rhbk.ops.example.com"));
4074 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
4075 }
4076
4077 #[test]
4078 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
4079 let raw = OAuthSsrfAllowlist {
4080 hosts: vec!["10.0.0.1".into()],
4081 cidrs: vec![],
4082 };
4083 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
4084 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
4085 }
4086
4087 #[test]
4088 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
4089 let raw = OAuthSsrfAllowlist {
4090 hosts: vec!["rhbk.ops.example.com:8443".into()],
4091 cidrs: vec![],
4092 };
4093 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
4094 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
4095 }
4096
4097 #[test]
4098 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
4099 let raw = OAuthSsrfAllowlist {
4100 hosts: vec![],
4101 cidrs: vec!["not-a-cidr".into()],
4102 };
4103 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
4104 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
4105 }
4106
4107 #[test]
4108 fn validate_rejects_misconfigured_allowlist() {
4109 let mut cfg = OAuthConfig::builder(
4110 "https://auth.example.com/",
4111 "mcp",
4112 "https://auth.example.com/jwks.json",
4113 )
4114 .build();
4115 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4116 hosts: vec!["10.0.0.1".into()],
4117 cidrs: vec![],
4118 });
4119 let err = cfg
4120 .validate()
4121 .expect_err("literal IP host must be rejected");
4122 assert!(
4123 err.to_string().contains("oauth.ssrf_allowlist"),
4124 "got {err}"
4125 );
4126 }
4127
4128 #[tokio::test]
4129 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
4130 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
4134 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
4135 .await
4136 .expect_err("loopback must still be blocked when not in allowlist");
4137 let msg = err.to_string();
4138 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
4139 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4140 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
4141 }
4142
4143 #[tokio::test]
4144 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
4145 let err = screen_oauth_target(
4148 "https://localhost/jwks.json",
4149 false,
4150 &crate::ssrf::CompiledSsrfAllowlist::default(),
4151 )
4152 .await
4153 .expect_err("loopback rejection");
4154 let msg = err.to_string();
4155 assert!(msg.contains("blocked IP"), "got {msg:?}");
4156 assert!(msg.contains("loopback"), "got {msg:?}");
4157 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4159 }
4160
4161 #[tokio::test]
4162 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
4163 let allow = make_allowlist(&["localhost"], &[]);
4165 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4166 .await
4167 .expect("allowlisted host must pass");
4168 }
4169
4170 #[tokio::test]
4171 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
4172 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
4175 screen_oauth_target("https://localhost/jwks.json", false, &allow)
4176 .await
4177 .expect("allowlisted CIDR must pass");
4178 }
4179
4180 #[tokio::test]
4181 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
4182 let mut cfg = OAuthConfig::builder(
4183 "https://auth.example.com/",
4184 "mcp",
4185 "https://auth.example.com/jwks.json",
4186 )
4187 .build();
4188 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
4189 hosts: vec![],
4190 cidrs: vec!["bad-cidr".into()],
4191 });
4192 let Err(err) = JwksCache::new(&cfg) else {
4193 panic!("invalid CIDR must fail JwksCache::new")
4194 };
4195 let msg = err.to_string();
4196 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
4197 }
4198
4199 #[tokio::test]
4200 async fn audience_falls_back_to_azp_by_default() {
4201 let kid = "test-audience-azp-default";
4202 let (pem, jwks) = generate_test_keypair(kid);
4203
4204 let mock_server = wiremock::MockServer::start().await;
4205 wiremock::Mock::given(wiremock::matchers::method("GET"))
4206 .and(wiremock::matchers::path("/jwks.json"))
4207 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4208 .mount(&mock_server)
4209 .await;
4210
4211 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4212 let config = test_config(&jwks_uri);
4213 let cache = test_cache(&config);
4214
4215 let now = jsonwebtoken::get_current_timestamp();
4216 let token = mint_token_with_claims(
4217 &pem,
4218 kid,
4219 &serde_json::json!({
4220 "iss": "https://auth.test.local",
4221 "aud": "https://some-other-resource.example.com",
4222 "azp": "https://mcp.test.local/mcp",
4223 "sub": "compat-client",
4224 "scope": "mcp:read",
4225 "exp": now + 3600,
4226 "iat": now,
4227 }),
4228 );
4229
4230 let identity = cache
4231 .validate_token_with_reason(&token)
4232 .await
4233 .expect("azp fallback should remain enabled by default");
4234 assert_eq!(identity.role, "viewer");
4235 }
4236
4237 #[tokio::test]
4238 async fn strict_audience_validation_rejects_azp_only_match() {
4239 let kid = "test-audience-azp-strict";
4240 let (pem, jwks) = generate_test_keypair(kid);
4241
4242 let mock_server = wiremock::MockServer::start().await;
4243 wiremock::Mock::given(wiremock::matchers::method("GET"))
4244 .and(wiremock::matchers::path("/jwks.json"))
4245 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4246 .mount(&mock_server)
4247 .await;
4248
4249 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4250 let mut config = test_config(&jwks_uri);
4251 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4252 {
4253 config.strict_audience_validation = true;
4254 }
4255 let cache = test_cache(&config);
4256
4257 let now = jsonwebtoken::get_current_timestamp();
4258 let token = mint_token_with_claims(
4259 &pem,
4260 kid,
4261 &serde_json::json!({
4262 "iss": "https://auth.test.local",
4263 "aud": "https://some-other-resource.example.com",
4264 "azp": "https://mcp.test.local/mcp",
4265 "sub": "strict-client",
4266 "scope": "mcp:read",
4267 "exp": now + 3600,
4268 "iat": now,
4269 }),
4270 );
4271
4272 let failure = cache
4273 .validate_token_with_reason(&token)
4274 .await
4275 .expect_err("strict audience validation must ignore azp fallback");
4276 assert_eq!(failure, JwtValidationFailure::Invalid);
4277 }
4278
4279 #[tokio::test]
4280 async fn warn_mode_accepts_azp_only_match_and_warns_once() {
4281 let kid = "test-audience-warn-mode";
4282 let (pem, jwks) = generate_test_keypair(kid);
4283
4284 let mock_server = wiremock::MockServer::start().await;
4285 wiremock::Mock::given(wiremock::matchers::method("GET"))
4286 .and(wiremock::matchers::path("/jwks.json"))
4287 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4288 .mount(&mock_server)
4289 .await;
4290
4291 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4292 let mut config = test_config(&jwks_uri);
4293 config.audience_validation_mode = Some(AudienceValidationMode::Warn);
4294 let cache = test_cache(&config);
4295
4296 let now = jsonwebtoken::get_current_timestamp();
4297 let claims = serde_json::json!({
4298 "iss": "https://auth.test.local",
4299 "aud": "https://some-other-resource.example.com",
4300 "azp": "https://mcp.test.local/mcp",
4301 "sub": "warn-client",
4302 "scope": "mcp:read",
4303 "exp": now + 3600,
4304 "iat": now,
4305 });
4306 let token = mint_token_with_claims(&pem, kid, &claims);
4307
4308 let identity = cache
4309 .validate_token_with_reason(&token)
4310 .await
4311 .expect("warn mode must accept azp-only match");
4312 assert_eq!(identity.role, "viewer");
4313 assert!(
4314 cache.azp_fallback_warned.load(Ordering::Relaxed),
4315 "warn-once flag should be set after first azp-only match"
4316 );
4317
4318 let token2 = mint_token_with_claims(&pem, kid, &claims);
4319 cache
4320 .validate_token_with_reason(&token2)
4321 .await
4322 .expect("warn mode must continue accepting subsequent matches");
4323 assert!(
4324 cache.azp_fallback_warned.load(Ordering::Relaxed),
4325 "warn-once flag must remain set; the assertion guards against accidental clearing"
4326 );
4327 }
4328
4329 #[tokio::test]
4330 async fn permissive_mode_accepts_azp_only_match_silently() {
4331 let kid = "test-audience-permissive-mode";
4332 let (pem, jwks) = generate_test_keypair(kid);
4333
4334 let mock_server = wiremock::MockServer::start().await;
4335 wiremock::Mock::given(wiremock::matchers::method("GET"))
4336 .and(wiremock::matchers::path("/jwks.json"))
4337 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4338 .mount(&mock_server)
4339 .await;
4340
4341 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4342 let mut config = test_config(&jwks_uri);
4343 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4344 let cache = test_cache(&config);
4345
4346 let now = jsonwebtoken::get_current_timestamp();
4347 let token = mint_token_with_claims(
4348 &pem,
4349 kid,
4350 &serde_json::json!({
4351 "iss": "https://auth.test.local",
4352 "aud": "https://some-other-resource.example.com",
4353 "azp": "https://mcp.test.local/mcp",
4354 "sub": "permissive-client",
4355 "scope": "mcp:read",
4356 "exp": now + 3600,
4357 "iat": now,
4358 }),
4359 );
4360
4361 cache
4362 .validate_token_with_reason(&token)
4363 .await
4364 .expect("permissive mode must accept azp-only match");
4365 assert!(
4366 !cache.azp_fallback_warned.load(Ordering::Relaxed),
4367 "permissive mode must not flip the warn-once flag"
4368 );
4369 }
4370
4371 #[test]
4372 fn audience_validation_mode_overrides_legacy_bool() {
4373 let mut config = OAuthConfig::default();
4374 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4375 {
4376 config.strict_audience_validation = false;
4377 }
4378 config.audience_validation_mode = Some(AudienceValidationMode::Strict);
4379 assert_eq!(
4380 config.effective_audience_validation_mode(),
4381 AudienceValidationMode::Strict,
4382 "explicit mode must override legacy false"
4383 );
4384
4385 let mut config = OAuthConfig::default();
4386 #[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
4387 {
4388 config.strict_audience_validation = true;
4389 }
4390 config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
4391 assert_eq!(
4392 config.effective_audience_validation_mode(),
4393 AudienceValidationMode::Permissive,
4394 "explicit mode must override legacy true"
4395 );
4396 }
4397
4398 #[test]
4399 fn audience_validation_mode_default_is_warn_when_unset() {
4400 let config = OAuthConfig::default();
4401 assert_eq!(
4402 config.effective_audience_validation_mode(),
4403 AudienceValidationMode::Warn,
4404 "unset mode + unset bool must resolve to Warn (the new default)"
4405 );
4406 }
4407
4408 #[test]
4409 fn audience_validation_legacy_bool_true_resolves_to_strict() {
4410 let mut config = OAuthConfig::default();
4411 #[allow(deprecated, reason = "covers the legacy bool resolution path")]
4412 {
4413 config.strict_audience_validation = true;
4414 }
4415 assert_eq!(
4416 config.effective_audience_validation_mode(),
4417 AudienceValidationMode::Strict,
4418 "legacy bool=true must resolve to Strict for backward compat"
4419 );
4420 }
4421
4422 #[derive(Clone, Default)]
4423 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
4424
4425 impl CapturedLogs {
4426 fn contents(&self) -> String {
4427 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
4428 String::from_utf8(bytes).unwrap_or_default()
4429 }
4430 }
4431
4432 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
4433
4434 impl std::io::Write for CapturedLogsWriter {
4435 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
4436 if let Ok(mut guard) = self.0.lock() {
4437 guard.extend_from_slice(buf);
4438 }
4439 Ok(buf.len())
4440 }
4441
4442 fn flush(&mut self) -> std::io::Result<()> {
4443 Ok(())
4444 }
4445 }
4446
4447 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
4448 type Writer = CapturedLogsWriter;
4449
4450 fn make_writer(&'a self) -> Self::Writer {
4451 CapturedLogsWriter(Arc::clone(&self.0))
4452 }
4453 }
4454
4455 #[tokio::test]
4456 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
4457 let kid = "oversized-jwks";
4458 let (_pem, jwks) = generate_test_keypair(kid);
4459 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
4460 oversized_body.push_str(&" ".repeat(4096));
4461
4462 let mock_server = wiremock::MockServer::start().await;
4463 wiremock::Mock::given(wiremock::matchers::method("GET"))
4464 .and(wiremock::matchers::path("/jwks.json"))
4465 .respond_with(
4466 wiremock::ResponseTemplate::new(200)
4467 .insert_header("content-type", "application/json")
4468 .set_body_string(oversized_body),
4469 )
4470 .mount(&mock_server)
4471 .await;
4472
4473 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4474 let mut config = test_config(&jwks_uri);
4475 config.jwks_max_response_bytes = 256;
4476 let cache = test_cache(&config);
4477
4478 let logs = CapturedLogs::default();
4479 let subscriber = tracing_subscriber::fmt()
4480 .with_writer(logs.clone())
4481 .with_ansi(false)
4482 .without_time()
4483 .finish();
4484 let _guard = tracing::subscriber::set_default(subscriber);
4485
4486 let result = cache.fetch_jwks().await;
4487 assert!(result.is_none(), "oversized JWKS must be dropped");
4488 assert!(
4489 logs.contents()
4490 .contains("JWKS response exceeded configured size cap"),
4491 "expected cap-exceeded warning in logs"
4492 );
4493 }
4494
4495 #[tokio::test]
4496 async fn role_claim_keycloak_nested_array() {
4497 let kid = "test-role-1";
4498 let (pem, jwks) = generate_test_keypair(kid);
4499
4500 let mock_server = wiremock::MockServer::start().await;
4501 wiremock::Mock::given(wiremock::matchers::method("GET"))
4502 .and(wiremock::matchers::path("/jwks.json"))
4503 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4504 .mount(&mock_server)
4505 .await;
4506
4507 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4508 let config = test_config_with_role_claim(
4509 &jwks_uri,
4510 "realm_access.roles",
4511 vec![
4512 RoleMapping {
4513 claim_value: "mcp-admin".into(),
4514 role: "ops".into(),
4515 },
4516 RoleMapping {
4517 claim_value: "mcp-viewer".into(),
4518 role: "viewer".into(),
4519 },
4520 ],
4521 );
4522 let cache = test_cache(&config);
4523
4524 let now = jsonwebtoken::get_current_timestamp();
4525 let token = mint_token_with_claims(
4526 &pem,
4527 kid,
4528 &serde_json::json!({
4529 "iss": "https://auth.test.local",
4530 "aud": "https://mcp.test.local/mcp",
4531 "sub": "keycloak-user",
4532 "exp": now + 3600,
4533 "iat": now,
4534 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
4535 }),
4536 );
4537
4538 let id = cache
4539 .validate_token(&token)
4540 .await
4541 .expect("should authenticate");
4542 assert_eq!(id.name, "keycloak-user");
4543 assert_eq!(id.role, "ops");
4544 }
4545
4546 #[tokio::test]
4547 async fn role_claim_flat_roles_array() {
4548 let kid = "test-role-2";
4549 let (pem, jwks) = generate_test_keypair(kid);
4550
4551 let mock_server = wiremock::MockServer::start().await;
4552 wiremock::Mock::given(wiremock::matchers::method("GET"))
4553 .and(wiremock::matchers::path("/jwks.json"))
4554 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4555 .mount(&mock_server)
4556 .await;
4557
4558 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4559 let config = test_config_with_role_claim(
4560 &jwks_uri,
4561 "roles",
4562 vec![
4563 RoleMapping {
4564 claim_value: "MCP.Admin".into(),
4565 role: "ops".into(),
4566 },
4567 RoleMapping {
4568 claim_value: "MCP.Reader".into(),
4569 role: "viewer".into(),
4570 },
4571 ],
4572 );
4573 let cache = test_cache(&config);
4574
4575 let now = jsonwebtoken::get_current_timestamp();
4576 let token = mint_token_with_claims(
4577 &pem,
4578 kid,
4579 &serde_json::json!({
4580 "iss": "https://auth.test.local",
4581 "aud": "https://mcp.test.local/mcp",
4582 "sub": "azure-ad-user",
4583 "exp": now + 3600,
4584 "iat": now,
4585 "roles": ["MCP.Reader", "OtherApp.Admin"]
4586 }),
4587 );
4588
4589 let id = cache
4590 .validate_token(&token)
4591 .await
4592 .expect("should authenticate");
4593 assert_eq!(id.name, "azure-ad-user");
4594 assert_eq!(id.role, "viewer");
4595 }
4596
4597 #[tokio::test]
4598 async fn role_claim_no_matching_value_rejected() {
4599 let kid = "test-role-3";
4600 let (pem, jwks) = generate_test_keypair(kid);
4601
4602 let mock_server = wiremock::MockServer::start().await;
4603 wiremock::Mock::given(wiremock::matchers::method("GET"))
4604 .and(wiremock::matchers::path("/jwks.json"))
4605 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4606 .mount(&mock_server)
4607 .await;
4608
4609 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4610 let config = test_config_with_role_claim(
4611 &jwks_uri,
4612 "roles",
4613 vec![RoleMapping {
4614 claim_value: "mcp-admin".into(),
4615 role: "ops".into(),
4616 }],
4617 );
4618 let cache = test_cache(&config);
4619
4620 let now = jsonwebtoken::get_current_timestamp();
4621 let token = mint_token_with_claims(
4622 &pem,
4623 kid,
4624 &serde_json::json!({
4625 "iss": "https://auth.test.local",
4626 "aud": "https://mcp.test.local/mcp",
4627 "sub": "limited-user",
4628 "exp": now + 3600,
4629 "iat": now,
4630 "roles": ["some-other-role"]
4631 }),
4632 );
4633
4634 assert!(cache.validate_token(&token).await.is_none());
4635 }
4636
4637 #[tokio::test]
4638 async fn role_claim_space_separated_string() {
4639 let kid = "test-role-4";
4640 let (pem, jwks) = generate_test_keypair(kid);
4641
4642 let mock_server = wiremock::MockServer::start().await;
4643 wiremock::Mock::given(wiremock::matchers::method("GET"))
4644 .and(wiremock::matchers::path("/jwks.json"))
4645 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4646 .mount(&mock_server)
4647 .await;
4648
4649 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4650 let config = test_config_with_role_claim(
4651 &jwks_uri,
4652 "custom_scope",
4653 vec![
4654 RoleMapping {
4655 claim_value: "write".into(),
4656 role: "ops".into(),
4657 },
4658 RoleMapping {
4659 claim_value: "read".into(),
4660 role: "viewer".into(),
4661 },
4662 ],
4663 );
4664 let cache = test_cache(&config);
4665
4666 let now = jsonwebtoken::get_current_timestamp();
4667 let token = mint_token_with_claims(
4668 &pem,
4669 kid,
4670 &serde_json::json!({
4671 "iss": "https://auth.test.local",
4672 "aud": "https://mcp.test.local/mcp",
4673 "sub": "custom-client",
4674 "exp": now + 3600,
4675 "iat": now,
4676 "custom_scope": "read audit"
4677 }),
4678 );
4679
4680 let id = cache
4681 .validate_token(&token)
4682 .await
4683 .expect("should authenticate");
4684 assert_eq!(id.name, "custom-client");
4685 assert_eq!(id.role, "viewer");
4686 }
4687
4688 #[tokio::test]
4689 async fn scope_backward_compat_without_role_claim() {
4690 let kid = "test-compat-1";
4692 let (pem, jwks) = generate_test_keypair(kid);
4693
4694 let mock_server = wiremock::MockServer::start().await;
4695 wiremock::Mock::given(wiremock::matchers::method("GET"))
4696 .and(wiremock::matchers::path("/jwks.json"))
4697 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4698 .mount(&mock_server)
4699 .await;
4700
4701 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4702 let config = test_config(&jwks_uri); let cache = test_cache(&config);
4704
4705 let token = mint_token(
4706 &pem,
4707 kid,
4708 "https://auth.test.local",
4709 "https://mcp.test.local/mcp",
4710 "legacy-bot",
4711 "mcp:admin other:scope",
4712 );
4713
4714 let id = cache
4715 .validate_token(&token)
4716 .await
4717 .expect("should authenticate");
4718 assert_eq!(id.name, "legacy-bot");
4719 assert_eq!(id.role, "ops"); }
4721
4722 #[tokio::test]
4727 async fn jwks_refresh_deduplication() {
4728 let kid = "test-dedup";
4731 let (pem, jwks) = generate_test_keypair(kid);
4732
4733 let mock_server = wiremock::MockServer::start().await;
4734 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4735 .and(wiremock::matchers::path("/jwks.json"))
4736 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4737 .expect(1) .mount(&mock_server)
4739 .await;
4740
4741 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4742 let config = test_config(&jwks_uri);
4743 let cache = Arc::new(test_cache(&config));
4744
4745 let token = mint_token(
4747 &pem,
4748 kid,
4749 "https://auth.test.local",
4750 "https://mcp.test.local/mcp",
4751 "concurrent-bot",
4752 "mcp:read",
4753 );
4754
4755 let mut handles = Vec::new();
4756 for _ in 0..5 {
4757 let c = Arc::clone(&cache);
4758 let t = token.clone();
4759 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
4760 }
4761
4762 for h in handles {
4763 let result = h.await.unwrap();
4764 assert!(result.is_some(), "all concurrent requests should succeed");
4765 }
4766
4767 }
4769
4770 #[tokio::test]
4771 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
4772 let kid = "test-cooldown";
4775 let (_pem, jwks) = generate_test_keypair(kid);
4776
4777 let mock_server = wiremock::MockServer::start().await;
4778 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
4779 .and(wiremock::matchers::path("/jwks.json"))
4780 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
4781 .expect(1) .mount(&mock_server)
4783 .await;
4784
4785 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
4786 let config = test_config(&jwks_uri);
4787 let cache = test_cache(&config);
4788
4789 let fake_token1 =
4791 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
4792 let _ = cache.validate_token(fake_token1).await;
4793
4794 let fake_token2 =
4797 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
4798 let _ = cache.validate_token(fake_token2).await;
4799
4800 let fake_token3 =
4802 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
4803 let _ = cache.validate_token(fake_token3).await;
4804
4805 }
4807
4808 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
4811 OAuthProxyConfig {
4812 authorize_url: "https://example.invalid/auth".into(),
4813 token_url: token_url.into(),
4814 client_id: "mcp-client".into(),
4815 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
4816 introspection_url: None,
4817 revocation_url: None,
4818 expose_admin_endpoints: false,
4819 require_auth_on_admin_endpoints: false,
4820 allow_unauthenticated_admin_endpoints: false,
4821 }
4822 }
4823
4824 fn test_http_client() -> OauthHttpClient {
4827 rustls::crypto::ring::default_provider()
4828 .install_default()
4829 .ok();
4830 let config = OAuthConfig::builder(
4831 "https://auth.test.local",
4832 "https://mcp.test.local/mcp",
4833 "https://auth.test.local/.well-known/jwks.json",
4834 )
4835 .allow_http_oauth_urls(true)
4836 .build();
4837 OauthHttpClient::with_config(&config)
4838 .expect("build test http client")
4839 .__test_allow_loopback_ssrf()
4840 }
4841
4842 #[tokio::test]
4843 async fn introspect_proxies_and_injects_client_credentials() {
4844 use wiremock::matchers::{body_string_contains, method, path};
4845
4846 let mock_server = wiremock::MockServer::start().await;
4847 wiremock::Mock::given(method("POST"))
4848 .and(path("/introspect"))
4849 .and(body_string_contains("client_id=mcp-client"))
4850 .and(body_string_contains("client_secret=shh"))
4851 .and(body_string_contains("token=abc"))
4852 .respond_with(
4853 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
4854 "active": true,
4855 "scope": "read"
4856 })),
4857 )
4858 .expect(1)
4859 .mount(&mock_server)
4860 .await;
4861
4862 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4863 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
4864
4865 let http = test_http_client();
4866 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4867 assert_eq!(resp.status(), 200);
4868 }
4869
4870 #[tokio::test]
4871 async fn introspect_returns_404_when_not_configured() {
4872 let proxy = proxy_cfg("https://example.invalid/token");
4873 let http = test_http_client();
4874 let resp = handle_introspect(&http, &proxy, "token=abc").await;
4875 assert_eq!(resp.status(), 404);
4876 }
4877
4878 #[tokio::test]
4879 async fn revoke_proxies_and_returns_upstream_status() {
4880 use wiremock::matchers::{method, path};
4881
4882 let mock_server = wiremock::MockServer::start().await;
4883 wiremock::Mock::given(method("POST"))
4884 .and(path("/revoke"))
4885 .respond_with(wiremock::ResponseTemplate::new(200))
4886 .expect(1)
4887 .mount(&mock_server)
4888 .await;
4889
4890 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4891 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4892
4893 let http = test_http_client();
4894 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4895 assert_eq!(resp.status(), 200);
4896 }
4897
4898 #[tokio::test]
4899 async fn revoke_returns_404_when_not_configured() {
4900 let proxy = proxy_cfg("https://example.invalid/token");
4901 let http = test_http_client();
4902 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4903 assert_eq!(resp.status(), 404);
4904 }
4905
4906 #[test]
4907 fn metadata_advertises_endpoints_only_when_configured() {
4908 let mut cfg = test_config("https://auth.test.local/jwks.json");
4909 let m = authorization_server_metadata("https://mcp.local", &cfg);
4911 assert!(m.get("introspection_endpoint").is_none());
4912 assert!(m.get("revocation_endpoint").is_none());
4913
4914 let mut proxy = proxy_cfg("https://upstream.local/token");
4917 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4918 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4919 cfg.proxy = Some(proxy);
4920 let m = authorization_server_metadata("https://mcp.local", &cfg);
4921 assert!(
4922 m.get("introspection_endpoint").is_none(),
4923 "introspection must not be advertised when expose_admin_endpoints=false"
4924 );
4925 assert!(
4926 m.get("revocation_endpoint").is_none(),
4927 "revocation must not be advertised when expose_admin_endpoints=false"
4928 );
4929
4930 if let Some(p) = cfg.proxy.as_mut() {
4932 p.expose_admin_endpoints = true;
4933 p.revocation_url = None;
4934 }
4935 let m = authorization_server_metadata("https://mcp.local", &cfg);
4936 assert_eq!(
4937 m["introspection_endpoint"],
4938 serde_json::Value::String("https://mcp.local/introspect".into())
4939 );
4940 assert!(m.get("revocation_endpoint").is_none());
4941
4942 if let Some(p) = cfg.proxy.as_mut() {
4944 p.revocation_url = Some("https://upstream.local/revoke".into());
4945 }
4946 let m = authorization_server_metadata("https://mcp.local", &cfg);
4947 assert_eq!(
4948 m["revocation_endpoint"],
4949 serde_json::Value::String("https://mcp.local/revoke".into())
4950 );
4951 }
4952
4953 fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
4956 let mut cfg = validation_https_config();
4957 cfg.token_exchange = Some(tx);
4958 cfg
4959 }
4960
4961 fn tx_with(
4962 client_secret: Option<&str>,
4963 client_cert: Option<ClientCertConfig>,
4964 ) -> TokenExchangeConfig {
4965 TokenExchangeConfig::new(
4966 "https://idp.example.com/token".into(),
4967 "client".into(),
4968 client_secret.map(|s| secrecy::SecretString::new(s.into())),
4969 client_cert,
4970 "downstream".into(),
4971 )
4972 }
4973
4974 #[test]
4975 fn validate_rejects_token_exchange_without_client_auth() {
4976 let cfg = https_cfg_with_tx(tx_with(None, None));
4977 let err = cfg
4978 .validate()
4979 .expect_err("token_exchange without client auth must be rejected");
4980 let msg = err.to_string();
4981 assert!(
4982 msg.contains("requires client authentication"),
4983 "error must explain missing client auth; got {msg:?}"
4984 );
4985 }
4986
4987 #[test]
4988 fn validate_rejects_token_exchange_with_both_secret_and_cert() {
4989 let cc = ClientCertConfig {
4990 cert_path: PathBuf::from("/nonexistent/cert.pem"),
4991 key_path: PathBuf::from("/nonexistent/key.pem"),
4992 };
4993 let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
4994 let err = cfg
4995 .validate()
4996 .expect_err("client_secret + client_cert must be rejected");
4997 let msg = err.to_string();
4998 assert!(
4999 msg.contains("mutually") && msg.contains("exclusive"),
5000 "error must explain mutual exclusion; got {msg:?}"
5001 );
5002 }
5003
5004 #[cfg(not(feature = "oauth-mtls-client"))]
5005 #[test]
5006 fn validate_rejects_client_cert_without_feature() {
5007 let cc = ClientCertConfig {
5008 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5009 key_path: PathBuf::from("/nonexistent/key.pem"),
5010 };
5011 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5012 let err = cfg
5013 .validate()
5014 .expect_err("client_cert without feature must be rejected");
5015 assert!(
5016 err.to_string().contains("oauth-mtls-client"),
5017 "error must reference the cargo feature; got {err}"
5018 );
5019 }
5020
5021 #[cfg(feature = "oauth-mtls-client")]
5022 #[test]
5023 fn validate_rejects_missing_client_cert_files() {
5024 let cc = ClientCertConfig {
5025 cert_path: PathBuf::from("/nonexistent/cert.pem"),
5026 key_path: PathBuf::from("/nonexistent/key.pem"),
5027 };
5028 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5029 let err = cfg
5030 .validate()
5031 .expect_err("missing cert file must be rejected");
5032 assert!(
5033 err.to_string().contains("unreadable"),
5034 "error must call out unreadable file; got {err}"
5035 );
5036 }
5037
5038 #[cfg(feature = "oauth-mtls-client")]
5039 #[test]
5040 fn validate_rejects_malformed_client_cert_pem() {
5041 let dir = std::env::temp_dir();
5042 let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
5043 let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
5044 std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
5045 std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
5046 let cc = ClientCertConfig {
5047 cert_path: cert.clone(),
5048 key_path: key.clone(),
5049 };
5050 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5051 let err = cfg.validate().expect_err("malformed PEM must be rejected");
5052 let _ = std::fs::remove_file(&cert);
5053 let _ = std::fs::remove_file(&key);
5054 assert!(
5055 err.to_string().contains("PEM parse failed"),
5056 "error must call out PEM parse failure; got {err}"
5057 );
5058 }
5059
5060 #[cfg(feature = "oauth-mtls-client")]
5061 fn write_self_signed_pem() -> (PathBuf, PathBuf) {
5062 let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
5063 let dir = std::env::temp_dir();
5064 let pid = std::process::id();
5065 let nonce: u64 = rand::random();
5066 let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
5067 let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
5068 std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
5069 std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
5070 (cert_path, key_path)
5071 }
5072
5073 #[cfg(feature = "oauth-mtls-client")]
5074 fn install_test_crypto_provider() {
5075 let _ = rustls::crypto::ring::default_provider().install_default();
5076 }
5077
5078 #[cfg(feature = "oauth-mtls-client")]
5079 #[test]
5080 fn validate_accepts_well_formed_client_cert() {
5081 install_test_crypto_provider();
5082 let (cert_path, key_path) = write_self_signed_pem();
5083 let cc = ClientCertConfig {
5084 cert_path: cert_path.clone(),
5085 key_path: key_path.clone(),
5086 };
5087 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5088 let res = cfg.validate();
5089 let _ = std::fs::remove_file(&cert_path);
5090 let _ = std::fs::remove_file(&key_path);
5091 res.expect("well-formed cert+key must validate");
5092 }
5093
5094 #[cfg(feature = "oauth-mtls-client")]
5095 #[test]
5096 fn client_for_returns_cached_mtls_client() {
5097 install_test_crypto_provider();
5098 let (cert_path, key_path) = write_self_signed_pem();
5099 let cc = ClientCertConfig {
5100 cert_path: cert_path.clone(),
5101 key_path: key_path.clone(),
5102 };
5103 let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
5104 let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
5105 let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
5106 let cert_client = http.client_for(tx_ref);
5107 let inner_client = http.client_for(&tx_with(Some("s"), None));
5108 let _ = std::fs::remove_file(&cert_path);
5109 let _ = std::fs::remove_file(&key_path);
5110 assert!(
5111 !std::ptr::eq(cert_client, inner_client),
5112 "client_for must return distinct clients for cert vs no-cert configs"
5113 );
5114 }
5115
5116 #[cfg(feature = "oauth-mtls-client")]
5117 #[test]
5118 fn client_for_falls_back_to_inner_when_cache_miss() {
5119 install_test_crypto_provider();
5120 let cfg = validation_https_config();
5121 let http = OauthHttpClient::with_config(&cfg).expect("build client");
5122 let unrelated_cc = ClientCertConfig {
5123 cert_path: PathBuf::from("/cache/miss/cert.pem"),
5124 key_path: PathBuf::from("/cache/miss/key.pem"),
5125 };
5126 let tx_unknown = tx_with(None, Some(unrelated_cc));
5127 let fallback = http.client_for(&tx_unknown);
5128 let inner = http.client_for(&tx_with(Some("s"), None));
5129 assert!(
5130 std::ptr::eq(fallback, inner),
5131 "cache miss must fall back to inner client"
5132 );
5133 }
5134}