use super::types::{
AuthorizationServerMetadata, DiscoveryError, OIDCProviderMetadata, ValidatedDiscoveryMetadata,
};
use crate::ssrf::{SsrfError, SsrfValidator};
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Debug, Error)]
pub enum FetcherError {
#[error("SSRF protection blocked request: {0}")]
SsrfBlocked(#[from] SsrfError),
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("Response size limit exceeded")]
ResponseTooLarge,
#[error("Invalid JSON response: {0}")]
InvalidJson(String),
#[error("Discovery validation failed: {0}")]
ValidationFailed(#[from] DiscoveryError),
#[error("All discovery endpoints failed. RFC 8414: {oauth2_error}, OIDC: {oidc_error}")]
AllEndpointsFailed {
oauth2_error: String,
oidc_error: String,
},
#[error("Invalid issuer URL: {0}")]
InvalidIssuer(String),
#[error("Cache error: {0}")]
CacheError(String),
}
#[derive(Debug, Clone)]
struct CacheEntry {
metadata: ValidatedDiscoveryMetadata,
expires_at: SystemTime,
}
#[derive(Debug, Clone)]
pub struct FetcherConfig {
pub max_response_size: usize,
pub request_timeout: Duration,
pub default_cache_ttl: Duration,
pub max_cache_ttl: Duration,
pub user_agent: String,
pub fallback_to_oidc: bool,
}
impl Default for FetcherConfig {
fn default() -> Self {
Self {
max_response_size: 10 * 1024, request_timeout: Duration::from_secs(5),
default_cache_ttl: Duration::from_secs(3600), max_cache_ttl: Duration::from_secs(86400), user_agent: format!("TurboMCP/{}", env!("CARGO_PKG_VERSION")),
fallback_to_oidc: true,
}
}
}
pub struct DiscoveryFetcher {
client: reqwest::Client,
ssrf_validator: Arc<SsrfValidator>,
config: FetcherConfig,
cache: Arc<DashMap<String, CacheEntry>>,
}
impl DiscoveryFetcher {
pub fn new(ssrf_validator: SsrfValidator) -> Result<Self, FetcherError> {
Self::with_config(ssrf_validator, FetcherConfig::default())
}
pub fn with_config(
ssrf_validator: SsrfValidator,
config: FetcherConfig,
) -> Result<Self, FetcherError> {
let client = reqwest::Client::builder()
.timeout(config.request_timeout)
.user_agent(&config.user_agent)
.redirect(reqwest::redirect::Policy::none()) .build()
.map_err(|e| FetcherError::HttpError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
client,
ssrf_validator: Arc::new(ssrf_validator),
config,
cache: Arc::new(DashMap::new()),
})
}
pub async fn fetch(&self, issuer: &str) -> Result<ValidatedDiscoveryMetadata, FetcherError> {
let issuer_url = url::Url::parse(issuer)
.map_err(|e| FetcherError::InvalidIssuer(format!("Invalid URL: {}", e)))?;
if issuer_url.scheme() != "https" {
return Err(FetcherError::InvalidIssuer(
"Issuer MUST use https scheme".to_string(),
));
}
if let Some(cached) = self.get_cached(issuer) {
debug!("Returning cached discovery metadata for: {}", issuer);
return Ok(cached);
}
let oauth2_url = self.build_oauth2_discovery_url(&issuer_url)?;
debug!("Trying RFC 8414 discovery: {}", oauth2_url);
match self.fetch_oauth2(&oauth2_url, issuer).await {
Ok(metadata) => {
debug!("Successfully fetched RFC 8414 metadata for: {}", issuer);
Ok(metadata)
}
Err(e) => {
debug!("RFC 8414 discovery failed: {}", e);
if self.config.fallback_to_oidc {
let oidc_url = self.build_oidc_discovery_url(&issuer_url)?;
debug!("Trying OIDC Discovery fallback: {}", oidc_url);
match self.fetch_oidc(&oidc_url, issuer).await {
Ok(metadata) => {
debug!("Successfully fetched OIDC metadata for: {}", issuer);
Ok(metadata)
}
Err(oidc_error) => {
warn!(
"Both RFC 8414 and OIDC Discovery failed for issuer: {}",
issuer
);
Err(FetcherError::AllEndpointsFailed {
oauth2_error: e.to_string(),
oidc_error: oidc_error.to_string(),
})
}
}
} else {
Err(e)
}
}
}
}
fn build_oauth2_discovery_url(&self, issuer: &url::Url) -> Result<String, FetcherError> {
let mut url = issuer.clone();
let path = url.path().trim_end_matches('/');
let discovery_path = if path.is_empty() || path == "/" {
"/.well-known/oauth-authorization-server".to_string()
} else {
format!("/.well-known/oauth-authorization-server{}", path)
};
url.set_path(&discovery_path);
Ok(url.to_string())
}
fn build_oidc_discovery_url(&self, issuer: &url::Url) -> Result<String, FetcherError> {
let mut url = issuer.clone();
url.set_path("/.well-known/openid-configuration");
Ok(url.to_string())
}
async fn fetch_oauth2(
&self,
discovery_url: &str,
issuer: &str,
) -> Result<ValidatedDiscoveryMetadata, FetcherError> {
self.ssrf_validator.validate_url(discovery_url)?;
let response = self
.client
.get(discovery_url)
.send()
.await
.map_err(|e| FetcherError::HttpError(format!("Request failed: {}", e)))?;
if !response.status().is_success() {
return Err(FetcherError::HttpError(format!(
"HTTP {} {}",
response.status().as_u16(),
response.status().canonical_reason().unwrap_or("Unknown")
)));
}
let cache_ttl = self.parse_cache_headers(&response);
if let Some(content_length) = response.content_length()
&& content_length > self.config.max_response_size as u64
{
return Err(FetcherError::ResponseTooLarge);
}
let body = response
.bytes()
.await
.map_err(|e| FetcherError::HttpError(format!("Failed to read response: {}", e)))?;
if body.len() > self.config.max_response_size {
return Err(FetcherError::ResponseTooLarge);
}
let metadata: AuthorizationServerMetadata = serde_json::from_slice(&body)
.map_err(|e| FetcherError::InvalidJson(format!("Failed to parse JSON: {}", e)))?;
let validated = ValidatedDiscoveryMetadata::new_oauth2(metadata, issuer.to_string())?;
self.cache_metadata(issuer, validated.clone(), cache_ttl);
Ok(validated)
}
async fn fetch_oidc(
&self,
discovery_url: &str,
issuer: &str,
) -> Result<ValidatedDiscoveryMetadata, FetcherError> {
self.ssrf_validator.validate_url(discovery_url)?;
let response = self
.client
.get(discovery_url)
.send()
.await
.map_err(|e| FetcherError::HttpError(format!("Request failed: {}", e)))?;
if !response.status().is_success() {
return Err(FetcherError::HttpError(format!(
"HTTP {} {}",
response.status().as_u16(),
response.status().canonical_reason().unwrap_or("Unknown")
)));
}
let cache_ttl = self.parse_cache_headers(&response);
if let Some(content_length) = response.content_length()
&& content_length > self.config.max_response_size as u64
{
return Err(FetcherError::ResponseTooLarge);
}
let body = response
.bytes()
.await
.map_err(|e| FetcherError::HttpError(format!("Failed to read response: {}", e)))?;
if body.len() > self.config.max_response_size {
return Err(FetcherError::ResponseTooLarge);
}
let metadata: OIDCProviderMetadata = serde_json::from_slice(&body)
.map_err(|e| FetcherError::InvalidJson(format!("Failed to parse JSON: {}", e)))?;
let validated = ValidatedDiscoveryMetadata::new_oidc(metadata, issuer.to_string())?;
self.cache_metadata(issuer, validated.clone(), cache_ttl);
Ok(validated)
}
fn get_cached(&self, issuer: &str) -> Option<ValidatedDiscoveryMetadata> {
if let Some(entry) = self.cache.get(issuer) {
let now = SystemTime::now();
if now < entry.expires_at {
return Some(entry.metadata.clone());
} else {
drop(entry); self.cache.remove(issuer);
}
}
None
}
fn cache_metadata(&self, issuer: &str, metadata: ValidatedDiscoveryMetadata, ttl: Duration) {
let expires_at = SystemTime::now() + ttl;
debug!(
"Caching discovery metadata for {} with TTL of {}s",
issuer,
ttl.as_secs()
);
self.cache.insert(
issuer.to_string(),
CacheEntry {
metadata,
expires_at,
},
);
}
fn parse_cache_headers(&self, response: &reqwest::Response) -> Duration {
if let Some(cache_control) = response.headers().get("cache-control")
&& let Ok(value) = cache_control.to_str()
{
for directive in value.split(',') {
let directive = directive.trim();
if let Some(max_age) = directive.strip_prefix("max-age=")
&& let Ok(seconds) = max_age.parse::<u64>()
{
let ttl = Duration::from_secs(seconds);
return ttl.min(self.config.max_cache_ttl);
}
}
if value.contains("no-cache") || value.contains("no-store") {
return Duration::from_secs(0);
}
}
self.config.default_cache_ttl
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn cache_stats(&self) -> CacheStats {
let total_entries = self.cache.len();
let mut expired = 0;
let now = SystemTime::now();
for entry in self.cache.iter() {
if now >= entry.expires_at {
expired += 1;
}
}
CacheStats {
total_entries,
expired_entries: expired,
valid_entries: total_entries - expired,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub expired_entries: usize,
pub valid_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ssrf::SsrfValidator;
#[test]
fn test_fetcher_creation() {
let validator = SsrfValidator::default();
let fetcher = DiscoveryFetcher::new(validator);
assert!(fetcher.is_ok());
}
#[test]
fn test_oauth2_discovery_url_building() {
let validator = SsrfValidator::default();
let fetcher = DiscoveryFetcher::new(validator).unwrap();
let issuer = url::Url::parse("https://example.com").unwrap();
let url = fetcher.build_oauth2_discovery_url(&issuer).unwrap();
assert_eq!(
url,
"https://example.com/.well-known/oauth-authorization-server"
);
let issuer = url::Url::parse("https://example.com/issuer1").unwrap();
let url = fetcher.build_oauth2_discovery_url(&issuer).unwrap();
assert_eq!(
url,
"https://example.com/.well-known/oauth-authorization-server/issuer1"
);
}
#[test]
fn test_oidc_discovery_url_building() {
let validator = SsrfValidator::default();
let fetcher = DiscoveryFetcher::new(validator).unwrap();
let issuer = url::Url::parse("https://example.com").unwrap();
let url = fetcher.build_oidc_discovery_url(&issuer).unwrap();
assert_eq!(url, "https://example.com/.well-known/openid-configuration");
let issuer = url::Url::parse("https://example.com/issuer1").unwrap();
let url = fetcher.build_oidc_discovery_url(&issuer).unwrap();
assert_eq!(url, "https://example.com/.well-known/openid-configuration");
}
#[test]
fn test_cache_ttl_parsing() {
let validator = SsrfValidator::default();
let fetcher = DiscoveryFetcher::new(validator).unwrap();
let response = reqwest::Response::from(
http::Response::builder()
.header("cache-control", "max-age=3600")
.body("")
.unwrap(),
);
let ttl = fetcher.parse_cache_headers(&response);
assert_eq!(ttl, Duration::from_secs(3600));
}
#[test]
fn test_cache_stats() {
let validator = SsrfValidator::default();
let fetcher = DiscoveryFetcher::new(validator).unwrap();
let stats = fetcher.cache_stats();
assert_eq!(stats.total_entries, 0);
assert_eq!(stats.valid_entries, 0);
assert_eq!(stats.expired_entries, 0);
}
}