#[cfg(feature = "bot-auth")]
use crate::bot_auth::BotAuthConfig;
use crate::dns::DnsPolicy;
use crate::error::FetchError;
use crate::fetchers::FetcherRegistry;
use crate::types::{FetchRequest, FetchResponse};
use url::Url;
#[derive(Debug, Clone, Default)]
pub struct FetchOptions {
pub user_agent: Option<String>,
pub allow_prefixes: Vec<String>,
pub block_prefixes: Vec<String>,
pub enable_markdown: bool,
pub enable_text: bool,
pub dns_policy: DnsPolicy,
pub max_body_size: Option<usize>,
pub enable_save_to_file: bool,
pub respect_proxy_env: bool,
pub allowed_ports: Vec<u16>,
pub blocked_hosts: Vec<String>,
pub same_host_redirects_only: bool,
#[cfg(feature = "bot-auth")]
pub bot_auth: Option<BotAuthConfig>,
}
impl FetchOptions {
pub(crate) fn validate_url(&self, url: &Url) -> Result<(), FetchError> {
self.validate_host(url)?;
self.validate_port(url)?;
Ok(())
}
pub(crate) fn validate_redirect_target(
&self,
current_url: &Url,
next_url: &Url,
) -> Result<(), FetchError> {
self.validate_url(next_url)?;
if self.same_host_redirects_only
&& normalized_host(current_url) != normalized_host(next_url)
{
return Err(FetchError::BlockedUrl);
}
Ok(())
}
fn validate_host(&self, url: &Url) -> Result<(), FetchError> {
let Some(host) = normalized_host(url) else {
return Ok(());
};
if self
.blocked_hosts
.iter()
.any(|rule| host_matches_rule(&host, rule))
{
return Err(FetchError::BlockedUrl);
}
Ok(())
}
fn validate_port(&self, url: &Url) -> Result<(), FetchError> {
if self.allowed_ports.is_empty() {
return Ok(());
}
let port = url.port_or_known_default().ok_or(FetchError::BlockedUrl)?;
if self.allowed_ports.contains(&port) {
Ok(())
} else {
Err(FetchError::BlockedUrl)
}
}
}
fn normalized_host(url: &Url) -> Option<String> {
url.host_str()
.map(|host| host.trim_end_matches('.').to_ascii_lowercase())
}
fn host_matches_rule(host: &str, rule: &str) -> bool {
let normalized_rule = rule.trim_end_matches('.').to_ascii_lowercase();
if let Some(suffix) = normalized_rule.strip_prefix('.') {
host == suffix || host.ends_with(&format!(".{suffix}"))
} else {
host == normalized_rule
}
}
pub async fn fetch(req: FetchRequest) -> Result<FetchResponse, FetchError> {
let options = FetchOptions {
enable_markdown: true,
enable_text: true,
..Default::default()
};
fetch_with_options(req, options).await
}
pub async fn fetch_with_options(
req: FetchRequest,
options: FetchOptions,
) -> Result<FetchResponse, FetchError> {
if req.url.is_empty() {
return Err(FetchError::MissingUrl);
}
let registry = FetcherRegistry::with_defaults();
registry.fetch(req, options).await
}
const DEFAULT_BATCH_CONCURRENCY: usize = 5;
pub async fn batch_fetch(
requests: Vec<FetchRequest>,
concurrency: Option<usize>,
) -> Vec<Result<FetchResponse, FetchError>> {
let options = FetchOptions {
enable_markdown: true,
enable_text: true,
..Default::default()
};
batch_fetch_with_options(requests, options, concurrency).await
}
pub async fn batch_fetch_with_options(
requests: Vec<FetchRequest>,
options: FetchOptions,
concurrency: Option<usize>,
) -> Vec<Result<FetchResponse, FetchError>> {
use futures::stream::{self, StreamExt};
use std::sync::Arc;
let concurrency = concurrency.unwrap_or(DEFAULT_BATCH_CONCURRENCY).max(1);
let num_requests = requests.len();
let options = Arc::new(options);
let mut indexed_results: Vec<(usize, Result<FetchResponse, FetchError>)> =
stream::iter(requests.into_iter().enumerate())
.map(|(idx, req)| {
let options = Arc::clone(&options);
async move {
let registry = FetcherRegistry::with_defaults();
let result = registry.fetch(req, (*options).clone()).await;
(idx, result)
}
})
.buffer_unordered(concurrency)
.collect()
.await;
indexed_results.sort_by_key(|(idx, _)| *idx);
let mut results: Vec<Result<FetchResponse, FetchError>> = (0..num_requests)
.map(|_| Err(FetchError::MissingUrl))
.collect();
for (idx, result) in indexed_results {
results[idx] = result;
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_fetch_empty_url() {
let req = FetchRequest::new("");
let result = fetch(req).await;
assert!(matches!(result, Err(FetchError::MissingUrl)));
}
#[tokio::test]
async fn test_fetch_invalid_scheme() {
let req = FetchRequest::new("ftp://example.com");
let result = fetch(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_fetch_options_default() {
let options = FetchOptions::default();
assert!(options.user_agent.is_none());
assert!(options.allow_prefixes.is_empty());
assert!(options.block_prefixes.is_empty());
assert!(!options.enable_markdown);
assert!(!options.enable_text);
assert!(options.dns_policy.block_private);
assert!(options.max_body_size.is_none());
assert!(!options.enable_save_to_file);
assert!(!options.respect_proxy_env);
assert!(options.allowed_ports.is_empty());
assert!(options.blocked_hosts.is_empty());
assert!(!options.same_host_redirects_only);
}
#[test]
fn test_validate_url_blocks_configured_host_and_port() {
let options = FetchOptions {
allowed_ports: vec![443],
blocked_hosts: vec!["localhost".to_string(), ".internal".to_string()],
..Default::default()
};
assert!(matches!(
options.validate_url(&Url::parse("https://api.internal").unwrap()),
Err(FetchError::BlockedUrl)
));
assert!(matches!(
options.validate_url(&Url::parse("https://example.com:8443").unwrap()),
Err(FetchError::BlockedUrl)
));
assert!(options
.validate_url(&Url::parse("https://example.com").unwrap())
.is_ok());
}
#[test]
fn test_validate_redirect_target_blocks_cross_host_when_enabled() {
let options = FetchOptions {
same_host_redirects_only: true,
..Default::default()
};
let current = Url::parse("https://example.com/start").unwrap();
let next = Url::parse("https://www.example.com/end").unwrap();
assert!(matches!(
options.validate_redirect_target(¤t, &next),
Err(FetchError::BlockedUrl)
));
}
#[tokio::test]
async fn test_batch_fetch_multiple_urls() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/page1"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Page 1")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/page2"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Page 2")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let requests = vec![
FetchRequest::new(format!("{}/page1", server.uri())),
FetchRequest::new(format!("{}/page2", server.uri())),
];
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let results = batch_fetch_with_options(requests, options, None).await;
assert_eq!(results.len(), 2);
assert!(results[0]
.as_ref()
.unwrap()
.content
.as_deref()
.unwrap()
.contains("Page 1"));
assert!(results[1]
.as_ref()
.unwrap()
.content
.as_deref()
.unwrap()
.contains("Page 2"));
}
#[tokio::test]
async fn test_batch_fetch_partial_failure() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/ok"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("OK")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let requests = vec![
FetchRequest::new(format!("{}/ok", server.uri())),
FetchRequest::new(""), ];
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let results = batch_fetch_with_options(requests, options, None).await;
assert_eq!(results.len(), 2);
assert!(results[0].is_ok());
assert!(results[1].is_err());
}
#[tokio::test]
async fn test_batch_fetch_respects_concurrency_limit() {
let requests = vec![
FetchRequest::new(""), FetchRequest::new(""), ];
let results = batch_fetch(requests, Some(1)).await;
assert_eq!(results.len(), 2);
assert!(results[0].is_err());
assert!(results[1].is_err());
}
#[tokio::test]
async fn test_batch_fetch_empty_input() {
let results = batch_fetch(vec![], None).await;
assert!(results.is_empty());
}
}