use super::types::{ClientMetadata, ClientMetadataError, ValidatedClientMetadata};
use crate::ssrf::{SsrfError, SsrfValidator};
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Debug, Error)]
pub enum FetcherError {
#[error("SSRF protection blocked request: {0}")]
SsrfBlocked(#[from] SsrfError),
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("Response size limit exceeded")]
ResponseTooLarge,
#[error("Invalid JSON response: {0}")]
InvalidJson(String),
#[error("Metadata validation failed: {0}")]
ValidationFailed(#[from] ClientMetadataError),
#[error("Rate limit exceeded for client_id: {0}")]
RateLimitExceeded(String),
#[error("Cache error: {0}")]
CacheError(String),
}
#[derive(Debug, Clone)]
struct CacheEntry {
metadata: ValidatedClientMetadata,
expires_at: SystemTime,
}
#[derive(Debug, Clone)]
struct RateLimitEntry {
count: u32,
window_start: SystemTime,
}
#[derive(Debug, Clone)]
pub struct FetcherConfig {
pub max_response_size: usize,
pub request_timeout: Duration,
pub default_cache_ttl: Duration,
pub max_cache_ttl: Duration,
pub rate_limit_max_requests: u32,
pub rate_limit_window: Duration,
pub user_agent: String,
}
impl Default for FetcherConfig {
fn default() -> Self {
Self {
max_response_size: 5 * 1024, request_timeout: Duration::from_secs(5),
default_cache_ttl: Duration::from_secs(3600), max_cache_ttl: Duration::from_secs(86400), rate_limit_max_requests: 10,
rate_limit_window: Duration::from_secs(60), user_agent: format!("TurboMCP/{}", env!("CARGO_PKG_VERSION")),
}
}
}
pub struct MetadataFetcher {
client: reqwest::Client,
ssrf_validator: Arc<SsrfValidator>,
config: FetcherConfig,
cache: Arc<DashMap<String, CacheEntry>>,
rate_limits: Arc<DashMap<String, RateLimitEntry>>,
}
impl MetadataFetcher {
pub fn new(ssrf_validator: SsrfValidator) -> Result<Self, FetcherError> {
Self::with_config(ssrf_validator, FetcherConfig::default())
}
pub fn with_config(
ssrf_validator: SsrfValidator,
config: FetcherConfig,
) -> Result<Self, FetcherError> {
let client = reqwest::Client::builder()
.timeout(config.request_timeout)
.user_agent(&config.user_agent)
.redirect(reqwest::redirect::Policy::none()) .build()
.map_err(|e| FetcherError::HttpError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
client,
ssrf_validator: Arc::new(ssrf_validator),
config,
cache: Arc::new(DashMap::new()),
rate_limits: Arc::new(DashMap::new()),
})
}
pub async fn fetch(
&self,
client_id_url: &str,
) -> Result<ValidatedClientMetadata, FetcherError> {
debug!("Validating client_id URL: {}", client_id_url);
self.ssrf_validator.validate_url(client_id_url)?;
self.check_rate_limit(client_id_url)?;
if let Some(cached) = self.get_cached(client_id_url) {
debug!("Returning cached metadata for: {}", client_id_url);
return Ok(cached);
}
debug!("Fetching metadata from network: {}", client_id_url);
let response = self
.client
.get(client_id_url)
.send()
.await
.map_err(|e| FetcherError::HttpError(format!("Request failed: {}", e)))?;
if !response.status().is_success() {
warn!(
"Non-success status {} for client_id: {}",
response.status(),
client_id_url
);
return Err(FetcherError::HttpError(format!(
"HTTP {} {}",
response.status().as_u16(),
response.status().canonical_reason().unwrap_or("Unknown")
)));
}
let cache_ttl = self.parse_cache_headers(&response);
if let Some(content_length) = response.content_length()
&& content_length > self.config.max_response_size as u64
{
return Err(FetcherError::ResponseTooLarge);
}
let body = response
.bytes()
.await
.map_err(|e| FetcherError::HttpError(format!("Failed to read response: {}", e)))?;
if body.len() > self.config.max_response_size {
return Err(FetcherError::ResponseTooLarge);
}
let metadata: ClientMetadata = serde_json::from_slice(&body)
.map_err(|e| FetcherError::InvalidJson(format!("Failed to parse JSON: {}", e)))?;
let validated = ValidatedClientMetadata::new(metadata, client_id_url.to_string())?;
self.cache_metadata(client_id_url, validated.clone(), cache_ttl);
Ok(validated)
}
fn check_rate_limit(&self, client_id: &str) -> Result<(), FetcherError> {
let now = SystemTime::now();
let mut entry = self
.rate_limits
.entry(client_id.to_string())
.or_insert(RateLimitEntry {
count: 0,
window_start: now,
});
if let Ok(elapsed) = now.duration_since(entry.window_start)
&& elapsed >= self.config.rate_limit_window
{
entry.count = 0;
entry.window_start = now;
}
if entry.count >= self.config.rate_limit_max_requests {
warn!("Rate limit exceeded for client_id: {}", client_id);
return Err(FetcherError::RateLimitExceeded(client_id.to_string()));
}
entry.count += 1;
Ok(())
}
fn get_cached(&self, client_id: &str) -> Option<ValidatedClientMetadata> {
if let Some(entry) = self.cache.get(client_id) {
let now = SystemTime::now();
if now < entry.expires_at {
return Some(entry.metadata.clone());
} else {
drop(entry); self.cache.remove(client_id);
}
}
None
}
fn cache_metadata(&self, client_id: &str, metadata: ValidatedClientMetadata, ttl: Duration) {
let expires_at = SystemTime::now() + ttl;
debug!(
"Caching metadata for {} with TTL of {}s",
client_id,
ttl.as_secs()
);
self.cache.insert(
client_id.to_string(),
CacheEntry {
metadata,
expires_at,
},
);
}
fn parse_cache_headers(&self, response: &reqwest::Response) -> Duration {
if let Some(cache_control) = response.headers().get("cache-control")
&& let Ok(value) = cache_control.to_str()
{
for directive in value.split(',') {
let directive = directive.trim();
if let Some(max_age) = directive.strip_prefix("max-age=")
&& let Ok(seconds) = max_age.parse::<u64>()
{
let ttl = Duration::from_secs(seconds);
return ttl.min(self.config.max_cache_ttl);
}
}
if value.contains("no-cache") || value.contains("no-store") {
return Duration::from_secs(0);
}
}
self.config.default_cache_ttl
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn cache_stats(&self) -> CacheStats {
let total_entries = self.cache.len();
let mut expired = 0;
let now = SystemTime::now();
for entry in self.cache.iter() {
if now >= entry.expires_at {
expired += 1;
}
}
CacheStats {
total_entries,
expired_entries: expired,
valid_entries: total_entries - expired,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub expired_entries: usize,
pub valid_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ssrf::SsrfPolicy;
#[test]
fn test_fetcher_creation() {
let validator = SsrfValidator::default();
let fetcher = MetadataFetcher::new(validator);
assert!(fetcher.is_ok());
}
#[test]
fn test_cache_ttl_parsing() {
let validator = SsrfValidator::default();
let fetcher = MetadataFetcher::new(validator).unwrap();
let response = reqwest::Response::from(
http::Response::builder()
.header("cache-control", "max-age=3600")
.body("")
.unwrap(),
);
let ttl = fetcher.parse_cache_headers(&response);
assert_eq!(ttl, Duration::from_secs(3600));
}
#[test]
fn test_cache_stats() {
let validator = SsrfValidator::default();
let fetcher = MetadataFetcher::new(validator).unwrap();
let stats = fetcher.cache_stats();
assert_eq!(stats.total_entries, 0);
assert_eq!(stats.valid_entries, 0);
assert_eq!(stats.expired_entries, 0);
}
#[tokio::test]
async fn test_rate_limiting() {
let validator = SsrfValidator::new(SsrfPolicy {
allow_private_networks: true,
allow_localhost: true,
..Default::default()
});
let config = FetcherConfig {
rate_limit_max_requests: 2,
rate_limit_window: Duration::from_secs(60),
..Default::default()
};
let fetcher = MetadataFetcher::with_config(validator, config).unwrap();
let client_id = "https://example.com/metadata.json";
assert!(fetcher.check_rate_limit(client_id).is_ok());
assert!(fetcher.check_rate_limit(client_id).is_ok());
assert!(matches!(
fetcher.check_rate_limit(client_id),
Err(FetcherError::RateLimitExceeded(_))
));
}
}