1use std::{
17 collections::HashMap,
18 path::PathBuf,
19 sync::Arc,
20 time::{Duration, Instant},
21};
22
23use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
24use serde::Deserialize;
25use tokio::{net::lookup_host, sync::RwLock};
26
27use crate::auth::{AuthIdentity, AuthMethod};
28
29fn evaluate_oauth_redirect(
55 attempt: &reqwest::redirect::Attempt<'_>,
56 allow_http: bool,
57 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
58) -> Result<(), String> {
59 let prev_https = attempt
60 .previous()
61 .last()
62 .is_some_and(|prev| prev.scheme() == "https");
63 let target_url = attempt.url();
64 let dest_scheme = target_url.scheme();
65 if dest_scheme != "https" {
66 if prev_https {
67 return Err("redirect downgrades https -> http".to_owned());
68 }
69 if !allow_http || dest_scheme != "http" {
70 return Err("redirect to non-HTTP(S) URL refused".to_owned());
71 }
72 }
73 if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
74 {
75 return Err(format!("redirect target forbidden: {reason}"));
76 }
77 if attempt.previous().len() >= 2 {
78 return Err("too many redirects (max 2)".to_owned());
79 }
80 Ok(())
81}
82
83#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
97async fn screen_oauth_target_with_test_override(
98 url: &str,
99 allow_http: bool,
100 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
101 #[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
102) -> Result<(), crate::error::McpxError> {
103 let parsed = check_oauth_url("oauth target", url, allow_http)?;
104 #[cfg(any(test, feature = "test-helpers"))]
105 if test_allow_loopback_ssrf {
106 return Ok(());
107 }
108 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
109 return Err(crate::error::McpxError::Config(format!(
110 "OAuth target forbidden ({reason}): {url}"
111 )));
112 }
113
114 let host = parsed.host_str().ok_or_else(|| {
115 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
116 })?;
117 let port = parsed.port_or_known_default().ok_or_else(|| {
118 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
119 })?;
120
121 let addrs = lookup_host((host, port)).await.map_err(|error| {
122 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
123 })?;
124
125 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
126 let mut any_addr = false;
127 for addr in addrs {
128 any_addr = true;
129 let ip = addr.ip();
130 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
131 if reason == "cloud_metadata" {
134 return Err(crate::error::McpxError::Config(format!(
135 "OAuth target resolved to blocked IP ({reason}): {url}"
136 )));
137 }
138 if allowlist.is_empty() {
142 return Err(crate::error::McpxError::Config(format!(
143 "OAuth target resolved to blocked IP ({reason}): {url}"
144 )));
145 }
146 if host_allowed || allowlist.ip_allowed(ip) {
148 continue;
149 }
150 return Err(crate::error::McpxError::Config(format!(
151 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
152 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
153 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
154 URL: {url}"
155 )));
156 }
157 }
158 if !any_addr {
159 return Err(crate::error::McpxError::Config(format!(
160 "OAuth target DNS resolution returned no addresses: {url}"
161 )));
162 }
163
164 Ok(())
165}
166
167async fn screen_oauth_target(
168 url: &str,
169 allow_http: bool,
170 allowlist: &crate::ssrf::CompiledSsrfAllowlist,
171) -> Result<(), crate::error::McpxError> {
172 #[cfg(any(test, feature = "test-helpers"))]
173 {
174 screen_oauth_target_with_test_override(url, allow_http, allowlist, false).await
175 }
176 #[cfg(not(any(test, feature = "test-helpers")))]
177 {
178 let parsed = check_oauth_url("oauth target", url, allow_http)?;
179 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
180 return Err(crate::error::McpxError::Config(format!(
181 "OAuth target forbidden ({reason}): {url}"
182 )));
183 }
184
185 let host = parsed.host_str().ok_or_else(|| {
186 crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
187 })?;
188 let port = parsed.port_or_known_default().ok_or_else(|| {
189 crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
190 })?;
191
192 let addrs = lookup_host((host, port)).await.map_err(|error| {
193 crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
194 })?;
195
196 let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
197 let mut any_addr = false;
198 for addr in addrs {
199 any_addr = true;
200 let ip = addr.ip();
201 if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
202 if reason == "cloud_metadata" {
203 return Err(crate::error::McpxError::Config(format!(
204 "OAuth target resolved to blocked IP ({reason}): {url}"
205 )));
206 }
207 if allowlist.is_empty() {
208 return Err(crate::error::McpxError::Config(format!(
209 "OAuth target resolved to blocked IP ({reason}): {url}"
210 )));
211 }
212 if host_allowed || allowlist.ip_allowed(ip) {
213 continue;
214 }
215 return Err(crate::error::McpxError::Config(format!(
216 "OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
217 To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
218 to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
219 URL: {url}"
220 )));
221 }
222 }
223 if !any_addr {
224 return Err(crate::error::McpxError::Config(format!(
225 "OAuth target DNS resolution returned no addresses: {url}"
226 )));
227 }
228
229 Ok(())
230 }
231}
232
233#[derive(Clone)]
274pub struct OauthHttpClient {
275 inner: reqwest::Client,
276 allow_http: bool,
277 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
282 #[cfg(any(test, feature = "test-helpers"))]
283 test_allow_loopback_ssrf: bool,
284}
285
286impl OauthHttpClient {
287 pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
305 Self::build(Some(config))
306 }
307
308 #[deprecated(
331 since = "1.2.1",
332 note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
333 )]
334 pub fn new() -> Result<Self, crate::error::McpxError> {
335 Self::build(None)
336 }
337
338 fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
341 let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
342
343 let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
348 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
349 crate::error::McpxError::Startup(format!("oauth http client: {e}"))
350 })?),
351 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
352 };
353
354 let redirect_allowlist = Arc::clone(&allowlist);
357
358 let mut builder = reqwest::Client::builder()
359 .connect_timeout(Duration::from_secs(10))
360 .timeout(Duration::from_secs(30))
361 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
362 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
372 Ok(()) => attempt.follow(),
373 Err(reason) => {
374 tracing::warn!(
375 reason = %reason,
376 target = %attempt.url(),
377 "oauth redirect rejected"
378 );
379 attempt.error(reason)
380 }
381 }
382 }));
383
384 if let Some(cfg) = config
385 && let Some(ref ca_path) = cfg.ca_cert_path
386 {
387 let pem = std::fs::read(ca_path).map_err(|e| {
392 crate::error::McpxError::Startup(format!(
393 "oauth http client: read ca_cert_path {}: {e}",
394 ca_path.display()
395 ))
396 })?;
397 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
398 crate::error::McpxError::Startup(format!(
399 "oauth http client: parse ca_cert_path {}: {e}",
400 ca_path.display()
401 ))
402 })?;
403 builder = builder.add_root_certificate(cert);
404 }
405
406 let inner = builder.build().map_err(|e| {
407 crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
408 })?;
409 Ok(Self {
410 inner,
411 allow_http,
412 allowlist,
413 #[cfg(any(test, feature = "test-helpers"))]
414 test_allow_loopback_ssrf: false,
415 })
416 }
417
418 async fn send_screened(
419 &self,
420 url: &str,
421 request: reqwest::RequestBuilder,
422 ) -> Result<reqwest::Response, crate::error::McpxError> {
423 #[cfg(any(test, feature = "test-helpers"))]
424 if self.test_allow_loopback_ssrf {
425 screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
426 .await?;
427 } else {
428 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
429 }
430 #[cfg(not(any(test, feature = "test-helpers")))]
431 screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
432 request.send().await.map_err(|error| {
433 crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
434 })
435 }
436
437 #[cfg(any(test, feature = "test-helpers"))]
442 #[doc(hidden)]
443 #[must_use]
444 pub fn __test_allow_loopback_ssrf(mut self) -> Self {
445 self.test_allow_loopback_ssrf = true;
446 self
447 }
448
449 #[doc(hidden)]
455 pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
456 self.inner.get(url).send().await
457 }
458}
459
460impl std::fmt::Debug for OauthHttpClient {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 f.debug_struct("OauthHttpClient").finish_non_exhaustive()
463 }
464}
465
466#[derive(Debug, Clone, Default, Deserialize)]
526#[non_exhaustive]
527pub struct OAuthSsrfAllowlist {
528 #[serde(default)]
533 pub hosts: Vec<String>,
534 #[serde(default)]
540 pub cidrs: Vec<String>,
541}
542
543fn compile_oauth_ssrf_allowlist(
550 raw: &OAuthSsrfAllowlist,
551) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
552 let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
553 for (idx, entry) in raw.hosts.iter().enumerate() {
554 let trimmed = entry.trim();
555 if trimmed.is_empty() {
556 return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
557 }
558 if trimmed.contains([':', '/', '@', '?', '#']) {
562 return Err(format!(
563 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
564 (no scheme, port, path, userinfo, query, or fragment)"
565 ));
566 }
567 match url::Host::parse(trimmed) {
568 Ok(url::Host::Domain(_)) => {}
569 Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
570 return Err(format!(
571 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
572 here -- list them via oauth.ssrf_allowlist.cidrs instead"
573 ));
574 }
575 Err(e) => {
576 return Err(format!(
577 "oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
578 ));
579 }
580 }
581 hosts.push(trimmed.to_ascii_lowercase());
582 }
583 hosts.sort();
584 hosts.dedup();
585
586 let mut cidrs = Vec::with_capacity(raw.cidrs.len());
587 for (idx, entry) in raw.cidrs.iter().enumerate() {
588 let parsed = crate::ssrf::CidrEntry::parse(entry)
589 .map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
590 cidrs.push(parsed);
591 }
592
593 Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
594}
595
596#[derive(Debug, Clone, Deserialize)]
598#[non_exhaustive]
599pub struct OAuthConfig {
600 pub issuer: String,
602 pub audience: String,
604 pub jwks_uri: String,
606 #[serde(default)]
609 pub scopes: Vec<ScopeMapping>,
610 pub role_claim: Option<String>,
616 #[serde(default)]
619 pub role_mappings: Vec<RoleMapping>,
620 #[serde(default = "default_jwks_cache_ttl")]
623 pub jwks_cache_ttl: String,
624 pub proxy: Option<OAuthProxyConfig>,
628 pub token_exchange: Option<TokenExchangeConfig>,
633 #[serde(default)]
648 pub ca_cert_path: Option<PathBuf>,
649 #[serde(default)]
661 pub allow_http_oauth_urls: bool,
662 #[serde(default)]
671 pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
672 #[serde(default = "default_max_jwks_keys")]
676 pub max_jwks_keys: usize,
677 #[serde(default)]
684 pub strict_audience_validation: bool,
685 #[serde(default = "default_jwks_max_bytes")]
689 pub jwks_max_response_bytes: u64,
690}
691
692fn default_jwks_cache_ttl() -> String {
693 "10m".into()
694}
695
696const fn default_max_jwks_keys() -> usize {
697 256
698}
699
700const fn default_jwks_max_bytes() -> u64 {
701 1024 * 1024
702}
703
704impl Default for OAuthConfig {
705 fn default() -> Self {
706 Self {
707 issuer: String::new(),
708 audience: String::new(),
709 jwks_uri: String::new(),
710 scopes: Vec::new(),
711 role_claim: None,
712 role_mappings: Vec::new(),
713 jwks_cache_ttl: default_jwks_cache_ttl(),
714 proxy: None,
715 token_exchange: None,
716 ca_cert_path: None,
717 allow_http_oauth_urls: false,
718 max_jwks_keys: default_max_jwks_keys(),
719 strict_audience_validation: false,
720 jwks_max_response_bytes: default_jwks_max_bytes(),
721 ssrf_allowlist: None,
722 }
723 }
724}
725
726impl OAuthConfig {
727 pub fn builder(
733 issuer: impl Into<String>,
734 audience: impl Into<String>,
735 jwks_uri: impl Into<String>,
736 ) -> OAuthConfigBuilder {
737 OAuthConfigBuilder {
738 inner: Self {
739 issuer: issuer.into(),
740 audience: audience.into(),
741 jwks_uri: jwks_uri.into(),
742 ..Self::default()
743 },
744 }
745 }
746
747 pub fn validate(&self) -> Result<(), crate::error::McpxError> {
763 let allow_http = self.allow_http_oauth_urls;
764 let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
765 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
766 return Err(crate::error::McpxError::Config(format!(
767 "oauth.issuer forbidden ({reason})"
768 )));
769 }
770 let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
771 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
772 return Err(crate::error::McpxError::Config(format!(
773 "oauth.jwks_uri forbidden ({reason})"
774 )));
775 }
776 if let Some(proxy) = &self.proxy {
777 let url = check_oauth_url(
778 "oauth.proxy.authorize_url",
779 &proxy.authorize_url,
780 allow_http,
781 )?;
782 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
783 return Err(crate::error::McpxError::Config(format!(
784 "oauth.proxy.authorize_url forbidden ({reason})"
785 )));
786 }
787 let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
788 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
789 return Err(crate::error::McpxError::Config(format!(
790 "oauth.proxy.token_url forbidden ({reason})"
791 )));
792 }
793 if let Some(url) = &proxy.introspection_url {
794 let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
795 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
796 return Err(crate::error::McpxError::Config(format!(
797 "oauth.proxy.introspection_url forbidden ({reason})"
798 )));
799 }
800 }
801 if let Some(url) = &proxy.revocation_url {
802 let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
803 if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
804 return Err(crate::error::McpxError::Config(format!(
805 "oauth.proxy.revocation_url forbidden ({reason})"
806 )));
807 }
808 }
809 }
810 if let Some(tx) = &self.token_exchange {
811 let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
812 if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
813 return Err(crate::error::McpxError::Config(format!(
814 "oauth.token_exchange.token_url forbidden ({reason})"
815 )));
816 }
817 }
818 if let Some(raw) = &self.ssrf_allowlist {
822 let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
823 crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
824 })?;
825 if !compiled.is_empty() {
826 tracing::warn!(
827 host_count = compiled.host_count(),
828 cidr_count = compiled.cidr_count(),
829 "oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
830 are now reachable. Cloud-metadata addresses remain blocked. \
831 See SECURITY.md \"Operator allowlist\"."
832 );
833 }
834 }
835 Ok(())
836 }
837}
838
839fn check_oauth_url(
846 field: &str,
847 raw: &str,
848 allow_http: bool,
849) -> Result<url::Url, crate::error::McpxError> {
850 let parsed = url::Url::parse(raw).map_err(|e| {
851 crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
852 })?;
853 if !parsed.username().is_empty() || parsed.password().is_some() {
854 return Err(crate::error::McpxError::Config(format!(
855 "{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
856 )));
857 }
858 match parsed.scheme() {
859 "https" => Ok(parsed),
860 "http" if allow_http => Ok(parsed),
861 "http" => Err(crate::error::McpxError::Config(format!(
862 "{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
863 to override - strongly discouraged in production)"
864 ))),
865 other => Err(crate::error::McpxError::Config(format!(
866 "{field}: must use https scheme (got {other:?})"
867 ))),
868 }
869}
870
871#[derive(Debug, Clone)]
877#[must_use = "builders do nothing until `.build()` is called"]
878pub struct OAuthConfigBuilder {
879 inner: OAuthConfig,
880}
881
882impl OAuthConfigBuilder {
883 pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
885 self.inner.scopes = scopes;
886 self
887 }
888
889 pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
891 self.inner.scopes.push(ScopeMapping {
892 scope: scope.into(),
893 role: role.into(),
894 });
895 self
896 }
897
898 pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
901 self.inner.role_claim = Some(claim.into());
902 self
903 }
904
905 pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
907 self.inner.role_mappings = mappings;
908 self
909 }
910
911 pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
914 self.inner.role_mappings.push(RoleMapping {
915 claim_value: claim_value.into(),
916 role: role.into(),
917 });
918 self
919 }
920
921 pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
924 self.inner.jwks_cache_ttl = ttl.into();
925 self
926 }
927
928 pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
931 self.inner.proxy = Some(proxy);
932 self
933 }
934
935 pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
937 self.inner.token_exchange = Some(token_exchange);
938 self
939 }
940
941 pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
946 self.inner.ca_cert_path = Some(path.into());
947 self
948 }
949
950 pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
956 self.inner.allow_http_oauth_urls = allow;
957 self
958 }
959
960 pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
963 self.inner.strict_audience_validation = strict;
964 self
965 }
966
967 pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
969 self.inner.jwks_max_response_bytes = bytes;
970 self
971 }
972
973 pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
981 self.inner.ssrf_allowlist = Some(allowlist);
982 self
983 }
984
985 #[must_use]
987 pub fn build(self) -> OAuthConfig {
988 self.inner
989 }
990}
991
992#[derive(Debug, Clone, Deserialize)]
994#[non_exhaustive]
995pub struct ScopeMapping {
996 pub scope: String,
998 pub role: String,
1000}
1001
1002#[derive(Debug, Clone, Deserialize)]
1006#[non_exhaustive]
1007pub struct RoleMapping {
1008 pub claim_value: String,
1010 pub role: String,
1012}
1013
1014#[derive(Debug, Clone, Deserialize)]
1021#[non_exhaustive]
1022pub struct TokenExchangeConfig {
1023 pub token_url: String,
1026 pub client_id: String,
1028 pub client_secret: Option<secrecy::SecretString>,
1031 pub client_cert: Option<ClientCertConfig>,
1035 pub audience: String,
1039}
1040
1041impl TokenExchangeConfig {
1042 #[must_use]
1044 pub fn new(
1045 token_url: String,
1046 client_id: String,
1047 client_secret: Option<secrecy::SecretString>,
1048 client_cert: Option<ClientCertConfig>,
1049 audience: String,
1050 ) -> Self {
1051 Self {
1052 token_url,
1053 client_id,
1054 client_secret,
1055 client_cert,
1056 audience,
1057 }
1058 }
1059}
1060
1061#[derive(Debug, Clone, Deserialize)]
1064#[non_exhaustive]
1065pub struct ClientCertConfig {
1066 pub cert_path: PathBuf,
1068 pub key_path: PathBuf,
1070}
1071
1072#[derive(Debug, Deserialize)]
1074#[non_exhaustive]
1075pub struct ExchangedToken {
1076 pub access_token: String,
1078 pub expires_in: Option<u64>,
1080 pub issued_token_type: Option<String>,
1083}
1084
1085#[derive(Debug, Clone, Deserialize, Default)]
1092#[non_exhaustive]
1093pub struct OAuthProxyConfig {
1094 pub authorize_url: String,
1097 pub token_url: String,
1100 pub client_id: String,
1102 pub client_secret: Option<secrecy::SecretString>,
1104 #[serde(default)]
1108 pub introspection_url: Option<String>,
1109 #[serde(default)]
1113 pub revocation_url: Option<String>,
1114 #[serde(default)]
1126 pub expose_admin_endpoints: bool,
1127 #[serde(default)]
1133 pub require_auth_on_admin_endpoints: bool,
1134}
1135
1136impl OAuthProxyConfig {
1137 pub fn builder(
1145 authorize_url: impl Into<String>,
1146 token_url: impl Into<String>,
1147 client_id: impl Into<String>,
1148 ) -> OAuthProxyConfigBuilder {
1149 OAuthProxyConfigBuilder {
1150 inner: Self {
1151 authorize_url: authorize_url.into(),
1152 token_url: token_url.into(),
1153 client_id: client_id.into(),
1154 ..Self::default()
1155 },
1156 }
1157 }
1158}
1159
1160#[derive(Debug, Clone)]
1166#[must_use = "builders do nothing until `.build()` is called"]
1167pub struct OAuthProxyConfigBuilder {
1168 inner: OAuthProxyConfig,
1169}
1170
1171impl OAuthProxyConfigBuilder {
1172 pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
1174 self.inner.client_secret = Some(secret);
1175 self
1176 }
1177
1178 pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
1182 self.inner.introspection_url = Some(url.into());
1183 self
1184 }
1185
1186 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
1190 self.inner.revocation_url = Some(url.into());
1191 self
1192 }
1193
1194 pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
1202 self.inner.expose_admin_endpoints = expose;
1203 self
1204 }
1205
1206 pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
1209 self.inner.require_auth_on_admin_endpoints = require;
1210 self
1211 }
1212
1213 #[must_use]
1215 pub fn build(self) -> OAuthProxyConfig {
1216 self.inner
1217 }
1218}
1219
1220type JwksKeyCache = (
1228 HashMap<String, (Algorithm, DecodingKey)>,
1229 Vec<(Algorithm, DecodingKey)>,
1230);
1231
1232struct CachedKeys {
1233 keys: HashMap<String, (Algorithm, DecodingKey)>,
1235 unnamed_keys: Vec<(Algorithm, DecodingKey)>,
1237 fetched_at: Instant,
1238 ttl: Duration,
1239}
1240
1241impl CachedKeys {
1242 fn is_expired(&self) -> bool {
1243 self.fetched_at.elapsed() >= self.ttl
1244 }
1245}
1246
1247#[allow(
1256 missing_debug_implementations,
1257 reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
1258)]
1259#[non_exhaustive]
1260pub struct JwksCache {
1261 jwks_uri: String,
1262 ttl: Duration,
1263 max_jwks_keys: usize,
1264 max_response_bytes: u64,
1265 allow_http: bool,
1266 inner: RwLock<Option<CachedKeys>>,
1267 http: reqwest::Client,
1268 validation_template: Validation,
1269 expected_audience: String,
1273 strict_audience_validation: bool,
1274 scopes: Vec<ScopeMapping>,
1275 role_claim: Option<String>,
1276 role_mappings: Vec<RoleMapping>,
1277 last_refresh_attempt: RwLock<Option<Instant>>,
1280 refresh_lock: tokio::sync::Mutex<()>,
1282 allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
1286 #[cfg(any(test, feature = "test-helpers"))]
1287 test_allow_loopback_ssrf: bool,
1288}
1289
1290const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
1292
1293const ACCEPTED_ALGS: &[Algorithm] = &[
1295 Algorithm::RS256,
1296 Algorithm::RS384,
1297 Algorithm::RS512,
1298 Algorithm::ES256,
1299 Algorithm::ES384,
1300 Algorithm::PS256,
1301 Algorithm::PS384,
1302 Algorithm::PS512,
1303 Algorithm::EdDSA,
1304];
1305
1306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1308#[non_exhaustive]
1309pub enum JwtValidationFailure {
1310 Expired,
1312 Invalid,
1314}
1315
1316impl JwksCache {
1317 pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
1324 rustls::crypto::ring::default_provider()
1327 .install_default()
1328 .ok();
1329 jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
1330 .install_default()
1331 .ok();
1332
1333 let ttl =
1334 humantime::parse_duration(&config.jwks_cache_ttl).unwrap_or(Duration::from_mins(10));
1335
1336 let mut validation = Validation::new(Algorithm::RS256);
1337 validation.validate_aud = false;
1349 validation.set_issuer(&[&config.issuer]);
1350 validation.set_required_spec_claims(&["exp", "iss"]);
1351 validation.validate_exp = true;
1352 validation.validate_nbf = true;
1353
1354 let allow_http = config.allow_http_oauth_urls;
1355
1356 let allowlist = match config.ssrf_allowlist.as_ref() {
1359 Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
1360 Box::<dyn std::error::Error + Send + Sync>::from(format!(
1361 "oauth.ssrf_allowlist: {e}"
1362 ))
1363 })?),
1364 None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
1365 };
1366 let redirect_allowlist = Arc::clone(&allowlist);
1367
1368 let mut http_builder = reqwest::Client::builder()
1369 .timeout(Duration::from_secs(10))
1370 .connect_timeout(Duration::from_secs(3))
1371 .redirect(reqwest::redirect::Policy::custom(move |attempt| {
1372 match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
1382 Ok(()) => attempt.follow(),
1383 Err(reason) => {
1384 tracing::warn!(
1385 reason = %reason,
1386 target = %attempt.url(),
1387 "oauth redirect rejected"
1388 );
1389 attempt.error(reason)
1390 }
1391 }
1392 }));
1393
1394 if let Some(ref ca_path) = config.ca_cert_path {
1395 let pem = std::fs::read(ca_path)?;
1401 let cert = reqwest::tls::Certificate::from_pem(&pem)?;
1402 http_builder = http_builder.add_root_certificate(cert);
1403 }
1404
1405 let http = http_builder.build()?;
1406
1407 Ok(Self {
1408 jwks_uri: config.jwks_uri.clone(),
1409 ttl,
1410 max_jwks_keys: config.max_jwks_keys,
1411 max_response_bytes: config.jwks_max_response_bytes,
1412 allow_http,
1413 inner: RwLock::new(None),
1414 http,
1415 validation_template: validation,
1416 expected_audience: config.audience.clone(),
1417 strict_audience_validation: config.strict_audience_validation,
1418 scopes: config.scopes.clone(),
1419 role_claim: config.role_claim.clone(),
1420 role_mappings: config.role_mappings.clone(),
1421 last_refresh_attempt: RwLock::new(None),
1422 refresh_lock: tokio::sync::Mutex::new(()),
1423 allowlist,
1424 #[cfg(any(test, feature = "test-helpers"))]
1425 test_allow_loopback_ssrf: false,
1426 })
1427 }
1428
1429 #[cfg(any(test, feature = "test-helpers"))]
1433 #[doc(hidden)]
1434 #[must_use]
1435 pub fn __test_allow_loopback_ssrf(mut self) -> Self {
1436 self.test_allow_loopback_ssrf = true;
1437 self
1438 }
1439
1440 pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
1442 self.validate_token_with_reason(token).await.ok()
1443 }
1444
1445 pub async fn validate_token_with_reason(
1452 &self,
1453 token: &str,
1454 ) -> Result<AuthIdentity, JwtValidationFailure> {
1455 let claims = self.decode_claims(token).await?;
1456
1457 self.check_audience(&claims)?;
1458 let role = self.resolve_role(&claims)?;
1459
1460 let sub = claims.sub;
1463 let name = claims
1464 .extra
1465 .get("preferred_username")
1466 .and_then(|v| v.as_str())
1467 .map(String::from)
1468 .or_else(|| sub.clone())
1469 .or(claims.azp)
1470 .or(claims.client_id)
1471 .unwrap_or_else(|| "oauth-client".into());
1472
1473 Ok(AuthIdentity {
1474 name,
1475 role,
1476 method: AuthMethod::OAuthJwt,
1477 raw_token: None,
1478 sub,
1479 })
1480 }
1481
1482 async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
1494 let (key, alg) = self.select_jwks_key(token).await?;
1495
1496 let mut validation = self.validation_template.clone();
1500 validation.algorithms = vec![alg];
1501
1502 let token_owned = token.to_owned();
1505 let join =
1506 tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
1507 .await;
1508
1509 let decode_result = match join {
1510 Ok(r) => r,
1511 Err(join_err) => {
1512 core::hint::cold_path();
1513 tracing::error!(
1514 error = %join_err,
1515 "JWT decode task panicked or was cancelled"
1516 );
1517 return Err(JwtValidationFailure::Invalid);
1518 }
1519 };
1520
1521 decode_result.map(|td| td.claims).map_err(|e| {
1522 core::hint::cold_path();
1523 let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
1524 JwtValidationFailure::Expired
1525 } else {
1526 JwtValidationFailure::Invalid
1527 };
1528 tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
1529 failure
1530 })
1531 }
1532
1533 #[allow(clippy::cognitive_complexity)]
1542 async fn select_jwks_key(
1543 &self,
1544 token: &str,
1545 ) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
1546 let Ok(header) = decode_header(token) else {
1547 core::hint::cold_path();
1548 tracing::debug!("JWT header decode failed");
1549 return Err(JwtValidationFailure::Invalid);
1550 };
1551 let kid = header.kid.as_deref();
1552 tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
1553
1554 if !ACCEPTED_ALGS.contains(&header.alg) {
1555 core::hint::cold_path();
1556 tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
1557 return Err(JwtValidationFailure::Invalid);
1558 }
1559
1560 let Some(key) = self.find_key(kid, header.alg).await else {
1561 core::hint::cold_path();
1562 tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
1563 return Err(JwtValidationFailure::Invalid);
1564 };
1565
1566 Ok((key, header.alg))
1567 }
1568
1569 fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
1577 let aud_ok = claims.aud.contains(&self.expected_audience)
1578 || (!self.strict_audience_validation
1579 && claims
1580 .azp
1581 .as_deref()
1582 .is_some_and(|azp| azp == self.expected_audience));
1583 if aud_ok {
1584 return Ok(());
1585 }
1586 core::hint::cold_path();
1587 tracing::debug!(
1588 aud = ?claims.aud.0,
1589 azp = ?claims.azp,
1590 expected = %self.expected_audience,
1591 strict = self.strict_audience_validation,
1592 "JWT rejected: audience mismatch"
1593 );
1594 Err(JwtValidationFailure::Invalid)
1595 }
1596
1597 fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
1603 if let Some(ref claim_path) = self.role_claim {
1604 let values = resolve_claim_path(&claims.extra, claim_path);
1605 return self
1606 .role_mappings
1607 .iter()
1608 .find(|m| values.contains(&m.claim_value.as_str()))
1609 .map(|m| m.role.clone())
1610 .ok_or(JwtValidationFailure::Invalid);
1611 }
1612
1613 let token_scopes: Vec<&str> = claims
1614 .scope
1615 .as_deref()
1616 .unwrap_or("")
1617 .split_whitespace()
1618 .collect();
1619
1620 self.scopes
1621 .iter()
1622 .find(|m| token_scopes.contains(&m.scope.as_str()))
1623 .map(|m| m.role.clone())
1624 .ok_or(JwtValidationFailure::Invalid)
1625 }
1626
1627 async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1630 {
1632 let guard = self.inner.read().await;
1633 if let Some(cached) = guard.as_ref()
1634 && !cached.is_expired()
1635 && let Some(key) = lookup_key(cached, kid, alg)
1636 {
1637 return Some(key);
1638 }
1639 }
1640
1641 self.refresh_with_cooldown().await;
1643
1644 let guard = self.inner.read().await;
1645 guard
1646 .as_ref()
1647 .and_then(|cached| lookup_key(cached, kid, alg))
1648 }
1649
1650 async fn refresh_with_cooldown(&self) {
1655 let _guard = self.refresh_lock.lock().await;
1657
1658 {
1660 let last = self.last_refresh_attempt.read().await;
1661 if let Some(ts) = *last
1662 && ts.elapsed() < JWKS_REFRESH_COOLDOWN
1663 {
1664 tracing::debug!(
1665 elapsed_ms = ts.elapsed().as_millis(),
1666 cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
1667 "JWKS refresh skipped (cooldown active)"
1668 );
1669 return;
1670 }
1671 }
1672
1673 {
1676 let mut last = self.last_refresh_attempt.write().await;
1677 *last = Some(Instant::now());
1678 }
1679
1680 let _ = self.refresh_inner().await;
1682 }
1683
1684 async fn refresh_inner(&self) -> Result<(), String> {
1689 let Some(jwks) = self.fetch_jwks().await else {
1690 return Ok(());
1691 };
1692 let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
1693 Ok(cache) => cache,
1694 Err(msg) => {
1695 tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
1696 return Err(msg);
1697 }
1698 };
1699
1700 tracing::debug!(
1701 named = keys.len(),
1702 unnamed = unnamed_keys.len(),
1703 "JWKS refreshed"
1704 );
1705
1706 let mut guard = self.inner.write().await;
1707 *guard = Some(CachedKeys {
1708 keys,
1709 unnamed_keys,
1710 fetched_at: Instant::now(),
1711 ttl: self.ttl,
1712 });
1713 Ok(())
1714 }
1715
1716 #[allow(
1718 clippy::cognitive_complexity,
1719 reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
1720 )]
1721 async fn fetch_jwks(&self) -> Option<JwkSet> {
1722 #[cfg(any(test, feature = "test-helpers"))]
1723 let screening = if self.test_allow_loopback_ssrf {
1724 screen_oauth_target_with_test_override(
1725 &self.jwks_uri,
1726 self.allow_http,
1727 &self.allowlist,
1728 true,
1729 )
1730 .await
1731 } else {
1732 screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
1733 };
1734 #[cfg(not(any(test, feature = "test-helpers")))]
1735 let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
1736
1737 if let Err(error) = screening {
1738 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
1739 return None;
1740 }
1741
1742 let mut resp = match self.http.get(&self.jwks_uri).send().await {
1743 Ok(resp) => resp,
1744 Err(e) => {
1745 tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
1746 return None;
1747 }
1748 };
1749
1750 let initial_capacity =
1751 usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
1752 let mut body = Vec::with_capacity(initial_capacity);
1753 while let Some(chunk) = match resp.chunk().await {
1754 Ok(chunk) => chunk,
1755 Err(error) => {
1756 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
1757 return None;
1758 }
1759 } {
1760 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1761 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
1762 if body_len.saturating_add(chunk_len) > self.max_response_bytes {
1763 tracing::warn!(
1764 uri = %self.jwks_uri,
1765 max_bytes = self.max_response_bytes,
1766 "JWKS response exceeded configured size cap"
1767 );
1768 return None;
1769 }
1770 body.extend_from_slice(&chunk);
1771 }
1772
1773 match serde_json::from_slice::<JwkSet>(&body) {
1774 Ok(jwks) => Some(jwks),
1775 Err(error) => {
1776 tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
1777 None
1778 }
1779 }
1780 }
1781
1782 #[cfg(any(test, feature = "test-helpers"))]
1785 #[doc(hidden)]
1786 pub async fn __test_refresh_now(&self) -> Result<(), String> {
1787 let jwks = self
1788 .fetch_jwks()
1789 .await
1790 .ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
1791 let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
1792 let mut guard = self.inner.write().await;
1793 *guard = Some(CachedKeys {
1794 keys,
1795 unnamed_keys,
1796 fetched_at: Instant::now(),
1797 ttl: self.ttl,
1798 });
1799 Ok(())
1800 }
1801
1802 #[cfg(any(test, feature = "test-helpers"))]
1805 #[doc(hidden)]
1806 pub async fn __test_has_kid(&self, kid: &str) -> bool {
1807 let guard = self.inner.read().await;
1808 guard
1809 .as_ref()
1810 .is_some_and(|cache| cache.keys.contains_key(kid))
1811 }
1812}
1813
1814fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
1816 if jwks.keys.len() > max_keys {
1817 return Err(format!(
1818 "jwks_key_count_exceeds_cap: got {} keys, max is {}",
1819 jwks.keys.len(),
1820 max_keys
1821 ));
1822 }
1823 let mut keys = HashMap::new();
1824 let mut unnamed_keys = Vec::new();
1825 for jwk in &jwks.keys {
1826 let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
1827 continue;
1828 };
1829 let Some(alg) = jwk_algorithm(jwk) else {
1830 continue;
1831 };
1832 if let Some(ref kid) = jwk.common.key_id {
1833 keys.insert(kid.clone(), (alg, decoding_key));
1834 } else {
1835 unnamed_keys.push((alg, decoding_key));
1836 }
1837 }
1838 Ok((keys, unnamed_keys))
1839}
1840
1841fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
1843 if let Some(kid) = kid
1844 && let Some((cached_alg, key)) = cached.keys.get(kid)
1845 && *cached_alg == alg
1846 {
1847 return Some(key.clone());
1848 }
1849 cached
1851 .unnamed_keys
1852 .iter()
1853 .find(|(a, _)| *a == alg)
1854 .map(|(_, k)| k.clone())
1855}
1856
1857#[allow(clippy::wildcard_enum_match_arm)]
1859fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
1860 jwk.common.key_algorithm.and_then(|ka| match ka {
1861 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
1862 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
1863 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
1864 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
1865 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
1866 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
1867 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
1868 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
1869 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
1870 _ => None,
1871 })
1872}
1873
1874fn resolve_claim_path<'a>(
1888 extra: &'a HashMap<String, serde_json::Value>,
1889 path: &str,
1890) -> Vec<&'a str> {
1891 let mut segments = path.split('.');
1892 let Some(first) = segments.next() else {
1893 return Vec::new();
1894 };
1895
1896 let mut current: Option<&serde_json::Value> = extra.get(first);
1897
1898 for segment in segments {
1899 current = current.and_then(|v| v.get(segment));
1900 }
1901
1902 match current {
1903 Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
1904 Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
1905 _ => Vec::new(),
1906 }
1907}
1908
1909#[derive(Debug, Deserialize)]
1915struct Claims {
1916 sub: Option<String>,
1918 #[serde(default)]
1921 aud: OneOrMany,
1922 azp: Option<String>,
1924 client_id: Option<String>,
1926 scope: Option<String>,
1928 #[serde(flatten)]
1930 extra: HashMap<String, serde_json::Value>,
1931}
1932
1933#[derive(Debug, Default)]
1935struct OneOrMany(Vec<String>);
1936
1937impl OneOrMany {
1938 fn contains(&self, value: &str) -> bool {
1939 self.0.iter().any(|v| v == value)
1940 }
1941}
1942
1943impl<'de> Deserialize<'de> for OneOrMany {
1944 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
1945 use serde::de;
1946
1947 struct Visitor;
1948 impl<'de> de::Visitor<'de> for Visitor {
1949 type Value = OneOrMany;
1950 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1951 f.write_str("a string or array of strings")
1952 }
1953 fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
1954 Ok(OneOrMany(vec![v.to_owned()]))
1955 }
1956 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
1957 let mut v = Vec::new();
1958 while let Some(s) = seq.next_element::<String>()? {
1959 v.push(s);
1960 }
1961 Ok(OneOrMany(v))
1962 }
1963 }
1964 deserializer.deserialize_any(Visitor)
1965 }
1966}
1967
1968#[must_use]
1975pub fn looks_like_jwt(token: &str) -> bool {
1976 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1977
1978 let mut parts = token.splitn(4, '.');
1979 let Some(header_b64) = parts.next() else {
1980 return false;
1981 };
1982 if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
1984 return false;
1985 }
1986 let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
1988 return false;
1989 };
1990 let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
1992 return false;
1993 };
1994 header.get("alg").is_some()
1995}
1996
1997#[must_use]
2007pub fn protected_resource_metadata(
2008 resource_url: &str,
2009 server_url: &str,
2010 config: &OAuthConfig,
2011) -> serde_json::Value {
2012 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2017 let auth_server = server_url;
2018 serde_json::json!({
2019 "resource": resource_url,
2020 "authorization_servers": [auth_server],
2021 "scopes_supported": scopes,
2022 "bearer_methods_supported": ["header"]
2023 })
2024}
2025
2026#[must_use]
2031pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
2032 let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
2033 let mut meta = serde_json::json!({
2034 "issuer": &config.issuer,
2035 "authorization_endpoint": format!("{server_url}/authorize"),
2036 "token_endpoint": format!("{server_url}/token"),
2037 "registration_endpoint": format!("{server_url}/register"),
2038 "response_types_supported": ["code"],
2039 "grant_types_supported": ["authorization_code", "refresh_token"],
2040 "code_challenge_methods_supported": ["S256"],
2041 "scopes_supported": scopes,
2042 "token_endpoint_auth_methods_supported": ["none"],
2043 });
2044 if let Some(proxy) = &config.proxy
2045 && proxy.expose_admin_endpoints
2046 && let Some(obj) = meta.as_object_mut()
2047 {
2048 if proxy.introspection_url.is_some() {
2049 obj.insert(
2050 "introspection_endpoint".into(),
2051 serde_json::Value::String(format!("{server_url}/introspect")),
2052 );
2053 }
2054 if proxy.revocation_url.is_some() {
2055 obj.insert(
2056 "revocation_endpoint".into(),
2057 serde_json::Value::String(format!("{server_url}/revoke")),
2058 );
2059 }
2060 if proxy.require_auth_on_admin_endpoints {
2061 obj.insert(
2062 "introspection_endpoint_auth_methods_supported".into(),
2063 serde_json::json!(["bearer"]),
2064 );
2065 obj.insert(
2066 "revocation_endpoint_auth_methods_supported".into(),
2067 serde_json::json!(["bearer"]),
2068 );
2069 }
2070 }
2071 meta
2072}
2073
2074#[must_use]
2087pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
2088 use axum::{
2089 http::{StatusCode, header},
2090 response::IntoResponse,
2091 };
2092
2093 let upstream_query = replace_client_id(query, &proxy.client_id);
2095 let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
2096
2097 (StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
2098}
2099
2100pub async fn handle_token(
2106 http: &OauthHttpClient,
2107 proxy: &OAuthProxyConfig,
2108 body: &str,
2109) -> axum::response::Response {
2110 use axum::{
2111 http::{StatusCode, header},
2112 response::IntoResponse,
2113 };
2114
2115 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2117
2118 if let Some(ref secret) = proxy.client_secret {
2120 use std::fmt::Write;
2121
2122 use secrecy::ExposeSecret;
2123 let _ = write!(
2124 upstream_body,
2125 "&client_secret={}",
2126 urlencoding::encode(secret.expose_secret())
2127 );
2128 }
2129
2130 let result = http
2131 .send_screened(
2132 &proxy.token_url,
2133 http.inner
2134 .post(&proxy.token_url)
2135 .header("Content-Type", "application/x-www-form-urlencoded")
2136 .body(upstream_body),
2137 )
2138 .await;
2139
2140 match result {
2141 Ok(resp) => {
2142 let status =
2143 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2144 let body_bytes = resp.bytes().await.unwrap_or_default();
2145 (
2146 status,
2147 [(header::CONTENT_TYPE, "application/json")],
2148 body_bytes,
2149 )
2150 .into_response()
2151 }
2152 Err(e) => {
2153 tracing::error!(error = %e, "OAuth token proxy request failed");
2154 (
2155 StatusCode::BAD_GATEWAY,
2156 [(header::CONTENT_TYPE, "application/json")],
2157 "{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
2158 )
2159 .into_response()
2160 }
2161 }
2162}
2163
2164#[must_use]
2171pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
2172 let mut resp = serde_json::json!({
2173 "client_id": proxy.client_id,
2174 "token_endpoint_auth_method": "none",
2175 });
2176 if let Some(uris) = body.get("redirect_uris")
2177 && let Some(obj) = resp.as_object_mut()
2178 {
2179 obj.insert("redirect_uris".into(), uris.clone());
2180 }
2181 if let Some(name) = body.get("client_name")
2182 && let Some(obj) = resp.as_object_mut()
2183 {
2184 obj.insert("client_name".into(), name.clone());
2185 }
2186 resp
2187}
2188
2189pub async fn handle_introspect(
2195 http: &OauthHttpClient,
2196 proxy: &OAuthProxyConfig,
2197 body: &str,
2198) -> axum::response::Response {
2199 let Some(ref url) = proxy.introspection_url else {
2200 return oauth_error_response(
2201 axum::http::StatusCode::NOT_FOUND,
2202 "not_supported",
2203 "introspection endpoint is not configured",
2204 );
2205 };
2206 proxy_oauth_admin_request(http, proxy, url, body).await
2207}
2208
2209pub async fn handle_revoke(
2216 http: &OauthHttpClient,
2217 proxy: &OAuthProxyConfig,
2218 body: &str,
2219) -> axum::response::Response {
2220 let Some(ref url) = proxy.revocation_url else {
2221 return oauth_error_response(
2222 axum::http::StatusCode::NOT_FOUND,
2223 "not_supported",
2224 "revocation endpoint is not configured",
2225 );
2226 };
2227 proxy_oauth_admin_request(http, proxy, url, body).await
2228}
2229
2230async fn proxy_oauth_admin_request(
2234 http: &OauthHttpClient,
2235 proxy: &OAuthProxyConfig,
2236 upstream_url: &str,
2237 body: &str,
2238) -> axum::response::Response {
2239 use axum::{
2240 http::{StatusCode, header},
2241 response::IntoResponse,
2242 };
2243
2244 let mut upstream_body = replace_client_id(body, &proxy.client_id);
2245 if let Some(ref secret) = proxy.client_secret {
2246 use std::fmt::Write;
2247
2248 use secrecy::ExposeSecret;
2249 let _ = write!(
2250 upstream_body,
2251 "&client_secret={}",
2252 urlencoding::encode(secret.expose_secret())
2253 );
2254 }
2255
2256 let result = http
2257 .send_screened(
2258 upstream_url,
2259 http.inner
2260 .post(upstream_url)
2261 .header("Content-Type", "application/x-www-form-urlencoded")
2262 .body(upstream_body),
2263 )
2264 .await;
2265
2266 match result {
2267 Ok(resp) => {
2268 let status =
2269 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
2270 let content_type = resp
2271 .headers()
2272 .get(header::CONTENT_TYPE)
2273 .and_then(|v| v.to_str().ok())
2274 .unwrap_or("application/json")
2275 .to_owned();
2276 let body_bytes = resp.bytes().await.unwrap_or_default();
2277 (status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
2278 }
2279 Err(e) => {
2280 tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
2281 oauth_error_response(
2282 StatusCode::BAD_GATEWAY,
2283 "server_error",
2284 "upstream endpoint unreachable",
2285 )
2286 }
2287 }
2288}
2289
2290fn oauth_error_response(
2291 status: axum::http::StatusCode,
2292 error: &str,
2293 description: &str,
2294) -> axum::response::Response {
2295 use axum::{http::header, response::IntoResponse};
2296 let body = serde_json::json!({
2297 "error": error,
2298 "error_description": description,
2299 });
2300 (
2301 status,
2302 [(header::CONTENT_TYPE, "application/json")],
2303 body.to_string(),
2304 )
2305 .into_response()
2306}
2307
2308#[derive(Debug, Deserialize)]
2314struct OAuthErrorResponse {
2315 error: String,
2316 error_description: Option<String>,
2317}
2318
2319fn sanitize_oauth_error_code(raw: &str) -> &'static str {
2326 match raw {
2327 "invalid_request" => "invalid_request",
2328 "invalid_client" => "invalid_client",
2329 "invalid_grant" => "invalid_grant",
2330 "unauthorized_client" => "unauthorized_client",
2331 "unsupported_grant_type" => "unsupported_grant_type",
2332 "invalid_scope" => "invalid_scope",
2333 "temporarily_unavailable" => "temporarily_unavailable",
2334 "invalid_target" => "invalid_target",
2336 _ => "server_error",
2339 }
2340}
2341
2342pub async fn exchange_token(
2354 http: &OauthHttpClient,
2355 config: &TokenExchangeConfig,
2356 subject_token: &str,
2357) -> Result<ExchangedToken, crate::error::McpxError> {
2358 use secrecy::ExposeSecret;
2359
2360 let mut req = http
2361 .inner
2362 .post(&config.token_url)
2363 .header("Content-Type", "application/x-www-form-urlencoded")
2364 .header("Accept", "application/json");
2365
2366 if let Some(ref secret) = config.client_secret {
2368 use base64::Engine;
2369 let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
2370 "{}:{}",
2371 urlencoding::encode(&config.client_id),
2372 urlencoding::encode(secret.expose_secret()),
2373 ));
2374 req = req.header("Authorization", format!("Basic {credentials}"));
2375 }
2376 let form_body = build_exchange_form(config, subject_token);
2379
2380 let resp = http
2381 .send_screened(&config.token_url, req.body(form_body))
2382 .await
2383 .map_err(|e| {
2384 tracing::error!(error = %e, "token exchange request failed");
2385 crate::error::McpxError::Auth("server_error".into())
2387 })?;
2388
2389 let status = resp.status();
2390 let body_bytes = resp.bytes().await.map_err(|e| {
2391 tracing::error!(error = %e, "failed to read token exchange response");
2392 crate::error::McpxError::Auth("server_error".into())
2393 })?;
2394
2395 if !status.is_success() {
2396 core::hint::cold_path();
2397 let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
2400 let short_code = parsed
2401 .as_ref()
2402 .map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
2403 if let Some(ref e) = parsed {
2404 tracing::warn!(
2405 status = %status,
2406 upstream_error = %e.error,
2407 upstream_error_description = e.error_description.as_deref().unwrap_or(""),
2408 client_code = %short_code,
2409 "token exchange rejected by authorization server",
2410 );
2411 } else {
2412 tracing::warn!(
2413 status = %status,
2414 client_code = %short_code,
2415 "token exchange rejected (unparseable upstream body)",
2416 );
2417 }
2418 return Err(crate::error::McpxError::Auth(short_code.into()));
2419 }
2420
2421 let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
2422 tracing::error!(error = %e, "failed to parse token exchange response");
2423 crate::error::McpxError::Auth("server_error".into())
2426 })?;
2427
2428 log_exchanged_token(&exchanged);
2429
2430 Ok(exchanged)
2431}
2432
2433fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
2436 let body = format!(
2437 "grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
2438 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
2439 urlencoding::encode(subject_token),
2440 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2441 urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
2442 urlencoding::encode(&config.audience),
2443 );
2444 if config.client_secret.is_none() {
2445 format!(
2446 "{body}&client_id={}",
2447 urlencoding::encode(&config.client_id)
2448 )
2449 } else {
2450 body
2451 }
2452}
2453
2454fn log_exchanged_token(exchanged: &ExchangedToken) {
2457 use base64::Engine;
2458
2459 if !looks_like_jwt(&exchanged.access_token) {
2460 tracing::debug!(
2461 token_len = exchanged.access_token.len(),
2462 issued_token_type = ?exchanged.issued_token_type,
2463 expires_in = exchanged.expires_in,
2464 "exchanged token (opaque)",
2465 );
2466 return;
2467 }
2468 let Some(payload) = exchanged.access_token.split('.').nth(1) else {
2469 return;
2470 };
2471 let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
2472 return;
2473 };
2474 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
2475 return;
2476 };
2477 tracing::debug!(
2478 sub = ?claims.get("sub"),
2479 aud = ?claims.get("aud"),
2480 azp = ?claims.get("azp"),
2481 iss = ?claims.get("iss"),
2482 expires_in = exchanged.expires_in,
2483 "exchanged token claims (JWT)",
2484 );
2485}
2486
2487fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
2489 let encoded_id = urlencoding::encode(upstream_client_id);
2490 let mut parts: Vec<String> = params
2491 .split('&')
2492 .filter(|p| !p.starts_with("client_id="))
2493 .map(String::from)
2494 .collect();
2495 parts.push(format!("client_id={encoded_id}"));
2496 parts.join("&")
2497}
2498
2499#[cfg(test)]
2500mod tests {
2501 use std::sync::Arc;
2502
2503 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2504
2505 use super::*;
2506
2507 #[test]
2508 fn looks_like_jwt_valid() {
2509 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
2511 let payload = URL_SAFE_NO_PAD.encode(b"{}");
2512 let token = format!("{header}.{payload}.signature");
2513 assert!(looks_like_jwt(&token));
2514 }
2515
2516 #[test]
2517 fn looks_like_jwt_rejects_opaque_token() {
2518 assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
2519 }
2520
2521 #[test]
2522 fn looks_like_jwt_rejects_two_segments() {
2523 let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
2524 let token = format!("{header}.payload");
2525 assert!(!looks_like_jwt(&token));
2526 }
2527
2528 #[test]
2529 fn looks_like_jwt_rejects_four_segments() {
2530 assert!(!looks_like_jwt("a.b.c.d"));
2531 }
2532
2533 #[test]
2534 fn looks_like_jwt_rejects_no_alg() {
2535 let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
2536 let payload = URL_SAFE_NO_PAD.encode(b"{}");
2537 let token = format!("{header}.{payload}.sig");
2538 assert!(!looks_like_jwt(&token));
2539 }
2540
2541 #[test]
2542 fn protected_resource_metadata_shape() {
2543 let config = OAuthConfig {
2544 issuer: "https://auth.example.com".into(),
2545 audience: "https://mcp.example.com/mcp".into(),
2546 jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
2547 scopes: vec![
2548 ScopeMapping {
2549 scope: "mcp:read".into(),
2550 role: "viewer".into(),
2551 },
2552 ScopeMapping {
2553 scope: "mcp:admin".into(),
2554 role: "ops".into(),
2555 },
2556 ],
2557 role_claim: None,
2558 role_mappings: vec![],
2559 jwks_cache_ttl: "10m".into(),
2560 proxy: None,
2561 token_exchange: None,
2562 ca_cert_path: None,
2563 allow_http_oauth_urls: false,
2564 max_jwks_keys: default_max_jwks_keys(),
2565 strict_audience_validation: false,
2566 jwks_max_response_bytes: default_jwks_max_bytes(),
2567 ssrf_allowlist: None,
2568 };
2569 let meta = protected_resource_metadata(
2570 "https://mcp.example.com/mcp",
2571 "https://mcp.example.com",
2572 &config,
2573 );
2574 assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
2575 assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
2576 assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
2577 assert_eq!(meta["bearer_methods_supported"][0], "header");
2578 }
2579
2580 fn validation_https_config() -> OAuthConfig {
2585 OAuthConfig::builder(
2586 "https://auth.example.com",
2587 "mcp",
2588 "https://auth.example.com/.well-known/jwks.json",
2589 )
2590 .build()
2591 }
2592
2593 #[test]
2594 fn validate_accepts_all_https_urls() {
2595 let cfg = validation_https_config();
2596 cfg.validate().expect("all-HTTPS config must validate");
2597 }
2598
2599 #[test]
2600 fn validate_rejects_http_jwks_uri() {
2601 let mut cfg = validation_https_config();
2602 cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
2603 let err = cfg.validate().expect_err("http jwks_uri must be rejected");
2604 let msg = err.to_string();
2605 assert!(
2606 msg.contains("oauth.jwks_uri") && msg.contains("https"),
2607 "error must reference offending field + scheme requirement; got {msg:?}"
2608 );
2609 }
2610
2611 #[test]
2612 fn validate_rejects_http_proxy_authorize_url() {
2613 let mut cfg = validation_https_config();
2614 cfg.proxy = Some(
2615 OAuthProxyConfig::builder(
2616 "http://idp.example.com/authorize", "https://idp.example.com/token",
2618 "client",
2619 )
2620 .build(),
2621 );
2622 let err = cfg
2623 .validate()
2624 .expect_err("http authorize_url must be rejected");
2625 assert!(
2626 err.to_string().contains("oauth.proxy.authorize_url"),
2627 "error must reference proxy.authorize_url; got {err}"
2628 );
2629 }
2630
2631 #[test]
2632 fn validate_rejects_http_proxy_token_url() {
2633 let mut cfg = validation_https_config();
2634 cfg.proxy = Some(
2635 OAuthProxyConfig::builder(
2636 "https://idp.example.com/authorize",
2637 "http://idp.example.com/token", "client",
2639 )
2640 .build(),
2641 );
2642 let err = cfg.validate().expect_err("http token_url must be rejected");
2643 assert!(
2644 err.to_string().contains("oauth.proxy.token_url"),
2645 "error must reference proxy.token_url; got {err}"
2646 );
2647 }
2648
2649 #[test]
2650 fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
2651 let mut cfg = validation_https_config();
2652 cfg.proxy = Some(
2653 OAuthProxyConfig::builder(
2654 "https://idp.example.com/authorize",
2655 "https://idp.example.com/token",
2656 "client",
2657 )
2658 .introspection_url("http://idp.example.com/introspect")
2659 .build(),
2660 );
2661 let err = cfg
2662 .validate()
2663 .expect_err("http introspection_url must be rejected");
2664 assert!(err.to_string().contains("oauth.proxy.introspection_url"));
2665
2666 let mut cfg = validation_https_config();
2667 cfg.proxy = Some(
2668 OAuthProxyConfig::builder(
2669 "https://idp.example.com/authorize",
2670 "https://idp.example.com/token",
2671 "client",
2672 )
2673 .revocation_url("http://idp.example.com/revoke")
2674 .build(),
2675 );
2676 let err = cfg
2677 .validate()
2678 .expect_err("http revocation_url must be rejected");
2679 assert!(err.to_string().contains("oauth.proxy.revocation_url"));
2680 }
2681
2682 #[test]
2683 fn validate_rejects_http_token_exchange_url() {
2684 let mut cfg = validation_https_config();
2685 cfg.token_exchange = Some(TokenExchangeConfig::new(
2686 "http://idp.example.com/token".into(), "client".into(),
2688 None,
2689 None,
2690 "downstream".into(),
2691 ));
2692 let err = cfg
2693 .validate()
2694 .expect_err("http token_exchange.token_url must be rejected");
2695 assert!(
2696 err.to_string().contains("oauth.token_exchange.token_url"),
2697 "error must reference token_exchange.token_url; got {err}"
2698 );
2699 }
2700
2701 #[test]
2702 fn validate_rejects_unparseable_url() {
2703 let mut cfg = validation_https_config();
2704 cfg.jwks_uri = "not a url".into();
2705 let err = cfg
2706 .validate()
2707 .expect_err("unparseable URL must be rejected");
2708 assert!(err.to_string().contains("invalid URL"));
2709 }
2710
2711 #[test]
2712 fn validate_rejects_non_http_scheme() {
2713 let mut cfg = validation_https_config();
2714 cfg.jwks_uri = "file:///etc/passwd".into();
2715 let err = cfg.validate().expect_err("file:// scheme must be rejected");
2716 let msg = err.to_string();
2717 assert!(
2718 msg.contains("must use https scheme") && msg.contains("file"),
2719 "error must reject non-http(s) schemes; got {msg:?}"
2720 );
2721 }
2722
2723 #[test]
2724 fn validate_accepts_http_with_escape_hatch() {
2725 let mut cfg = OAuthConfig::builder(
2730 "http://auth.local",
2731 "mcp",
2732 "http://auth.local/.well-known/jwks.json",
2733 )
2734 .allow_http_oauth_urls(true)
2735 .build();
2736 cfg.proxy = Some(
2737 OAuthProxyConfig::builder(
2738 "http://idp.local/authorize",
2739 "http://idp.local/token",
2740 "client",
2741 )
2742 .introspection_url("http://idp.local/introspect")
2743 .revocation_url("http://idp.local/revoke")
2744 .build(),
2745 );
2746 cfg.token_exchange = Some(TokenExchangeConfig::new(
2747 "http://idp.local/token".into(),
2748 "client".into(),
2749 None,
2750 None,
2751 "downstream".into(),
2752 ));
2753 cfg.validate()
2754 .expect("escape hatch must permit http on all URL fields");
2755 }
2756
2757 #[test]
2758 fn validate_with_escape_hatch_still_rejects_unparseable() {
2759 let mut cfg = validation_https_config();
2762 cfg.allow_http_oauth_urls = true;
2763 cfg.jwks_uri = "::not-a-url::".into();
2764 cfg.validate()
2765 .expect_err("escape hatch must NOT bypass URL parsing");
2766 }
2767
2768 #[tokio::test]
2769 async fn jwks_cache_rejects_redirect_downgrade_to_http() {
2770 rustls::crypto::ring::default_provider()
2785 .install_default()
2786 .ok();
2787
2788 let policy = reqwest::redirect::Policy::custom(|attempt| {
2789 if attempt.url().scheme() != "https" {
2790 attempt.error("redirect to non-HTTPS URL refused")
2791 } else if attempt.previous().len() >= 2 {
2792 attempt.error("too many redirects (max 2)")
2793 } else {
2794 attempt.follow()
2795 }
2796 });
2797 let client = reqwest::Client::builder()
2798 .timeout(Duration::from_secs(5))
2799 .connect_timeout(Duration::from_secs(3))
2800 .redirect(policy)
2801 .build()
2802 .expect("test client builds");
2803
2804 let mock = wiremock::MockServer::start().await;
2805 wiremock::Mock::given(wiremock::matchers::method("GET"))
2806 .and(wiremock::matchers::path("/jwks.json"))
2807 .respond_with(
2808 wiremock::ResponseTemplate::new(302)
2809 .insert_header("location", "http://example.invalid/jwks.json"),
2810 )
2811 .mount(&mock)
2812 .await;
2813
2814 let url = format!("{}/jwks.json", mock.uri());
2823 let err = client
2824 .get(&url)
2825 .send()
2826 .await
2827 .expect_err("redirect policy must reject scheme downgrade");
2828 let chain = format!("{err:#}");
2829 assert!(
2830 chain.contains("redirect to non-HTTPS URL refused")
2831 || chain.to_lowercase().contains("redirect"),
2832 "error must surface redirect-policy rejection; got {chain:?}"
2833 );
2834 }
2835
2836 use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
2841
2842 fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
2844 let mut rng = rsa::rand_core::OsRng;
2845 let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
2846 let private_pem = private_key
2847 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
2848 .expect("PKCS8 PEM export")
2849 .to_string();
2850
2851 let public_key = private_key.to_public_key();
2852 let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
2853 let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
2854
2855 let jwks = serde_json::json!({
2856 "keys": [{
2857 "kty": "RSA",
2858 "use": "sig",
2859 "alg": "RS256",
2860 "kid": kid,
2861 "n": n,
2862 "e": e
2863 }]
2864 });
2865
2866 (private_pem, jwks)
2867 }
2868
2869 fn mint_token(
2871 private_pem: &str,
2872 kid: &str,
2873 issuer: &str,
2874 audience: &str,
2875 subject: &str,
2876 scope: &str,
2877 ) -> String {
2878 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
2879 .expect("encoding key from PEM");
2880 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
2881 header.kid = Some(kid.into());
2882
2883 let now = jsonwebtoken::get_current_timestamp();
2884 let claims = serde_json::json!({
2885 "iss": issuer,
2886 "aud": audience,
2887 "sub": subject,
2888 "scope": scope,
2889 "exp": now + 3600,
2890 "iat": now,
2891 });
2892
2893 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
2894 }
2895
2896 fn test_config(jwks_uri: &str) -> OAuthConfig {
2897 OAuthConfig {
2898 issuer: "https://auth.test.local".into(),
2899 audience: "https://mcp.test.local/mcp".into(),
2900 jwks_uri: jwks_uri.into(),
2901 scopes: vec![
2902 ScopeMapping {
2903 scope: "mcp:read".into(),
2904 role: "viewer".into(),
2905 },
2906 ScopeMapping {
2907 scope: "mcp:admin".into(),
2908 role: "ops".into(),
2909 },
2910 ],
2911 role_claim: None,
2912 role_mappings: vec![],
2913 jwks_cache_ttl: "5m".into(),
2914 proxy: None,
2915 token_exchange: None,
2916 ca_cert_path: None,
2917 allow_http_oauth_urls: true,
2918 max_jwks_keys: default_max_jwks_keys(),
2919 strict_audience_validation: false,
2920 jwks_max_response_bytes: default_jwks_max_bytes(),
2921 ssrf_allowlist: None,
2922 }
2923 }
2924
2925 fn test_cache(config: &OAuthConfig) -> JwksCache {
2926 JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
2927 }
2928
2929 #[tokio::test]
2930 async fn valid_jwt_returns_identity() {
2931 let kid = "test-key-1";
2932 let (pem, jwks) = generate_test_keypair(kid);
2933
2934 let mock_server = wiremock::MockServer::start().await;
2935 wiremock::Mock::given(wiremock::matchers::method("GET"))
2936 .and(wiremock::matchers::path("/jwks.json"))
2937 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2938 .mount(&mock_server)
2939 .await;
2940
2941 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2942 let config = test_config(&jwks_uri);
2943 let cache = test_cache(&config);
2944
2945 let token = mint_token(
2946 &pem,
2947 kid,
2948 "https://auth.test.local",
2949 "https://mcp.test.local/mcp",
2950 "ci-bot",
2951 "mcp:read mcp:other",
2952 );
2953
2954 let identity = cache.validate_token(&token).await;
2955 assert!(identity.is_some(), "valid JWT should authenticate");
2956 let id = identity.unwrap();
2957 assert_eq!(id.name, "ci-bot");
2958 assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
2960 }
2961
2962 #[tokio::test]
2963 async fn wrong_issuer_rejected() {
2964 let kid = "test-key-2";
2965 let (pem, jwks) = generate_test_keypair(kid);
2966
2967 let mock_server = wiremock::MockServer::start().await;
2968 wiremock::Mock::given(wiremock::matchers::method("GET"))
2969 .and(wiremock::matchers::path("/jwks.json"))
2970 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2971 .mount(&mock_server)
2972 .await;
2973
2974 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
2975 let config = test_config(&jwks_uri);
2976 let cache = test_cache(&config);
2977
2978 let token = mint_token(
2979 &pem,
2980 kid,
2981 "https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
2983 "attacker",
2984 "mcp:admin",
2985 );
2986
2987 assert!(cache.validate_token(&token).await.is_none());
2988 }
2989
2990 #[tokio::test]
2991 async fn wrong_audience_rejected() {
2992 let kid = "test-key-3";
2993 let (pem, jwks) = generate_test_keypair(kid);
2994
2995 let mock_server = wiremock::MockServer::start().await;
2996 wiremock::Mock::given(wiremock::matchers::method("GET"))
2997 .and(wiremock::matchers::path("/jwks.json"))
2998 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
2999 .mount(&mock_server)
3000 .await;
3001
3002 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3003 let config = test_config(&jwks_uri);
3004 let cache = test_cache(&config);
3005
3006 let token = mint_token(
3007 &pem,
3008 kid,
3009 "https://auth.test.local",
3010 "https://wrong-audience.example.com", "attacker",
3012 "mcp:admin",
3013 );
3014
3015 assert!(cache.validate_token(&token).await.is_none());
3016 }
3017
3018 #[tokio::test]
3019 async fn expired_jwt_rejected() {
3020 let kid = "test-key-4";
3021 let (pem, jwks) = generate_test_keypair(kid);
3022
3023 let mock_server = wiremock::MockServer::start().await;
3024 wiremock::Mock::given(wiremock::matchers::method("GET"))
3025 .and(wiremock::matchers::path("/jwks.json"))
3026 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3027 .mount(&mock_server)
3028 .await;
3029
3030 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3031 let config = test_config(&jwks_uri);
3032 let cache = test_cache(&config);
3033
3034 let encoding_key =
3036 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
3037 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3038 header.kid = Some(kid.into());
3039 let now = jsonwebtoken::get_current_timestamp();
3040 let claims = serde_json::json!({
3041 "iss": "https://auth.test.local",
3042 "aud": "https://mcp.test.local/mcp",
3043 "sub": "expired-bot",
3044 "scope": "mcp:read",
3045 "exp": now - 120,
3046 "iat": now - 3720,
3047 });
3048 let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
3049
3050 assert!(cache.validate_token(&token).await.is_none());
3051 }
3052
3053 #[tokio::test]
3054 async fn no_matching_scope_rejected() {
3055 let kid = "test-key-5";
3056 let (pem, jwks) = generate_test_keypair(kid);
3057
3058 let mock_server = wiremock::MockServer::start().await;
3059 wiremock::Mock::given(wiremock::matchers::method("GET"))
3060 .and(wiremock::matchers::path("/jwks.json"))
3061 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3062 .mount(&mock_server)
3063 .await;
3064
3065 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3066 let config = test_config(&jwks_uri);
3067 let cache = test_cache(&config);
3068
3069 let token = mint_token(
3070 &pem,
3071 kid,
3072 "https://auth.test.local",
3073 "https://mcp.test.local/mcp",
3074 "limited-bot",
3075 "some:other:scope", );
3077
3078 assert!(cache.validate_token(&token).await.is_none());
3079 }
3080
3081 #[tokio::test]
3082 async fn wrong_signing_key_rejected() {
3083 let kid = "test-key-6";
3084 let (_pem, jwks) = generate_test_keypair(kid);
3085
3086 let (attacker_pem, _) = generate_test_keypair(kid);
3088
3089 let mock_server = wiremock::MockServer::start().await;
3090 wiremock::Mock::given(wiremock::matchers::method("GET"))
3091 .and(wiremock::matchers::path("/jwks.json"))
3092 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3093 .mount(&mock_server)
3094 .await;
3095
3096 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3097 let config = test_config(&jwks_uri);
3098 let cache = test_cache(&config);
3099
3100 let token = mint_token(
3102 &attacker_pem,
3103 kid,
3104 "https://auth.test.local",
3105 "https://mcp.test.local/mcp",
3106 "attacker",
3107 "mcp:admin",
3108 );
3109
3110 assert!(cache.validate_token(&token).await.is_none());
3111 }
3112
3113 #[tokio::test]
3114 async fn admin_scope_maps_to_ops_role() {
3115 let kid = "test-key-7";
3116 let (pem, jwks) = generate_test_keypair(kid);
3117
3118 let mock_server = wiremock::MockServer::start().await;
3119 wiremock::Mock::given(wiremock::matchers::method("GET"))
3120 .and(wiremock::matchers::path("/jwks.json"))
3121 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3122 .mount(&mock_server)
3123 .await;
3124
3125 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3126 let config = test_config(&jwks_uri);
3127 let cache = test_cache(&config);
3128
3129 let token = mint_token(
3130 &pem,
3131 kid,
3132 "https://auth.test.local",
3133 "https://mcp.test.local/mcp",
3134 "admin-bot",
3135 "mcp:admin",
3136 );
3137
3138 let id = cache
3139 .validate_token(&token)
3140 .await
3141 .expect("should authenticate");
3142 assert_eq!(id.role, "ops");
3143 assert_eq!(id.name, "admin-bot");
3144 }
3145
3146 #[tokio::test]
3147 async fn jwks_server_down_returns_none() {
3148 let config = test_config("http://127.0.0.1:1/jwks.json");
3150 let cache = test_cache(&config);
3151
3152 let kid = "orphan-key";
3153 let (pem, _) = generate_test_keypair(kid);
3154 let token = mint_token(
3155 &pem,
3156 kid,
3157 "https://auth.test.local",
3158 "https://mcp.test.local/mcp",
3159 "bot",
3160 "mcp:read",
3161 );
3162
3163 assert!(cache.validate_token(&token).await.is_none());
3164 }
3165
3166 #[test]
3171 fn resolve_claim_path_flat_string() {
3172 let mut extra = HashMap::new();
3173 extra.insert(
3174 "scope".into(),
3175 serde_json::Value::String("mcp:read mcp:admin".into()),
3176 );
3177 let values = resolve_claim_path(&extra, "scope");
3178 assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
3179 }
3180
3181 #[test]
3182 fn resolve_claim_path_flat_array() {
3183 let mut extra = HashMap::new();
3184 extra.insert(
3185 "roles".into(),
3186 serde_json::json!(["mcp-admin", "mcp-viewer"]),
3187 );
3188 let values = resolve_claim_path(&extra, "roles");
3189 assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
3190 }
3191
3192 #[test]
3193 fn resolve_claim_path_nested_keycloak() {
3194 let mut extra = HashMap::new();
3195 extra.insert(
3196 "realm_access".into(),
3197 serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
3198 );
3199 let values = resolve_claim_path(&extra, "realm_access.roles");
3200 assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
3201 }
3202
3203 #[test]
3204 fn resolve_claim_path_missing_returns_empty() {
3205 let extra = HashMap::new();
3206 assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
3207 }
3208
3209 #[test]
3210 fn resolve_claim_path_numeric_leaf_returns_empty() {
3211 let mut extra = HashMap::new();
3212 extra.insert("count".into(), serde_json::json!(42));
3213 assert!(resolve_claim_path(&extra, "count").is_empty());
3214 }
3215
3216 fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
3222 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
3223 .expect("encoding key from PEM");
3224 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
3225 header.kid = Some(kid.into());
3226 jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
3227 }
3228
3229 fn test_config_with_role_claim(
3230 jwks_uri: &str,
3231 role_claim: &str,
3232 role_mappings: Vec<RoleMapping>,
3233 ) -> OAuthConfig {
3234 OAuthConfig {
3235 issuer: "https://auth.test.local".into(),
3236 audience: "https://mcp.test.local/mcp".into(),
3237 jwks_uri: jwks_uri.into(),
3238 scopes: vec![],
3239 role_claim: Some(role_claim.into()),
3240 role_mappings,
3241 jwks_cache_ttl: "5m".into(),
3242 proxy: None,
3243 token_exchange: None,
3244 ca_cert_path: None,
3245 allow_http_oauth_urls: true,
3246 max_jwks_keys: default_max_jwks_keys(),
3247 strict_audience_validation: false,
3248 jwks_max_response_bytes: default_jwks_max_bytes(),
3249 ssrf_allowlist: None,
3250 }
3251 }
3252
3253 #[tokio::test]
3254 async fn screen_oauth_target_rejects_literal_ip() {
3255 let err = screen_oauth_target(
3256 "https://127.0.0.1/jwks.json",
3257 false,
3258 &crate::ssrf::CompiledSsrfAllowlist::default(),
3259 )
3260 .await
3261 .expect_err("literal IPs must be rejected");
3262 let msg = err.to_string();
3263 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3264 }
3265
3266 #[tokio::test]
3267 async fn screen_oauth_target_rejects_private_dns_resolution() {
3268 let err = screen_oauth_target(
3269 "https://localhost/jwks.json",
3270 false,
3271 &crate::ssrf::CompiledSsrfAllowlist::default(),
3272 )
3273 .await
3274 .expect_err("localhost resolution must be rejected");
3275 let msg = err.to_string();
3276 assert!(
3277 msg.contains("blocked IP") && msg.contains("loopback"),
3278 "got {msg:?}"
3279 );
3280 }
3281
3282 #[tokio::test]
3283 async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
3284 let err = screen_oauth_target(
3285 "http://127.0.0.1/jwks.json",
3286 true,
3287 &crate::ssrf::CompiledSsrfAllowlist::default(),
3288 )
3289 .await
3290 .expect_err("literal IPs must still be rejected when http is allowed");
3291 let msg = err.to_string();
3292 assert!(msg.contains("literal IPv4 addresses are forbidden"));
3293 }
3294
3295 #[tokio::test]
3296 async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
3297 let err = screen_oauth_target(
3298 "http://localhost/jwks.json",
3299 true,
3300 &crate::ssrf::CompiledSsrfAllowlist::default(),
3301 )
3302 .await
3303 .expect_err("private DNS resolution must still be rejected when http is allowed");
3304 let msg = err.to_string();
3305 assert!(
3306 msg.contains("blocked IP") && msg.contains("loopback"),
3307 "got {msg:?}"
3308 );
3309 }
3310
3311 #[tokio::test]
3312 async fn screen_oauth_target_allows_public_hostname() {
3313 screen_oauth_target(
3314 "https://example.com/.well-known/jwks.json",
3315 false,
3316 &crate::ssrf::CompiledSsrfAllowlist::default(),
3317 )
3318 .await
3319 .expect("public hostname should pass screening");
3320 }
3321
3322 fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
3328 let raw = OAuthSsrfAllowlist {
3329 hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
3330 cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
3331 };
3332 compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
3333 }
3334
3335 #[test]
3336 fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
3337 let raw = OAuthSsrfAllowlist {
3338 hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
3339 cidrs: vec![],
3340 };
3341 let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
3342 assert_eq!(compiled.host_count(), 1);
3343 assert!(compiled.host_allowed("rhbk.ops.example.com"));
3344 assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
3345 }
3346
3347 #[test]
3348 fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
3349 let raw = OAuthSsrfAllowlist {
3350 hosts: vec!["10.0.0.1".into()],
3351 cidrs: vec![],
3352 };
3353 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
3354 assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
3355 }
3356
3357 #[test]
3358 fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
3359 let raw = OAuthSsrfAllowlist {
3360 hosts: vec!["rhbk.ops.example.com:8443".into()],
3361 cidrs: vec![],
3362 };
3363 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
3364 assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
3365 }
3366
3367 #[test]
3368 fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
3369 let raw = OAuthSsrfAllowlist {
3370 hosts: vec![],
3371 cidrs: vec!["not-a-cidr".into()],
3372 };
3373 let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
3374 assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
3375 }
3376
3377 #[test]
3378 fn validate_rejects_misconfigured_allowlist() {
3379 let mut cfg = OAuthConfig::builder(
3380 "https://auth.example.com/",
3381 "mcp",
3382 "https://auth.example.com/jwks.json",
3383 )
3384 .build();
3385 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
3386 hosts: vec!["10.0.0.1".into()],
3387 cidrs: vec![],
3388 });
3389 let err = cfg
3390 .validate()
3391 .expect_err("literal IP host must be rejected");
3392 assert!(
3393 err.to_string().contains("oauth.ssrf_allowlist"),
3394 "got {err}"
3395 );
3396 }
3397
3398 #[tokio::test]
3399 async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
3400 let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
3404 let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
3405 .await
3406 .expect_err("loopback must still be blocked when not in allowlist");
3407 let msg = err.to_string();
3408 assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
3409 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
3410 assert!(msg.contains("SECURITY.md"), "got {msg:?}");
3411 }
3412
3413 #[tokio::test]
3414 async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
3415 let err = screen_oauth_target(
3418 "https://localhost/jwks.json",
3419 false,
3420 &crate::ssrf::CompiledSsrfAllowlist::default(),
3421 )
3422 .await
3423 .expect_err("loopback rejection");
3424 let msg = err.to_string();
3425 assert!(msg.contains("blocked IP"), "got {msg:?}");
3426 assert!(msg.contains("loopback"), "got {msg:?}");
3427 assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
3429 }
3430
3431 #[tokio::test]
3432 async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
3433 let allow = make_allowlist(&["localhost"], &[]);
3435 screen_oauth_target("https://localhost/jwks.json", false, &allow)
3436 .await
3437 .expect("allowlisted host must pass");
3438 }
3439
3440 #[tokio::test]
3441 async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
3442 let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
3445 screen_oauth_target("https://localhost/jwks.json", false, &allow)
3446 .await
3447 .expect("allowlisted CIDR must pass");
3448 }
3449
3450 #[tokio::test]
3451 async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
3452 let mut cfg = OAuthConfig::builder(
3453 "https://auth.example.com/",
3454 "mcp",
3455 "https://auth.example.com/jwks.json",
3456 )
3457 .build();
3458 cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
3459 hosts: vec![],
3460 cidrs: vec!["bad-cidr".into()],
3461 });
3462 let Err(err) = JwksCache::new(&cfg) else {
3463 panic!("invalid CIDR must fail JwksCache::new")
3464 };
3465 let msg = err.to_string();
3466 assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
3467 }
3468
3469 #[tokio::test]
3470 async fn audience_falls_back_to_azp_by_default() {
3471 let kid = "test-audience-azp-default";
3472 let (pem, jwks) = generate_test_keypair(kid);
3473
3474 let mock_server = wiremock::MockServer::start().await;
3475 wiremock::Mock::given(wiremock::matchers::method("GET"))
3476 .and(wiremock::matchers::path("/jwks.json"))
3477 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3478 .mount(&mock_server)
3479 .await;
3480
3481 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3482 let config = test_config(&jwks_uri);
3483 let cache = test_cache(&config);
3484
3485 let now = jsonwebtoken::get_current_timestamp();
3486 let token = mint_token_with_claims(
3487 &pem,
3488 kid,
3489 &serde_json::json!({
3490 "iss": "https://auth.test.local",
3491 "aud": "https://some-other-resource.example.com",
3492 "azp": "https://mcp.test.local/mcp",
3493 "sub": "compat-client",
3494 "scope": "mcp:read",
3495 "exp": now + 3600,
3496 "iat": now,
3497 }),
3498 );
3499
3500 let identity = cache
3501 .validate_token_with_reason(&token)
3502 .await
3503 .expect("azp fallback should remain enabled by default");
3504 assert_eq!(identity.role, "viewer");
3505 }
3506
3507 #[tokio::test]
3508 async fn strict_audience_validation_rejects_azp_only_match() {
3509 let kid = "test-audience-azp-strict";
3510 let (pem, jwks) = generate_test_keypair(kid);
3511
3512 let mock_server = wiremock::MockServer::start().await;
3513 wiremock::Mock::given(wiremock::matchers::method("GET"))
3514 .and(wiremock::matchers::path("/jwks.json"))
3515 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3516 .mount(&mock_server)
3517 .await;
3518
3519 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3520 let mut config = test_config(&jwks_uri);
3521 config.strict_audience_validation = true;
3522 let cache = test_cache(&config);
3523
3524 let now = jsonwebtoken::get_current_timestamp();
3525 let token = mint_token_with_claims(
3526 &pem,
3527 kid,
3528 &serde_json::json!({
3529 "iss": "https://auth.test.local",
3530 "aud": "https://some-other-resource.example.com",
3531 "azp": "https://mcp.test.local/mcp",
3532 "sub": "strict-client",
3533 "scope": "mcp:read",
3534 "exp": now + 3600,
3535 "iat": now,
3536 }),
3537 );
3538
3539 let failure = cache
3540 .validate_token_with_reason(&token)
3541 .await
3542 .expect_err("strict audience validation must ignore azp fallback");
3543 assert_eq!(failure, JwtValidationFailure::Invalid);
3544 }
3545
3546 #[derive(Clone, Default)]
3547 struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
3548
3549 impl CapturedLogs {
3550 fn contents(&self) -> String {
3551 let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
3552 String::from_utf8(bytes).unwrap_or_default()
3553 }
3554 }
3555
3556 struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
3557
3558 impl std::io::Write for CapturedLogsWriter {
3559 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
3560 if let Ok(mut guard) = self.0.lock() {
3561 guard.extend_from_slice(buf);
3562 }
3563 Ok(buf.len())
3564 }
3565
3566 fn flush(&mut self) -> std::io::Result<()> {
3567 Ok(())
3568 }
3569 }
3570
3571 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
3572 type Writer = CapturedLogsWriter;
3573
3574 fn make_writer(&'a self) -> Self::Writer {
3575 CapturedLogsWriter(Arc::clone(&self.0))
3576 }
3577 }
3578
3579 #[tokio::test]
3580 async fn jwks_response_size_cap_returns_none_and_logs_warning() {
3581 let kid = "oversized-jwks";
3582 let (_pem, jwks) = generate_test_keypair(kid);
3583 let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
3584 oversized_body.push_str(&" ".repeat(4096));
3585
3586 let mock_server = wiremock::MockServer::start().await;
3587 wiremock::Mock::given(wiremock::matchers::method("GET"))
3588 .and(wiremock::matchers::path("/jwks.json"))
3589 .respond_with(
3590 wiremock::ResponseTemplate::new(200)
3591 .insert_header("content-type", "application/json")
3592 .set_body_string(oversized_body),
3593 )
3594 .mount(&mock_server)
3595 .await;
3596
3597 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3598 let mut config = test_config(&jwks_uri);
3599 config.jwks_max_response_bytes = 256;
3600 let cache = test_cache(&config);
3601
3602 let logs = CapturedLogs::default();
3603 let subscriber = tracing_subscriber::fmt()
3604 .with_writer(logs.clone())
3605 .with_ansi(false)
3606 .without_time()
3607 .finish();
3608 let _guard = tracing::subscriber::set_default(subscriber);
3609
3610 let result = cache.fetch_jwks().await;
3611 assert!(result.is_none(), "oversized JWKS must be dropped");
3612 assert!(
3613 logs.contents()
3614 .contains("JWKS response exceeded configured size cap"),
3615 "expected cap-exceeded warning in logs"
3616 );
3617 }
3618
3619 #[tokio::test]
3620 async fn role_claim_keycloak_nested_array() {
3621 let kid = "test-role-1";
3622 let (pem, jwks) = generate_test_keypair(kid);
3623
3624 let mock_server = wiremock::MockServer::start().await;
3625 wiremock::Mock::given(wiremock::matchers::method("GET"))
3626 .and(wiremock::matchers::path("/jwks.json"))
3627 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3628 .mount(&mock_server)
3629 .await;
3630
3631 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3632 let config = test_config_with_role_claim(
3633 &jwks_uri,
3634 "realm_access.roles",
3635 vec![
3636 RoleMapping {
3637 claim_value: "mcp-admin".into(),
3638 role: "ops".into(),
3639 },
3640 RoleMapping {
3641 claim_value: "mcp-viewer".into(),
3642 role: "viewer".into(),
3643 },
3644 ],
3645 );
3646 let cache = test_cache(&config);
3647
3648 let now = jsonwebtoken::get_current_timestamp();
3649 let token = mint_token_with_claims(
3650 &pem,
3651 kid,
3652 &serde_json::json!({
3653 "iss": "https://auth.test.local",
3654 "aud": "https://mcp.test.local/mcp",
3655 "sub": "keycloak-user",
3656 "exp": now + 3600,
3657 "iat": now,
3658 "realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
3659 }),
3660 );
3661
3662 let id = cache
3663 .validate_token(&token)
3664 .await
3665 .expect("should authenticate");
3666 assert_eq!(id.name, "keycloak-user");
3667 assert_eq!(id.role, "ops");
3668 }
3669
3670 #[tokio::test]
3671 async fn role_claim_flat_roles_array() {
3672 let kid = "test-role-2";
3673 let (pem, jwks) = generate_test_keypair(kid);
3674
3675 let mock_server = wiremock::MockServer::start().await;
3676 wiremock::Mock::given(wiremock::matchers::method("GET"))
3677 .and(wiremock::matchers::path("/jwks.json"))
3678 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3679 .mount(&mock_server)
3680 .await;
3681
3682 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3683 let config = test_config_with_role_claim(
3684 &jwks_uri,
3685 "roles",
3686 vec![
3687 RoleMapping {
3688 claim_value: "MCP.Admin".into(),
3689 role: "ops".into(),
3690 },
3691 RoleMapping {
3692 claim_value: "MCP.Reader".into(),
3693 role: "viewer".into(),
3694 },
3695 ],
3696 );
3697 let cache = test_cache(&config);
3698
3699 let now = jsonwebtoken::get_current_timestamp();
3700 let token = mint_token_with_claims(
3701 &pem,
3702 kid,
3703 &serde_json::json!({
3704 "iss": "https://auth.test.local",
3705 "aud": "https://mcp.test.local/mcp",
3706 "sub": "azure-ad-user",
3707 "exp": now + 3600,
3708 "iat": now,
3709 "roles": ["MCP.Reader", "OtherApp.Admin"]
3710 }),
3711 );
3712
3713 let id = cache
3714 .validate_token(&token)
3715 .await
3716 .expect("should authenticate");
3717 assert_eq!(id.name, "azure-ad-user");
3718 assert_eq!(id.role, "viewer");
3719 }
3720
3721 #[tokio::test]
3722 async fn role_claim_no_matching_value_rejected() {
3723 let kid = "test-role-3";
3724 let (pem, jwks) = generate_test_keypair(kid);
3725
3726 let mock_server = wiremock::MockServer::start().await;
3727 wiremock::Mock::given(wiremock::matchers::method("GET"))
3728 .and(wiremock::matchers::path("/jwks.json"))
3729 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3730 .mount(&mock_server)
3731 .await;
3732
3733 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3734 let config = test_config_with_role_claim(
3735 &jwks_uri,
3736 "roles",
3737 vec![RoleMapping {
3738 claim_value: "mcp-admin".into(),
3739 role: "ops".into(),
3740 }],
3741 );
3742 let cache = test_cache(&config);
3743
3744 let now = jsonwebtoken::get_current_timestamp();
3745 let token = mint_token_with_claims(
3746 &pem,
3747 kid,
3748 &serde_json::json!({
3749 "iss": "https://auth.test.local",
3750 "aud": "https://mcp.test.local/mcp",
3751 "sub": "limited-user",
3752 "exp": now + 3600,
3753 "iat": now,
3754 "roles": ["some-other-role"]
3755 }),
3756 );
3757
3758 assert!(cache.validate_token(&token).await.is_none());
3759 }
3760
3761 #[tokio::test]
3762 async fn role_claim_space_separated_string() {
3763 let kid = "test-role-4";
3764 let (pem, jwks) = generate_test_keypair(kid);
3765
3766 let mock_server = wiremock::MockServer::start().await;
3767 wiremock::Mock::given(wiremock::matchers::method("GET"))
3768 .and(wiremock::matchers::path("/jwks.json"))
3769 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3770 .mount(&mock_server)
3771 .await;
3772
3773 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3774 let config = test_config_with_role_claim(
3775 &jwks_uri,
3776 "custom_scope",
3777 vec![
3778 RoleMapping {
3779 claim_value: "write".into(),
3780 role: "ops".into(),
3781 },
3782 RoleMapping {
3783 claim_value: "read".into(),
3784 role: "viewer".into(),
3785 },
3786 ],
3787 );
3788 let cache = test_cache(&config);
3789
3790 let now = jsonwebtoken::get_current_timestamp();
3791 let token = mint_token_with_claims(
3792 &pem,
3793 kid,
3794 &serde_json::json!({
3795 "iss": "https://auth.test.local",
3796 "aud": "https://mcp.test.local/mcp",
3797 "sub": "custom-client",
3798 "exp": now + 3600,
3799 "iat": now,
3800 "custom_scope": "read audit"
3801 }),
3802 );
3803
3804 let id = cache
3805 .validate_token(&token)
3806 .await
3807 .expect("should authenticate");
3808 assert_eq!(id.name, "custom-client");
3809 assert_eq!(id.role, "viewer");
3810 }
3811
3812 #[tokio::test]
3813 async fn scope_backward_compat_without_role_claim() {
3814 let kid = "test-compat-1";
3816 let (pem, jwks) = generate_test_keypair(kid);
3817
3818 let mock_server = wiremock::MockServer::start().await;
3819 wiremock::Mock::given(wiremock::matchers::method("GET"))
3820 .and(wiremock::matchers::path("/jwks.json"))
3821 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3822 .mount(&mock_server)
3823 .await;
3824
3825 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3826 let config = test_config(&jwks_uri); let cache = test_cache(&config);
3828
3829 let token = mint_token(
3830 &pem,
3831 kid,
3832 "https://auth.test.local",
3833 "https://mcp.test.local/mcp",
3834 "legacy-bot",
3835 "mcp:admin other:scope",
3836 );
3837
3838 let id = cache
3839 .validate_token(&token)
3840 .await
3841 .expect("should authenticate");
3842 assert_eq!(id.name, "legacy-bot");
3843 assert_eq!(id.role, "ops"); }
3845
3846 #[tokio::test]
3851 async fn jwks_refresh_deduplication() {
3852 let kid = "test-dedup";
3855 let (pem, jwks) = generate_test_keypair(kid);
3856
3857 let mock_server = wiremock::MockServer::start().await;
3858 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
3859 .and(wiremock::matchers::path("/jwks.json"))
3860 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3861 .expect(1) .mount(&mock_server)
3863 .await;
3864
3865 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3866 let config = test_config(&jwks_uri);
3867 let cache = Arc::new(test_cache(&config));
3868
3869 let token = mint_token(
3871 &pem,
3872 kid,
3873 "https://auth.test.local",
3874 "https://mcp.test.local/mcp",
3875 "concurrent-bot",
3876 "mcp:read",
3877 );
3878
3879 let mut handles = Vec::new();
3880 for _ in 0..5 {
3881 let c = Arc::clone(&cache);
3882 let t = token.clone();
3883 handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
3884 }
3885
3886 for h in handles {
3887 let result = h.await.unwrap();
3888 assert!(result.is_some(), "all concurrent requests should succeed");
3889 }
3890
3891 }
3893
3894 #[tokio::test]
3895 async fn jwks_refresh_cooldown_blocks_rapid_requests() {
3896 let kid = "test-cooldown";
3899 let (_pem, jwks) = generate_test_keypair(kid);
3900
3901 let mock_server = wiremock::MockServer::start().await;
3902 let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
3903 .and(wiremock::matchers::path("/jwks.json"))
3904 .respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
3905 .expect(1) .mount(&mock_server)
3907 .await;
3908
3909 let jwks_uri = format!("{}/jwks.json", mock_server.uri());
3910 let config = test_config(&jwks_uri);
3911 let cache = test_cache(&config);
3912
3913 let fake_token1 =
3915 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
3916 let _ = cache.validate_token(fake_token1).await;
3917
3918 let fake_token2 =
3921 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
3922 let _ = cache.validate_token(fake_token2).await;
3923
3924 let fake_token3 =
3926 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
3927 let _ = cache.validate_token(fake_token3).await;
3928
3929 }
3931
3932 fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
3935 OAuthProxyConfig {
3936 authorize_url: "https://example.invalid/auth".into(),
3937 token_url: token_url.into(),
3938 client_id: "mcp-client".into(),
3939 client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
3940 introspection_url: None,
3941 revocation_url: None,
3942 expose_admin_endpoints: false,
3943 require_auth_on_admin_endpoints: false,
3944 }
3945 }
3946
3947 fn test_http_client() -> OauthHttpClient {
3950 rustls::crypto::ring::default_provider()
3951 .install_default()
3952 .ok();
3953 let config = OAuthConfig::builder(
3954 "https://auth.test.local",
3955 "https://mcp.test.local/mcp",
3956 "https://auth.test.local/.well-known/jwks.json",
3957 )
3958 .allow_http_oauth_urls(true)
3959 .build();
3960 OauthHttpClient::with_config(&config)
3961 .expect("build test http client")
3962 .__test_allow_loopback_ssrf()
3963 }
3964
3965 #[tokio::test]
3966 async fn introspect_proxies_and_injects_client_credentials() {
3967 use wiremock::matchers::{body_string_contains, method, path};
3968
3969 let mock_server = wiremock::MockServer::start().await;
3970 wiremock::Mock::given(method("POST"))
3971 .and(path("/introspect"))
3972 .and(body_string_contains("client_id=mcp-client"))
3973 .and(body_string_contains("client_secret=shh"))
3974 .and(body_string_contains("token=abc"))
3975 .respond_with(
3976 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
3977 "active": true,
3978 "scope": "read"
3979 })),
3980 )
3981 .expect(1)
3982 .mount(&mock_server)
3983 .await;
3984
3985 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
3986 proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
3987
3988 let http = test_http_client();
3989 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3990 assert_eq!(resp.status(), 200);
3991 }
3992
3993 #[tokio::test]
3994 async fn introspect_returns_404_when_not_configured() {
3995 let proxy = proxy_cfg("https://example.invalid/token");
3996 let http = test_http_client();
3997 let resp = handle_introspect(&http, &proxy, "token=abc").await;
3998 assert_eq!(resp.status(), 404);
3999 }
4000
4001 #[tokio::test]
4002 async fn revoke_proxies_and_returns_upstream_status() {
4003 use wiremock::matchers::{method, path};
4004
4005 let mock_server = wiremock::MockServer::start().await;
4006 wiremock::Mock::given(method("POST"))
4007 .and(path("/revoke"))
4008 .respond_with(wiremock::ResponseTemplate::new(200))
4009 .expect(1)
4010 .mount(&mock_server)
4011 .await;
4012
4013 let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
4014 proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
4015
4016 let http = test_http_client();
4017 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4018 assert_eq!(resp.status(), 200);
4019 }
4020
4021 #[tokio::test]
4022 async fn revoke_returns_404_when_not_configured() {
4023 let proxy = proxy_cfg("https://example.invalid/token");
4024 let http = test_http_client();
4025 let resp = handle_revoke(&http, &proxy, "token=abc").await;
4026 assert_eq!(resp.status(), 404);
4027 }
4028
4029 #[test]
4030 fn metadata_advertises_endpoints_only_when_configured() {
4031 let mut cfg = test_config("https://auth.test.local/jwks.json");
4032 let m = authorization_server_metadata("https://mcp.local", &cfg);
4034 assert!(m.get("introspection_endpoint").is_none());
4035 assert!(m.get("revocation_endpoint").is_none());
4036
4037 let mut proxy = proxy_cfg("https://upstream.local/token");
4040 proxy.introspection_url = Some("https://upstream.local/introspect".into());
4041 proxy.revocation_url = Some("https://upstream.local/revoke".into());
4042 cfg.proxy = Some(proxy);
4043 let m = authorization_server_metadata("https://mcp.local", &cfg);
4044 assert!(
4045 m.get("introspection_endpoint").is_none(),
4046 "introspection must not be advertised when expose_admin_endpoints=false"
4047 );
4048 assert!(
4049 m.get("revocation_endpoint").is_none(),
4050 "revocation must not be advertised when expose_admin_endpoints=false"
4051 );
4052
4053 if let Some(p) = cfg.proxy.as_mut() {
4055 p.expose_admin_endpoints = true;
4056 p.revocation_url = None;
4057 }
4058 let m = authorization_server_metadata("https://mcp.local", &cfg);
4059 assert_eq!(
4060 m["introspection_endpoint"],
4061 serde_json::Value::String("https://mcp.local/introspect".into())
4062 );
4063 assert!(m.get("revocation_endpoint").is_none());
4064
4065 if let Some(p) = cfg.proxy.as_mut() {
4067 p.revocation_url = Some("https://upstream.local/revoke".into());
4068 }
4069 let m = authorization_server_metadata("https://mcp.local", &cfg);
4070 assert_eq!(
4071 m["revocation_endpoint"],
4072 serde_json::Value::String("https://mcp.local/revoke".into())
4073 );
4074 }
4075}