mcp_proxy/
introspection.rs1use std::sync::Arc;
9
10use tower_mcp::oauth::OAuthError;
11use tower_mcp::oauth::token::{TokenClaims, TokenValidator};
12
13#[derive(Debug, Clone, serde::Deserialize)]
19pub struct AuthServerMetadata {
20 pub issuer: String,
22 #[serde(default)]
24 pub jwks_uri: Option<String>,
25 #[serde(default)]
27 pub introspection_endpoint: Option<String>,
28 #[serde(default)]
30 pub token_endpoint: Option<String>,
31 #[serde(default)]
33 pub authorization_endpoint: Option<String>,
34 #[serde(default)]
36 pub scopes_supported: Vec<String>,
37 #[serde(default)]
39 pub response_types_supported: Vec<String>,
40 #[serde(default)]
42 pub grant_types_supported: Vec<String>,
43 #[serde(default)]
45 pub token_endpoint_auth_methods_supported: Vec<String>,
46}
47
48pub async fn discover_auth_server(issuer: &str) -> anyhow::Result<AuthServerMetadata> {
53 let client = reqwest::Client::new();
54 let issuer = issuer.trim_end_matches('/');
55
56 let rfc8414_url = format!("{issuer}/.well-known/oauth-authorization-server");
58 if let Ok(resp) = client.get(&rfc8414_url).send().await
59 && resp.status().is_success()
60 && let Ok(metadata) = resp.json::<AuthServerMetadata>().await
61 {
62 tracing::info!(
63 issuer = %metadata.issuer,
64 jwks_uri = ?metadata.jwks_uri,
65 introspection = ?metadata.introspection_endpoint,
66 "Discovered auth server metadata (RFC 8414)"
67 );
68 return Ok(metadata);
69 }
70
71 let oidc_url = format!("{issuer}/.well-known/openid-configuration");
73 let resp = client
74 .get(&oidc_url)
75 .send()
76 .await
77 .map_err(|e| anyhow::anyhow!("failed to discover auth server at {oidc_url}: {e}"))?;
78
79 if !resp.status().is_success() {
80 anyhow::bail!(
81 "auth server discovery failed: {} returned {}",
82 oidc_url,
83 resp.status()
84 );
85 }
86
87 let metadata = resp
88 .json::<AuthServerMetadata>()
89 .await
90 .map_err(|e| anyhow::anyhow!("failed to parse auth server metadata: {e}"))?;
91
92 tracing::info!(
93 issuer = %metadata.issuer,
94 jwks_uri = ?metadata.jwks_uri,
95 introspection = ?metadata.introspection_endpoint,
96 "Discovered auth server metadata (OIDC)"
97 );
98
99 Ok(metadata)
100}
101
102#[derive(Clone)]
111pub struct IntrospectionValidator {
112 inner: Arc<IntrospectionState>,
113}
114
115struct IntrospectionState {
116 introspection_endpoint: String,
117 client_id: String,
118 client_secret: String,
119 expected_audience: Option<String>,
120 http_client: reqwest::Client,
121}
122
123#[derive(Debug, serde::Deserialize)]
125struct IntrospectionResponse {
126 active: bool,
128 #[serde(default)]
130 sub: Option<String>,
131 #[serde(default)]
133 iss: Option<String>,
134 #[serde(default)]
136 aud: Option<serde_json::Value>,
137 #[serde(default)]
139 exp: Option<u64>,
140 #[serde(default)]
142 scope: Option<String>,
143 #[serde(default)]
145 client_id: Option<String>,
146}
147
148impl IntrospectionValidator {
149 pub fn new(introspection_endpoint: &str, client_id: &str, client_secret: &str) -> Self {
151 Self {
152 inner: Arc::new(IntrospectionState {
153 introspection_endpoint: introspection_endpoint.to_string(),
154 client_id: client_id.to_string(),
155 client_secret: client_secret.to_string(),
156 expected_audience: None,
157 http_client: reqwest::Client::new(),
158 }),
159 }
160 }
161
162 pub fn expected_audience(mut self, audience: &str) -> Self {
164 Arc::get_mut(&mut self.inner)
165 .expect("no other references")
166 .expected_audience = Some(audience.to_string());
167 self
168 }
169}
170
171impl TokenValidator for IntrospectionValidator {
172 async fn validate_token(&self, token: &str) -> Result<TokenClaims, OAuthError> {
173 let resp = self
174 .inner
175 .http_client
176 .post(&self.inner.introspection_endpoint)
177 .basic_auth(&self.inner.client_id, Some(&self.inner.client_secret))
178 .form(&[("token", token)])
179 .send()
180 .await
181 .map_err(|e| OAuthError::InvalidToken {
182 description: format!("introspection request failed: {e}"),
183 })?;
184
185 if !resp.status().is_success() {
186 return Err(OAuthError::InvalidToken {
187 description: format!("introspection endpoint returned {}", resp.status()),
188 });
189 }
190
191 let introspection: IntrospectionResponse =
192 resp.json().await.map_err(|e| OAuthError::InvalidToken {
193 description: format!("invalid introspection response: {e}"),
194 })?;
195
196 if !introspection.active {
197 return Err(OAuthError::InvalidToken {
198 description: "token is not active".to_string(),
199 });
200 }
201
202 if let Some(expected_aud) = &self.inner.expected_audience {
204 let aud_matches = match &introspection.aud {
205 Some(serde_json::Value::String(s)) => s == expected_aud,
206 Some(serde_json::Value::Array(arr)) => arr
207 .iter()
208 .any(|v| v.as_str().is_some_and(|s| s == expected_aud)),
209 _ => true, };
211 if !aud_matches {
212 return Err(OAuthError::InvalidAudience);
213 }
214 }
215
216 Ok(TokenClaims {
217 sub: introspection.sub,
218 iss: introspection.iss,
219 aud: None,
220 exp: introspection.exp,
221 scope: introspection.scope,
222 client_id: introspection.client_id,
223 extra: std::collections::HashMap::new(),
224 })
225 }
226}
227
228#[derive(Clone)]
238pub struct FallbackValidator<J: TokenValidator> {
239 jwt_validator: J,
240 introspection_validator: IntrospectionValidator,
241}
242
243impl<J: TokenValidator> FallbackValidator<J> {
244 pub fn new(jwt_validator: J, introspection_validator: IntrospectionValidator) -> Self {
247 Self {
248 jwt_validator,
249 introspection_validator,
250 }
251 }
252}
253
254impl<J: TokenValidator> TokenValidator for FallbackValidator<J> {
255 async fn validate_token(&self, token: &str) -> Result<TokenClaims, OAuthError> {
256 match self.jwt_validator.validate_token(token).await {
258 Ok(claims) => Ok(claims),
259 Err(_jwt_err) => {
260 self.introspection_validator.validate_token(token).await
262 }
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_introspection_validator_creation() {
273 let validator = IntrospectionValidator::new(
274 "https://auth.example.com/oauth/introspect",
275 "client-id",
276 "client-secret",
277 )
278 .expected_audience("mcp-proxy");
279
280 assert_eq!(
281 validator.inner.introspection_endpoint,
282 "https://auth.example.com/oauth/introspect"
283 );
284 assert_eq!(
285 validator.inner.expected_audience.as_deref(),
286 Some("mcp-proxy")
287 );
288 }
289
290 #[test]
291 fn test_fallback_validator_creation() {
292 let jwt = IntrospectionValidator::new("https://example.com/introspect", "id", "secret");
293 let introspection =
294 IntrospectionValidator::new("https://example.com/introspect", "id", "secret");
295 let _fallback = FallbackValidator::new(jwt, introspection);
296 }
297}