use super::{JwksCache, JwksClient, StandardClaims};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
use serde::Deserialize;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::OnceCell;
use tracing::{debug, error, info, warn};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
#[derive(Debug, Clone, Deserialize)]
struct OidcDiscoveryDocument {
jwks_uri: String,
#[serde(flatten)]
_additional: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct JwtValidationResult {
pub claims: StandardClaims,
pub algorithm: Algorithm,
pub key_id: Option<String>,
pub issued_at: Option<SystemTime>,
pub expires_at: Option<SystemTime>,
}
pub struct JwtValidator {
expected_issuer: String,
expected_audience: String,
jwks_client: Arc<JwksClient>,
clock_skew_leeway: Duration,
allowed_algorithms: Vec<Algorithm>,
discovered_jwks_uri: OnceCell<String>,
ssrf_validator: Option<Arc<crate::ssrf::SsrfValidator>>,
}
impl std::fmt::Debug for JwtValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtValidator")
.field("expected_issuer", &self.expected_issuer)
.field("expected_audience", &self.expected_audience)
.field("jwks_client", &self.jwks_client)
.field("clock_skew_leeway", &self.clock_skew_leeway)
.field("allowed_algorithms", &self.allowed_algorithms)
.field(
"discovered_jwks_uri",
&self.discovered_jwks_uri.get().map(|_| "<cached>"),
)
.field(
"ssrf_validator",
&self.ssrf_validator.as_ref().map(|_| "<SsrfValidator>"),
)
.finish()
}
}
impl JwtValidator {
async fn discover_jwks_uri(
issuer: &str,
ssrf_validator: Option<&crate::ssrf::SsrfValidator>,
) -> McpResult<String> {
let discovery_url = format!("{}/.well-known/openid-configuration", issuer);
debug!(
issuer = issuer,
discovery_url = %discovery_url,
"Attempting RFC 8414 OIDC discovery"
);
if let Some(validator) = ssrf_validator {
validator.validate_url(&discovery_url).map_err(|e| {
McpError::authentication(format!("SSRF validation failed for discovery URL: {e}"))
})?;
}
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| McpError::internal(format!("Failed to build HTTP client: {e}")))?;
match client.get(&discovery_url).send().await {
Ok(response) if response.status().is_success() => {
match response.json::<OidcDiscoveryDocument>().await {
Ok(doc) => {
info!(
issuer = issuer,
jwks_uri = %doc.jwks_uri,
"Successfully discovered JWKS URI via RFC 8414"
);
return Ok(doc.jwks_uri);
}
Err(e) => {
warn!(
error = %e,
issuer = issuer,
"Failed to parse OIDC discovery document, trying fallback"
);
}
}
}
Ok(response) => {
warn!(
status = %response.status(),
issuer = issuer,
"OIDC discovery endpoint returned non-success status, trying fallback"
);
}
Err(e) => {
warn!(
error = %e,
issuer = issuer,
"Failed to fetch OIDC discovery document, trying fallback"
);
}
}
let fallback_uri = format!("{}/.well-known/jwks.json", issuer);
info!(
issuer = issuer,
jwks_uri = %fallback_uri,
"Using fallback JWKS URI pattern (RFC 8414 discovery failed)"
);
Ok(fallback_uri)
}
pub async fn new(expected_issuer: String, expected_audience: String) -> McpResult<Self> {
let jwks_uri = Self::discover_jwks_uri(&expected_issuer, None).await?;
let jwks_client = Arc::new(JwksClient::new(jwks_uri.clone()));
Ok(Self {
expected_issuer,
expected_audience,
jwks_client,
clock_skew_leeway: Duration::from_secs(60), allowed_algorithms: vec![
Algorithm::ES256, Algorithm::RS256, Algorithm::PS256, ],
discovered_jwks_uri: OnceCell::new_with(Some(jwks_uri)),
ssrf_validator: None,
})
}
pub async fn new_with_ssrf(
expected_issuer: String,
expected_audience: String,
ssrf_validator: Arc<crate::ssrf::SsrfValidator>,
) -> McpResult<Self> {
let jwks_uri =
Self::discover_jwks_uri(&expected_issuer, Some(ssrf_validator.as_ref())).await?;
let jwks_client = Arc::new(JwksClient::new(jwks_uri.clone()));
Ok(Self {
expected_issuer,
expected_audience,
jwks_client,
clock_skew_leeway: Duration::from_secs(60),
allowed_algorithms: vec![Algorithm::ES256, Algorithm::RS256, Algorithm::PS256],
discovered_jwks_uri: OnceCell::new_with(Some(jwks_uri)),
ssrf_validator: Some(ssrf_validator),
})
}
pub fn with_jwks_uri(
expected_issuer: String,
expected_audience: String,
jwks_uri: String,
) -> Self {
let jwks_client = Arc::new(JwksClient::new(jwks_uri.clone()));
Self {
expected_issuer,
expected_audience,
jwks_client,
clock_skew_leeway: Duration::from_secs(60),
allowed_algorithms: vec![Algorithm::ES256, Algorithm::RS256, Algorithm::PS256],
discovered_jwks_uri: OnceCell::new_with(Some(jwks_uri)),
ssrf_validator: None,
}
}
pub fn with_jwks_client(
expected_issuer: String,
expected_audience: String,
jwks_client: Arc<JwksClient>,
) -> Self {
Self {
expected_issuer,
expected_audience,
jwks_client,
clock_skew_leeway: Duration::from_secs(60),
allowed_algorithms: vec![Algorithm::ES256, Algorithm::RS256, Algorithm::PS256],
discovered_jwks_uri: OnceCell::new(), ssrf_validator: None,
}
}
pub fn with_ssrf_validator(mut self, ssrf_validator: Arc<crate::ssrf::SsrfValidator>) -> Self {
self.ssrf_validator = Some(ssrf_validator);
self
}
pub fn with_clock_skew(mut self, leeway: Duration) -> Self {
self.clock_skew_leeway = leeway;
self
}
pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
self.allowed_algorithms = algorithms;
self
}
pub async fn validate(&self, token: &str) -> McpResult<JwtValidationResult> {
let header = decode_header(token).map_err(|e| {
debug!(error = %e, "Failed to decode JWT header");
McpError::invalid_params(format!("Invalid JWT format: {e}"))
})?;
if !self.allowed_algorithms.contains(&header.alg) {
error!(
algorithm = ?header.alg,
allowed = ?self.allowed_algorithms,
"JWT algorithm not allowed"
);
return Err(McpError::invalid_params(format!(
"Algorithm {:?} not allowed",
header.alg
)));
}
let key_id = header.kid.clone().ok_or_else(|| {
error!("JWT missing kid (key ID) in header");
McpError::invalid_params("JWT must include kid (key ID) in header".to_string())
})?;
let decoding_key = self.get_decoding_key(&key_id, header.alg).await?;
let mut validation = Validation::new(header.alg);
validation.set_audience(&[&self.expected_audience]);
validation.set_issuer(&[&self.expected_issuer]);
validation.leeway = self.clock_skew_leeway.as_secs();
let token_data: TokenData<StandardClaims> = decode(token, &decoding_key, &validation)
.map_err(|e| {
warn!(
error = %e,
issuer = %self.expected_issuer,
audience = %self.expected_audience,
"JWT validation failed"
);
McpError::invalid_params(format!("JWT validation failed: {e}"))
})?;
let issued_at = token_data
.claims
.iat
.map(|iat| UNIX_EPOCH + Duration::from_secs(iat));
let expires_at = token_data
.claims
.exp
.map(|exp| UNIX_EPOCH + Duration::from_secs(exp));
debug!(
issuer = %self.expected_issuer,
audience = %self.expected_audience,
subject = ?token_data.claims.sub,
algorithm = ?header.alg,
"JWT validation successful"
);
Ok(JwtValidationResult {
claims: token_data.claims,
algorithm: header.alg,
key_id: Some(key_id),
issued_at,
expires_at,
})
}
pub async fn validate_with_refresh(&self, token: &str) -> McpResult<JwtValidationResult> {
match self.validate(token).await {
Ok(result) => Ok(result),
Err(first_error) => {
warn!(
error = %first_error,
"JWT validation failed, refreshing JWKS and retrying"
);
self.jwks_client.refresh().await?;
self.validate(token).await.map_err(|e| {
error!(error = %e, "JWT validation failed after JWKS refresh");
e
})
}
}
}
async fn get_decoding_key(
&self,
key_id: &str,
_algorithm: Algorithm,
) -> McpResult<DecodingKey> {
let jwks = self.jwks_client.get_jwks().await?;
let jwk = jwks.find(key_id).ok_or_else(|| {
error!(key_id = key_id, "Key ID not found in JWKS");
McpError::invalid_params(format!("Key ID '{key_id}' not found in JWKS"))
})?;
DecodingKey::from_jwk(jwk).map_err(|e| {
error!(key_id = key_id, error = %e, "Failed to create decoding key from JWK");
McpError::internal(format!("Invalid JWK: {e}"))
})
}
pub fn expected_issuer(&self) -> &str {
&self.expected_issuer
}
pub fn expected_audience(&self) -> &str {
&self.expected_audience
}
}
#[derive(Debug)]
pub struct MultiIssuerValidator {
expected_audience: String,
validators: std::collections::HashMap<String, Arc<JwtValidator>>,
#[allow(dead_code)]
jwks_cache: Arc<JwksCache>,
}
impl MultiIssuerValidator {
pub fn new(expected_audience: String) -> Self {
Self {
expected_audience,
validators: std::collections::HashMap::new(),
jwks_cache: Arc::new(JwksCache::new()),
}
}
pub async fn add_issuer(&mut self, issuer: String) -> McpResult<()> {
let jwks_uri = JwtValidator::discover_jwks_uri(&issuer, None).await?;
let jwks_client = Arc::new(JwksClient::new(jwks_uri));
let validator = Arc::new(JwtValidator::with_jwks_client(
issuer.clone(),
self.expected_audience.clone(),
jwks_client,
));
self.validators.insert(issuer, validator);
Ok(())
}
pub async fn add_issuer_with_ssrf(
&mut self,
issuer: String,
ssrf_validator: Arc<crate::ssrf::SsrfValidator>,
) -> McpResult<()> {
let jwks_uri =
JwtValidator::discover_jwks_uri(&issuer, Some(ssrf_validator.as_ref())).await?;
let jwks_client = Arc::new(JwksClient::new(jwks_uri));
let validator = Arc::new(
JwtValidator::with_jwks_client(
issuer.clone(),
self.expected_audience.clone(),
jwks_client,
)
.with_ssrf_validator(ssrf_validator),
);
self.validators.insert(issuer, validator);
Ok(())
}
pub fn add_issuer_with_jwks_uri(&mut self, issuer: String, jwks_uri: String) {
let jwks_client = Arc::new(JwksClient::new(jwks_uri));
let validator = Arc::new(JwtValidator::with_jwks_client(
issuer.clone(),
self.expected_audience.clone(),
jwks_client,
));
self.validators.insert(issuer, validator);
}
pub async fn validate(&self, token: &str) -> McpResult<JwtValidationResult> {
let header = decode_header(token)
.map_err(|e| McpError::invalid_params(format!("Invalid JWT format: {e}")))?;
const ALLOWED_ALGORITHMS: &[Algorithm] = &[
Algorithm::ES256,
Algorithm::ES384,
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
];
if !ALLOWED_ALGORITHMS.contains(&header.alg) {
error!(algorithm = ?header.alg, "JWT algorithm not in allowlist");
return Err(McpError::invalid_params(format!(
"JWT algorithm {:?} not allowed. Only asymmetric algorithms (ES*, RS*, PS*) are permitted.",
header.alg
)));
}
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(McpError::invalid_params("Invalid JWT format".to_string()));
}
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let payload = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| McpError::invalid_params(format!("Invalid JWT payload encoding: {e}")))?;
let claims: StandardClaims = serde_json::from_slice(&payload)
.map_err(|e| McpError::invalid_params(format!("Invalid JWT claims: {e}")))?;
let issuer = claims.iss.ok_or_else(|| {
McpError::invalid_params("JWT missing iss (issuer) claim".to_string())
})?;
let validator = self.validators.get(&issuer).ok_or_else(|| {
error!(issuer = %issuer, "Unknown issuer");
McpError::invalid_params(format!("Issuer '{}' not supported", issuer))
})?;
validator.validate_with_refresh(token).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_validator_creation_with_jwks_uri() {
let validator = JwtValidator::with_jwks_uri(
"https://auth.example.com".to_string(),
"https://mcp.example.com".to_string(),
"https://auth.example.com/jwks".to_string(),
);
assert_eq!(validator.expected_issuer(), "https://auth.example.com");
assert_eq!(validator.expected_audience(), "https://mcp.example.com");
assert_eq!(validator.clock_skew_leeway, Duration::from_secs(60));
assert_eq!(validator.allowed_algorithms.len(), 3);
}
#[test]
fn test_jwt_validator_custom_clock_skew() {
let validator = JwtValidator::with_jwks_uri(
"https://auth.example.com".to_string(),
"https://mcp.example.com".to_string(),
"https://auth.example.com/jwks".to_string(),
)
.with_clock_skew(Duration::from_secs(30));
assert_eq!(validator.clock_skew_leeway, Duration::from_secs(30));
}
#[test]
fn test_jwt_validator_custom_algorithms() {
let validator = JwtValidator::with_jwks_uri(
"https://auth.example.com".to_string(),
"https://mcp.example.com".to_string(),
"https://auth.example.com/jwks".to_string(),
)
.with_algorithms(vec![Algorithm::ES256]);
assert_eq!(validator.allowed_algorithms, vec![Algorithm::ES256]);
}
#[test]
fn test_multi_issuer_validator_creation() {
let validator = MultiIssuerValidator::new("https://mcp.example.com".to_string());
assert_eq!(validator.expected_audience, "https://mcp.example.com");
assert_eq!(validator.validators.len(), 0);
}
#[test]
fn test_multi_issuer_validator_add_issuer() {
let mut validator = MultiIssuerValidator::new("https://mcp.example.com".to_string());
validator.add_issuer_with_jwks_uri(
"https://auth.example.com".to_string(),
"https://auth.example.com/jwks".to_string(),
);
assert_eq!(validator.validators.len(), 1);
assert!(
validator
.validators
.contains_key("https://auth.example.com")
);
}
}