use std::sync::Arc;
use tower_mcp::oauth::OAuthError;
use tower_mcp::oauth::token::{TokenClaims, TokenValidator};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct AuthServerMetadata {
pub issuer: String,
#[serde(default)]
pub jwks_uri: Option<String>,
#[serde(default)]
pub introspection_endpoint: Option<String>,
#[serde(default)]
pub token_endpoint: Option<String>,
#[serde(default)]
pub authorization_endpoint: Option<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
#[serde(default)]
pub response_types_supported: Vec<String>,
#[serde(default)]
pub grant_types_supported: Vec<String>,
#[serde(default)]
pub token_endpoint_auth_methods_supported: Vec<String>,
}
pub async fn discover_auth_server(issuer: &str) -> anyhow::Result<AuthServerMetadata> {
let client = reqwest::Client::new();
let issuer = issuer.trim_end_matches('/');
let rfc8414_url = format!("{issuer}/.well-known/oauth-authorization-server");
if let Ok(resp) = client.get(&rfc8414_url).send().await
&& resp.status().is_success()
&& let Ok(metadata) = resp.json::<AuthServerMetadata>().await
{
tracing::info!(
issuer = %metadata.issuer,
jwks_uri = ?metadata.jwks_uri,
introspection = ?metadata.introspection_endpoint,
"Discovered auth server metadata (RFC 8414)"
);
return Ok(metadata);
}
let oidc_url = format!("{issuer}/.well-known/openid-configuration");
let resp = client
.get(&oidc_url)
.send()
.await
.map_err(|e| anyhow::anyhow!("failed to discover auth server at {oidc_url}: {e}"))?;
if !resp.status().is_success() {
anyhow::bail!(
"auth server discovery failed: {} returned {}",
oidc_url,
resp.status()
);
}
let metadata = resp
.json::<AuthServerMetadata>()
.await
.map_err(|e| anyhow::anyhow!("failed to parse auth server metadata: {e}"))?;
tracing::info!(
issuer = %metadata.issuer,
jwks_uri = ?metadata.jwks_uri,
introspection = ?metadata.introspection_endpoint,
"Discovered auth server metadata (OIDC)"
);
Ok(metadata)
}
#[derive(Clone)]
pub struct IntrospectionValidator {
inner: Arc<IntrospectionState>,
}
struct IntrospectionState {
introspection_endpoint: String,
client_id: String,
client_secret: String,
expected_audience: Option<String>,
http_client: reqwest::Client,
}
#[derive(Debug, serde::Deserialize)]
struct IntrospectionResponse {
active: bool,
#[serde(default)]
sub: Option<String>,
#[serde(default)]
iss: Option<String>,
#[serde(default)]
aud: Option<serde_json::Value>,
#[serde(default)]
exp: Option<u64>,
#[serde(default)]
scope: Option<String>,
#[serde(default)]
client_id: Option<String>,
}
impl IntrospectionValidator {
pub fn new(introspection_endpoint: &str, client_id: &str, client_secret: &str) -> Self {
Self {
inner: Arc::new(IntrospectionState {
introspection_endpoint: introspection_endpoint.to_string(),
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
expected_audience: None,
http_client: reqwest::Client::new(),
}),
}
}
pub fn expected_audience(mut self, audience: &str) -> Self {
Arc::get_mut(&mut self.inner)
.expect("no other references")
.expected_audience = Some(audience.to_string());
self
}
}
impl TokenValidator for IntrospectionValidator {
async fn validate_token(&self, token: &str) -> Result<TokenClaims, OAuthError> {
let resp = self
.inner
.http_client
.post(&self.inner.introspection_endpoint)
.basic_auth(&self.inner.client_id, Some(&self.inner.client_secret))
.form(&[("token", token)])
.send()
.await
.map_err(|e| OAuthError::InvalidToken {
description: format!("introspection request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(OAuthError::InvalidToken {
description: format!("introspection endpoint returned {}", resp.status()),
});
}
let introspection: IntrospectionResponse =
resp.json().await.map_err(|e| OAuthError::InvalidToken {
description: format!("invalid introspection response: {e}"),
})?;
if !introspection.active {
return Err(OAuthError::InvalidToken {
description: "token is not active".to_string(),
});
}
if let Some(expected_aud) = &self.inner.expected_audience {
let aud_matches = match &introspection.aud {
Some(serde_json::Value::String(s)) => s == expected_aud,
Some(serde_json::Value::Array(arr)) => arr
.iter()
.any(|v| v.as_str().is_some_and(|s| s == expected_aud)),
_ => true, };
if !aud_matches {
return Err(OAuthError::InvalidAudience);
}
}
Ok(TokenClaims {
sub: introspection.sub,
iss: introspection.iss,
aud: None,
exp: introspection.exp,
scope: introspection.scope,
client_id: introspection.client_id,
extra: std::collections::HashMap::new(),
})
}
}
#[derive(Clone)]
pub struct FallbackValidator<J: TokenValidator> {
jwt_validator: J,
introspection_validator: IntrospectionValidator,
}
impl<J: TokenValidator> FallbackValidator<J> {
pub fn new(jwt_validator: J, introspection_validator: IntrospectionValidator) -> Self {
Self {
jwt_validator,
introspection_validator,
}
}
}
impl<J: TokenValidator> TokenValidator for FallbackValidator<J> {
async fn validate_token(&self, token: &str) -> Result<TokenClaims, OAuthError> {
match self.jwt_validator.validate_token(token).await {
Ok(claims) => Ok(claims),
Err(_jwt_err) => {
self.introspection_validator.validate_token(token).await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_introspection_validator_creation() {
let validator = IntrospectionValidator::new(
"https://auth.example.com/oauth/introspect",
"client-id",
"client-secret",
)
.expected_audience("mcp-proxy");
assert_eq!(
validator.inner.introspection_endpoint,
"https://auth.example.com/oauth/introspect"
);
assert_eq!(
validator.inner.expected_audience.as_deref(),
Some("mcp-proxy")
);
}
#[test]
fn test_fallback_validator_creation() {
let jwt = IntrospectionValidator::new("https://example.com/introspect", "id", "secret");
let introspection =
IntrospectionValidator::new("https://example.com/introspect", "id", "secret");
let _fallback = FallbackValidator::new(jwt, introspection);
}
}