use std::collections::HashMap;
use std::time::Duration;
use reqwest::header::{CONTENT_TYPE, HeaderMap, USER_AGENT};
use crate::error::{CrawlError, classify_reqwest_error, error_chain_string};
use crate::net::ssrf::validate_url;
#[cfg(not(target_arch = "wasm32"))]
use crate::types::CookieInfo;
use crate::types::WafClassifier;
use crate::types::{AuthConfig, CrawlConfig, ResponseMeta};
use crate::waf::TomlClassifier;
#[derive(Debug, Clone, Default)]
#[allow(dead_code)]
pub struct BrowserExtras {
pub eval_result: Option<serde_json::Value>,
pub network_events: Vec<crate::types::ResponseMeta>,
pub cookies: Vec<crate::types::CookieInfo>,
}
pub struct HttpResponse {
pub status: u16,
pub content_type: String,
pub body: String,
pub body_bytes: Vec<u8>,
#[allow(dead_code)]
pub headers: std::collections::HashMap<String, Vec<String>>,
#[allow(dead_code)]
pub browser_extras: Option<BrowserExtras>,
#[allow(dead_code)]
pub final_url: String,
}
pub(crate) async fn http_fetch(
url: &str,
config: &CrawlConfig,
extra_headers: &std::collections::HashMap<String, String>,
client: &reqwest::Client,
) -> Result<HttpResponse, CrawlError> {
let initial_url = url::Url::parse(url).map_err(|e| CrawlError::SsrfPolicyViolation {
url: url.to_string(),
reason: format!("invalid URL: {e}"),
})?;
validate_url(&initial_url, &config.ssrf)
.await
.map_err(|e| CrawlError::SsrfPolicyViolation {
url: url.to_string(),
reason: e.to_string(),
})?;
let mut current_url = initial_url.clone();
let mut final_url_str: String;
let mut redirects_followed = 0u8;
loop {
let mut req = client.get(current_url.to_string());
if let Some(ref ua) = config.user_agent {
req = req.header(USER_AGENT, ua.as_str());
} else {
req = req.header(USER_AGENT, concat!("crawlberg/", env!("CARGO_PKG_VERSION")));
}
match config.auth {
Some(AuthConfig::Basic {
ref username,
ref password,
}) => {
req = req.basic_auth(username, Some(password));
}
Some(AuthConfig::Bearer { ref token }) => {
req = req.bearer_auth(token);
}
Some(AuthConfig::Header { ref name, ref value }) => {
req = req.header(name.as_str(), value.as_str());
}
None => {}
}
for (k, v) in &config.custom_headers {
req = req.header(k.as_str(), v.as_str());
}
for (k, v) in extra_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = req.send().await.map_err(|e| classify_reqwest_error(&e))?;
let status = resp.status().as_u16();
final_url_str = resp.url().to_string();
let content_type = resp
.headers()
.get_all(CONTENT_TYPE)
.iter()
.next_back()
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let headers = resp.headers().clone();
if (300..400).contains(&status) {
let location_header = headers
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
if let Some(location) = location_header {
let next_url = match current_url.join(&location) {
Ok(u) => u,
Err(_) => {
let body_bytes_vec = resp.bytes().await.unwrap_or_default().to_vec();
let body = String::from_utf8_lossy(&body_bytes_vec).into_owned();
let mut headers_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for (name, value) in headers.iter() {
if let Ok(v) = value.to_str() {
headers_map
.entry(name.as_str().to_lowercase())
.or_default()
.push(v.to_string());
}
}
return Ok(HttpResponse {
status,
content_type,
body,
body_bytes: body_bytes_vec,
headers: headers_map,
browser_extras: None,
final_url: final_url_str,
});
}
};
if let Err(e) = validate_url(&next_url, &config.ssrf).await {
return Err(CrawlError::SsrfPolicyViolation {
url: next_url.to_string(),
reason: e.to_string(),
});
}
redirects_followed += 1;
if redirects_followed > config.ssrf.max_redirects {
return Err(CrawlError::SsrfPolicyViolation {
url: next_url.to_string(),
reason: "too many redirects".to_string(),
});
}
current_url = next_url;
continue;
}
}
match status {
401 => return Err(CrawlError::Unauthorized("unauthorized".into())),
403 => {
let body = resp.text().await.unwrap_or_default();
let partial_response = build_partial_response(status, &body, &headers);
let classifier = TomlClassifier::builtin();
if let Ok(Some(signal)) = classifier.classify(&partial_response) {
return Err(CrawlError::WafBlocked {
vendor: signal.vendor.clone(),
message: format!("waf/blocked detected: {}", signal.vendor),
});
}
return Err(CrawlError::Forbidden("forbidden".into()));
}
404 => return Err(CrawlError::NotFound(format!("not_found: {url}"))),
408 => return Err(CrawlError::Timeout("timeout: request timed out".into())),
410 => return Err(CrawlError::Gone("gone".into())),
429 => return Err(CrawlError::RateLimited("rate_limited".into())),
500 => return Err(CrawlError::ServerError("server_error".into())),
502 => return Err(CrawlError::BadGateway("bad_gateway".into())),
503 => {
return Err(CrawlError::ServerError("server_error: service unavailable".into()));
}
_ => {}
}
if (200..300).contains(&status) {
let headers_only_response = build_partial_response(status, "", &headers);
let classifier = TomlClassifier::builtin();
if let Ok(Some(signal)) = classifier.classify(&headers_only_response) {
let body = resp.text().await.unwrap_or_default();
let partial_response = build_partial_response(status, &body, &headers);
let vendor = classifier
.classify(&partial_response)
.ok()
.flatten()
.map(|s| s.vendor)
.unwrap_or(signal.vendor);
return Err(CrawlError::WafBlocked {
message: format!("waf/blocked detected on 2xx (header): {vendor}"),
vendor,
});
}
}
let expected_len = headers
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
let body_bytes = resp.bytes().await.map_err(|e| {
let chain = error_chain_string(&e);
let is_body_error = chain.contains("content-length")
|| chain.contains("truncat")
|| chain.contains("incomplete")
|| chain.contains("end of file")
|| chain.contains("body error")
|| chain.contains("body from connection")
|| chain.contains("decoding response body")
|| chain.contains("error decoding");
#[cfg(not(target_arch = "wasm32"))]
let is_body_error = is_body_error || e.is_body();
if is_body_error {
CrawlError::DataLoss(format!("data_loss: {e}"))
} else {
classify_reqwest_error(&e)
}
})?;
if let Some(expected) = expected_len
&& body_bytes.len() < expected
&& expected - body_bytes.len() > 100
{
return Err(CrawlError::DataLoss(format!(
"data_loss: expected {expected} bytes, got {}",
body_bytes.len()
)));
}
let body_bytes_vec = body_bytes.to_vec();
let body = String::from_utf8_lossy(&body_bytes_vec).into_owned();
if (200..300).contains(&status) {
let partial_response = build_partial_response_with_bytes(status, &body_bytes_vec, &body, &headers);
let classifier = TomlClassifier::builtin();
if let Ok(Some(signal)) = classifier.classify(&partial_response) {
return Err(CrawlError::WafBlocked {
vendor: signal.vendor.clone(),
message: format!("waf/blocked detected on 2xx (body): {}", signal.vendor),
});
}
}
let mut headers_map: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
for (name, value) in headers.iter() {
if let Ok(v) = value.to_str() {
headers_map
.entry(name.as_str().to_lowercase())
.or_default()
.push(v.to_string());
}
}
return Ok(HttpResponse {
status,
content_type,
body,
body_bytes: body_bytes_vec,
headers: headers_map,
browser_extras: None,
final_url: final_url_str,
});
}
}
#[cfg_attr(target_arch = "wasm32", allow(unused_variables, unused_mut))]
pub(crate) fn build_client(config: &CrawlConfig) -> Result<reqwest::Client, CrawlError> {
let mut builder = reqwest::Client::builder();
#[cfg(not(target_arch = "wasm32"))]
{
builder = builder
.redirect(reqwest::redirect::Policy::none())
.timeout(config.request_timeout);
}
#[cfg(not(target_arch = "wasm32"))]
if config.cookies_enabled {
builder = builder.cookie_store(true);
}
#[cfg(not(target_arch = "wasm32"))]
if let Some(provider) = config.proxy_provider.clone() {
let proxy = reqwest::Proxy::custom(move |url| {
let host = url.host_str().unwrap_or("");
let cfg = provider.next_proxy(host)?;
let mut parsed = reqwest::Url::parse(&cfg.url).ok()?;
if let (Some(user), Some(pass)) = (&cfg.username, &cfg.password) {
let _ = parsed.set_username(user);
let _ = parsed.set_password(Some(pass));
}
Some(parsed)
});
builder = builder.proxy(proxy);
} else if let Some(ref proxy_config) = config.proxy {
let mut proxy = reqwest::Proxy::all(&proxy_config.url)
.map_err(|e| CrawlError::InvalidConfig(format!("invalid proxy URL: {e}")))?;
if let (Some(user), Some(pass)) = (&proxy_config.username, &proxy_config.password) {
proxy = proxy.basic_auth(user, pass);
}
builder = builder.proxy(proxy);
}
builder
.build()
.map_err(|e| CrawlError::Other(format!("Failed to build HTTP client: {e}")))
}
pub(crate) async fn fetch_with_retry(
url: &str,
config: &CrawlConfig,
extra_headers: &std::collections::HashMap<String, String>,
client: &reqwest::Client,
) -> Result<HttpResponse, CrawlError> {
let retries = config.retry_count;
let retry_codes = config.retry_codes.clone();
let mut last_err = None;
for attempt in 0..=retries {
match http_fetch(url, config, extra_headers, client).await {
Ok(resp) => return Ok(resp),
Err(e) => {
let should_retry = match &e {
CrawlError::ServerError(_) => retry_codes.contains(&503) || retry_codes.contains(&500),
CrawlError::RateLimited(_) => retry_codes.contains(&429),
_ => false,
};
if should_retry && attempt < retries {
let delay = Duration::from_millis(100 * (1 << attempt));
tokio::time::sleep(delay).await;
last_err = Some(e);
continue;
}
return Err(e);
}
}
}
Err(last_err.unwrap_or_else(|| CrawlError::Other("retry exhausted".into())))
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn extract_cookies_from_hashmap(
headers: &std::collections::HashMap<String, Vec<String>>,
) -> Vec<CookieInfo> {
let mut cookies = Vec::new();
if let Some(values) = headers.get("set-cookie") {
for raw in values {
let parts: Vec<&str> = raw.split(';').collect();
if let Some(nv) = parts.first()
&& let Some((name, value)) = nv.split_once('=')
{
let mut cookie = CookieInfo {
name: name.trim().to_owned(),
value: value.trim().to_owned(),
domain: None,
path: None,
};
for attr in &parts[1..] {
let attr = attr.trim().to_lowercase();
if let Some(d) = attr.strip_prefix("domain=") {
cookie.domain = Some(d.to_owned());
} else if let Some(p) = attr.strip_prefix("path=") {
cookie.path = Some(p.to_owned());
}
}
cookies.push(cookie);
}
}
}
cookies
}
pub(crate) fn extract_response_meta_from_hashmap(
headers: &std::collections::HashMap<String, Vec<String>>,
) -> ResponseMeta {
ResponseMeta {
etag: headers.get("etag").and_then(|v| v.first().cloned()),
last_modified: headers.get("last-modified").and_then(|v| v.first().cloned()),
cache_control: headers.get("cache-control").and_then(|v| v.first().cloned()),
server: headers.get("server").and_then(|v| v.first().cloned()),
x_powered_by: headers.get("x-powered-by").and_then(|v| v.first().cloned()),
content_language: headers.get("content-language").and_then(|v| v.first().cloned()),
content_encoding: headers.get("content-encoding").and_then(|v| v.first().cloned()),
}
}
fn build_partial_response(status: u16, body: &str, headers: &HeaderMap) -> HttpResponse {
let body_bytes = body.as_bytes().to_vec();
build_partial_response_with_bytes(status, &body_bytes, body, headers)
}
fn build_partial_response_with_bytes(status: u16, body_bytes: &[u8], body: &str, headers: &HeaderMap) -> HttpResponse {
let mut headers_map: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in headers.iter() {
if let Ok(v) = value.to_str() {
headers_map
.entry(name.as_str().to_lowercase())
.or_default()
.push(v.to_string());
}
}
HttpResponse {
status,
content_type: String::new(),
body: body.to_string(),
body_bytes: body_bytes.to_vec(),
headers: headers_map,
browser_extras: None,
final_url: String::new(),
}
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn detect_waf_vendor(server: &str, body: &str) -> String {
let body_bytes = body.as_bytes().to_vec();
let mut headers_map: HashMap<String, Vec<String>> = HashMap::new();
if !server.is_empty() {
headers_map
.entry("server".to_string())
.or_default()
.push(server.to_string());
}
let response = HttpResponse {
status: 403,
content_type: String::new(),
body: body.to_string(),
body_bytes,
headers: headers_map,
browser_extras: None,
final_url: String::new(),
};
TomlClassifier::builtin()
.classify(&response)
.ok()
.flatten()
.map(|s| s.vendor)
.unwrap_or_else(|| "unknown".to_string())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn is_waf_blocked(server: &str, body: &str, headers: &HashMap<String, Vec<String>>) -> bool {
let body_bytes = body.as_bytes().to_vec();
let mut headers_map: HashMap<String, Vec<String>> = HashMap::new();
for (k, values) in headers {
headers_map.insert(k.to_lowercase(), values.clone());
}
if !server.is_empty() {
headers_map
.entry("server".to_string())
.or_default()
.push(server.to_string());
}
let response = HttpResponse {
status: 403,
content_type: String::new(),
body: body.to_string(),
body_bytes,
headers: headers_map,
browser_extras: None,
final_url: String::new(),
};
TomlClassifier::builtin().classify(&response).ok().flatten().is_some()
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn http_fetch_populates_final_url() {
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/page"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body>Hello</body></html>")
.append_header("content-type", "text/html"),
)
.mount(&mock)
.await;
let url = format!("{}/page", mock.uri());
let mut config = CrawlConfig::default();
config.ssrf.deny_private = false;
let client = build_client(&config).expect("client must build");
let resp = http_fetch(&url, &config, &std::collections::HashMap::new(), &client)
.await
.expect("http_fetch must succeed");
assert!(
!resp.final_url.is_empty(),
"final_url must not be empty after a successful fetch"
);
assert!(
resp.final_url.contains("/page"),
"final_url must contain the requested path, got: {}",
resp.final_url
);
}
}