use crate::client::FetchOptions;
use crate::convert::{
extract_headings, extract_metadata, filter_excessive_newlines, html_to_markdown, html_to_text,
is_html, is_markdown_content_type, is_plain_text_content_type, strip_boilerplate,
};
use crate::error::FetchError;
use crate::fetchers::Fetcher;
use crate::file_saver::FileSaver;
use crate::types::{FetchRequest, FetchResponse, HttpMethod};
use crate::DEFAULT_USER_AGENT;
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_DISPOSITION, LOCATION, USER_AGENT};
use std::time::Duration;
use tracing::{debug, error, warn};
use url::Url;
const BINARY_PREFIXES: &[&str] = &[
"image/",
"audio/",
"video/",
"application/octet-stream",
"application/pdf",
"application/zip",
"application/gzip",
"application/x-tar",
"application/x-rar",
"application/x-7z",
"application/vnd.ms-",
"application/vnd.openxmlformats",
"font/",
];
const FIRST_BYTE_TIMEOUT: Duration = Duration::from_secs(1);
const BODY_TIMEOUT: Duration = Duration::from_secs(30);
const TRUNCATION_MESSAGE: &str = "\n\n[..content truncated...]";
const MAX_REDIRECTS: usize = 10;
const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub struct DefaultFetcher;
impl DefaultFetcher {
pub fn new() -> Self {
Self
}
}
impl Default for DefaultFetcher {
fn default() -> Self {
Self::new()
}
}
fn build_headers(options: &FetchOptions, accept: &str, request: &FetchRequest) -> HeaderMap {
let mut headers = HeaderMap::new();
let user_agent = options.user_agent.as_deref().unwrap_or(DEFAULT_USER_AGENT);
headers.insert(
USER_AGENT,
HeaderValue::from_str(user_agent)
.unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_USER_AGENT)),
);
headers.insert(
ACCEPT,
HeaderValue::from_str(accept).unwrap_or_else(|_| HeaderValue::from_static("*/*")),
);
if let Some(ref etag) = request.if_none_match {
if let Ok(v) = HeaderValue::from_str(etag) {
headers.insert(reqwest::header::IF_NONE_MATCH, v);
}
}
if let Some(ref date) = request.if_modified_since {
if let Ok(v) = HeaderValue::from_str(date) {
headers.insert(reqwest::header::IF_MODIFIED_SINCE, v);
}
}
headers
}
#[cfg(feature = "bot-auth")]
fn apply_bot_auth_if_enabled(
mut headers: HeaderMap,
options: &FetchOptions,
url: &Url,
) -> HeaderMap {
if let Some(ref bot_auth) = options.bot_auth {
if let Some(authority) = url.host_str() {
match bot_auth.sign_request(authority) {
Ok(auth_headers) => {
if let Ok(v) = HeaderValue::from_str(&auth_headers.signature) {
headers.insert("signature", v);
}
if let Ok(v) = HeaderValue::from_str(&auth_headers.signature_input) {
headers.insert("signature-input", v);
}
if let Some(ref fqdn) = auth_headers.signature_agent {
if let Ok(v) = HeaderValue::from_str(fqdn) {
headers.insert("signature-agent", v);
}
}
}
Err(e) => {
warn!("Bot-auth signing failed: {e}");
}
}
}
}
headers
}
#[cfg(not(feature = "bot-auth"))]
fn apply_bot_auth_if_enabled(headers: HeaderMap, _options: &FetchOptions, _url: &Url) -> HeaderMap {
headers
}
struct ResponseMeta {
content_type: Option<String>,
last_modified: Option<String>,
etag: Option<String>,
content_length: Option<u64>,
filename: Option<String>,
}
fn extract_response_meta(headers: &HeaderMap, url: &str) -> ResponseMeta {
ResponseMeta {
content_type: headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
last_modified: headers
.get("last-modified")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
etag: headers
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
content_length: headers
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok()),
filename: extract_filename(headers, url),
}
}
#[async_trait]
impl Fetcher for DefaultFetcher {
fn name(&self) -> &'static str {
"default"
}
fn matches(&self, _url: &Url) -> bool {
true
}
async fn fetch(
&self,
request: &FetchRequest,
options: &FetchOptions,
) -> Result<FetchResponse, FetchError> {
if request.url.is_empty() {
return Err(FetchError::MissingUrl);
}
let method = request.effective_method();
let wants_markdown = options.enable_markdown && request.wants_markdown();
let wants_text = options.enable_text && request.wants_text();
let max_body_size = options.max_body_size.unwrap_or(DEFAULT_MAX_BODY_SIZE);
let accept = if wants_markdown {
"text/html, text/markdown, text/plain, */*;q=0.8"
} else if wants_text {
"text/html, text/plain, */*;q=0.8"
} else {
"*/*"
};
let headers = build_headers(options, accept, request);
let parsed_url = url::Url::parse(&request.url).map_err(|_| FetchError::InvalidUrlScheme)?;
let headers = apply_bot_auth_if_enabled(headers, options, &parsed_url);
let reqwest_method = match method {
HttpMethod::Get => reqwest::Method::GET,
HttpMethod::Head => reqwest::Method::HEAD,
};
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;
let status_code = response.status().as_u16();
let final_url = response.url().to_string();
let meta = extract_response_meta(response.headers(), &final_url);
if status_code == 304 {
return Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
last_modified: meta.last_modified,
etag: meta.etag,
..Default::default()
});
}
if method == HttpMethod::Head {
return Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
size: meta.content_length,
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
method: Some("HEAD".to_string()),
redirect_chain,
..Default::default()
});
}
if let Some(ref ct) = meta.content_type {
if is_binary_content_type(ct) {
return Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
size: meta.content_length,
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
redirect_chain,
error: Some(
"Binary content is not supported. Only textual content (HTML, text, JSON, etc.) can be fetched."
.to_string(),
),
..Default::default()
});
}
}
let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await;
let size = body.len() as u64;
let content = String::from_utf8_lossy(&body).to_string();
let is_paywall = detect_paywall(&content);
let is_html_content = is_html(&meta.content_type, &content);
let wants_main = request.wants_main_content();
let page_metadata = if is_html_content {
let mut pm = extract_metadata(&content);
pm.headings = extract_headings(&content);
if pm.is_empty() {
None
} else {
Some(pm)
}
} else {
None
};
let (format, final_content) =
if is_markdown_content_type(&meta.content_type) && wants_markdown {
debug!("Content-type is markdown; skipping HTML conversion");
("markdown".to_string(), content)
} else if is_plain_text_content_type(&meta.content_type) && wants_text {
debug!("Content-type is plain text; skipping HTML conversion");
("text".to_string(), content)
} else if is_html_content {
let html = if wants_main {
strip_boilerplate(&content)
} else {
content
};
if wants_markdown {
("markdown".to_string(), html_to_markdown(&html))
} else if wants_text {
("text".to_string(), html_to_text(&html))
} else {
("raw".to_string(), html)
}
} else {
("raw".to_string(), content)
};
let mut final_content = filter_excessive_newlines(&final_content);
if truncated {
final_content.push_str(TRUNCATION_MESSAGE);
}
let word_count = count_words(&final_content);
Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
size: Some(size),
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
format: Some(format),
content: Some(final_content),
truncated: if truncated { Some(true) } else { None },
metadata: page_metadata,
word_count: Some(word_count),
redirect_chain,
is_paywall: if is_paywall { Some(true) } else { None },
..Default::default()
})
}
async fn fetch_to_file(
&self,
request: &FetchRequest,
options: &FetchOptions,
saver: &dyn FileSaver,
) -> Result<FetchResponse, FetchError> {
let save_path = match &request.save_to_file {
Some(path) => path.clone(),
None => return self.fetch(request, options).await,
};
if request.url.is_empty() {
return Err(FetchError::MissingUrl);
}
let method = request.effective_method();
let max_body_size = options.max_body_size.unwrap_or(DEFAULT_MAX_BODY_SIZE);
let headers = build_headers(options, "*/*", request);
let parsed_url = url::Url::parse(&request.url).map_err(|_| FetchError::InvalidUrlScheme)?;
let headers = apply_bot_auth_if_enabled(headers, options, &parsed_url);
let reqwest_method = match method {
HttpMethod::Get => reqwest::Method::GET,
HttpMethod::Head => reqwest::Method::HEAD,
};
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;
let status_code = response.status().as_u16();
let final_url = response.url().to_string();
let meta = extract_response_meta(response.headers(), &final_url);
if method == HttpMethod::Head {
return Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
size: meta.content_length,
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
method: Some("HEAD".to_string()),
redirect_chain,
..Default::default()
});
}
let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await;
let size = body.len() as u64;
let save_result = saver
.save(&save_path, &body)
.await
.map_err(|e| FetchError::SaveError(e.to_string()))?;
Ok(FetchResponse {
url: final_url,
status_code,
content_type: meta.content_type,
size: Some(size),
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
truncated: if truncated { Some(true) } else { None },
saved_path: Some(save_result.path),
bytes_written: Some(save_result.bytes_written),
redirect_chain,
..Default::default()
})
}
}
async fn send_request_following_redirects(
initial_url: Url,
method: reqwest::Method,
headers: HeaderMap,
options: &FetchOptions,
) -> Result<(reqwest::Response, Vec<String>), FetchError> {
let mut current_url = initial_url;
let mut redirect_chain = Vec::new();
for redirect_count in 0..=MAX_REDIRECTS {
let client = build_client_for_url(¤t_url, headers.clone(), options)?;
let response = client
.request(method.clone(), current_url.clone())
.send()
.await
.map_err(FetchError::from_reqwest)?;
let Some(next_url) = redirect_target(¤t_url, &response, options)? else {
return Ok((response, redirect_chain));
};
if redirect_count == MAX_REDIRECTS {
return Err(FetchError::RequestError("too many redirects".to_string()));
}
debug!(
from = %current_url,
to = %next_url,
hop = redirect_count + 1,
"Following redirect with IP validation"
);
redirect_chain.push(current_url.to_string());
current_url = next_url;
}
unreachable!("redirect loop must return before exhausting iterations");
}
fn build_client_for_url(
url: &Url,
headers: HeaderMap,
options: &FetchOptions,
) -> Result<reqwest::Client, FetchError> {
let mut client_builder = reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(FIRST_BYTE_TIMEOUT)
.timeout(FIRST_BYTE_TIMEOUT)
.redirect(reqwest::redirect::Policy::none());
if !options.respect_proxy_env {
client_builder = client_builder.no_proxy();
}
if options.dns_policy.block_private {
if let Some(host) = url.host_str() {
let port = url.port_or_known_default().unwrap_or(80);
let validated_addr = options
.dns_policy
.resolve_and_validate(host, port)
.map_err(|_| FetchError::BlockedUrl)?;
client_builder = client_builder.resolve(host, validated_addr);
}
}
client_builder.build().map_err(FetchError::ClientBuildError)
}
fn redirect_target(
base_url: &Url,
response: &reqwest::Response,
options: &FetchOptions,
) -> Result<Option<Url>, FetchError> {
if !response.status().is_redirection() || response.status().as_u16() == 304 {
return Ok(None);
}
let location = response
.headers()
.get(LOCATION)
.ok_or_else(|| {
FetchError::RequestError("redirect response missing Location header".to_string())
})?
.to_str()
.map_err(|_| {
FetchError::RequestError("redirect Location header is not valid UTF-8".to_string())
})?;
let next_url = base_url.join(location).map_err(|_| {
FetchError::RequestError("redirect Location is not a valid URL".to_string())
})?;
if next_url.scheme() != "http" && next_url.scheme() != "https" {
return Err(FetchError::InvalidUrlScheme);
}
options.validate_redirect_target(base_url, &next_url)?;
Ok(Some(next_url))
}
fn is_binary_content_type(content_type: &str) -> bool {
let ct_lower = content_type.to_lowercase();
BINARY_PREFIXES
.iter()
.any(|prefix| ct_lower.starts_with(prefix))
}
fn extract_filename(headers: &HeaderMap, url: &str) -> Option<String> {
if let Some(disposition) = headers.get(CONTENT_DISPOSITION) {
if let Ok(value) = disposition.to_str() {
if let Some(filename) = parse_content_disposition_filename(value) {
return Some(filename);
}
}
}
if let Ok(parsed) = url::Url::parse(url) {
if let Some(mut segments) = parsed.path_segments() {
if let Some(last) = segments.next_back() {
if last.contains('.') && !last.is_empty() {
return Some(last.to_string());
}
}
}
}
None
}
fn parse_content_disposition_filename(value: &str) -> Option<String> {
let patterns = ["filename=\"", "filename="];
for pattern in patterns {
if let Some(start) = value.find(pattern) {
let rest = &value[start + pattern.len()..];
if pattern.ends_with('"') {
if let Some(end) = rest.find('"') {
return Some(rest[..end].to_string());
}
} else {
let end = rest
.find(|c: char| c.is_whitespace() || c == ';')
.unwrap_or(rest.len());
let filename = rest[..end].trim_matches('"');
if !filename.is_empty() {
return Some(filename.to_string());
}
}
}
}
None
}
async fn read_body_with_timeout(
response: reqwest::Response,
timeout: Duration,
max_size: usize,
) -> (Bytes, bool) {
let mut body = Vec::new();
let mut stream = response.bytes_stream();
let deadline = tokio::time::Instant::now() + timeout;
loop {
let chunk_future = stream.next();
let timeout_future = tokio::time::sleep_until(deadline);
tokio::select! {
chunk = chunk_future => {
match chunk {
Some(Ok(bytes)) => {
let remaining = max_size.saturating_sub(body.len());
if remaining == 0 {
warn!("Body size limit reached ({}), truncating", max_size);
return (Bytes::from(body), true);
}
if bytes.len() > remaining {
body.extend_from_slice(&bytes[..remaining]);
warn!("Body size limit reached ({}), truncating", max_size);
return (Bytes::from(body), true);
}
body.extend_from_slice(&bytes);
}
Some(Err(e)) => {
error!("Error reading body chunk: {}", e);
let has_content = !body.is_empty();
return (Bytes::from(body), has_content);
}
None => {
return (Bytes::from(body), false);
}
}
}
_ = timeout_future => {
warn!("Body timeout reached, returning partial content");
return (Bytes::from(body), true);
}
}
}
}
fn count_words(text: &str) -> u64 {
text.split_whitespace().count() as u64
}
const PAYWALL_INDICATORS: &[&str] = &[
"paywall",
"subscribe to read",
"subscribe to continue",
"subscription required",
"premium content",
"members only",
"sign in to read",
"log in to read",
"create a free account",
"already a subscriber",
"unlock this article",
"get unlimited access",
"start your free trial",
];
fn detect_paywall(html: &str) -> bool {
let lower = html.to_lowercase();
PAYWALL_INDICATORS
.iter()
.any(|indicator| lower.contains(indicator))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::DnsPolicy;
use crate::types::FetchRequest;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_is_binary_content_type() {
assert!(is_binary_content_type("image/png"));
assert!(is_binary_content_type("image/jpeg"));
assert!(is_binary_content_type("audio/mp3"));
assert!(is_binary_content_type("video/mp4"));
assert!(is_binary_content_type("application/pdf"));
assert!(is_binary_content_type("application/octet-stream"));
assert!(is_binary_content_type("application/zip"));
assert!(is_binary_content_type("application/vnd.ms-excel"));
assert!(is_binary_content_type(
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
));
assert!(is_binary_content_type("font/woff2"));
assert!(!is_binary_content_type("text/html"));
assert!(!is_binary_content_type("text/plain"));
assert!(!is_binary_content_type("application/json"));
assert!(!is_binary_content_type("application/javascript"));
}
#[test]
fn test_parse_content_disposition_filename() {
assert_eq!(
parse_content_disposition_filename("attachment; filename=\"file.pdf\""),
Some("file.pdf".to_string())
);
assert_eq!(
parse_content_disposition_filename("attachment; filename=file.pdf"),
Some("file.pdf".to_string())
);
assert_eq!(
parse_content_disposition_filename("inline; filename=\"report.xlsx\"; size=1234"),
Some("report.xlsx".to_string())
);
assert_eq!(parse_content_disposition_filename("inline"), None);
}
#[test]
fn test_extract_filename_from_url() {
let headers = HeaderMap::new();
assert_eq!(
extract_filename(&headers, "https://example.com/path/to/file.pdf"),
Some("file.pdf".to_string())
);
assert_eq!(
extract_filename(&headers, "https://example.com/path/to/document"),
None
);
assert_eq!(extract_filename(&headers, "https://example.com/"), None);
}
#[test]
fn test_default_fetcher_matches_all() {
let fetcher = DefaultFetcher::new();
let url = Url::parse("https://example.com").unwrap();
assert!(fetcher.matches(&url));
let url = Url::parse("https://github.com/owner/repo").unwrap();
assert!(fetcher.matches(&url));
}
#[tokio::test]
async fn test_manual_redirect_following() {
let destination = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/final"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("redirected")
.insert_header("content-type", "text/plain"),
)
.mount(&destination)
.await;
let origin = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", format!("{}/final", destination.uri())),
)
.mount(&origin)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
enable_text: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/start", origin.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.content.as_deref(), Some("redirected"));
}
#[tokio::test]
async fn test_redirect_target_handles_relative_location() {
let origin = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(ResponseTemplate::new(302).insert_header("location", "/final"))
.mount(&origin)
.await;
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
let base_url = Url::parse(&format!("{}/start", origin.uri())).unwrap();
let response = client.get(base_url.clone()).send().await.unwrap();
let redirect = redirect_target(&base_url, &response, &FetchOptions::default()).unwrap();
assert_eq!(
redirect.unwrap(),
Url::parse(&format!("{}/final", origin.uri())).unwrap()
);
}
#[tokio::test]
async fn test_redirect_target_rejects_non_http_location() {
let origin = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "file:///etc/passwd"),
)
.mount(&origin)
.await;
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
let base_url = Url::parse(&format!("{}/start", origin.uri())).unwrap();
let response = client.get(base_url.clone()).send().await.unwrap();
let redirect = redirect_target(&base_url, &response, &FetchOptions::default());
assert!(matches!(redirect, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_skip_conversion_for_markdown_content_type() {
let server = MockServer::start().await;
let md_body = "# Already Markdown\n\nThis is **already** formatted.";
Mock::given(method("GET"))
.and(path("/doc"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(md_body, "text/markdown; charset=utf-8"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/doc", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.format.as_deref(), Some("markdown"));
assert!(response
.content
.as_deref()
.unwrap()
.contains("# Already Markdown"));
assert!(response.content.as_deref().unwrap().contains("**already**"));
}
#[tokio::test]
async fn test_skip_conversion_for_plain_text_content_type() {
let server = MockServer::start().await;
let text_body = "Just plain text\nwith newlines.";
Mock::given(method("GET"))
.and(path("/plain"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(text_body)
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_text: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/plain", server.uri())).as_text();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.format.as_deref(), Some("text"));
assert!(response
.content
.as_deref()
.unwrap()
.contains("Just plain text"));
}
#[tokio::test]
async fn test_markdown_content_type_without_markdown_request_returns_raw() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/doc"))
.respond_with(
ResponseTemplate::new(200).set_body_raw("# Title", "text/markdown; charset=utf-8"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/doc", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.format.as_deref(), Some("raw"));
assert!(response.content.as_deref().unwrap().contains("# Title"));
}
#[tokio::test]
async fn test_plain_text_content_type_without_text_request_returns_raw() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/plain"))
.respond_with(ResponseTemplate::new(200).set_body_raw("hello world", "text/plain"))
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/plain", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.format.as_deref(), Some("raw"));
}
#[cfg(feature = "bot-auth")]
#[tokio::test]
async fn test_bot_auth_headers_sent() {
use crate::bot_auth::BotAuthConfig;
use wiremock::matchers::header_exists;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/authed"))
.and(header_exists("signature"))
.and(header_exists("signature-input"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("ok")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
bot_auth: Some(BotAuthConfig::from_seed([10u8; 32])),
..Default::default()
};
let request = FetchRequest::new(format!("{}/authed", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.content.as_deref(), Some("ok"));
}
#[cfg(feature = "bot-auth")]
#[tokio::test]
async fn test_bot_auth_signature_agent_header_sent() {
use crate::bot_auth::BotAuthConfig;
use wiremock::matchers::{header, header_exists};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/agent"))
.and(header_exists("signature"))
.and(header_exists("signature-input"))
.and(header("signature-agent", "bot.example.com"))
.respond_with(ResponseTemplate::new(200).set_body_string("agent ok"))
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
bot_auth: Some(BotAuthConfig::from_seed([11u8; 32]).with_agent_fqdn("bot.example.com")),
..Default::default()
};
let request = FetchRequest::new(format!("{}/agent", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 200);
}
#[tokio::test]
async fn test_etag_returned_in_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/page"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("content")
.insert_header("content-type", "text/plain")
.insert_header("etag", "\"abc123\""),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/page", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.etag.as_deref(), Some("\"abc123\""));
}
#[tokio::test]
async fn test_conditional_fetch_304_not_modified() {
use wiremock::matchers::header;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/page"))
.and(header("if-none-match", "\"abc123\""))
.respond_with(ResponseTemplate::new(304).insert_header("etag", "\"abc123\""))
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request =
FetchRequest::new(format!("{}/page", server.uri())).if_none_match("\"abc123\"");
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 304);
assert_eq!(response.etag.as_deref(), Some("\"abc123\""));
assert!(response.content.is_none());
assert!(response.format.is_none());
}
#[tokio::test]
async fn test_conditional_fetch_if_modified_since() {
use wiremock::matchers::header_exists;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/page"))
.and(header_exists("if-modified-since"))
.respond_with(ResponseTemplate::new(304))
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/page", server.uri()))
.if_modified_since("Wed, 21 Oct 2015 07:28:00 GMT");
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 304);
assert!(response.content.is_none());
}
#[test]
fn test_count_words() {
assert_eq!(count_words("hello world"), 2);
assert_eq!(count_words(""), 0);
assert_eq!(count_words(" one two three "), 3);
assert_eq!(count_words("word"), 1);
}
#[test]
fn test_detect_paywall() {
assert!(detect_paywall("<div class=\"paywall\">Subscribe</div>"));
assert!(detect_paywall("<p>Subscribe to read the full article</p>"));
assert!(detect_paywall("<span>Already a subscriber? Log in</span>"));
assert!(detect_paywall("<div>Unlock this article</div>"));
assert!(!detect_paywall("<p>This is a normal article</p>"));
assert!(!detect_paywall("<h1>Hello World</h1><p>Free content</p>"));
}
#[tokio::test]
async fn test_word_count_in_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/article"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Hello world this is a test")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/article", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.word_count, Some(6));
}
#[tokio::test]
async fn test_redirect_chain_tracked() {
let destination = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/final"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("arrived")
.insert_header("content-type", "text/plain"),
)
.mount(&destination)
.await;
let origin = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", format!("{}/final", destination.uri())),
)
.mount(&origin)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/start", origin.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.redirect_chain.len(), 1);
assert!(response.redirect_chain[0].contains("/start"));
}
#[tokio::test]
async fn test_no_redirect_chain_for_direct_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/direct"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("direct")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/direct", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();
assert!(response.redirect_chain.is_empty());
}
#[tokio::test]
async fn test_paywall_detection() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/paywalled"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><div class='paywall'>Subscribe to read the full article</div><p>Preview...</p></body></html>")
.insert_header("content-type", "text/html"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/paywalled", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert_eq!(response.is_paywall, Some(true));
}
#[tokio::test]
async fn test_no_paywall_for_normal_content() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/free"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><p>This is free content</p></body></html>")
.insert_header("content-type", "text/html"),
)
.mount(&server)
.await;
let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/free", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();
assert!(response.is_paywall.is_none());
}
}